In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import esm
from transformers import AutoTokenizer, AutoModel
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

class AntibodyDataset(Dataset):
    def __init__(self, seq_df, asec_df, binding_df, stability_df, tokenizer, max_length=1024):
        """
        Dataset class for antibody sequences and their properties

        Args:
            seq_df: DataFrame with VH and VL sequences
            asec_df: DataFrame with aggregation data
            binding_df: DataFrame with binding data
            stability_df: DataFrame with thermal stability data
            tokenizer: ESM-2 tokenizer
            max_length: Maximum sequence length for padding/truncation
        """
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.process_data(seq_df, asec_df, binding_df, stability_df)

    def process_data(self, seq_df, asec_df, binding_df, stability_df):
        """Merge and process all experimental data"""
        # Process sequences
        self.sequences = seq_df.copy()

        # Combine VH and VL sequences with a linker
        self.sequences['combined_sequence'] = (
            self.sequences['VH_seq'] +
            'GGGGS' +  # Add a standard linker
            self.sequences['VL_seq']
        )

        # Merge experimental data
        merged_data = pd.merge(
            self.sequences,
            asec_df[['ID', '% Aggregate']],
            left_on='Sample ID',
            right_on='ID',
            how='left'
        )

        merged_data = pd.merge(
            merged_data,
            binding_df[['Sample ID', 'KD (nM)']],
            on='Sample ID',
            how='left'
        )

        merged_data = pd.merge(
            merged_data,
            stability_df[['Sample', 'Tm1', 'Tm2']],
            left_on='Sample ID',
            right_on='Sample',
            how='left'
        )

        # Convert and clean numeric columns
        merged_data['% Aggregate'] = pd.to_numeric(merged_data['% Aggregate'], errors='coerce')
        merged_data['KD (nM)'] = pd.to_numeric(merged_data['KD (nM)'], errors='coerce')

        # Handle missing values
        self.processed_data = merged_data.fillna({
            '% Aggregate': merged_data['% Aggregate'].mean(),
            'KD (nM)': merged_data['KD (nM)'].mean(),
            'Tm1': merged_data['Tm1'].mean(),
            'Tm2': merged_data['Tm2'].mean()
        })

        # Scale targets
        self.scaler = StandardScaler()
        self.targets = self.scaler.fit_transform(
            self.processed_data[['% Aggregate', 'KD (nM)', 'Tm1', 'Tm2']].values
        )

    def __len__(self):
        return len(self.processed_data)

    def __getitem__(self, idx):
        sequence = self.processed_data.iloc[idx]['combined_sequence']

        # Tokenize sequence
        inputs = self.tokenizer(
            sequence,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        # Get target values
        targets = torch.tensor(self.targets[idx], dtype=torch.float32)

        return {
            'input_ids': inputs['input_ids'].squeeze(0),
            'attention_mask': inputs['attention_mask'].squeeze(0),
            'targets': targets
        }

class ESM2ForAntibodyPrediction(nn.Module):
    def __init__(self, model_name="facebook/esm2_t33_650M_UR50D", num_labels=4):
        """
        Fine-tuning ESM-2 for antibody property prediction

        Args:
            model_name: Name of the ESM-2 model to use
            num_labels: Number of properties to predict
        """
        super().__init__()

        # Load pretrained ESM-2 model
        self.esm2 = AutoModel.from_pretrained(model_name)

        # Freeze some layers (optional)
        self.freeze_layers(num_layers_to_freeze=8)

        # Add prediction head
        self.dropout = nn.Dropout(0.1)
        self.classifier = nn.Sequential(
            nn.Linear(self.esm2.config.hidden_size, 512),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(256, num_labels)
        )

    def freeze_layers(self, num_layers_to_freeze):
        """Freeze the first n transformer layers"""
        modules_to_freeze = [
            self.esm2.embeddings,
            *self.esm2.encoder.layer[:num_layers_to_freeze]
        ]
        for module in modules_to_freeze:
            for param in module.parameters():
                param.requires_grad = False

    def forward(self, input_ids, attention_mask):
        # Get ESM-2 embeddings
        outputs = self.esm2(
            input_ids=input_ids,
            attention_mask=attention_mask,
            return_dict=True
        )

        # Use CLS token representation
        pooled_output = outputs.last_hidden_state[:, 0, :]
        pooled_output = self.dropout(pooled_output)

        # Predict properties
        logits = self.classifier(pooled_output)

        return logits

def train_model(model, train_loader, val_loader, num_epochs=10, learning_rate=2e-5):
    """Training loop with validation"""
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)

    # Initialize optimizer with weight decay
    optimizer = torch.optim.AdamW(
        [
            {'params': model.classifier.parameters(), 'lr': learning_rate},
            {'params': model.esm2.parameters(), 'lr': learning_rate/10}
        ],
        weight_decay=0.01
    )

    # Learning rate scheduler
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer,
        T_max=num_epochs
    )

    criterion = nn.MSELoss()
    best_val_loss = float('inf')

    for epoch in range(num_epochs):
        # Training phase
        model.train()
        train_loss = 0.0

        for batch in train_loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            targets = batch['targets'].to(device)

            optimizer.zero_grad()
            outputs = model(input_ids, attention_mask)
            loss = criterion(outputs, targets)

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

            train_loss += loss.item()

        # Validation phase
        model.eval()
        val_loss = 0.0

        with torch.no_grad():
            for batch in val_loader:
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                targets = batch['targets'].to(device)

                outputs = model(input_ids, attention_mask)
                loss = criterion(outputs, targets)
                val_loss += loss.item()

        # Update learning rate
        scheduler.step()

        # Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), 'best_esm2_model.pth')

        print(f'Epoch {epoch+1}/{num_epochs}')
        print(f'Training Loss: {train_loss/len(train_loader):.4f}')
        print(f'Validation Loss: {val_loss/len(val_loader):.4f}')
        print(f'Learning Rate: {scheduler.get_last_lr()[0]:.2e}')

def main():
    # Load data
    seq_df = pd.read_csv('antibody_sequences.csv')
    asec_df = pd.read_csv('asec_data.csv')
    binding_df = pd.read_csv('binding_data.csv')
    stability_df = pd.read_csv('stability_data.csv')

    # Initialize tokenizer
    tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")

    # Create dataset
    dataset = AntibodyDataset(
        seq_df,
        asec_df,
        binding_df,
        stability_df,
        tokenizer
    )

    # Split dataset
    train_size = int(0.8 * len(dataset))
    val_size = len(dataset) - train_size
    train_dataset, val_dataset = torch.utils.data.random_split(
        dataset, [train_size, val_size]
    )

    # Create data loaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=8,  # Smaller batch size due to model size
        shuffle=True,
        num_workers=4
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=8,
        shuffle=False,
        num_workers=4
    )

    # Initialize model
    model = ESM2ForAntibodyPrediction()

    # Train model
    train_model(model, train_loader, val_loader)

if __name__ == "__main__":
    main()