This tutorial explains how to run EpiGen inference, annotate phenotype-associated TCRs / epitopes, and take ensemble. The overall process is as follows:

**inference --> annotation --> ensemble**

To start with, prepare an input file as in `data/sample_tcrs.csv` that contains CDR3b sequences in the 'tcr' column. Download the epitope database from https://zenodo.org/records/14624873/files/tumor_associated_epitopes.csv?download=1. Let's assume we place it under the `data/` directory. 

In [None]:
cd EpiGen
mkdir data && cd data
wget https://zenodo.org/records/14624873/files/tumor_associated_epitopes.csv
wget https://zenodo.org/records/14861398/files/obs_annotated_cancer_wu_ens_th0.5.csv
wget https://zenodo.org/records/14861864/files/sample_tcrs.csv

Then we need to determine other high-level parameters such as `top_k`, `method`, `ens_th`, and `tokenizer_path`. 

- `top_k` is the number of epitopes generated from each TCR. Although TCR-pMHC interaction is highly specific, it is estimated that one TCR should be able to recognize multiple epitopes (ref: A. K. Sewell, Why must T cells be cross-reactive?, Nature Reviews Immunology, 2012). This is a hyper-parameter that should be determined considering the characteristics of the dataset you're dealing with. In our manuscript, we set top_k=1 for the cancer dataset analysis, and top_k=8 for the COVID-19 dataset analysis. This is because of the number of PA T cells detected. COVID-19 is a single species where there are only ~1,500 CD8+ T cell epitopes. The database of tumor-associated epitopes is much larger. For more details, please refer to Supplementary Note 4 of the manuscript. 

- `method` is the matching method to be used for querying the epitopes to the epitope database. The current code supports two methods: 'substring' and 'levenshtein'. 'levenshtein' can be used together with a threshold which may avoid from strict matching. 'substring' is much faster and is recommended for the first run, because in our experience they are not too different. 

- `ens_th` is the threshold for the ensembling multiple annotation files. The default mode runs inference using 11 independent models. `ens_th` of 0.5 means a TCR is considered PA if it is predicted as PA by at least 6 models (> 0.5 of the total models). You may want to tune this sometimes as well to get more robust results. 

Considering the above, let's define a simple configuration in python dictionary:

In [None]:
cfg = {
    "exp_dir": "example_run",
    "input_file": "data/sample_tcrs.csv",
    "epitope_db": "data/tumor_associated_epitopes.csv",
    "top_k": 4,
    "method": 'substring',
    "ens_th": 0.5,
}

Also, import some python modules

In [None]:
import os
import pandas as pd
import numpy as np
from pathlib import Path
from epigen import EpiGenPredictor,EpitopeAnnotator,EpitopeEnsembler,visualize_match_overlaps_parallel

Now, read in the data and run inference. 

In [None]:
# Read the input data
tcrs = pd.read_csv(cfg['input_file'])
tcrs = tcrs["text"].tolist()

# Predict from TCR sequences
predictor = EpiGenPredictor()
results = predictor.predict_all(
    tcr_sequences=tcrs,
    output_dir=f"{cfg['exp_dir']}/predictions",
    top_k=cfg['top_k']
)


This will create 'predictions' under 'example_run' directory. You need GPUs with more than 24GB memory to run this (to hold GPT2-small architecture). Here, `predict_all_models()` function is a wrapper to `predict()` function, which runs inference using 11 independent models in sequence. If you want to speed up the process, you may want to run inference in parallel using the `predict()` function. Running the above code would have created `example_run/predictions/predictions_{i}.csv` files that contains the generated epitopes in the 'pred_{i}' columns. 

Now, run the annotation step using the following code:

In [None]:
# Annotate phenotype-associated epitopes / tcrs
os.environ["TOKENIZERS_PARALLELISM"] = "false"
annotator = EpitopeAnnotator(cfg['epitope_db'])
annotator.annotate_all(
    predictions_dir=f"{cfg['exp_dir']}/predictions",
    output_dir=f"{cfg['exp_dir']}/annotations",
    top_k=cfg['top_k'],
    method=cfg['method']
)

This step does not use GPUs but may take a while depending on the size of the dataset. This will create 'annotations' under 'example_run' directory. Multiprocessing is used to speed up the process. Running the above code would have created `example_run/annotations/annotations_{i}.csv` files that contains the annotated epitopes in the 'match_{i}', 'ref_epitope_{i}', and 'ref_protein_{i}' columns. For a TCR, match_1==1 means pred_1 (epitope) was found to match an entry in the database and labeled as PA. 

Here, you may check some of the agreements between the eleven models by runing:

In [None]:
annotation_files = [f"{cfg['exp_dir']}/annotations/annotations_{model_idx}.csv" for model_idx in range(1, 1 + 11)]
# Visualize the match overlaps between 11 annotations
similarity_matrix, file_names = visualize_match_overlaps_parallel(
    files_list=annotation_files,
    outdir=f"{cfg['exp_dir']}/ensemble",
    top_k=cfg['top_k'],
)

Depending on the agreements, you may want to tune the `ens_th` and `top_k` parameters. 

Finally, run the ensemble step using the following code:

In [None]:
# Ensemble 11 annotations to get the final robust annotation
ensembler = EpitopeEnsembler(threshold=cfg['ens_th'])
final_results = ensembler.ensemble(
    annotation_files,
    output_path=f"{cfg['exp_dir']}/ensemble/annotations_ens_all_th{cfg['ens_th']}.csv",
    top_k=cfg['top_k']
)

print(final_results)

Running the above code create 'ensemble' under 'example_run' directory. This is the final annotation file that contains the ensembled results. Please repeat until here at least five times to get multiple annotation files. You'll want to check the consistency of the results later. As EpiGen called model.generate() function of GPT-2, it generates different epitope sequences for the same TCR when it is run again. This is why we adopted taking an ensemble of 11 different models, which resulted in robust result in our dataset. 

You may merge this with your single-cell transcriptomics data to analyze the TCR-phenotype associations. From this point, the functions needed would be different by the structure of your dataset. Here, as an example, download an annotated observation file from https://zenodo.org/records/14624873/files/obs_annotated_cancer_wu_ens_th0.5.csv (under `data`) where we'll inject our TCR epitope annotation information. Let's consider a utility function to merge this with the single-cell data. 

In [None]:
def merge_annotations(
    site_file: str,
    annotation_file: str,
    output_dir: str = "merged",
    randomize: bool = False,
    random_seed: int = 42
    ):
    """Merge site data with new annotations by matching TCR sequences.

    Args:
        site_file: Path to site_added.csv
        annotation_file: Path to annotation_ens_th0.5.csv
        output_dir: Directory to save output file
        randomize: Whether to randomize annotation matches
        random_seed: Seed for reproducible randomization
    """
    print("\n=== Starting Annotation Merge ===")
    print(f"• Mode: {'Randomized' if randomize else 'Normal'}")

    # Read input files
    print("• Reading input files...")
    site_df = pd.read_csv(site_file)
    annot_df = pd.read_csv(annotation_file)

    if randomize:
        print("• Randomizing annotation matches...")
        np.random.seed(random_seed)

        # Identify columns to shuffle
        match_cols = [col for col in annot_df.columns if any(x in col for x in ['match_', 'ref_epitope_', 'ref_protein_'])]

        # Group columns by their index (e.g., match_0, ref_epitope_0, ref_protein_0)
        col_groups = {}
        for col in match_cols:
            idx = col.split('_')[-1]
            if idx.isdigit():
                if idx not in col_groups:
                    col_groups[idx] = []
                col_groups[idx].append(col)

        # Shuffle each group of columns together
        for idx, cols in col_groups.items():
            shuffle_idx = np.random.permutation(len(annot_df))
            annot_df[cols] = annot_df[cols].iloc[shuffle_idx].values

                # Get columns to keep from site_df
    keep_cols = []
    drop_patterns = ['pred_', 'ref_epitope_', 'ref_protein_', 'match_']
    for col in site_df.columns:
        if not any(pattern in col for pattern in drop_patterns):
            keep_cols.append(col)

    # Create clean site dataframe
    print("• Removing old predictions and annotations...")
    site_clean = site_df[keep_cols].copy()

    # Rename columns for merging
    annot_df = annot_df.rename(columns={'tcr': 'cdr3'})

    # Merge dataframes
    print("• Merging with new annotations...")
    merged_df = site_clean.merge(annot_df, on='cdr3', how='left')

    # Create output directory
    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    # Generate output filename
    site_stem = Path(site_file).stem
    annot_stem = Path(annotation_file).stem
    random_suffix = "_randomized" if randomize else ""
    output_file = output_dir / f"{site_stem}_merged_{annot_stem}{random_suffix}.csv"

    # Save merged file
    merged_df.to_csv(output_file, index=False)

    # Print statistics
    print("\n=== Merge Summary ===")
    print(f"• Total cells in site file: {len(site_df)}")
    print(f"• Total TCRs in annotation file: {len(annot_df)}")
    print(f"• Cells matched with annotations: {merged_df['pred_0'].notna().sum()}")
    print(f"• Cells without matches: {merged_df['pred_0'].isna().sum()}")

    # Print match statistics
    match_cols = [col for col in merged_df.columns if col.startswith('match_')]
    for k in range(min(4, len(match_cols))):  # Show first 4 positions
        matches = merged_df[f'match_{k}'].sum()
        total = merged_df[f'match_{k}'].notna().sum()
        if total > 0:
            print(f"• Match rate at k={k}: {matches/total*100:.1f}%")

    print(f"\n• Results saved to: {output_file}")
    print("===========================")

Call the function with the site file and the annotation file to inject our epitope annotation information to the transcriptomics data. 

In [None]:
### Merge old code's annotation to cancer_wu site_added.csv
merged_df = merge_annotations(
    site_file=f"data/obs_annotated_cancer_wu_ens_th0.5.csv",
    annotation_file=f"{cfg['exp_dir']}/ensemble/annotations_ens_all_th{cfg['ens_th']}.csv",
    output_dir=f"{cfg['exp_dir']}"
)

Running the above code creates the merged file under the `example_run` directory. We can now use this file to analyze the TCR-phenotype associations. Download the `cancer_wu` dataset to get the more raw transcriptomics data. 

In [None]:
from research.cancer_wu.download import download_and_preprocess
download_and_preprocess(outdir="data/cancer_wu", input_file="research/cancer_wu/data_links.txt")

This will download the `cancer_wu` dataset from https://www.ncbi.nlm.nih.gov/geo/query/acc.cgi?acc=GSE139555. We need to define some utility functions to read in the transcriptomics data. Please refer to `research/cacner_wu/analyze.py` and `research/cancer_wu/utils.py` for more details. 

In [None]:
# from research.cancer_wu.utils import *
# from research.cancer_wu.analyze import *
import scanpy as sc
import pandas as pd
import os
import anndata as ad
import scirpy as ir
import mudata as md
from mudata import MuData

CELL_TYPES = ['8.1-Teff', '8.2-Tem', '8.3a-Trm', '8.3b-Trm', '8.3c-Trm']
SAMPLES = ['CN1', 'CN2', 'CT1', 'CT2', 'EN1', 'EN2', 'EN3', 'ET1', 'ET2', 'ET3',
           'LB6', 'LN1', 'LN2', 'LN3', 'LN4', 'LN5', 'LN6', 'LT1', 'LT2', 'LT3',
           'LT4', 'LT5', 'LT6', 'RB1', 'RB2', 'RB3', 'RN1', 'RN2', 'RN3', 'RT1',
           'RT2', 'RT3']
           
def read_tcell_integrated(data_dir, transpose=False):
    """
    Read the main gene expression data
    """
    # Read the H5AD file
    adata = sc.read_h5ad(f"{data_dir}/GSE139555_tcell_integrated.h5ad")
    if transpose:
        adata = adata.transpose()
    metadata = pd.read_csv(f"{data_dir}/GSE139555%5Ftcell%5Fmetadata.txt", sep="\t", index_col=0)
    # Make sure the index of the metadata matches the obs_names of the AnnData object
    adata.obs = adata.obs.join(metadata, how='left')
    print("Successfully read GSE139555_t_cell_integrated!")
    return adata


def read_all_data(data_dir, obs_cache=None, filter_cdr3_notna=True, filter_cell_types=True):
    """
    The main function to read CD8+ T cell data from Wu et al. dataset
    Both gene expression and TCR sequences are read

    Parameters
    ----------
    data_dir: str
        Root directory of the data
    obs_cache: str / None
        csv file that contains some annotated TCR data. As there are multiple annotation steps,
        this file is always read after the very first annotation
    filter_cdr3_notna: bool
        Drop the rows that do not have viable CDR3 sequence information
    filter_cell_types: bool
        Drop the rows that are not CD8+ T cells
    """
    samples = ['CN1', 'CT2', 'EN3', 'ET3', 'LB6', 'LN3', 'LN6', 'LT3', 'LT6', 'RB2', 'RN2', 'RT2',
               'CN2', 'EN1', 'ET1', 'LN1', 'LN4', 'LT1', 'LT4', 'RB3', 'RN3', 'RT3',
               'CT1', 'EN2', 'ET2', 'LN2', 'LN5', 'LT2', 'LT5', 'RB1', 'RN1', 'RT1']
    # Read T-cell integrated (gene expression data)
    adata = read_tcell_integrated(data_dir)

    # Read the TCR sequencing data using scirpy (ir)
    airrs = []
    for sample in [s for s in os.listdir(data_dir) if s in samples]:
        for x in os.listdir(f"{data_dir}/{sample}"):
            if x.endswith("contig_annotations.csv") or x.endswith("annotations.csv"):
                airr = ir.io.read_10x_vdj(f"{data_dir}/{sample}/{x}")
                # Add a column to identify the source file
                airr.obs['new_cell_id'] = airr.obs.index.map(lambda x: sample + "_" + x)
                airr.obs.index = airr.obs['new_cell_id']
                airrs.append(airr)
    # Merge the AIRR objects
    if len(airrs) > 1:
        merged_airr = ad.concat(airrs)
    else:
        merged_airr = airrs[0]

    if obs_cache:
        print(f"Reading cache from {obs_cache}..")
        df_cache = pd.read_csv(obs_cache)

        # Merge df_cache to adata.obs based on cell_id
        # Set cell_id as index in df_cache to match adata.obs
        df_cache = df_cache.set_index('cell_id')

        # Keep only the cells that exist in df_cache
        common_cells = adata.obs.index.intersection(df_cache.index)
        adata = adata[common_cells].copy()

        # Update adata.obs with all columns from df_cache
        # This will overwrite existing columns and add new ones
        adata.obs = adata.obs.combine_first(df_cache)

        # For columns that exist in both, prefer df_cache values
        for col in df_cache.columns:
            if col in adata.obs:
                adata.obs[col] = df_cache[col]

        print(f"Updated adata.obs with {len(df_cache.columns)} columns from cache")
        print(f"Retained {len(common_cells)} cells after matching with cache")

    if filter_cell_types:
        print("Get only CD8+ T cells..")
        adata = adata[adata.obs['ident'].isin(CELL_TYPES)].copy()

    if filter_cdr3_notna:
        # Filter based on non-NA cdr3 values:
        valid_cells = adata.obs['cdr3'].notna()
        print(f"Filtering out {(~valid_cells).sum()} cells with NA cdr3 values")
        adata = adata[valid_cells].copy()

    mdata = MuData({"airr": merged_airr, "gex": adata})

    print(f"Successfully merged {len(airrs)} AIRR objects!")
    print(f"(read_all_data) The number of CD8+ T cells: {len(adata.obs)}")
    return mdata


def read_all_raw_data(data_dir):
    samples = os.listdir(data_dir)
    adata_list = []

    for sample in samples:
        if sample in SAMPLES:
            sample_path = os.path.join(data_dir, sample)
            file = os.listdir(sample_path)[0]

            # Read the data with the prefix applied to barcodes
            adata = sc.read_10x_mtx(
                path=sample_path,
                var_names="gene_symbols",
                make_unique=True,
                prefix=file.split(".")[0] + "."
            )

            # Rename the barcodes
            prefix = f"{sample}_"
            adata.obs_names = [f"{prefix}{barcode}" for barcode in adata.obs_names]

            # Append the annotated data to the list
            adata_list.append(adata)

    # Concatenate all the data into one AnnData object
    combined_adata = ad.concat(adata_list, axis=0)
    print("Successfully read all RAW data!")

    return combined_adata


def filter_and_update_combined_adata(combined_adata, processed_adata):
    # Get the common indices (barcodes) between the combined_adata and processed_adata
    common_indices = processed_adata.obs_names.intersection(combined_adata.obs_names)

    # Filter combined_adata to keep only those cells present in processed_adata
    filtered_combined_adata = combined_adata[common_indices].copy()

    # Copy obs from processed_adata to filtered_combined_adata
    for col in processed_adata.obs.columns:
        # Add a new column in filtered_combined_adata if it doesn't already exist
        if col not in filtered_combined_adata.obs.columns:
            filtered_combined_adata.obs[col] = None

        # Copy the data from processed_adata.obs to filtered_combined_adata.obs, matching by index
        filtered_combined_adata.obs[col] = processed_adata.obs.loc[common_indices, col]

    print(f"Filtered the combined data using the processed adata! (Finding intersection). Num of rows={len(filtered_combined_adata)}")

    return filtered_combined_adata

Now, let's run the differential gene expression analysis between Phenotype-Associated (PA) T cell that we marked by match_{i} to be 1 and other background T cells. 

In [None]:
from epigen import DEGAnalyzer
# Read the processed gene expression data of CD8+ T cell and then inject our epitope annotation
mdata = read_all_data(data_dir="data/cancer_wu", obs_cache=f"{cfg['exp_dir']}/obs_annotated_cancer_wu_ens_th0.5_merged_annotations_ens_all_th0.5.csv")
# Read the raw gene expression data of CD8+ T cells
raw_adata = read_all_raw_data(data_dir="data/cancer_wu")
# Merge the raw gene expression data with the previous TCR-GEX data
raw_adata_filtered = filter_and_update_combined_adata(raw_adata, mdata['gex'])

# Perform DEG analysis
for k in range(1, 1 + cfg['top_k']):
    analyzer = DEGAnalyzer(output_dir=f"{cfg['exp_dir']}/gex_grouped", top_k=k)
    analyzer.analyze(raw_adata_filtered.copy())

For more thorough result, please use a complete data instead of sample_tcrs.csv. Also, run the pipeline multiple times (prediction, annotation, ensemble, ..). 