In [1]:
import os
os.chdir('/Users/alexascunceparis/Desktop/BSC/TCRranker_v3')

# TRAINING THE MODEL TCR-p

In [2]:
from potential_calc import *
from utils import *
from extract_contacts import filter_contacts

## TRAIN MODEL FOR TCR-P BY POSITION (P1_P9)

### Define the folds

In [None]:
import pandas as pd

def create_empty_contact_model(length):
    """Creates an empty contact model for a given peptide length."""
    return {f'P{i+1}': pd.DataFrame() for i in range(length)}

def main():
    # Step 1: List of PDB IDs to exclude from training
    validation_set = []
    training_set = []
    # Dictionary to store models by peptide length
    contact_models = {}

    # Step 2: Parse chain information from the general.txt file
    chain_dict = parse_general_file('./structures_annotation/general.txt')

    # Step 3: Path to the folder containing contact maps
    folder_path = './contact_maps'
    contact_files = [
        f for f in os.listdir(folder_path)
        if f.endswith('_contacts.csv') and f.split('_')[0] not in validation_set and f.split('_')[0] in training_set]
    
    print(len(contact_files), "contact files found in the folder.")

    # Step 4: Process each contact file
    for contact_file in contact_files:
        pdb_id = contact_file.split('_')[0]
        pdb_path = f"./pdb_files/{pdb_id}.pdb"

        print(f"Processing contacts for PDB ID: {pdb_id}")
        
        # Read the contacts for the current PDB ID
        contacts = pd.read_csv(os.path.join(folder_path, contact_file))

        # Check if the pdb_id exists in chain_dict, if not, use the default settings
        chains = chain_dict.get(pdb_id, {
            'tcra_chain': 'D',
            'tcrb_chain': 'E',
            'peptide_chain': 'C',
            'b2m_chain': 'B',
            'mhc_chain': 'A'
        })
        chain_dict[pdb_id] = chains

        if chains and all(chains.values()):
            try:
                # Filter contacts for TCR-peptide and TCR-MHC
                contacts_TCR_p, contacts_TCR_MHC = filter_contacts(
                    contacts,
                    chains['tcra_chain'],
                    chains['tcrb_chain'],
                    chains['peptide_chain'],
                    chains['mhc_chain'])

                if not contacts_TCR_p.empty:
                    # Extract the peptide sequence and determine its length
                    _, _, peptide_sequence = extract_specific_sequences(pdb_path, chain_dict, extract_sequences)
                    peptide_length = len(peptide_sequence)

                    # Create a model for this peptide length if it doesn't exist
                    if peptide_length not in contact_models:
                        contact_models[peptide_length] = create_empty_contact_model(peptide_length)

                    # Distribute contacts into appropriate positions (P1, P2, ..., Pn)
                    for _, contact in contacts_TCR_p.iterrows():
                        resid_to = contact['resid_to']

                        # Ensure resid_to falls within the peptide length
                        if 1 <= resid_to <= peptide_length:
                            position = f'P{resid_to}'
                            contact_models[peptide_length][position] = pd.concat(
                                [contact_models[peptide_length][position], contact.to_frame().T])

            except Exception as e:
                print(f"Error processing PDB ID {pdb_id}: {e}")

    # Step 5: Create directories and save potentials for each peptide length
    for length, model in contact_models.items():
        # Create a directory for the current peptide length if it doesn't exist
        length_directory = f"./model/TCR-p-L{length}"
        os.makedirs(length_directory, exist_ok=True)

        for position, contacts in model.items():
            if not contacts.empty:
                print(f"Calculating TCR-peptide potential for length {length}, position {position}")

                # Calculate the TCR-peptide potential for the current contact data
                data_TCR_p = calculate_potential(contacts, peptide=True)
                
                # Generate the output file name based on the residue position
                output_p = os.path.join(length_directory, f"tcr_p_potential_{position}.csv")
                
                # Save the result to a CSV file in the corresponding folder
                data_TCR_p.to_csv(output_p, index=False)
                print(f"TCR-peptide potential for length {length}, position {position} saved to {output_p}")
            else:
                print(f"No valid TCR-peptide contacts for length {length}, position {position}.")

    print("Training complete for all PDB IDs (excluding specified ones).")

if __name__ == "__main__":
    main()


143 contact files found in the folder.
Processing contacts for PDB ID: 7n2p
Processing contacts for PDB ID: 4jfd
Processing contacts for PDB ID: 7l1d
Processing contacts for PDB ID: 4jry
Processing contacts for PDB ID: 5wkf
Processing contacts for PDB ID: 7jwj
Processing contacts for PDB ID: 8gvb
Processing contacts for PDB ID: 5ivx
Processing contacts for PDB ID: 2ypl
Processing contacts for PDB ID: 5isz
Processing contacts for PDB ID: 1mwa
Processing contacts for PDB ID: 5wlg
Processing contacts for PDB ID: 7n2o
Processing contacts for PDB ID: 5jhd
Processing contacts for PDB ID: 6vm7
Processing contacts for PDB ID: 4ftv
Processing contacts for PDB ID: 3rgv
Processing contacts for PDB ID: 8ye4
Processing contacts for PDB ID: 8gom
Processing contacts for PDB ID: 7pb2
Processing contacts for PDB ID: 4ms8
Processing contacts for PDB ID: 2ak4
Processing contacts for PDB ID: 3uts
Processing contacts for PDB ID: 1nam
Processing contacts for PDB ID: 8en8
Processing contacts for PDB ID: 8shi

# TRAIN MODEL FOR TCR-MHC

### Select training set

In [None]:
import os
import pandas as pd

def main():
    # Step 1: Create a list of PDB IDs to exclude from training
    exclude_pdb_ids = []
    selected_pdb_files = ['5brz', '5c0a', '6p64', '3gsn', '5jzi', '7l1d', '7pb2', '7nme', '6avg', '3ffc', '7n2p', '7n2q', '8enh', '4prh', '6mtm', '7dzn', '3kps', '4mji', '7r80', '8f5a', '7dzm', '8shi', '6uon', '8rlv', '5w1v', '7ndu', '7na5', '5ivx', '1g6r', '1mwa', '2ol3', '4mxq']

    # Step 2: Read chain information from the general.txt file
    chain_dict = parse_general_file('./structures_annotation/general.txt')

    # Step 3: Path to the folder containing contact maps
    folder_path = './contact_maps'
    contact_files = [
        f for f in os.listdir(folder_path)
        if f.endswith('_contacts.csv') and f.split('_')[0] in selected_pdb_files]
    print(len(contact_files), "contact files found in the folder.")

    # Initialize DataFrame to store all contacts
    all_contacts_TCR_MHC = pd.DataFrame()

    # Step 4: Iterate over all contact files
    for contact_file in contact_files:
        pdb_id = contact_file.split('_')[0]  # Extract the PDB ID from the file name
        
        if pdb_id in exclude_pdb_ids:  # Skip if the PDB ID is in the exclusion list
            print(f"Skipping PDB ID: {pdb_id} (in exclusion list)")
            continue

        print(f"Processing contacts for PDB ID: {pdb_id}")
        
        # Read the contacts for the current PDB ID
        contacts = pd.read_csv(os.path.join(folder_path, contact_file))
        chains = chain_dict.get(pdb_id, {
            'tcra_chain': 'D',
            'tcrb_chain': 'E',
            'peptide_chain': 'C',
            'b2m_chain': 'B',
            'mhc_chain': 'A'
        })
        chain_dict[pdb_id] = chains

        if chains and all(chains.values()):  # Ensure all chains are present
            try:
                # Filter contacts for the current PDB ID
                contacts_TCR_p, contacts_TCR_MHC = filter_contacts(
                    contacts,
                    chains['tcra_chain'],
                    chains['tcrb_chain'],
                    chains['peptide_chain'],
                    chains['mhc_chain'])
                    #,threshold=1)

                # Accumulate the filtered contacts
                if not contacts_TCR_MHC.empty:
                    all_contacts_TCR_MHC = pd.concat([all_contacts_TCR_MHC, contacts_TCR_MHC], ignore_index=True)

            except Exception as e:
                print(f"Error processing contacts for PDB ID {pdb_id}: {e}")
        else:
            print(f"Missing chains for PDB ID: {pdb_id}. Skipping...")

    # Step 5: Check if we have valid data to calculate potentials
    if not all_contacts_TCR_MHC.empty:
        print("Calculating TCR-MHC potential")
        data_TCR_MHC = calculate_potential(all_contacts_TCR_MHC, peptide=True)
        output_p = "./model/TCR_MHC_weighted.csv"  # Output model path
        
        # Ensure the output directory exists
        os.makedirs(os.path.dirname(output_p), exist_ok=True)

        data_TCR_MHC.to_csv(output_p, index=False)
        print(f"TCR-MHC potential saved to {output_p}")
    else:
        print("No valid TCR-MHC contacts for the training set.")

if __name__ == "__main__":
    main()

32 contact files found in the folder.
Processing contacts for PDB ID: 7n2p
Processing contacts for PDB ID: 8rlv
Processing contacts for PDB ID: 7l1d
Processing contacts for PDB ID: 5ivx
Processing contacts for PDB ID: 1mwa
Processing contacts for PDB ID: 7pb2
Processing contacts for PDB ID: 8shi
Processing contacts for PDB ID: 7n2q
Processing contacts for PDB ID: 3gsn
Processing contacts for PDB ID: 2ol3
Processing contacts for PDB ID: 5w1v
Processing contacts for PDB ID: 6uon
Processing contacts for PDB ID: 5jzi
Processing contacts for PDB ID: 6mtm
Processing contacts for PDB ID: 7nme
Processing contacts for PDB ID: 8enh
Processing contacts for PDB ID: 7dzm
Processing contacts for PDB ID: 3ffc
Processing contacts for PDB ID: 4prh
Processing contacts for PDB ID: 7r80
Processing contacts for PDB ID: 1g6r
Processing contacts for PDB ID: 7ndu
Processing contacts for PDB ID: 8f5a
Processing contacts for PDB ID: 7na5
Processing contacts for PDB ID: 7dzn
Processing contacts for PDB ID: 5brz
