In [6]:
import os
import warnings
import random
import numpy as np
import pandas as pd

In [7]:
def load_data_split(dat,split_type, seed):
    n_fold = 5
    idx_test_fold = 0
    idx_val_fold = -1
    idx_test = None
    idx_train = None
    x_pep = dat.epi
    x_tcr = dat.tcr
    
    if split_type == 'random':
        pass
    elif split_type == 'epi':
        unique_peptides = np.unique(x_pep)
        n_total = len(unique_peptides)
    elif split_type == 'tcr':
        unique_tcrs = np.unique(x_tcr)
        n_total = len(unique_tcrs)
        
    np.random.seed(seed)    
    idx_shuffled = np.arange(n_total)
    np.random.shuffle(idx_shuffled)
    
    # Determine data split from folds
    n_test = int(round(n_total / n_fold))
    n_train = n_total - n_test

    # Determine position of current test fold
    test_fold_start_index = idx_test_fold * n_test
    test_fold_end_index = (idx_test_fold + 1) * n_test

    if split_type == 'random':
        pass
    elif split_type == 'epi':
        if idx_val_fold < 0:
            idx_test_pep = idx_shuffled[test_fold_start_index:test_fold_end_index]
            test_peptides = unique_peptides[idx_test_pep]
            idx_test = [index for index, pep in enumerate(x_pep) if pep in test_peptides]
            idx_train = list(set(range(len(x_pep))).difference(set(idx_test)))
        else:
            pass
    elif split_type == 'tcr':
        if idx_val_fold < 0:
            idx_test_tcr = idx_shuffled[test_fold_start_index:test_fold_end_index]
            test_tcrs = unique_tcrs[idx_test_tcr]
            idx_test = [index for index, tcr in enumerate(x_tcr) if tcr in test_tcrs]
            idx_train = list(set(range(len(x_tcr))).difference(set(idx_test)))
        else:
            pass
        
    testData = dat.iloc[idx_test, :].sample(frac=1).reset_index(drop=True)
    trainData = dat.iloc[idx_train, :].sample(frac=1).reset_index(drop=True)
    
    print('================check Overlapping========================')
    print('number of overlapping tcrs: ', str(len(set(trainData.tcr).intersection(set(testData.tcr)))))
    print('number of overlapping epitopes: ', str(len(set(trainData.epi).intersection(set(testData.epi)))))
    
    return trainData, testData


def epi_token_tcr(df, name):
#     df['epi_tcr'] = '<epi>' + df['epi'] + '<eoepi>' + '$' + '<tcr>' + df['tcr'] + '<eotcr>' + '<EOS>'
    df['epi_tcr'] = df['epi'] + '$' + df['tcr'] + '<EOS>'
    df['epi_tcr'].to_csv(f'./data/{name}.txt', header=False, index=False, sep='\t')
    return

In [8]:
df = pd.read_csv('combined_dataset_repTCRs.csv')
df = df[:150008] # only use positively bind pairs

# Filter the dataframe to include only those epitopes with frequency > 100
epitope_counts = df['epi'].value_counts()
filtered_df = df[df['epi'].isin(epitope_counts[epitope_counts > 100].index)].reset_index(drop=True)

#### writing training and testing data for different splits into disk
trainData, testData = load_data_split(filtered_df, 'epi', 42)

number of overlapping tcrs:  2920
number of overlapping epitopes:  0


## ${EPI\$TCR1[SEP]EPI\$TCR2[SEP]EPI\$TCR3[EOS]}$

In [19]:
def create_k_shot_samples(df, name, k_shots, compact_format=False, seed=None):

    if seed is not None:
        np.random.seed(seed)

    epitope_groups = df.groupby('epi')['tcr'].apply(list)
    all_k_shot_samples = []

    # Iterate over each epitope group
    for epitope, tcrs in epitope_groups.items():
        total_tcrs = len(tcrs)

        # Number of k-shot samples to create
        num_samples = round(total_tcrs * 1.0) 
        num_samples = min(num_samples, len(tcrs) // k_shots)

        for _ in range(num_samples):
            # Sample k TCRs for this epitope, ensuring no duplicates
            sampled_tcrs = np.random.choice(tcrs, k_shots, replace=False)

            if compact_format:
                # EPI$TCR1$TCR2$TCR3<EOS> format
                k_shot_sample = f"{epitope}${'$'.join(sampled_tcrs)}<EOS>"
            else:
                # EPI$TCR1[SEP]EPI$TCR2[SEP]EPI$TCR3<EOS> format
                k_shot_sample = '[SEP]'.join([f"{epitope}${tcr}" for tcr in sampled_tcrs]) + '<EOS>'

            all_k_shot_samples.append(k_shot_sample)

    # Saving the k-shot samples to a file
    with open(f'{name}_{k_shots}_shot_samples_seed_{seed}.txt', 'w') as file:
        for sample in all_k_shot_samples:
            file.write(sample + '\n')

    # Indicating completion
    print(f'{k_shots}-shot samples for {name} are saved to {name}_{k_shots}_shot_samples with seed {seed}')

## ${EPI\$TCR1$TCR2$TCR3[EOS]}$

In [21]:
# Example usage
# create_k_shot_samples(trainData, 'training', 3, compact_format=False)
# create_k_shot_samples(testData, 'testing', 3, compact_format=False)

create_k_shot_samples(trainData, 'training', 10, compact_format=True, seed=99)
create_k_shot_samples(testData, 'testing', 10, compact_format=True, seed=99)

create_k_shot_samples(trainData, 'training', 10, compact_format=True, seed=73)
create_k_shot_samples(testData, 'testing', 10, compact_format=True, seed=73)

create_k_shot_samples(trainData, 'training', 10, compact_format=True, seed=42)
create_k_shot_samples(testData, 'testing', 10, compact_format=True, seed=42)

10-shot samples for training are saved to training_10_shot_samples with seed 99
10-shot samples for testing are saved to testing_10_shot_samples with seed 99
10-shot samples for training are saved to training_10_shot_samples with seed 73
10-shot samples for testing are saved to testing_10_shot_samples with seed 73
10-shot samples for training are saved to training_10_shot_samples with seed 42
10-shot samples for testing are saved to testing_10_shot_samples with seed 42


# Sample TCRs with replacements

In [4]:
def create_k_shot_samples_w_replacement(df, name, k_shots, compact_format=False):
    epitope_groups = df.groupby('epi')['tcr'].apply(list)
    all_k_shot_samples = []

    # Iterate over each epitope group
    for epitope, tcrs in epitope_groups.items():
        total_tcrs = len(tcrs)

        # Number of k-shot samples to create
        num_samples = round(total_tcrs * 1.0) 
        num_samples = min(num_samples, len(tcrs))

        for _ in range(num_samples):
            # Sample k TCRs for this epitope, ensuring no duplicates
            sampled_tcrs = np.random.choice(tcrs, k_shots, replace=False)

            if compact_format:
                # EPI$TCR1$TCR2$TCR3<EOS> format
                k_shot_sample = f"{epitope}${'$'.join(sampled_tcrs)}<EOS>"
            else:
                # EPI$TCR1[SEP]EPI$TCR2[SEP]EPI$TCR3<EOS> format
                k_shot_sample = '[SEP]'.join([f"{epitope}${tcr}" for tcr in sampled_tcrs]) + '<EOS>'

            all_k_shot_samples.append(k_shot_sample)

    # Saving the k-shot samples to a file
    with open(f'{name}_{k_shots}_shot_samples.txt', 'w') as file:
        for sample in all_k_shot_samples:
            file.write(sample + '\n')

    # Indicating completion
    print(f'{k_shots}-shot samples for {name} are saved to {name}_{k_shots}_shot_samples.txt')

# Example usage
create_k_shot_samples_w_replacement(trainData, 'training', 10, compact_format=False)
create_k_shot_samples_w_replacement(testData, 'testing', 10, compact_format=False)

create_k_shot_samples_w_replacement(trainData, 'training', 5, compact_format=False)
create_k_shot_samples_w_replacement(testData, 'testing', 5, compact_format=False)

create_k_shot_samples_w_replacement(trainData, 'training', 3, compact_format=False)
create_k_shot_samples_w_replacement(testData, 'testing', 3, compact_format=False)

10-shot samples for training are saved to training_10_shot_samples.txt
10-shot samples for testing are saved to testing_10_shot_samples.txt
5-shot samples for training are saved to training_5_shot_samples.txt
5-shot samples for testing are saved to testing_5_shot_samples.txt
3-shot samples for training are saved to training_3_shot_samples.txt
3-shot samples for testing are saved to testing_3_shot_samples.txt


In [5]:
# Example usage
create_k_shot_samples_w_replacement(trainData, 'training', 10, compact_format=True)
create_k_shot_samples_w_replacement(testData, 'testing', 10, compact_format=True)

create_k_shot_samples_w_replacement(trainData, 'training', 5, compact_format=True)
create_k_shot_samples_w_replacement(testData, 'testing', 5, compact_format=True)

create_k_shot_samples_w_replacement(trainData, 'training', 3, compact_format=True)
create_k_shot_samples_w_replacement(testData, 'testing', 3, compact_format=True)

10-shot samples for training are saved to training_10_shot_samples.txt
10-shot samples for testing are saved to testing_10_shot_samples.txt
5-shot samples for training are saved to training_5_shot_samples.txt
5-shot samples for testing are saved to testing_5_shot_samples.txt
3-shot samples for training are saved to training_3_shot_samples.txt
3-shot samples for testing are saved to testing_3_shot_samples.txt
