# Spectra on BAL Samples

In [1]:
import scipy
import json
import pickle
import os
import shutil
import collections
import pickle
import numpy as np
import scanpy as sc
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
from collections import OrderedDict
from spectra import spectra as spc

In [2]:
def get_factor_celltypes(adata, obs_key, cellscore_obsm_key = 'SPECTRA_cell_scores'):
    '''
    Assigns Spectra factors to cell types by analyzing the factor cell scores.
    Cell type specific factors will have zero cell scores except in their respective cell type
    
    adata: AnnData , object containing the Spectra output
    obs_key: str , column name in adata.obs containing the cell type annotations
    cellscore_obsm_key: str , key for adata.obsm containing the Spectra cell scores
    
    returns: dict , dictionary of {factor index : 'cell type'}
    '''
    
    # Get cellscores
    import pandas as pd
    cell_scores_df = pd.DataFrame(adata.obsm[cellscore_obsm_key])
    cell_scores_df['celltype'] = list(adata.obs[obs_key])
    
    # Find global and cell type specific fators
    global_factors_series = (cell_scores_df.groupby('celltype').mean() != 0).all()
    global_factors = [factor for factor in global_factors_series.index if global_factors_series[factor]]
    specific_cell_scores = (cell_scores_df.groupby('celltype').mean()).T[~global_factors_series].T
    specific_factors = {}
    
    for i in set(cell_scores_df['celltype']):
        specific_factors[i] = [factor for factor in specific_cell_scores.loc[i].index if
                               specific_cell_scores.loc[i, factor]]
    
    # Inverse dict factor:celltype
    factors_inv = {}
    for i,v in specific_factors.items():
        for factor in v:
            factors_inv[factor] = i
    
    # Add global
    for factor in global_factors:
        factors_inv[factor] = 'global'
            
    return factors_inv

In [3]:
def check_gene_set_dictionary(adata, 
                              annotations, 
                              obs_key = 'cell_type_annotations', 
                              global_key = 'global', 
                              return_dict = True):
    '''
    Filters annotations dictionary contains only genes contained in the adata. 
    Checks that annotations dictionary cell type keys and adata cell types are identical.
    Checks that all gene sets in annotations dictionary contain >2 genes after filtering.
    
    adata: AnnData , data to use with Spectra
    annotations: dict , gene set annotations dictionary to use with Spectra
    obs_key: str , column name for cell type annotations in adata.obs
    global_key: str , key for global gene sests in gene set annotation dictionary
    return_dict: bool , return filtered gene set annotation dictionary
    
    returns: dict , filtered gene set annotation dictionary
    
    '''
    # test if keys match
    adata_labels = list(set(adata.obs[obs_key])) + [global_key]  # cell type labels in adata object
    annotation_labels = list(annotations.keys())
    matching_celltype_labels = list(set(adata_labels).intersection(annotation_labels))
    dict_keys_ok = False
    if set(annotation_labels) == set(adata_labels):
        print('Cell type labels in gene set annotation dictionary and AnnData object are identical')
        dict_keys_ok = True
    
    counter = 0
    annotations_new = {}
    for k, v in annotations.items():
        annotations_new[k] = {}
        for k2, v2 in v.items():
            v2 = [x for x in v2 if x in adata.var_names]
            annotations_new[k][k2] = v2
            length = len(v2)
            if length < 3:
                print(f"gene set {k2} for cell type {k} is of length {length}")
                counter += 1
            
    if counter > 0:
        print(f"{counter} gene set(s) are too small. Gene sets must contain at least 3 genes")
    elif counter == 0 and dict_keys_ok:
        print('Your gene set annotation dictionary is correctly formatted.')
    if return_dict:
        return annotations_new

### Read in data

In [4]:
data_dir = "/projects/b1038/Pulmonary/cpuritz/PASC/data"

In [5]:
adata = sc.read_h5ad(f"{data_dir}/01BAL/01integrated_BAL_v12/01integrated_BAL_v12.h5ad")
adata_raw = sc.read_h5ad(f"{data_dir}/01BAL/raw/adata_raw.h5ad")

In [6]:
gene_sets = pd.read_excel(f"{data_dir}/01BAL/spectra/spectra-gene-sets.xlsx")

In [7]:
# cytopus convention is to use hyphens, let's use hyphens
gene_sets.cell_type = gene_sets.cell_type.replace({
    "B_memory": "B-memory",
    "CD4_T": "CD4-T",
    "CD8_T": "CD8-T",
    "pDC": "p-DC",
})

In [8]:
input_gene_sets = collections.defaultdict(dict)
for _, r in gene_sets.iterrows():
    input_gene_sets[r.cell_type][r.gene_set_name] = eval(r.gene_set)

In [9]:
input_gene_sets.keys()

dict_keys(['all-cells', 'B-memory', 'CD4-T', 'CD8-T', 'DC', 'MDC', 'p-DC', 'leukocyte', 'mast', 'TFH', 'Treg', 'Mac', 'CD8-T_KLRG1pos-effector', 'TSCM', 'endo-aerocyte', 'B-memory-switched', 'gran', 'Langerhans', 'TRM', 'lung-endo-venous', 'capillary', 'gdT', 'plasma', 'B-pb-mature', 'plasma-blast', 'B', 'TCM', 'T-naive', 'TEM', 'B-memory-DN', 'mono', 'ILC3-NCRpos', 'CD56dim-NK', 'cDC1', 'T', 'mo-DC', 'B-pb-t2', 'B-memory-IgM-MZ', 'cDC2', 'Lti', 'FDC', 'cDC3', 'CD8-T_KLRG1neg-effector', 'CD8-T-progenitor-exhausted', 'NK', 'ILC2', 'ILC1', 'CD56bright-NK', 'B-memory-non-switched', 'B-pb-t1', 'NK-adaptive', 'B-naive', 'endo-systemic-venous', 'GC-B', 'ILC3-NCRneg'])

In [10]:
np.sort(adata.obs.cell_type.cat.categories)

array(['B cells', 'CD4 T cells-1', 'CD4 T cells-2', 'CD8 T cells-1',
       'CD8 T cells-2', 'CD8 T cells-3', 'DC1', 'DC2', 'Epithelial cells',
       'Mast cells', 'Migratory DC', 'MoAM-1', 'MoAM-2', 'MoAM-3',
       'MoAM-4', 'Monocytes-1', 'Monocytes-2', 'Perivascular macrophages',
       'Plasma cells', 'Proliferating T cells',
       'Proliferating macrophages', 'SARS-CoV-2', 'TRAM-1', 'TRAM-2',
       'TRAM-3', 'TRAM-4', 'TRAM-5', 'TRAM-6', 'TRAM-7', 'Tregs',
       'gdT cells and NK cells', 'pDC'], dtype=object)

### Map our cell types to Spectra's coarser labels

In [11]:
cell_type_map = {
    "TRAM-1": "Mac",
    "TRAM-2": "Mac",
    "TRAM-3": "Mac",
    "TRAM-4": "Mac",
    "TRAM-5": "Mac",
    "TRAM-6": "Mac",
    "TRAM-7": "Mac",
    "Proliferating macrophages": "Mac",
    
    "MoAM-1": "MDC",
    "MoAM-2": "MDC",
    "MoAM-3": "MDC",
    "MoAM-4": "MDC",
    "Perivascular macrophages": "MDC",
    
    "Monocytes-1": "mono",
    "Monocytes-2": "mono",
    
    "CD4 T cells-1": "CD4-T",
    "CD4 T cells-2": "CD4-T",
    
    "CD8 T cells-1": "CD8-T",
    "CD8 T cells-2": "CD8-T",
    "CD8 T cells-3": "CD8-T",
    
    "Tregs": "Treg",
    "gdT cells and NK cells": "NK",
    "Proliferating T cells": "T",
    
    "DC1": "cDC1",
    "DC2": "cDC2",
    "Migratory DC": "cDC3",
    "pDC": "p-DC",
    
    "Mast cells": "mast",
    "B cells": "B",
    "Plasma cells": "plasma"
}

In [12]:
print([x for x in adata.obs.cell_type.cat.categories if x not in cell_type_map.keys()])
print([x for x in cell_type_map.keys() if x not in adata.obs.cell_type.cat.categories])

['Epithelial cells', 'SARS-CoV-2']
[]


### Filter gene sets

In [13]:
adata.obs["cell_type_spectra"] = adata.obs.cell_type.replace(cell_type_map)

In [14]:
present_cell_types = adata.obs.cell_type_spectra.unique().tolist() + ["all-cells"]
input_gene_sets = {k: v for k, v in input_gene_sets.items() if k in present_cell_types}

In [15]:
present_sets = set(input_gene_sets.keys())
for ct in set(present_cell_types) - present_sets:
    input_gene_sets[ct] = {}

In [16]:
input_gene_sets["global"] = input_gene_sets.pop("all-cells")

In [17]:
cnt = [0, 0]
to_delete = []
for ct, sets in input_gene_sets.items():
    for name, gene_set in sets.items():
        gene_set = list(set(gene_set).intersection(adata.var_names))
        if len(gene_set) < 3:
            cnt[0] += 1
            to_delete.append((ct, name))
        else:
            cnt[1] += 1
            sets[name] = gene_set
for ct, name in to_delete:
    del input_gene_sets[ct][name]
print(f"{cnt[0]} gene sets excluded due to < 3 genes, {cnt[1]} gene sets kept")

118 gene sets excluded due to < 3 genes, 68 gene sets kept


### Coarsen cell type labels
We'll merge TRAMs, MoAMs, monocytes, and DC2 cells into one coarse cell type. Same for all T and NK cells.

In [18]:
to_merge = ["Mac", "MDC", "mono", "cDC2"]
merge_into = to_merge[0]
for ct in to_merge[1:]:
    input_gene_sets[merge_into].update(input_gene_sets[ct])
    del input_gene_sets[ct]
adata.obs.loc[adata.obs.cell_type_spectra.isin(to_merge[1:]), "cell_type_spectra"] = merge_into

In [19]:
to_merge = ["T", "NK", "Treg", "CD4-T", "CD8-T"]
merge_into = to_merge[0]
for ct in to_merge[1:]:
    input_gene_sets[merge_into].update(input_gene_sets[ct])
    del input_gene_sets[ct]
adata.obs.loc[adata.obs.cell_type_spectra.isin(to_merge[1:]), "cell_type_spectra"] = merge_into

In [20]:
adata.obs.cell_type_spectra.cat.categories

Index(['Mac', 'CD4-T', 'CD8-T', 'MDC', 'mono', 'cDC2', 'NK', 'Treg', 'T', 'B',
       'cDC1', 'cDC3', 'p-DC', 'mast', 'plasma', 'Epithelial cells',
       'SARS-CoV-2'],
      dtype='object')

### Remove unused Spectra annotations

In [21]:
adata.obs.cell_type_spectra = adata.obs.cell_type_spectra.cat.remove_unused_categories()

In [22]:
adata.obs.cell_type_spectra.cat.categories

Index(['Mac', 'T', 'B', 'cDC1', 'cDC3', 'p-DC', 'mast', 'plasma',
       'Epithelial cells', 'SARS-CoV-2'],
      dtype='object')

In [23]:
check_gene_set_dictionary(adata, input_gene_sets, obs_key = "cell_type_spectra", return_dict = False)

Cell type labels in gene set annotation dictionary and AnnData object are identical
Your gene set annotation dictionary is correctly formatted.


### Allow novel factors for each cell type of interest

In [24]:
n_novel = 1
L = {ct: len(sets) + n_novel if ct not in ["Epithelial cells", "SARS-CoV-2"] else len(sets) for ct, sets in 
     input_gene_sets.items()}
L

{'p-DC': 3,
 'mast': 2,
 'Mac': 5,
 'plasma': 2,
 'B': 2,
 'cDC1': 2,
 'T': 19,
 'cDC3': 2,
 'Epithelial cells': 0,
 'SARS-CoV-2': 0,
 'global': 40}

### Train model

In [25]:
with open(f"{data_dir}/01BAL/spectra/input_gene_sets.pkl", 'wb') as f:
    pickle.dump(input_gene_sets, f)
with open(f"{data_dir}/01BAL/spectra/L.pkl", 'wb') as f:
    pickle.dump(L, f)
adata.write_h5ad(f"{data_dir}/01BAL/spectra/adata.h5ad")

In [31]:
### Model trained on a Quest GPU node. ###

In [5]:
# Read in trained model
adata = sc.read_h5ad(f"{data_dir}/01BAL/spectra/adata.h5ad")
with open(f"{data_dir}/01BAL/spectra/model.pkl", 'rb') as f:
    model = pickle.load(f)
with open(f"{data_dir}/01BAL/spectra/input_gene_sets.pkl", 'rb') as f:
    input_gene_sets = pickle.load(f)

### Add Spectra scores to anndata

In [6]:
factor_celltypes = get_factor_celltypes(adata, "cell_type_spectra")

In [7]:
gene_weights = pd.DataFrame(
    adata.uns['SPECTRA_factors'], 
    index = [f"{factor_celltypes[x]}_{x}" for x in range(adata.uns['SPECTRA_factors'].shape[0])],
    columns = adata.var[adata.var['spectra_vocab']].index
).T

In [8]:
cell_scores = pd.DataFrame(
    adata.obsm['SPECTRA_cell_scores'], 
    index = adata.obs_names,
    columns = [f"{factor_celltypes[x]}_{x}" for x in range(adata.uns['SPECTRA_factors'].shape[0])]
)

In [9]:
adata.obs = pd.concat([adata.obs, cell_scores], axis = 1)

In [None]:
orig_factor_names = model.matching(adata.uns["SPECTRA_markers"], input_gene_sets)
orig_factor_names_map = {}
for i, orig_name in enumerate(orig_factor_names):
    curr_name = cell_scores.columns[i]
    ct = factor_celltypes[i]
    if orig_name == "0":
        new_name = f"F_{i}_{ct}"
    elif ct != "global":
        new_name = f"F_{i}_{ct}_{orig_name.replace(ct, '')}"
    else:
        new_name = f"F_{i}_{orig_name}"
    orig_factor_names_map[curr_name] = new_name

with open(f"{data_dir}/01BAL/spectra/orig_factor_names_map.pkl", 'wb') as f:
    pickle.dump(orig_factor_names_map, f)
adata.obs.rename(columns = orig_factor_names_map, inplace = True)
list(adata.obs.columns)

## Save output

In [None]:
v = "v12_spectra"
out_dir = f"01NEP/01integrated_BAL_{v}"
out_name = f"01integrated_BAL_{v}"

In [None]:
adata.write(f"{data_dir}/{out_dir}/{out_name}.h5ad")
adata.obs.to_csv(f"{data_dir}/{out_dir}/{out_name}-metadata.csv")