In [1]:
from torch import nn
import torch

device = "cpu"

if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available():
    device = "mps"

torch.device(device)

device(type='mps')

In [2]:
import pandas as pd

rcsb_data = pd.read_csv("data/rcsb/RCSB_PDB_Macromolecular_Structure_Dataset_with_Structural_Features.csv")

In [3]:
filtered_data = rcsb_data.loc[:,["Sequence", "Number of Residues", "Molecular Weight per Deposited Model", "Molecular Weight (Entity)", "R Free", "R Work", "Helix", "Sheet", "Coil"]]
filtered_data = filtered_data.dropna(subset=["Helix", "Sheet", "Coil"])

In [4]:
amino_acids = ["A", "C", "D", "E", "F", "G", "H", "I", "K", "L", "M", "N", "P", "Q", "R", "S", "T", "V", "W", "Y", "X", "U", "O"]
additional_tokens = ["<pad>", "<sos>", "<eos>"]

amino_acid_tokens = {aa : i for i, aa in enumerate(additional_tokens + amino_acids)}

def amino_acid_tokenizer(amino_acid : str, amino_acid_tokens) -> torch.Tensor:
    return torch.tensor([amino_acid_tokens[aa] for aa in amino_acid] + [amino_acid_tokens["<eos>"]], dtype=torch.long)

In [5]:
scrambled_data = filtered_data.sample(frac=1)
scrambled_data = scrambled_data.reset_index(drop=True)

data_size = len(scrambled_data)
train_size = int(data_size * 0.8)
test_size = data_size - train_size
validation_size = int(train_size * 0.2)

train_data = scrambled_data.iloc[:train_size - validation_size]
validation_data = scrambled_data.iloc[train_size - validation_size:train_size]
test_data = scrambled_data.iloc[train_size:]

train_data = train_data.reset_index(drop=True)
validation_data = validation_data.reset_index(drop=True)
test_data = test_data.reset_index(drop=True)

In [6]:
train_input_df = train_data.loc[:,["Sequence", "Number of Residues", "Molecular Weight per Deposited Model", "Molecular Weight (Entity)", "R Free", "R Work"]]
train_output_df = train_data.loc[:,["Helix", "Sheet", "Coil"]]
validation_input_df = validation_data.loc[:,["Sequence", "Number of Residues", "Molecular Weight per Deposited Model", "Molecular Weight (Entity)", "R Free", "R Work"]]
validation_output_df = validation_data.loc[:,["Helix", "Sheet", "Coil"]]
test_input_df = test_data.loc[:,["Sequence", "Number of Residues", "Molecular Weight per Deposited Model", "Molecular Weight (Entity)", "R Free", "R Work"]]
test_output_df = test_data.loc[:,["Helix", "Sheet", "Coil"]]

In [7]:
def create_dataset(input_df, output_df):
    input_tensors = torch.tensor(input_df.drop(columns=["Sequence"]).values, dtype=torch.float32)

    tokenized_sequences = input_df["Sequence"].apply(amino_acid_tokenizer, amino_acid_tokens = amino_acid_tokens)

    input_tensor_sequences = torch.nn.utils.rnn.pad_sequence(tokenized_sequences, batch_first = True, padding_side="left", pad_value=amino_acid_tokens["<pad>"])
    
    sequence_list_lengths = torch.tensor([len(sequence) for sequence in tokenized_sequences], dtype=torch.int)

    output_tensors = torch.tensor(output_df.values, dtype=torch.float32)

    return torch.utils.data.TensorDataset(input_tensors, input_tensor_sequences, sequence_list_lengths, output_tensors)

In [8]:
train_dataset = create_dataset(train_input_df, train_output_df)
validation_dataset = create_dataset(validation_input_df, validation_output_df)
test_dataset = create_dataset(test_input_df, test_output_df)

In [9]:
training_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=16, shuffle=True)
validation_dataloader = torch.utils.data.DataLoader(validation_dataset, batch_size=16, shuffle=True)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=16, shuffle=True)

In [10]:
from torch.utils.tensorboard import SummaryWriter
from torch.optim import Optimizer
from torch.utils.data import DataLoader

In [11]:
def train(epoch : int,
          model : nn.Module,
          device: torch.device,
          train_dataloader : DataLoader,
          optimizer : Optimizer,
          loss_fn : nn.Module,
          tensorboard_writer : SummaryWriter = None) -> float:
    running_loss = 0.

    model.train(True)

    for idx, (input_tensors, input_tensor_sequences, sequence_list_lengths, output_tensors) in enumerate(train_dataloader):
        input_tensors = input_tensors.to(device)
        input_tensor_sequences = input_tensor_sequences.to(device)
        output_tensors = output_tensors.to(device)

        optimizer.zero_grad()

        output = model(input_tensors, input_tensor_sequences, sequence_list_lengths)
        loss = loss_fn(output, output_tensors)
  
        loss.backward()
        optimizer.step()
  
        running_loss += loss.item()

        if tensorboard_writer is not None:
            tensorboard_writer.add_scalar("Loss/train", loss.item(), epoch * len(train_dataloader) + idx)

    running_loss /= len(train_dataloader)

    if tensorboard_writer is not None:
        tensorboard_writer.add_scalar("Loss/train/epoch", running_loss, epoch)
    
    return running_loss

In [12]:
def validate(epoch : int, 
             model : nn.Module,
             device : torch.device,
             validation_dataloader : DataLoader,
             loss_fn : nn.Module,
             tensorboard_writer : SummaryWriter = None) -> float:
    running_loss = 0.

    model.eval()

    with torch.no_grad():
        for idx, (input_tensors, input_tensor_sequences, sequence_list_lengths, output_tensors) in enumerate(validation_dataloader):
            input_tensors = input_tensors.to(device)
            input_tensor_sequences = input_tensor_sequences.to(device)
            output_tensors = output_tensors.to(device)

            output = model(input_tensors, input_tensor_sequences, sequence_list_lengths)
            loss = loss_fn(output, output_tensors)

            running_loss += loss.item()

            if tensorboard_writer is not None:
                tensorboard_writer.add_scalar("Loss/validation", loss.item(), epoch * len(validation_dataloader) + idx)
    
    running_loss /= len(validation_dataloader)

    if tensorboard_writer is not None:
        tensorboard_writer.add_scalar("Loss/validation/epoch", running_loss, epoch)
    
    return running_loss

In [13]:
def test(model : nn.Module,
         device : torch.device,
         test_dataloader : DataLoader,
         loss_fn : nn.Module) -> float:
    running_loss = 0.

    model.eval()

    with torch.no_grad():
        for idx, (input_tensors, input_tensor_sequences, sequence_list_lengths, output_tensors)  in enumerate(test_dataloader):
            input_tensors = input_tensors.to(device)
            input_tensor_sequences = input_tensor_sequences.to(device)
            output_tensors = output_tensors.to(device)

            output = model(input_tensors, input_tensor_sequences, sequence_list_lengths)
            loss = loss_fn(output, output_tensors)

            running_loss += loss.item()

    running_loss /= len(test_dataloader)

    return running_loss

In [14]:
from datetime import datetime

In [15]:
def train_model(model : nn.Module,
                device : torch.device,
                training_dataloader : DataLoader,
                validation_dataloader : DataLoader,
                test_dataloader : DataLoader,
                optimizer : Optimizer,
                loss_fn : nn.Module,
                epochs : int,
                model_name : str = "SecondCount") -> float:
    model.to(device)

    timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    tensorboard_writer = SummaryWriter(f"runs/{model_name}_{timestamp}")

    best_validation_loss = float("inf")

    for epoch in range(epochs):
        train_loss = train(epoch, model, device, training_dataloader, optimizer, loss_fn, tensorboard_writer)
        validation_loss = validate(epoch, model, device, validation_dataloader, loss_fn, tensorboard_writer)

        print(f"Epoch {epoch + 1}/{epochs} - Train Loss: {train_loss} - Validation Loss: {validation_loss}")

        if validation_loss < best_validation_loss:
            best_validation_loss = validation_loss
            torch.save(model.state_dict(), f"models/{model_name}_{timestamp}_{epoch}.pt")


    test_loss = test(model, device, test_dataloader, loss_fn)

    print(f"Test Loss: {test_loss}")

    tensorboard_writer.close()

    return test_loss

In [None]:
class Encoder(nn.Module):
    def __init__(self, amino_acid_tokens : dict,  embedding_dim : int, hidden_size : int,  device : torch):
        super(Encoder, self).__init__()

        self.amino_acid_tokens = amino_acid_tokens
        self.embedding_dim = embedding_dim
        self.hidden_size = hidden_size
        self.device = device

        self.amino_acid_embedding = nn.Embedding(len(amino_acid_tokens), embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_size, batch_first = True)

    def forward(self, input_tensor_sequences : torch.Tensor, sequence_list_lengths : torch.Tensor) -> torch.Tensor:
        embedded_sequences = self.amino_acid_embedding(input_tensor_sequences)
        output, (hidden, cell) = self.lstm(embedded_sequences)
        return output, hidden, cell

In [None]:
class Decoder(nn.Module):
    def __init__(self, amino_acid_tokens : dict, embedding_dim : int, hidden_size : int, device : torch):
        super(Decoder, self).__init__()

        self.amino_acid_tokens = amino_acid_tokens
        self.embedding_dim = embedding_dim
        self.hidden_size = hidden_size
        self.device = device

        self.amino_acid_embedding = nn.Embedding(len(amino_acid_tokens), embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_size, batch_first = True)
        self.fc = nn.Linear(hidden_size, len(amino_acid_tokens))

    def forward(self, input_tensor_sequences : torch.Tensor, sequence_list_lengths : torch.Tensor, hidden : torch.Tensor, cell : torch.Tensor) -> torch.Tensor:
        embedded_sequences = self.amino_acid_embedding(input_tensor_sequences)
        output, (hidden, cell) = self.lstm(embedded_sequences, (hidden, cell))
        output = self.fc(output)
        return output, hidden, cell

In [19]:
class SecondCountModel(nn.Module):
    def __init__(self, amino_acid_tokens : dict, hidden_size : int, structural_features_size : int, device : torch, output_size : int):
        super(SecondCountModel, self).__init__()
        self.device = device

        self.amino_acid_embedding = nn.Embedding(len(amino_acid_tokens) - 1, hidden_size)
        self.lstm = nn.LSTM(hidden_size, hidden_size, batch_first = True, bidirectional = True)
        self.amino_acid_tokens = amino_acid_tokens

        self.structural_features_to_hidden = nn.Sequential(
            nn.Linear(structural_features_size, hidden_size * 2),
            nn.ReLU(),
            nn.Linear(hidden_size * 2, hidden_size),
        )

        self.hidden_to_output = nn.Sequential(
            nn.Linear(hidden_size * 2, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, output_size),
        )


    def forward(self, structural_features : torch.Tensor, input_tensor_sequences : torch.Tensor, sequence_list_lengths : torch.Tensor):
        amino_acid_embeddings = self.amino_acid_embedding(input_tensor_sequences).to(self.device)
        structural_features_hidden = self.structural_features_to_hidden(structural_features).to(self.device)

        packed_amino_acid_embeddings = nn.utils.rnn.pack_padded_sequence(amino_acid_embeddings, sequence_list_lengths.cpu(), batch_first = True, enforce_sorted=False).to(self.device)

        h_0 = structural_features_hidden.unsqueeze(0).repeat(2, 1, 1).to(self.device)
        c_0 = structural_features_hidden.unsqueeze(0).repeat(2, 1, 1).to(self.device)

        packed_output, _ = self.lstm(packed_amino_acid_embeddings, (h_0, c_0))

        output, _ = nn.utils.rnn.pad_packed_sequence(packed_output, batch_first = True)

        # Mean pool the output
        output = torch.mean(output, dim = 1)

        output = self.hidden_to_output(output)

        return output

In [20]:
model = SecondCountModel(amino_acid_tokens, 128, 5, device, 3)
optimizer = torch.optim.Adam(model.parameters())
loss_fn = nn.MSELoss()

In [21]:
train_model(model, device, training_dataloader, validation_dataloader, test_dataloader, optimizer, loss_fn, 1024, "SecondCount")

Epoch 1/1024 - Train Loss: 31654.050488068027 - Validation Loss: 22190.404670516305
Epoch 2/1024 - Train Loss: 26856.925991274904 - Validation Loss: 21642.05111667799
Epoch 3/1024 - Train Loss: 26516.446228293873 - Validation Loss: 21432.164684527852
Epoch 4/1024 - Train Loss: 26069.580709798887 - Validation Loss: 20752.058780570653
Epoch 5/1024 - Train Loss: 25226.136799241776 - Validation Loss: 20789.548700747284
Epoch 6/1024 - Train Loss: 24994.43340238213 - Validation Loss: 19798.483120329485
Epoch 7/1024 - Train Loss: 24564.953969364084 - Validation Loss: 19047.331298828125
Epoch 8/1024 - Train Loss: 22957.56353040137 - Validation Loss: 18659.743355129078
Epoch 9/1024 - Train Loss: 22093.011772055812 - Validation Loss: 17791.923596722147
Epoch 10/1024 - Train Loss: 21285.149921533844 - Validation Loss: 18379.051456351903
Epoch 11/1024 - Train Loss: 21025.096761245393 - Validation Loss: 17260.668236243208
Epoch 12/1024 - Train Loss: 19878.792093468546 - Validation Loss: 16457.21597

KeyboardInterrupt: 

In [22]:
model.to(device)
for idx, (input_tensors, input_tensor_sequences, sequence_list_lengths, output_tensors) in enumerate(test_dataloader):
    input_tensors = input_tensors.to(device)
    input_tensor_sequences = input_tensor_sequences.to(device)
    output_tensors = output_tensors.to(device)

    output = model(input_tensors, input_tensor_sequences, sequence_list_lengths)
    print(output)
    print(output_tensors)

tensor([[117.8054,  55.2328, 117.3696],
        [272.5121, 122.8096, 263.7769],
        [125.5346, 123.9469, 199.1798],
        [104.6472,  89.2093, 151.2075],
        [151.6633, 145.1387, 234.8581],
        [ 94.9836,  86.4288, 144.4266],
        [151.9038, 102.0007, 186.4629],
        [212.7014, 127.0935, 243.3509],
        [201.2157, 107.5353, 214.0625],
        [163.8992,  65.0516, 148.3757],
        [ 87.0900,  91.8793, 148.9137],
        [ 91.8534,  87.7834, 144.7154],
        [792.9969, 496.5347, 932.1163],
        [133.1007,  78.8339, 150.8832],
        [144.3036,  58.2812, 132.0340],
        [124.6502,  89.5621, 161.0484]], device='mps:0',
       grad_fn=<LinearBackward0>)
tensor([[ 110.,   64.,  100.],
        [ 467.,  208.,  484.],
        [  20.,   82.,  126.],
        [  43.,   80.,  136.],
        [  19.,   84.,  115.],
        [  42.,   77.,  138.],
        [  39.,   27.,   90.],
        [ 295.,   94.,  332.],
        [  54.,   31.,   31.],
        [ 144.,   34.,  112.],