In [1]:
from sceptr import sceptr
import pandas as pd
from src.model import sceptr_unidirectional, load_trained
from pathlib import Path
import torch
import re
import warnings
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

warnings.filterwarnings("ignore")

dir = Path.cwd() / "results" / "sceptr" / "eval" / "trained-sceptr-caneval-0"
pattern = re.compile(r"eval-set-auc-(.*).csv")
bestepoch = int(pattern.match(str(list(dir.glob("eval-set-auc-*.csv"))[0].name)).group(1))

model = dir / f"Epoch {bestepoch}" / f"classifier-{bestepoch}.pth"
model = load_trained(model, sceptr_unidirectional)

In [2]:
files = list((Path.cwd() / "data" / "sceptr-eval").glob("**/*.tsv"))
dfs = [(pd.read_csv(i, sep = "\t", dtype = object), int("cancer" in str(i))) for i in files]

In [3]:
control_idx = []
cancer_idx = []

for i, item in tqdm(list(enumerate(dfs))):
    df, label = item
    vecs = torch.from_numpy(sceptr.calc_vector_representations(df))
    vecs = vecs.cuda() if torch.cuda.is_available() else vecs
    pred = model(vecs)

    if int(round(pred.item(), 0)) == label:
        nonzeros = torch.nonzero(model.last_weights)[:, 0].tolist()

        if label == 1:
            cancer_idx.append((files[i], label, nonzeros))
        else:
            control_idx.append((files[i], label, nonzeros))

  0%|          | 0/19 [00:00<?, ?it/s]

In [4]:
import tidytcells as tt

def cleandf(df):
    # We enforce the V call and J call to be from Alpha or Beta Chains First
    enforce_abv = df["v_call"].str.startswith("TRA") | df["v_call"].str.startswith("TRB") | pd.isna(df["v_call"])
    enforce_abj = df["j_call"].str.startswith("TRA") | df["j_call"].str.startswith("TRB") | pd.isna(df["j_call"])
    enforce_cdr3_notempty = ~pd.isna(df["junction_aa"])
    df = df[enforce_abv & enforce_abj & enforce_cdr3_notempty].copy()
    # We then enforce them to be functional
    df["v_call"] = df["v_call"].apply(lambda i: tt.tr.standardise(i, enforce_functional = True, suppress_warnings = True) \
        if i is not None else i)

    df["j_call"] = df["j_call"].apply(lambda i: tt.tr.standardise(i, enforce_functional = True, suppress_warnings = True) \
        if i is not None else i)
    
    return df


seqfnames = pd.read_csv("rds_file_locations/tcrseq_pbmcfnames.csv")
chains = {}
nonzero_idx = control_idx[:]

for file, label, idx in tqdm(nonzero_idx):
    if label == 0:
        dir = Path.cwd() / "data" / "full-trimmed" / "control"
        pat = file.name.replace(".tsv", "")
        raws = list(dir.glob(f"*{pat}*"))
    else:
        patid = file.name.replace("_positive.tsv", "")
        fname = seqfnames[seqfnames["LTX_ID"] == patid]["filename"].tolist()
        fname = [i.replace(".gz", "") for i in fname]
        dir = Path.cwd() / "data" / "full-trimmed" / "pbmc_cancer"
        raws = [dir / i for i in fname if (dir / i).exists()]
    
    for raw in raws:
        if raw.suffix == ".tsv":
            df = pd.read_csv(raw, delimiter = "\t")
            df = df[["v_call", "j_call", "junction_aa", "duplicate_count"]]
            df = df.dropna(axis=0, how="all")
            df = cleandf(df)
        else:
            df = pd.read_csv(raw, delimiter = ", ", index_col=None, header=None)
            df.columns = ["junction_aa", "duplicate_count"]
            
        idx = [i for i in idx if i < len(df)]
        if file.name not in chains.keys():
            chains[file.name]["alpha" if "alpha" in str(raw) else "beta"] = {}
        chains[file.name]["alpha" if "alpha" in str(raw) else "beta"] = df.iloc[idx]

  0%|          | 0/9 [00:00<?, ?it/s]

KeyError: 'dcr_HCW_0125.tsv'