# Model training
general training procedure   
Step 1: Pretraining on Masked Residue  
Prepare your protein sequence dataset.  
Mask residues randomly (e.g., 15% of tokens) and train the model to predict the   masked residues.  
Save the model weights after pretraining.  
Step 2: Fine-Tuning on Contact Map Prediction  
Prepare your labeled dataset with protein sequences and corresponding contact maps.  
Initialize your model with the pretrained weights.  
Replace the output layer (if necessary) with one suited for contact map prediction   (e.g., a pairwise scoring mechanism or CNN layers to predict NxN contact maps).  
Train on the contact map data in a supervised manner.  

In [1]:
# example code for training on Masked Residue 
# generate with ChatGPT
# this code has a problem with handlign padding, but the essence is there

import torch
import torch.nn as nn
import torch.optim as optim
import random

# Define special tokens
MASK_TOKEN = "<MASK>"
VOCAB = ["A", "C", "D", "E", "F", "G", "H", "I", "K", "L", "M", "N", "P", "Q", "R", "S", "T", "V", "W", "Y"]  # Example vocabulary
VOCAB += [MASK_TOKEN]
TOKEN_TO_ID = {token: idx for idx, token in enumerate(VOCAB)}
ID_TO_TOKEN = {idx: token for token, idx in TOKEN_TO_ID.items()}

# Function to generate masked sequences
def mask_sequence(sequence, mask_fraction=0.15):
    """
    Mask a fraction of the amino acids in the sequence with the <MASK> token.

    Args:
        sequence (str): Input amino acid sequence.
        mask_fraction (float): Fraction of tokens to mask.

    Returns:
        tuple: Masked sequence (list of tokens) and target (list of original tokens).
    """
    tokens = list(sequence)
    target = tokens.copy()

    num_to_mask = int(len(tokens) * mask_fraction)
    mask_indices = random.sample(range(len(tokens)), num_to_mask)

    for idx in mask_indices:
        tokens[idx] = MASK_TOKEN

    return tokens, target

# Create synthetic dataset
def create_dataset(sequences, mask_fraction=0.15):
    """
    Generate masked sequences and targets for training.

    Args:
        sequences (list): List of amino acid sequences.
        mask_fraction (float): Fraction of tokens to mask.

    Returns:
        list: List of (masked_sequence, target_sequence) pairs.
    """
    dataset = []
    for seq in sequences:
        masked_seq, target_seq = mask_sequence(seq, mask_fraction)
        dataset.append((masked_seq, target_seq))
    return dataset

# Define a simple model
class ProteinMLMModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=embedding_dim, nhead=4), num_layers=2
        )
        self.fc = nn.Linear(embedding_dim, vocab_size)

    def forward(self, x):
        embedded = self.embedding(x)
        transformed = self.transformer(embedded)
        logits = self.fc(transformed)
        return logits

# Training loop
def train_model(model, dataset, epochs=10, batch_size=8, lr=1e-3):
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(epochs):
        total_loss = 0
        for i in range(0, len(dataset), batch_size):
            batch = dataset[i:i+batch_size]
            masked_seqs, target_seqs = zip(*batch)

            # Convert tokens to IDs
            masked_ids = [torch.tensor([TOKEN_TO_ID[tok] for tok in seq]) for seq in masked_seqs]
            target_ids = [torch.tensor([TOKEN_TO_ID[tok] for tok in seq]) for seq in target_seqs]

            # Pad sequences to the same length
            masked_ids = nn.utils.rnn.pad_sequence(masked_ids, batch_first=True, padding_value=TOKEN_TO_ID[MASK_TOKEN])
            target_ids = nn.utils.rnn.pad_sequence(target_ids, batch_first=True, padding_value=-100)

            optimizer.zero_grad()

            # Forward pass
            logits = model(masked_ids)

            # Reshape logits and targets for loss computation
            logits = logits.view(-1, logits.size(-1))
            target_ids = target_ids.view(-1)

            loss = criterion(logits, target_ids)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        print(f"Epoch {epoch + 1}/{epochs}, Loss: {total_loss / len(dataset):.4f}")

# Example usage
if __name__ == "__main__":
    # Example sequences
    sequences = [
        "ACDEFGHIKLMNPQRSTVWY",
        "LMNPQRSTVWYACDEFGHI",
        "QRSTVWYACDEFGHIKLMN",
    ]

    # Generate dataset
    dataset = create_dataset(sequences, mask_fraction=0.15)

    # Initialize model
    model = ProteinMLMModel(vocab_size=len(VOCAB), embedding_dim=32, hidden_dim=64)

    # Train model
    train_model(model, dataset, epochs=5)




Epoch 1/5, Loss: 1.0440
Epoch 2/5, Loss: 0.8455
Epoch 3/5, Loss: 0.6898
Epoch 4/5, Loss: 0.5936
Epoch 5/5, Loss: 0.5116
