Evaluate DeepTCR. Note that DeepTCR does _not_ provide a pre-trained model. This should be done under the `deeptcr` conda env on Python 3.8, which is NOT the same as our default environment (as it requires tensorflow instead of pytorch).

TF/CUDA (10.1) compatibility https://www.tensorflow.org/install/source#gpu

References:
- Manuscript: https://www.nature.com/articles/s41467-021-21879-w#data-availability
- Code: https://github.com/sidhomj/DeepTCR
- Tutorial for unsupervised training of DeepTCR https://github.com/sidhomj/DeepTCR/blob/master/tutorials/unsupervised/8%20-%20VAE%20Inference.ipynb

In [3]:
import os, sys
import collections
import logging
import itertools
from typing import *

import numpy as np 
import pandas as pd

SRC_DIR = os.path.join(os.path.dirname(os.getcwd()), "tcr")
assert os.path.isdir(SRC_DIR)
sys.path.append(SRC_DIR)
import canonical_models as models

RESULTS_DIR = os.path.join(
    os.path.dirname(SRC_DIR),
    "external_eval", "deeptcr_embeds",
)
assert os.path.isdir(RESULTS_DIR), f"Cannot find {RESULTS_DIR}"

os.environ["CUDA_VISIBLE_DEVICES"]="2"

In [4]:
from DeepTCR.DeepTCR import DeepTCR_U

DTCRU = DeepTCR_U('murine_antigens', device=3)
DTCRU

<DeepTCR.DeepTCR.DeepTCR_U at 0x7f752bb1bee0>

In [5]:
DTCRU.Get_Data(
    directory="/home/wukevin/downloads/DeepTCR/Data/Murine_Antigens",
    Load_Prev_Data=False, aggregate_by_aa=True, count_column=1,
    aa_column_beta=0,
)

Loading Data...
Embedding Sequences...
Data Loaded


In [6]:
DTCRU.Train_VAE(Load_Prev_Data=False)

Instructions for updating:
Use `tf.keras.layers.Conv2D` instead.
Instructions for updating:
Please use `layer.__call__` method instead.
Instructions for updating:
Use keras.layers.Flatten instead.
Instructions for updating:
Use keras.layers.dropout instead.
Instructions for updating:
Use keras.layers.Dense instead.
Instructions for updating:
Use `tf.keras.layers.Conv2DTranspose` instead.
Epoch = 0, Iteration = 0 Total Loss: 1.10987: Recon Loss: 1.03073: Latent Loss: 0.07913: Sparsity Loss: 0.00000: Recon Accuracy: 0.05407
Epoch = 1, Iteration = 0 Total Loss: 1.09609: Recon Loss: 1.01961: Latent Loss: 0.07648: Sparsity Loss: 0.00000: Recon Accuracy: 0.07234
Epoch = 2, Iteration = 0 Total Loss: 1.08236: Recon Loss: 1.00877: Latent Loss: 0.07358: Sparsity Loss: 0.00000: Recon Accuracy: 0.09324
Epoch = 3, Iteration = 0 Total Loss: 1.06518: Recon Loss: 0.99580: Latent Loss: 0.06938: Sparsity Loss: 0.00000: Recon Accuracy: 0.12457
Epoch = 4, Iteration = 0 Total Loss: 1.04377: Recon Loss: 0.9

## Load in LCMV data

In [7]:
###
### THIS IS COPIED FROM data_loaders.py DUE TO CROSS-ENVIRONMENT
### COMPATIBILITY ISSUES
###

def dedup(x: Iterable[Any]) -> List[Any]:
    """
    Dedup the given iterable, preserving order of occurrence
    >>> dedup([1, 2, 0, 1, 3, 2])
    [1, 2, 0, 3]
    >>> dedup(dedup([1, 2, 0, 1, 3, 2]))
    [1, 2, 0, 3]
    """
    # https://stackoverflow.com/questions/480214/how-do-you-remove-duplicates-from-a-list-whilst-preserving-order
    # Python 3.7 and above guarantee that dict is insertion ordered
    # sets do NOT do this, so list(set(x)) will lose order information
    return list(dict.fromkeys(x))

def load_lcmv_table(
    fname: str = "/home/wukevin/projects/tcr/tcr/data/lcmv_tetramer_tcr.txt",
    drop_na: bool = True,
    drop_unsorted: bool = True,
) -> pd.DataFrame:
    """Load the LCMV data table"""
    table = pd.read_csv(fname, delimiter="\t")
    if drop_na:
        table.dropna(axis=0, how="any", subset=["tetramer", "TRB", "TRA"], inplace=True)
    if drop_unsorted:
        drop_idx = table.index[table["tetramer"] == "Unsorted"]
        table.drop(index=drop_idx, inplace=True)

    # Take entires with multiple TRA or TRB sequences and split them, carrying over
    # all of the other metadata to each row
    dedup_rows = []
    for i, row in table.iterrows():
        for i, (tra, trb) in enumerate(
            itertools.product(row["TRA"].split(";"), row["TRB"].split(";"))
        ):
            new_row = row.copy(deep=True)
            new_row["TRA"] = tra
            new_row["TRB"] = trb
            dedup_rows.append(new_row)
    dedup_table = pd.DataFrame(dedup_rows, columns=table.columns)

    return dedup_table

def dedup_and_merge_labels(
    sequences: Sequence[str], labels: Sequence[str], sep: str = ","
) -> Tuple[List[str], List[str]]:
    """
    Remove duplicates in sequences and merge labels accordingly
    sep is the label separator, used to split and rejoin labels
    Return is sorted!

    >>> dedup_and_merge_labels(['a', 'b', 'a'], ['x', 'y', 'y'])
    (['a', 'b'], ['x,y', 'y'])
    >>> dedup_and_merge_labels(['a', 'b', 'a', 'a'], ['x', 'y', 'y,x', 'z'])
    (['a', 'b'], ['x,y,z', 'y'])
    >>> dedup_and_merge_labels(['a', 'b', 'd', 'c'], ['x', 'z', 'y', 'n'])
    (['a', 'b', 'c', 'd'], ['x', 'z', 'n', 'y'])
    """
    # unique returns the *sorted* unique elements of an array
    uniq_sequences, inverse_idx, uniq_seq_counts = np.unique(
        sequences, return_inverse=True, return_counts=True
    )
    uniq_labels, agg_count = [], 0
    # Walk through all unique sequences and fetch/merge corresponding labels
    for i, (seq, c) in enumerate(zip(uniq_sequences, uniq_seq_counts)):
        orig_idx = np.where(inverse_idx == i)[0]
        match_labels = dedup([labels[i] for i in orig_idx])
        if len(match_labels) == 1:
            uniq_labels.append(match_labels.pop())
        else:  # Aggregate labels
            aggregated_labels = dedup(
                list(
                    itertools.chain.from_iterable([l.split(sep) for l in match_labels])
                )
            )
            logging.debug(f"Merging {match_labels} -> {sep.join(aggregated_labels)}")
            agg_count += 1
            uniq_labels.append(sep.join(sorted(aggregated_labels)))
    assert len(uniq_sequences) == len(uniq_labels)
    logging.info(
        f"Deduped from {len(sequences)} -> {len(uniq_sequences)} merging {agg_count} labels"
    )
    return list(uniq_sequences), uniq_labels

def dedup_lcmv_table(
    lcmv_tab: pd.DataFrame,
    blacklist_label_combos: Sequence[str] = (
        "TetMid,TetNeg",
        "TetNeg,TetPos",
        "TetMid,TetNeg,TetPos",
    ),
) -> Tuple[List[Tuple[str, str]], List[str]]:
    """
    Return TRA and TRB pairs that are deduped and removes pairs with ambiguous labels

    This was implemented to centrally solve the issue where the LCMV table had duplicate and
    a few cases of ambiguous labels

    Returns two lists of equal length:
    - List of (TRA, TRB) pairs
    - List of corresponding labels (may be merged)
    """
    lcmv_ab = ["|".join(p) for p in zip(lcmv_tab["TRA"], lcmv_tab["TRB"])]
    lcmv_ab_dedup, lcmv_labels_dedup = dedup_and_merge_labels(
        lcmv_ab, list(lcmv_tab["tetramer"])
    )
    all_label_counter = collections.Counter(lcmv_labels_dedup)
    logging.info(f"Combined labels {all_label_counter.most_common()}")
    logging.info(f"Filtering out labels {blacklist_label_combos}")
    good_label_idx = [
        i for i, l in enumerate(lcmv_labels_dedup) if l not in blacklist_label_combos
    ]
    logging.info(f"Retaining {len(good_label_idx)} pairs with unambiguous labels")
    lcmv_ab_good = [lcmv_ab_dedup[i] for i in good_label_idx]
    lcmv_labels_good = [lcmv_labels_dedup[i] for i in good_label_idx]
    assert len(lcmv_ab_good) == len(lcmv_labels_good) == len(good_label_idx)
    label_counter = collections.Counter(lcmv_labels_good)
    logging.info(f"LCMV deduped labels: {label_counter.most_common()}")

    # Resplit into pairs
    lcmv_ab_good_split = [tuple(p.split("|")) for p in lcmv_ab_good]
    return lcmv_ab_good_split, lcmv_labels_good

lcmv = load_lcmv_table()

lcmv_dedup_tra_trb, lcmv_dedup_labels = dedup_lcmv_table(lcmv)
lcmv_dedup_tra, lcmv_dedup_trb = zip(*lcmv_dedup_tra_trb)
lcmv_dedup = pd.DataFrame(
    {
        "TRA": lcmv_dedup_tra,
        "TRB": lcmv_dedup_trb,
        "tetramer": lcmv_dedup_labels,
        "label": ["TetPos" in l or "TetMid" in l for l in lcmv_dedup_labels],  # Pos and mid are both positive labels
    }
)
lcmv_dedup

Unnamed: 0,TRA,TRB,tetramer,label
0,CAAAAAGNYKYVF,CASSLLGGSYEQYF,TetNeg,False
1,CAAAASNTNKVVF,CASSLGLGANTGQLYF,TetNeg,False
2,CAAAASSGSWQLIF,CASGPREANERLFF,TetNeg,False
3,CAAADNYAQGLTF,CASGEGPDYTF,TetNeg,False
4,CAAADNYAQGLTF,CASRDWGDEQYF,TetNeg,False
...,...,...,...,...
17697,CVVNTGKLTF,CASSYGNERLFF,TetNeg,False
17698,CVYSNNRIFF,CASSLWDRGDTQYF,TetNeg,False
17699,CWGSALGRLHF,CASSDQGANTEVFF,TetNeg,False
17700,CWGSALGRLHF,CASSSGLGAEQYF,TetNeg,False


In [8]:
lcmv_labels = np.array(lcmv_dedup['label'])
lcmv_labels

array([False, False, False, ..., False, False, False])

In [9]:
lcmv_features = DTCRU.Sequence_Inference(
    alpha_sequences=lcmv_dedup['TRA'],
    beta_sequences=lcmv_dedup['TRB'],
)
lcmv_features.shape

INFO:tensorflow:Restoring parameters from murine_antigens/models/model_0/model.ckpt


INFO:tensorflow:Restoring parameters from murine_antigens/models/model_0/model.ckpt


(17702, 256)

In [12]:
np.save(os.path.join(RESULTS_DIR, "lcmv_deeptcr_embed.npy"), lcmv_features)