In [1]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
import warnings
warnings.filterwarnings('ignore')

torch.manual_seed(42)
np.random.seed(42)

class ProteinDataset(Dataset):
    """Dataset class for protein sequences and their per-residue labels"""

    def __init__(self, sequences, dssp3_labels, dssp8_labels):
        self.sequences = sequences
        self.dssp3_labels = dssp3_labels
        self.dssp8_labels = dssp8_labels

        # Amino acid to integer mapping (20 standard amino acids + unknown)
        # X for unknown
        self.aa_to_idx = {
            'A': 0, 'R': 1, 'N': 2, 'D': 3, 'C': 4, 'Q': 5, 'E': 6, 'G': 7,
            'H': 8, 'I': 9, 'L': 10, 'K': 11, 'M': 12, 'F': 13, 'P': 14,
            'S': 15, 'T': 16, 'W': 17, 'Y': 18, 'V': 19, 'X': 20
        }

        # Label mappings
        self.dssp3_to_idx = {'H': 0, 'E': 1, 'C': 2}
        self.dssp8_to_idx = {'H': 0, 'B': 1, 'E': 2, 'G': 3, 'I': 4, 'T': 5, 'S': 6, '-': 7}

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        sequence = self.sequences[idx]
        dssp3_str = self.dssp3_labels[idx]
        dssp8_str = self.dssp8_labels[idx]

        seq_len = len(sequence)

        # Convert amino acid sequence to integer indices
        seq_indices = [self.aa_to_idx.get(aa, 20) for aa in sequence]

        # Convert labels to integer indices
        dssp3_indices = [self.dssp3_to_idx.get(label, 2) for label in dssp3_str[:seq_len]]
        dssp8_indices = [self.dssp8_to_idx.get(label, 7) for label in dssp8_str[:seq_len]]

        # Max length in dataset is around 1740-1750, so keeping max length as 1800 for safety. This could be changed for dataset with variable length.
        max_length = 1800
        if seq_len > max_length:
            seq_indices = seq_indices[:max_length]
            dssp3_indices = dssp3_indices[:max_length]
            dssp8_indices = dssp8_indices[:max_length]
            seq_len = max_length
        else:
            # Pad with appropriate values
            pad_length = max_length - seq_len
            seq_indices.extend([20] * pad_length)
            dssp3_indices.extend([2] * pad_length)
            dssp8_indices.extend([7] * pad_length)

        return {
            'sequence': torch.tensor(seq_indices, dtype=torch.long),
            'dssp3': torch.tensor(dssp3_indices, dtype=torch.long),
            'dssp8': torch.tensor(dssp8_indices, dtype=torch.long),
            'length': seq_len
        }

class ProteinEmbedding(nn.Module):
    """Embedding layer for amino acid sequences"""

    def __init__(self, vocab_size=21, embed_dim=128):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=20)
        self.position_embedding = nn.Embedding(1800, embed_dim)

    def forward(self, x):
        batch_size, seq_len = x.shape
        positions = torch.arange(seq_len, device=x.device).unsqueeze(0).expand(batch_size, -1)

        # Amino acid embeddings + positional embeddings
        x = self.embedding(x) + self.position_embedding(positions)
        return x

class ConvolutionalBlock(nn.Module):
    """Convolutional block as described in the paper"""

    def __init__(self, in_channels, out_channels, kernel_size=7, dropout=0.1):
        super().__init__()
        self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, padding=kernel_size//2, stride=1)
        self.batch_norm = nn.BatchNorm1d(out_channels)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = self.conv(x)
        x = self.batch_norm(x)
        x = self.relu(x)
        x = self.dropout(x)
        return x

class ProteinStructurePredictor(nn.Module):
    """ Protein secondary structure prediction model. Modified for DSSP3 and DSSP8 predictions only. """

    def __init__(self, vocab_size=21, embed_dim=128, hidden_dim=512, dropout=0.1):
        super().__init__()

        # Embedding layer
        self.embedding = ProteinEmbedding(vocab_size, embed_dim)

        # Convolutional encoder
        self.conv_layers = nn.ModuleList([
            ConvolutionalBlock(embed_dim, hidden_dim, dropout=dropout),
            ConvolutionalBlock(hidden_dim, hidden_dim, dropout=dropout),
            ConvolutionalBlock(hidden_dim, hidden_dim//2, dropout=dropout),
            ConvolutionalBlock(hidden_dim//2, hidden_dim//2, dropout=dropout),
            ConvolutionalBlock(hidden_dim//2, hidden_dim//4, dropout=dropout)
        ])

        # Task-specific output heads for per-residue prediction
        final_dim = hidden_dim // 4

        # DSSP3 classifier (3 classes: H, E, C)
        self.dssp3_classifier = nn.Linear(final_dim, 3)

        # DSSP8 classifier (8 classes: H, B, E, G, I, T, S, -)
        self.dssp8_classifier = nn.Linear(final_dim, 8)

    def forward(self, x, lengths=None):
        # Embedding
        x = self.embedding(x)  # (batch_size, seq_len, embed_dim)

        # Transpose for conv1d (batch_size, embed_dim, seq_len)
        x = x.transpose(1, 2)

        # Convolutional layers
        for conv_layer in self.conv_layers:
            x = conv_layer(x)

        # Transpose back to (batch_size, seq_len, hidden_dim)
        x = x.transpose(1, 2)

        # Per-residue predictions
        dssp3_pred = self.dssp3_classifier(x)  # (batch_size, seq_len, 3)
        dssp8_pred = self.dssp8_classifier(x)  # (batch_size, seq_len, 8)

        return {
            'dssp3': dssp3_pred,
            'dssp8': dssp8_pred
        }

class ProteinStructureTrainer:
    """Trainer class for the protein structure prediction model"""

    def __init__(self, model, device='cpu'):
        self.model = model.to(device)
        self.device = device
        self.criterion = nn.CrossEntropyLoss(ignore_index=-1)

    def train_epoch(self, train_loader, optimizer, scheduler=None):
        self.model.train()
        total_loss = 0
        num_batches = 0

        for batch in train_loader:
            sequences = batch['sequence'].to(self.device)
            lengths = batch['length'].to(self.device)
            dssp3_labels = batch['dssp3'].to(self.device)
            dssp8_labels = batch['dssp8'].to(self.device)

            optimizer.zero_grad()

            outputs = self.model(sequences, lengths)

            # Create masks for valid positions
            batch_size, seq_len = sequences.shape
            mask = torch.arange(seq_len, device=self.device).unsqueeze(0) < lengths.unsqueeze(1)

            # Flatten predictions and labels for loss calculation
            dssp3_pred_flat = outputs['dssp3'].view(-1, 3)
            dssp8_pred_flat = outputs['dssp8'].view(-1, 8)

            dssp3_labels_flat = dssp3_labels.view(-1)
            dssp8_labels_flat = dssp8_labels.view(-1)

            mask_flat = mask.view(-1)
            dssp3_labels_flat[~mask_flat] = -1
            dssp8_labels_flat[~mask_flat] = -1

            # Calculate losses for each task
            dssp3_loss = self.criterion(dssp3_pred_flat, dssp3_labels_flat)
            dssp8_loss = self.criterion(dssp8_pred_flat, dssp8_labels_flat)

            # Multi-task loss (weighted sum)
            total_batch_loss = dssp3_loss + dssp8_loss

            total_batch_loss.backward()
            optimizer.step()

            if scheduler:
                scheduler.step()

            total_loss += total_batch_loss.item()
            num_batches += 1

        return total_loss / num_batches

    def evaluate(self, test_loader):
        self.model.eval()
        predictions = {'dssp3': [], 'dssp8': []}
        true_labels = {'dssp3': [], 'dssp8': []}

        with torch.no_grad():
            for batch in test_loader:
                sequences = batch['sequence'].to(self.device)
                lengths = batch['length'].to(self.device)

                outputs = self.model(sequences, lengths)

                # Create mask for valid positions
                batch_size, seq_len = sequences.shape
                mask = torch.arange(seq_len, device=self.device).unsqueeze(0) < lengths.unsqueeze(1)

                # Get predictions for valid positions only
                for i in range(batch_size):
                    valid_len = lengths[i].item()

                    # Get predictions for this sequence
                    dssp3_pred = torch.argmax(outputs['dssp3'][i, :valid_len], dim=1).cpu().numpy()
                    dssp8_pred = torch.argmax(outputs['dssp8'][i, :valid_len], dim=1).cpu().numpy()

                    # Get true labels for this sequence
                    dssp3_true = batch['dssp3'][i, :valid_len].numpy()
                    dssp8_true = batch['dssp8'][i, :valid_len].numpy()

                    predictions['dssp3'].extend(dssp3_pred)
                    predictions['dssp8'].extend(dssp8_pred)

                    true_labels['dssp3'].extend(dssp3_true)
                    true_labels['dssp8'].extend(dssp8_true)

        # Calculate accuracies
        accuracies = {}
        for task in predictions.keys():
            accuracies[task] = accuracy_score(true_labels[task], predictions[task])

        return accuracies, predictions, true_labels

def load_and_prepare_data(file_path):
    """Load and prepare the dataset from the input file"""

    df = pd.read_csv(file_path)

    sequences = df['input'].tolist()
    dssp3_labels = df[' dssp3'].tolist()
    dssp8_labels = df[' dssp8'].tolist()

    return sequences, dssp3_labels, dssp8_labels

def create_train_test_split(sequences, dssp3_labels, dssp8_labels, test_size=0.2, random_state=42):
    """Create train/test split with proper stratification"""

    # Create indices
    indices = list(range(len(sequences)))

    # Split the indices
    train_indices, test_indices = train_test_split(
        indices,
        test_size=test_size,
        random_state=random_state,
        shuffle=True
    )

    # Create train sets
    train_sequences = [sequences[i] for i in train_indices]
    train_dssp3 = [dssp3_labels[i] for i in train_indices]
    train_dssp8 = [dssp8_labels[i] for i in train_indices]

    # Create test sets
    test_sequences = [sequences[i] for i in test_indices]
    test_dssp3 = [dssp3_labels[i] for i in test_indices]
    test_dssp8 = [dssp8_labels[i] for i in test_indices]

    return {
        'train': (train_sequences, train_dssp3, train_dssp8),
        'test': (test_sequences, test_dssp3, test_dssp8),
        'train_indices': train_indices,
        'test_indices': test_indices
    }

def main():
    """Main training and evaluation function"""

    sequences, dssp3_labels, dssp8_labels = load_and_prepare_data(r"C:\Users\Admin\Desktop\SummerSiege_ProteinProject\SummerSiege_FinalSubmission\CB513_HHblits.csv")
    print(f"Successfully loaded {len(sequences)} sequences from file")

    # Create train/test split
    print("Creating train/test split...")
    split_data = create_train_test_split(sequences, dssp3_labels, dssp8_labels, test_size=0.2)

    train_sequences, train_dssp3, train_dssp8 = split_data['train']
    test_sequences, test_dssp3, test_dssp8 = split_data['test']

    print(f"Training sequences: {len(train_sequences)}")
    print(f"Testing sequences: {len(test_sequences)}")

    # Create datasets
    train_dataset = ProteinDataset(train_sequences, train_dssp3, train_dssp8)
    test_dataset = ProteinDataset(test_sequences, test_dssp3, test_dssp8)

    # Further split training data into train/validation
    train_size = int(0.8 * len(train_dataset))
    val_size = len(train_dataset) - train_size
    train_dataset, val_dataset = torch.utils.data.random_split(train_dataset, [train_size, val_size])

    print(f"Final split - Train: {len(train_dataset)}, Validation: {len(val_dataset)}, Test: {len(test_dataset)}")

    # Create data loaders
    batch_size = 4
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    # Initialize model
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    model = ProteinStructurePredictor(
        vocab_size=21,
        embed_dim=128,
        hidden_dim=512,
        dropout=0.1
    )

    # Initialize trainer
    trainer = ProteinStructureTrainer(model, device)

    # Setup optimizer and scheduler
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=30)

    # Training loop
    print("Starting training...")
    num_epochs = 30
    best_val_acc = 0.0

    for epoch in range(num_epochs):
        # Train
        train_loss = trainer.train_epoch(train_loader, optimizer, scheduler)

        # Evaluate on validation set
        val_accuracies, _, _ = trainer.evaluate(val_loader)

        # Save best model
        current_val_acc = (val_accuracies['dssp3'] + val_accuracies['dssp8']) / 2
        if current_val_acc > best_val_acc:
            best_val_acc = current_val_acc
            torch.save({
                'model_state_dict': model.state_dict(),
                'model_config': {
                    'vocab_size': 21,
                    'embed_dim': 128,
                    'hidden_dim': 512,
                    'dropout': 0.1
                },
                'epoch': epoch,
                'val_accuracy': current_val_acc
            }, 'best_protein_structure_model.pth')

        print(f"Epoch {epoch+1}/{num_epochs}:")
        print(f"  Train Loss: {train_loss:.4f}")
        print(f"  Val Accuracies - DSSP3: {val_accuracies['dssp3']:.4f}, "
              f"DSSP8: {val_accuracies['dssp8']:.4f}")
        print(f"  Average Val Accuracy: {current_val_acc:.4f}")
        if current_val_acc > best_val_acc:
            print("  * New best model saved!")
        print()

    # Final evaluation on test set
    print("Evaluating on test set...")
    test_accuracies, test_predictions, test_true_labels = trainer.evaluate(test_loader)

    print("Final Test Results:")
    print(f"DSSP3 Accuracy: {test_accuracies['dssp3']:.4f}")
    print(f"DSSP8 Accuracy: {test_accuracies['dssp8']:.4f}")
    print(f"Average Test Accuracy: {(test_accuracies['dssp3'] + test_accuracies['dssp8']) / 2:.4f}")

    # Print detailed classification reports
    print("\nDSSP3 Classification Report:")
    dssp3_labels_names = ['H', 'E', 'C']
    print(classification_report(test_true_labels['dssp3'], test_predictions['dssp3'],
                              target_names=dssp3_labels_names))

    print("\nDSSP8 Classification Report:")
    dssp8_labels_names = ['H', 'B', 'E', 'G', 'I', 'T', 'S', '-']
    print(classification_report(test_true_labels['dssp8'], test_predictions['dssp8'],
                              target_names=dssp8_labels_names))

    print("Training completed!")
    print("Best model saved as 'best_protein_structure_model.pth'")

def predict_structure(model_path, sequences):
    """Function to make predictions on new sequences"""

    # Load model
    checkpoint = torch.load(model_path, map_location='cpu')

    config = checkpoint['model_config']
    model = ProteinStructurePredictor(**config)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()

    # Label decoders
    idx_to_dssp3 = {0: 'H', 1: 'E', 2: 'C'}
    idx_to_dssp8 = {0: 'H', 1: 'B', 2: 'E', 3: 'G', 4: 'I', 5: 'T', 6: 'S', 7: '-'}

    # Amino acid mapping
    aa_to_idx = {
        'A': 0, 'R': 1, 'N': 2, 'D': 3, 'C': 4, 'Q': 5, 'E': 6, 'G': 7,
        'H': 8, 'I': 9, 'L': 10, 'K': 11, 'M': 12, 'F': 13, 'P': 14,
        'S': 15, 'T': 16, 'W': 17, 'Y': 18, 'V': 19, 'X': 20
    }

    predictions = []

    with torch.no_grad():
        for sequence in sequences:
            seq_len = len(sequence)

            # Convert to indices
            seq_indices = [aa_to_idx.get(aa, 20) for aa in sequence]

            # Pad/truncate
            max_length = 1800
            if len(seq_indices) > max_length:
                seq_indices = seq_indices[:max_length]
                seq_len = max_length
            else:
                seq_indices.extend([20] * (max_length - len(seq_indices)))

            # Convert to tensor
            seq_tensor = torch.tensor(seq_indices, dtype=torch.long).unsqueeze(0)
            length_tensor = torch.tensor([seq_len], dtype=torch.long)

            # Predict
            outputs = model(seq_tensor, length_tensor)

            # Get per-residue predictions
            dssp3_pred = torch.argmax(outputs['dssp3'][0, :seq_len], dim=1).numpy()
            dssp8_pred = torch.argmax(outputs['dssp8'][0, :seq_len], dim=1).numpy()

            # Decode predictions
            dssp3_decoded = ''.join([idx_to_dssp3[i] for i in dssp3_pred])
            dssp8_decoded = ''.join([idx_to_dssp8[i] for i in dssp8_pred])

            predictions.append({
                'sequence': sequence,
                'dssp3': dssp3_decoded,
                'dssp8': dssp8_decoded
            })

    return predictions

if __name__ == "__main__":
    main()

Successfully loaded 511 sequences from file
Creating train/test split...
Training sequences: 408
Testing sequences: 103
Final split - Train: 326, Validation: 82, Test: 103
Using device: cpu
Starting training...
Epoch 1/30:
  Train Loss: 1.5993
  Val Accuracies - DSSP3: 0.7879, DSSP8: 0.7065
  Average Val Accuracy: 0.7472

Epoch 2/30:
  Train Loss: 1.2871
  Val Accuracies - DSSP3: 0.8024, DSSP8: 0.7229
  Average Val Accuracy: 0.7626

Epoch 3/30:
  Train Loss: 1.2163
  Val Accuracies - DSSP3: 0.8091, DSSP8: 0.7341
  Average Val Accuracy: 0.7716

Epoch 4/30:
  Train Loss: 1.1459
  Val Accuracies - DSSP3: 0.8097, DSSP8: 0.7370
  Average Val Accuracy: 0.7734

Epoch 5/30:
  Train Loss: 1.1193
  Val Accuracies - DSSP3: 0.8148, DSSP8: 0.7417
  Average Val Accuracy: 0.7783

Epoch 6/30:
  Train Loss: 1.0656
  Val Accuracies - DSSP3: 0.8170, DSSP8: 0.7458
  Average Val Accuracy: 0.7814

Epoch 7/30:
  Train Loss: 1.0622
  Val Accuracies - DSSP3: 0.8263, DSSP8: 0.7513
  Average Val Accuracy: 0.7888