In [None]:
%pip install torch
%pip install pandas
%pip install matplotlib
%pip install scikit-learn

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from sklearn.preprocessing import LabelEncoder
from torch.utils.data import DataLoader, TensorDataset

In [None]:
def load_data(file_path):
    data = pd.read_csv(file_path)
    
    sequences = data['seq'].tolist()
    structures = data['sst3'].tolist()
    
    return sequences, structures

AMINO_ACIDS = 'ACDEFGHIKLMNPQRSTVWY'

#Function to one hot encode a sequence
def one_hot_encode(sequence):

    one_hot_matrix = np.zeros((len(sequence), len(AMINO_ACIDS)), dtype=int)
    
    for i, amino_acid in enumerate(sequence):
        if amino_acid in AMINO_ACIDS:
            index = AMINO_ACIDS.index(amino_acid)
            one_hot_matrix[i, index] = 1
        else:
            raise ValueError(f"Unknown amino acid {amino_acid} in sequence.")
    
    return one_hot_matrix


#Function to extract one hot encoded sequences, and remove those close to the edge.
def extract_features_and_labels(sequences, structures, window_size=15):

    X, y = [], []
    half_window = window_size // 2
    
    for sequence, structure in zip(sequences, structures):
        padded_sequence = 'X' * half_window + sequence + 'X' * half_window  # Padding with 'X'
        for i in range(len(sequence)):
            window = padded_sequence[i:i + window_size]  # Extract window
            if 'X' not in window:  # Ignore windows with padding (if any)
                one_hot_window = one_hot_encode(window)  # One-hot encode the window
                X.append(one_hot_window.flatten())  # Flatten to 1D
                y.append(structure[i])  # Label for the center amino acid
    
    return np.array(X), np.array(y)


# Convert structures to integers for classification (C = 0, E = 1, H = 2)
def preprocess_labels(y):
    label_encoder = LabelEncoder()
    y_encoded = label_encoder.fit_transform(y) 
    return y_encoded, label_encoder

def prepare_data(X, y, batch_size=64):
    X_tensor = torch.tensor(X, dtype=torch.float32) 
    y_tensor = torch.tensor(y, dtype=torch.long)     
    
    dataset = TensorDataset(X_tensor, y_tensor)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    
    return loader


In [None]:
# Function to compute accuracy
def accuracy(preds, labels):
    _, predicted = torch.max(preds, 1)
    return (predicted == labels).sum().item() / len(labels)

# Training function

def train_rnn(model, train_loader, val_loader, criterion, optimizer, num_epochs=10, device='cpu'):
    model.to(device)
    train_losses = []
    train_accuracies = [] 
    val_accuracies = []

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        correct_train = 0  
        total_train = 0  


        for i, (inputs, labels) in enumerate(train_loader):
            inputs, labels = inputs.to(device), labels.to(device)

        
            inputs = inputs.view(-1, 15, 20)  # (batch_size, seq_len, input_size)

            # Zero the parameter gradients
            optimizer.zero_grad()

            # Forward pass
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            # Backward pass and optimization
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

            _, predicted = torch.max(outputs, 1)
            total_train += labels.size(0)
            correct_train += (predicted == labels).sum().item()

        avg_train_loss = running_loss / len(train_loader)
        train_acc = correct_train / total_train
        train_losses.append(avg_train_loss)
        train_accuracies.append(train_acc)

        # Evaluation on validation set
        val_acc = evaluate_rnn(model, val_loader, device)
        val_accuracies.append(val_acc)

        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {avg_train_loss:.4f}, Validation Accuracy: {val_acc:.4f}')

    return train_losses, train_accuracies, val_accuracies

# Validation function
def evaluate_rnn(model, val_loader, device='cpu'):
    model.eval()  # Set model to evaluation mode
    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            # Reshape inputs to 3D (batch_size, sequence_length, input_size)
            inputs = inputs.view(-1, 15, 20)

            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    return correct / total


In [None]:
class VanillaRNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(VanillaRNN, self).__init__()
        self.hidden_size = hidden_size
        
        # Define an RNN layer
        self.rnn = nn.RNN(input_size, hidden_size, batch_first=True)  # batch_first=True makes input (batch, seq, input_size)
        
        # Fully connected layer for output prediction
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        # Initialize hidden state (batch_size, hidden_size)
        batch_size = x.size(0)
        h0 = torch.zeros(1, batch_size, self.hidden_size).to(x.device)  # Initialize hidden state
        
        # Forward propagate through RNN
        out, hn = self.rnn(x, h0)
        
        # We take the output of the last time step
        out = out[:, -1, :]  # (batch_size, hidden_size)
        
        # Pass through the fully connected layer
        out = self.fc(out) 
        
        return out
    

class LSTM(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, num_layers=2, dropout = 0):
        super(LSTM, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        
        # Define a two-layer LSTM
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, dropout=dropout)  # batch_first=True makes input (batch, seq, input_size)
        
        # Dropout layer before the fully connected layer (optional, for further regularization)
        self.dropout = nn.Dropout(dropout)

        # Fully connected layer for output prediction
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        batch_size = x.size(0)

        # Initialize hidden state and cell state for LSTM
        h0 = torch.zeros(self.num_layers, batch_size, self.hidden_size).to(x.device)  # Hidden state
        c0 = torch.zeros(self.num_layers, batch_size, self.hidden_size).to(x.device)  # Cell state
        
        # Forward propagate through LSTM
        out, (hn, cn) = self.lstm(x, (h0, c0))
        
        # We take the output of the last time step
        out = out[:, -1, :]  # (batch_size, hidden_size)

        # Apply dropout before the fully connected layer
        out = self.dropout(out)
        
        # Pass through the fully connected layer to map hidden state to output
        out = self.fc(out)
        
        return out
    
class GRU(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, num_layers=2):
        super(GRU, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        
        # Define a two-layer GRU
        self.gru = nn.GRU(input_size, hidden_size, num_layers, batch_first=True)  # Use GRU instead of LSTM
        
        # Fully connected layer for output prediction
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        batch_size = x.size(0)

        # Initialize hidden state for GRU
        h0 = torch.zeros(self.num_layers, batch_size, self.hidden_size).to(x.device)  # GRU only has hidden state, no cell state

        # Forward propagate through GRU
        out, hn = self.gru(x, h0)
        
        # We take the output of the last time step
        out = out[:, -1, :]  # (batch_size, hidden_size)
        
        # Pass through the fully connected layer to map hidden state to output
        out = self.fc(out)
        
        return out
    


In [None]:
# Load the training and test data
train_file = 'training_data.csv'
test_file = 'test_data.csv'

train_sequences, train_structures = load_data(train_file)
test_sequences, test_structures = load_data(test_file)

# Feature extraction (sliding window and one-hot encoding)
X_train, y_train = extract_features_and_labels(train_sequences, train_structures)
X_test, y_test = extract_features_and_labels(test_sequences, test_structures)

# Preprocess labels (encode them)
y_train_encoded, label_encoder = preprocess_labels(y_train)
y_test_encoded, _ = preprocess_labels(y_test)  # Use the same label encoder for test set

# Create DataLoaders for training and testing
train_loader = prepare_data(X_train, y_train_encoded)
test_loader = prepare_data(X_test, y_test_encoded)

In [None]:
# Hyperparameters
input_size = 20  # One-hot encoded amino acid (20 possible values for each amino acid)
hidden_size = 64
output_size = 3  # C, E, H (3 secondary structure types)
num_epochs = 50
learning_rate = 0.001

# Initialize the model, loss function, and optimizer
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

#Choose the desired model, loss function and optimiser
model = VanillaRNN(input_size, hidden_size, output_size)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)


# Train the model 
train_losses, train_accuracies, val_accuracies = train_rnn(model, train_loader, test_loader, criterion, optimizer, num_epochs, device)

# Plot the training loss
plt.figure(figsize=(10, 5))
plt.plot(train_losses, label='Training Loss')
plt.title('Training Loss Over Time')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()

# Plot the training accuracy and validation accuracy
plt.figure(figsize=(10, 5))
plt.plot(train_accuracies, label='Training Accuracy', color='blue')
plt.plot(val_accuracies, label='Validation Accuracy', color='orange')
plt.title('Training and Validation Accuracy Over Time')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.show()