In [None]:
import os
import json
import argparse
import numpy as np
import torch
import pandas as pd
import anndata
import scanpy as sc
import pickle

from STELLAR import STELLAR
from datasets import GraphDataset
from matplotlib import pyplot as plt
from sklearn.metrics import pairwise_distances

In [None]:
import sys
sys.path.append("../")

### Setting default parameters

In [None]:
EXPERIMENT_DIR = "/home/dani/Documents/Thesis/Methods/IMCBenchmark/output/stellar_seurat"
CONFIG_PATH = os.path.join(EXPERIMENT_DIR, 'config.json')

# load the params
with open(CONFIG_PATH) as f:
    config = json.load(f)

In [None]:
parser = argparse.ArgumentParser(description='STELLAR')

args = parser.parse_args(args=[])
args.train_dataset = config['train_dataset']
args.val_dataset = config['val_dataset']
args.epochs = config['epochs']
args.lr = config['lr']
args.wd = config['wd']
args.sample_rate = config['sample_rate']
args.batch_size = config['batch_size']
args.distance_threshold = config['distance_threshold']
args.num_heads = config['num_heads']
args.seed = config['seed']
args.num_heads = config['num_heads']
args.num_seed_class = config['num_seed_class']
args.cuda = torch.cuda.is_available()
args.device = torch.device("cuda" if args.cuda else "cpu")
args.use_processed_graph = True

In [None]:
args

### Preparing dataset

In [None]:
def create_labels_dict(train_df, val_df):
    train_labels = list(set(train_df['cell_type']))
    val_labels = list(set(val_df['cell_type']))
    labels = list(set(train_labels + val_labels))
    
    cell_types = np.sort(labels).tolist()
    cell_type_dict = {}
    inverse_dict = {}    
    for i, cell_type in enumerate(cell_types):
        cell_type_dict[cell_type] = i
        inverse_dict[i] = cell_type
        
    return cell_type_dict, inverse_dict

In [None]:
train_df = pd.read_csv(args.train_dataset)
train_df = train_df[train_df['cell_type'] != 'unlabeled'].reset_index(drop=True)
val_df = pd.read_csv(args.val_dataset)
val_df = val_df[val_df['cell_type'] != 'unlabeled'].reset_index(drop=True)

cell_type_dict, inverse_dict = create_labels_dict(train_df, val_df)

train_df['cell_type'] = train_df['cell_type'].map(cell_type_dict)

In [None]:
train_df.head()

#### Compute graph between cells

In [None]:
def get_own_edge_index(pos, distance_threshold):
    edge_list = []
    dists = pairwise_distances(pos)
    dists_mask = dists < distance_threshold
    np.fill_diagonal(dists_mask, 0)
    edge_list = np.transpose(np.nonzero(dists_mask)).tolist()
    return edge_list


def prepare_graph(train_df, test_df, distance_threshold, sample_rate):
    train_df = train_df.sample(n=round(sample_rate*len(train_df)), random_state=1)
    train_X = train_df.iloc[:, 9:].values
    test_X = test_df.iloc[:, 9:].values
    train_y = train_df['cell_type'].values
    labeled_pos = train_df[['x', 'y']].values
    unlabeled_pos = test_df[['x', 'y']].values
    labeled_edges = get_own_edge_index(labeled_pos, distance_threshold)
    unlabeled_edges = get_own_edge_index(unlabeled_pos, distance_threshold)
    
    return train_X, train_y, test_X, labeled_edges, unlabeled_edges

In [None]:
PROCESSED_GRAPH_FILE = os.path.join(EXPERIMENT_DIR, "dataset_preprocessed.pkl")

if args.use_processed_graph and os.path.exists(PROCESSED_GRAPH_FILE):
    packed_graph = pickle.load(open(PROCESSED_GRAPH_FILE, "rb" ))
else:
    packed_graph = prepare_graph(train_df, val_df, args.distance_threshold, args.sample_rate)
    
    # save to .pkl 
    with open(PROCESSED_GRAPH_FILE, 'wb') as file:
        pickle.dump(packed_graph, file)

labeled_X, labeled_y, unlabeled_X, labeled_edges, unlabeled_edges = packed_graph
dataset = GraphDataset(labeled_X, labeled_y, unlabeled_X, labeled_edges, unlabeled_edges)

### Training

In [None]:
stellar = STELLAR(args, dataset)
stellar.train()

### Validation

In [None]:
_, pred_prob, pred_prob_list, pred_labels = stellar.pred()
pred_labels = pred_labels.astype('object')

for i in range(len(pred_labels)):
    if pred_labels[i] in inverse_dict.keys():
        pred_labels[i] = inverse_dict[pred_labels[i]]

results_df = val_df[['sample_id', 'object_id', 'cell_type']].copy()
results_df['pred'] = pred_labels.tolist()
results_df['pred_prob'] = pred_prob.tolist()
results_df['prob_list'] = pred_prob_list.tolist()
results_df.rename(columns = {
    'sample_id': 'image_id',
    'object_id': 'cell_id',
    'cell_type': 'label'
}, inplace = True)

results_df.to_csv(os.path.join(EXPERIMENT_DIR, 'stellar_results.csv'), index=False)

results_df.head()

### Visualizing results

Cells are colored according to predicted cell types. Novel classes are denoted with numbers.

In [None]:
adata = anndata.AnnData(unlabeled_X)
adata.obs['pred'] = pd.Categorical(pred_labels)

sc.pp.neighbors(adata)
sc.tl.umap(adata)
fig = sc.pl.umap(adata, color=['pred'], size=5, return_fig=True)

fig.savefig(os.path.join(EXPERIMENT_DIR, 'UMAP_predictions.pdf'), format="pdf", bbox_inches="tight")