In [11]:
import os, sys
import importlib
import logging
import tempfile
from typing import *

import tqdm

import numpy as np
import pandas as pd
from sklearn import metrics

import torch
from torch.utils.data import Subset
import skorch
import skorch.helper

SRC_DIR = os.path.join(os.path.dirname(os.getcwd()), "tcr")
assert os.path.isdir(SRC_DIR), f"Cannot find src dir: {SRC_DIR}"
sys.path.append(SRC_DIR)
import data_loader as dl
import utils
MODEL_DIR = os.path.join(SRC_DIR, "models")
sys.path.append(MODEL_DIR)
import conv

FILT_EDIT_DIST = False

DEVICE = utils.get_device(3)
TRAINED_MODEL_DIR = "/home/wukevin/projects/tcr/tcr_models"

In [3]:
# Load in the PIRD dataset
pird_data = dl.load_pird(with_antigen_only=True)
pird_data = pird_data.loc[~pd.isnull(pird_data['CDR3.beta.aa'])]
pird_data.head()

INFO:root:PIRD data 0.1655 data labelled with antigen sequence
INFO:root:PIRD: Removing 95 entires with non amino acid residues
INFO:root:Entries with antigen sequence: 8429/51044
INFO:root:Unique antigen sequences: 73
INFO:root:PIRD data TRA/TRB instances: Counter({'TRB': 46428, 'TRA': 4011, 'TRA-TRB': 605})
INFO:root:PIRD entries with TRB sequence: 4607
INFO:root:PIRD entries with TRB sequence: 47040
INFO:root:PIRD entries with TRA and TRB:  605


Unnamed: 0,ICDname,Disease.name,Category,Antigen,Antigen.sequence,HLA,Locus,CDR3.alpha.aa,CDR3.beta.aa,CDR3.alpha.nt,...,Cell.subtype,Prepare.method,Evaluate.method,Case.num,Control.type,Control.num,Filteration,Journal,Pubmed.id,Grade
0,A15,Tuberculosis,Pathogen,CFP10,TAAQAAVVRFQEAAN,DRB1*15:03,TRA-TRB,CIEHTNSGGSNYKLTF,CASSLEETQYF,,...,CD4,Multiple PCR,Antigen-specific ex vivo proliferation,22.0,,,,Nature,28636589,5
1,A15,Tuberculosis,Pathogen,CFP10,TAAQAAVVRFQEAAN,DRB1*15:03,TRA-TRB,CIVHTNSGGSNYKLTF,CASSPEETQYF,,...,CD4,Multiple PCR,Antigen-specific ex vivo proliferation,22.0,,,,Nature,28636589,5
2,A15,Tuberculosis,Pathogen,CFP10,TAAQAAVVRFQEAAN,DRB1*15:03,TRA-TRB,CIVKTNSGGSNYKLTF,CASSFEETQYF,,...,CD4,Multiple PCR,Antigen-specific ex vivo proliferation,22.0,,,,Nature,28636589,5
1845,A15,Tuberculosis,Pathogen,Rv1195,ADTLQSIGATTVASN,DRB1*15:03,TRA-TRB,CAGAGGGGFKTIF,CASSVALASGANVLTF,,...,CD4,Multiple PCR,Antigen-specific ex vivo proliferation,22.0,,,,Nature,28636589,5
1846,A15,Tuberculosis,Pathogen,Rv1195,ADTLQSIGATTVASN,DRB1*15:03,TRA-TRB,CAGPTGGSYIPTF,CASSVALATGEQYF,,...,CD4,Multiple PCR,Antigen-specific ex vivo proliferation,22.0,,,,Nature,28636589,5


In [13]:
tcrdb_data = dl.load_tcrdb()
tcrdb_data = tcrdb_data.loc[tcrdb_data['tra_trb'] == 'TRB']
tcrdb_data.head()

Unnamed: 0,accession,RunId,AASeq,cloneFraction,tra_trb
0,PRJNA330606,SRR4102112,CANTGTGFNEQFF,0.008305,TRB
1,PRJNA330606,SRR4102112,CASSHTRGVGTQYF,0.003841,TRB
2,PRJNA330606,SRR4102112,CSGVHEQYF,0.003824,TRB
3,PRJNA330606,SRR4102112,CASSLPNGEGSSYEQYF,0.002825,TRB
4,PRJNA330606,SRR4102112,CASSQGGIAGDVYEQYF,0.002614,TRB


In [1]:
importlib.reload(dl)

REDUCE_LR_ON_PLATEAU_PARAMS = {
    "mode": "min",
    "factor": 0.1,
    "patience": 10,
    "min_lr": 1e-6,
}

tcrdb_neg_rng = np.random.default_rng(seed=64)

def convnet_on_antigen(antigen:str) -> Tuple[str, float]:
    """
    Train a convnet on the antigen and return Tuple[antigen, test_auprc]
    If invalid, return nan
    """
    pird_pos_table = pird_data.loc[pird_data['Antigen.sequence'] == antigen]
    pird_pos_trbs = list(pird_pos_table['CDR3.beta.aa'])
    if len(pird_pos_trbs) < 20:
        return antigen, np.nan
    
    # Get a negative set of sequences from TCRdb, sampled at 5 negatives per positive sequence
    tcrdb_trbs = tcrdb_data['AASeq']
    rand_neg_trbs = list(tcrdb_data.iloc[tcrdb_neg_rng.choice(
        np.arange(len(tcrdb_trbs)), size=int(len(pird_pos_trbs) * 5), replace=False
    )]['AASeq'])
    
    full_dset = dl.TCRSupervisedIdxDataset(
        pird_pos_trbs + rand_neg_trbs,
        np.array([1] * len(pird_pos_trbs) + [0] * len(rand_neg_trbs))
    )
    train_dset_raw = dl.DatasetSplit(full_dset, split='train', valid=0.0, test=0.3)
    test_dset = dl.DatasetSplit(full_dset, split='test', valid=0.0, test=0.3)

    # Filter out training sequences too similar to test sequences
    if FILT_EDIT_DIST:
        train_seqs = train_dset_raw.all_sequences()
        test_seqs = test_dset.all_sequences()
        train_dists = dl.min_dist_train_test_seqs(train_seqs, test_seqs)
        far_idx = np.where(train_dists >= 2)[0]  # Keep only train seqs at least 2 edits away
        train_dset = Subset(train_dset_raw, far_idx)
    else:
        train_dset = train_dset_raw
    
    torch.manual_seed(42)
    with tempfile.TemporaryDirectory() as tmpdir:
        logging.info(f"Running in temporary dir: {tmpdir}")
        net = skorch.NeuralNet(
            module=conv.OnePartConvNet,
            module__use_embedding=True,
            module__n_output=2,
            module__max_input_len=full_dset.max_len,
            criterion=torch.nn.CrossEntropyLoss,
            optimizer=torch.optim.Adam,
            max_epochs=250,
            batch_size=512,
            lr=1e-3,
            callbacks=[
                skorch.callbacks.LRScheduler(
                    policy=torch.optim.lr_scheduler.ReduceLROnPlateau,
                    mode='min',
                    factor=0.1,
                    patience=10,
                    min_lr=1e-6,
                ),
                skorch.callbacks.GradientNormClipping(gradient_clip_value=5),
                skorch.callbacks.EpochScoring(
                    "average_precision",
                    lower_is_better=False,
                    on_train=False,
                    name="valid_auprc",
                ),
                skorch.callbacks.EarlyStopping(
                    patience=25,
                    monitor="valid_auprc",
                    lower_is_better=False,
                ),
                skorch.callbacks.Checkpoint(  # Seems to cause errors if placed before scoring
                    dirname=tmpdir,
                    fn_prefix="net_",
                    monitor="valid_auprc_best",
                ),
            ],
            device=3,
        )
        net.classes_ = np.unique(test_dset.all_labels())
        net.fit(train_dset)
        # Restore best weights and evaluate on test set
        cp = skorch.callbacks.Checkpoint(dirname=tmpdir, fn_prefix="net_")
        # net.load_params(checkpoint=cp)  # Performance is better with this disabled
        test_truth = test_dset.all_labels()
        test_preds = net.predict_proba(test_dset)[:, 1]
        test_auprc = metrics.average_precision_score(test_truth, test_preds)
    return antigen, test_auprc
    
convnet_on_antigen("FPRPWLHGL")

IndentationError: expected an indented block (<ipython-input-1-e93a52739a6c>, line 44)

In [17]:
antigen_auprc_pairs = [convnet_on_antigen(a) for a in utils.dedup(pird_data['Antigen.sequence'])]
len(antigen_auprc_pairs)

INFO:root:Max len not set, using empirical max len of 26
INFO:root:Using maximum length of 26
INFO:root:Split train with 970 examples
INFO:root:Split test with 416 examples
INFO:root:Running in temporary dir: /tmp/tmprzvpv1m9


  epoch    train_loss    valid_auprc    valid_loss    cp     dur
-------  ------------  -------------  ------------  ----  ------
      1        [36m0.5475[0m         [32m0.1221[0m        [35m0.6887[0m     +  0.0244
      2        [36m0.3900[0m         [32m0.2035[0m        [35m0.6639[0m     +  0.0539
      3        [36m0.3453[0m         [32m0.3061[0m        [35m0.6462[0m     +  0.0409
      4        [36m0.3056[0m         [32m0.3584[0m        [35m0.6313[0m     +  0.0460
      5        [36m0.2774[0m         [32m0.3771[0m        [35m0.6102[0m     +  0.0288
      6        [36m0.2589[0m         [32m0.3948[0m        [35m0.5801[0m     +  0.0245
      7        [36m0.2408[0m         [32m0.4083[0m        [35m0.5403[0m     +  0.0241
      8        [36m0.2228[0m         [32m0.4139[0m        [35m0.4951[0m     +  0.0242
      9        [36m0.2085[0m         [32m0.4256[0m        [35m0.4526[0m     +  0.0244
     10        [36m0.1964[0m         0.4

INFO:root:Max len not set, using empirical max len of 20
INFO:root:Using maximum length of 20
INFO:root:Split train with 239 examples
INFO:root:Split test with 103 examples
INFO:root:Running in temporary dir: /tmp/tmp4m067d32


  epoch    train_loss    valid_auprc    valid_loss    cp     dur
-------  ------------  -------------  ------------  ----  ------
      1        [36m1.2464[0m         [32m0.1333[0m        [35m0.7201[0m     +  0.0179
      2        [36m0.8146[0m         0.1268        [35m0.6895[0m        0.0114
      3        [36m0.5351[0m         0.1169        [35m0.6591[0m        0.0206
      4        [36m0.3710[0m         0.1233        [35m0.6306[0m        0.0139
      5        [36m0.2826[0m         0.1233        [35m0.6011[0m        0.0117
      6        [36m0.2344[0m         [32m0.1417[0m        [35m0.5706[0m     +  0.0182
      7        [36m0.2041[0m         [32m0.1726[0m        [35m0.5392[0m     +  0.0116
      8        [36m0.1820[0m         0.1500        [35m0.5081[0m        0.0093
      9        [36m0.1633[0m         0.1333        [35m0.4773[0m        0.0087
     10        [36m0.1460[0m         0.1339        [35m0.4471[0m        0.0088
     11      

INFO:root:Max len not set, using empirical max len of 20
INFO:root:Using maximum length of 20
INFO:root:Split train with 239 examples
INFO:root:Split test with 103 examples
INFO:root:Running in temporary dir: /tmp/tmp81bsk6rt


  epoch    train_loss    valid_auprc    valid_loss    cp     dur
-------  ------------  -------------  ------------  ----  ------
      1        [36m1.1707[0m         [32m0.0779[0m        [35m0.7178[0m     +  0.0100
      2        [36m0.7982[0m         [32m0.0800[0m        [35m0.6919[0m     +  0.0097
      3        [36m0.5535[0m         [32m0.0867[0m        [35m0.6673[0m     +  0.0094
      4        [36m0.4004[0m         [32m0.1258[0m        [35m0.6442[0m     +  0.0091
      5        [36m0.3047[0m         [32m0.2369[0m        [35m0.6202[0m     +  0.0092
      6        [36m0.2406[0m         [32m0.4695[0m        [35m0.5946[0m     +  0.0203
      7        [36m0.1940[0m         [32m0.6421[0m        [35m0.5683[0m     +  0.0209
      8        [36m0.1599[0m         [32m0.7403[0m        [35m0.5407[0m     +  0.0172
      9        [36m0.1341[0m         [32m0.8345[0m        [35m0.5130[0m     +  0.0204
     10        [36m0.1141[0m         [3

INFO:root:Max len not set, using empirical max len of 20
INFO:root:Using maximum length of 20
INFO:root:Split train with 143 examples
INFO:root:Split test with 61 examples
INFO:root:Running in temporary dir: /tmp/tmp7bqs92ph


  epoch    train_loss    valid_auprc    valid_loss    cp     dur
-------  ------------  -------------  ------------  ----  ------
      1        [36m1.1996[0m         [32m0.3340[0m        [35m0.7207[0m     +  0.0083
      2        [36m0.8160[0m         0.1909        [35m0.7004[0m        0.0085
      3        [36m0.5559[0m         0.2061        [35m0.6830[0m        0.0078
      4        [36m0.3962[0m         0.2390        [35m0.6676[0m        0.0077
      5        [36m0.2994[0m         0.1671        [35m0.6527[0m        0.0076
      6        [36m0.2382[0m         0.1689        [35m0.6381[0m        0.0075
      7        [36m0.1950[0m         0.1740        [35m0.6232[0m        0.0076
      8        [36m0.1616[0m         0.1679        [35m0.6076[0m        0.0074
      9        [36m0.1348[0m         0.1692        [35m0.5921[0m        0.0073
     10        [36m0.1126[0m         0.1728        [35m0.5773[0m        0.0075
     11        [36m0.0947[0m 

INFO:root:Max len not set, using empirical max len of 20
INFO:root:Using maximum length of 20
INFO:root:Split train with 433 examples
INFO:root:Split test with 185 examples
INFO:root:Running in temporary dir: /tmp/tmpzyq3b0on


  epoch    train_loss    valid_auprc    valid_loss    cp     dur
-------  ------------  -------------  ------------  ----  ------
      1        [36m1.2429[0m         [32m0.0494[0m        [35m0.7185[0m     +  0.0120
      2        [36m0.8851[0m         [32m0.0507[0m        [35m0.6894[0m     +  0.0119
      3        [36m0.6483[0m         0.0449        [35m0.6621[0m        0.0276
      4        [36m0.5047[0m         0.0469        [35m0.6345[0m        0.0263
      5        [36m0.4204[0m         0.0472        [35m0.6068[0m        0.0157
      6        [36m0.3706[0m         0.0488        [35m0.5790[0m        0.0126
      7        [36m0.3382[0m         [32m0.0510[0m        [35m0.5513[0m     +  0.0120
      8        [36m0.3148[0m         [32m0.0526[0m        [35m0.5245[0m     +  0.0119
      9        [36m0.2948[0m         0.0524        [35m0.4988[0m        0.0119
     10        [36m0.2756[0m         [32m0.0544[0m        [35m0.4748[0m     +  0.

INFO:root:Max len not set, using empirical max len of 20
INFO:root:Using maximum length of 20
INFO:root:Split train with 412 examples
INFO:root:Split test with 176 examples
INFO:root:Running in temporary dir: /tmp/tmplwv5snhy


  epoch    train_loss    valid_auprc    valid_loss    cp     dur
-------  ------------  -------------  ------------  ----  ------
      1        [36m1.1473[0m         [32m0.2450[0m        [35m0.7086[0m     +  0.0120
      2        [36m0.8232[0m         [32m0.2499[0m        [35m0.6830[0m     +  0.0118
      3        [36m0.6228[0m         [32m0.2514[0m        [35m0.6585[0m     +  0.0116
      4        [36m0.5076[0m         0.2378        [35m0.6350[0m        0.0118
      5        [36m0.4447[0m         0.2239        [35m0.6125[0m        0.0117
      6        [36m0.4084[0m         0.2066        [35m0.5910[0m        0.0112
      7        [36m0.3845[0m         0.2022        [35m0.5710[0m        0.0118
      8        [36m0.3653[0m         0.2013        [35m0.5526[0m        0.0116
      9        [36m0.3470[0m         0.2237        [35m0.5360[0m        0.0113
     10        [36m0.3285[0m         0.2359        [35m0.5209[0m        0.0116
     11      

INFO:root:Max len not set, using empirical max len of 20
INFO:root:Using maximum length of 20
INFO:root:Split train with 189 examples
INFO:root:Split test with 81 examples
INFO:root:Running in temporary dir: /tmp/tmpr4r7kskh


  epoch    train_loss    valid_auprc    valid_loss    cp     dur
-------  ------------  -------------  ------------  ----  ------
      1        [36m1.2010[0m         [32m0.0629[0m        [35m0.7268[0m     +  0.0131
      2        [36m0.8359[0m         0.0625        [35m0.6996[0m        0.0151
      3        [36m0.5818[0m         [32m0.0644[0m        [35m0.6751[0m     +  0.0118
      4        [36m0.4236[0m         [32m0.0730[0m        [35m0.6525[0m     +  0.0095
      5        [36m0.3275[0m         [32m0.0767[0m        [35m0.6302[0m     +  0.0087
      6        [36m0.2687[0m         [32m0.0878[0m        [35m0.6087[0m     +  0.0083
      7        [36m0.2286[0m         [32m0.1065[0m        [35m0.5881[0m     +  0.0081
      8        [36m0.1968[0m         [32m0.1175[0m        [35m0.5684[0m     +  0.0082
      9        [36m0.1698[0m         [32m0.1739[0m        [35m0.5500[0m     +  0.0080
     10        [36m0.1461[0m         [32m0.1765

INFO:root:Max len not set, using empirical max len of 19
INFO:root:Using maximum length of 19
INFO:root:Split train with 151 examples
INFO:root:Split test with 65 examples
INFO:root:Running in temporary dir: /tmp/tmpdic8hsl0


  epoch    train_loss    valid_auprc    valid_loss    cp     dur
-------  ------------  -------------  ------------  ----  ------
      1        [36m1.1393[0m         [32m0.2104[0m        [35m0.6430[0m     +  0.0091
      2        [36m0.7847[0m         0.2102        [35m0.6425[0m        0.0137
      3        [36m0.5477[0m         [32m0.2245[0m        [35m0.6404[0m     +  0.0157
      4        [36m0.3977[0m         [32m0.3035[0m        [35m0.6352[0m     +  0.0106
      5        [36m0.3056[0m         [32m0.3435[0m        [35m0.6258[0m     +  0.0184
      6        [36m0.2474[0m         [32m0.4196[0m        [35m0.6135[0m     +  0.0112
      7        [36m0.2072[0m         [32m0.4313[0m        [35m0.6004[0m     +  0.0086
      8        [36m0.1763[0m         [32m0.4316[0m        [35m0.5862[0m     +  0.0081
      9        [36m0.1509[0m         [32m0.5280[0m        [35m0.5722[0m     +  0.0123
     10        [36m0.1289[0m         0.5234      

INFO:root:Max len not set, using empirical max len of 19
INFO:root:Using maximum length of 19
INFO:root:Split train with 223 examples
INFO:root:Split test with 95 examples
INFO:root:Running in temporary dir: /tmp/tmpbn596r9q


  epoch    train_loss    valid_auprc    valid_loss    cp     dur
-------  ------------  -------------  ------------  ----  ------
      1        [36m1.1991[0m         [32m0.2847[0m        [35m0.6259[0m     +  0.0172
      2        [36m0.8544[0m         [32m0.4417[0m        [35m0.6229[0m     +  0.0228
      3        [36m0.6121[0m         [32m0.6046[0m        [35m0.6180[0m     +  0.0186
      4        [36m0.4579[0m         [32m0.6514[0m        [35m0.6094[0m     +  0.0130
      5        [36m0.3653[0m         0.6514        [35m0.5986[0m        0.0101
      6        [36m0.3076[0m         [32m0.6521[0m        [35m0.5855[0m     +  0.0091
      7        [36m0.2678[0m         [32m0.6528[0m        [35m0.5699[0m     +  0.0091
      8        [36m0.2372[0m         [32m0.6544[0m        [35m0.5524[0m     +  0.0091
      9        [36m0.2111[0m         0.6052        [35m0.5341[0m        0.0145
     10        [36m0.1878[0m         0.6061        [35m0.

INFO:root:Max len not set, using empirical max len of 18
INFO:root:Using maximum length of 18
INFO:root:Split train with 151 examples
INFO:root:Split test with 65 examples
INFO:root:Running in temporary dir: /tmp/tmpyo_insv6


  epoch    train_loss    valid_auprc    valid_loss    cp     dur
-------  ------------  -------------  ------------  ----  ------
      1        [36m1.1150[0m         [32m0.2016[0m        [35m0.6922[0m     +  0.0085
      2        [36m0.7710[0m         [32m0.2025[0m        [35m0.6770[0m     +  0.0133
      3        [36m0.5323[0m         [32m0.2027[0m        [35m0.6627[0m     +  0.0114
      4        [36m0.3779[0m         [32m0.2033[0m        [35m0.6501[0m     +  0.0159
      5        [36m0.2822[0m         0.1965        [35m0.6358[0m        0.0110
      6        [36m0.2228[0m         0.1991        [35m0.6210[0m        0.0092
      7        [36m0.1833[0m         [32m0.2172[0m        [35m0.6043[0m     +  0.0185
      8        [36m0.1543[0m         [32m0.2589[0m        [35m0.5864[0m     +  0.0114
      9        [36m0.1317[0m         [32m0.2953[0m        [35m0.5684[0m     +  0.0144
     10        [36m0.1133[0m         0.2949        [35m0.

INFO:root:Max len not set, using empirical max len of 20
INFO:root:Using maximum length of 20
INFO:root:Split train with 185 examples
INFO:root:Split test with 79 examples
INFO:root:Running in temporary dir: /tmp/tmpd8rirws1


  epoch    train_loss    valid_auprc    valid_loss    cp     dur
-------  ------------  -------------  ------------  ----  ------
      1        [36m1.2546[0m         [32m0.1150[0m        [35m0.7139[0m     +  0.0090
      2        [36m0.8876[0m         [32m0.1182[0m        [35m0.6886[0m     +  0.0086
      3        [36m0.6356[0m         0.1137        [35m0.6643[0m        0.0083
      4        [36m0.4729[0m         [32m0.1349[0m        [35m0.6403[0m     +  0.0083
      5        [36m0.3684[0m         [32m0.1705[0m        [35m0.6152[0m     +  0.0084
      6        [36m0.2992[0m         [32m0.2635[0m        [35m0.5897[0m     +  0.0084
      7        [36m0.2507[0m         [32m0.3031[0m        [35m0.5637[0m     +  0.0084
      8        [36m0.2149[0m         [32m0.4083[0m        [35m0.5373[0m     +  0.0082
      9        [36m0.1873[0m         [32m0.5583[0m        [35m0.5112[0m     +  0.0081
     10        [36m0.1648[0m         0.5583      

INFO:root:Max len not set, using empirical max len of 20
INFO:root:Using maximum length of 20
INFO:root:Split train with 109 examples
INFO:root:Split test with 47 examples
INFO:root:Running in temporary dir: /tmp/tmp6ruw6vb3


  epoch    train_loss    valid_auprc    valid_loss    cp     dur
-------  ------------  -------------  ------------  ----  ------
      1        [36m0.9224[0m         [32m0.3409[0m        [35m0.7124[0m     +  0.0077
      2        [36m0.6057[0m         0.3167        [35m0.6895[0m        0.0075
      3        [36m0.4126[0m         0.2255        [35m0.6686[0m        0.0070
      4        [36m0.2945[0m         0.1838        [35m0.6487[0m        0.0072
      5        [36m0.2208[0m         0.1875        [35m0.6291[0m        0.0168
      6        [36m0.1704[0m         0.1500        [35m0.6104[0m        0.0126
      7        [36m0.1342[0m         0.1667        [35m0.5922[0m        0.0102
      8        [36m0.1079[0m         0.1667        [35m0.5744[0m        0.0087
      9        [36m0.0885[0m         0.1667        [35m0.5573[0m        0.0081
     10        [36m0.0735[0m         0.1667        [35m0.5407[0m        0.0076
     11        [36m0.0622[0m 

INFO:root:Max len not set, using empirical max len of 25
INFO:root:Using maximum length of 25
INFO:root:Split train with 286 examples
INFO:root:Split test with 122 examples
INFO:root:Running in temporary dir: /tmp/tmp8j89jjmb


  epoch    train_loss    valid_auprc    valid_loss    cp     dur
-------  ------------  -------------  ------------  ----  ------
      1        [36m0.5441[0m         [32m0.2548[0m        [35m0.6826[0m     +  0.0106
      2        [36m0.4173[0m         [32m0.2591[0m        [35m0.6782[0m     +  0.0103
      3        [36m0.3536[0m         0.2277        [35m0.6727[0m        0.0100
      4        [36m0.3036[0m         0.2480        [35m0.6677[0m        0.0100
      5        [36m0.2573[0m         [32m0.2766[0m        [35m0.6629[0m     +  0.0095
      6        [36m0.2171[0m         [32m0.2970[0m        [35m0.6582[0m     +  0.0098
      7        [36m0.1848[0m         [32m0.3005[0m        [35m0.6532[0m     +  0.0100
      8        [36m0.1608[0m         [32m0.3151[0m        [35m0.6468[0m     +  0.0098
      9        [36m0.1430[0m         [32m0.3181[0m        [35m0.6381[0m     +  0.0100
     10        [36m0.1286[0m         [32m0.3196[0m      

INFO:root:Max len not set, using empirical max len of 20
INFO:root:Using maximum length of 20
INFO:root:Split train with 176 examples
INFO:root:Split test with 76 examples
INFO:root:Running in temporary dir: /tmp/tmpolz6zycm


  epoch    train_loss    valid_auprc    valid_loss    cp     dur
-------  ------------  -------------  ------------  ----  ------
      1        [36m1.1771[0m         [32m0.2527[0m        [35m0.7070[0m     +  0.0087
      2        [36m0.7935[0m         0.2208        [35m0.6878[0m        0.0087
      3        [36m0.5495[0m         0.1570        [35m0.6711[0m        0.0082
      4        [36m0.4058[0m         0.1364        [35m0.6559[0m        0.0083
      5        [36m0.3197[0m         0.1431        [35m0.6395[0m        0.0082
      6        [36m0.2641[0m         0.1455        [35m0.6229[0m        0.0083
      7        [36m0.2242[0m         0.1509        [35m0.6069[0m        0.0083
      8        [36m0.1924[0m         0.1583        [35m0.5914[0m        0.0081
      9        [36m0.1658[0m         0.1611        [35m0.5768[0m        0.0081
     10        [36m0.1434[0m         0.1678        [35m0.5630[0m        0.0081
     11        [36m0.1247[0m 

INFO:root:Max len not set, using empirical max len of 23
INFO:root:Using maximum length of 23
INFO:root:Split train with 370 examples
INFO:root:Split test with 158 examples
INFO:root:Running in temporary dir: /tmp/tmph25jq24g


  epoch    train_loss    valid_auprc    valid_loss    cp     dur
-------  ------------  -------------  ------------  ----  ------
      1        [36m0.8361[0m         [32m0.1259[0m        [35m0.6743[0m     +  0.0115
      2        [36m0.5334[0m         [32m0.1315[0m        [35m0.6667[0m     +  0.0112
      3        [36m0.3723[0m         [32m0.1570[0m        [35m0.6559[0m     +  0.0111
      4        [36m0.2995[0m         [32m0.1960[0m        [35m0.6423[0m     +  0.0109
      5        [36m0.2613[0m         [32m0.2298[0m        [35m0.6259[0m     +  0.0110
      6        [36m0.2340[0m         [32m0.2878[0m        [35m0.6063[0m     +  0.0111
      7        [36m0.2105[0m         [32m0.3931[0m        [35m0.5846[0m     +  0.0110
      8        [36m0.1896[0m         [32m0.4917[0m        [35m0.5615[0m     +  0.0113
      9        [36m0.1709[0m         [32m0.5740[0m        [35m0.5379[0m     +  0.0114
     10        [36m0.1547[0m         [3

INFO:root:Max len not set, using empirical max len of 19
INFO:root:Using maximum length of 19
INFO:root:Split train with 101 examples
INFO:root:Split test with 43 examples
INFO:root:Running in temporary dir: /tmp/tmppv90s_3i


  epoch    train_loss    valid_auprc    valid_loss    cp     dur
-------  ------------  -------------  ------------  ----  ------
      1        [36m1.0641[0m         [32m0.1250[0m        [35m0.6137[0m     +  0.0157
      2        [36m0.7181[0m         0.0909        [35m0.6135[0m        0.0097
      3        [36m0.4838[0m         0.1111        [35m0.6125[0m        0.0082
      4        [36m0.3361[0m         0.1250        [35m0.6080[0m        0.0079
      5        [36m0.2437[0m         0.1250        [35m0.5999[0m        0.0077
      6        [36m0.1853[0m         [32m0.1667[0m        [35m0.5894[0m     +  0.0137
      7        [36m0.1466[0m         [32m0.2000[0m        [35m0.5760[0m     +  0.0088
      8        [36m0.1196[0m         [32m0.3333[0m        [35m0.5606[0m     +  0.0080
      9        [36m0.0996[0m         0.3333        [35m0.5438[0m        0.0073
     10        [36m0.0842[0m         [32m1.0000[0m        [35m0.5263[0m     +  0.

INFO:root:Max len not set, using empirical max len of 19
INFO:root:Using maximum length of 19
INFO:root:Split train with 118 examples
INFO:root:Split test with 50 examples
INFO:root:Running in temporary dir: /tmp/tmp4u3kapzo


  epoch    train_loss    valid_auprc    valid_loss    cp     dur
-------  ------------  -------------  ------------  ----  ------
      1        [36m1.1715[0m         [32m0.1198[0m        [35m0.6294[0m     +  0.0077
      2        [36m0.8230[0m         [32m0.1263[0m        [35m0.6255[0m     +  0.0078
      3        [36m0.5663[0m         [32m0.1287[0m        [35m0.6219[0m     +  0.0073
      4        [36m0.3985[0m         [32m0.1395[0m        [35m0.6180[0m     +  0.0073
      5        [36m0.2944[0m         0.1368        [35m0.6116[0m        0.0071
      6        [36m0.2306[0m         [32m0.1513[0m        [35m0.6029[0m     +  0.0070
      7        [36m0.1880[0m         [32m0.1958[0m        [35m0.5934[0m     +  0.0125
      8        [36m0.1572[0m         [32m0.2494[0m        [35m0.5831[0m     +  0.0117
      9        [36m0.1322[0m         [32m0.4180[0m        [35m0.5729[0m     +  0.0132
     10        [36m0.1113[0m         [32m0.4184

INFO:root:Max len not set, using empirical max len of 35
INFO:root:Using maximum length of 35
INFO:root:Split train with 17283 examples
INFO:root:Split test with 7407 examples
INFO:root:Running in temporary dir: /tmp/tmp71x_y9yl


  epoch    train_loss    valid_auprc    valid_loss    cp     dur
-------  ------------  -------------  ------------  ----  ------
      1        [36m0.4798[0m         [32m0.1882[0m        [35m0.4904[0m     +  0.3245
      2        [36m0.4346[0m         [32m0.1936[0m        [35m0.4524[0m     +  0.4162
      3        [36m0.4221[0m         [32m0.1953[0m        0.4541     +  0.3193
      4        [36m0.4128[0m         0.1930        0.4567        0.3874
      5        [36m0.4045[0m         0.1939        0.4588        0.3606
      6        [36m0.3959[0m         0.1908        0.4627        0.3627
      7        [36m0.3871[0m         0.1896        0.4676        0.3335
      8        [36m0.3786[0m         0.1902        0.4737        0.3190
      9        [36m0.3689[0m         0.1909        0.4794        0.4057
     10        [36m0.3596[0m         0.1932        0.4841        0.3419
     11        [36m0.3493[0m         0.1907        0.4923        0.5055
     12    

INFO:root:Max len not set, using empirical max len of 21
INFO:root:Using maximum length of 21
INFO:root:Split train with 88 examples
INFO:root:Split test with 38 examples
INFO:root:Running in temporary dir: /tmp/tmpjd_60k6z


Stopping since valid_auprc has not improved in the last 25 epochs.
  epoch    train_loss    valid_auprc    valid_loss    cp     dur
-------  ------------  -------------  ------------  ----  ------
      1        [36m1.5857[0m            nan        [35m0.8192[0m        0.0072
      2        [36m1.0736[0m            nan        [35m0.7890[0m        0.0071
      3        [36m0.7057[0m            nan        [35m0.7572[0m        0.0068
      4        [36m0.4637[0m            nan        [35m0.7258[0m        0.0069
      5        [36m0.3181[0m            nan        [35m0.6945[0m        0.0069
      6        [36m0.2325[0m            nan        [35m0.6629[0m        0.0068
      7        [36m0.1780[0m            nan        [35m0.6318[0m        0.0067
      8        [36m0.1384[0m            nan        [35m0.6013[0m        0.0068
      9        [36m0.1076[0m            nan        [35m0.5709[0m        0.0069
     10        [36m0.0838[0m            nan        [3

  recall = tps / tps[-1]
  recall = tps / tps[-1]
  recall = tps / tps[-1]
  recall = tps / tps[-1]
  recall = tps / tps[-1]
  recall = tps / tps[-1]
  recall = tps / tps[-1]
  recall = tps / tps[-1]
  recall = tps / tps[-1]
  recall = tps / tps[-1]
  recall = tps / tps[-1]
  recall = tps / tps[-1]
  recall = tps / tps[-1]
  recall = tps / tps[-1]
  recall = tps / tps[-1]
  recall = tps / tps[-1]


     17        [36m0.0239[0m            nan        [35m0.3612[0m        0.0166


  recall = tps / tps[-1]


     18        [36m0.0214[0m            nan        [35m0.3404[0m        0.0128
     19        [36m0.0194[0m            nan        [35m0.3208[0m        0.0096
     20        [36m0.0178[0m            nan        [35m0.3026[0m        0.0075
     21        [36m0.0165[0m            nan        [35m0.2856[0m        0.0073
     22        [36m0.0153[0m            nan        [35m0.2702[0m        0.0072
     23        [36m0.0143[0m            nan        [35m0.2561[0m        0.0070
     24        [36m0.0133[0m            nan        [35m0.2435[0m        0.0103
Stopping since valid_auprc has not improved in the last 25 epochs.


  recall = tps / tps[-1]
  recall = tps / tps[-1]
  recall = tps / tps[-1]
  recall = tps / tps[-1]
  recall = tps / tps[-1]
  recall = tps / tps[-1]
  recall = tps / tps[-1]
  recall = tps / tps[-1]
INFO:root:Max len not set, using empirical max len of 21
INFO:root:Using maximum length of 21
INFO:root:Split train with 298 examples
INFO:root:Split test with 128 examples
INFO:root:Running in temporary dir: /tmp/tmpmxk2x5z7


  epoch    train_loss    valid_auprc    valid_loss    cp     dur
-------  ------------  -------------  ------------  ----  ------
      1        [36m1.4493[0m         [32m0.3708[0m        [35m0.7590[0m     +  0.0108
      2        [36m1.0123[0m         0.3689        [35m0.7372[0m        0.0105
      3        [36m0.7170[0m         0.3456        [35m0.7150[0m        0.0102
      4        [36m0.5423[0m         0.2977        [35m0.6939[0m        0.0101
      5        [36m0.4435[0m         0.3199        [35m0.6740[0m        0.0132
      6        [36m0.3857[0m         0.3261        [35m0.6539[0m        0.0101
      7        [36m0.3470[0m         0.3453        [35m0.6339[0m        0.0102
      8        [36m0.3172[0m         0.3662        [35m0.6142[0m        0.0100
      9        [36m0.2893[0m         [32m0.3798[0m        [35m0.5943[0m     +  0.0102
     10        [36m0.2622[0m         [32m0.4076[0m        [35m0.5750[0m     +  0.0102
     11      

INFO:root:Max len not set, using empirical max len of 22
INFO:root:Using maximum length of 22
INFO:root:Split train with 844 examples
INFO:root:Split test with 362 examples
INFO:root:Running in temporary dir: /tmp/tmpohn9qu78


  epoch    train_loss    valid_auprc    valid_loss    cp     dur
-------  ------------  -------------  ------------  ----  ------
      1        [36m0.8431[0m         [32m0.1389[0m        [35m0.6621[0m     +  0.0404
      2        [36m0.5789[0m         0.1257        [35m0.6549[0m        0.0405
      3        [36m0.5304[0m         0.1360        [35m0.6482[0m        0.0286
      4        [36m0.4888[0m         [32m0.1589[0m        [35m0.6372[0m     +  0.0340
      5        [36m0.4394[0m         [32m0.1996[0m        [35m0.6246[0m     +  0.0362
      6        [36m0.3997[0m         [32m0.2243[0m        [35m0.6093[0m     +  0.0235
      7        [36m0.3736[0m         [32m0.2383[0m        [35m0.5887[0m     +  0.0225
      8        [36m0.3541[0m         [32m0.2433[0m        [35m0.5606[0m     +  0.0333
      9        [36m0.3353[0m         [32m0.2581[0m        [35m0.5262[0m     +  0.0344
     10        [36m0.3174[0m         [32m0.2725[0m      

INFO:root:Max len not set, using empirical max len of 20
INFO:root:Using maximum length of 20
INFO:root:Split train with 601 examples
INFO:root:Split test with 257 examples
INFO:root:Running in temporary dir: /tmp/tmpnr4npk6d


  epoch    train_loss    valid_auprc    valid_loss    cp     dur
-------  ------------  -------------  ------------  ----  ------
      1        [36m1.1974[0m         [32m0.1567[0m        [35m0.7124[0m     +  0.0143
      2        [36m0.8324[0m         [32m0.1665[0m        [35m0.6834[0m     +  0.0147
      3        [36m0.5891[0m         0.1611        [35m0.6551[0m        0.0144
      4        [36m0.4419[0m         0.1336        [35m0.6279[0m        0.0143
      5        [36m0.3563[0m         0.1239        [35m0.6005[0m        0.0142
      6        [36m0.3080[0m         0.1217        [35m0.5729[0m        0.0141
      7        [36m0.2787[0m         0.1059        [35m0.5458[0m        0.0142
      8        [36m0.2583[0m         0.1029        [35m0.5193[0m        0.0143
      9        [36m0.2415[0m         0.1024        [35m0.4937[0m        0.0139
     10        [36m0.2258[0m         0.1010        [35m0.4696[0m        0.0141
     11        [36m0.

INFO:root:Max len not set, using empirical max len of 19
INFO:root:Using maximum length of 19
INFO:root:Split train with 155 examples
INFO:root:Split test with 67 examples
INFO:root:Running in temporary dir: /tmp/tmppir70u43


  epoch    train_loss    valid_auprc    valid_loss    cp     dur
-------  ------------  -------------  ------------  ----  ------
      1        [36m1.3454[0m         [32m0.3061[0m        [35m0.6326[0m     +  0.0182
      2        [36m0.9442[0m         [32m0.3390[0m        [35m0.6301[0m     +  0.0111
      3        [36m0.6530[0m         [32m0.4885[0m        [35m0.6248[0m     +  0.0085
      4        [36m0.4566[0m         0.3512        [35m0.6171[0m        0.0080
      5        [36m0.3300[0m         0.3443        [35m0.6077[0m        0.0078
      6        [36m0.2482[0m         0.3967        [35m0.5970[0m        0.0078
      7        [36m0.1940[0m         [32m0.4889[0m        [35m0.5843[0m     +  0.0078
      8        [36m0.1560[0m         [32m0.5572[0m        [35m0.5698[0m     +  0.0078
      9        [36m0.1281[0m         0.5542        [35m0.5539[0m        0.0077
     10        [36m0.1064[0m         0.5547        [35m0.5368[0m        0.

INFO:root:Max len not set, using empirical max len of 26
INFO:root:Using maximum length of 26
INFO:root:Split train with 256 examples
INFO:root:Split test with 110 examples
INFO:root:Running in temporary dir: /tmp/tmpj3961ygl


  epoch    train_loss    valid_auprc    valid_loss    cp     dur
-------  ------------  -------------  ------------  ----  ------
      1        [36m0.6288[0m         [32m0.2042[0m        [35m0.7009[0m     +  0.0231
      2        [36m0.4456[0m         [32m0.2260[0m        [35m0.6847[0m     +  0.0135
      3        [36m0.3675[0m         [32m0.2886[0m        [35m0.6717[0m     +  0.0103
      4        [36m0.3279[0m         [32m0.3155[0m        [35m0.6602[0m     +  0.0102
      5        [36m0.2965[0m         [32m0.3433[0m        [35m0.6499[0m     +  0.0100
      6        [36m0.2650[0m         [32m0.3752[0m        [35m0.6395[0m     +  0.0098
      7        [36m0.2340[0m         [32m0.3936[0m        [35m0.6292[0m     +  0.0147
      8        [36m0.2061[0m         0.3672        [35m0.6191[0m        0.0166
      9        [36m0.1838[0m         0.3779        [35m0.6072[0m        0.0189
     10        [36m0.1660[0m         0.3810        [35m0.

INFO:root:Max len not set, using empirical max len of 19
INFO:root:Using maximum length of 19
INFO:root:Split train with 101 examples
INFO:root:Split test with 43 examples
INFO:root:Running in temporary dir: /tmp/tmpvniyh0u1


  epoch    train_loss    valid_auprc    valid_loss    cp     dur
-------  ------------  -------------  ------------  ----  ------
      1        [36m1.3579[0m            nan        [35m0.6084[0m        0.0121
      2        [36m0.9240[0m            nan        [35m0.6062[0m        0.0108
      3        [36m0.6176[0m            nan        [35m0.6032[0m        0.0104
      4        [36m0.4091[0m            nan        [35m0.5955[0m        0.0107
      5        [36m0.2766[0m            nan        [35m0.5825[0m        0.0091
      6        [36m0.1941[0m            nan        [35m0.5659[0m        0.0083
      7        [36m0.1441[0m            nan        [35m0.5468[0m        0.0078
      8        [36m0.1128[0m            nan        [35m0.5259[0m        0.0075
      9        [36m0.0914[0m            nan        [35m0.5028[0m        0.0074
     10        [36m0.0755[0m            nan        [35m0.4783[0m        0.0072
     11        [36m0.0633[0m          

  recall = tps / tps[-1]
  recall = tps / tps[-1]
  recall = tps / tps[-1]
  recall = tps / tps[-1]
  recall = tps / tps[-1]
  recall = tps / tps[-1]
  recall = tps / tps[-1]
  recall = tps / tps[-1]
  recall = tps / tps[-1]
  recall = tps / tps[-1]
  recall = tps / tps[-1]
  recall = tps / tps[-1]
  recall = tps / tps[-1]
  recall = tps / tps[-1]
  recall = tps / tps[-1]
  recall = tps / tps[-1]
  recall = tps / tps[-1]
  recall = tps / tps[-1]
  recall = tps / tps[-1]


     20        [36m0.0209[0m            nan        [35m0.2449[0m        0.0065


  recall = tps / tps[-1]


     21        [36m0.0191[0m            nan        [35m0.2283[0m        0.0163
     22        [36m0.0174[0m            nan        [35m0.2130[0m        0.0139
     23        [36m0.0160[0m            nan        [35m0.1990[0m        0.0124
     24        [36m0.0147[0m            nan        [35m0.1865[0m        0.0102
Stopping since valid_auprc has not improved in the last 25 epochs.


  recall = tps / tps[-1]
  recall = tps / tps[-1]
  recall = tps / tps[-1]
  recall = tps / tps[-1]
  recall = tps / tps[-1]
INFO:root:Max len not set, using empirical max len of 20
INFO:root:Using maximum length of 20
INFO:root:Split train with 945 examples
INFO:root:Split test with 405 examples
INFO:root:Running in temporary dir: /tmp/tmpjzm1bsow


  epoch    train_loss    valid_auprc    valid_loss    cp     dur
-------  ------------  -------------  ------------  ----  ------
      1        [36m1.1910[0m         [32m0.0972[0m        [35m0.6904[0m     +  0.0530
      2        [36m0.6289[0m         [32m0.1085[0m        [35m0.6359[0m     +  0.0259
      3        [36m0.4077[0m         [32m0.1159[0m        [35m0.5790[0m     +  0.0226
      4        [36m0.3352[0m         [32m0.1189[0m        [35m0.5210[0m     +  0.0218
      5        [36m0.3094[0m         [32m0.1261[0m        [35m0.4664[0m     +  0.0347
      6        [36m0.2943[0m         [32m0.1351[0m        [35m0.4189[0m     +  0.0328
      7        [36m0.2798[0m         [32m0.1444[0m        [35m0.3806[0m     +  0.0222
      8        [36m0.2637[0m         [32m0.1696[0m        [35m0.3514[0m     +  0.0880
      9        [36m0.2463[0m         [32m0.1824[0m        [35m0.3299[0m     +  0.0229
     10        [36m0.2290[0m         [3

73

In [18]:
antigen_auprc_pairs = [pair for pair in antigen_auprc_pairs if not np.isnan(pair[1])]
len(antigen_auprc_pairs)

26

In [19]:
df = pd.DataFrame(
    [pair[1] for pair in antigen_auprc_pairs],
    index=[pair[0] for pair in antigen_auprc_pairs],
    columns=['ConvNet']
)
df

Unnamed: 0,ConvNet
LLWNGPMAV,0.530086
RPRGEVRFL,0.386815
ATDALMTGY,0.878661
HSKKKCDEL,0.268149
KAFSPEVIPMF,0.561982
KRWIILGLNK,0.413079
TPQDLNTML,0.388053
EIYKRWII,0.162382
HPKVSSEVHI,0.583317
IIKDYGKQM,0.623517


In [20]:
if FILT_EDIT_DIST:
    df.to_csv("antigen_cv_convnet_baseline_edit_dist_filt.csv")
else:
    df.to_csv("antigen_cv_convnet_baseline.csv")