# Imports

In [None]:
import os
import yaml
import json

from tqdm import tqdm

import torch
from torch import nn
from torch import optim
from torch.utils.data import DataLoader, Dataset
from torch.utils.tensorboard import SummaryWriter
from torch.nn.functional import cosine_similarity, normalize

import numpy as np

from trainer import Trainer


# Loss

In [None]:
class ContrastiveLoss(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, pred, labels):
        (first, second) = pred

        first = normalize(first, p=2, dim=1) # l2 norm
        second = normalize(second, p=2, dim=1) # l2 norm 

        # (N, e) @ (e, N) -> (N, N)
        # 
        
        distance = torch.abs(cosine_similarity(first, second)) # (N)

        loss = .5 * (labels * (1-distance) + (1-labels) * distance)
        loss = loss.mean()

        return loss

# Models

In [None]:
class GRUModel(nn.Module):
    def __init__(self, vocab_size, embed_size, hidden_size, num_layers=1, dropout=0, bidirectional=False):
        super().__init__()
        
        self.vocab_size = vocab_size
        self.embed_size = embed_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.dropout = dropout
        self.bidirectional = bidirectional
        
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.gru = nn.GRU(embed_size, hidden_size, num_layers=num_layers, dropout=dropout, batch_first=True, bidirectional=bidirectional)
        self.out = nn.Linear(hidden_size*2 if self.bidirectional else hidden_size, vocab_size)
        
        
    def forward(self, x, hidden=None):
        x = self.embedding(x)
        x, hidden = self.gru(x, hidden)
        x = self.out(x[:, -1])
        return x, hidden
    
class PositionalEncoding(nn.Module):
    """Sinusoidal Positional Encoding
    """
    def __init__(self, d_model, max_len):
        super().__init__()
        self.d_model = d_model
        self.max_len = max_len
        
        positions = torch.arange(0, max_len).unsqueeze(1)

        frequencies = 10000**(torch.arange(0, d_model, 2)/d_model)
        
        self.encoding = torch.zeros(max_len, d_model)

        self.encoding[:, 1::2] = torch.sin(positions / frequencies)
        self.encoding[:, 0::2] = torch.cos(positions / frequencies)

    def forward(self, x):
        seq_len = x.shape[1]

        x = x  + self.encoding[:seq_len].to(x.device)

        return x


class CrossLayer(nn.Module):
    def __init__(self, embed_size, hidden_size, dropout):
        super().__init__()
        self.embed_size = embed_size
        self.hidden_size = hidden_size
        self.dropout = dropout

        self.self_attention_encoder_1 = nn.TransformerEncoderLayer(
            d_model=embed_size,
            nhead=1,
            dim_feedforward=hidden_size,
            dropout=dropout,
            batch_first=True
        )

        self.self_attention_encoder_2 = nn.TransformerEncoderLayer(
            d_model=embed_size,
            nhead=1,
            dim_feedforward=hidden_size,
            dropout=dropout,
            batch_first=True
        )

        self.cross_attention1 = nn.MultiheadAttention(embed_dim=embed_size, num_heads=1, batch_first=True)
        self.cross_attention2 = nn.MultiheadAttention(embed_dim=embed_size, num_heads=1, batch_first=True)


    def forward(self, x1, x2):

        # x1 -> (N, S)
        # x2 -> (N, S)

        cross_alignment_1, _ = self.cross_attention1(
            query=x2,
            key=x1,
            value=x1
        )

        cross_alignment_2, _ = self.cross_attention2(
            query=x1,
            key=x2,
            value=x2
        )

        alignment_1 = self.self_attention_encoder_1(cross_alignment_1)
        alignment_2 = self.self_attention_encoder_2(cross_alignment_2)

        return alignment_1, alignment_2


class CrossNet(nn.Module):
    def __init__(self, hidden_size, num_layers, dropout, mode="classification"):
        super().__init__()

        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.dropout = dropout
        self.is_contrastive = mode=="contrastive"

        self.drug_embedding = None
        self.protein_embedding = None

        self.cross_layers = nn.ModuleList([
            CrossLayer(embed_size=64, hidden_size=hidden_size, dropout=dropout) for _ in range(num_layers)
        ])

        if not self.is_contrastive:
            self.out = nn.Linear(64*2, 2)


        self.positional_encoding = PositionalEncoding(d_model=64, max_len=20000)


    def set_drug_protein_embeddings(self, drug, protein):
        self.drug_embedding = drug
        self.protein_embedding = protein

        self.drug_embedding.requires_grad = False
        self.protein_embedding.requires_grad = False


    def forward(self, x):
        x1, x2 = x
        N = x1.shape[0]

        x1 = self.drug_embedding(x1) # (N, S, E)
        x2 = self.protein_embedding(x2) # (N, S, E)

        x1 = self.positional_encoding(x1)
        x2 = self.positional_encoding(x2)

        for layer in self.cross_layers:
            x1, x2 = layer(x1, x2)


        # x1 (N, S, 64)
        # x2 (N, S, 64)

        x1 = x1.mean(1) # (N, 64)
        x2 = x2.mean(1) # (N, 64)

        if self.is_contrastive: 
            return x1, x2
        
        else:
            x = torch.cat([x1, x2], 1) # (N, 64*2)
            x = self.out(x)

            return x
        
def build_model(name, config):
    hidden_size = config["HIDDEN_SIZE"]
    num_layers = config["NUM_LAYERS"]
    dropout = config["DROPOUT"]
    mode = "contrastive" if bool(config["CONTRASTIVE"]) else "classification"

    
    if "cross" in name:
        net = CrossNet(hidden_size=hidden_size, num_layers=num_layers, dropout=dropout, mode=mode)

    else:
        net = Transformer(hidden_size=hidden_size, num_layers=num_layers, dropout=dropout)

    return net

# Vocab


In [None]:
class Vocab:
    def __init__(self, tokens): 
        special_tokens = ["PAD", "SOS", "EOS"]
        self.tokens = tokens + special_tokens
        self.token_ix = {t:i for i, t in enumerate(self.tokens)}
        self.ix_token = {i:t for i,t in enumerate(self.tokens)}

    def encode(self, seq, max_len=None):
        encoded = [self.token_ix["SOS"]] + [self.token_ix[t] for t in seq] + [self.token_ix["EOS"]]
        if max_len:
            if len(encoded) < max_len:
                encoded += [self.token_ix["PAD"]]*(max_len-len(encoded))
                
        return encoded
                
    def decode(self, seq):
        return [self.ix_token[t] for t in seq]
    
    def __len__(self):
        return len(self.tokens)
    

class ProteinVocab(Vocab):
    def __init__(self):
        # 20 amino acids
        tokens = ['A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'Y']
        super().__init__(tokens)


class SMILEVocab(Vocab):
    def __init__(self):
        tokens = ['#','(',')','+','-','.','/','1','2','3','4','5','6','7','8','=','@',
                  'A','B','C','F','G','H','I','K','L','M','N','O','P','S','T','V','W','Z',
                  '[','\\',']','a','b','d','e','g','i','l','n','o','r','s','t','u']
        super().__init__(tokens)


class DrugBankDB:
    def __init__(self):
        self.id_to_smile = json.load(open("../data/databankid_to_smile.json", "r"))

    def get_smile_from_id(self, id):
        return self.id_to_smile[id]


class UniProtDB:
    def __init__(self):
        self.id_to_amino_seq = json.load(open("../data/uniprotid_to_seq.json", "r"))

    def get_amino_seq_from_id(self, id):
        return self.id_to_amino_seq[id]

# Data

In [None]:
from vocab import DrugBankDB, UniProtDB, ProteinVocab, SMILEVocab


class DTIDataset(Dataset):
    def __init__(self, train, smile_vocab, protein_vocab, smile_embedding, amino_embedding, device="cuda"):
        self.device = device
        if train:
            self.x = json.load(open("../data/dti_train_x.json", "r"))
            self.y = json.load(open("../data/dti_train_y.json", "r"))
        else:
            self.x = json.load(open("../data/dti_test_x.json", "r"))
            self.y = json.load(open("../data/dti_test_y.json", "r"))

        self.train = train
        self.drugbankdb = DrugBankDB()
        self.uniprotdb = UniProtDB()
        self.smile_vocab = smile_vocab
        self.protein_vocab = protein_vocab
        self.smile_embedding = smile_embedding
        self.amino_embedding = amino_embedding

    def __len__(self):
        return len(self.y)
    
    def __getitem__(self, idx):
        row = self.x[idx]
        
        drug = row["drug"]
        target = row["target"]
        
        drug = self.drugbankdb.get_smile_from_id(drug)
        drug = torch.Tensor(self.smile_vocab.encode(drug)).long().to(self.device) # get tokenized
        # drug = self.smile_embedding(drug).mean(1).squeeze() # get embeddings
        
        target = self.uniprotdb.get_amino_seq_from_id(target)
        target = torch.Tensor(self.protein_vocab.encode(target)).long().to(self.device) # get tokenized
        # target = self.amino_embedding(target).mean(1).squeeze() # get embeddings
        
        y = self.y[idx]

        drug = drug.to(self.device)
        target = target.to(self.device)
        
        return drug, target, y

# Utils

In [None]:
def parse_yaml(f):
    return yaml.safe_load(open(f, "r"))

def get_predictions_from_similarity(repr1, repr2, thresh=.5):
    repr1 = normalize(repr1, p=2, dim=1) # l2 norm
    repr2 = normalize(repr2, p=2, dim=1) # l2 norm 
    distance = torch.abs(cosine_similarity(repr1, repr2)) # (N)
    
    pred = (distance > thresh).long()

    return pred

def get_predictions_from_prob(prob, thresh=.5):
    return (prob > thresh).long()


def dataloader_collate_fn(data):
    """
    Pad (x1, x2) in data to same length in batch
    """

    x1, x2, y = zip(*data)

    y = torch.Tensor(list(y))

    lengths_x1 = [len(x) for x in x1]
    lengths_x2 = [len(x) for x in x2]

    max_len_x1 = max(lengths_x1)
    max_len_x2 = max(lengths_x2)

    batch_size = len(lengths_x1)

    x1_padded = torch.zeros(batch_size, max_len_x1).fill_(51)
    x2_padded = torch.zeros(batch_size, max_len_x2).fill_(21) # 21 is the <PAD> token in the amino acids vocab

    for i, seq in enumerate(x1):
        end = lengths_x1[i]
        x1_padded[i, :end] = seq[:end]

    for i, seq in enumerate(x2):
        end = lengths_x2[i]
        x2_padded[i, :end] = seq[:end]


    x1_padded = x1_padded.long()
    x2_padded = x2_padded.long()


    return x1_padded, x2_padded, y

def accuracy_from_contrastive_model(p, y):
    p1, p2 = p
    p = get_predictions_from_similarity(p1, p2)
    accuracy = (p == y).float().mean()

    return accuracy


def accuracy_from_classification_model(p, y):
    p = p.argmax(-1)
    accuracy = (p==y).float().mean()
    return accuracy

# Trainer

In [None]:
import logging

class RunningAverager:
    def __init__(self, track=[], smooth=.6):
        self.track = {
            to_track: 0
            for to_track in track
        }
        self.smooth = smooth

    def add_new(self, values):
        for key in values.keys():
            if key not in self.track:
                logging.warning(f"{key} not in tracked values.")
            
            new_value = values[key]
            old_value = self.track[key]

            self.track[key] = self.smooth*new_value + (1-self.smooth)*old_value

    def get_tracked(self):
        return self.track


class Trainer:
    def __init__(
        self,
        model,
        train,
        test,
        epochs,
        optimizer,
        lossfn,
        metrics,
        config_file,
        smooth=.6,
        device="cuda"
    ):
        self.model = model
        self.model.to(device)
        self.optimizer = optimizer
        self.lossfn = lossfn
        self.metrics = metrics
        self.train = train
        self.test = test
        self.epochs = epochs
        self.writer = writer
        self.smooth = smooth
        self.device = device
        self.config_file = config_file

        self.lowest_loss = {
            "train": float("inf"),
            "test": float("inf")
        }

    def train_one_epoch(self, epoch):
        averager = RunningAverager(
            track=["loss"]+list(self.metrics.keys()),
            smooth=self.smooth
        )

        for x1, x2, y in self.train:
            self.model.train()
            self.optimizer.zero_grad()

            x1 = x1.to(self.device)
            x2 = x2.to(self.device)
            y = y.to(self.device)

            p = self.model((x1, x2))
            loss = self.lossfn(p, y)

            loss.backward()

            self.optimizer.step()

            averager.add_new({
                "loss": loss.item(),
            })

            for metric_name in self.metrics.keys():
                metric_value = self.metrics[metric_name](p, y)
                
                averager.add_new({
                    metric_name: metric_value
                })

        averaged = averager.get_tracked()

        return averaged

    @torch.no_grad()
    def test_one_epoch(self, epoch):
        self.model.eval()
        averager = RunningAverager(
            track=["loss"]+list(self.metrics.keys()),
            smooth=self.smooth
        )

        for x1, x2, y in self.test:

            x1 = x1.to(self.device)
            x2 = x2.to(self.device)
            y = y.to(self.device)
            
            p = self.model((x1, x2))
            loss = self.lossfn(p, y)

            averager.add_new({
                "loss": loss.item(),
            })

            for metric_name in self.metrics.keys():
                metric_value = self.metrics[metric_name](p, y)
                
                averager.add_new({
                    metric_name: metric_value
                })

        averaged = averager.get_tracked()

        return averaged

    def run(self):
        for epoch in tqdm(range(self.epochs)):
            train_averaged = self.train_one_epoch(epoch)
            test_averaged = self.test_one_epoch(epoch)

            if test_averaged["loss"] < self.lowest_loss["test"]:
                self.lowest_loss["test"] = test_averaged["loss"]
                
                self.save_checkpoint(accuracy=test_averaged["Accuracy"], epoch=epoch)

            if train_averaged["loss"] < self.lowest_loss["train"]:
                self.lowest_loss["train"] = train_averaged["loss"]
                
                
            print(
                " ".join([f"{metric_name}: {averaged[metric_name]}" for metric_name in list(train_averaged.keys())])
            )
            
            print(
                " ".join([f"{metric_name}: {averaged[metric_name]}" for metric_name in list(test_averaged.keys())])
            )
            
    def save_checkpoint(self, accuracy=None, epoch=None):
        torch.save({
            "model": self.model.state_dict(),
            "optimizer": self.optimizer.state_dict(),
            "loss": self.lowest_loss["test"],
            "accuracy": accuracy,
            "epoch": epoch+1
        }, f"../checkpoints/classification/{self.config_file['RUN_NAME']}/model.pt")


# Train tyme

In [None]:
config = {
    "SEED": 42,
    "NAME": "crosstransformer",
    "RUN_NAME": "crosstransformer-contrastive-baseline",
    "LR": 1e-4,
    "EPOCHS": 3000,
    "BATCH_SIZE": 16,
    "DROPOUT": 0,
    "NUM_LAYERS": 2,
    "HIDDEN_SIZE": 256,
    "CONTRASTIVE": True
    
}

os.makedirs(f"../checkpoints/classification/{config['RUN_NAME']}")

np.random.seed(int(config["SEED"]))
torch.manual_seed(int(config["SEED"]))

amino_vocab = ProteinVocab()
smile_vocab = SMILEVocab()

pretrained_smile_embeddings = GRUModel(
    vocab_size=len(smile_vocab),
    embed_size=64,
    hidden_size=128,
    num_layers=1,
    dropout=0,
    bidirectional=False
).to(args.device)

pretrained_amino_embeddings = GRUModel(
    vocab_size=len(amino_vocab),
    embed_size=64,
    hidden_size=128,
    num_layers=1,
    dropout=0,
    bidirectional=False
).to(args.device)

pretrained_smile_embeddings.load_state_dict(torch.load("../checkpoints/pretraining/smile_gru.pth"))
pretrained_amino_embeddings.load_state_dict(torch.load("../checkpoints/pretraining/protein_gru.pth"))

pretrained_smile_embeddings = pretrained_smile_embeddings.embedding
pretrained_amino_embeddings = pretrained_amino_embeddings.embedding


net = build_model(config)

net.set_drug_protein_embeddings(
    pretrained_smile_embeddings.to(args.device),
    pretrained_amino_embeddings.to(args.device)        
)

net.to(args.device)

train = DTIDataset(
    train=True,
    smile_vocab=smile_vocab,
    protein_vocab=amino_vocab,
    smile_embedding=pretrained_smile_embeddings,
    amino_embedding=pretrained_amino_embeddings,
    device=args.device
)

test = DTIDataset(
    train=False,
    smile_vocab=smile_vocab,
    protein_vocab=amino_vocab,
    smile_embedding=pretrained_smile_embeddings,
    amino_embedding=pretrained_amino_embeddings,
    device=args.device
)

train = DataLoader(train, batch_size=int(config_file["BATCH_SIZE"]), shuffle=True, collate_fn=utils.dataloader_collate_fn)
test = DataLoader(test, batch_size=int(config_file["BATCH_SIZE"]), shuffle=True, collate_fn=utils.dataloader_collate_fn)

lossfn = ContrastiveLoss()

optimizer = optim.Adam(net.parameters(), lr=float(config_file["LR"]))

trainer = Trainer(
    model=net,
    train=train,
    test=test,
    epochs=config_file["EPOCHS"],
    optimizer=optimizer,
    lossfn=lossfn,
    metrics={
        "Accuracy": accuracy_from_contrastive_model
    },
    smooth=.6,
    device="cuda",
    config_file=config
)

trainer.run()