In [None]:
import re
import math
import numpy as np
import torch
import torch.nn.functional as F
from torch import Tensor, nn

device = (
    "cuda:0"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

## Data Loader

In [None]:
VFC_match = re.compile(r'[(]VFC(.*?)[)]', re.S)  # brick_match = re.compile(r'[\[](.*?)[\]]', re.S) 

VOCAB = ["<unk>"] + list('AFCUDNEQGHLIKOMPRSTVWYBZJX') + ["<pad>"]
HEADERS = ['NotVF', 'VFC0001', 'VFC0346', 'VFC0083', 'VFC0235', 'VFC0258', 'VFC0272', 'VFC0315', 'VFC0325', 'VFC0086', 'VFC0204', 'VFC0271', 'VFC0301', 'VFC0251', 'VFC0282']
BATCH_SIZE = 100
EMBED_SIZE = 20
MULTY_HEADER = 4
assert EMBED_SIZE % 2 == 0
assert MULTY_HEADER % 2 == 0
assert EMBED_SIZE % MULTY_HEADER == 0

def read_fasta(file):
    empty_line_buffer = True
    data_seq = []
    data_header = []
    with open(file,'r') as f:
        while True:
            line = f.readline().strip()      
            if not line:
                if empty_line_buffer:
                    empty_line_buffer = False
                    continue
                else:
                    break
            empty_line_buffer = True
            if '>' == line[0]:
                data_header.append(line)
                data_seq.append('')
            else:
                data_seq[-1] += line
    return list(zip(data_seq,data_header))


def tokenizer_embedder(seq):
    return [VOCAB.index(amino) for amino in list(seq)]

def header_parser(header):
    if '>VF' == header[:3]:
        return 'VFC' + re.findall(VFC_match, header)[0]
    else:
        return 'NotVF'

def header_encoder(parsed_header):
    return HEADERS.index(parsed_header)

def collate_batch(data_batch, dtype=torch.float32, batch_first=False):
    header_batch, seq_batch, seq_len = [], [], []
    for seq,header in data_batch:
        header_batch.append(header_encoder(header_parser(header)) )
        seq_batch.append(torch.tensor(tokenizer_embedder(seq), dtype=dtype))   ## .unsqueeze(1) if no embedding layer (not suggest)
        seq_len.append(len(seq))
    header_batch = torch.tensor(header_batch)
    seq_batch = torch.nn.utils.rnn.pad_sequence(seq_batch, padding_value=float(VOCAB.index('<pad>')), batch_first=batch_first) ## padded to equal
    # if pack:
        # seq_batch = nn.utils.rnn.pack_padded_sequence(seq_batch, seq_len, batch_first=batch_first, enforce_sorted=False) ## packed, for RNN/LSTM
    return header_batch.to(device),seq_batch.to(device),torch.tensor(seq_len).to(device)

## Training step

In [None]:
def train_epoch(dataloader, model, loss_fn, optimizer):
    lossSum = 0
    model.train()                                    ### set training mode
    for (header_batch, seq_batch, seq_len_batch) in dataloader:
        pred = model(seq_batch, seq_len_batch)
        # Compute prediction error
        loss = loss_fn(pred,header_batch)
        lossSum += loss.item()
        # Backpropagation
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
    avgTrainingLoss = lossSum/len(dataloader)
    return avgTrainingLoss


def test_epoch(dataloader, model, loss_fn):
    lossSum = 0
    correctSum = 0
    model.eval()
    with torch.no_grad():
        for (header_batch, seq_batch, seq_len_batch) in dataloader:
            pred = model(seq_batch, seq_len_batch)
            lossSum += loss_fn(pred,header_batch).item()
            correctSum += (pred.argmax(1) == header_batch).type(torch.float).sum().item()
    avgTestingLoss = lossSum/len(dataloader)             ## /num_batches
    avgTestingAcc  = correctSum/len(dataloader.dataset)  ## /size
    return avgTestingLoss,avgTestingAcc

## Model: Transformer

In [None]:
## Latest Suggestion: https://pytorch.org/tutorials/beginner/translation_transformer.html
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len, batch_first):
        super().__init__()
        self.dropout = nn.Dropout()
        self.batch_first = batch_first
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
        )
        if self.batch_first:
            pe = torch.zeros(1, max_len, d_model)
            pe[0, :, 0::2] = torch.sin(position * div_term)
            pe[0, :, 1::2] = torch.cos(position * div_term) ##  when batch_first=True: [batch, seq_len, dmodel]
        else:
            pe = torch.zeros(max_len, 1, d_model)
            pe[:, 0, 0::2] = torch.sin(position * div_term) ##  when batch_first=False: [seq_len, batch, dmodel]
            pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer("pe", pe)
    
    def forward(self, x):   
        if self.batch_first:
            x = x[:,:self.pe.size(1)] + self.pe[:, : x.size(1), :]
        else:
            x = x[:self.pe.size(0)] + self.pe[:x.size(0)]   
        return self.dropout(x)


class TransformerModel(nn.Module):
    def __init__(self, VOCAB_size, d_model, nhead, num_layers, out_class, batch_first):  ## Suggest: batch_first=True
        super().__init__()
        self.batch_first = batch_first
        self.embedding = nn.Embedding(VOCAB_size, d_model)    ## turn each cell into emb_vector
        self.pos_encoder = PositionalEncoding(d_model, 5000, batch_first)
        encode_layer = nn.TransformerEncoderLayer(d_model, nhead, batch_first=batch_first)
        encoder_norm = nn.LayerNorm(d_model)
        self.TransformerEncoder = nn.TransformerEncoder(encode_layer, num_layers, encoder_norm)
        self.linear = nn.Linear(d_model, out_class)
        
    def forward(self, input_batch, _):
        x = self.embedding(input_batch)    ##  [seq_len <-> batch, dmodel]
        x = self.pos_encoder(x)            ## |
        x = self.TransformerEncoder(x)     ## |  
        if self.batch_first:
            x = self.linear(x[:,0,:])          
        else:
            x = self.linear(x[0,:,:])      ## output[batch, class_num]
        return x


In [None]:
if_bfA = True

train_dl = torch.utils.data.DataLoader(read_fasta('trainset.faa'), batch_size=BATCH_SIZE, shuffle=True, 
                                       collate_fn = lambda x: collate_batch(x, dtype=torch.int, batch_first=if_bfA) )
test_dl = torch.utils.data.DataLoader(read_fasta('testset.faa'), batch_size=BATCH_SIZE, shuffle=False, 
                                      collate_fn = lambda x: collate_batch(x, dtype=torch.int, batch_first=if_bfA) )

modelA = TransformerModel(len(VOCAB), EMBED_SIZE, MULTY_HEADER, 2, len(HEADERS), if_bfA )
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(modelA.parameters(), lr=1e-3)

epochs = 5
for t in range(epochs):
    avgTrainingLoss = train_epoch(train_dl, modelA, loss_fn, optimizer)
    avgTestingLoss,avgTestingAcc = test_epoch(test_dl, modelA, loss_fn)
    print(f'Epoch {t+1}----Testing Acc:: {avgTestingAcc:>7f}') 

## Model: LSTM

MANBA/LSTM -- NN -- category



In [None]:
## https://www.cnblogs.com/BlueBlueSea/p/13723560.html
class LSTM_Net(nn.Module):
    def __init__(self, VOCAB_size, d_model, hidden_size, num_layers, out_class, batch_first):  ## Suggest: batch_first=False
        super().__init__()
        self.batch_first = batch_first
        self.embedding = nn.Embedding(VOCAB_size, d_model)   
        self.lstm = nn.LSTM(
            input_size=d_model, hidden_size=hidden_size, num_layers=num_layers, batch_first=batch_first, bidirectional=True
        )                                         
        self.linear = nn.Linear(hidden_size*2, out_class)
    def forward(self, input_batch, seq_len_batch):
        x = self.embedding(input_batch) 
        x = nn.utils.rnn.pack_padded_sequence(x, seq_len_batch, batch_first=self.batch_first, enforce_sorted=False)
        x,_ = self.lstm(x)
        x,_ = nn.utils.rnn.pad_packed_sequence(x, padding_value=float(VOCAB.index('<pad>')), batch_first=self.batch_first)   ## [B<->SeqLen,2*hidden_size]
        if self.batch_first:  
            x = self.linear(x[:, -1, :])   ## [B!,SeqLen,2*hidden_size]
        else:
            x = self.linear(x[-1])         ## [SeqLen,B!,2*hidden_size]
        return x


## LSTM:  all, (hidden,cell) = self.lstm(x) ---> [2*num_layers <-> batch, 2*hidden_size]
## h_final = x[:, -1, :] when batch_first=True   [B!,SeqLen,2*hidden_size]
## h_final = x[-1]  when batch_first=False   [SeqLen,B!,2*hidden_size]

In [None]:
if_bfB = False

train_dl = torch.utils.data.DataLoader(read_fasta('trainset.faa'), batch_size=BATCH_SIZE, shuffle=True, 
                                       collate_fn = lambda x: collate_batch(x, dtype=torch.int, batch_first=if_bfB) )
test_dl = torch.utils.data.DataLoader(read_fasta('testset.faa'), batch_size=BATCH_SIZE, shuffle=False, 
                                      collate_fn = lambda x: collate_batch(x, dtype=torch.int, batch_first=if_bfB) )


modelB = LSTM_Net(len(VOCAB), EMBED_SIZE, 33, 2, len(HEADERS), if_bfB )
loss_fnB = nn.CrossEntropyLoss()
optimizerB = torch.optim.SGD(modelB.parameters(), lr=1e-3)

epochs = 5
for t in range(epochs):
    avgTrainingLoss = train_epoch(train_dl, modelB, loss_fnB, optimizerB)
    avgTestingLoss,avgTestingAcc = test_epoch(test_dl, modelB, loss_fnB)
    print(f'Epoch {t+1}----Testing Acc:: {avgTestingAcc:>7f}') 
