In [122]:
from sceptr import sceptr
import pandas as pd
from src.model import sceptr_unidirectional, load_trained
from pathlib import Path
import torch
import re
import warnings
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

warnings.filterwarnings("ignore")

dir = Path.cwd() / "results" / "sceptr" / "eval" / "trained-sceptr-caneval-2"
#pattern = re.compile(r"eval-set-auc-(.*).csv")
#bestepoch = int(pattern.match(str(list(dir.glob("eval-set-auc-*.csv"))[0].name)).group(1))
bestepoch = 49

model = dir / f"Epoch {bestepoch}" / f"classifier-{bestepoch}.pth"
model = load_trained(model, sceptr_unidirectional)

In [123]:
files = list((Path.cwd() / "data" / "sceptr-eval").glob("**/*.tsv"))
dfs = [(pd.read_csv(i, sep = "\t", dtype = object), int("cancer" in str(i))) for i in files]

In [124]:
def alphafirst(df):
    return pd.isna(df.iloc[0]).values[3:].all()

def splitidx(df):
    return df[pd.isna(df["CDR3A"])].index[0]

control_idx = []
cancer_idx = []

for i, item in tqdm(list(enumerate(dfs))):
    df, label = item
    vecs = torch.from_numpy(sceptr.calc_vector_representations(df))
    vecs = vecs.cuda() if torch.cuda.is_available() else vecs
    pred = model(vecs)

    if int(round(pred.item(), 0)) == label:
        idx = splitidx(df)
        nonzeros = torch.nonzero(model.last_weights)[:, 0].tolist()
        ws = model.last_weights[torch.nonzero(model.last_weights)[:, 0]][:, 0].tolist()
        alpha = [(i, model.last_weights[i].item()) for i in nonzeros if i < idx]
        betas = [(i - idx, model.last_weights[i].item()) for i in nonzeros if i >= idx]

        if not alphafirst(df):
            alpha, betas = betas, alpha

        if label == 1:
            cancer_idx.append((files[i], label, {"alpha": alpha, "beta": betas}))
        else:
            control_idx.append((files[i], label, {"alpha": alpha, "beta": betas}))

  0%|          | 0/19 [00:00<?, ?it/s]

In [125]:
import tidytcells as tt

def cleandf(df):
    # We enforce the V call and J call to be from Alpha or Beta Chains First
    enforce_abv = df["v_call"].str.startswith("TRA") | df["v_call"].str.startswith("TRB") | pd.isna(df["v_call"])
    enforce_abj = df["j_call"].str.startswith("TRA") | df["j_call"].str.startswith("TRB") | pd.isna(df["j_call"])
    enforce_cdr3_notempty = ~pd.isna(df["junction_aa"])
    df = df[enforce_abv & enforce_abj & enforce_cdr3_notempty].copy()
    # We then enforce them to be functional
    df["v_call"] = df["v_call"].apply(lambda i: tt.tr.standardise(i, enforce_functional = True, suppress_warnings = True) \
        if i is not None else i)

    df["j_call"] = df["j_call"].apply(lambda i: tt.tr.standardise(i, enforce_functional = True, suppress_warnings = True) \
        if i is not None else i)
    
    return df


seqfnames = pd.read_csv("rds_file_locations/tcrseq_pbmcfnames.csv")
chains = {}
nonzero_idx = cancer_idx[:]

for file, label, idx in tqdm(nonzero_idx):
    if label == 0:
        dir = Path.cwd() / "data" / "full-trimmed" / "control"
        pat = file.name.replace(".tsv", "")
        raws = list(dir.glob(f"*{pat}*"))
    else:
        patid = file.name.replace("_positive.tsv", "")
        fname = seqfnames[seqfnames["LTX_ID"] == patid]["filename"].tolist()
        fname = [i.replace(".gz", "") for i in fname]
        dir = Path.cwd() / "data" / "full-trimmed" / "pbmc_cancer"
        raws = [dir / i for i in fname if (dir / i).exists()]
    
    for raw in raws:
        chain = "alpha" if "alpha" in str(raw) else "beta"
        
        if raw.suffix == ".tsv":
            df = pd.read_csv(raw, delimiter = "\t")
            df = df[["v_call", "j_call", "junction_aa", "duplicate_count"]]
            df = df.dropna(axis=0, how="all")
            df = cleandf(df)
        else:
            df = pd.read_csv(raw, delimiter = ", ", index_col=None, header=None)
            df.columns = ["junction_aa", "duplicate_count"]
            
        if file.name not in chains.keys():
            chains[file.name] = {}
        chains[file.name][chain] = df.iloc[[i[0] for i in idx[chain]]]
        chains[file.name][chain]["prob"] = [i[1] for i in idx[chain]]
        chains[file.name][chain].sort_values(by = ["prob"], inplace = True, ascending = False)


  0%|          | 0/9 [00:00<?, ?it/s]

In [136]:
fnames = list(chains.keys())
chain = "beta"
cdr3s = pd.concat([chains[f][chain] for f in fnames])
cdr3s.sort_values(by = ["prob"], ascending = False).to_csv("top-ranking-tcrs.csv")

In [129]:
from collections import Counter

def visualise(chain):
    fnames = list(chains.keys())
    vs = sum([chains[f][chain]["v_call"].tolist() for f in fnames], [])
    js = sum([chains[f][chain]["j_call"].tolist() for f in fnames], [])
    cdr3s = sum([chains[f][chain]["junction_aa"].tolist() for f in fnames], [])

    print ("{:30}| {:7}".format("V Call", "Repeat"))
    for key, repeats in Counter(vs).items():
        print (f"{key:30}| {repeats:7}")

    print ("{:30}| {:7}".format("J Call", "Repeat"))
    for key, repeats in Counter(js).items():
        print (f"{key:30}| {repeats:7}")

    print ("")
    cdr3s = list(dict(Counter(cdr3s)).items())
    cdr3s.sort(key = lambda x: -x[1])
    for key, item in cdr3s:
        if item == 1:
            break
        
        print (f"CDR3: {key}")
        for f in fnames:
            if chains[f][chain]["junction_aa"].str.contains(key).any():
                print (f)
                print (chains[f][chain].loc[chains[f][chain]["junction_aa"].str.contains(key)])
        print ("")


In [130]:
visualise("alpha")

V Call                        | Repeat 
TRAV19                        |      50
TRAV27                        |       1
TRAV10                        |       2
TRAV3                         |       2
TRAV8-6                       |       4
TRAV8-2                       |       1
TRAV13-2                      |       1
TRAV23/DV6                    |       2
J Call                        | Repeat 
TRAJ6                         |      29
TRAJ28                        |      15
TRAJ15                        |       2
TRAJ45                        |       3
TRAJ32                        |       1
TRAJ57                        |       1
TRAJ22                        |       1
TRAJ54                        |       3
TRAJ30                        |       3
TRAJ53                        |       1
TRAJ56                        |       3
TRAJ29                        |       1

CDR3: CALSEAAGAGSYQLTF
LTX0474_positive.tsv
       v_call  j_call       junction_aa  duplicate_count      prob
9587   T

In [131]:
visualise("beta")

V Call                        | Repeat 
TRBV2                         |     240
TRBV10-3                      |      46
J Call                        | Repeat 
TRBJ1-1                       |      50
TRBJ1-2                       |      37
TRBJ2-2                       |       4
TRBJ2-7                       |      89
TRBJ2-1                       |      72
TRBJ2-3                       |      17
TRBJ1-4                       |       1
TRBJ1-3                       |       2
TRBJ2-5                       |       3
TRBJ1-5                       |       6
TRBJ1-6                       |       1
TRBJ2-4                       |       4

CDR3: CAISEDDANEQYF
LTX0538_positive.tsv
        v_call   j_call    junction_aa  duplicate_count      prob
2759  TRBV10-3  TRBJ2-7  CAISEDDANEQYF               13  0.036713
LTX0648_positive.tsv
        v_call   j_call    junction_aa  duplicate_count      prob
4802  TRBV10-3  TRBJ2-7  CAISEDDANEQYF                9  0.043292
LTX0672_positive.tsv
        v_ca