# GOALS
- Pretrain Sequence model on protein sequences (GOAL to create a model that can do next token prediction accurately on protein sequences)
- Pretrain Sequence model on SMILES of Drugs

In [None]:
import json
import torch
from torch import nn, optim
from torch.data import Dataset, DataLoader
import plotly.graph_objects as go
import numpy as np
from tqdm import tqdm

In [None]:
class Vocab:
    def __init__(self, tokens): 
        special_tokens = ["PAD", "SOS", "EOS"]
        self.tokens = tokens + special_tokens
        self.token_ix = {t:i for i, t in enumerate(self.tokens)}
        self.ix_token = {i:t for i,t in enumerate(self.tokens)}

    def encode(self, seq, max_len=None):
        encoded = [self.token_ix["SOS"]] + [self.token_ix[t] for t in seq] + [self.token_ix["EOS"]]
        if max_len:
            if len(encoded) < max_len:
                encoded += [self.token_ix["PAD"]]*(max_len-len(encoded))
                
        return encoded
                
    def decode(self, seq):
        return [self.ix_token[t] for t in seq]

class ProteinVocab(Vocab):
    def __init__(self):
        # 20 amino acids
        tokens = ['A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'Y']
        super().__init__(tokens)
        
class SMILEVocab(Vocab):
    def __init__(self):
        tokens = ['#','(',')','+','-','.','/','1','2','3','4','5','6','7','8','=','@',
                  'A','B','C','F','G','H','I','K','L','M','N','O','P','S','T','V','W','Z',
                  '[','\\',']','a','b','d','e','g','i','l','n','o','r','s','t','u']
        super().__init__(tokens)

In [None]:
def generate_examples(vocab, sequence, seq_len):
    x = []
    y = []
    
    sequence = vocab.encode(sequence)
    
    for i in range(len(sequence)-seq_len-1):
        x.append(sequence[i:i+seq_len])
        y.append([sequence[i+1]])
    return x, y

In [None]:
protein_vocab = ProteinVocab()
smile_vocab = SMILEVocab()

In [None]:
SEQ_LEN = 16

In [None]:
protein_data = {
    "x": [],
    "y": []
}

smile_data = {
    "x": [],
    "y": []
}

In [None]:
ut_id_to_seq = json.load(open("../data/uniprotid_to_seq.json", 'r'))
db_id_to_smile = json.load(open("../data/databankid_to_smile.json", 'r'))

In [None]:
# protein sequences
for uniprot_id in tqdm(list(ut_id_to_seq.keys())):
    protein_sequence = ut_id_to_seq[uniprot_id]
    x, y = generate_examples(protein_vocab, protein_sequence, seq_len)
    protein_data["x"]+=x
    protein_data["y"]+=y

In [None]:
len(protein_data["x"]), len(protein_data["y"])

In [None]:
# SMILE sequences
for drugbank_id in tqdm(list(db_id_to_smile.keys())):
    smile_seq = db_id_to_smile[drugbank_id]
    x, y = generate_examples(smile_vocab, smile_seq, seq_len)
    smile_data["x"]+=x
    smile_data["y"]+=y

In [None]:
len(smile_data["x"]), len(smile_data["y"])

In [None]:
protein_data["x"] = torch.Tensor(protein_data["x"])
protein_data["y"] = torch.Tensor(protein_data["y"])

smile_data["x"] = torch.Tensor(smile_data["x"])
smile_data["y"] = torch.Tensor(smile_data["y"])

In [None]:
protein_data_indices = torch.randperm(protein_data["x"].shape[0])
smile_data_indices = torch.randperm(smile_data["x"].shape[0])

In [None]:
train_pct = .8

protein_train_indices = protein_data_indices[:int(train_pct*protein_data_indices.shape[0])]
protein_test_indices = protein_data_indices[int(train_pct*protein_data_indices.shape[0]):]

smile_train_indices = smile_data_indices[:int(train_pct*smile_data_indices.shape[0])]
smile_test_indices = smile_data_indices[int(train_pct*smile_data_indices.shape[0]):]

In [None]:
protein_train_indices.shape, protein_test_indices.shape

In [None]:
smile_train_indices.shape, smile_test_indices.shape

In [None]:
class SequenceDataset(Dataset):
    def __init__(self, data):
        self.data = data
    
    def __len__(self):
        return len(self.data["y"])
    
    def __getitem__(self, idx):
        x = self.data["x"][idx]
        y = self.data["y"][idx][0]
        
        return x, y

## Step 1.5 Baseline models

In [None]:
class GRUModel(nn.Module):
    def __init__(self, vocab_size, embed_size, hidden_size, num_layers=1, dropout=0, bidirectional=False):
        super().__init__()
        
        self.vocab_size = vocab_size
        self.embed_size = embed_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.dropout = dropout
        self.bidirectional = bidirectional
        
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.gru = nn.GRU(embed_size, hidden_size, num_layers=num_layers, dropout=dropout, batch_first=True, bidirectional=bidirectional)
        self.out = nn.Linear(hidden_size*2 if self.bidirectional else hidden_size, vocab_size)
        
        
    def forward(self, x, hidden=None):
        x = self.embedding(x)
        x, hidden = self.gru(x, hidden)
        x = self.out(x[:, -1])
        return x, hidden

In [None]:
num_epochs = 100
lr = 3e-4
lossfn = nn.CrossEntropyLoss()

In [None]:
def train(model, train, test, optimizer, epochs):
    train_loss_over_time = []
    test_loss_over_time = []
    train_accuracy_over_time = []
    test_accuracy_over_time = []
    
    for epoch in tqdm(range(num_epochs)):
        train_loss_epoch = []
        test_loss_epoch = []

        train_accuracy_epoch = []
        test_accuracy_epoch = []
        
        for x, y in train:
            optimizer.zero_grad()
            x = x.long()
            y = y.long()

            p, _ = net(x)
            loss = lossfn(p, y)

            loss.backward()

            optimizer.step()

            train_loss_epoch.append(loss.item())
            accuracy = (p.argmax(-1) == y).sum()/p.shape[0]
            train_accuracy_epoch.append(accuracy.item())
            
        with torch.no_grad():
            for x, y in test:
                x = x.long()
                y = y.long()

                p, _ = net(x)
                loss = lossfn(p, y)
                test_loss_epoch.append(loss.item())
                accuracy = (p.argmax(-1) == y).sum()/p.shape[0]
                test_accuracy_epoch.append(accuracy.item())
                
        train_loss_epoch = sum(train_loss_epoch)/len(train_loss_epoch)
        test_loss_epoch = sum(test_loss_epoch)/len(test_loss_epoch)
        train_accuracy_epoch = sum(train_accuracy_epoch)/len(train_accuracy_epoch)
        test_accuracy_epoch = sum(test_accuracy_epoch)/len(test_accuracy_epoch)

        train_loss_over_time.append(train_loss_epoch)
        train_accuracy_over_time.append(train_accuracy_epoch)
        test_loss_over_time.append(test_loss_epoch)
        test_accuracy_over_time.append(test_accuracy_epoch)
        
        print(f"Epoch : {epoch+1} | Test Loss : {test_loss_epoch:.4f} | Test Accuracy : {test_accuracy_epoch:.4f} | Train Loss : {train_loss_epoch:.4f} | Train Accuracy : {train_accuracy_epoch:.4f}")
        
    return {
        "model": model,
        "train_loss_over_time": train_loss_over_time,
        "test_loss_over_time": train_loss_over_time,
        "train_accuracy_over_time": train_accuracy_over_time,
        "test_accuracy_over_time": test_accuracy_over_time
    }

### Model #1: - single layer unidirectional gru

In [None]:
net = GRUModel(
    vocab_size=len(protein_vocab.tokens),
    embed_size=64,
    hidden_size=128,
    num_layers=1,
    dropout=0,
    bidirectional=False
)

train = SequenceDataset({
    "x": protein_data["x"][protein_train_indices],
    "y": protein_data["y"][protein_train_indices]
})

test = SequenceDataset({
    "x": protein_data["x"][protein_test_indices],
    "y": protein_data["y"][protein_test_indices]
})

train = DataLoader(train, batch_size=128, shuffle=True)
test = DataLoader(test, batch_size=128, shuffle=True)

optimized = train(model=net, train=train, test=test, optimizer=optimizer, epochs=30)

In [None]:
trained_model = optimized["model"]
train_loss_over_time = optimized["train_loss_over_time"]
train_accuracy_over_time = optimized["train_accuracy_over_time"]
test_loss_over_time = optimized["test_loss_over_time"]
test_accuracy_over_time = optimized["test_accuracy_over_time"]

In [None]:
fig = go.Figure()

fig.add_trace(
    go.Scatter(
        x=np.arange(len(train_loss_over_time))+1, 
        y=train_loss_over_time,
        mode='lines',
        name='Train loss over time'
    )
)

fig.add_trace(
    go.Scatter(
        x=np.arange(len(test_loss_over_time))+1, 
        y=test_loss_over_time,
        mode='lines',
        name='Test loss over time'
    )
)

fig.update_layout(
    title='Loss over time (Protein GRU)',
    xaxis_title='Epochs',
    yaxis_title='Cross Entropy Loss'
)

fig.update_yaxes(type="log")

fig.show()

In [None]:
fig = go.Figure()

fig.add_trace(
    go.Scatter(
        x=np.arange(len(train_accuracy_over_time))+1, 
        y=train_accuracy_over_time,
        mode='lines',
        name='Train accuracy over time'
    )
)

fig.add_trace(
    go.Scatter(
        x=np.arange(len(test_accuracy_over_time))+1, 
        y=test_accuracy_over_time,
        mode='lines',
        name='Test accuracy over time'
    )
)

fig.update_layout(
    title='Accuracy over time (Protein GRU)',
    xaxis_title='Epochs',
    yaxis_title='Cross Entropy Loss'
)

fig.show()

In [None]:
torch.save(trained_model.state_dict(), "../checkpoints/pretraining/protein_gru.pth")
torch.save(train_loss_over_time, "../scalars/pretraining/protein_gru_train_loss_over_time.pth")
torch.save(test_loss_over_time, "../scalars/pretraining/protein_gru_test_loss_over_time.pth")
torch.save(train_accuracy_over_time, "../scalars/pretraining/protein_gru_train_accuracy_over_time.pth")
torch.save(test_accuracy_over_time, "../scalars/pretraining/protein_gru_test_accuracy_over_time.pth")

In [None]:
net = GRUModel(
    vocab_size=len(smile_vocab.tokens),
    embed_size=64,
    hidden_size=128,
    num_layers=1,
    dropout=0,
    bidirectional=False
)

train = SequenceDataset({
    "x": smile_data["x"][smile_train_indices],
    "y": smile_data["y"][smile_train_indices]
})

test = SequenceDataset({
    "x": smile_data["x"][smile_test_indices],
    "y": smile_data["y"][smile_test_indices]
})

train = DataLoader(train, batch_size=128, shuffle=True)
test = DataLoader(test, batch_size=128, shuffle=True)

optimized = train(model=net, train=train, test=test, optimizer=optimizer, epochs=30)

In [None]:
trained_model = optimized["model"]
train_loss_over_time = optimized["train_loss_over_time"]
train_accuracy_over_time = optimized["train_accuracy_over_time"]
test_loss_over_time = optimized["test_loss_over_time"]
test_accuracy_over_time = optimized["test_accuracy_over_time"]

In [None]:
fig = go.Figure()

fig.add_trace(
    go.Scatter(
        x=np.arange(len(train_loss_over_time))+1, 
        y=train_loss_over_time,
        mode='lines',
        name='Train loss over time'
    )
)

fig.add_trace(
    go.Scatter(
        x=np.arange(len(test_loss_over_time))+1, 
        y=test_loss_over_time,
        mode='lines',
        name='Test loss over time'
    )
)

fig.update_layout(
    title='Loss over time (Drug GRU)',
    xaxis_title='Epochs',
    yaxis_title='Cross Entropy Loss'
)

fig.update_yaxes(type="log")

fig.show()

In [None]:
fig = go.Figure()

fig.add_trace(
    go.Scatter(
        x=np.arange(len(train_accuracy_over_time))+1, 
        y=train_accuracy_over_time,
        mode='lines',
        name='Train accuracy over time'
    )
)

fig.add_trace(
    go.Scatter(
        x=np.arange(len(test_accuracy_over_time))+1, 
        y=test_accuracy_over_time,
        mode='lines',
        name='Test accuracy over time'
    )
)

fig.update_layout(
    title='Accuracy over time (Drug GRU)',
    xaxis_title='Epochs',
    yaxis_title='Cross Entropy Loss'
)

fig.show()

In [None]:
torch.save(trained_model.state_dict(), "../checkpoints/pretraining/smile_gru.pth")
torch.save(train_loss_over_time, "../scalars/pretraining/smile_gru_train_loss_over_time.pth")
torch.save(test_loss_over_time, "../scalars/pretraining/smile_gru_test_loss_over_time.pth")
torch.save(train_accuracy_over_time, "../scalars/pretraining/smile_gru_train_accuracy_over_time.pth")
torch.save(test_accuracy_over_time, "../scalars/pretraining/smile_gru_test_accuracy_over_time.pth")