# Import

In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
%config IPCompleter.use_jedi = False

import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import seaborn as sns
from pathlib import Path
from IPython.display import display
import pickle

import os
import warnings
warnings.filterwarnings("ignore")
%config InlineBackend.figure_format = 'png'
plt.rcParams['pdf.fonttype'] = 'truetype'
plt.rcParams['svg.fonttype'] = 'none'
plt.rcParams['figure.dpi'] = 120
sns.set_style('white')

In [2]:
from functions.utils import *
from functions.clustering import *

# Data

In [3]:
# expression = read_dataset('your_sample.tsv')
expression = read_dataset('/uftp2/Datasets/TCGA/current_version/data/projects/SKCM/expressions.tsv.gz').iloc[:,:10]
expression.head()

Unnamed: 0_level_0,TCGA-FW-A3I3-06,TCGA-FS-A1ZD-06,TCGA-EE-A2M6-06,TCGA-WE-A8K4-01,TCGA-FS-A4FD-06,TCGA-EE-A3AD-06,TCGA-GN-A4U7-06,TCGA-ER-A19J-06,TCGA-D3-A5GU-06,TCGA-W3-AA1V-06
Gene,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
A1BG,0.0,0.0,0.0,0.0,0.036089,0.0,0.0,0.0,0.0,0.06477
A1CF,0.025217,0.007251,0.004564,0.024397,0.006793,0.025941,0.013212,0.016306,0.030594,0.034704
A2M,73.571877,504.510696,525.628828,140.915524,406.081849,2926.982979,457.965691,1438.106707,225.667659,98.306524
A2ML1,24.33331,0.273186,0.158535,0.109855,0.324048,3.859605,0.408411,3.658979,0.435077,0.0641
A3GALT2,0.032785,0.158301,0.048754,0.08675,0.0,0.0,0.0,0.034769,0.0,0.074897


In [5]:
raw_features = read_dataset('data/unscaled_features.tsv')

# Calculate features

In [6]:
with open('model/IE_clusters.pickle', 'rb') as f:
    ie_dict = pickle.load(f)

gmt = ie_dict['gmt_dict']
gmt = gmt_genes_alt_names(gmt, expression.index, verbose=True)

prog_coeffs = ie_dict['progeny_coeffs'].reset_index()

Matched: 356
Trying to find new names for 1 genes in 19706 known
querying 1-1...done.
Finished.
1 input query terms found dup hits:
	[('TRBC1', 2)]
Pass "returnall=True" to return complete lists of duplicate or missing query terms.
1 genes were not converted


In [7]:
ssgsea_df = ssgsea_formula(expression.T, gmt)
progen_df = run_progeny(expression, prog_coeffs=prog_coeffs).T
features_df = pd.concat([ssgsea_df, progen_df], axis=1)[ie_dict['X'].columns]

# Find reference, scale and classify

In [9]:
with open('model/ovr_knn_calibrated.pickle', 'rb') as f:
    model = pickle.load(f)

In [4]:
def predict_batch(ser, X, batch):
    distances = np.sqrt(((X - ser) ** 2).sum(axis=1))

    nearest_indices = distances.nsmallest(20).index

    nearest_batches = batch.loc[nearest_indices]

    batch_counts = nearest_batches.value_counts()

    max_count = batch_counts.max()
    most_frequent_batches = batch_counts[batch_counts == max_count].index

    for batch_value in nearest_batches:
        if batch_value in most_frequent_batches:
            most_frequent_batch = batch_value
            break

    return most_frequent_batch

In [14]:
from tqdm import tqdm
sample_df = features_df
ref_df = raw_features
batch = ref_df['Batch']
ref_df = ref_df.drop(columns='Batch')

batch_pred = {}
probas = {}
for name, ser in tqdm(sample_df.iterrows()):
    
    ser = ser[(~ser.index.isin(['Batch']))&(~ser.isna())]
    X_ref = ref_df[ser.index].dropna()

    ident_batch = predict_batch(ser, X_ref, batch)
    batch_pred[name] = ident_batch

    ref_coh = X_ref[batch==ident_batch]
    
    scaled_features = median_scale(pd.concat([ref_coh,ser.to_frame().T]))
    scaled_features = scaled_features.loc[name]
    
    prob = model.predict_proba(scaled_features.to_frame().T)
    probas[name] = list(prob[0])

10it [00:01,  8.95it/s]


In [15]:
batch_pred = pd.Series(batch_pred)
probas = pd.DataFrame(probas)

probas.index = model.classes_
probas.loc['Unclassified'] = 0

probas.loc['Unclassified',probas.columns[probas[probas>0.47].isna().all()]] = 1

class_predict = probas.idxmax()

class_predict.value_counts()

Lymphoid-Cell-Enriched    7
Unclassified              3
dtype: int64

In [16]:
batch_pred.value_counts()

Lung_Adenocarcinoma_GSE31210_batch_0                     9
Squamous_Cell_Carcinoma_of_the_Head_and_Neck_GSE40774    1
dtype: int64