In [123]:
from torch import nn
import torch

device = "cpu"

if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available():
    device = "mps"

torch.device(device)

device(type='mps')

In [124]:
import pandas as pd

rcsb_data = pd.read_csv("data/rcsb/RCSB_PDB_Macromolecular_Structure_Dataset_with_Structural_Features.csv")

In [150]:
filtered_data = rcsb_data.loc[:,["Sequence", "Number of Residues", "Molecular Weight per Deposited Model", "Molecular Weight (Entity)", "R Free", "R Work", "Helix", "Sheet", "Coil"]]
filtered_data = filtered_data.dropna(subset=["Helix", "Sheet", "Coil"])

In [151]:
amino_acid_tokens = {
    "A": 1,
    "C": 2,
    "D": 3,
    "E": 4,
    "F": 5,
    "G": 6,
    "H": 7,
    "I": 8,
    "K": 9,
    "L": 10,
    "M": 11,
    "N": 12,
    "P": 13,
    "Q": 14,
    "R": 15,
    "S": 16,
    "T": 17,
    "V": 18,
    "W": 19,
    "Y": 20,
    "X": 21,
    "U": 22,
    "O": 23,
    "<pad>": 0
}

def amino_acid_tokenizer(amino_acid : str, amino_acid_tokens) -> torch.Tensor:
    return torch.tensor([amino_acid_tokens[aa] for aa in amino_acid], dtype = torch.int)

In [152]:
tokenized_sequences = filtered_data["Sequence"].apply(amino_acid_tokenizer, amino_acid_tokens = amino_acid_tokens)

# Padding sequences with zeros to make them all the same length
sequences = torch.nn.utils.rnn.pad_sequence(tokenized_sequences, batch_first = True)
filtered_data["Sequence"] = [sequence for sequence in sequences]

In [162]:
scrambled_data = filtered_data.sample(frac=1)
scrambled_data = scrambled_data.reset_index(drop=True)

data_size = len(scrambled_data)
train_size = int(data_size * 0.8)
test_size = data_size - train_size
validation_size = int(train_size * 0.2)

train_data = scrambled_data.iloc[:train_size - validation_size]
validation_data = scrambled_data.iloc[train_size - validation_size:train_size]
test_data = scrambled_data.iloc[train_size:]

train_data = train_data.reset_index(drop=True)
validation_data = validation_data.reset_index(drop=True)
test_data = test_data.reset_index(drop=True)

In [163]:
train_input_df = train_data.loc[:,["Sequence", "Number of Residues", "Molecular Weight per Deposited Model", "Molecular Weight (Entity)", "R Free", "R Work"]]
train_output_df = train_data.loc[:,["Helix", "Sheet", "Coil"]]
validation_input_df = validation_data.loc[:,["Sequence", "Number of Residues", "Molecular Weight per Deposited Model", "Molecular Weight (Entity)", "R Free", "R Work"]]
validation_output_df = validation_data.loc[:,["Helix", "Sheet", "Coil"]]
test_input_df = test_data.loc[:,["Sequence", "Number of Residues", "Molecular Weight per Deposited Model", "Molecular Weight (Entity)", "R Free", "R Work"]]
test_output_df = test_data.loc[:,["Helix", "Sheet", "Coil"]]

In [164]:
def create_dataset(input_df, output_df):
    input_tensors = torch.tensor(input_df.drop(columns=["Sequence"]).values, dtype=torch.float32)
    input_tensor_sequences = torch.stack(tuple(input_df["Sequence"].values), 0)

    output_tensors = torch.tensor(output_df.values, dtype=torch.float32)

    return torch.utils.data.TensorDataset(input_tensors, input_tensor_sequences, output_tensors)

In [165]:
train_dataset = create_dataset(train_input_df, train_output_df)
validation_dataset = create_dataset(validation_input_df, validation_output_df)
test_dataset = create_dataset(test_input_df, test_output_df)

In [166]:
training_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
validation_dataloader = torch.utils.data.DataLoader(validation_dataset, batch_size=32, shuffle=True)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=True)

In [167]:
from torch.utils.tensorboard import SummaryWriter
from torch.optim import Optimizer
from torch.utils.data import DataLoader

In [168]:
def train(epoch : int,
          model : nn.Module,
          device: torch.device,
          train_dataloader : DataLoader,
          optimizer : Optimizer,
          loss_fn : nn.Module,
          tensorboard_writer : SummaryWriter = None) -> float:
    running_loss = 0.

    model.train(True)

    for idx, (input_tensors, input_tensor_sequences, output_tensors) in enumerate(train_dataloader):
        input_tensors = input_tensors.to(device)
        input_tensor_sequences = input_tensor_sequences.to(device)
        output_tensors = output_tensors.to(device)

        optimizer.zero_grad()

        output = model(input_tensors, input_tensor_sequences)
        loss = loss_fn(output, output_tensors)
  
        loss.backward()
        optimizer.step()
  
        running_loss += loss.item()

        if tensorboard_writer is not None:
            tensorboard_writer.add_scalar("Loss/train", loss.item(), epoch * len(train_dataloader) + idx)

    running_loss /= len(train_dataloader)

    if tensorboard_writer is not None:
        tensorboard_writer.add_scalar("Loss/train/epoch", running_loss, epoch)
    
    return running_loss

In [169]:
def validate(epoch : int, 
             model : nn.Module,
             device : torch.device,
             validation_dataloader : DataLoader,
             loss_fn : nn.Module,
             tensorboard_writer : SummaryWriter = None) -> float:
    running_loss = 0.

    model.eval()

    with torch.no_grad():
        for idx, (input_tensors, input_tensor_sequences, output_tensors) in enumerate(validation_dataloader):
            input_tensors = input_tensors.to(device)
            input_tensor_sequences = input_tensor_sequences.to(device)
            output_tensors = output_tensors.to(device)

            output = model(input_tensors, input_tensor_sequences)
            loss = loss_fn(output, output_tensors)

            running_loss += loss.item()

            if tensorboard_writer is not None:
                tensorboard_writer.add_scalar("Loss/validation", loss.item(), epoch * len(validation_dataloader) + idx)
    
    running_loss /= len(validation_dataloader)

    if tensorboard_writer is not None:
        tensorboard_writer.add_scalar("Loss/validation/epoch", running_loss, epoch)
    
    return running_loss

In [170]:
def test(model : nn.Module,
         device : torch.device,
         test_dataloader : DataLoader,
         loss_fn : nn.Module) -> float:
    running_loss = 0.

    model.eval()

    with torch.no_grad():
        for idx, (input_tensors, input_tensor_sequences, output_tensors) in enumerate(test_dataloader):
            input_tensors = input_tensors.to(device)
            input_tensor_sequences = input_tensor_sequences.to(device)
            output_tensors = output_tensors.to(device)

            output = model(input_tensors, input_tensor_sequences)
            loss = loss_fn(output, output_tensors)

            running_loss += loss.item()

    running_loss /= len(test_dataloader)

    return running_loss

In [171]:
from datetime import datetime

In [172]:
def train_model(model : nn.Module,
                device : torch.device,
                training_dataloader : DataLoader,
                validation_dataloader : DataLoader,
                test_dataloader : DataLoader,
                optimizer : Optimizer,
                loss_fn : nn.Module,
                epochs : int,
                model_name : str = "SecondCount") -> float:
    model.to(device)

    timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    tensorboard_writer = SummaryWriter(f"runs/{model_name}_{timestamp}")

    best_validation_loss = float("inf")

    for epoch in range(epochs):
        train_loss = train(epoch, model, device, training_dataloader, optimizer, loss_fn, tensorboard_writer)
        validation_loss = validate(epoch, model, device, validation_dataloader, loss_fn, tensorboard_writer)

        print(f"Epoch {epoch + 1}/{epochs} - Train Loss: {train_loss} - Validation Loss: {validation_loss}")

        if validation_loss < best_validation_loss:
            best_validation_loss = validation_loss
            torch.save(model.state_dict(), f"models/{model_name}_{timestamp}_{epoch}.pt")


    test_loss = test(model, device, test_dataloader, loss_fn)

    print(f"Test Loss: {test_loss}")

    tensorboard_writer.close()

    return test_loss

In [211]:
class SecondCountModel(nn.Module):
    def __init__(self, device : torch.device, embedding_dim : int, structural_features : int, amino_acid_tokens : dict, hidden_size : int = 128):
        super().__init__()

        self.embedding_dim = embedding_dim
        self.structural_features = structural_features
        self.amino_acid_tokens = amino_acid_tokens
        self.hidden_size = hidden_size
        self.device = device

        self.embedding = nn.Embedding(len(amino_acid_tokens), embedding_dim, padding_idx=amino_acid_tokens["<pad>"])

        # Encoder-Decoder LSTM Model

        self.encoder_lstm = nn.LSTM(input_size=embedding_dim, hidden_size=hidden_size, num_layers=1, batch_first=True)
        self.decoder_lstm = nn.LSTM(input_size=embedding_dim, hidden_size=hidden_size, num_layers=1, batch_first=True)

        # Fully Connected Layer
        self.fc = nn.Sequential(
            nn.Linear(hidden_size + structural_features, 64),
            nn.ReLU(),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Linear(32, 3)
        )
    
    def forward(self, input_tensors, input_tensor_sequences):
        embedded_sequences = self.embedding(input_tensor_sequences)
        batch_size, sequence_length, _ = embedded_sequences.size()

        _, (encoding_hidden, encoding_cell) = self.encoder_lstm(embedded_sequences)

        decoder_outputs, _ = self.decoder_lstm(embedded_sequences, (encoding_hidden, encoding_cell))

        # Extracting the last hidden state of the decoder
        last_hidden_state = decoder_outputs[:, -1, :]

        # Concatenating the last hidden state of the decoder with the structural features
        concatenated_features = torch.cat((last_hidden_state, input_tensors), dim=1)

        output = self.fc(concatenated_features)

        return output



In [218]:
model = SecondCountModel(device, 32, train_dataset.tensors[0].shape[1], amino_acid_tokens)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
loss_fn = nn.MSELoss()

In [219]:
train_model(model, device, training_dataloader, validation_dataloader, test_dataloader, optimizer, loss_fn, 128, "SecondCount")

Epoch 1/128 - Train Loss: 58957.595703125 - Validation Loss: 57226.11297817888
Epoch 2/128 - Train Loss: 59072.78733965611 - Validation Loss: 59947.20780576509
Epoch 3/128 - Train Loss: 59065.20229769378 - Validation Loss: 57350.127794989225
Epoch 4/128 - Train Loss: 59049.20721888646 - Validation Loss: 57256.86722117457
Epoch 5/128 - Train Loss: 59044.51104496452 - Validation Loss: 57322.02475080819
Epoch 6/128 - Train Loss: 58974.49274188046 - Validation Loss: 57229.198747306036
Epoch 7/128 - Train Loss: 58925.97316798581 - Validation Loss: 57627.59018049569
Epoch 8/128 - Train Loss: 58905.5217658297 - Validation Loss: 57176.26650053879
Epoch 9/128 - Train Loss: 58863.157290529474 - Validation Loss: 56840.27262931035
Epoch 10/128 - Train Loss: 58806.80977927129 - Validation Loss: 56904.486900592674
Epoch 11/128 - Train Loss: 58768.254281522924 - Validation Loss: 56767.202619881464
Epoch 12/128 - Train Loss: 58683.64304721616 - Validation Loss: 59163.9437971444
Epoch 13/128 - Train Lo

KeyboardInterrupt: 

In [220]:
model.to(device)
for idx, (input_tensors, input_tensor_sequences, output_tensors) in enumerate(test_dataloader):
    input_tensors = input_tensors.to(device)
    input_tensor_sequences = input_tensor_sequences.to(device)
    output_tensors = output_tensors.to(device)

    output = model(input_tensors, input_tensor_sequences)
    print(output)
    print(output_tensors)

tensor([[7.2891, 6.9583, 7.9596],
        [7.2891, 6.9583, 7.9596],
        [7.2891, 6.9583, 7.9596],
        [7.2891, 6.9583, 7.9596],
        [7.2891, 6.9583, 7.9596],
        [7.2891, 6.9583, 7.9596],
        [7.2891, 6.9583, 7.9596],
        [7.2891, 6.9583, 7.9596],
        [7.2891, 6.9583, 7.9596],
        [7.2891, 6.9583, 7.9596],
        [7.2891, 6.9583, 7.9596],
        [7.2891, 6.9583, 7.9596],
        [7.2891, 6.9583, 7.9596],
        [7.2891, 6.9583, 7.9596],
        [7.2891, 6.9583, 7.9596],
        [7.2891, 6.9583, 7.9596],
        [7.2891, 6.9583, 7.9596],
        [7.2891, 6.9583, 7.9596],
        [7.2891, 6.9583, 7.9596],
        [7.2891, 6.9583, 7.9596],
        [7.2891, 6.9583, 7.9596],
        [7.2891, 6.9583, 7.9596],
        [7.2891, 6.9583, 7.9596],
        [7.2891, 6.9583, 7.9596],
        [7.2891, 6.9583, 7.9596],
        [7.2891, 6.9583, 7.9596],
        [7.2891, 6.9583, 7.9596],
        [7.2891, 6.9583, 7.9596],
        [7.2891, 6.9583, 7.9596],
        [7.289