In [1]:
import os, sys
import logging
import glob
import json
import itertools
import collections
import importlib
import tempfile
from typing import *

import tqdm

import numpy as np
import pandas as pd
from sklearn import metrics
from sklearn.svm import SVC
from sklearn.ensemble import GradientBoostingClassifier
from matplotlib import pyplot as plt
import seaborn as sns
from scipy import stats

import anndata as ad
import scanpy as sc

import torch
import torch.nn as nn
import skorch
import skorch.helper

from transformers import BertModel, BertForMaskedLM, BertTokenizer, FeatureExtractionPipeline

SRC_DIR = os.path.join(os.path.dirname(os.getcwd()), "tcr")
assert os.path.isdir(SRC_DIR), f"Cannot find src dir: {SRC_DIR}"
sys.path.append(SRC_DIR)
import custom_metrics
import data_loader as dl
import featurization as ft
import canonical_models as models
import model_utils
import plot_utils
import utils
MODEL_DIR = os.path.join(SRC_DIR, "models")
sys.path.append(MODEL_DIR)
import transformer_custom as trans
import conv

DEVICE = utils.get_device(3)
TRAINED_MODEL_DIR = "/home/wukevin/projects/tcr/tcr_models"
SCRIPTS_DIR = os.path.join(
    os.path.dirname(SRC_DIR),
    "scripts",
)
assert os.path.isdir(SCRIPTS_DIR)

PLOT_DIR = os.path.join(os.path.dirname(SRC_DIR), "plots/pird_antigen_cv")
assert os.path.isdir(PLOT_DIR)

In [2]:
# Load in the PIRD dataset
pird_data = dl.load_pird(with_antigen_only=True)
pird_data = pird_data.loc[~pd.isnull(pird_data['CDR3.beta.aa'])]
pird_data

INFO:root:PIRD data 0.1655 data labelled with antigen sequence
INFO:root:PIRD: Removing 95 entires with non amino acid residues
INFO:root:Entries with antigen sequence: 8429/51044
INFO:root:Unique antigen sequences: 73
INFO:root:PIRD data TRA/TRB instances: Counter({'TRB': 46428, 'TRA': 4011, 'TRA-TRB': 605})
INFO:root:PIRD entries with TRB sequence: 4607
INFO:root:PIRD entries with TRB sequence: 47040
INFO:root:PIRD entries with TRA and TRB:  605


Unnamed: 0,ICDname,Disease.name,Category,Antigen,Antigen.sequence,HLA,Locus,CDR3.alpha.aa,CDR3.beta.aa,CDR3.alpha.nt,...,Cell.subtype,Prepare.method,Evaluate.method,Case.num,Control.type,Control.num,Filteration,Journal,Pubmed.id,Grade
0,A15,Tuberculosis,Pathogen,CFP10,TAAQAAVVRFQEAAN,DRB1*15:03,TRA-TRB,CIEHTNSGGSNYKLTF,CASSLEETQYF,,...,CD4,Multiple PCR,Antigen-specific ex vivo proliferation,22.0,,,,Nature,28636589,5
1,A15,Tuberculosis,Pathogen,CFP10,TAAQAAVVRFQEAAN,DRB1*15:03,TRA-TRB,CIVHTNSGGSNYKLTF,CASSPEETQYF,,...,CD4,Multiple PCR,Antigen-specific ex vivo proliferation,22.0,,,,Nature,28636589,5
2,A15,Tuberculosis,Pathogen,CFP10,TAAQAAVVRFQEAAN,DRB1*15:03,TRA-TRB,CIVKTNSGGSNYKLTF,CASSFEETQYF,,...,CD4,Multiple PCR,Antigen-specific ex vivo proliferation,22.0,,,,Nature,28636589,5
1845,A15,Tuberculosis,Pathogen,Rv1195,ADTLQSIGATTVASN,DRB1*15:03,TRA-TRB,CAGAGGGGFKTIF,CASSVALASGANVLTF,,...,CD4,Multiple PCR,Antigen-specific ex vivo proliferation,22.0,,,,Nature,28636589,5
1846,A15,Tuberculosis,Pathogen,Rv1195,ADTLQSIGATTVASN,DRB1*15:03,TRA-TRB,CAGPTGGSYIPTF,CASSVALATGEQYF,,...,CD4,Multiple PCR,Antigen-specific ex vivo proliferation,22.0,,,,Nature,28636589,5
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
51124,Z88,Allergy,Allergy,Beryllium sulfate,FWIDLFETIG,DPB1*02:01,TRB,,CASSLSQGGDTQYF,,...,CD4,,Selection of antigen- specific T cells using p...,,,,,J Immunol,24719461,4
51125,Z88,Allergy,Allergy,Beryllium sulfate,FWIDLFETIG,DPB1*02:01,TRB,,CASSLSQGGEKLF,,...,CD4,,Selection of antigen- specific T cells using p...,,,,,J Immunol,24719461,4
51126,Z88,Allergy,Allergy,Beryllium sulfate,FWIDLFETIG,DPB1*02:01,TRB,,CASSLSQGGRPMF,,...,CD4,,Selection of antigen- specific T cells using p...,,,,,J Immunol,24719461,4
51127,Z88,Allergy,Allergy,Beryllium sulfate,FWIDLFETIG,DPB1*02:01,TRB,,CASSMGQGGETQYF,,...,CD4,,Selection of antigen- specific T cells using p...,,,,,J Immunol,24719461,4


In [1]:
REDUCE_LR_ON_PLATEAU_PARAMS = {
    "mode": "min",
    "factor": 0.1,
    "patience": 10,
    "min_lr": 1e-6,
}

def convnet_on_antigen(antigen:str) -> Tuple[str, float]:
    """
    Train a convnet on the antigen and return Tuple[antigen, test_auprc]
    If invalid, return nan
    """
    pird_pos_table = pird_data.loc[pird_data['Antigen.sequence'] == antigen]
    #
    pird_pos_trbs = list(pird_pos_table['CDR3.beta.aa'])
    if len(pird_pos_trbs) < 20:
        return antigen, np.nan
    neg_trbs = dl.sample_unlabelled_tcrdb_trb(len(pird_pos_table) * 5)
    
    full_dset = dl.TCRSupervisedIdxDataset(
        pird_pos_trbs + neg_trbs,
        np.array([1] * len(pird_pos_trbs) + [0] * len(neg_trbs))
    )
    train_dset = dl.DatasetSplit(full_dset, split='train', valid=0.0, test=0.3)
    test_dset = dl.DatasetSplit(full_dset, split='test', valid=0.0, test=0.3)
    
    torch.manual_seed(42)
    with tempfile.TemporaryDirectory() as tmpdir:
        logging.info(f"Running in temporary dir: {tmpdir}")
        net = skorch.NeuralNet(
            module=conv.OnePartConvNet,
            module__use_embedding=True,
            module__n_output=2,
            module__max_input_len=full_dset.max_len,
            criterion=torch.nn.CrossEntropyLoss,
            optimizer=torch.optim.Adam,
            max_epochs=250,
            batch_size=512,
            lr=1e-3,
            callbacks=[
                skorch.callbacks.LRScheduler(
                    policy=torch.optim.lr_scheduler.ReduceLROnPlateau,
                    mode='min',
                    factor=0.1,
                    patience=10,
                    min_lr=1e-6,
                ),
                skorch.callbacks.GradientNormClipping(gradient_clip_value=5),
                skorch.callbacks.EpochScoring(
                    "average_precision",
                    lower_is_better=False,
                    on_train=False,
                    name="valid_auprc",
                ),
                skorch.callbacks.EarlyStopping(
                    patience=25,
                    monitor="valid_auprc",
                    lower_is_better=False,
                ),
                skorch.callbacks.Checkpoint(  # Seems to cause errors if placed before scoring
                    dirname=tmpdir,
                    fn_prefix="net_",
                    monitor="valid_auprc_best",
                ),
            ],
            device=3,
        )
        net.classes_ = np.unique(test_dset.all_labels())
        net.fit(train_dset)
        # Restore best weights and evaluate on test set
        cp = skorch.callbacks.Checkpoint(dirname=tmpdir, fn_prefix="net_")
#         net.load_params(checkpoint=cp)
        test_truth = test_dset.all_labels()
        test_preds = net.predict_proba(test_dset)[:, 1]
        test_auprc = metrics.average_precision_score(test_truth, test_preds)
    return antigen, test_auprc
    
convnet_on_antigen("FPRPWLHGL")

NameError: name 'Tuple' is not defined

In [4]:
antigen_auprc_pairs = [convnet_on_antigen(a) for a in utils.dedup(pird_data['Antigen.sequence'])]
len(antigen_auprc_pairs)

INFO:root:Got maximum length of 22
INFO:root:Split train with 970 examples
INFO:root:Split test with 416 examples
INFO:root:Running in temporary dir: /tmp/tmphsl1uj0n


  epoch    train_loss    valid_auprc    valid_loss    cp     dur
-------  ------------  -------------  ------------  ----  ------
      1        [36m0.8163[0m         [32m0.1578[0m        [35m0.6622[0m     +  0.0255
      2        [36m0.5344[0m         [32m0.2012[0m        [35m0.6552[0m     +  0.0243
      3        [36m0.4366[0m         [32m0.3027[0m        [35m0.6484[0m     +  0.0392
      4        [36m0.3731[0m         [32m0.3800[0m        [35m0.6336[0m     +  0.0337
      5        [36m0.3353[0m         [32m0.4502[0m        [35m0.6101[0m     +  0.0375
      6        [36m0.3127[0m         [32m0.4901[0m        [35m0.5762[0m     +  0.0313
      7        [36m0.2933[0m         [32m0.5338[0m        [35m0.5341[0m     +  0.0237
      8        [36m0.2754[0m         [32m0.5463[0m        [35m0.4881[0m     +  0.0239
      9        [36m0.2593[0m         [32m0.5553[0m        [35m0.4436[0m     +  0.0241
     10        [36m0.2444[0m         [3

INFO:root:Got maximum length of 20
INFO:root:Split train with 239 examples
INFO:root:Split test with 103 examples
INFO:root:Running in temporary dir: /tmp/tmp2xkvv5bx


  epoch    train_loss    valid_auprc    valid_loss    cp     dur
-------  ------------  -------------  ------------  ----  ------
      1        [36m1.1928[0m         [32m0.1823[0m        [35m0.7052[0m     +  0.0096
      2        [36m0.8366[0m         0.1765        [35m0.6801[0m        0.0095
      3        [36m0.5950[0m         [32m0.1901[0m        [35m0.6565[0m     +  0.0092
      4        [36m0.4429[0m         [32m0.2986[0m        [35m0.6340[0m     +  0.0092
      5        [36m0.3504[0m         [32m0.4274[0m        [35m0.6111[0m     +  0.0091
      6        [36m0.2899[0m         [32m0.4686[0m        [35m0.5882[0m     +  0.0092
      7        [36m0.2477[0m         [32m0.5115[0m        [35m0.5653[0m     +  0.0091
      8        [36m0.2154[0m         [32m0.5510[0m        [35m0.5419[0m     +  0.0092
      9        [36m0.1892[0m         [32m0.5595[0m        [35m0.5181[0m     +  0.0091
     10        [36m0.1666[0m         [32m0.6419

INFO:root:Got maximum length of 20
INFO:root:Split train with 239 examples
INFO:root:Split test with 103 examples
INFO:root:Running in temporary dir: /tmp/tmp8gzwymp9


  epoch    train_loss    valid_auprc    valid_loss    cp     dur
-------  ------------  -------------  ------------  ----  ------
      1        [36m1.2322[0m         [32m0.1191[0m        [35m0.7080[0m     +  0.0094
      2        [36m0.8683[0m         [32m0.1378[0m        [35m0.6808[0m     +  0.0094
      3        [36m0.6143[0m         [32m0.1851[0m        [35m0.6557[0m     +  0.0092
      4        [36m0.4461[0m         [32m0.2906[0m        [35m0.6326[0m     +  0.0091
      5        [36m0.3363[0m         [32m0.5974[0m        [35m0.6094[0m     +  0.0091
      6        [36m0.2624[0m         [32m0.6739[0m        [35m0.5854[0m     +  0.0091
      7        [36m0.2103[0m         [32m0.7179[0m        [35m0.5612[0m     +  0.0093
      8        [36m0.1727[0m         [32m0.7755[0m        [35m0.5355[0m     +  0.0091
      9        [36m0.1448[0m         [32m0.8690[0m        [35m0.5088[0m     +  0.0090
     10        [36m0.1226[0m         [3

INFO:root:Got maximum length of 19
INFO:root:Split train with 143 examples
INFO:root:Split test with 61 examples
INFO:root:Running in temporary dir: /tmp/tmp4r23hdwk


  epoch    train_loss    valid_auprc    valid_loss    cp     dur
-------  ------------  -------------  ------------  ----  ------
      1        [36m1.0881[0m         [32m0.4194[0m        [35m0.6324[0m     +  0.0082
      2        [36m0.7382[0m         0.4096        [35m0.6312[0m        0.0182
      3        [36m0.4962[0m         0.3571        [35m0.6268[0m        0.0127
      4        [36m0.3415[0m         0.3293        [35m0.6204[0m        0.0136
      5        [36m0.2485[0m         0.2893        [35m0.6130[0m        0.0115
      6        [36m0.1923[0m         0.2717        [35m0.6050[0m        0.0100
      7        [36m0.1552[0m         0.2617        [35m0.5968[0m        0.0086
      8        [36m0.1287[0m         0.2688        [35m0.5876[0m        0.0081
      9        [36m0.1090[0m         0.2455        [35m0.5781[0m        0.0078
     10        [36m0.0939[0m         0.2288        [35m0.5680[0m        0.0078
     11        [36m0.0817[0m 

INFO:root:Got maximum length of 20
INFO:root:Split train with 433 examples
INFO:root:Split test with 185 examples
INFO:root:Running in temporary dir: /tmp/tmpq94l0p6t


  epoch    train_loss    valid_auprc    valid_loss    cp     dur
-------  ------------  -------------  ------------  ----  ------
      1        [36m1.0848[0m         [32m0.0729[0m        [35m0.7235[0m     +  0.0127
      2        [36m0.7790[0m         [32m0.0775[0m        [35m0.6967[0m     +  0.0124
      3        [36m0.5883[0m         [32m0.0839[0m        [35m0.6721[0m     +  0.0122
      4        [36m0.4774[0m         [32m0.0978[0m        [35m0.6497[0m     +  0.0123
      5        [36m0.4117[0m         [32m0.1152[0m        [35m0.6282[0m     +  0.0122
      6        [36m0.3697[0m         [32m0.1559[0m        [35m0.6069[0m     +  0.0120
      7        [36m0.3389[0m         [32m0.2053[0m        [35m0.5866[0m     +  0.0123
      8        [36m0.3137[0m         [32m0.2916[0m        [35m0.5674[0m     +  0.0121
      9        [36m0.2914[0m         [32m0.3448[0m        [35m0.5498[0m     +  0.0120
     10        [36m0.2701[0m         [3

INFO:root:Got maximum length of 20
INFO:root:Split train with 412 examples
INFO:root:Split test with 176 examples
INFO:root:Running in temporary dir: /tmp/tmpixce070v


  epoch    train_loss    valid_auprc    valid_loss    cp     dur
-------  ------------  -------------  ------------  ----  ------
      1        [36m1.1769[0m         [32m0.3222[0m        [35m0.7145[0m     +  0.0122
      2        [36m0.8595[0m         0.2891        [35m0.6917[0m        0.0122
      3        [36m0.6532[0m         0.2435        [35m0.6699[0m        0.0117
      4        [36m0.5338[0m         0.1776        [35m0.6488[0m        0.0116
      5        [36m0.4714[0m         0.1480        [35m0.6280[0m        0.0116
      6        [36m0.4366[0m         0.1418        [35m0.6083[0m        0.0116
      7        [36m0.4138[0m         0.1405        [35m0.5898[0m        0.0117
      8        [36m0.3950[0m         0.1407        [35m0.5731[0m        0.0116
      9        [36m0.3760[0m         0.1409        [35m0.5583[0m        0.0115
     10        [36m0.3559[0m         0.1429        [35m0.5455[0m        0.0115
     11        [36m0.3344[0m 

INFO:root:Got maximum length of 21
INFO:root:Split train with 189 examples
INFO:root:Split test with 81 examples
INFO:root:Running in temporary dir: /tmp/tmpq7v3gmbm


  epoch    train_loss    valid_auprc    valid_loss    cp     dur
-------  ------------  -------------  ------------  ----  ------
      1        [36m1.5393[0m         [32m0.1380[0m        [35m0.7728[0m     +  0.0090
      2        [36m1.0800[0m         [32m0.1662[0m        [35m0.7481[0m     +  0.0089
      3        [36m0.7657[0m         [32m0.3000[0m        [35m0.7231[0m     +  0.0085
      4        [36m0.5688[0m         [32m0.3081[0m        [35m0.6989[0m     +  0.0085
      5        [36m0.4531[0m         0.2119        [35m0.6753[0m        0.0085
      6        [36m0.3800[0m         0.2256        [35m0.6511[0m        0.0084
      7        [36m0.3286[0m         0.2518        [35m0.6263[0m        0.0085
      8        [36m0.2887[0m         0.3045        [35m0.6025[0m        0.0085
      9        [36m0.2546[0m         [32m0.3396[0m        [35m0.5785[0m     +  0.0083
     10        [36m0.2244[0m         [32m0.3925[0m        [35m0.5538[0m 

INFO:root:Got maximum length of 20
INFO:root:Split train with 151 examples
INFO:root:Split test with 65 examples
INFO:root:Running in temporary dir: /tmp/tmp0h09yz0r


  epoch    train_loss    valid_auprc    valid_loss    cp     dur
-------  ------------  -------------  ------------  ----  ------
      1        [36m1.1857[0m         [32m0.1950[0m        [35m0.7191[0m     +  0.0083
      2        [36m0.7800[0m         [32m0.2193[0m        [35m0.6997[0m     +  0.0082
      3        [36m0.5314[0m         [32m0.2296[0m        [35m0.6838[0m     +  0.0079
      4        [36m0.3970[0m         0.2258        [35m0.6710[0m        0.0078
      5        [36m0.3245[0m         0.2110        [35m0.6592[0m        0.0078
      6        [36m0.2788[0m         0.2066        [35m0.6485[0m        0.0078
      7        [36m0.2443[0m         0.2040        [35m0.6385[0m        0.0079
      8        [36m0.2142[0m         0.2015        [35m0.6285[0m        0.0078
      9        [36m0.1868[0m         0.2131        [35m0.6181[0m        0.0078
     10        [36m0.1623[0m         [32m0.2416[0m        [35m0.6079[0m     +  0.0077
    

INFO:root:Got maximum length of 20
INFO:root:Split train with 223 examples
INFO:root:Split test with 95 examples
INFO:root:Running in temporary dir: /tmp/tmpi0ho5cgx


  epoch    train_loss    valid_auprc    valid_loss    cp     dur
-------  ------------  -------------  ------------  ----  ------
      1        [36m1.1151[0m         [32m0.1321[0m        [35m0.7114[0m     +  0.0188
      2        [36m0.7836[0m         [32m0.1677[0m        [35m0.6898[0m     +  0.0153
      3        [36m0.5629[0m         [32m0.2142[0m        [35m0.6694[0m     +  0.0110
      4        [36m0.4280[0m         [32m0.3017[0m        [35m0.6487[0m     +  0.0092
      5        [36m0.3473[0m         [32m0.5973[0m        [35m0.6268[0m     +  0.0090
      6        [36m0.2958[0m         [32m0.7205[0m        [35m0.6048[0m     +  0.0088
      7        [36m0.2591[0m         [32m0.7311[0m        [35m0.5834[0m     +  0.0087
      8        [36m0.2297[0m         [32m0.7478[0m        [35m0.5620[0m     +  0.0091
      9        [36m0.2043[0m         0.7467        [35m0.5403[0m        0.0089
     10        [36m0.1824[0m         0.7475      

INFO:root:Got maximum length of 20
INFO:root:Split train with 151 examples
INFO:root:Split test with 65 examples
INFO:root:Running in temporary dir: /tmp/tmprzge9sea


  epoch    train_loss    valid_auprc    valid_loss    cp     dur
-------  ------------  -------------  ------------  ----  ------
      1        [36m1.2244[0m         [32m0.3165[0m        [35m0.7214[0m     +  0.0168
      2        [36m0.8117[0m         [32m0.3321[0m        [35m0.7017[0m     +  0.0105
      3        [36m0.5492[0m         [32m0.3402[0m        [35m0.6844[0m     +  0.0086
      4        [36m0.3928[0m         0.2660        [35m0.6702[0m        0.0080
      5        [36m0.2970[0m         0.2536        [35m0.6557[0m        0.0080
      6        [36m0.2334[0m         0.2679        [35m0.6407[0m        0.0079
      7        [36m0.1883[0m         0.2764        [35m0.6247[0m        0.0079
      8        [36m0.1543[0m         0.2823        [35m0.6072[0m        0.0078
      9        [36m0.1281[0m         0.3150        [35m0.5885[0m        0.0180
     10        [36m0.1074[0m         0.3351        [35m0.5692[0m        0.0132
     11      

INFO:root:Got maximum length of 19
INFO:root:Split train with 185 examples
INFO:root:Split test with 79 examples
INFO:root:Running in temporary dir: /tmp/tmpgdq_esu8


  epoch    train_loss    valid_auprc    valid_loss    cp     dur
-------  ------------  -------------  ------------  ----  ------
      1        [36m1.1105[0m         [32m0.5756[0m        [35m0.6250[0m     +  0.0090
      2        [36m0.7814[0m         [32m0.5780[0m        [35m0.6225[0m     +  0.0087
      3        [36m0.5525[0m         0.5694        [35m0.6180[0m        0.0085
      4        [36m0.4073[0m         [32m0.6306[0m        [35m0.6110[0m     +  0.0083
      5        [36m0.3196[0m         0.5978        [35m0.6003[0m        0.0083
      6        [36m0.2634[0m         0.5945        [35m0.5870[0m        0.0080
      7        [36m0.2238[0m         0.5957        [35m0.5719[0m        0.0196
      8        [36m0.1942[0m         0.6278        [35m0.5561[0m        0.0145
      9        [36m0.1701[0m         0.6238        [35m0.5400[0m        0.0138
     10        [36m0.1495[0m         0.6238        [35m0.5233[0m        0.0104
     11      

INFO:root:Got maximum length of 20
INFO:root:Split train with 109 examples
INFO:root:Split test with 47 examples
INFO:root:Running in temporary dir: /tmp/tmpiluhzq75


  epoch    train_loss    valid_auprc    valid_loss    cp     dur
-------  ------------  -------------  ------------  ----  ------
      1        [36m0.9712[0m         [32m0.2083[0m        [35m0.7172[0m     +  0.0159
      2        [36m0.6314[0m         [32m0.2159[0m        [35m0.6904[0m     +  0.0166
      3        [36m0.4341[0m         0.2019        [35m0.6666[0m        0.0107
      4        [36m0.3291[0m         [32m0.3333[0m        [35m0.6443[0m     +  0.0094
      5        [36m0.2646[0m         0.2667        [35m0.6239[0m        0.0078
      6        [36m0.2157[0m         0.2667        [35m0.6061[0m        0.0076
      7        [36m0.1754[0m         0.2500        [35m0.5894[0m        0.0074
      8        [36m0.1434[0m         0.2436        [35m0.5739[0m        0.0072
      9        [36m0.1194[0m         0.2436        [35m0.5593[0m        0.0076
     10        [36m0.1012[0m         0.2436        [35m0.5452[0m        0.0075
     11      

INFO:root:Got maximum length of 20
INFO:root:Split train with 286 examples
INFO:root:Split test with 122 examples
INFO:root:Running in temporary dir: /tmp/tmp68wze67f


  epoch    train_loss    valid_auprc    valid_loss    cp     dur
-------  ------------  -------------  ------------  ----  ------
      1        [36m1.1309[0m         [32m0.2592[0m        [35m0.7180[0m     +  0.0106
      2        [36m0.8061[0m         [32m0.2785[0m        [35m0.6969[0m     +  0.0103
      3        [36m0.5936[0m         [32m0.2863[0m        [35m0.6785[0m     +  0.0099
      4        [36m0.4682[0m         [32m0.2898[0m        [35m0.6618[0m     +  0.0206
      5        [36m0.3943[0m         [32m0.2962[0m        [35m0.6460[0m     +  0.0174
      6        [36m0.3473[0m         0.2915        [35m0.6309[0m        0.0153
      7        [36m0.3125[0m         [32m0.3222[0m        [35m0.6171[0m     +  0.0201
      8        [36m0.2825[0m         [32m0.3286[0m        [35m0.6045[0m     +  0.0117
      9        [36m0.2549[0m         [32m0.3430[0m        [35m0.5923[0m     +  0.0101
     10        [36m0.2286[0m         [32m0.3484

INFO:root:Got maximum length of 21
INFO:root:Split train with 176 examples
INFO:root:Split test with 76 examples
INFO:root:Running in temporary dir: /tmp/tmpmwxwhb98


  epoch    train_loss    valid_auprc    valid_loss    cp     dur
-------  ------------  -------------  ------------  ----  ------
      1        [36m1.4751[0m         [32m0.1788[0m        [35m0.7645[0m     +  0.0088
      2        [36m1.0023[0m         0.1761        [35m0.7405[0m        0.0089
      3        [36m0.6825[0m         [32m0.1908[0m        [35m0.7158[0m     +  0.0134
      4        [36m0.4845[0m         [32m0.1980[0m        [35m0.6930[0m     +  0.0116
      5        [36m0.3739[0m         [32m0.2055[0m        [35m0.6707[0m     +  0.0126
      6        [36m0.3117[0m         [32m0.2334[0m        [35m0.6480[0m     +  0.0090
      7        [36m0.2735[0m         [32m0.2484[0m        [35m0.6263[0m     +  0.0085
      8        [36m0.2464[0m         0.2473        [35m0.6058[0m        0.0140
      9        [36m0.2236[0m         [32m0.2664[0m        [35m0.5873[0m     +  0.0155
     10        [36m0.2014[0m         [32m0.2705[0m      

INFO:root:Got maximum length of 20
INFO:root:Split train with 370 examples
INFO:root:Split test with 158 examples
INFO:root:Running in temporary dir: /tmp/tmpo22hmfam


  epoch    train_loss    valid_auprc    valid_loss    cp     dur
-------  ------------  -------------  ------------  ----  ------
      1        [36m1.1832[0m         [32m0.2751[0m        [35m0.7083[0m     +  0.0115
      2        [36m0.8455[0m         0.2506        [35m0.6875[0m        0.0113
      3        [36m0.6254[0m         [32m0.3130[0m        [35m0.6674[0m     +  0.0111
      4        [36m0.4931[0m         [32m0.4188[0m        [35m0.6481[0m     +  0.0117
      5        [36m0.4112[0m         [32m0.5120[0m        [35m0.6276[0m     +  0.0261
      6        [36m0.3539[0m         [32m0.6273[0m        [35m0.6062[0m     +  0.0248
      7        [36m0.3106[0m         [32m0.6710[0m        [35m0.5849[0m     +  0.0133
      8        [36m0.2758[0m         [32m0.7312[0m        [35m0.5633[0m     +  0.0117
      9        [36m0.2467[0m         0.7151        [35m0.5421[0m        0.0146
     10        [36m0.2221[0m         0.7240        [35m0.

INFO:root:Got maximum length of 20
INFO:root:Split train with 101 examples
INFO:root:Split test with 43 examples
INFO:root:Running in temporary dir: /tmp/tmpyb453_7j


  epoch    train_loss    valid_auprc    valid_loss    cp     dur
-------  ------------  -------------  ------------  ----  ------
      1        [36m1.1843[0m         [32m0.0909[0m        [35m0.7161[0m     +  0.0074
      2        [36m0.7911[0m         [32m0.1250[0m        [35m0.6940[0m     +  0.0073
      3        [36m0.5284[0m         0.1250        [35m0.6752[0m        0.0070
      4        [36m0.3678[0m         [32m0.2000[0m        [35m0.6568[0m     +  0.0068
      5        [36m0.2725[0m         0.2000        [35m0.6386[0m        0.0070
      6        [36m0.2135[0m         [32m0.2500[0m        [35m0.6194[0m     +  0.0070
      7        [36m0.1728[0m         [32m0.3333[0m        [35m0.5980[0m     +  0.0071
      8        [36m0.1415[0m         [32m1.0000[0m        [35m0.5744[0m     +  0.0071
      9        [36m0.1167[0m         1.0000        [35m0.5487[0m        0.0070
     10        [36m0.0966[0m         1.0000        [35m0.5212[0m 

INFO:root:Got maximum length of 18
INFO:root:Split train with 118 examples
INFO:root:Split test with 50 examples
INFO:root:Running in temporary dir: /tmp/tmpfs_00zl0


  epoch    train_loss    valid_auprc    valid_loss    cp     dur
-------  ------------  -------------  ------------  ----  ------
      1        [36m1.0833[0m         [32m0.1335[0m        [35m0.6977[0m     +  0.0078
      2        [36m0.7553[0m         0.1217        [35m0.6793[0m        0.0077
      3        [36m0.5384[0m         0.1156        [35m0.6639[0m        0.0074
      4        [36m0.4049[0m         0.1123        [35m0.6492[0m        0.0075
      5        [36m0.3261[0m         0.1123        [35m0.6355[0m        0.0074
      6        [36m0.2737[0m         0.1123        [35m0.6223[0m        0.0075
      7        [36m0.2327[0m         0.1120        [35m0.6085[0m        0.0074
      8        [36m0.1986[0m         0.1120        [35m0.5948[0m        0.0074
      9        [36m0.1696[0m         0.1185        [35m0.5815[0m        0.0072
     10        [36m0.1454[0m         0.1222        [35m0.5692[0m        0.0075
     11        [36m0.1245[0m 

INFO:root:Got maximum length of 34
INFO:root:Split train with 17283 examples
INFO:root:Split test with 7407 examples
INFO:root:Running in temporary dir: /tmp/tmpovo68w87


  epoch    train_loss    valid_auprc    valid_loss    cp     dur
-------  ------------  -------------  ------------  ----  ------
      1        [36m0.4988[0m         [32m0.1949[0m        [35m0.4571[0m     +  0.4868
      2        [36m0.4417[0m         [32m0.2066[0m        [35m0.4489[0m     +  0.4218
      3        [36m0.4296[0m         [32m0.2125[0m        [35m0.4485[0m     +  0.4821
      4        [36m0.4219[0m         [32m0.2153[0m        0.4503     +  0.5507
      5        [36m0.4150[0m         [32m0.2200[0m        0.4525     +  0.5023
      6        [36m0.4080[0m         [32m0.2235[0m        0.4558     +  0.5560
      7        [36m0.4008[0m         [32m0.2261[0m        0.4607     +  0.5316
      8        [36m0.3934[0m         0.2256        0.4659        0.5710
      9        [36m0.3859[0m         0.2256        0.4687        0.5066
     10        [36m0.3778[0m         0.2236        0.4740        0.5164
     11        [36m0.3692[0m         0

INFO:root:Got maximum length of 20
INFO:root:Split train with 88 examples
INFO:root:Split test with 38 examples
INFO:root:Running in temporary dir: /tmp/tmpa1tse2vt


Stopping since valid_auprc has not improved in the last 25 epochs.
  epoch    train_loss    valid_auprc    valid_loss    cp     dur
-------  ------------  -------------  ------------  ----  ------
      1        [36m1.2277[0m         [32m0.1825[0m        [35m0.7196[0m     +  0.0071
      2        [36m0.7803[0m         [32m0.1833[0m        [35m0.6953[0m     +  0.0072
      3        [36m0.4982[0m         [32m0.2000[0m        [35m0.6748[0m     +  0.0068
      4        [36m0.3375[0m         [32m0.2917[0m        [35m0.6559[0m     +  0.0068
      5        [36m0.2487[0m         [32m0.3929[0m        [35m0.6374[0m     +  0.0068
      6        [36m0.1939[0m         0.3929        [35m0.6182[0m        0.0068
      7        [36m0.1545[0m         [32m0.5000[0m        [35m0.5986[0m     +  0.0068
      8        [36m0.1244[0m         0.5000        [35m0.5776[0m        0.0068
      9        [36m0.1007[0m         [32m0.5833[0m        [35m0.5551[0m     +  

INFO:root:Got maximum length of 23
INFO:root:Split train with 298 examples
INFO:root:Split test with 128 examples
INFO:root:Running in temporary dir: /tmp/tmpr9v1h_xp


  epoch    train_loss    valid_auprc    valid_loss    cp     dur
-------  ------------  -------------  ------------  ----  ------
      1        [36m0.8696[0m         [32m0.2356[0m        [35m0.6766[0m     +  0.0272
      2        [36m0.5887[0m         [32m0.2704[0m        [35m0.6681[0m     +  0.0158
      3        [36m0.4394[0m         [32m0.3113[0m        [35m0.6558[0m     +  0.0233
      4        [36m0.3658[0m         [32m0.3836[0m        [35m0.6407[0m     +  0.0133
      5        [36m0.3213[0m         [32m0.4407[0m        [35m0.6236[0m     +  0.0109
      6        [36m0.2886[0m         [32m0.4854[0m        [35m0.6043[0m     +  0.0112
      7        [36m0.2607[0m         [32m0.5598[0m        [35m0.5838[0m     +  0.0106
      8        [36m0.2354[0m         [32m0.5657[0m        [35m0.5630[0m     +  0.0105
      9        [36m0.2126[0m         [32m0.6302[0m        [35m0.5421[0m     +  0.0158
     10        [36m0.1922[0m         [3

INFO:root:Got maximum length of 23
INFO:root:Split train with 844 examples
INFO:root:Split test with 362 examples
INFO:root:Running in temporary dir: /tmp/tmp2325jdnv


  epoch    train_loss    valid_auprc    valid_loss    cp     dur
-------  ------------  -------------  ------------  ----  ------
      1        [36m0.8309[0m         [32m0.2008[0m        [35m0.6605[0m     +  0.0374
      2        [36m0.5285[0m         0.1921        [35m0.6371[0m        0.0526
      3        [36m0.4898[0m         0.1565        [35m0.6091[0m        0.0486
      4        [36m0.4693[0m         0.1487        [35m0.5801[0m        0.0446
      5        [36m0.4335[0m         0.1522        [35m0.5560[0m        0.0411
      6        [36m0.3967[0m         0.1590        [35m0.5384[0m        0.0543
      7        [36m0.3679[0m         0.1622        [35m0.5272[0m        0.0266
      8        [36m0.3489[0m         0.1652        [35m0.5165[0m        0.0231
      9        [36m0.3333[0m         0.1705        [35m0.5037[0m        0.0224
     10        [36m0.3162[0m         0.1722        [35m0.4880[0m        0.0224
     11        [36m0.2993[0m 

    103        [36m0.0069[0m         0.2869        0.5817        0.0313
    104        [36m0.0067[0m         0.2870        0.5828        0.0363
    105        [36m0.0065[0m         0.2875        0.5839        0.0218
    106        [36m0.0064[0m         0.2876        0.5850        0.0218
    107        [36m0.0063[0m         0.2875        0.5863        0.0230
    108        [36m0.0061[0m         0.2874        0.5875        0.0234
    109        [36m0.0060[0m         0.2873        0.5887        0.0232
    110        [36m0.0059[0m         0.2880        0.5898        0.0231
    111        [36m0.0057[0m         0.2879        0.5908        0.0433
    112        [36m0.0056[0m         0.2884        0.5918        0.0355
    113        [36m0.0055[0m         0.2893        0.5929        0.0418
    114        [36m0.0054[0m         0.2896        0.5939        0.0339
    115        [36m0.0053[0m         0.2891        0.5947        0.0227
    116        [36m0.0052[0m        

INFO:root:Got maximum length of 25
INFO:root:Split train with 601 examples
INFO:root:Split test with 257 examples
INFO:root:Running in temporary dir: /tmp/tmpt7q9qxl7


  epoch    train_loss    valid_auprc    valid_loss    cp     dur
-------  ------------  -------------  ------------  ----  ------
      1        [36m0.6313[0m         [32m0.1950[0m        [35m0.6799[0m     +  0.0150
      2        [36m0.5013[0m         [32m0.2449[0m        [35m0.6741[0m     +  0.0149
      3        [36m0.4280[0m         [32m0.3214[0m        [35m0.6676[0m     +  0.0148
      4        [36m0.3722[0m         [32m0.3893[0m        [35m0.6611[0m     +  0.0146
      5        [36m0.3268[0m         [32m0.4607[0m        [35m0.6545[0m     +  0.0143
      6        [36m0.2914[0m         [32m0.5051[0m        [35m0.6473[0m     +  0.0145
      7        [36m0.2640[0m         [32m0.5224[0m        [35m0.6391[0m     +  0.0145
      8        [36m0.2422[0m         [32m0.5270[0m        [35m0.6298[0m     +  0.0147
      9        [36m0.2238[0m         0.5259        [35m0.6182[0m        0.0146
     10        [36m0.2074[0m         0.5169      

INFO:root:Got maximum length of 19
INFO:root:Split train with 155 examples
INFO:root:Split test with 67 examples
INFO:root:Running in temporary dir: /tmp/tmpgd449lam


  epoch    train_loss    valid_auprc    valid_loss    cp     dur
-------  ------------  -------------  ------------  ----  ------
      1        [36m1.2508[0m         [32m0.4365[0m        [35m0.6368[0m     +  0.0084
      2        [36m0.8962[0m         [32m0.4648[0m        [35m0.6349[0m     +  0.0083
      3        [36m0.6329[0m         0.4316        [35m0.6318[0m        0.0079
      4        [36m0.4563[0m         0.4399        [35m0.6264[0m        0.0078
      5        [36m0.3420[0m         0.4511        [35m0.6184[0m        0.0079
      6        [36m0.2656[0m         [32m0.4831[0m        [35m0.6089[0m     +  0.0078
      7        [36m0.2126[0m         [32m0.5496[0m        [35m0.5970[0m     +  0.0078
      8        [36m0.1733[0m         [32m0.5514[0m        [35m0.5827[0m     +  0.0079
      9        [36m0.1427[0m         [32m0.6035[0m        [35m0.5660[0m     +  0.0079
     10        [36m0.1184[0m         [32m0.6657[0m        [35m0.

INFO:root:Got maximum length of 20
INFO:root:Split train with 256 examples
INFO:root:Split test with 110 examples
INFO:root:Running in temporary dir: /tmp/tmpmox9dwi5


  epoch    train_loss    valid_auprc    valid_loss    cp     dur
-------  ------------  -------------  ------------  ----  ------
      1        [36m1.1475[0m         [32m0.2203[0m        [35m0.7133[0m     +  0.0163
      2        [36m0.8135[0m         [32m0.2295[0m        [35m0.6925[0m     +  0.0210
      3        [36m0.5911[0m         [32m0.3584[0m        [35m0.6718[0m     +  0.0127
      4        [36m0.4558[0m         [32m0.3960[0m        [35m0.6522[0m     +  0.0097
      5        [36m0.3743[0m         [32m0.4749[0m        [35m0.6346[0m     +  0.0096
      6        [36m0.3234[0m         [32m0.5489[0m        [35m0.6172[0m     +  0.0094
      7        [36m0.2878[0m         [32m0.5551[0m        [35m0.6000[0m     +  0.0094
      8        [36m0.2598[0m         [32m0.5787[0m        [35m0.5832[0m     +  0.0094
      9        [36m0.2355[0m         0.5437        [35m0.5670[0m        0.0095
     10        [36m0.2140[0m         0.5610      

INFO:root:Got maximum length of 20
INFO:root:Split train with 101 examples
INFO:root:Split test with 43 examples
INFO:root:Running in temporary dir: /tmp/tmpitu1qo4n


  epoch    train_loss    valid_auprc    valid_loss    cp     dur
-------  ------------  -------------  ------------  ----  ------
      1        [36m1.1607[0m         [32m0.1111[0m        [35m0.7152[0m     +  0.0076
      2        [36m0.7629[0m         [32m0.2000[0m        [35m0.6928[0m     +  0.0077
      3        [36m0.4940[0m         [32m0.2500[0m        [35m0.6723[0m     +  0.0072
      4        [36m0.3268[0m         [32m0.5000[0m        [35m0.6515[0m     +  0.0071
      5        [36m0.2277[0m         [32m1.0000[0m        [35m0.6314[0m     +  0.0071
      6        [36m0.1682[0m         1.0000        [35m0.6102[0m        0.0072
      7        [36m0.1298[0m         1.0000        [35m0.5866[0m        0.0072
      8        [36m0.1025[0m         1.0000        [35m0.5615[0m        0.0115
      9        [36m0.0820[0m         1.0000        [35m0.5352[0m        0.0125
     10        [36m0.0663[0m         1.0000        [35m0.5078[0m        0.

INFO:root:Got maximum length of 20
INFO:root:Split train with 945 examples
INFO:root:Split test with 405 examples
INFO:root:Running in temporary dir: /tmp/tmp3oqrre77


  epoch    train_loss    valid_auprc    valid_loss    cp     dur
-------  ------------  -------------  ------------  ----  ------
      1        [36m1.0622[0m         [32m0.1648[0m        [35m0.6890[0m     +  0.0355
      2        [36m0.5717[0m         [32m0.5677[0m        [35m0.6375[0m     +  0.0373
      3        [36m0.3811[0m         [32m0.6276[0m        [35m0.5844[0m     +  0.0248
      4        [36m0.3111[0m         [32m0.6655[0m        [35m0.5287[0m     +  0.0249
      5        [36m0.2792[0m         [32m0.6870[0m        [35m0.4744[0m     +  0.0258
      6        [36m0.2589[0m         [32m0.7014[0m        [35m0.4260[0m     +  0.0260
      7        [36m0.2417[0m         [32m0.7059[0m        [35m0.3854[0m     +  0.0258
      8        [36m0.2261[0m         [32m0.7115[0m        [35m0.3529[0m     +  0.0259
      9        [36m0.2116[0m         [32m0.7151[0m        [35m0.3279[0m     +  0.0260
     10        [36m0.1983[0m         [3

     99        [36m0.0054[0m         0.7642        0.2580        0.0237
    100        [36m0.0052[0m         0.7640        0.2584        0.0236
Stopping since valid_auprc has not improved in the last 25 epochs.


73

In [5]:
antigen_auprc_pairs = [pair for pair in antigen_auprc_pairs if not np.isnan(pair[1])]
len(antigen_auprc_pairs)

26

In [9]:
df = pd.DataFrame(
    [pair[1] for pair in antigen_auprc_pairs],
    index=[pair[0] for pair in antigen_auprc_pairs],
    columns=['ConvNet']
)
df

Unnamed: 0,ConvNet
LLWNGPMAV,0.714525
RPRGEVRFL,0.804326
ATDALMTGY,0.968797
HSKKKCDEL,0.777029
KAFSPEVIPMF,0.699305
KRWIILGLNK,0.434463
TPQDLNTML,0.667731
EIYKRWII,0.34129
HPKVSSEVHI,0.613221
IIKDYGKQM,0.716234


In [10]:
df.to_csv("antigen_cv_convnet_baseline.csv")