In [1]:
#Relevant packages
import numpy as np
import pandas as pd
import torch 
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import torchtext
import torchtext.vocab as vocab
import os
from datetime import date
today = date.today()

from misc import get_split_indices
from misc import export_results
from data_preprocessing import data_loader
from build_vocabulary import vocab_geno
from build_vocabulary import vocab_pheno
from create_dataset import NCBIDataset
from bert_builder import BERT_pt
from bert_builder import BERT_ft
from trainer import BertTrainer_ft
from trainer import BertTrainer_pt
from misc import get_paths

####################################################
#Data directory
data_dir, ab_dir, save_directory = get_paths()

#Hyperparameters
threshold_year = 1970
max_length = [51,81]
mask_prob = 0.15
drop_prob = 0.2
limit_data = True
reduced_samples = 1000 #Ta bort denna senare

dim_emb = 128
dim_hidden = 128
attention_heads = 8 

num_encoders = 2

epochs = 2
batch_size = 32
lr = 0.0001
stop_patience = 3
export_model = True

# WandB settings
wandb_mode = False 
wandb_project = "Main"
wandb_run_name = "Main_ReducedSamples"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
####################################################

#set mode for run, True for fine tuning, False for pretraining 
mode_ft = True
if mode_ft:
    print(f"Fine tuning mode")
    include_pheno = True
else:
    print(f"Pretraining mode")
    include_pheno = False

if device.type == "cuda":
    print(f"Using GPU: {torch.cuda.get_device_name(0)}")
    torch.cuda.empty_cache()
else:
    print("Using CPU")  
    
print(f"\n Retrieving data from: {data_dir}")
print("Loading data...")
NCBI,ab_df = data_loader(include_pheno,threshold_year,data_dir,ab_dir)
print(f"Data correctly loaded, {len(NCBI)} samples found")
print("Creating vocabulary...")
vocabulary_geno = vocab_geno(NCBI, include_pheno)
vocabulary_pheno = vocab_pheno(ab_df)

print(f"Vocabulary created with number of elements:",len(vocabulary_geno))
print(f"Number of antibiotics:",len(vocabulary_pheno))

if limit_data:
    print(f"Reducing samples to {reduced_samples}")
    NCBI = NCBI.head(reduced_samples)

train_indices, val_indices = get_split_indices(len(NCBI), 0.2)
train_set = NCBIDataset(NCBI.iloc[train_indices], vocabulary_geno, vocabulary_pheno, max_length, mask_prob,include_pheno)
val_set = NCBIDataset(NCBI.iloc[val_indices], vocabulary_geno, vocabulary_pheno, max_length, mask_prob,include_pheno)
print(f"Datasets has been created with {len(train_set)} samples in the training set and {len(val_set)} samples in the validation set")

print(f"Creating model...")
if mode_ft:
    model = BERT_ft(vocab_size=len(vocabulary_geno), dim_embedding = dim_emb, dim_hidden=dim_hidden, attention_heads=8, num_encoders=num_encoders, dropout_prob=drop_prob, vocab_size_pheno=len(vocabulary_pheno), device=device)
else:
    model = BERT_pt(vocab_size=len(vocabulary_geno), dim_embedding = dim_emb, attention_heads=attention_heads, num_encoders=2, dropout_prob=drop_prob)
print(f"Model successfully loaded")
print(f"---------------------------------------------------------")
print(f"Starting training...")
if mode_ft:
    trainer = BertTrainer_ft(model, train_set, val_set, epochs, batch_size, lr, device, stop_patience, wandb_mode, wandb_project, wandb_run_name)
else: 
    trainer = BertTrainer_pt(model, train_set, val_set, epochs, batch_size, lr, device, stop_patience, wandb_mode, wandb_project, wandb_run_name)

results = trainer()
print(f"---------------------------------------------------------")
if export_model:
    print(f"Exporting model...")
    export_model_label = str(today)+"model.pkl"
    trainer._save_model(save_directory+"/"+export_model_label)
print("Exporting results...")
export_results_label = str(today)+"run.pkl"
export_results(results, save_directory+"/"+export_results_label)



c:\Users\erikw\Desktop\Exjobb kod\base


UnboundLocalError: local variable 'data_dir' referenced before assignment