In [1]:
import torch
import numpy as np
import pandas as pd
from torch import Tensor, nn
from tqdm.notebook import tqdm
from torchaudio.models import Conformer
from torch.utils.data import Dataset, DataLoader
from transformers import Trainer, TrainingArguments
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'



In [2]:
class conformer_encoder(nn.Module) :
    def __init__(self, kernel_size: int, num_channels: int, num_layers: int, feed_forward = 1024, num_heads = 16) :
        super().__init__()

        self.embedding = nn.Sequential(nn.Embedding(457, num_channels//4), nn.ReLU(), nn.Linear(num_channels//4, num_channels//2), 
                                        nn.ReLU(), nn.Linear(num_channels//2, num_channels))
        
        trans_encoder = nn.TransformerEncoderLayer(num_channels, num_heads, feed_forward)
        self.transformer_encoder = nn.TransformerEncoder(trans_encoder, num_layers)
        self.encoder =  Conformer(num_channels, num_heads, feed_forward, num_layers, kernel_size)     
        self.result = nn.Sequential(nn.Linear(num_channels*2, num_channels), nn.ReLU(), nn.Linear(num_channels, num_channels//2),
                                    nn.Linear(num_channels//2, num_channels//4), nn.ReLU(), nn.Linear(num_channels//4, num_channels//8), 
                                    nn.ReLU(), nn.Linear(num_channels//8, 2))
        
    def forward(self, input_ids, length, mask) :
        
        mask = torch.unsqueeze(mask, dim=-1)
        max_len = torch.max(length)
        mask = mask[:, :max_len]
        input_ids = input_ids[:, :max_len]
        embedding = self.embedding(input_ids)*mask
        
        encoded, _ = self.encoder(embedding, length)
        trans_encoded = self.transformer_encoder(embedding*mask) + embedding
        result_input = torch.concat(((encoded + embedding)*mask, trans_encoded), dim=-1)
        output = self.result(result_input)*mask

        return output

# DECODER FILE

The decoder will focus on one base at a time. It will use positional information to understand the base's location.

Additionally, it will ignore bases that cannot react with the current base by masking irrelevant parts of the encoder's output.

In [None]:
class decoder(nn.Module) :
    def __init__(self, kernel_size: int, num_channels: int, num_layers: int, feed_forward = 1024, num_heads = 16) :
        super().__init__()

        self.embedding = nn.Sequential(nn.Embedding(1, num_channels//4), nn.ReLU(), nn.Linear(num_channels//4, num_channels//2), 
                                        nn.ReLU(), nn.Linear(num_channels//2, num_channels))
        
        trans_encoder = conformer_encoder(num_channels, num_heads, feed_forward)
        self.transformer_Decoder = nn.TransformerDecoder(trans_encoder, num_layers)    
        self.result = nn.Sequential(nn.Linear(num_channels*2, num_channels), nn.ReLU(), nn.Linear(num_channels, num_channels//2),
                                    nn.Linear(num_channels//2, num_channels//4), nn.ReLU(), nn.Linear(num_channels//4, num_channels//8), 
                                    nn.ReLU(), nn.Linear(num_channels//8, 2))
        self.loss = nn.L1Loss()
        
    def forward(self, input_ids, length, mask, labels=None) :
        
        mask = torch.unsqueeze(mask, dim=-1)
        max_len = torch.max(length)
        mask = mask[:, :max_len]
        input_ids = input_ids[:, :max_len]
        embedding = self.embedding(input_ids)*mask
        
        encoded, _ = self.encoder(embedding, length)
        trans_encoded = self.transformer_encoder(embedding*mask) + embedding
        result_input = torch.concat(((encoded + embedding)*mask, trans_encoded), dim=-1)
        output = self.result(result_input)*mask
        
        if labels is not None :
            
            y = labels[:, :max_len]
            cover = y != 0
            output *= cover
            loss = torch.unsqueeze( self.loss(output, y), dim=0)
            return loss
        
        return output