# TODO: later shift to unit tests

In [1]:
import sys
sys.path.append('../..')

In [2]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

from epytope.Core import Peptide, Allele, TCREpitope, ImmuneReceptorChain, ImmuneReceptor

from epytope.IO import IRDatasetAdapterFactory
from epytope.TCRSpecificityPrediction import TCRSpecificityPredictorFactory

In [3]:
peptide = Peptide("SYFPEITHI")
allele = Allele("HLA-A*02:01")
epitope_1 = TCREpitope(peptide=peptide, allele=allele)
epitope_2 = TCREpitope(peptide="EAAGIGILTV", allele=allele)

In [4]:
path_data = '../../scrubs/vdjdb.tsv'
tcr_repertoire = IRDatasetAdapterFactory("vdjdb")
tcr_repertoire.from_path(path_data)
tcr_repertoire.receptors = tcr_repertoire.receptors[:20]

In [5]:
for name, version in TCRSpecificityPredictorFactory.available_methods().items():
    print(name, ",".join(version))

imrex 
titan 1.0.0
ergo-ii 
pmtnet 
epitcr 
atm-tcr 


In [10]:
reqs = {
    "ergo-ii": {"repository": "../../external/ERGO-II"},
    "pmtnet": {"repository": "../../external/pMTnet", "conda": "epytope_tf20"},
    "epitcr": {"repository": "../../external/epiTCR", "conda": "epytope_python3.8"},
    "atm-tcr": {"repository": "../../external/ATM-TCR", "conda": "epytope_torch10", "cuda": True},
    "titan": {"conda": "epytope_torch10"},
    "imrex": {"conda": "tcr_predictors"},
}

choices = {
    "ergo-ii": {"dataset": ["vdjdb", "mcpas"]}
}

## Test Pairwise

In [40]:
epitopes_pairwise = [epitope_1, epitope_2]

for name, req_model in reqs.items():
    print(name)
    
    predictor = TCRSpecificityPredictorFactory(name)
    results = predictor.predict(tcr_repertoire, epitopes_pairwise, pairwise=True, **req_model)
    
    assert len(results)== len(tcr_repertoire.receptors), "Results have wrong length"
    for epitope in epitopes_pairwise:
        assert epitope in results, "Epitope not in result"
        assert name in [el.lower() for el in results[epitope].columns.tolist()], "Method not in results"
        assert results[epitope].iloc[:, 0].isna().sum() < len(results), "Method always yield NaN"
        
print("### All pairwise Tests succeded")

ergo-ii
pmtnet
epitcr
atm-tcr
titan
imrex
### All pairwise Tests succeded


## Test Not-Pairwise

In [39]:
epitopes_list = [epitope_1, epitope_2] * 10

for name, req_model in reqs.items(): 
    print(name)
    predictor = TCRSpecificityPredictorFactory(name)
    results = predictor.predict(tcr_repertoire, epitopes_list, pairwise=False, **req_model)
    
    assert len(results)== len(tcr_repertoire.receptors), "Results have wrong length"
    assert name in [el.lower() for el in results["Method"].columns]
    assert results["Method"].iloc[:, 0].isna().sum() < len(results), "Method always yield NaN"
    for i, epitope in enumerate(epitopes_list):
        assert results.at[i, ("Epitope", "Peptide")] == epitope.peptide, f"Wrong epitope at position {i}"
        assert results.at[i, ("Epitope", "MHC")] == epitope.allele, f"Wrong MHC at position {i}"
    
print("### All non-pairwise Tests succeded")

ergo-ii
pmtnet
epitcr
atm-tcr
titan
imrex
### All non-pairwise Tests succeded
