In [None]:
# Finializing

In [None]:
import wandb

wandb.init(project="DNABERT", name="first-run")

config = {
    "epochs": 30,
    "datasate_size": "10K",
    "codons": 3,
    "batch_size": 32,
    "embed_dim": 128,
    "num_heads": 4,
    "num_layers": 2,
    "ff_dim": 256,
    "learning_rate": 1e-3,
    "model": "transformer",
}

wandb.config.update(config)

In [None]:
import pandas as pd

data = pd.read_csv("Dataset/data.csv")
data.shape

(1859483, 10)

In [131]:
data = data.sample(n=10000)

## Prepairing

In [132]:
data['mutation_pos'] = 0

In [133]:
data['ref'] = data['ref'].map({'G': 0, 'C': 1, 'A': 2, 'T':3})
data['alt'] = data['alt'].map({'G': 0, 'C': 1, 'A': 2, 'T':3})

In [134]:
from sklearn.preprocessing import LabelEncoder

mut_type_encoder = LabelEncoder()
data['mutation_type'] = mut_type_encoder.fit_transform(data['mutation_type'])

In [135]:
from sklearn.preprocessing import LabelEncoder

chrom_encoder = LabelEncoder()
data['chrom'] = chrom_encoder.fit_transform(data['chrom'])

In [136]:
data.head()

Unnamed: 0,sequence,label,mutation_pos,ref,alt,mutation_type,chrom,genomic_pos,context_left,context_right
820871,GTTATTTTGTTTTATTCTCACTGCTTCTGGGCAGAGGGAGCTGGGA...,2,0,1,3,5,7,1232766,GTTATTTTGTTTTATTCTCACTGCTTCTGGGCAGAGGGAGCTGGGA...,GCTGTCCTTCCTGATCAGCCTGTACTACAACACCATCGTGGCGTGG...
1646142,TGGTTCTTCGTGGAGCTCGGCTGGGCCAGGCAGTATTGAGCGATGT...,2,0,1,0,4,2,44126932,TGGTTCTTCGTGGAGCTCGGCTGGGCCAGGCAGTATTGAGCGATGT...,TCTGAAGTTCTTGACTGGAAGAGGTGGGTAGTACCTCCTAGTAAAC...
1280029,TCAGTATGACACCAATGAATATAGTATTAACAGGTAAGATGAGTGG...,1,0,1,3,5,10,47858894,TCAGTATGACACCAATGAATATAGTATTAACAGGTAAGATGAGTGG...,AGGGCCGTCAGGGGCGCCATGCACTCATGCCGATTGAGCTCGTCCA...
1594122,CCCTCTGAGCTTGGACCCCAGCCCCACCTGCGCTGGCCAGGGAGGC...,2,0,0,2,6,1,133269495,CCCTCTGAGCTTGGACCCCAGCCCCACCTGCGCTGGCCAGGGAGGC...,CGTGGCAGTGGCACTCCTGCTTGTGGTTGCACACCTGCGGGGACGG...
411670,AAGTTCCATCTTTTACTTCAAAAACAAAACTAACATTGCATATTAC...,4,0,1,3,5,4,157774114,AAGTTCCATCTTTTACTTCAAAAACAAAACTAACATTGCATATTAC...,GAGCCACTGTTCTTTGTACCAGAAAAGGAAGACCAGAGCCACTTCC...


In [137]:
from sklearn.preprocessing import MinMaxScaler

y_genom_scaler = MinMaxScaler() # use scaler.inverse_transform(y_pred_scaled) when prediction in done.

In [138]:
x = data['sequence'].values
y_labels = data['label'].values
y_ref = data['ref'].values
y_alt = data['alt'].values
y_mut_type = data['mutation_type'].values
y_chrom = data['chrom'].values
y_gemon_pos = y_genom_scaler.fit_transform(data['genomic_pos'].values.reshape(-1, 1)).reshape(-1)

## Text to Tensor Pipeline

In [None]:
def get_codon(seq, k=config['codons']):
    return [seq[i:i+k] for i in range(len(seq) - k + 1)]

In [140]:
vocab = {'<PAD>': 0, '<UNK>': 1}

for seq in data['sequence']:
    for codons in get_codon(seq.lower()):
        if codons not in vocab:
            vocab[codons] = len(vocab)
        else:
            continue

In [141]:
def get_tensor(text):
    return [vocab.get(codons.lower(), vocab['<UNK>']) for codons in get_codon(text)]

## Custom Dataset Pipepline

In [142]:
import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader, random_split

In [182]:
class CustomDataset(Dataset):
    def __init__(
            self,
            x_frame, 
            y_labels_frame, 
            y_ref_frame,
            y_alt_frame, 
            y_mut_type_frame, 
            y_chrom_frame, 
            y_gemon_pos_frame
    ):
        self.x_frame = x_frame
        self.y_labels_frame = y_labels_frame
        self.y_ref_frame = y_ref_frame
        self.y_alt_frame = y_alt_frame
        self.y_mut_type_frame = y_mut_type_frame
        self.y_chrom_frame = y_chrom_frame
        self.y_gemon_pos_frame = y_gemon_pos_frame

    def __len__(self):
        return len(self.x_frame)
    
    def __getitem__(self, index):
        x_tensor = torch.tensor(get_tensor(self.x_frame[index]), dtype=torch.long)
        y_labels_frame = torch.tensor(self.y_labels_frame[index], dtype=torch.long)
        y_ref_frame = torch.tensor(self.y_ref_frame[index], dtype=torch.long)
        y_alt_frame = torch.tensor(self.y_alt_frame[index], dtype=torch.long)
        y_mut_type_frame = torch.tensor(self.y_mut_type_frame[index], dtype=torch.long)
        y_chrom_frame = torch.tensor(self.y_chrom_frame[index], dtype=torch.long)
        y_gemon_pos_frame = torch.tensor(self.y_gemon_pos_frame[index], dtype=torch.float32)

        return x_tensor, y_labels_frame, y_ref_frame, y_alt_frame, y_mut_type_frame, y_chrom_frame, y_gemon_pos_frame

In [183]:
dataset = CustomDataset(
    x_frame=x,
    y_labels_frame=y_labels,
    y_ref_frame=y_ref,
    y_alt_frame=y_alt,
    y_mut_type_frame=y_mut_type,
    y_chrom_frame=y_chrom,
    y_gemon_pos_frame=y_gemon_pos
)

In [184]:
dataset[0]

(tensor([ 2,  3,  4,  5,  6,  6,  7,  8,  2,  6,  6,  3,  4,  5,  9, 10, 11, 12,
         13, 14, 15, 16, 17, 18,  9, 10, 15, 19, 20, 21, 22, 23, 24, 25, 26, 20,
         27, 25, 28, 17, 15, 19, 20, 27, 29, 30, 26, 27, 25, 28, 31, 32, 32, 33,
         34, 20, 20, 21, 31, 35, 13, 36, 37, 15, 38, 39, 40, 41, 42, 19, 43, 44,
         45, 32, 37, 15,  8, 44, 45, 35, 13, 40, 23, 26, 20, 21, 17, 15, 19, 20,
         21, 17, 15,  8, 46,  8, 44, 12, 13, 47, 48, 17, 15,  8, 44, 45, 37, 18,
          9, 45, 37, 15, 38, 49, 50, 12, 23, 28, 31, 37, 15,  8, 51, 52, 14, 53,
         52, 40, 54, 55, 40, 13, 36, 35, 41, 50, 56, 57, 46, 19, 21, 58, 57, 46,
         19, 20, 43, 46, 16, 17, 15,  8, 46, 19, 43, 51, 52, 36, 37, 11, 45, 37,
         11, 12, 54, 55, 14, 11, 45, 37, 18,  9, 45, 35, 23, 28, 22, 13, 36, 32,
         33, 48, 17, 15, 16, 31, 32, 37, 15, 19, 27, 25, 28, 17, 11, 45, 37, 15,
         16]),
 tensor(2),
 tensor(1),
 tensor(3),
 tensor(5),
 tensor(7),
 tensor(0.0049))

In [185]:
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size

train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

In [186]:
from torch.nn.utils.rnn import pad_sequence

def collate_fn(batch):
    x_tensor, y_labels_frame, y_ref_frame, y_alt_frame, y_mut_type_frame, y_chrom_frame, y_gemon_pos_frame = zip(*batch)

    y_labels_frame = torch.tensor(y_labels_frame, dtype=torch.long)
    y_ref_frame = torch.tensor(y_ref_frame, dtype=torch.long)
    y_alt_frame = torch.tensor(y_alt_frame, dtype=torch.long)
    y_mut_type_frame = torch.tensor(y_mut_type_frame, dtype=torch.long)
    y_chrom_frame = torch.tensor(y_chrom_frame, dtype=torch.long)
    y_gemon_pos_frame = torch.tensor(y_gemon_pos_frame, dtype=torch.float32)


    padded_x_tensor = pad_sequence(x_tensor, batch_first=True, padding_value=vocab['<PAD>'])

    return padded_x_tensor, y_labels_frame, y_ref_frame, y_alt_frame, y_mut_type_frame, y_chrom_frame, y_gemon_pos_frame

In [None]:
train_loader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=config['batch_size'], collate_fn=collate_fn)

In [188]:
import math

class PositionalEncoding(nn.Module):
    def __init__(self, embed_dim, max_len=5000):
        super().__init__()

        pe = torch.zeros(max_len, embed_dim)
        position = torch.arange(0, max_len).unsqueeze(1)

        div_term = torch.exp((torch.arange(0, embed_dim, 2)) * (-math.log(10000.0) / embed_dim))

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:, :x.size(1), :].to(x.device)
        return x

class Transformer(nn.Module):
    def __init__(self, embed_dim=512, num_heads=8, num_layers=6, ff_dim=2048, vocab_size=10000, max_len=5000):
        super(Transformer, self).__init__()
        self.embeddings = nn.Embedding(vocab_size, embed_dim)
        self.position_encoding = PositionalEncoding(embed_dim=embed_dim, max_len=max_len)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim,
            nhead=num_heads,
            dim_feedforward=ff_dim,
            dropout=0.1,
            batch_first=True
        )
        self.encoder = nn.TransformerEncoder(
            encoder_layer=encoder_layer,
            num_layers=num_layers
        )

        self.y_labels_out = nn.Linear(embed_dim, 5)
        self.y_ref_out = nn.Linear(embed_dim, 4)
        self.y_alt_out = nn.Linear(embed_dim, 4)
        self.y_mut_type_out = nn.Linear(embed_dim, 12)
        self.y_chrom_out = nn.Linear(embed_dim, 12)
        self.y_gemon_pos_out = nn.Linear(embed_dim, 1)

    def forward(self, x):
        x = self.embeddings(x)
        x = self.position_encoding(x)

        x = self.encoder(x)
        x = x.mean(dim=1)

        y_labels_out = self.y_labels_out(x)
        y_ref_out = self.y_ref_out(x)
        y_alt_out = self.y_alt_out(x)
        y_mut_type_out = self.y_mut_type_out(x)
        y_chrom_out = self.y_chrom_out(x)
        y_gemon_pos_out = self.y_gemon_pos_out(x)

        return y_labels_out, y_ref_out, y_alt_out, y_mut_type_out, y_chrom_out, y_gemon_pos_out

In [None]:
model = Transformer(
    embed_dim=config['embed_dim'],
    num_heads=config['num_heads'],
    num_layers=config['num_layers'],
    ff_dim=config['ff_dim'],
    vocab_size=len(vocab),
    max_len=200
)

In [190]:
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total trainable parameters: {num_params}")

Total trainable parameters: 278310


In [None]:
ce1 = nn.CrossEntropyLoss()
ce2 = nn.CrossEntropyLoss()
ce3 = nn.CrossEntropyLoss()
ce4 = nn.CrossEntropyLoss()
ce5 = nn.CrossEntropyLoss()

mse = nn.MSELoss()

optimizer = optim.Adam(model.parameters(), lr=config['learning_rate'])

In [192]:
def train(model, loader, ce1, ce2, ce3, ce4, ce5, mse, optimizer):
    model.train()

    running_loss = 0.0
    for x, y1, y2, y3, y4, y5, y6 in loader:
        optimizer.zero_grad()

        out1, out2, out3, out4, out5, out6 = model(x)
        
        loss1 = ce1(out1, y1)
        loss2 = ce2(out2, y2)
        loss3 = ce3(out3, y3)
        loss4 = ce4(out4, y4)
        loss5 = ce5(out5, y5)
        loss6 = mse(out6.squeeze(1), y6)

        loss = loss1 + loss2 + loss3 + loss4 + loss5 + loss6

        loss.backward()
        optimizer.step()

        running_loss += loss.item() * len(x)
    
    print(f"{loss1.item()} : {loss2.item()} : {loss3.item()} : {loss4.item()} : {loss5.item()} : {loss6.item()}")

    return (running_loss / len(loader.dataset))

In [193]:
next(iter(train_loader))

(tensor([[37, 15, 19,  ...,  7,  8, 46],
         [31, 35, 13,  ..., 33, 48, 22],
         [ 5,  6,  3,  ..., 29, 64, 64],
         ...,
         [31, 37, 11,  ..., 16, 31, 32],
         [55, 47, 60,  ..., 65, 42,  8],
         [ 5,  6,  3,  ..., 22, 23, 28]]),
 tensor([1, 2, 2, 3, 2, 1, 1, 2, 1, 1, 2, 3, 1, 0, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2,
         2, 2, 1, 1, 0, 2, 2, 2]),
 tensor([1, 0, 0, 0, 0, 1, 0, 0, 2, 0, 1, 1, 1, 3, 1, 0, 0, 2, 1, 0, 3, 2, 1, 1,
         1, 3, 2, 0, 2, 3, 3, 1]),
 tensor([3, 2, 3, 3, 1, 3, 2, 2, 0, 1, 0, 0, 3, 2, 3, 1, 3, 3, 3, 1, 1, 0, 2, 2,
         3, 1, 0, 1, 3, 1, 0, 3]),
 tensor([ 5,  6,  8,  8,  7,  5,  6,  6,  1,  7,  4,  4,  5,  9,  5,  7,  8,  2,
          5,  7, 10,  1,  3,  3,  5, 10,  1,  7,  2, 10, 11,  5]),
 tensor([ 1,  5,  0,  6,  3, 10,  4,  6,  8,  3,  8, 11,  4,  2,  1,  7,  8, 10,
          8,  7, 11,  4,  7,  4,  7,  1,  4,  1,  6,  9,  4, 11]),
 tensor([0.4059, 0.7400, 0.2577, 0.2090, 0.0623, 0.5807, 0.6092, 0.0225, 0.2274,
         0.02

In [None]:
def validation(model, loader, ce1, ce2, ce3, ce4, ce5, mse):
    model.eval()

    running_loss, correct1, correct2, correct3, correct4, correct5, total = 0.0, 0, 0, 0, 0, 0, 0

    with torch.no_grad():
        for x, y1, y2, y3, y4, y5, y6 in loader:
            out1, out2, out3, out4, out5, out6 = model(x)

            loss1 = ce1(out1, y1)
            loss2 = ce2(out2, y2)
            loss3 = ce3(out3, y3)
            loss4 = ce4(out4, y4)
            loss5 = ce5(out5, y5)
            loss6 = mse(out6.squeeze(1), y6)

            loss = loss1 + loss2 + loss3 + loss4 + loss5 + loss6
            running_loss += loss.item() * len(x) 

            prediction1 = torch.argmax(out1, dim=1)
            correct1 += (prediction1 == y1).sum().item()

            prediction2 = torch.argmax(out2, dim=1)
            correct2 += (prediction2 == y2).sum().item()

            prediction3 = torch.argmax(out3, dim=1)
            correct3 += (prediction3 == y3).sum().item()

            prediction4 = torch.argmax(out4, dim=1)
            correct4 += (prediction4 == y4).sum().item()

            prediction5 = torch.argmax(out5, dim=1)
            correct5 += (prediction5 == y5).sum().item()

            total += len(x)
    
    accuracy1 = correct1 / total
    accuracy2 = correct2 / total
    accuracy3 = correct3 / total
    accuracy4 = correct4 / total
    accuracy5 = correct5 / total

    return (
        running_loss / len(loader.dataset),
        accuracy1, 
        accuracy2, 
        accuracy3, 
        accuracy4, 
        accuracy5,
        loss6.item()
    )

In [None]:
epochs = 10
patience = 3
best_val_loss = float('inf')
counter = 0
early_stop = False

for epoch in range(epochs):
    train_loss = train(
        model,
        train_loader,
        ce1,
        ce2,
        ce3,
        ce4,
        ce5,
        mse,
        optimizer
    )

    val_loss, acc1, acc2, acc3, acc4, acc5 = validation(
        model, 
        test_loader,
        ce1,
        ce2,
        ce3,
        ce4,
        ce5,
        mse
    )

    print(f"Epoch ({epoch+1}/{epochs}): Train Loss = {train_loss:.4f}, Val Loss = {val_loss:.4f}, Acc1 = {acc1:.4f}, Acc2 = {acc2:.4f}, Acc3 = {acc3:.4f}, Acc4 = {acc4:.4f}, Acc5 = {acc5:.4f}")

    wandb.log({"epoch": epoch,
               "Training Loss": train_loss, 
               "Validation Loss": val_loss, 
               "Label Accuracy": acc1,  
               "Reference Accuracy": acc2,  
               "Alternate Accuracy": acc3,  
               "Mutation Type Accuracy": acc4, 
               "Chrom Accuracy": acc5, 
               })

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        counter = 0
        print("No-Early Stopping Triggered!")
        continue
    else:
        counter += 1
        print(f"No improvement in val loss. Counter = {counter}/{patience}")
        if counter >= patience:
            print("Early stopping triggered!")
            early_stop = True
            break