##### Categorizing Imports

In [1]:
import tensorflow as tf
import numpy as np
import os
import pandas as pd
import matplotlib.pyplot as plt

import torch
import torchtext
from collections import OrderedDict as od

In [2]:
# importing the data

import os
# from google.colab import drive
# drive.mount('/content/drive')
# os.listdir("/content/drive/MyDrive/Transformer-Viruses")

import pandas as pd
viralData = pd.read_csv("../data/genome_data_with_transmission_levels.csv")
viralData.drop("Unnamed: 0", axis=1, inplace=True)
print(viralData.columns)
display(viralData.head())

Index(['Locus', 'Position/Length Indicator?', 'Virus Name', 'Genome',
       'Estimated Transmission Level'],
      dtype='object')


Unnamed: 0,Locus,Position/Length Indicator?,Virus Name,Genome,Estimated Transmission Level
0,NC_034975,1239568,Mamastrovirus 4,AAGAAGGAGGTTATCAAAGAGGAAAAGATCAAGAACAATGACATCC...,4b
1,NC_001560,1972577,Indiana vesiculovirus,ACGAAGACAAACAAACCATTATTATCATTAAAAGGCTCAGGAGAAA...,2
2,NC_038236,1972577,Indiana vesiculovirus,ACGAAGACAAACAAACCATTATTACCATTAAAAGGCTCAGGAGAAA...,2
3,NC_021928,2560526,Human orthorubulavirus 4,ACCAAGGGGAGAAGAGATATGGATACTGATCTGGAAAATTAAAGGT...,4b
4,NC_014373,565995,Bundibugyo ebolavirus,CGGACACACAAAAAGAATGAAGGATTTTGAATCTTTATTGTGTGCG...,3


## Create the tokenizer


In [3]:
# define a vocabulary, this should be a mapping from all of the possible characters
# that can appear in our sequence
# we can either do this manually by defining a dictionary and writing {'A': 0, 'G': 1, 'T': 2, ...}
# or you can use torchtext.vocab.build_vocab_from_iterator (this is probably easier)
from torchtext.vocab import vocab
from torchtext.data.utils import get_tokenizer

sample_sequence = viralData.iloc[0, 3]

vocabs = od()
vocabs["A"] = int(1)
vocabs["G"] = int(2)
vocabs["T"] = int(3)
vocabs["C"] = int(4)
vocabs["U"] = int(5)
# print("Vocab is: {}".format(vocabs)) # sanity check that all of the possible values are reflected

vocab = vocab(vocabs, specials=["<unk>"])
vocab.set_default_index(vocab["<unk>"])
tokenizer = get_tokenizer(vocab)
print(tokenizer(list(sample_sequence)))

viralData["Genome"] = viralData["Genome"].map(lambda x: tokenizer(list(x)))
viralData.head()

[1, 1, 2, 1, 1, 2, 2, 1, 2, 2, 3, 3, 1, 3, 4, 1, 1, 1, 2, 1, 2, 2, 1, 1, 1, 1, 2, 1, 3, 4, 1, 1, 2, 1, 1, 4, 1, 1, 3, 2, 1, 4, 1, 3, 4, 4, 2, 4, 4, 1, 1, 1, 3, 4, 4, 3, 4, 3, 2, 3, 2, 4, 3, 2, 1, 4, 4, 4, 4, 1, 3, 4, 3, 1, 4, 1, 4, 3, 1, 2, 2, 1, 3, 3, 2, 2, 2, 2, 4, 3, 2, 3, 4, 3, 3, 2, 2, 1, 2, 3, 4, 1, 4, 1, 4, 4, 1, 2, 1, 1, 4, 3, 4, 4, 3, 3, 2, 1, 3, 2, 1, 1, 1, 2, 1, 2, 4, 2, 1, 1, 4, 3, 2, 1, 1, 1, 4, 1, 2, 4, 1, 2, 3, 2, 2, 2, 2, 4, 1, 1, 3, 2, 3, 2, 2, 2, 3, 2, 2, 3, 4, 4, 4, 4, 1, 1, 3, 2, 4, 2, 3, 2, 2, 3, 2, 2, 2, 3, 3, 4, 1, 4, 2, 1, 2, 1, 1, 2, 2, 1, 3, 2, 1, 2, 2, 1, 2, 1, 4, 3, 4, 1, 3, 4, 2, 4, 1, 1, 1, 1, 2, 2, 4, 1, 1, 4, 1, 1, 1, 3, 1, 4, 3, 3, 4, 1, 3, 2, 2, 1, 1, 3, 3, 3, 2, 1, 4, 3, 2, 2, 1, 4, 4, 1, 2, 1, 3, 1, 3, 2, 1, 3, 2, 2, 4, 1, 4, 1, 1, 3, 3, 4, 4, 1, 1, 4, 2, 4, 4, 1, 4, 3, 4, 4, 3, 4, 3, 3, 2, 4, 1, 3, 1, 3, 4, 1, 1, 1, 1, 1, 1, 4, 3, 4, 1, 2, 2, 3, 2, 2, 3, 4, 1, 1, 3, 2, 1, 3, 4, 1, 1, 4, 2, 1, 2, 2, 3, 2, 4, 1, 1, 1, 2, 2, 1, 1, 1, 1, 1, 2, 3, 1, 4, 

Unnamed: 0,Locus,Position/Length Indicator?,Virus Name,Genome,Estimated Transmission Level
0,NC_034975,1239568,Mamastrovirus 4,"[1, 1, 2, 1, 1, 2, 2, 1, 2, 2, 3, 3, 1, 3, 4, ...",4b
1,NC_001560,1972577,Indiana vesiculovirus,"[1, 4, 2, 1, 1, 2, 1, 4, 1, 1, 1, 4, 1, 1, 1, ...",2
2,NC_038236,1972577,Indiana vesiculovirus,"[1, 4, 2, 1, 1, 2, 1, 4, 1, 1, 1, 4, 1, 1, 1, ...",2
3,NC_021928,2560526,Human orthorubulavirus 4,"[1, 4, 4, 1, 1, 2, 2, 2, 2, 1, 2, 1, 1, 2, 1, ...",4b
4,NC_014373,565995,Bundibugyo ebolavirus,"[4, 2, 2, 1, 4, 1, 4, 1, 4, 1, 1, 1, 1, 1, 2, ...",3


In [4]:
type(viralData["Genome"][0][0])
print(viralData.columns)
display(viralData.copy())

Index(['Locus', 'Position/Length Indicator?', 'Virus Name', 'Genome',
       'Estimated Transmission Level'],
      dtype='object')


Unnamed: 0,Locus,Position/Length Indicator?,Virus Name,Genome,Estimated Transmission Level
0,NC_034975,1239568,Mamastrovirus 4,"[1, 1, 2, 1, 1, 2, 2, 1, 2, 2, 3, 3, 1, 3, 4, ...",4b
1,NC_001560,1972577,Indiana vesiculovirus,"[1, 4, 2, 1, 1, 2, 1, 4, 1, 1, 1, 4, 1, 1, 1, ...",2
2,NC_038236,1972577,Indiana vesiculovirus,"[1, 4, 2, 1, 1, 2, 1, 4, 1, 1, 1, 4, 1, 1, 1, ...",2
3,NC_021928,2560526,Human orthorubulavirus 4,"[1, 4, 4, 1, 1, 2, 2, 2, 2, 1, 2, 1, 1, 2, 1, ...",4b
4,NC_014373,565995,Bundibugyo ebolavirus,"[4, 2, 2, 1, 4, 1, 4, 1, 4, 1, 1, 1, 1, 1, 2, ...",3
...,...,...,...,...,...
375,HM181996,2734523,Cardiovirus D,"[3, 3, 3, 2, 1, 1, 4, 2, 4, 2, 2, 2, 1, 4, 4, ...",2
376,MG210578,2044876,Human hepegivirus 1,"[1, 1, 4, 3, 2, 3, 3, 2, 3, 3, 2, 3, 1, 2, 4, ...",4a
377,MW149253,1912147,Mamastrovirus sp.,"[4, 4, 1, 1, 2, 1, 4, 1, 2, 2, 3, 2, 2, 3, 2, ...",4b
378,MG026496,2039694,Salivirus sp.,"[2, 2, 4, 2, 2, 2, 4, 3, 3, 2, 3, 2, 2, 1, 4, ...",4a


## Create our dataset

In [5]:
numpyViralData = viralData.copy()

for i in range(0, len(viralData["Genome"])):
    numpyViralData.at[i, "Genome"] = np.array(viralData.loc[i, "Genome"])

display(numpyViralData)
print(type(numpyViralData["Genome"][0][0]))
# tensorViralData = tf.convert_to_tensor(numpyViralData)
# print(numpyViralData)

Unnamed: 0,Locus,Position/Length Indicator?,Virus Name,Genome,Estimated Transmission Level
0,NC_034975,1239568,Mamastrovirus 4,"[1, 1, 2, 1, 1, 2, 2, 1, 2, 2, 3, 3, 1, 3, 4, ...",4b
1,NC_001560,1972577,Indiana vesiculovirus,"[1, 4, 2, 1, 1, 2, 1, 4, 1, 1, 1, 4, 1, 1, 1, ...",2
2,NC_038236,1972577,Indiana vesiculovirus,"[1, 4, 2, 1, 1, 2, 1, 4, 1, 1, 1, 4, 1, 1, 1, ...",2
3,NC_021928,2560526,Human orthorubulavirus 4,"[1, 4, 4, 1, 1, 2, 2, 2, 2, 1, 2, 1, 1, 2, 1, ...",4b
4,NC_014373,565995,Bundibugyo ebolavirus,"[4, 2, 2, 1, 4, 1, 4, 1, 4, 1, 1, 1, 1, 1, 2, ...",3
...,...,...,...,...,...
375,HM181996,2734523,Cardiovirus D,"[3, 3, 3, 2, 1, 1, 4, 2, 4, 2, 2, 2, 1, 4, 4, ...",2
376,MG210578,2044876,Human hepegivirus 1,"[1, 1, 4, 3, 2, 3, 3, 2, 3, 3, 2, 3, 1, 2, 4, ...",4a
377,MW149253,1912147,Mamastrovirus sp.,"[4, 4, 1, 1, 2, 1, 4, 1, 2, 2, 3, 2, 2, 3, 2, ...",4b
378,MG026496,2039694,Salivirus sp.,"[2, 2, 4, 2, 2, 2, 4, 3, 3, 2, 3, 2, 2, 1, 4, ...",4a


<class 'numpy.int64'>


In [6]:
import torch
from torch.utils.data import Dataset

# Pytorch defines a nice dataset class that only requires we implement two functions:
# 1. __len__
# 2. __getitem__
# Our dataset is just the set of all samples that we want to train our model on
# __len__ should get us the total number of samples
# __getitem__ takes in an integer and should give us the corresponding element in the dataset

# You might be wondering why they have those weird underscores? That enables us to
# call those functions on the class instances directly
# For instance if we define an instance of our Dataset class, then we can get the length of it
# by invoking `len` on it directly
# new_dataset = GenomicR0ValueDataset(some_sample_sequences, some_sample_r0_values, new_tokenizer)
# len(new_dataset) # this will call GenomicR0ValueDataset.__len__


class viralDataset(Dataset):
    def __init__(self, sequences: list, r0_values: np.ndarray, tokenizer):
        self.sequences = sequences
        self.r0_values = r0_values.clone().detach().numpy()
        self.tokenizer = tokenizer
        # here we should initialize some class variables using `self.` so that
        # we can access them further along
        # make sure all of our sequences are tokenized so that when we return them in __getitem__
        # we don't have to do any post-processing on them during training

    def __len__(self):
        return self.r0_values.shape[0]
        # this needs to return the number of elements in our dataset
        # so this should be the total number of sequences

    def __getitem__(self, index):
        try:
            sequence = self.sequences[index].tolist()
            # sequence = self.tokenizer(sequence)
            # Ensure the sequence is a PyTorch tensor
            sequence = torch.tensor(sequence, dtype=torch.float32)
        except Exception as e:
            print(f"Error processing sequence at index {index}: {e}")
            raise

        try:
            r0_value = self.r0_values[index]
            # Ensure the r0_value is a PyTorch tensor
            r0_value = torch.tensor(r0_value, dtype=torch.float32)
        except Exception as e:
            print(f"Error processing r0_value at index {index}: {e}")
            raise

        return sequence, r0_value
        # get item should take in an index and return the corresponding genomic sequence
        # AND its r0 value. we need to return both because every time we pass a sequence
        # through our model, we need to compare its predicted r0 value to the correct r0
        # value. so the format of this output should be returning two things like this:
        # return _, _

## Create our dataloader

In [7]:
# imports
import torch
from torch.utils.data import random_split, DataLoader
from torchvision import datasets, transforms
from torch import nn, optim

In [8]:
numpyViralData.shape

(380, 5)

In [9]:
# finding the maximum length of the sequences
max_len = 0
for i in range(0, len(numpyViralData["Genome"])):
    if len(numpyViralData["Genome"][i]) > max_len:
        max_len = len(numpyViralData["Genome"][i])
print(max_len)

31028


In [10]:
sequences = []
# ! WE ARE USING LISTS HERE, NOT NUMPY ARRAYS


for i in range(numpyViralData.shape[0]):
    stacked_array = np.stack(numpyViralData.loc[i, "Genome"])
    tensor = torch.tensor(stacked_array, dtype=torch.float32)
    sequences.append(tensor)

# sequences = np.vstack(sequences)
# display(np.array(sequences))
# print(type(sequences[0]))
display(sequences)

[tensor([1., 1., 2.,  ..., 1., 1., 1.]),
 tensor([1., 4., 2.,  ..., 4., 2., 3.]),
 tensor([1., 4., 2.,  ..., 4., 2., 3.]),
 tensor([1., 4., 4.,  ..., 2., 2., 3.]),
 tensor([4., 2., 2.,  ..., 4., 4., 1.]),
 tensor([4., 2., 4.,  ..., 2., 4., 2.]),
 tensor([4., 2., 4.,  ..., 2., 2., 2.]),
 tensor([2., 1., 2.,  ..., 3., 1., 4.]),
 tensor([3., 3., 4.,  ..., 1., 1., 3.]),
 tensor([1., 4., 4.,  ..., 2., 2., 3.]),
 tensor([1., 2., 1.,  ..., 2., 4., 3.]),
 tensor([2., 4., 3.,  ..., 1., 3., 4.]),
 tensor([2., 4., 3.,  ..., 1., 3., 4.]),
 tensor([2., 4., 3.,  ..., 1., 3., 4.]),
 tensor([2., 4., 3.,  ..., 1., 3., 4.]),
 tensor([2., 4., 3.,  ..., 1., 3., 4.]),
 tensor([2., 4., 3.,  ..., 1., 3., 4.]),
 tensor([2., 4., 3.,  ..., 1., 3., 4.]),
 tensor([2., 4., 3.,  ..., 1., 3., 4.]),
 tensor([2., 4., 3.,  ..., 1., 3., 4.]),
 tensor([2., 4., 3.,  ..., 1., 3., 4.]),
 tensor([1., 3., 2.,  ..., 3., 1., 2.]),
 tensor([1., 3., 2.,  ..., 3., 3., 4.]),
 tensor([3., 3., 1.,  ..., 3., 1., 3.]),
 tensor([4., 2.,

In [11]:
type(sequences[0][0])

torch.Tensor

In [12]:
sequences[0].tolist()

[1.0,
 1.0,
 2.0,
 1.0,
 1.0,
 2.0,
 2.0,
 1.0,
 2.0,
 2.0,
 3.0,
 3.0,
 1.0,
 3.0,
 4.0,
 1.0,
 1.0,
 1.0,
 2.0,
 1.0,
 2.0,
 2.0,
 1.0,
 1.0,
 1.0,
 1.0,
 2.0,
 1.0,
 3.0,
 4.0,
 1.0,
 1.0,
 2.0,
 1.0,
 1.0,
 4.0,
 1.0,
 1.0,
 3.0,
 2.0,
 1.0,
 4.0,
 1.0,
 3.0,
 4.0,
 4.0,
 2.0,
 4.0,
 4.0,
 1.0,
 1.0,
 1.0,
 3.0,
 4.0,
 4.0,
 3.0,
 4.0,
 3.0,
 2.0,
 3.0,
 2.0,
 4.0,
 3.0,
 2.0,
 1.0,
 4.0,
 4.0,
 4.0,
 4.0,
 1.0,
 3.0,
 4.0,
 3.0,
 1.0,
 4.0,
 1.0,
 4.0,
 3.0,
 1.0,
 2.0,
 2.0,
 1.0,
 3.0,
 3.0,
 2.0,
 2.0,
 2.0,
 2.0,
 4.0,
 3.0,
 2.0,
 3.0,
 4.0,
 3.0,
 3.0,
 2.0,
 2.0,
 1.0,
 2.0,
 3.0,
 4.0,
 1.0,
 4.0,
 1.0,
 4.0,
 4.0,
 1.0,
 2.0,
 1.0,
 1.0,
 4.0,
 3.0,
 4.0,
 4.0,
 3.0,
 3.0,
 2.0,
 1.0,
 3.0,
 2.0,
 1.0,
 1.0,
 1.0,
 2.0,
 1.0,
 2.0,
 4.0,
 2.0,
 1.0,
 1.0,
 4.0,
 3.0,
 2.0,
 1.0,
 1.0,
 1.0,
 4.0,
 1.0,
 2.0,
 4.0,
 1.0,
 2.0,
 3.0,
 2.0,
 2.0,
 2.0,
 2.0,
 4.0,
 1.0,
 1.0,
 3.0,
 2.0,
 3.0,
 2.0,
 2.0,
 2.0,
 3.0,
 2.0,
 2.0,
 3.0,
 4.0,
 4.0,
 4.0,
 4.0,
 1.0,
 1.0,
 3.0

In [13]:
# tokenizer([1,2,3,2,1,3,2])

In [14]:
# getting the categorical values
r0_values = numpyViralData["Estimated Transmission Level"].values
print(f"Possible Transmssion Values: {np.unique(r0_values)}")

# converting those r0 values to a numbers (so it's more categorical)
value_dict = {
    "2": 0,
    "3": 1,
    "4a": 2,
    "4b": 3,
}

transmission_levels = np.array(numpyViralData["Estimated Transmission Level"])

for i in range(len(transmission_levels)):
   transmission_levels[i] = value_dict[transmission_levels[i].strip()]

# display(transmission_levels)

# transform these levels into a stacked tensor
# stacked_transmission_levels = np.stack([transmission_levels]).astype(int)
stacked_transmission_levels = np.array(transmission_levels).astype(int)
tensor_transmission_levels = torch.tensor(
    stacked_transmission_levels, dtype=torch.int64
)
display(tensor_transmission_levels)

Possible Transmssion Values: ['2' '3' '4a' '4b']


tensor([3, 0, 0, 3, 1, 1, 1, 3, 2, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 0, 3,
        0, 1, 1, 0, 2, 2, 3, 0, 0, 0, 1, 0, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
        0, 0, 0, 0, 0, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 1, 2, 0, 3, 0, 0, 0,
        3, 1, 1, 1, 2, 0, 0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 2, 0, 2,
        0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 2, 3, 3, 2, 2, 2, 2, 0, 0, 0, 2, 1, 3,
        1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 3, 0, 1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 3, 3, 3, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 3, 0, 2,
        0, 2, 2, 1, 0, 0, 2, 3, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 3, 2, 2, 0, 0, 0, 0, 0, 0, 1, 1, 3, 2, 2, 2, 2, 2, 2, 2,
        2, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 2, 2, 2, 1, 1, 3, 3, 3, 2, 0, 2, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 3, 0, 0, 0, 3, 3, 2, 2, 2, 2, 2, 2,
        2, 2, 0, 2, 2, 3, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 1, 2, 1, 1, 1, 1, 2,

In [15]:
# pytorch dataloaders: https://pytorch.org/tutorials/beginner/basics/data_tutorial.html
# dataloaders take in a pytorch.Dataset as an argument (like the one we defined above!)


dataset = viralDataset(sequences, tensor_transmission_levels, tokenizer)
print(viralData["Genome"].tolist())
# lets instantiate that GenomicR0ValueDataset we defined above

# define some train test split
train_test_ratio = 0.9
train_size = int(
    train_test_ratio * len(dataset)
)  # how convenient that we can called `len` on our dataset!!
test_size = len(dataset) - train_size

# now use random_split from torch.utils.data to define our two datasets
train_dataset = random_split(
    dataset, [train_size, test_size], generator=torch.Generator().manual_seed(42)
)
test_dataset = random_split(
    dataset, [test_size, train_size], generator=torch.Generator().manual_seed(42)
)


# define our batch_size
batch_size = (
    32  # batch size defines how many sequences our model will be processing at once
)
# higher batch sizes mean training will be faster, but will make updates slightly less precisely (will talk about this in person)
# another thing to keep in mind is the amount of MEMORY that we have! our batches can't get too big!

train_dataloader = DataLoader(train_dataset, batch_size=batch_size)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size)

[[1, 1, 2, 1, 1, 2, 2, 1, 2, 2, 3, 3, 1, 3, 4, 1, 1, 1, 2, 1, 2, 2, 1, 1, 1, 1, 2, 1, 3, 4, 1, 1, 2, 1, 1, 4, 1, 1, 3, 2, 1, 4, 1, 3, 4, 4, 2, 4, 4, 1, 1, 1, 3, 4, 4, 3, 4, 3, 2, 3, 2, 4, 3, 2, 1, 4, 4, 4, 4, 1, 3, 4, 3, 1, 4, 1, 4, 3, 1, 2, 2, 1, 3, 3, 2, 2, 2, 2, 4, 3, 2, 3, 4, 3, 3, 2, 2, 1, 2, 3, 4, 1, 4, 1, 4, 4, 1, 2, 1, 1, 4, 3, 4, 4, 3, 3, 2, 1, 3, 2, 1, 1, 1, 2, 1, 2, 4, 2, 1, 1, 4, 3, 2, 1, 1, 1, 4, 1, 2, 4, 1, 2, 3, 2, 2, 2, 2, 4, 1, 1, 3, 2, 3, 2, 2, 2, 3, 2, 2, 3, 4, 4, 4, 4, 1, 1, 3, 2, 4, 2, 3, 2, 2, 3, 2, 2, 2, 3, 3, 4, 1, 4, 2, 1, 2, 1, 1, 2, 2, 1, 3, 2, 1, 2, 2, 1, 2, 1, 4, 3, 4, 1, 3, 4, 2, 4, 1, 1, 1, 1, 2, 2, 4, 1, 1, 4, 1, 1, 1, 3, 1, 4, 3, 3, 4, 1, 3, 2, 2, 1, 1, 3, 3, 3, 2, 1, 4, 3, 2, 2, 1, 4, 4, 1, 2, 1, 3, 1, 3, 2, 1, 3, 2, 2, 4, 1, 4, 1, 1, 3, 3, 4, 4, 1, 1, 4, 2, 4, 4, 1, 4, 3, 4, 4, 3, 4, 3, 3, 2, 4, 1, 3, 1, 3, 4, 1, 1, 1, 1, 1, 1, 4, 3, 4, 1, 2, 2, 3, 2, 2, 3, 4, 1, 1, 3, 2, 1, 3, 4, 1, 1, 4, 2, 1, 2, 2, 3, 2, 4, 1, 1, 1, 2, 2, 1, 1, 1, 1, 1, 2, 3, 1, 4,

In [16]:
dataset[1]

(tensor([1., 4., 2.,  ..., 4., 2., 3.]), tensor(0.))

In [17]:
# print(type(train_dataset[0]))
#

## Define our model architecture

### Positional Encoding

In [18]:
# positional encoding explanation
# https://machinelearningmastery.com/a-gentle-introduction-to-positional-encoding-in-transformer-models-part-1/
import math
import torch
import torch.nn as nn


class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=0.1)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
        )
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer("pe", pe)

    def forward(self, x):
        """
        Arguments:
            x: Tensor, shape ``[seq_len, batch_size, embedding_dim]``
        """
        x = x + self.pe[: x.size(0), None, :]
        return self.dropout(x)


# positional encoding is critical for the transformer model
# d_model is the embedding dimension of the model
# max_len here is our context window size

In [19]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [20]:
import math

import torch
from torch import nn, Tensor
from torch.nn import TransformerEncoder, TransformerEncoderLayer


class GenomeR0ValueModel(nn.Module):
    def __init__(self, ntoken: int, d_model: int, nhead: int, d_hid: int, nlayers: int):
        super().__init__()
        self.model_type = "Transformer"
        self.pos_encoder = PositionalEncoding(d_model)
        encoder_layers = TransformerEncoderLayer(d_model, nhead, d_hid)
        self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
        self.embedding = nn.Embedding(ntoken, d_model)
        self.d_model = d_model
        self.linear = nn.Linear(d_model, 1)
        self.init_weights()

    def init_weights(self) -> None:
        initrange = 0.1
        self.embedding.weight.data.uniform_(-initrange, initrange)
        self.linear.bias.data.zero_()
        self.linear.weight.data.uniform_(-initrange, initrange)

    def forward(self, src: Tensor, src_mask: Tensor = None) -> Tensor:
        assert src.numel() > 0, "src is empty"
        src = self.embedding(src.to(torch.long)) * math.sqrt(self.d_model)
        src = self.pos_encoder(src.to(torch.long))
        if src_mask is None:
            src_mask = nn.Transformer.generate_square_subsequent_mask(len(src)).to(device)
        src = src.view(src.size(0), -1, self.d_model)
        output = self.transformer_encoder(src, src_mask)
        output = self.linear(output)
        print(output.size())
        assert output.numel() > 0, "output is empty"
        return output[-1, -1, 0]


# here, we define the model class, it is based on a transformer architecture
# and it takes in a ntoken, which is the size of the vocabulary (number of unique characters in our input)
# d_model, which is the embedding size of the tokens
# nhead is the number of heads in our self-attention setup
# d_hid is the dimension of the hidden layer
# n_layers is the number of hidden layers

## Create our train and test loop and begin trainin

In [21]:
# initialize our model
import torch
import torch.nn as nn

model = GenomeR0ValueModel(
    ntoken=len(vocabs), d_model=64, nhead=8, d_hid=256, nlayers=8
)



In [22]:
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm  # Progress bar (optional, but very helpful)


def train_model(
    model,
    train_loader,
    test_loader,
    criterion,
    optimizer,
    num_epochs,
    device="cuda" if torch.cuda.is_available() else "cpu",
):
    # move the
    model.to(device)
    best_val_loss = float("inf")
    for epoch in range(num_epochs):
        # Training Phase
        model.train()
        train_loss = 0.0
        pbar = tqdm(train_loader, desc=f"Epoch {epoch + 1}/{num_epochs}")
        for inputs, targets in pbar:
            # print(inputs.shape)
            # print(targets.shape)
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(
                outputs, targets.unsqueeze(1)
            )  # Ensure target shape matches output
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            pbar.set_postfix({"Train Loss": loss.item()})

        train_loss /= len(train_loader)

        # Test Phase
        model.eval()
        val_loss = 0.0

        with torch.no_grad():
            for inputs, targets in test_loader:
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, targets.unsqueeze(1))
                val_loss += loss.item()
            val_loss /= len(test_loader)
            
            # Save the best model
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                torch.save(model.state_dict(), "best_model.pth")

In [23]:
from torch.utils.data import ConcatDataset

# Assuming `dataset` is an instance of your `viralDataset` class...
train_size = int(0.8 * len(sequences))  # 80% of the dataset for training
val_size = len(sequences) - train_size  # the rest for validation


In [24]:
total_train = ConcatDataset(train_dataset)
total_test = ConcatDataset(test_dataset)

In [25]:
total_train[0][1].size()

torch.Size([])

In [26]:
train_loader = DataLoader(total_train, batch_size=1, shuffle=True)
val_loader = DataLoader(total_test, batch_size=1, shuffle=True)


In [27]:
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)

## Begin training

In [28]:
# Train the model
train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=75)

Epoch 1/75:   0%|          | 0/380 [00:00<?, ?it/s]

  return F.mse_loss(input, target, reduction=self.reduction)
Epoch 1/75: 100%|██████████| 380/380 [02:05<00:00,  3.02it/s, Train Loss=1.08]    
Epoch 2/75:   2%|▏         | 8/380 [00:04<03:45,  1.65it/s, Train Loss=0.21]   


KeyboardInterrupt: 

In [None]:
model.eval()

In [None]:
import dill as pickle

savedModel = "/content/drive/MyDrive/Transformer-Viruses/Genome_Transmissivity_model_dill.pkl"
with open(savedModel, "wb") as f:
    pickle.dump(model, f)