# Train VAE to generate sequences

### dataset

In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
import pandas as pd

df = pd.read_csv("/data/human_virus_600k_seq_label_20aa.csv")
seq_20aa = df['sequence'].to_list()
label_seq = df['label'].to_list()
label_20aa = [1 if v == 'human' else 0 for v in label_seq]

### For accelerate training, we compute the embeddings of the whole dataset first.

To speed up training, we precompute embedding to avoid memory overflow.

In [None]:
import torch
from transformers import AutoTokenizer, EsmModel
from tqdm import tqdm

seq = seq_20aa

tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")  # ESM model path, you can down load from https://huggingface.co/facebook/esm2_t33_650M_UR50D
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

ESMmodel = EsmModel.from_pretrained("./post_train_esm/checkpoint-14980").to(device) # using ESM or LoRA-ESM to get embedding feature
ESMmodel = torch.nn.DataParallel(ESMmodel)  

batch_size = 4000  # Adjust this value according to your memory situation

batches_tcra = [seq[i:i+batch_size] for i in range(0, len(seq), batch_size)]

all_tcra_last_hidden_states = []

for batch_tcra in tqdm(batches_tcra):  
    batch_tcra_inputs = tokenizer(batch_tcra, return_tensors="pt", padding='max_length', truncation=True, max_length=20+2).to(device)

    with torch.no_grad():
        batch_tcra_outputs = ESMmodel(**batch_tcra_inputs)

    last_hidden_state = batch_tcra_outputs.last_hidden_state.cpu()
    all_tcra_last_hidden_states.append(last_hidden_state)
    

total_seq_last_hidden_states = torch.cat(all_tcra_last_hidden_states, dim=0)

torch.save(total_seq_last_hidden_states[:200000,:-1,:],"esm_human_virus_0to200000")
torch.save(total_seq_last_hidden_states[200000:400000,:-1,:],"esm_human_virus_200000to400000")
torch.save(total_seq_last_hidden_states[400000:,:-1,:],"esm_human_virus_400000to610760")

In [None]:
import torch

# This will take up ~200G of memory footprint
esm_data_tmp1 = torch.load("esm_human_virus_0to200000")
esm_data_tmp2 = torch.load("esm_human_virus_0to200000")
esm_data_tmp3 = torch.load("esm_human_virus_0to200000")

esm_data=torch.cat((esm_data_tmp1,esm_data_tmp2,esm_data_tmp3),dim=0)
esm_data.shape

del esm_data_tmp1
del esm_data_tmp2
del esm_data_tmp3

In [None]:
# Human Virus dataset & dataloader
import encoding_matrix
from data_util import seq_encoding_with_matrix, data_padding

train_peptide_human_virus = [seq_encoding_with_matrix(seq,matrix=encoding_matrix.NUMBER_MATRIX) for seq in seq_20aa]

padding_seq = 'A'*20 
padding_seq_NUMBER = [seq_encoding_with_matrix(padding_seq,matrix=encoding_matrix.NUMBER_MATRIX)]

padding_type = 'end'
all_NUMBER_encoding = data_padding(padding_seq_NUMBER+train_peptide_human_virus,padding_type)[1:,:,:]

import numpy as np  
from sklearn.model_selection import train_test_split  
train_pep, X_temp, train_labels, label_temp, train_esm, esm_temp = train_test_split(all_NUMBER_encoding, label_20aa, esm_data, test_size=0.3, random_state=42)  
valid_pep, test_pep, valid_labels, test_labels, valid_esm, test_esm = train_test_split(X_temp, label_temp, esm_temp, test_size=0.5, random_state=42)  


import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class TrainDataset(torch.utils.data.Dataset):
    def __init__(self, tcra, esm, label):
        self.tcra_inputs = tcra
        self.label = torch.tensor(label,dtype=torch.long)
        self.esm = esm


    def __getitem__(self, index):
        return self.tcra_inputs[index], self.esm[index], self.label[index]

    def __len__(self):
        return len(self.tcra_inputs)


train_dataset = TrainDataset(train_esm, train_pep, train_labels)
valid_dataset = TrainDataset(valid_esm, valid_pep, valid_labels)
test_dataset = TrainDataset(test_esm, test_pep, test_labels)

train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=64, shuffle=False)
valid_loader = torch.utils.data.DataLoader(dataset=valid_dataset, batch_size=64, shuffle=False)


### VAE model

actually we only train the decoder and classifier, the encoder is fixed for the computation resource constrain.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

def apply_rope(x, seq_len, head_dim):
    theta = 1.0 / (10000 ** (torch.arange(0, head_dim, 2, device=x.device).float() / head_dim))
    seq_positions = torch.arange(0, seq_len, device=x.device).float().unsqueeze(1)
    cos = torch.cos(seq_positions * theta).unsqueeze(0).unsqueeze(0)
    sin = torch.sin(seq_positions * theta).unsqueeze(0).unsqueeze(0)

    x1, x2 = x[..., ::2], x[..., 1::2]
    x = torch.cat([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1)
    return x

class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, dropout=0.0):
        super(MultiHeadAttention, self).__init__()
        assert embed_dim % num_heads == 0, "Embedding dimension must be divisible by number of heads"

        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        # Linear projections for queries, keys, and values
        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.k_proj = nn.Linear(embed_dim, embed_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)

        # Output linear projection
        self.out_proj = nn.Linear(embed_dim, embed_dim)

        self.dropout = nn.Dropout(dropout)

    def forward(self, query, key, value, mask=None):
        # Get batch size and sequence lengths
        batch_size, seq_len, _ = query.size()

        # Linear projections
        Q = self.q_proj(query)  # (batch_size, seq_len, embed_dim)
        K = self.k_proj(key)    # (batch_size, seq_len, embed_dim)
        V = self.v_proj(value)  # (batch_size, seq_len, embed_dim)

        # Split into multiple heads and reshape
        Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)  # (batch_size, num_heads, seq_len, head_dim)
        K = K.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)  # (batch_size, num_heads, seq_len, head_dim)
        V = V.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)  # (batch_size, num_heads, seq_len, head_dim)

        # Apply RoPE to queries and keys
        Q = apply_rope(Q, seq_len, self.head_dim)
        K = apply_rope(K, seq_len, self.head_dim)

        # Scaled dot-product attention
        scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim ** 0.5)  # (batch_size, num_heads, seq_len, seq_len)

        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))

        attention_weights = F.softmax(scores, dim=-1)  # (batch_size, num_heads, seq_len, seq_len)
        attention_weights = self.dropout(attention_weights)

        # Weighted sum of values
        attention_output = torch.matmul(attention_weights, V)  # (batch_size, num_heads, seq_len, head_dim)

        # Concatenate heads and project
        attention_output = attention_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.embed_dim)  # (batch_size, seq_len, embed_dim)
        output = self.out_proj(attention_output)  # (batch_size, seq_len, embed_dim)

        return output, attention_weights
    


In [None]:
import torch.nn as nn
class Decoder_esm_to_seq_with_Contrastive(nn.Module):
    def __init__(self, decoder_hidden_size, decoder_num_heads):
        super(Decoder_esm_to_seq_with_Contrastive, self).__init__()

        # parameter
        self.decoder_hidden_size = decoder_hidden_size # 1280
        self.decoder_num_heads = decoder_num_heads

        # network
        self.re_pos_net_linear = nn.Linear(self.decoder_hidden_size,self.decoder_hidden_size)
        self.re_pos_net_attn = MultiHeadAttention(self.decoder_hidden_size , 4)

        self.hidden_to_aa = nn.Linear(self.decoder_hidden_size, 21)

    
    def forward(self, esm_x, label):
        esm_x = self.re_pos_net_linear(esm_x)
        # esm_x,_ = self.re_pos_net_attn(esm_x, esm_x, esm_x)
        decoder_ids = self.hidden_to_aa(esm_x) # attn_out:[batch_size, max_seq_length, decoder_hidden_size] -> decoder_ids:[batch_size, max_seq_length, 21]
        decoder_ids = nn.functional.elu(decoder_ids) 
        decoder_ids = decoder_ids.permute(0,2,1) # decoder_ids:[batch_size, max_seq_length, 21] -> decoder_ids:[batch_size, 21, max_seq_length] 直接permute，在AA之间加入attention会大幅度降低最终还原的准确性
        return decoder_ids, esm_x, label


def loss_function2(recon_x, x, re_pos_esm, labels, margin=200):
    BCE = nn.functional.cross_entropy(recon_x, x, label_smoothing=0.2)
    re_pos_esm = re_pos_esm.reshape(re_pos_esm.shape[0], -1)
    categories = labels.unsqueeze(1)  
    pairwise_matrix = (categories == categories.T).float().to(device)
    distance_matrix = torch.cdist(re_pos_esm, re_pos_esm, p=2).to(device)  
    positive_loss = pairwise_matrix * torch.pow(distance_matrix, 2)
    negative_loss = (1 - pairwise_matrix) * torch.pow(torch.clamp(margin - distance_matrix, min=0.0), 2)
    CL = torch.mean(positive_loss + negative_loss)

    return BCE + CL, BCE, CL


def loss_function(recon_x, x):
    BCE = nn.functional.cross_entropy(recon_x, x, label_smoothing=0.2)
    return BCE

model = Decoder_esm_to_seq_with_Contrastive(1280,4).to(device)


In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

num_epochs = 6
for epoch in range(num_epochs):
    model.train()
    for i, (embed, labels, l_binary) in enumerate(train_loader):
        # Forward pass
        outputs, esm_x, label = model(embed[:,1:,:].float().to(device), l_binary)
        labels_onehot = nn.functional.one_hot(labels.squeeze().to(torch.int64), num_classes=21).permute(0,2,1) # 
        # loss = loss_function(recon_x = outputs, x = labels_onehot.squeeze().to(device).float()) # only compute reconstruction loss
        labels_onehot = nn.functional.one_hot(labels.squeeze().to(torch.int64), num_classes=21).permute(0,2,1) # 
        loss, BCE, CL = loss_function2(recon_x = outputs, 
                              x = labels_onehot.squeeze().to(device).float(), 
                              re_pos_esm=esm_x,
                              labels=l_binary)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    diff_mask = torch.ne(outputs.argmax(dim=1).int().cpu(), labels_onehot.argmax(dim=1).int().cpu())
    diff_count = torch.mean(torch.sum(diff_mask, dim=1).float())
    max_count = torch.max(torch.sum(diff_mask, dim=1).float())
    mse = nn.functional.mse_loss(outputs.argmax(dim=1).float().cpu(), labels_onehot.argmax(dim=1).float().cpu())
    print("train reconstruction",diff_count, 'max', max_count, 'mse', mse)
    print("loss",loss.cpu().item())
    
    if (epoch + 1) % 3 ==0:
            model.eval()
            valid_probs = []
            valid_cls_labels = []
            for i, (embed, labels, l_binary) in enumerate(valid_loader):
                # Forward pass
                outputs, esm_x, label = model(embed[:,1:,:].float().to(device), l_binary)
                labels_onehot = nn.functional.one_hot(labels.squeeze().to(torch.int64), num_classes=21).permute(0,2,1) # 
                # loss = loss_function(recon_x = outputs, x = labels_onehot.squeeze().to(device).float())
                labels_onehot = nn.functional.one_hot(labels.squeeze().to(torch.int64), num_classes=21).permute(0,2,1) # 
                loss, BCE, CL = loss_function2(recon_x = outputs, x = labels_onehot.squeeze().to(device).float(),
                              re_pos_esm=esm_x,
                              labels=l_binary)

            print("                valid loss",loss.cpu().item())
            diff_mask = torch.ne(outputs.argmax(dim=1).int().cpu(), labels_onehot.argmax(dim=1).int().cpu())
            diff_count = torch.mean(torch.sum(diff_mask, dim=1).float())
            max_count = torch.max(torch.sum(diff_mask, dim=1).float())
            mse = nn.functional.mse_loss(outputs.argmax(dim=1).float().cpu(), labels_onehot.argmax(dim=1).float().cpu())
            print("      valid reconstruction",diff_count, 'max', max_count, 'mse', mse)