In [9]:
from torch import nn
import torch

device = "cpu"

if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available():
    device = "mps"

torch.device(device)

device(type='mps')

In [10]:
import pandas as pd

pdb_data = pd.read_csv("data/pdb/data_filtered_le_2048.csv")

In [11]:
pdb_data = pdb_data.dropna()

In [12]:
pad_token = "<pad>"
pad_token_val = 0

In [13]:
amino_acids = ["A", "R", "N", "D", "C", "Q", "E", "G", "H", "I", "L", "K", "M", "F", "P", "S", "T", "W", "Y", "V", "X"]

amino_acid_tokens = {aa : i + 1 for i, aa in enumerate(amino_acids)}
tokens_amino_acid = {i + 1 : aa for i, aa in enumerate(amino_acids)}

def amino_acid_tokenizer(amino_acid : str, amino_acid_tokens) -> torch.Tensor:
    return torch.tensor([amino_acid_tokens[aa] for aa in amino_acid], dtype=torch.long)

In [14]:
dssp_types = ["G", "H", "I", "T", "E", "B", "S", "P", "-"]

dssp_tokens = {dssp : i + 1 for i, dssp in enumerate(dssp_types)}
tokens_dssp = {i + 1 : dssp for i, dssp in enumerate(dssp_types)}

def dssp_tokenizer(dssp : str, dssp_tokens) -> torch.Tensor:
    return torch.tensor([dssp_tokens[dssp] for dssp in dssp], dtype=torch.long)

In [15]:
max_sequence_length = pdb_data["sequence"].str.len().max()

In [16]:
def create_dataset(input_df, output_df):
    tokenized_input = input_df.apply(amino_acid_tokenizer, amino_acid_tokens = amino_acid_tokens)

    input_tensor_seq = torch.nn.utils.rnn.pad_sequence(tokenized_input, batch_first = True, padding_value = pad_token_val)

    seq_length_tensor = torch.tensor([len(seq) for seq in input_df], dtype=torch.long)

    tokenized_output = output_df.apply(dssp_tokenizer, dssp_tokens = dssp_tokens)

    output_tensor_seq = torch.nn.utils.rnn.pad_sequence(tokenized_output, batch_first = True, padding_value = pad_token_val)    

    return torch.utils.data.TensorDataset(input_tensor_seq, output_tensor_seq, seq_length_tensor)

In [17]:
pdb_dataset = create_dataset(pdb_data["sequence"], pdb_data["secondary_structure"])

In [18]:
generator = torch.Generator().manual_seed(32678236876854694)

In [19]:
train_dataset, validation_dataset, test_dataset = torch.utils.data.random_split(pdb_dataset, [0.8, 0.1, 0.1], generator = generator)

In [20]:
training_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=16, shuffle=True)
validation_dataloader = torch.utils.data.DataLoader(validation_dataset, batch_size=16, shuffle=True)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=16, shuffle=True)

In [21]:
from torch.utils.tensorboard import SummaryWriter
from torch.optim import Optimizer
from torch.utils.data import DataLoader

In [22]:
from typing import List, Tuple

In [23]:
import math
def train(epoch : int,
          encoder : nn.Module,
          decoder : nn.Module,
          encoder_optimizer : Optimizer,
          decoder_optimizer : Optimizer,
          dataloader : DataLoader,
          decoder_criterion : nn.Module,
          teacher_forcing_ratio_decay : float = 0.,
          writer : SummaryWriter = None) -> float:
    encoder.train()
    decoder.train()

    running_loss = 0.

    for idx, (sequence_data, ss_data, seq_len) in enumerate(dataloader):
        sequence_data = sequence_data.to(device)
        ss_data = ss_data.to(device)

        encoder_optimizer.zero_grad()
        decoder_optimizer.zero_grad()

        encoder_output, encoder_hidden = encoder(sequence_data, seq_len)
        decoder_output, _, _ = decoder(seq_len, encoder_output, encoder_hidden, ss_data, teacher_forcing_ratio_decay * math.exp(-epoch/10))

        loss = decoder_criterion(decoder_output.view(-1, decoder_output.size(-1)), ss_data.view(-1))

        loss.backward()

        decoder_optimizer.step()
        encoder_optimizer.step()
        
        running_loss += loss.item()

        if writer is not None:
            writer.add_scalar("Loss/Train/Decoder/Batch", loss.item(), epoch * len(dataloader) + idx)
    
    avg_loss = running_loss / len(dataloader)

    if writer is not None:
        writer.add_scalar("Loss/Train/Decoder/Epoch", avg_loss, epoch)

    return avg_loss

In [24]:
def validate(epoch : int,
             encoder : nn.Module,
             decoder : nn.Module,
             dataloader : DataLoader,
             decoder_criterion : nn.Module,
             writer : SummaryWriter = None) -> float:
    encoder.eval()
    decoder.eval()

    running_loss = 0.

    with torch.no_grad():
        for idx, (sequence_data, ss_data, seq_len) in enumerate(dataloader):
            sequence_data = sequence_data.to(device)
            ss_data = ss_data.to(device)

            encoder_output, encoder_hidden = encoder(sequence_data, seq_len)
            decoder_output, _, _ = decoder(seq_len, encoder_output, encoder_hidden)

            decoder_loss = decoder_criterion(decoder_output.view(-1, decoder_output.size(-1)), ss_data.view(-1))

            running_loss += decoder_loss.item()

            if writer is not None:
                writer.add_scalar("Loss/Validation/Decoder/Batch", decoder_loss.item(), epoch * len(dataloader) + idx)
    
    avg_loss = running_loss / len(dataloader)

    if writer is not None:
        writer.add_scalar("Loss/Validation/Decoder/Epoch", avg_loss, epoch)

    return avg_loss

In [25]:
def test(encoder : nn.Module,
         decoder : nn.Module,
         dataloader : DataLoader,
         decoder_criterion : nn.Module) -> float:
    encoder.eval()
    decoder.eval()

    running_loss = 0.

    with torch.no_grad():
        for idx, (sequence_data, ss_data, seq_len) in enumerate(dataloader):
            sequence_data = sequence_data.to(device)
            ss_data = ss_data.to(device)

            encoder_output, encoder_hidden = encoder(sequence_data, seq_len)
            decoder_output, _, _ = decoder(seq_len, encoder_output, encoder_hidden)

            decoder_loss = decoder_criterion(decoder_output.view(-1, decoder_output.size(-1)), ss_data.view(-1))

            running_loss += decoder_loss.item()
                
    avg_loss = running_loss / len(dataloader)

    return avg_loss

In [26]:
from datetime import datetime

In [27]:
import pathlib

def train_model(encoder : nn.Module,
                decoder : nn.Module,
                device : torch.device,
                training_dataloader : DataLoader,
                validation_dataloader : DataLoader,
                test_dataloader : DataLoader,
                encoder_optimizer : Optimizer,
                decoder_optimizer : Optimizer,
                decoder_loss_fn : nn.Module,
                epochs : int,
                last_epoch : int = 0,
                best_loss : float = float("inf"),
                teacher_forcing_ratio_decay : float = 0.,
                model_name : str | None = None) -> float:
    if model_name is None:
        timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
        model_name = f"SecondCount-{timestamp}"

    tensorboard_writer = SummaryWriter(log_dir=f"runs/{model_name}")

    encoder.to(device)
    decoder.to(device)

    pathlib.Path(f"models/{model_name}").mkdir(parents=True, exist_ok=True)

    for epoch in range(last_epoch, epochs):
        print(f"Epoch {epoch + 1}/{epochs}")

        train_loss = train(
            epoch = epoch,
            encoder = encoder,
            decoder = decoder,
            encoder_optimizer = encoder_optimizer,
            decoder_optimizer = decoder_optimizer,
            dataloader = training_dataloader,
            decoder_criterion = decoder_loss_fn,
            teacher_forcing_ratio_decay = teacher_forcing_ratio_decay,
            writer = tensorboard_writer
        )
        validation_loss = validate(
            epoch = epoch,
            encoder = encoder,
            decoder = decoder,
            dataloader = validation_dataloader,
            decoder_criterion = decoder_loss_fn,
            writer = tensorboard_writer
        )

        print(f"Train Loss: {train_loss}, Validation Loss: {validation_loss}")

        torch.save({
            "encoder" : encoder.state_dict(),
            "decoder" : decoder.state_dict(),
            "encoder_optimizer" : encoder_optimizer.state_dict(),
            "decoder_optimizer" : decoder_optimizer.state_dict(),
            "loss" : validation_loss,
            "best_loss" : best_loss,
            "epoch" : epoch
        }, f"models/{model_name}/model_{epoch}.tar")

        if validation_loss < best_loss:
            best_loss = validation_loss
            torch.save({
                "encoder" : encoder.state_dict(),
                "decoder" : decoder.state_dict(),
                "encoder_optimizer" : encoder_optimizer.state_dict(),
                "decoder_optimizer" : decoder_optimizer.state_dict(),
                "loss" : best_loss
            }, f"models/{model_name}/best_model.tar")

        
        

    test_loss = test(
        encoder = encoder,
        decoder = decoder,
        dataloader = test_dataloader,
        decoder_criterion = decoder_loss_fn
    )

    print(f"Test Loss: {test_loss}")

    tensorboard_writer.close()

    return test_loss


In [28]:
class Encoder(nn.Module):
    def __init__(self,
                 amino_acid_tokens : dict,
                 padding_token : int,
                 embedding_dim : int,
                 hidden_size : int,
                 bidirectional : bool = False,
                 num_layers : int = 1):
        super(Encoder, self).__init__()

        self.amino_acid_embedding = nn.Embedding(
            num_embeddings= len(amino_acid_tokens) + 1,
            embedding_dim = embedding_dim,
            padding_idx = padding_token
        )

        self.lstm = nn.LSTM(
            input_size = embedding_dim,
            hidden_size = hidden_size,
            batch_first = True,
            bidirectional = bidirectional,
            num_layers = num_layers
        )

    def forward(
            self, 
            input_seq : torch.Tensor,
            seq_len : torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        embedded_sequences = self.amino_acid_embedding(input_seq)

        packed_sequences = nn.utils.rnn.pack_padded_sequence(
            input = embedded_sequences, 
            lengths = seq_len.cpu(),
            batch_first = True,
            enforce_sorted = False
        )

        output, hidden = self.lstm(packed_sequences)

        output, _ = nn.utils.rnn.pad_packed_sequence(output, batch_first = True)

        return output, hidden

In [29]:
class Attention(nn.Module):
    def __init__(self, hidden_size : int, num_layers : int = 1):
        super(Attention, self).__init__()

        self.Wa = nn.Linear(hidden_size * num_layers, hidden_size)
        self.Ua = nn.Linear(hidden_size, hidden_size)
        self.Va = nn.Linear(hidden_size, 1)

    def forward(self, query : torch.Tensor, keys : torch.Tensor, mask : torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        batch_size = query.size(0)
        query = query.reshape(batch_size, 1, -1)
        scores = self.Va(torch.tanh(self.Wa(query) + self.Ua(keys)))
        scores = scores.squeeze(2).unsqueeze(1)

        scores = scores.masked_fill(~mask.unsqueeze(1), float("-inf"))

        attention_weights = torch.nn.functional.softmax(scores, dim=-1)
        context = torch.bmm(attention_weights, keys)

        return context, attention_weights

In [30]:
class Decoder(nn.Module):
    def __init__(self,
                 device : torch.device,
                 dssp_tokens : dict,
                 max_sequence_length : int,
                 embedding_dim : int,
                 padding_token : int,
                 hidden_size : int,
                 encoder_bidirectional : bool = False,
                 num_layers : int = 1,
                 dropout : float = 0.1):
        super(Decoder, self).__init__()

        self.device = device
        self.padding_token = padding_token
        self.dssp_tokens = dssp_tokens
        self.encoder_bidirectional = encoder_bidirectional
        self.max_sequence_length = max_sequence_length
        
        if encoder_bidirectional:
            hidden_size *= 2

        self.attention = Attention(hidden_size = hidden_size, num_layers = num_layers)
        self.dssp_embedding = nn.Embedding(num_embeddings = len(dssp_tokens) + 1, 
                                           embedding_dim = embedding_dim,
                                           padding_idx = padding_token)
        self.lstm = nn.LSTM(input_size = embedding_dim + hidden_size,
                            hidden_size = hidden_size,
                            batch_first = True,
                            num_layers = num_layers)
        self.fc = nn.Linear(in_features = hidden_size,
                            out_features = len(dssp_tokens) + 1)
        self.dropout = nn.Dropout(dropout)

    def forward(self,
                seq_len : torch.Tensor,
                encoder_outputs : torch.Tensor,
                encoder_hidden : torch.Tensor,
                ss_tensor : torch.Tensor | None = None,
                teacher_forcing_ratio : float = 0.) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        batch_size = seq_len.size(0)

        longest_sequence = seq_len.max().item()

        mask = torch.arange(longest_sequence).expand(batch_size, longest_sequence) < seq_len.unsqueeze(1)    
        mask = mask.to(self.device)

        decoder_input = torch.empty(batch_size, 1, dtype=torch.long).fill_(self.padding_token).to(self.device)
        if self.encoder_bidirectional:
            encoder_even_hidden = (encoder_hidden[0][::2], encoder_hidden[1][::2])
            encoder_odd_hidden = (encoder_hidden[0][1::2], encoder_hidden[1][1::2])
            encoder_hidden = (torch.cat((encoder_even_hidden[0], encoder_odd_hidden[0]), dim=-1), torch.cat((encoder_even_hidden[1], encoder_odd_hidden[1]), dim=-1))
        
        decoder_hidden = encoder_hidden

        decoder_outputs = []
        attentions = []

        for idx in range(longest_sequence):
            decoder_output, decoder_hidden, attn_weights = self.forward_step(decoder_input, mask, encoder_outputs, decoder_hidden)
            decoder_outputs.append(decoder_output)
            attentions.append(attn_weights)

            _, topi = decoder_output.topk(1)
            decoder_input = topi.squeeze(-1).detach()

            if teacher_forcing_ratio > 0 and ss_tensor is not None:
                teacher_forcing_rand = torch.rand(batch_size, device=self.device)
                teacher_forcing_mask = teacher_forcing_rand < teacher_forcing_ratio
                teacher_forcing_input = ss_tensor[:, idx].unsqueeze(1)

                decoder_input[teacher_forcing_mask] = teacher_forcing_input[teacher_forcing_mask]

        ans = torch.stack(decoder_outputs, dim=1)
        ans = torch.nn.functional.log_softmax(ans, dim=-1)
        attentions = torch.cat(attentions, dim=1)

        ans = torch.nn.functional.pad(ans, (0, 0, 0, 0, 0, self.max_sequence_length - longest_sequence, 0, 0))

        padding_mask = ~torch.nn.functional.pad(mask, (0, self.max_sequence_length - longest_sequence, 0, 0))
        padding_mask = padding_mask.view(batch_size, self.max_sequence_length, 1, 1).expand(-1, -1, -1, len(self.dssp_tokens) + 1)

        ans = ans.masked_fill(padding_mask, float("-inf"))

        zero_mask = padding_mask.clone()
        zero_mask[:,:,:,self.padding_token] = False
        zero_mask = padding_mask.masked_fill(zero_mask, False)
        
        ans = ans.masked_fill(zero_mask, 0)

        return ans, decoder_hidden, attentions
    
    def forward_step(self, 
                     input_tensor : torch.Tensor, 
                     mask : torch.Tensor,
                     encoder_outputs : torch.Tensor, 
                     encoder_hidden : torch.Tensor,) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        embedded_tensor = self.dropout(self.dssp_embedding(input_tensor))

        query = encoder_hidden[0].permute(1, 0, 2)
        context, attention_weights = self.attention(query, encoder_outputs, mask)

        input = torch.cat((embedded_tensor, context), dim=2)

        output, encoder_hidden = self.lstm(input, encoder_hidden)
        output = self.fc(output)
        return output, encoder_hidden, attention_weights

In [31]:
encoder = Encoder(
    amino_acid_tokens = amino_acids,
    padding_token = pad_token_val,
    embedding_dim = 32,
    hidden_size = 16,
    bidirectional = True,
    num_layers = 2
)

decoder = Decoder(
    device = device,
    dssp_tokens = dssp_types,
    embedding_dim = 32,
    max_sequence_length = max_sequence_length,
    padding_token = pad_token_val,
    hidden_size = 16,
    encoder_bidirectional = True,
    num_layers = 2
)

encoder_optimizer = torch.optim.Adam(encoder.parameters())
decoder_optimizer = torch.optim.Adam(decoder.parameters())


In [32]:
decoder_loss = nn.NLLLoss(ignore_index=pad_token_val)

In [33]:
teacher_forcing_ratio_decay = 0.9

In [34]:
train_model(
    encoder = encoder,
    decoder = decoder,
    device = device,
    training_dataloader = training_dataloader,
    validation_dataloader = validation_dataloader,
    test_dataloader = test_dataloader,
    encoder_optimizer = encoder_optimizer,
    decoder_optimizer = decoder_optimizer,
    decoder_loss_fn = decoder_loss,
    epochs = 25,
    teacher_forcing_ratio_decay = teacher_forcing_ratio_decay
)

Epoch 1/25


  return torch._C._nn.pad(input, pad, mode, value)


KeyboardInterrupt: 