In [114]:
from torch import nn
import torch

device = "cpu"

if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available():
    device = "mps"

torch.device(device)

device(type='mps')

In [115]:
import pandas as pd

rcsb_data = pd.read_csv("data/rcsb/RCSB_PDB_Macromolecular_Structure_Dataset_with_Structural_Features.csv")

In [116]:
filtered_data = rcsb_data[rcsb_data["Entity Polymer Type"] == "Protein"]
filtered_data = filtered_data.loc[:,["Sequence", "Number of Residues", "Molecular Weight per Deposited Model", "Molecular Weight (Entity)", "R Free", "R Work", "Helix", "Sheet", "Coil"]]
filtered_data = filtered_data.dropna(subset=["Helix", "Sheet", "Coil"])

In [117]:
amino_acids = ["A", "C", "D", "E", "F", "G", "H", "I", "K", "L", "M", "N", "P", "Q", "R", "S", "T", "V", "W", "Y", "X", "U", "O"]
additional_tokens = ["<pad>", "<sos>", "<eos>"]

amino_acid_tokens = {aa : i for i, aa in enumerate(additional_tokens + amino_acids)}

def amino_acid_tokenizer(amino_acid : str, amino_acid_tokens) -> torch.Tensor:
    return torch.tensor([amino_acid_tokens[aa] for aa in amino_acid] + [amino_acid_tokens["<eos>"]], dtype=torch.long)

In [118]:
scrambled_data = filtered_data.sample(frac=1)
scrambled_data = scrambled_data.reset_index(drop=True)

data_size = len(scrambled_data)
train_size = int(data_size * 0.8)
test_size = data_size - train_size
validation_size = int(train_size * 0.2)

train_data = scrambled_data.iloc[:train_size - validation_size]
validation_data = scrambled_data.iloc[train_size - validation_size:train_size]
test_data = scrambled_data.iloc[train_size:]

train_data = train_data.reset_index(drop=True)
validation_data = validation_data.reset_index(drop=True)
test_data = test_data.reset_index(drop=True)

In [119]:
train_input_df = train_data.loc[:,["Sequence", "Number of Residues", "Molecular Weight per Deposited Model", "Molecular Weight (Entity)", "R Free", "R Work"]]
train_output_df = train_data.loc[:,["Helix", "Sheet", "Coil"]]
validation_input_df = validation_data.loc[:,["Sequence", "Number of Residues", "Molecular Weight per Deposited Model", "Molecular Weight (Entity)", "R Free", "R Work"]]
validation_output_df = validation_data.loc[:,["Helix", "Sheet", "Coil"]]
test_input_df = test_data.loc[:,["Sequence", "Number of Residues", "Molecular Weight per Deposited Model", "Molecular Weight (Entity)", "R Free", "R Work"]]
test_output_df = test_data.loc[:,["Helix", "Sheet", "Coil"]]

In [120]:
def create_dataset(input_df, output_df):
    input_tensors = torch.tensor(input_df.drop(columns=["Sequence"]).values, dtype=torch.float32)

    tokenized_sequences = input_df["Sequence"].apply(amino_acid_tokenizer, amino_acid_tokens = amino_acid_tokens)

    input_tensor_sequences = torch.nn.utils.rnn.pad_sequence(tokenized_sequences, batch_first = True, padding_value=amino_acid_tokens["<pad>"])
    
    sequence_list_lengths = torch.tensor([len(sequence) for sequence in tokenized_sequences], dtype=torch.int)

    output_tensors = torch.tensor(output_df.values, dtype=torch.float32)

    return torch.utils.data.TensorDataset(input_tensors, input_tensor_sequences, sequence_list_lengths, output_tensors)

In [121]:
train_dataset = create_dataset(train_input_df, train_output_df)
validation_dataset = create_dataset(validation_input_df, validation_output_df)
test_dataset = create_dataset(test_input_df, test_output_df)

In [122]:
training_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=16, shuffle=True)
validation_dataloader = torch.utils.data.DataLoader(validation_dataset, batch_size=16, shuffle=True)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=16, shuffle=True)

In [123]:
from torch.utils.tensorboard import SummaryWriter
from torch.optim import Optimizer
from torch.utils.data import DataLoader

In [124]:
from typing import List


def train(epoch : int,
          model : nn.Module,
          device: torch.device,
          train_dataloader : DataLoader,
          optimizer : Optimizer,
          loss_fns : List[nn.Module],
          tensorboard_writer : SummaryWriter = None) -> float:
    running_loss = 0.

    model.train(True)

    for idx, (input_tensors, input_tensor_sequences, sequence_list_lengths, output_tensors) in enumerate(train_dataloader):
        input_tensors = input_tensors.to(device)
        input_tensor_sequences = input_tensor_sequences.to(device)
        output_tensors = output_tensors.to(device)

        optimizer.zero_grad()

        teacher_forcing_ratio = 0.9 ** epoch

        output, decoded = model(input_tensors, input_tensor_sequences, sequence_list_lengths, True, teacher_forcing_ratio)
        mse_loss = loss_fns[0](output, output_tensors)
        ce_loss = loss_fns[1](decoded.view(-1, decoded.size(-1)), input_tensor_sequences.view(-1))

        loss = mse_loss + ce_loss
  
        loss.backward()
        optimizer.step()
  
        running_loss += loss.item()

        if tensorboard_writer is not None:
            tensorboard_writer.add_scalar("Loss/train", loss.item(), epoch * len(train_dataloader) + idx)

    running_loss /= len(train_dataloader)

    if tensorboard_writer is not None:
        tensorboard_writer.add_scalar("Loss/train/epoch", running_loss, epoch)
    
    return running_loss

In [125]:
def validate(epoch : int, 
             model : nn.Module,
             device : torch.device,
             validation_dataloader : DataLoader,
             loss_fns : List[nn.Module],
             tensorboard_writer : SummaryWriter = None) -> float:
    running_loss = 0.

    model.eval()

    with torch.no_grad():
        for idx, (input_tensors, input_tensor_sequences, sequence_list_lengths, output_tensors) in enumerate(validation_dataloader):
            input_tensors = input_tensors.to(device)
            input_tensor_sequences = input_tensor_sequences.to(device)
            output_tensors = output_tensors.to(device)

            output, decoded = model(input_tensors, input_tensor_sequences, sequence_list_lengths, True)

            mse_loss = loss_fns[0](output, output_tensors)
            ce_loss = loss_fns[1](decoded.view(-1, decoded.size(-1)), input_tensor_sequences.view(-1))

            loss = mse_loss + ce_loss

            running_loss += loss.item()

            if tensorboard_writer is not None:
                tensorboard_writer.add_scalar("Loss/validation", loss.item(), epoch * len(validation_dataloader) + idx)
    
    running_loss /= len(validation_dataloader)

    if tensorboard_writer is not None:
        tensorboard_writer.add_scalar("Loss/validation/epoch", running_loss, epoch)
    
    return running_loss

In [126]:
def test(model : nn.Module,
         device : torch.device,
         test_dataloader : DataLoader,
         loss_fn : nn.Module) -> float:
    running_loss = 0.

    model.eval()

    with torch.no_grad():
        for idx, (input_tensors, input_tensor_sequences, sequence_list_lengths, output_tensors)  in enumerate(test_dataloader):
            input_tensors = input_tensors.to(device)
            input_tensor_sequences = input_tensor_sequences.to(device)
            output_tensors = output_tensors.to(device)

            output = model(input_tensors, input_tensor_sequences, sequence_list_lengths)
            loss = loss_fn(output, output_tensors)

            running_loss += loss.item()

    running_loss /= len(test_dataloader)

    return running_loss

In [127]:
from datetime import datetime

In [128]:
def train_model(model : nn.Module,
                device : torch.device,
                training_dataloader : DataLoader,
                validation_dataloader : DataLoader,
                test_dataloader : DataLoader,
                optimizer : Optimizer,
                train_val_loss_fns : List[nn.Module],
                test_loss_fn : nn.Module,
                epochs : int,
                model_name : str = "SecondCount") -> float:
    model.to(device)

    timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    tensorboard_writer = SummaryWriter(f"runs/{model_name}_{timestamp}")

    best_validation_loss = float("inf")

    for epoch in range(epochs):
        train_loss = train(epoch, model, device, training_dataloader, optimizer, train_val_loss_fns, tensorboard_writer)
        validation_loss = validate(epoch, model, device, validation_dataloader, train_val_loss_fns, tensorboard_writer)

        print(f"Epoch {epoch + 1}/{epochs} - Train Loss: {train_loss} - Validation Loss: {validation_loss}")

        if validation_loss < best_validation_loss:
            best_validation_loss = validation_loss
            torch.save(model.state_dict(), f"models/{model_name}_{timestamp}_{epoch}.pt")


    test_loss = test(model, device, test_dataloader, test_loss_fn)

    print(f"Test Loss: {test_loss}")

    tensorboard_writer.close()

    return test_loss

In [129]:
class Encoder(nn.Module):
    def __init__(self, amino_acid_tokens : dict,  embedding_dim : int, hidden_size : int,  device : torch):
        super(Encoder, self).__init__()

        self.amino_acid_tokens = amino_acid_tokens
        self.embedding_dim = embedding_dim
        self.hidden_size = hidden_size
        self.device = device

        self.amino_acid_embedding = nn.Embedding(len(amino_acid_tokens), embedding_dim, padding_idx=amino_acid_tokens["<pad>"])
        self.lstm = nn.LSTM(embedding_dim, hidden_size, batch_first = True)

    def forward(self, input_tensor_sequences : torch.Tensor, sequence_list_lengths : torch.Tensor) -> torch.Tensor:
        embedded_sequences = self.amino_acid_embedding(input_tensor_sequences)

        packed_sequences = nn.utils.rnn.pack_padded_sequence(embedded_sequences, sequence_list_lengths, batch_first = True, enforce_sorted = False)

        output, hidden = self.lstm(packed_sequences)

        output, _ = nn.utils.rnn.pad_packed_sequence(output, batch_first = True)

        return output, hidden

In [130]:
class Decoder(nn.Module):
    def __init__(self, amino_acid_tokens : dict, embedding_dim : int, hidden_size : int, device : torch):
        super(Decoder, self).__init__()

        self.amino_acid_tokens = amino_acid_tokens
        self.embedding_dim = embedding_dim
        self.hidden_size = hidden_size
        self.device = device

        self.amino_acid_embedding = nn.Embedding(len(amino_acid_tokens), embedding_dim, padding_idx=amino_acid_tokens["<pad>"])
        self.lstm = nn.LSTM(embedding_dim, hidden_size, batch_first = True)
        self.fc = nn.Linear(hidden_size, len(amino_acid_tokens))

    def forward(self, input_tensor_sequences : torch.Tensor, hidden : torch.Tensor, teacher_forcing_ratio : float = 0) -> torch.Tensor:
        batch_size, sequence_length = input_tensor_sequences.size()

        decoder_input = torch.empty(batch_size, 1, dtype=torch.long).fill_(self.amino_acid_tokens["<sos>"]).to(self.device)
        decoder_hidden = hidden
        decoder_outputs = []

        for idx in range(sequence_length):
            decoder_output, decoder_hidden = self.forward_step(decoder_input, decoder_hidden)
            decoder_outputs.append(decoder_output)

            if teacher_forcing_ratio > 0 and torch.rand(1).item() < teacher_forcing_ratio:
                decoder_input = input_tensor_sequences[:, idx].unsqueeze(1)
            else:
                _, topi = decoder_output.topk(1)
                decoder_input = topi.squeeze(-1).detach()

        ans = torch.stack(decoder_outputs, dim=1)
        ans = torch.nn.functional.log_softmax(ans, dim=-1)
        return ans
    
    def forward_step(self, input_tensor : torch.Tensor, hidden : torch.Tensor) -> torch.Tensor:
        embedded_tensor = self.amino_acid_embedding(input_tensor)
        output, hidden = self.lstm(embedded_tensor, hidden)
        output = self.fc(output)
        return output, hidden

In [131]:
class SecondCountModel(nn.Module):
    def __init__(self, amino_acid_tokens : dict, embedding_dim : int, hidden_size : int, structural_features_size : int, device : torch, output_size : int):
        super(SecondCountModel, self).__init__()

        self.amino_acid_tokens = amino_acid_tokens
        self.embedding_dim = embedding_dim
        self.hidden_size = hidden_size
        self.structural_features_size = structural_features_size
        self.device = device

        self.encoder = Encoder(amino_acid_tokens, embedding_dim, hidden_size, device)
        self.decoder = Decoder(amino_acid_tokens, embedding_dim, hidden_size, device)

        self.fc = nn.Sequential(
            nn.Linear(hidden_size + structural_features_size, hidden_size + structural_features_size),
            nn.ReLU(),
            nn.Linear(hidden_size + structural_features_size, max(structural_features_size, hidden_size)),
            nn.ReLU(),
            nn.Linear(max(structural_features_size, hidden_size), output_size)
        )

    def forward(self, input_tensors : torch.Tensor, input_tensor_sequences : torch.Tensor, sequence_list_lengths : torch.Tensor, decoder: bool = False, teacher_forcing_ratio : float = 0) -> torch.Tensor:
        _, hidden = self.encoder(input_tensor_sequences, sequence_list_lengths)

        output = self.fc(torch.cat([hidden[0].squeeze(0), input_tensors], dim=1))

        if decoder:
            return output, self.decoder(input_tensor_sequences, hidden, teacher_forcing_ratio)
        
        return output
    

In [132]:
model = SecondCountModel(amino_acid_tokens, 32, 64, 5, device, 3)
optimizer = torch.optim.Adam(model.parameters())
train_validation_loss_fns = [nn.MSELoss(), nn.NLLLoss()]
test_loss_fn = nn.MSELoss()

In [133]:
train_model(model, device, training_dataloader, validation_dataloader, test_dataloader, optimizer, train_validation_loss_fns, test_loss_fn, 1, "SecondCount")

Epoch 1/1 - Train Loss: 57970.3019176922 - Validation Loss: 56589.856765071905
Test Loss: 54483.54437169895


54483.54437169895

In [135]:
model.to(device)
total_diff = torch.zeros(3).to(device)

for idx, (input_tensors, input_tensor_sequences, sequence_list_lengths, output_tensors) in enumerate(test_dataloader):
    input_tensors = input_tensors.to(device)
    input_tensor_sequences = input_tensor_sequences.to(device)
    output_tensors = output_tensors.to(device)

    output = model(input_tensors, input_tensor_sequences, sequence_list_lengths, True)
    print(output)
    print(output_tensors)

(tensor([[8.1359, 8.4342, 8.4996],
        [8.1359, 8.4342, 8.4996],
        [8.1359, 8.4342, 8.4996],
        [8.1359, 8.4342, 8.4996],
        [8.1359, 8.4342, 8.4996],
        [8.1359, 8.4342, 8.4996],
        [8.1359, 8.4342, 8.4996],
        [8.1359, 8.4342, 8.4996],
        [8.1359, 8.4342, 8.4996],
        [8.1359, 8.4342, 8.4996],
        [8.1359, 8.4342, 8.4996],
        [8.1359, 8.4342, 8.4996],
        [8.1359, 8.4342, 8.4996],
        [8.1359, 8.4342, 8.4996],
        [8.1359, 8.4342, 8.4996],
        [8.1359, 8.4342, 8.4996]], device='mps:0', grad_fn=<LinearBackward0>), tensor([[[[-3.0224e+00, -4.1372e+00, -4.1497e+00,  ..., -4.2820e+00,
           -4.4127e+00, -4.3524e+00]],

         [[-3.6370e+00, -5.1629e+00, -4.8892e+00,  ..., -5.2681e+00,
           -5.3045e+00, -5.4506e+00]],

         [[-4.1791e+00, -5.9967e+00, -5.4843e+00,  ..., -6.0674e+00,
           -6.1071e+00, -6.3402e+00]],

         ...,

         [[-6.3451e+00, -7.0145e+00, -6.0595e+00,  ..., -7.0166e+00,

In [None]:

total_diff / len(test_dataloader)