In [3]:
#Relevant packages
import numpy as np
import pandas as pd
import torch 
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import torchtext
import torchtext.vocab as vocab
import os
from datetime import date
today = date.today()

from misc import get_split_indices
from misc import export_results
from data_preprocessing import data_loader
from build_vocabulary import vocab_geno
from build_vocabulary import vocab_pheno
from create_dataset import NCBIDataset
from bert_builder import BERT_pt
from bert_builder import BERT_ft
from trainer import BertTrainer_ft
from trainer import BertTrainer_pt
from misc import get_paths

####################################################
#Data directory
data_dir, ab_dir, save_directory = get_paths()

#Hyperparameters
threshold_year = 1970
max_length = [51,81]
mask_prob = 0.15
drop_prob = 0.2
limit_data = True
reduced_samples = 1000 #Ta bort denna senare

dim_emb = 128
dim_hidden = 128
attention_heads = 8 

num_encoders = 2

epochs = 10
batch_size = 32
lr = 0.0001
stop_patience = 3

# WandB settings
wandb_mode = True 
wandb_project = "TestingMain"
wandb_run_name = "Main_ReducedSamples_ModePT"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
####################################################

#set mode for run, True for fine tuning, False for pretraining 
mode_ft = True
if mode_ft:
    print(f"Fine tuning mode")
    include_pheno = True
    export_model = False
else:
    print(f"Pretraining mode")
    include_pheno = False
    export_model = True

if device.type == "cuda":
    print(f"Using GPU: {torch.cuda.get_device_name(0)}")
    torch.cuda.empty_cache()
else:
    print("Using CPU")  
    
print(f"\n Retrieving data from: {data_dir}")
print("Loading data...")
NCBI,ab_df = data_loader(include_pheno,threshold_year,data_dir,ab_dir)
print(f"Data correctly loaded, {len(NCBI)} samples found")
print("Creating vocabulary...")
vocabulary_geno = vocab_geno(NCBI, include_pheno)
vocabulary_pheno = vocab_pheno(ab_df)

print(f"Vocabulary created with number of elements:",len(vocabulary_geno))
if include_pheno:
    print(f"Number of antibiotics:",len(vocabulary_pheno))

if limit_data:
    print(f"Reducing samples to {reduced_samples}")
    NCBI = NCBI.head(reduced_samples)

train_indices, val_indices = get_split_indices(len(NCBI), 0.2)
train_set = NCBIDataset(NCBI.iloc[train_indices], vocabulary_geno, vocabulary_pheno, max_length, mask_prob,include_pheno)
val_set = NCBIDataset(NCBI.iloc[val_indices], vocabulary_geno, vocabulary_pheno, max_length, mask_prob,include_pheno)
print(f"Datasets has been created with {len(train_set)} samples in the training set and {len(val_set)} samples in the validation set")

print(f"Creating model...")
if mode_ft:
    model = BERT_ft(vocab_size=len(vocabulary_geno), dim_embedding = dim_emb, dim_hidden=dim_hidden, attention_heads=8, num_encoders=num_encoders, dropout_prob=drop_prob, num_ab=len(vocabulary_pheno), device=device)
else:
    model = BERT_pt(vocab_size=len(vocabulary_geno), dim_embedding = dim_emb, attention_heads=attention_heads, num_encoders=2, dropout_prob=drop_prob)
print(f"Model successfully loaded")
print(f"---------------------------------------------------------")
print(f"Starting training...")
if mode_ft:
    trainer = BertTrainer_ft(model, train_set, val_set, epochs, batch_size, lr, device, stop_patience, wandb_mode, wandb_project, wandb_run_name)
else: 
    trainer = BertTrainer_pt(model, train_set, val_set, epochs, batch_size, lr, device, stop_patience, wandb_mode, wandb_project, wandb_run_name)

results = trainer()
print(f"---------------------------------------------------------")
if export_model:
    print(f"Exporting model...")
    export_model_label = str(today)+"model.pkl"
    trainer._save_model(save_directory+"/"+export_model_label)
print("Exporting results...")
export_results_label = str(today)+"run"+"Mode"+str(mode_ft)+".pkl"
export_results(results, save_directory+"/"+export_results_label)



Fine tuning mode
Using CPU

 Retrieving data from: c:\Users\erikw\Desktop\ExjobbKod\data
Loading data...
Data correctly loaded, 6483 samples found
Creating vocabulary...
Vocabulary created with number of elements: 475
Number of antibiotics: 81
Reducing samples to 1000
Datasets has been created with 800 samples in the training set and 200 samples in the validation set
Creating model...
Model successfully loaded
---------------------------------------------------------
Starting training...


0,1
Accuracies/val_acc,▁▁▃▄▄▄▆▆▇█
Losses/train_loss,█▆▅▄▄▃▂▂▂▁
Losses/val_loss,█▆▅▄▄▃▂▂▁▁
epoch,▁▂▃▃▄▅▆▆▇█

0,1
Accuracies/val_acc,0.36527
Losses/train_loss,3.46872
Losses/val_loss,3.40149
epoch,10.0


Epoch 1/10
Epoch completed in 0.1 min
Evaluating on validation set...
Elapsed time: 00:00:08
Epoch 2/10
Epoch completed in 0.1 min
Evaluating on validation set...
Elapsed time: 00:00:14
Epoch 3/10
Epoch completed in 0.1 min
Evaluating on validation set...
Elapsed time: 00:00:21
Epoch 4/10
Epoch completed in 0.1 min
Evaluating on validation set...
Elapsed time: 00:00:28
Epoch 5/10
Epoch completed in 0.1 min
Evaluating on validation set...
Elapsed time: 00:00:34
Epoch 6/10
Epoch completed in 0.1 min
Evaluating on validation set...
Elapsed time: 00:00:41
Epoch 7/10
Epoch completed in 0.1 min
Evaluating on validation set...
Elapsed time: 00:00:47
Epoch 8/10
Epoch completed in 0.1 min
Evaluating on validation set...
Elapsed time: 00:00:54
Epoch 9/10
Epoch completed in 0.1 min
Evaluating on validation set...
Elapsed time: 00:01:00
Epoch 10/10
Epoch completed in 0.1 min
Evaluating on validation set...
Elapsed time: 00:01:08
-=Training completed=-
----------------------------------------------