In [None]:
!pip install -qr "../requirements.txt"

In [None]:
from datasets import load_dataset

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from torch.utils.data import TensorDataset, DataLoader

import word_embedding_classifier as model
import word_embedding_processor as processor
import word_embedding_evaluation as evaluation

from tqdm import tqdm

import wandb

from sklearn.metrics import accuracy_score
import csv

import optuna

import heapq
import math

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
dataset = load_dataset('ai2_arc', 'ARC-Easy')

In [None]:
X_train_with_stopwords_fastText, y_train_with_stopwords_fastText = processor.process_dataset(dataset['train'],remove_stop_words=False, embedding_model='fasttext')
X_val_with_stopwords_fastText, y_val_with_stopwords_fastText = processor.process_dataset(dataset['validation'], remove_stop_words=False, embedding_model='fasttext')
X_train_without_stopwords_fastText, y_train_without_stopwords_fastText = processor.process_dataset(dataset['train'], remove_stop_words=True, embedding_model='fasttext')
X_val_without_stopwords_fastText, y_val_without_stopwords_fastText = processor.process_dataset(dataset['validation'], remove_stop_words=True, embedding_model='fasttext')

X_train_with_stopwords_word2vec, y_train_with_stopwords_word2vec = processor.process_dataset(dataset['train'],remove_stop_words=False, embedding_model='word2vec')
X_val_with_stopwords_word2vec, y_val_with_stopwords_word2vec = processor.process_dataset(dataset['validation'], remove_stop_words=False, embedding_model='word2vec')
X_train_without_stopwords_word2vec, y_train_without_stopwords_word2vec = processor.process_dataset(dataset['train'], remove_stop_words=True, embedding_model='word2vec')
X_val_without_stopwords_word2vec, y_val_without_stopwords_word2vec = processor.process_dataset(dataset['validation'], remove_stop_words=True, embedding_model='word2vec')

X_train_with_stopwords_glove, y_train_with_stopwords_glove = processor.process_dataset(dataset['train'],remove_stop_words=False, embedding_model='glove')
X_val_with_stopwords_glove, y_val_with_stopwords_glove = processor.process_dataset(dataset['validation'], remove_stop_words=False, embedding_model='glove')
X_train_without_stopwords_glove, y_train_without_stopwords_glove = processor.process_dataset(dataset['train'], remove_stop_words=True, embedding_model='glove')
X_val_without_stopwords_glove, y_val_without_stopwords_glove = processor.process_dataset(dataset['validation'], remove_stop_words=True, embedding_model='glove')

In [None]:
#print(f'Length given train dataset: {len(dataset["train"])}')
#print(f'Length processed train dataset: {len(X_train_with_stopwords)}')
#print()
#print(f'Length given test dataset: {len(dataset["validation"])}')
#print(f'Length processed test dataset: {len(X_val_with_stopwords)}')

In [None]:
train_dataset_with_stopwords_fastText = TensorDataset(torch.tensor(X_train_with_stopwords_fastText).to(device), torch.tensor(y_train_with_stopwords_fastText).to(device))
train_dataloader_with_stopwords_fastText = DataLoader(train_dataset_with_stopwords_fastText, batch_size=len(train_dataset_with_stopwords_fastText), shuffle=True)

train_dataset_without_stopwords_FastText = TensorDataset(torch.tensor(X_train_without_stopwords_fastText).to(device), torch.tensor(y_train_without_stopwords_fastText).to(device))
train_dataloader_without_stopwords_fastText = DataLoader(train_dataset_without_stopwords_FastText, batch_size=len(train_dataset_without_stopwords_FastText), shuffle=True)


train_dataset_with_stopwords_word2vec = TensorDataset(torch.tensor(X_train_with_stopwords_word2vec).to(device), torch.tensor(y_train_with_stopwords_word2vec).to(device))
train_dataloader_with_stopwords_word2vec = DataLoader(train_dataset_with_stopwords_word2vec, batch_size=len(train_dataset_with_stopwords_word2vec), shuffle=True)

train_dataset_without_stopwords_word2vec = TensorDataset(torch.tensor(X_train_without_stopwords_word2vec).to(device), torch.tensor(y_train_without_stopwords_word2vec).to(device))
train_dataloader_without_stopwords_word2vec = DataLoader(train_dataset_without_stopwords_word2vec, batch_size=len(train_dataset_without_stopwords_word2vec), shuffle=True)


train_dataset_with_stopwords_glove = TensorDataset(torch.tensor(X_train_with_stopwords_glove).to(device), torch.tensor(y_train_with_stopwords_glove).to(device))
train_dataloader_with_stopwords_glove = DataLoader(train_dataset_with_stopwords_glove, batch_size=len(train_dataset_with_stopwords_glove), shuffle=True)

train_dataset_without_stopwords_glove = TensorDataset(torch.tensor(X_train_without_stopwords_glove).to(device), torch.tensor(y_train_without_stopwords_glove).to(device))
train_dataloader_without_stopwords_glove = DataLoader(train_dataset_without_stopwords_glove, batch_size=len(train_dataset_without_stopwords_glove), shuffle=True)

In [None]:
loss_fn = nn.CrossEntropyLoss()

output = torch.randn(10, 120).float()
target = torch.randint(120, (10,)).long()
loss = loss_fn(output, target)

In [None]:
#wandb.init(project="hslu-stableconfusion-nlp")
10 < float('nan')

In [None]:
def objective(trial):
    run = wandb.init(project="hslu-stableconfusion-nlp")
    # Define hyperparameters using the trial object
    embedding_dim = 100
    hidden_dim = trial.suggest_int('hidden_dim', 400, 5000)
    output_dim = 4
    num_layers = trial.suggest_int('num_layers', 1, 6)
    dataset_choice = trial.suggest_categorical('dataset', ['with_stopwords', 'without_stopwords'])
    optimizer_name = trial.suggest_categorical('optimizer', ['Adam', 'SGD'])
    lr = trial.suggest_float('lr', 1e-4, 1e-1, log=True)
    activation_function_name = trial.suggest_categorical('activation_function', ['relu', 'sigmoid', 'tanh'])
    word_embedding = trial.suggest_categorical('word_embedding', ['word2vec', 'fasttext'])
    activation_function = getattr(F, activation_function_name)

    wandb.config.update({
        'hidden_dim': hidden_dim,
        'num_layers': num_layers,
        'dataset': dataset_choice,
        'optimizer': optimizer_name,
        'lr': lr,
        'activation_function': activation_function_name,
        'embedding:': word_embedding
    })

    with open(f'fast_text_classifier_trial_{trial.number}.csv', 'w', newline='') as f:
        writer = csv.writer(f)
        writer.writerow(['Trial number', 'hidden_dim', 'num_layers', 'dataset', 'optimizer', 'lr', 'activation_function', 'word_embedding'])
        writer.writerow([trial.number, hidden_dim, num_layers, dataset_choice, optimizer_name, lr, activation_function_name, word_embedding])

    # Choose the dataset
    if word_embedding == 'word2vec':
        if dataset_choice == 'with_stopwords':
            train_dataloader = train_dataloader_with_stopwords_word2vec
            X_val = X_val_with_stopwords_word2vec
            y_val = y_val_with_stopwords_word2vec
        else:
            train_dataloader = train_dataloader_without_stopwords_word2vec
            X_val = X_val_without_stopwords_word2vec
            y_val = y_val_without_stopwords_word2vec

    elif word_embedding == 'fasttext':
        if dataset_choice == 'with_stopwords':
            train_dataloader = train_dataloader_with_stopwords_fastText
            X_val = X_val_with_stopwords_fastText
            y_val = y_val_with_stopwords_fastText
        else:
            train_dataloader = train_dataloader_without_stopwords_fastText
            X_val = X_val_without_stopwords_fastText
            y_val = y_val_without_stopwords_fastText
    elif word_embedding == 'glove':
        if dataset_choice == 'with_stopwords':
            train_dataloader = train_dataloader_with_stopwords_glove
            X_val = X_val_with_stopwords_glove
            y_val = y_val_with_stopwords_glove
        else:
            train_dataloader = train_dataloader_without_stopwords_glove
            X_val = X_val_without_stopwords_glove
            y_val = y_val_without_stopwords_glove

    # Create the model
    net = model.Word_embedding_classifier(embedding_dim, hidden_dim, output_dim, num_layers, activation_function).to(device)

    # Define your loss function and optimizer
    criterion = nn.CrossEntropyLoss()
    if optimizer_name == 'Adam':
        optimizer = optim.Adam(net.parameters(), lr=lr)
    else:
        optimizer = optim.SGD(net.parameters(), lr=lr)

    best_train_loss = float('inf')
    best_val_loss = float('inf')

    val_los_exit_counter = 0
    val_los_exit_threshold = 20

    epochs_since_improvement = 0
    improvment_limit = 100

    previous_val_train_loss_divergence = float('inf')
    epochs_since_divergence = 0
    divergence_limit = 100

    numb_epochs = 1500

    complete_vall_loss = []
  
    for epoch in range(numb_epochs): 
        total_train_loss = 0
        num_iterations = 0
        sum_accuracy = 0

        probress_bar = tqdm(train_dataloader)

        for input, label in probress_bar:
            label = label.type(torch.float32)
            output = net(input[:,0], input[:,1], input[:,2], input[:,3], input[:,4])

            loss = criterion(output, label)

            sum_accuracy += evaluation.calculate_accuracy_from_predicitions(output, label)

            total_train_loss += loss.item()
            num_iterations += 1

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            probress_bar.set_description('Train loss: %.6f' % (loss))
            probress_bar.update()

        accuracy = sum_accuracy / num_iterations

        train_loss = total_train_loss / num_iterations

        wandb.log({"Train Loss": train_loss, "Train Accuracy": accuracy})

        net.eval()

        with torch.no_grad():
            val_labels = torch.asarray(y_val).to(device).type(torch.float32)
            val_inputs = torch.asarray(X_val).to(device).type(torch.float32)

            val_outputs = net(val_inputs[:,0], val_inputs[:,1], val_inputs[:,2], val_inputs[:,3], val_inputs[:,4])

            val_loss = criterion(val_outputs, val_labels)

            val_accuracy_score = evaluation.calculate_accuracy_from_predicitions(val_outputs, val_labels)

        if val_loss < best_val_loss:
            torch.save(net.state_dict(), f'fast_text_classifier_trial_{trial.number}.pth')
            best_val_loss = val_loss
            val_los_exit_counter = 0
        elif val_loss < best_val_loss + 0.05:
            val_los_exit_counter = 0
        else:
            val_los_exit_counter += 1

        if train_loss < best_train_loss:
            best_train_loss = train_loss
            epochs_since_improvement = 0
        else:
            epochs_since_improvement += 1

        if abs(val_loss - train_loss) > previous_val_train_loss_divergence:
            epochs_since_divergence += 1
        else:
            epochs_since_divergence = 0
        previous_val_train_loss_divergence = abs(val_loss - train_loss)
        
        wandb.log({"Validation Loss": val_loss.item(), "Validation Accuracy": val_accuracy_score})
        complete_vall_loss.append(val_loss.item())
        
        if val_los_exit_counter >= val_los_exit_threshold or epochs_since_improvement >= improvment_limit or epochs_since_divergence >= divergence_limit:
            print("Early stopping due to validation loss not impoving, train loss not improving or train and validation loss diverging")
            break

        net.train()

    val_losses_tensor = torch.tensor(complete_vall_loss)
    val_losses_tensor, _ = torch.sort(val_losses_tensor)
    lowest_10_percent = int(len(val_losses_tensor) * 0.1)
    average_best_10_percent_val_loss = torch.mean(val_losses_tensor[:lowest_10_percent])
    return average_best_10_percent_val_loss


study = optuna.create_study(direction='minimize')
study.optimize(objective, n_trials=70)

print('Best trial:')
trial = study.best_trial
print('  Value: ', trial.value)
print('  Params: ')
for key, value in trial.params.items():
    print('    {}: {}'.format(key, value))

with open('best_config.txt', 'w') as f:
    f.write('Best trial:\n')
    f.write('  Trial number: {}\n'.format(trial.number))
    f.write('  Value: {}\n'.format(trial.value))
    f.write('  Params:\n')
    for key, value in trial.params.items():
        f.write('    {}: {}\n'.format(key, value))