In [None]:
from biodatasets import list_datasets, load_dataset
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, random_split
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
import numpy as np
from tqdm import tqdm
from sklearn.metrics import confusion_matrix

In [None]:
# Loading data into numpy array
pathogen = load_dataset("pathogen")

X, y = pathogen.to_npy_arrays(input_names=["sequence"], target_names=["class"])

pathogen.display_description()

In [None]:
# Encoding Amino Acids to number
def get_seq_column_map(X):
    unique = set()
    for idx, sequence in enumerate(X[0]):
        unique.update(list(sequence))
    
    return dict(zip(unique, list(range(len(unique)))))
    
pathogen_map = get_seq_column_map(X)
print(pathogen_map)

In [None]:
class PathogenDataset(Dataset):
    
    def __init__(self, pathogen_map, data):
        self.pathogen_map = pathogen_map
        self.X = data[0]
        self.Y = data[1]
        
    def __one_hot(self, Y):
        one_hot_Y = np.zeros((Y.size, Y.max() + 1))
        one_hot_Y[np.arange(Y.size), Y] = 1
        one_hot_Y = one_hot_Y
        return one_hot_Y.astype(np.float64)
    
    def __len__(self):
        return len(self.Y)
    
    def __getitem__(self, idx):
        X = torch.as_tensor([self.pathogen_map[e] for e in list(self.X[idx])]) 
        Y = self.Y[idx]
        return X, Y

def collate_padd(batch):
        x = [row[0] for row in batch]
        y = [row[1] for row in batch]
        
        sequence_len = [len(row) for row in x]
        x =  pad_sequence(x, batch_first=True)
        return (torch.as_tensor(x).to(torch.float32), torch.as_tensor(sequence_len)), torch.as_tensor(y).to(torch.float32)
    
# Split ~ 80% 10% 10%
training_set = PathogenDataset(pathogen_map,(X[0][:80000], y[0][:80000]))
training_loader = DataLoader(training_set, batch_size=8, shuffle=True, collate_fn=collate_padd)

validation_set = PathogenDataset(pathogen_map,(X[0][80000:90000], y[0][80000:90000]))
validation_loader = DataLoader(validation_set, batch_size=8, collate_fn=collate_padd)

testing_set = PathogenDataset(pathogen_map,(X[0][90000:], y[0][90000:]))
testing_loader = DataLoader(testing_set, batch_size=8, collate_fn=collate_padd)

next(iter(training_loader))

In [None]:
class Net(nn.Module):
    
    def __init__(self, input_dim=len(pathogen_map)):
        super().__init__()
        
        self.embed = nn.Embedding(
            num_embeddings=input_dim,
            embedding_dim=512,
        )
        
        self.lstm = nn.LSTM(
            input_size=512,
            hidden_size=256,
            num_layers=1,
            batch_first=True,
        )
        
        self.linear_1 = nn.Linear(
            in_features=256,
            out_features=128,
        )
        
        self.dropout = nn.Dropout(p=0.25)
        
        self.linear_2 = nn.Linear(
            in_features=128,
            out_features=1,
        )
        
    def forward(self, x, sequence_len):  
        embed = self.embed(x)
        
        packed_input = pack_padded_sequence(embed, sequence_len, batch_first=True, enforce_sorted=False)
        lstm_1_seq, _ = self.lstm(packed_input)
        output, _ = pad_packed_sequence(lstm_1_seq, batch_first=True)
        
        out_forward = output[range(len(output)), sequence_len - 1]
        
        dropout = self.dropout(out_forward)
        
        linear_1 = self.linear_1(dropout)
        dropout = self.dropout(linear_1)
        linear_2 = self.linear_2(dropout)
        
        return torch.squeeze(linear_2)

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"device : {device}")
torch.cuda.get_device_name()

In [None]:
model = Net().cuda()
print(model)

writer = SummaryWriter()

criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

for epoch in range(10):
    model.train()
    tqdm_bar = tqdm(training_loader, desc=f"epoch {epoch}", position=0)
    
    # Training
    for idx, ((inputs, sequence_len), labels) in enumerate(tqdm_bar):
        inputs = inputs.cuda()
        labels = labels.cuda()
        
        optimizer.zero_grad()
        sigmoid = nn.Sigmoid()
        outputs = sigmoid(model(inputs.to(torch.int32), sequence_len))
               
        loss = criterion(outputs, labels).to(torch.float32)
        loss.backward()
        
        writer.add_scalar('Loss/train', loss, idx)
        
        
        # Training Accuracy
        correct, total = 0, 0
        predicted = torch.round(outputs.flatten())
        y = labels

        total += labels.size(0)
        correct += (predicted == y).sum().item()
        writer.add_scalar('accuracy/train', correct/total, idx)
        
        optimizer.step()
        
    
    # Validation Accuracy
    model.eval()
    with torch.no_grad():
        for idx, ((inputs, sequence_len), labels)  in enumerate(validation_loader):
            correct, total = 0, 0
            inputs = inputs.cuda()
            labels = labels.cuda()

            outputs = sigmoid(model(inputs.to(torch.int32), sequence_len))

            predicted = torch.round(outputs.flatten())
            y = labels
            
            total += labels.size(0)
            correct += (predicted == y).sum().item()
            writer.add_scalar('accuracy/validation', correct/total, idx)
    
writer.close()

In [None]:
!tensorboard --logdir=runs

In [None]:
PATH = './pathogen_net.pth'
torch.save(model.state_dict(), PATH)

In [None]:
model = Net().cuda()
model.load_state_dict(torch.load(PATH))
model.eval()


# Testing Accuracy
correct, total = 0, 0
with torch.no_grad():
    all_predicted, all_y = [], []
    for ((inputs, sequence_len), labels) in testing_loader:
        inputs = inputs.cuda()
        labels = labels.cuda()
        
        outputs = sigmoid(model(inputs.to(torch.int32), sequence_len))
        
        predicted = torch.round(outputs.flatten())
        y = labels
        
        all_predicted.extend(predicted.tolist())
        all_y.extend(y.tolist())
        
        total += labels.size(0)
        correct += (predicted == y).sum().item()
        
print(confusion_matrix(all_y, all_predicted))
print(f'Accuracy of nn: {correct / total}')