In [None]:
from mhcflurry import Class1AffinityPredictor
import pipeline 
import tcrgp
 
mhcflurry_model = Class1AffinityPredictor.load()
classificator = pipeline.load_class()

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import random
import math
import numpy as np
from tqdm import tqdm

# Constants
AMINO_ACIDS = "ARNDCEQGHILKMFPSTWYV"
NUM_AMINO_ACIDS = len(AMINO_ACIDS)
MIN_LENGTH = 8
MAX_LENGTH = 12
MHC_LENGTH = 34  
TCR_MAX_LENGTH = 20  
NUM_EPISODES = 10000
MAX_STEPS_PER_EPISODE = 100
REWARD_THRESHOLD = 0.8
MAX_ACTIONS = NUM_AMINO_ACIDS * MAX_LENGTH
SIMILARITY_THRESHOLD = 0.7


def one_hot_encode(sequence, alphabet=AMINO_ACIDS, max_length=MAX_LENGTH):
    char_to_int = dict((c, i) for i, c in enumerate(alphabet))
    one_hot = torch.zeros(max_length, len(alphabet))
    for i, char in enumerate(sequence[:max_length]):
        one_hot[i, char_to_int[char]] = 1
    return one_hot.flatten()

def encode_mhc_sequence(mhc_symbol, mhc_mapping):
    mhc_sequence = mhc_mapping.get(mhc_symbol, "Default_Sequence")
    return one_hot_encode(mhc_sequence, alphabet=AMINO_ACIDS, max_length=MHC_LENGTH)    

def encode_tcr_sequence(tcr_sequence):
    return one_hot_encode(tcr_sequence, max_length=TCR_MAX_LENGTH)

def decode_state(state, alphabet=AMINO_ACIDS, max_length=MAX_LENGTH):
    state = state.reshape(-1, len(alphabet))
    peptide_sequence = ""
    for i in range(min(max_length, state.shape[0])):
        aa_index = torch.argmax(state[i]).item()
        peptide_sequence += alphabet[aa_index]
    return peptide_sequence

class PeptideQNetwork(nn.Module):
    def __init__(self, input_size, output_size):
        super(PeptideQNetwork, self).__init__()
        self.fc1 = nn.Linear(input_size, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, output_size)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x


def calculate_mhc_rewards(peptide_sequence, mhc_symbols):
    # Vectorized MHC reward calculation
    predicted_ic50s = mhcflurry_model.predict([peptide_sequence] * len(mhc_symbols), mhc_symbols)
    rewards = 1.0 - (np.array(predicted_ic50s) / 50000)
    return rewards

def calculate_tcr_reward(peptide_sequence, tcr_sequences):
    rewards = []
    for tcr_sequence in tcr_sequences:
        tcr_predict = tcrgp.predict([peptide_sequence], [tcr_sequences])
        predicted_tcr = tcr_predict[0]
        rewards.append(predicted_tcr)
    
    return rewards

def calculate_classification_reward(peptide_sequence):
    peptide_sequence = [peptide_sequence]  
    probabilities = classifier.predict_proba(peptide_sequence)
    class_1_probabilities = [prob[1] for prob in probabilities]
    return class_1_probabilities


def calculate_best_mhc_tcr_reward(peptide_sequence, mhc_symbols, tcr_sequences):
    # Preallocate arrays for MHC and TCR rewards
    mhc_rewards = np.empty(len(mhc_symbols))
    tcr_rewards = np.empty(len(tcr_sequences))

    # Calculate rewards for each MHC symbol
    for i, mhc_symbol in enumerate(mhc_symbols):
        mhc_rewards[i] = calculate_mhc_reward(peptide_sequence, mhc_symbol)

    # Calculate rewards for each TCR sequence
    for j, tcr_sequence in enumerate(tcr_sequences):
        tcr_rewards[j] = calculate_tcr_reward(peptide_sequence, tcr_sequence)

    # Calculate total rewards by multiplying MHC and TCR rewards
    total_rewards_matrix = mhc_rewards[:, np.newaxis] * tcr_rewards  # Broadcasting multiplication

    # Find the index of the maximum total reward
    max_reward_index = np.unravel_index(np.argmax(total_rewards_matrix), total_rewards_matrix.shape)
    best_mhc = mhc_symbols[max_reward_index[0]]
    best_tcr = tcr_sequences[max_reward_index[1]]
    best_reward = total_rewards_matrix[max_reward_index]

    # Extract the highest individual MHC and TCR rewards
    individual_mhc_reward = np.max(mhc_rewards)
    individual_tcr_reward = np.max(tcr_rewards)

    return best_reward, best_mhc, best_tcr, individual_mhc_reward, individual_tcr_reward



def compute_reward(peptide_sequence, mhc_symbol, tcr_sequences):
    classification_reward = calculate_classification_reward(peptide_sequence)
    if classification_reward < 0.5:
        return 0, None, None, 0, 0
    else:
        # Use the updated function to get all necessary reward values
        best_reward, best_mhc, best_tcr, individual_mhc_reward, individual_tcr_reward = calculate_best_mhc_tcr_reward(peptide_sequence, [mhc_symbol], tcr_sequences)
        return best_reward, best_mhc, best_tcr, individual_mhc_reward, individual_tcr_reward
        
def modify_peptide(peptide_sequence, action):
    sequence_length = len(peptide_sequence)
    position = action // NUM_AMINO_ACIDS
    new_aa = AMINO_ACIDS[action % NUM_AMINO_ACIDS]

    if position < sequence_length:
        peptide_sequence = peptide_sequence[:position] + new_aa + peptide_sequence[position + 1:]
    elif sequence_length < MAX_LENGTH:
        peptide_sequence = peptide_sequence[:position] + new_aa + peptide_sequence[position:]
    return peptide_sequence
    
def select_action(state, q_network, epsilon):
    if random.random() < epsilon:
        return random.randrange(NUM_AMINO_ACIDS * MAX_LENGTH)
    with torch.no_grad():
        return q_network(state).max(0)[1].item()

def perform_action(action, state, mhc_symbol, tcr_sequence, mhc_mapping):
    peptide_sequence = decode_state(state[:NUM_AMINO_ACIDS * MAX_LENGTH])
    new_peptide_sequence = modify_peptide(peptide_sequence, action) if action < MAX_ACTIONS else peptide_sequence

    total_reward, best_mhc, best_tcr, individual_mhc_reward, individual_tcr_reward = compute_reward(new_peptide_sequence, mhc_symbol, [tcr_sequence])
    new_encoded_peptide = one_hot_encode(new_peptide_sequence)
    encoded_mhc = encode_mhc_sequence(mhc_symbol, mhc_mapping)
    encoded_tcr = encode_tcr_sequence(tcr_sequence)

    new_state = torch.cat([new_encoded_peptide, encoded_mhc, encoded_tcr], dim=0)
    return new_state, total_reward, best_mhc, best_tcr, individual_mhc_reward, individual_tcr_reward


def needleman_wunsch_normalized(seq1, seq2, match=3, mismatch=1, gap=0):
    def calculate_max_score(seq1, seq2, match, mismatch, gap):
        # Calculate the maximum possible score
        max_score = max(len(seq1), len(seq2)) * match
        return max_score

    # Create a matrix to store alignment scores
    matrix = np.zeros((len(seq1) + 1, len(seq2) + 1))

    # Initialize the first row and first column with gap penalties
    for i in range(len(seq1) + 1):
        matrix[i][0] = i * gap
    for j in range(len(seq2) + 1):
        matrix[0][j] = j * gap

    # Fill in the matrix
    for i in range(1, len(seq1) + 1):
        for j in range(1, len(seq2) + 1):
            match_score = matrix[i-1][j-1] + (match if seq1[i-1] == seq2[j-1] else mismatch) 
            delete_score = matrix[i-1][j] + gap
            insert_score = matrix[i][j-1] + gap
            matrix[i][j] = max(match_score, delete_score, insert_score)

    # Calculate the normalized similarity score
    similarity_score = matrix[len(seq1)][len(seq2)] / calculate_max_score(seq1, seq2, match, mismatch, gap)

    return similarity_score


def train_q_network(mhc_mapping, tcr_sequence_list, max_selected_peptides=100):
    input_size = (NUM_AMINO_ACIDS * MAX_LENGTH) + (NUM_AMINO_ACIDS * MHC_LENGTH) + (NUM_AMINO_ACIDS * TCR_MAX_LENGTH)
    output_size = MAX_ACTIONS
    learning_rate = 0.001
    epsilon_start = 1.0
    epsilon_end = 0.01
    epsilon_decay = 300

    q_network = PeptideQNetwork(input_size, output_size)
    optimizer = optim.Adam(q_network.parameters(), lr=learning_rate)
    loss_fn = nn.MSELoss()

    peptide_rewards = {}
    data_buffer = []
    mhc_usage = {}
    tcr_usage = {}
    
    # Helper function to calculate similarity with existing peptides
    def is_too_similar(new_peptide, peptide_rewards, threshold):
        for existing_peptide in peptide_rewards.keys():
            if needleman_wunsch_normalized(new_peptide, existing_peptide) > threshold:
                return True
        return False
        
    for episode in tqdm(range(NUM_EPISODES), desc="Training Episodes"):
        initial_peptide = ''.join(random.choices(AMINO_ACIDS, k=random.randint(MIN_LENGTH, MAX_LENGTH)))
        encoded_peptide = one_hot_encode(initial_peptide)

        mhc_symbol = random.choice(list(mhc_mapping.keys()))
        encoded_mhc = encode_mhc_sequence(mhc_symbol, mhc_mapping)

        tcr_sequence = random.choice(tcr_sequence_list)
        encoded_tcr = encode_tcr_sequence(tcr_sequence)

        state = torch.cat([encoded_peptide, encoded_mhc, encoded_tcr], dim=0)

        for t in range(MAX_STEPS_PER_EPISODE):
            epsilon = epsilon_end + (epsilon_start - epsilon_end) * math.exp(-1. * t / epsilon_decay)
            state_tensor = torch.FloatTensor(state)
            action = select_action(state_tensor, q_network, epsilon)

            new_state, total_reward, best_mhc, best_tcr, individual_mhc_reward, individual_tcr_reward = perform_action(action, state_tensor, mhc_symbol, tcr_sequence, mhc_mapping)
            new_peptide = decode_state(new_state[:NUM_AMINO_ACIDS * MAX_LENGTH])

            new_state_tensor = torch.FloatTensor(new_state)
            current_q = q_network(state_tensor)[action].unsqueeze(0)
            max_next_q = q_network(new_state_tensor).max().item()
            expected_q = total_reward + (0.99 * max_next_q)
            loss = loss_fn(current_q, torch.tensor([expected_q], dtype=torch.float))

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            state = new_state

            data_buffer.append((episode, t, new_peptide, total_reward, loss.item(), best_mhc, best_tcr, individual_mhc_reward, individual_tcr_reward))

            if total_reward >= REWARD_THRESHOLD and mhc_usage.get(best_mhc, 0) < 10 and tcr_usage.get(best_tcr, 0) < 10:
                if new_peptide not in peptide_rewards or peptide_rewards[new_peptide]['total_reward'] < total_reward:
                    if not is_too_similar(new_peptide, peptide_rewards, SIMILARITY_THRESHOLD):
                        peptide_rewards[new_peptide] = {
                            'total_reward': total_reward,
                            'mhc_reward': individual_mhc_reward,
                            'mhc_protein': best_mhc,
                            'tcr_reward': individual_tcr_reward,
                            'tcr_sequence': best_tcr
                        }
                        mhc_usage[best_mhc] = mhc_usage.get(best_mhc, 0) + 1
                        tcr_usage[best_tcr] = tcr_usage.get(best_tcr, 0) + 1

            if len(data_buffer) >= 1000:
                with open("training_log_dict.txt", "a") as file:
                    for data in data_buffer:
                        file.write(f"{','.join(map(str, data))}\n")
                data_buffer.clear()

            if len(peptide_rewards) >= max_selected_peptides:
                print(f"Reached {max_selected_peptides} selected peptides. Terminating training.")
                break

        if len(peptide_rewards) >= max_selected_peptides:
            break

    with open("training_log_dict.txt", "a") as file:
        for data in data_buffer:
            file.write(f"{','.join(map(str, data))}\n")

    return peptide_rewards

#### Synthetic 

In [32]:
synth = pd.read_csv('/home/sjurc/Documents/python/neoantigen/automate/results/synthetic_results.csv')


In [33]:
mhc_mapping = dict(zip(synth['mhc_molecule'], synth['sequence']))

In [34]:
mhc_mapping = dict(list(mhc_mapping.items()))
tcr_sequences = tcr_sequence_list


peptides_synth = train_q_network(mhc_mapping, tcr_sequences)

Training Episodes:   9%|████▌                                             | 923/10000 [47:21<7:45:41,  3.08s/it]

Reached 100 selected peptides. Terminating training.





In [35]:
peptides_synth

{'CMMAKPGNIITR': {'total_reward': 0.8416502457058846,
  'mhc_reward': 0.9681164272540279,
  'mhc_protein': 'A*03:01',
  'tcr_reward': 0.8693688300415965,
  'tcr_sequence': 'CASTSLLLSSTYEQYF'},
 'TDGDAVLNINVW': {'total_reward': 0.843514191283461,
  'mhc_reward': 0.9935921757370347,
  'mhc_protein': 'B*44:02',
  'tcr_reward': 0.8489541402213159,
  'tcr_sequence': 'CASSQEGRDRGDEQYF'},
 'SDGPYVLHHNQW': {'total_reward': 0.8346809807867269,
  'mhc_reward': 0.9650600067725384,
  'mhc_protein': 'B*44:02',
  'tcr_reward': 0.864900601961696,
  'tcr_sequence': 'CASSQEGRDRGDEQYF'},
 'RHCCGVSYIKSL': {'total_reward': 0.8481511594888836,
  'mhc_reward': 0.9683243928166887,
  'mhc_protein': 'B*38:01',
  'tcr_reward': 0.8758956872105206,
  'tcr_sequence': 'CASSAGGEVEQFF'},
 'QHCCGQNRIESL': {'total_reward': 0.8930926329628337,
  'mhc_reward': 0.9979520491840042,
  'mhc_protein': 'B*38:01',
  'tcr_reward': 0.8949253961581511,
  'tcr_sequence': 'CASSAGGEVEQFF'},
 'HHCCTATRFCML': {'total_reward': 0.8114806

In [36]:
from collections import Counter


mhc_counter = Counter()
tcr_seq_counter = Counter()

for key in peptides_synth:
    mhc_counter[peptides_synth[key]['mhc_protein']] += 1
    tcr_seq_counter[peptides_synth[key]['tcr_sequence']] += 1

mhc_counter, tcr_seq_counter

(Counter({'A*03:01': 2,
          'B*44:02': 3,
          'B*38:01': 6,
          'B*35:01': 3,
          'A*30:02': 4,
          'B*45:01': 6,
          'B*15:03': 7,
          'A*29:02': 3,
          'A*24:03': 1,
          'C*14:02': 2,
          'A*32:01': 1,
          'B*15:01': 5,
          'A*33:01': 1,
          'A*25:01': 1,
          'B*57:01': 10,
          'C*12:03': 9,
          'B*27:05': 3,
          'A*11:01': 2,
          'B*37:01': 1,
          'A*24:02': 6,
          'B*44:03': 1,
          'A*02:02': 1,
          'B*53:01': 3,
          'A*02:01': 1,
          'C*06:02': 2,
          'A*01:01': 2,
          'B*58:01': 4,
          'B*15:17': 3,
          'A*31:01': 3,
          'B*54:01': 4}),
 Counter({'CASTSLLLSSTYEQYF': 1,
          'CASSQEGRDRGDEQYF': 3,
          'CASSAGGEVEQFF': 4,
          'CARRTGGSAFF': 6,
          'CSARDVSPGTQYF': 3,
          'CATALRDRVDQPQHF': 6,
          'CASSLGGGNQPQHF': 2,
          'CASSLALAYGYTF': 3,
          'CASTPGMGGTYEQYF': 2