In [1]:
from torch import nn
import torch

device = "cpu"

if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available():
    device = "mps"

torch.device(device)

device(type='mps')

In [3]:
import pandas as pd

pdb_data = pd.read_csv("data/pdb/data_filtered_le_2048.csv")

In [4]:
pdb_data = pdb_data.dropna()

In [28]:
pad_token = "<pad>"
pad_token_val = 0

In [13]:
amino_acids = ["A", "R", "N", "D", "C", "Q", "E", "G", "H", "I", "L", "K", "M", "F", "P", "S", "T", "W", "Y", "V", "X"]

amino_acid_tokens = {aa : i + 1 for i, aa in enumerate(amino_acids)}
tokens_amino_acid = {i + 1 : aa for i, aa in enumerate(amino_acids)}

def amino_acid_tokenizer(amino_acid : str, amino_acid_tokens) -> torch.Tensor:
    return torch.tensor([amino_acid_tokens[aa] for aa in amino_acid], dtype=torch.long)

In [20]:
dssp_types = ["G", "H", "I", "T", "E", "B", "S", "P", "-"]

dssp_tokens_offset = len(amino_acid_tokens)

dssp_tokens = {dssp : i + dssp_tokens_offset + 1 for i, dssp in enumerate(dssp_types)}
tokens_dssp = {i + dssp_tokens_offset + 1 : dssp for i, dssp in enumerate(dssp_types)}

def dssp_tokenizer(dssp : str, dssp_tokens) -> torch.Tensor:
    return torch.tensor([dssp_tokens[dssp] for dssp in dssp], dtype=torch.long)

In [22]:
scrambled_data = pdb_data.sample(frac=1)
scrambled_data = scrambled_data.reset_index(drop=True)

data_size = len(scrambled_data)
train_size = int(data_size * 0.8)
validation_size = int(data_size * 0.1)
test_size = data_size - train_size - validation_size

train_data = scrambled_data[:train_size]
validation_data = scrambled_data[train_size:train_size + validation_size]
test_data = scrambled_data[train_size + validation_size:]

train_data = train_data.reset_index(drop=True)
validation_data = validation_data.reset_index(drop=True)
test_data = test_data.reset_index(drop=True)

In [23]:
train_input_df = train_data["sequence"]
train_output_df = train_data["secondary_structure"]

validation_input_df = validation_data["sequence"]
validation_output_df = validation_data["secondary_structure"]

test_input_df = test_data["sequence"]
test_output_df = test_data["secondary_structure"]

In [29]:
def create_dataset(input_df, output_df):
    tokenized_input = input_df.apply(amino_acid_tokenizer, amino_acid_tokens = amino_acid_tokens)

    input_tensor_seq = torch.nn.utils.rnn.pad_sequence(tokenized_input, batch_first = True, padding_value = pad_token_val)

    tokenized_output = output_df.apply(dssp_tokenizer, dssp_tokens = dssp_tokens)

    output_tensor_seq = torch.nn.utils.rnn.pad_sequence(tokenized_output, batch_first = True, padding_value = pad_token_val)    

    return torch.utils.data.TensorDataset(input_tensor_seq, output_tensor_seq)

In [30]:
train_dataset = create_dataset(train_input_df, train_output_df)
validation_dataset = create_dataset(validation_input_df, validation_output_df)
test_dataset = create_dataset(test_input_df, test_output_df)

In [10]:
training_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=8, shuffle=True)
validation_dataloader = torch.utils.data.DataLoader(validation_dataset, batch_size=8, shuffle=True)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=8, shuffle=True)

In [11]:
from torch.utils.tensorboard import SummaryWriter
from torch.optim import Optimizer
from torch.utils.data import DataLoader

In [12]:
from typing import List, Tuple

In [13]:
def train(epoch : int,
          encoder : nn.Module,
          decoder : nn.Module,
          model : nn.Module,
          encoder_optimizer : Optimizer,
          decoder_optimizer : Optimizer,
          model_optimizer : Optimizer,
          dataloader : DataLoader,
          decoder_criterion : nn.Module,
          model_criterion : nn.Module,
          teacher_forcing_ratio_decay : float = 0.,
          writer : SummaryWriter = None) -> Tuple[Tuple[float, float], float]:
    encoder.train()
    decoder.train()
    model.train()

    running_losses = (0., 0.)
    running_total_loss = 0.

    for idx, (structural_data, sequence_data, sequence_length, output_data) in enumerate(dataloader):
        structural_data = structural_data.to(device)
        sequence_data = sequence_data.to(device)
        sequence_length = sequence_length.to(device)
        output_data = output_data.to(device)

        encoder_optimizer.zero_grad()
        decoder_optimizer.zero_grad()
        model_optimizer.zero_grad()

        encoder_output, encoder_hidden = encoder(sequence_data, sequence_length)
        decoder_output, _, _ = decoder(sequence_data, encoder_output, encoder_hidden, teacher_forcing_ratio_decay)
        model_output = model(structural_data, encoder_hidden)

        decoder_loss = decoder_criterion(decoder_output.view(-1, decoder_output.size(-1)), sequence_data.view(-1))
        model_loss = model_criterion(model_output, output_data)
        total_loss = decoder_loss + model_loss

        total_loss.backward()

        decoder_optimizer.step()
        model_optimizer.step()
        encoder_optimizer.step()
        
        running_losses = (running_losses[0] + decoder_loss.item(), running_losses[1] + model_loss.item())
        running_total_loss += total_loss.item()

        if writer is not None:
            writer.add_scalar("Loss/Train/Decoder", decoder_loss.item(), epoch * len(dataloader) + idx)
            writer.add_scalar("Loss/Train/Model", model_loss.item(), epoch * len(dataloader) + idx)
            writer.add_scalar("Loss/Train/Total", total_loss.item(), epoch * len(dataloader) + idx)
    
    running_losses = (running_losses[0] / len(dataloader), running_losses[1] / len(dataloader))
    running_total_loss /= len(dataloader)

    if writer is not None:
        writer.add_scalar("Loss/Train/Decoder/Epoch", running_losses[0], epoch)
        writer.add_scalar("Loss/Train/Model/Epoch", running_losses[1], epoch)
        writer.add_scalar("Loss/Train/Total/Epoch", running_total_loss, epoch)

    return running_losses, running_total_loss

In [14]:
def validate(epoch : int,
             encoder : nn.Module,
             decoder : nn.Module,
             model : nn.Module,
             dataloader : DataLoader,
             decoder_criterion : nn.Module,
             model_criterion : nn.Module,
             writer : SummaryWriter = None) -> Tuple[Tuple[float, float], float]:
    encoder.eval()
    decoder.eval()
    model.eval()

    running_losses = (0.0, 0.0)
    running_total_loss = 0.

    with torch.no_grad():
        for idx, (structural_data, sequence_data, sequence_length, output_data) in enumerate(dataloader):
            structural_data = structural_data.to(device)
            sequence_data = sequence_data.to(device)
            sequence_length = sequence_length.to(device)
            output_data = output_data.to(device)

            encoder_output, encoder_hidden = encoder(sequence_data, sequence_length)
            decoder_output, _, _ = decoder(sequence_data, encoder_output, encoder_hidden)
            model_output = model(structural_data, encoder_hidden)

            decoder_loss = decoder_criterion(decoder_output.view(-1, decoder_output.size(-1)), sequence_data.view(-1))
            model_loss = model_criterion(model_output, output_data)

            running_losses = (running_losses[0] + decoder_loss.item(), running_losses[1] + model_loss.item())
            running_total_loss += decoder_loss.item() + model_loss.item()

            if writer is not None:
                writer.add_scalar("Loss/Validation/Decoder", decoder_loss.item(), epoch * len(dataloader) + idx)
                writer.add_scalar("Loss/Validation/Model", model_loss.item(), epoch * len(dataloader) + idx)
                writer.add_scalar("Loss/Validation/Total", decoder_loss.item() + model_loss.item(), epoch * len(dataloader) + idx)
    
    running_losses = (running_losses[0] / len(dataloader), running_losses[1] / len(dataloader))
    running_total_loss /= len(dataloader)

    if writer is not None:
        writer.add_scalar("Loss/Validation/Decoder/Epoch", running_losses[0], epoch)
        writer.add_scalar("Loss/Validation/Model/Epoch", running_losses[1], epoch)
        writer.add_scalar("Loss/Validation/Total/Epoch", running_total_loss, epoch)

    return running_losses, running_total_loss

In [15]:
def test(encoder : nn.Module,
         decoder : nn.Module,
         model : nn.Module,
         dataloader : DataLoader,
         decoder_criterion : nn.Module,
         model_criterion : nn.Module) -> Tuple[Tuple[float, float], float]:
    encoder.eval()
    decoder.eval()
    model.eval()

    running_losses = (0.0, 0.0)
    running_total_loss = 0.

    with torch.no_grad():
        for idx, (structural_data, sequence_data, sequence_length, output_data) in enumerate(dataloader):
            structural_data = structural_data.to(device)
            sequence_data = sequence_data.to(device)
            sequence_length = sequence_length.to(device)
            output_data = output_data.to(device)

            encoder_output, encoder_hidden = encoder(sequence_data, sequence_length)
            decoder_output, _, _ = decoder(sequence_data, encoder_output, encoder_hidden)
            model_output = model(structural_data, encoder_hidden)

            decoder_loss = decoder_criterion(decoder_output.view(-1, decoder_output.size(-1)), sequence_data.view(-1))
            model_loss = model_criterion(model_output, output_data)

            running_losses = (running_losses[0] + decoder_loss.item(), running_losses[1] + model_loss.item())
            running_total_loss += decoder_loss.item() + model_loss.item()
                
    running_losses = (running_losses[0] / len(dataloader), running_losses[1] / len(dataloader))
    running_total_loss /= len(dataloader)

    return running_losses, running_total_loss

In [16]:
from datetime import datetime

In [17]:
import pathlib

def train_model(encoder : nn.Module,
                decoder : nn.Module,
                model : nn.Module,
                device : torch.device,
                training_dataloader : DataLoader,
                validation_dataloader : DataLoader,
                test_dataloader : DataLoader,
                encoder_optimizer : Optimizer,
                decoder_optimizer : Optimizer,
                model_optimizer : Optimizer,
                model_loss_fn : nn.Module,
                decoder_loss_fn : nn.Module,
                epochs : int,
                teacher_forcing_ratio_decay : float = 0.,
                model_name : str = "SecondCount") -> float:
    timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    tensorboard_writer = SummaryWriter(log_dir=f"runs/{model_name}-{timestamp}")

    encoder.to(device)
    decoder.to(device)
    model.to(device)

    # Create models directory if it does not exist
    pathlib.Path(f"models/{model_name}-{timestamp}").mkdir(parents=True, exist_ok=True)

    best_total_loss = float("inf")

    for epoch in range(epochs):
        print(f"Epoch {epoch + 1}/{epochs}")

        train_losses, train_total_losses = train(epoch, encoder, decoder, model, encoder_optimizer, decoder_optimizer, model_optimizer, training_dataloader, decoder_loss_fn, model_loss_fn, teacher_forcing_ratio_decay, tensorboard_writer)
        validation_losses, validation_total_losses = validate(epoch, encoder, decoder, model, validation_dataloader, decoder_loss_fn, model_loss_fn, tensorboard_writer)

        print(f"Train Losses: {train_losses}, Train Total Loss: {train_total_losses}")
        print(f"Validation Losses: {validation_losses}, Validation Total Loss: {validation_total_losses}")

        if validation_total_losses < best_total_loss:
            best_total_loss = validation_total_losses

            torch.save(encoder.state_dict(), f"models/{model_name}-{timestamp}/encoder.pt")
            torch.save(decoder.state_dict(), f"models/{model_name}-{timestamp}/decoder.pt")
            torch.save(model.state_dict(), f"models/{model_name}-{timestamp}/model.pt")

    test_losses, test_total_losses = test(encoder, decoder, model, test_dataloader, decoder_loss_fn, model_loss_fn)

    print(f"Test Losses: {test_losses}, Test Total Loss: {test_total_losses}")

    tensorboard_writer.close()

    return test_total_losses


In [18]:
class Encoder(nn.Module):
    def __init__(self, amino_acid_tokens : dict,  embedding_dim : int, hidden_size : int,  device : torch, bidirectional : bool = False, num_layers : int = 1):
        super(Encoder, self).__init__()

        self.amino_acid_tokens = amino_acid_tokens
        self.embedding_dim = embedding_dim
        self.hidden_size = hidden_size
        self.device = device
        self.bidirectional = bidirectional
        self.num_layers = num_layers

        self.amino_acid_embedding = nn.Embedding(len(amino_acid_tokens), embedding_dim, padding_idx=amino_acid_tokens["<pad>"])
        self.lstm = nn.LSTM(embedding_dim, hidden_size, batch_first = True, bidirectional = bidirectional, num_layers = num_layers)

    def forward(self, input_tensor_sequences : torch.Tensor, sequence_list_lengths : torch.Tensor) -> torch.Tensor:
        embedded_sequences = self.amino_acid_embedding(input_tensor_sequences)

        packed_sequences = nn.utils.rnn.pack_padded_sequence(embedded_sequences, sequence_list_lengths.cpu(), batch_first = True, enforce_sorted = False)

        output, hidden = self.lstm(packed_sequences)

        output, _ = nn.utils.rnn.pad_packed_sequence(output, batch_first = True)

        return output, hidden

In [19]:
class Attention(nn.Module):
    def __init__(self, hidden_size : int, num_layers : int = 1):
        super(Attention, self).__init__()

        self.Wa = nn.Linear(hidden_size * num_layers, hidden_size)
        self.Ua = nn.Linear(hidden_size, hidden_size)
        self.Va = nn.Linear(hidden_size, 1)

    def forward(self, query : torch.Tensor, keys : torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        batch_size = query.size(0)
        query = query.reshape(batch_size, 1, -1)
        scores = self.Va(torch.tanh(self.Wa(query) + self.Ua(keys)))
        scores = scores.squeeze(2).unsqueeze(1)

        attention_weights = torch.nn.functional.softmax(scores, dim=-1)
        context = torch.bmm(attention_weights, keys)

        return context, attention_weights

In [20]:
class Decoder(nn.Module):
    def __init__(self, amino_acid_tokens : dict, embedding_dim : int, hidden_size : int, device : torch, dropout : float = 0.1, encoder_bidirectional : bool = False, num_layers : int = 1):
        super(Decoder, self).__init__()
        
        if encoder_bidirectional:
            hidden_size *= 2

        self.amino_acid_tokens = amino_acid_tokens
        self.embedding_dim = embedding_dim
        self.hidden_size = hidden_size
        self.device = device
        self.encoder_bidirectional = encoder_bidirectional
        self.num_layers = num_layers
        self.dropout = dropout

        self.attention = Attention(hidden_size, num_layers)
        self.amino_acid_embedding = nn.Embedding(len(amino_acid_tokens), embedding_dim, padding_idx=amino_acid_tokens["<pad>"])
        self.lstm = nn.LSTM(embedding_dim + hidden_size, hidden_size, batch_first = True, num_layers = num_layers)
        self.fc = nn.Linear(hidden_size, len(amino_acid_tokens))
        self.dropout = nn.Dropout(dropout)

    def forward(self, input_tensor_sequences : torch.Tensor, encoder_outputs : torch.Tensor, encoder_hidden : torch.Tensor, teacher_forcing_ratio : float = 0.) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        batch_size, sequence_length = input_tensor_sequences.size()

        decoder_input = torch.empty(batch_size, 1, dtype=torch.long).fill_(self.amino_acid_tokens["<sos>"]).to(self.device)
        if self.encoder_bidirectional:
            encoder_even_hidden = (encoder_hidden[0][::2], encoder_hidden[1][::2])
            encoder_odd_hidden = (encoder_hidden[0][1::2], encoder_hidden[1][1::2])
            encoder_hidden = (torch.cat((encoder_even_hidden[0], encoder_odd_hidden[0]), dim=-1), torch.cat((encoder_even_hidden[1], encoder_odd_hidden[1]), dim=-1))
        
        decoder_hidden = encoder_hidden

        decoder_outputs = []
        attentions = []

        for idx in range(sequence_length):
            decoder_output, decoder_hidden, attn_weights = self.forward_step(decoder_input, encoder_outputs, decoder_hidden)
            decoder_outputs.append(decoder_output)
            attentions.append(attn_weights)

            if teacher_forcing_ratio > 0 and torch.rand(1).item() < teacher_forcing_ratio:
                decoder_input = input_tensor_sequences[:, idx].unsqueeze(1)
            else:
                _, topi = decoder_output.topk(1)
                decoder_input = topi.squeeze(-1).detach()

        ans = torch.stack(decoder_outputs, dim=1)
        ans = torch.nn.functional.log_softmax(ans, dim=-1)
        attentions = torch.cat(attentions, dim=1)
        return ans, decoder_hidden, attentions
    
    def forward_step(self, input_tensor : torch.Tensor, encoder_outputs : torch.Tensor, encoder_hidden : torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        embedded_tensor = self.dropout(self.amino_acid_embedding(input_tensor))

        query = encoder_hidden[0].permute(1, 0, 2)
        context, attention_weights = self.attention(query, encoder_outputs)
        input = torch.cat((embedded_tensor, context), dim=2)

        output, encoder_hidden = self.lstm(input, encoder_hidden)
        output = self.fc(output)
        return output, encoder_hidden, attention_weights

In [21]:
class SecondCountModel(nn.Module):
    def __init__(self, hidden_size : int, structural_features_size : int, output_size : int, bidirectional : bool = False, num_layers : int = 1):
        super(SecondCountModel, self).__init__()

        true_hidden_size = hidden_size

        if bidirectional:
            hidden_size *= 2
        
        hidden_size *= num_layers

        self.hidden_size = hidden_size
        self.structural_features_size = structural_features_size
        self.output_size = output_size
        self.bidirectional = bidirectional
        self.num_layers = num_layers

        self.hidden_norm = nn.BatchNorm1d(hidden_size)

        self.fc = nn.Sequential(
            nn.Linear(hidden_size + structural_features_size, hidden_size + structural_features_size),
            nn.BatchNorm1d(hidden_size + structural_features_size),
            nn.LeakyReLU(),
            nn.Dropout(0.5),
            nn.Linear(hidden_size + structural_features_size, hidden_size),
            nn.BatchNorm1d(hidden_size),
            nn.ELU(),
            nn.Dropout(0.5),
            nn.Linear(hidden_size, true_hidden_size),
            nn.BatchNorm1d(true_hidden_size),
            nn.ELU(),
            nn.Dropout(0.5),
            nn.Linear(true_hidden_size, output_size),
            nn.ELU()
        )

    def forward(self, input_tensors : torch.Tensor, hidden : torch.Tensor) -> torch.Tensor:
        # Flatten hidden state
        hidden = hidden[0].permute(1, 0, 2).reshape(hidden[0].size(1), -1)

        norm_hidden = self.hidden_norm(hidden)

        input_tensor = torch.cat([norm_hidden, input_tensors], dim=1)

        for m in self.fc:
            input_tensor = m(input_tensor)
        return input_tensor
    

In [22]:
class MSLELoss(nn.Module):
    def __init__(self):
        super(MSLELoss, self).__init__()

    def forward(self, input : torch.Tensor, target : torch.Tensor) -> torch.Tensor:
        return torch.mean((torch.log1p(input) - torch.log1p(target)) ** 2)

In [23]:
encoder = Encoder(amino_acid_tokens, 32, 64, device, True, 4)
decoder = Decoder(amino_acid_tokens, 32, 64, device, 0.1, True, 4)
model = SecondCountModel(64, 5, 3, True, 4)

encoder_optimizer = torch.optim.Adam(encoder.parameters())
decoder_optimizer = torch.optim.Adam(decoder.parameters())
model_optimizer = torch.optim.Adam(model.parameters())


In [24]:
model_loss = MSLELoss()
decoder_loss = nn.NLLLoss()

In [25]:
teacher_forcing_ratio_decay = 0.5

In [None]:
train_model(encoder, decoder, model, device, training_dataloader, validation_dataloader, test_dataloader, encoder_optimizer, decoder_optimizer, model_optimizer, model_loss, decoder_loss, 1024, teacher_forcing_ratio_decay, "SecondCount")

Epoch 1/1024


In [None]:
encoder.to(device)
decoder.to(device)
model.to(device)

for idx, (structural_data, sequence_data, sequence_length, output_data) in enumerate(test_dataloader):
    structural_data = structural_data.to(device)
    sequence_data = sequence_data.to(device)
    sequence_length = sequence_length.to(device)
    output_data = output_data.to(device)

    encoder_output, encoder_hidden = encoder(sequence_data, sequence_length)
    decoder_output = decoder(sequence_data, encoder_hidden,)
    model_output = model(structural_data, encoder_hidden)

    print(f"Structural Data: {structural_data}")
    print(f"Output Data: {output_data}")
    print(f"Model Output: {model_output}")
    print(f"Decoder Output: {decoder_output}")