In [None]:
import pandas as pd
import os
import numpy as np
from pathlib import Path
import math
import torch
from copy import deepcopy
from itertools import chain 
from torch.utils.data import Dataset
from torchtext.vocab import vocab as Vocab

import warnings
warnings.filterwarnings("ignore")


Pathing


In [1]:
#Lokalt
data_dir = 'c:\\Users\\erika\\Desktop\\Exjobb\\data'

In [None]:
#saga
data_dir = "/home/aeerik/data/raw/"


Data import

In [None]:
from data_preprocessing import data_loader
include_pheno = False
threshold_year = 1970
path = data_dir

NCBI = data_loader(include_pheno,threshold_year,path)
NCBI.head()

Vocabulary

In [None]:
from build_vocabulary import make_vocabulary
vocabulary = make_vocabulary(NCBI)
print(len(vocabulary))

In [2]:
from create_dataset import NCBIDataset
from build_vocabulary import make_vocabulary
from data_preprocessing import data_loader

include_pheno = False
threshold_year = 1970
path = data_dir

NCBI = data_loader(include_pheno,threshold_year,path)

max_length = 20
mask_prob = 0.30
vocabulary = make_vocabulary(NCBI)

test_set = NCBIDataset(NCBI, vocabulary, max_length, mask_prob)
test_set.prepare_dataset()

Embedding 

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class JointEmbedding(nn.Module):

    def __init__(self, embedding_dim, vocab_size, max_length, drop_prob):
        super(JointEmbedding, self).__init__()

        self.max_length = max_length
        self.drop_prob = drop_prob
        self.vocab_size = vocab_size
        self.embedding_dim = embedding_dim

        self.token_emb = nn.Embedding(self.vocab_size, self.embedding_dim)

        self.dropout = nn.Dropout(drop_prob)	
        self.norm = nn.LayerNorm(self.embedding_dim)

    def forward(self, input_tensor):
        token_embedding = self.token_emb(input_tensor)
        token_embedding = self.norm(token_embedding)
        token_embedding = self.dropout(token_embedding)

        return token_embedding


In [None]:
from embedding import JointEmbedding
embedding_dim = 32
voca_size = len(vocabulary)
max_length = 20
drop_prob = 0.2

emb_test = JointEmbedding(embedding_dim, voca_size, max_length, drop_prob)
emb_test.forward(test_set[1][0])

In [None]:
import torch

from torch import nn
import torch.nn.functional as f

class AttentionHead(nn.Module):

    def __init__(self, dim_inp, dim_out, drop_prob):
        super(AttentionHead, self).__init__()

        self.dim_inp = dim_inp
        self.drop_prob = drop_prob
        self.dropout = nn.Dropout(drop_prob)
        self.q = nn.Linear(dim_inp, dim_out)
        self.k = nn.Linear(dim_inp, dim_out)
        self.v = nn.Linear(dim_inp, dim_out)

    def forward(self, input_tensor: torch.Tensor, attention_mask: torch.Tensor = None):
        query, key, value = self.q(input_tensor), self.k(input_tensor), self.v(input_tensor)

        scale = query.size(1) ** 0.5
        scores = torch.matmul(query, key.transpose(-1, -2)) / scale

        scores = scores.masked_fill_(attention_mask, -1e9)
        attn = f.softmax(scores, dim=-1)
        attn = self.dropout(attn)
        context = torch.matmul(attn, value)

        return context


class MultiHeadAttention(nn.Module):

    def __init__(self, num_heads, dim_inp, dim_out,drop_prob):
        super(MultiHeadAttention, self).__init__()

        self.heads = nn.ModuleList([
            AttentionHead(dim_inp, dim_out,drop_prob) for _ in range(num_heads)
        ])
        self.linear = nn.Linear(dim_out * num_heads, dim_inp)
        self.norm = nn.LayerNorm(dim_inp)

    def forward(self, input_tensor: torch.Tensor, attention_mask: torch.Tensor):
        s = [head(input_tensor, attention_mask) for head in self.heads]
        scores = torch.cat(s, dim=-1)
        scores = self.linear(scores)
        return self.norm(scores)
    
attention_test = AttentionHead(32, 32, 0.2)
attention_test.forward(emb_test.forward(test_set[1][0]), test_set[1][2])

multihead_test = MultiHeadAttention(8, 32, 32,0.2)
multihead_test.forward(emb_test.forward(test_set[1][0]), test_set[1][2])

In [None]:
import torch

from torch import nn
import torch.nn.functional as f

#dropout 
#num heads
#dim in, dim out

class Encoder(nn.Module):

    def __init__(self, dim_inp, dim_out, attention_heads, dropout_prob):
        super(Encoder, self).__init__()
        self.dropout_prob = dropout_prob
        self.attention_heads = attention_heads
        self.dim_inp = dim_inp
        self.dim_out = dim_out
        
        
        self.attention = MultiHeadAttention(self.attention_heads, self.dim_inp, self.dim_out, self.dropout_prob)  # batch_size x sentence size x dim_inp
        self.feed_forward = nn.Sequential(
            nn.Linear(self.dim_inp, self.dim_out),
            nn.Dropout(self.dropout_prob),
            nn.GELU(),
            nn.Linear(self.dim_out, self.dim_inp),
            nn.Dropout(self.dropout_prob)
        )
        self.norm = nn.LayerNorm(self.dim_inp)

    def forward(self, input_tensor: torch.Tensor, attention_mask: torch.Tensor):
        context = self.attention(input_tensor, attention_mask)
        res = self.feed_forward(context)
        return self.norm(res)

encoder_test = Encoder(32, 32, 8, 0.2)
encoder_test.forward(emb_test.forward(test_set[1][0]), test_set[1][2])

In [None]:
class BERT(nn.Module):

    def __init__(self, vocab_size, max_length, dim_inp, dim_out, attention_heads, num_encoders, dropout_prob):
        super(BERT, self).__init__()
        self.attention_heads = attention_heads
        self.max_length = max_length
        self.vocab_size = vocab_size
        self.dim_inp = dim_inp  
        self.dim_out = dim_out
        self.num_encoders = num_encoders
        self.dropout_prob = dropout_prob  

        self.embedding = JointEmbedding(self.dim_inp, self.vocab_size, self.max_length, self.dropout_prob)
        self.encoders = nn.ModuleList([Encoder(self.dim_inp, self.dim_out, self.attention_heads, self.dropout_prob) for _ in range(self.num_encoders)])

        self.token_prediction_layer = nn.Linear(self.dim_inp, self.vocab_size)
        self.softmax = nn.LogSoftmax(dim=-1)

    def forward(self, input_tensor: torch.Tensor, attention_mask: torch.Tensor):
        embedded = self.embedding(input_tensor)
        for layer in self.encoders:
            embedded = layer(embedded, attention_mask)

        token_predictions = self.token_prediction_layer(embedded)
        return self.softmax(token_predictions)

In [None]:
bert_test = BERT(len(vocabulary), 20, 32, 32, 8, 2, 0.2)
bert_test.forward(test_set[1][0], test_set[1][2])

In [None]:
from bert_builder import AttentionHead

attention_test = AttentionHead(32, 32, 0.2)
attention_test.forward(emb_test.forward(test_set[1][0]), test_set[1][2])

In [None]:
from bert_builder import MultiHeadAttention
multihead_test = MultiHeadAttention(num_heads=8, dim_inp=32, dim_out=32, drop_prob=0.2)
multihead_test.forward(emb_test.forward(test_set[1][0]), test_set[1][2])

In [None]:
from bert_builder import Encoder

encoder_test = Encoder(32, 32, 8, 0.2)
encoder_test.forward(emb_test.forward(test_set[1][0]), test_set[1][2])

In [None]:
from bert_builder import BERT

bert_test = BERT(vocab_size=len(vocabulary), max_length=20, dim_inp=32, dim_out=32, attention_heads=8, num_encoders=2, dropout_prob=0.2)
bert_test.forward(test_set[1][0], test_set[1][2])

---------------------------

Trainer

In [None]:
import time
from datetime import datetime
from torch import nn
from torch.utils.data import DataLoader
from create_dataset import NCBIDataset


class BertTrainer:
    def __init__(self, model, train_set, val_set, epochs, batch_size, lr, device, results_dir):
        
        self.model = model
        self.train_set = train_set
        self.val_set = val_set
        self.epochs = epochs    
        self.batch_size = batch_size
        self.lr = lr
        self.weight_decay = 0.01
        self.current_epoch  = 0

        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr, weight_decay=self.weight_decay)
        self.criterion = nn.CrossEntropyLoss(ignore_index = -1).to(device)

        self.device = device
        self.results_dir = results_dir


        def __call__(self):      
            self.wandb_run = self._init_wandb()
            self.val_set.prepare_dataset() 
            self.val_loader = DataLoader(self.val_set, batch_size=self.batch_size, shuffle=False)
            
            start_time = time.time()
            self.best_val_loss = float('inf')
            self._init_result_lists()
            for self.current_epoch in range(self.current_epoch, self.epochs):
                self.model.train()
                self.train_set.prepare_dataset()
                self.train_loader = DataLoader(self.train_set, batch_size=self.batch_size, shuffle=True)
                epoch_start_time = time.time()




        
        def _init_result_lists(self):
            self.losses = []
            self.val_losses = []
            self.val_accs = []
            self.val_iso_accs = []
            self.val_iso_stats = []

        def train(self, epoch: int):
            print(f"Epoch {epoch+1}/{self.epochs}")
            time_ref = time.time()
            
            epoch_loss = 0
            reporting_loss = 0
            printing_loss = 0
            for i, batch in enumerate(self.train_loader):
                batch_index = i + 1
                input, token_target, attn_mask = batch
                
                self.optimizer.zero_grad() # zero out gradients
                tokens = self.model(input, attn_mask) # get predictions
                
                loss = self.criterion(tokens.transpose(-1, -2), token_target) # change dim to align with token_target
                
                epoch_loss += loss.item() 
                reporting_loss += loss.item()
                printing_loss += loss.item()
                
                loss.backward() 
                self.optimizer.step() 
                if batch_index % self.report_every == 0:
                    self._report_loss_results(batch_index, reporting_loss)
                    reporting_loss = 0 
                    
                if batch_index % self.print_progress_every == 0:
                    time_elapsed = time.gmtime(time.time() - time_ref) 
                    self._print_loss_summary(time_elapsed, batch_index, printing_loss) 
                    printing_loss = 0           
            avg_epoch_loss = epoch_loss / self.num_batches
            return avg_epoch_loss 