In [220]:
from torch import nn
import torch

device = "cpu"

if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available():
    device = "mps"

torch.device(device)

device(type='mps')

In [221]:
import pandas as pd

rcsb_data = pd.read_csv("data/rcsb/RCSB_PDB_Macromolecular_Structure_Dataset_with_Structural_Features.csv")

In [222]:
filtered_data = rcsb_data[rcsb_data["Entity Polymer Type"] == "Protein"]
filtered_data = filtered_data.loc[:,["Sequence", "Number of Residues", "Molecular Weight per Deposited Model", "Molecular Weight (Entity)", "R Free", "R Work", "Helix", "Sheet", "Coil"]]
filtered_data = filtered_data.dropna()

In [223]:
amino_acids = ["A", "C", "D", "E", "F", "G", "H", "I", "K", "L", "M", "N", "P", "Q", "R", "S", "T", "V", "W", "Y", "X", "U", "O"]
additional_tokens = ["<pad>", "<sos>", "<eos>"]

amino_acid_tokens = {aa : i for i, aa in enumerate(additional_tokens + amino_acids)}
tokens_to_amino_acid = {i : aa for i, aa in enumerate(additional_tokens + amino_acids)}

def amino_acid_tokenizer(amino_acid : str, amino_acid_tokens) -> torch.Tensor:
    return torch.tensor([amino_acid_tokens[aa] for aa in amino_acid] + [amino_acid_tokens["<eos>"]], dtype=torch.long)

In [224]:
# Normalize additional features
filtered_data["Number of Residues"] = (filtered_data["Number of Residues"] - filtered_data["Number of Residues"].mean()) / filtered_data["Number of Residues"].std()
filtered_data["Molecular Weight per Deposited Model"] = (filtered_data["Molecular Weight per Deposited Model"] - filtered_data["Molecular Weight per Deposited Model"].mean()) / filtered_data["Molecular Weight per Deposited Model"].std()
filtered_data["Molecular Weight (Entity)"] = (filtered_data["Molecular Weight (Entity)"] - filtered_data["Molecular Weight (Entity)"].mean()) / filtered_data["Molecular Weight (Entity)"].std()
filtered_data["R Free"] = (filtered_data["R Free"] - filtered_data["R Free"].mean()) / filtered_data["R Free"].std()
filtered_data["R Work"] = (filtered_data["R Work"] - filtered_data["R Work"].mean()) / filtered_data["R Work"].std()

In [225]:
scrambled_data = filtered_data.sample(frac=1)
scrambled_data = scrambled_data.reset_index(drop=True)

data_size = len(scrambled_data)
train_size = int(data_size * 0.8)
test_size = data_size - train_size
validation_size = int(train_size * 0.2)

train_data = scrambled_data.iloc[:train_size - validation_size]
validation_data = scrambled_data.iloc[train_size - validation_size:train_size]
test_data = scrambled_data.iloc[train_size:]

train_data = train_data.reset_index(drop=True)
validation_data = validation_data.reset_index(drop=True)
test_data = test_data.reset_index(drop=True)

In [226]:
train_input_df = train_data.loc[:,["Sequence", "Number of Residues", "Molecular Weight per Deposited Model", "Molecular Weight (Entity)", "R Free", "R Work"]]
train_output_df = train_data.loc[:,["Helix", "Sheet", "Coil"]]
validation_input_df = validation_data.loc[:,["Sequence", "Number of Residues", "Molecular Weight per Deposited Model", "Molecular Weight (Entity)", "R Free", "R Work"]]
validation_output_df = validation_data.loc[:,["Helix", "Sheet", "Coil"]]
test_input_df = test_data.loc[:,["Sequence", "Number of Residues", "Molecular Weight per Deposited Model", "Molecular Weight (Entity)", "R Free", "R Work"]]
test_output_df = test_data.loc[:,["Helix", "Sheet", "Coil"]]

In [227]:
def create_dataset(input_df, output_df):
    input_tensors = torch.tensor(input_df.drop(columns=["Sequence"]).values, dtype=torch.float)

    tokenized_sequences = input_df["Sequence"].apply(amino_acid_tokenizer, amino_acid_tokens = amino_acid_tokens)

    input_tensor_sequences = torch.nn.utils.rnn.pad_sequence(tokenized_sequences, batch_first = True, padding_value=amino_acid_tokens["<pad>"])
    
    sequence_list_lengths = torch.tensor([len(sequence) for sequence in tokenized_sequences], dtype=torch.int)

    output_tensors = torch.tensor(output_df.values, dtype=torch.float)

    return torch.utils.data.TensorDataset(input_tensors, input_tensor_sequences, sequence_list_lengths, output_tensors)

In [228]:
train_dataset = create_dataset(train_input_df, train_output_df)
validation_dataset = create_dataset(validation_input_df, validation_output_df)
test_dataset = create_dataset(test_input_df, test_output_df)

In [229]:
training_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
validation_dataloader = torch.utils.data.DataLoader(validation_dataset, batch_size=32, shuffle=True)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=True)

In [230]:
from torch.utils.tensorboard import SummaryWriter
from torch.optim import Optimizer
from torch.utils.data import DataLoader

In [231]:
from typing import List, Tuple

In [232]:
def train(epoch : int,
          encoder : nn.Module,
          model : nn.Module,
          encoder_optimizer : Optimizer,
          model_optimizer : Optimizer,
          dataloader : DataLoader,
          model_criterion : nn.Module,
          writer : SummaryWriter = None) -> float:
    encoder.train()
    model.train()

    running_loss = 0.

    for idx, (structural_data, sequence_data, sequence_length, output_data) in enumerate(dataloader):
        structural_data = structural_data.to(device)
        sequence_data = sequence_data.to(device)
        sequence_length = sequence_length.to(device)
        output_data = output_data.to(device)

        encoder_optimizer.zero_grad()
        model_optimizer.zero_grad()

        _, encoder_hidden = encoder(sequence_data, sequence_length)
        model_output = model(structural_data, encoder_hidden)

        model_loss = model_criterion(model_output, output_data)

        model_loss.backward()

        model_optimizer.step()
        encoder_optimizer.step()
        
        running_loss += model_loss.item()

        if writer is not None:
            writer.add_scalar("Loss/Train", model_loss.item(), epoch * len(dataloader) + idx)
    
    running_loss /= len(dataloader)

    if writer is not None:
        writer.add_scalar("Loss/Train/Epoch", running_loss, epoch)

    return running_loss

In [233]:
def validate(epoch : int,
             encoder : nn.Module,
             model : nn.Module,
             dataloader : DataLoader,
             model_criterion : nn.Module,
             writer : SummaryWriter = None) -> Tuple[float, float]:
    encoder.eval()
    model.eval()

    running_loss = 0.

    with torch.no_grad():
        for idx, (structural_data, sequence_data, sequence_length, output_data) in enumerate(dataloader):
            structural_data = structural_data.to(device)
            sequence_data = sequence_data.to(device)
            sequence_length = sequence_length.to(device)
            output_data = output_data.to(device)

            _, encoder_hidden = encoder(sequence_data, sequence_length)
            model_output = model(structural_data, encoder_hidden)

            model_loss = model_criterion(model_output, output_data)

            running_loss += model_loss.item()

            if writer is not None:
                writer.add_scalar("Loss/Validation", model_loss.item(), epoch * len(dataloader) + idx)
    
    running_loss /= len(dataloader)

    if writer is not None:
        writer.add_scalar("Loss/Validation/Epoch", running_loss, epoch)

    return running_loss

In [234]:
def test(encoder : nn.Module,
         model : nn.Module,
         dataloader : DataLoader,
         model_criterion : nn.Module) -> Tuple[float, float]:
    encoder.eval()
    model.eval()

    running_loss = 0.

    with torch.no_grad():
        for idx, (structural_data, sequence_data, sequence_length, output_data) in enumerate(dataloader):
            structural_data = structural_data.to(device)
            sequence_data = sequence_data.to(device)
            sequence_length = sequence_length.to(device)
            output_data = output_data.to(device)

            _, encoder_hidden = encoder(sequence_data, sequence_length)
            model_output = model(structural_data, encoder_hidden)

            model_loss = model_criterion(model_output, output_data)

            running_loss += model_loss.item()
                
    running_loss /= len(dataloader)

    return running_loss

In [235]:
from datetime import datetime

In [236]:
import pathlib

def train_model(encoder : nn.Module,
                model : nn.Module,
                device : torch.device,
                training_dataloader : DataLoader,
                validation_dataloader : DataLoader,
                test_dataloader : DataLoader,
                encoder_optimizer : Optimizer,
                model_optimizer : Optimizer,
                model_loss_fn : nn.Module,
                epochs : int,
                model_name : str = "SecondCount") -> float:
    timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    tensorboard_writer = SummaryWriter(log_dir=f"runs/{model_name}-{timestamp}")

    encoder.to(device)
    model.to(device)

    # Create models directory if it does not exist
    pathlib.Path(f"models/{model_name}-{timestamp}").mkdir(parents=True, exist_ok=True)

    best_model_loss = float("inf")

    for epoch in range(epochs):
        print(f"Epoch {epoch + 1}/{epochs}")

        train_loss = train(epoch, encoder, model, encoder_optimizer, model_optimizer, training_dataloader, model_loss_fn, tensorboard_writer)
        validation_loss = validate(epoch, encoder, model, validation_dataloader, model_loss_fn, tensorboard_writer)

        print(f"Train Losses: {train_loss}")
        print(f"Validation Losses: {validation_loss}")

        if validation_loss < best_model_loss:
            torch.save(model.state_dict(), f"models/{model_name}-{timestamp}/model.pt")
            best_model_loss = validation_loss
        
    test_loss = test(encoder, model, test_dataloader, model_loss_fn)

    print(f"Test Losses: {test_loss}")

    tensorboard_writer.close()

    return test_loss


In [237]:
class Encoder(nn.Module):
    def __init__(self, amino_acid_tokens : dict,  embedding_dim : int, hidden_size : int,  bidirecitonal : bool, num_layer : int , device : torch):
        super(Encoder, self).__init__()

        self.amino_acid_tokens = amino_acid_tokens
        self.embedding_dim = embedding_dim
        self.hidden_size = hidden_size
        self.device = device

        self.amino_acid_embedding = nn.Embedding(len(amino_acid_tokens), embedding_dim, padding_idx=amino_acid_tokens["<pad>"])
        self.lstm = nn.LSTM(embedding_dim, hidden_size, batch_first = True, bidirectional = bidirecitonal, num_layers = num_layer)

    def forward(self, input_tensor_sequences : torch.Tensor, sequence_list_lengths : torch.Tensor) -> torch.Tensor:
        embedded_sequences = self.amino_acid_embedding(input_tensor_sequences)

        packed_sequences = nn.utils.rnn.pack_padded_sequence(embedded_sequences, sequence_list_lengths.cpu(), batch_first = True, enforce_sorted = False)

        output, hidden = self.lstm(packed_sequences)

        output, _ = nn.utils.rnn.pad_packed_sequence(output, batch_first = True)

        return output, hidden

In [238]:
class SecondCountModel(nn.Module):
    def __init__(self, hidden_size : int, structural_features_size : int, output_size : int, bidirectional : bool, num_layers : int):
        super(SecondCountModel, self).__init__()

        hidden_size = hidden_size * (2 if bidirectional else 1) * num_layers

        self.hidden_size = hidden_size
        self.structural_features_size = structural_features_size

        self.hidden_norm = nn.BatchNorm1d(hidden_size)

        self.fc = nn.Sequential(
            nn.Linear(hidden_size + structural_features_size, hidden_size + structural_features_size),
            nn.BatchNorm1d(hidden_size + structural_features_size),
            nn.LeakyReLU(),
            nn.Dropout(0.5),
            nn.Linear(hidden_size + structural_features_size, max(structural_features_size, hidden_size)),
            nn.BatchNorm1d(max(structural_features_size, hidden_size)),
            nn.LeakyReLU(),
            nn.Dropout(0.5),
            nn.Linear(max(structural_features_size, hidden_size), output_size)
        )

    def forward(self, input_tensors : torch.Tensor, hidden : torch.Tensor) -> torch.Tensor:
        batch_size = input_tensors.size(0)
        hidden = hidden[0].permute(1, 0, 2).reshape(batch_size, -1)

        norm_hidden = self.hidden_norm(hidden)

        input_tensor = torch.cat([norm_hidden, input_tensors], dim=1)

        for m in self.fc:
            input_tensor = m(input_tensor)
        return input_tensor
    

In [239]:
class MSLELoss(nn.Module):
    def __init__(self, epsilon : float = 1e-7):
        super(MSLELoss, self).__init__()
        self.epsilon = epsilon

    def forward(self, input : torch.Tensor, target : torch.Tensor) -> torch.Tensor:
        input = torch.clamp(input, min=self.epsilon)
        target = torch.clamp(target, min=self.epsilon)
        return torch.mean((torch.log1p(input) - torch.log1p(target)) ** 2)

In [240]:
hidden_size = 128
bidirectional = True
num_layers = 4

encoder = Encoder(amino_acid_tokens, 32, hidden_size, bidirectional, num_layers, device)
model = SecondCountModel(hidden_size, 5, 3, bidirectional, num_layers)

encoder_optimizer = torch.optim.Adam(encoder.parameters())
model_optimizer = torch.optim.Adam(model.parameters())


In [241]:
model_loss = MSLELoss()

In [242]:
teacher_forcing_ratio_decay = 0.5

In [243]:
train_model(encoder, model, device, training_dataloader, validation_dataloader, test_dataloader, encoder_optimizer, model_optimizer, model_loss, 1024, "SecondCount")

Epoch 1/1024


KeyboardInterrupt: 

In [None]:
encoder.to(device)
model.to(device)

for idx, (structural_data, sequence_data, sequence_length, output_data) in enumerate(test_dataloader):
    structural_data = structural_data.to(device)
    sequence_data = sequence_data.to(device)
    sequence_length = sequence_length.to(device)
    output_data = output_data.to(device)

    encoder_output, encoder_hidden = encoder(sequence_data, sequence_length)
    model_output = model(structural_data, encoder_hidden)

    print(f"Structural Data: {structural_data}")
    print(f"Output Data: {output_data}")
    print(f"Model Output: {model_output}")