# Set up environment

Don't run the following Cell if you are using local machine

In [None]:
!git clone https://{GITHUB_ACCESS_TOKEN}@github.com/AliMohseninejad/ganbert-classifier.git
!rm ganbert-classifier/Codes/main.ipynb
!cp -r ganbert-classifier/Codes/data/ ./
!cp -r ganbert-classifier/Codes/evaluation/ ./
!cp -r ganbert-classifier/Codes/model/ ./
!cp -r ganbert-classifier/Codes/training/ ./
!cp -r ganbert-classifier/Dataset/ ../
!cp -r ganbert-classifier/Plots/ ../

In [None]:
!pip install -qU transformers

Run the following cell only if you are using google colab.
The dataset should be available on your google drive.

In [None]:
from google.colab import drive
drive.mount('/content/drive')
dataset_path = "drive/MyDrive/"

Run the following cell only if you are using Kaggle. The dataset should be first uploaded to Kaggle as the "subtaskB" dataset.

In [None]:
dataset_path = "/kaggle/input/subtaskB/"

Run the following cell only if you are using local machine. The dataset should be in the "Dataset" folder.

In [None]:
dataset_path = "../Dataset/"

# Vanilla BERT

In [None]:
import torch
import os

from data.data_loader import generate_dataloader
from model.bert import get_bert_model, get_tokenizer
from model.discriminator import Discriminator
from model.generator1 import Generator
from training.train import train_vanilla_classier, train_gan

In [None]:
train_batch_size = 4
val_test_batch_size = 4
epochs = 10
learning_rate = 5e-5
model_name = "bert-base-cased"
bert_tokenizer, bert_config = get_tokenizer(model_name=model_name)

In [None]:
if not os.path.exists("../Plots/vanilla-bert/"):
    os.mkdir(path="../Plots/vanilla_bert/")

In [None]:
for unsupervised_ratio in [0.99, 0.95, 0.90, 0.50]:
    bert_save_path = f"../Plots/vanilla-bert/bert_{int(100*(1-unsupervised_ratio))}sup.pth"
    discriminator_save_path = (
        f"../Plots/vanilla-bert/discriminator_{int(100*(1-unsupervised_ratio))}sup.pth"
    )

    bert_model, _ = get_bert_model(model_name=model_name)
    classifier = Discriminator()

    # Get dataloaders
    train_dataloader, val_dataloader, test_dataloader = generate_dataloader(
        dataset_folder_path=dataset_path,
        unsupervised_ratio=unsupervised_ratio,
        tokenizer=bert_tokenizer,
        train_batch_size=train_batch_size,
        valid_batch_size=val_test_batch_size,
        test_batch_size=val_test_batch_size,
        use_bow_dataset=False,
        random_seed=42,
    )

    # Define optimizer
    model_params = [v for v in bert_model.parameters()] + [
        v for v in classifier.parameters()
    ]
    optimizer = torch.optim.AdamW(model_params, lr=learning_rate)

    # Train the model
    bert_model, classifier, vanilla_training_results = train_vanilla_classier(
        transformer=bert_model,
        classifier=classifier,
        optimizer=optimizer,
        epochs=epochs,
        scheduler=None,
        train_dataloader=train_dataloader,
        validation_dataloader=val_dataloader,
        bert_save_path=bert_save_path,
        discriminator_save_path=discriminator_save_path,
    )

    # Test the model

    # Visualize results


# GAN-BERT

## G1

In [None]:
train_batch_size = 4
val_test_batch_size = 4
epochs = 10
learning_rate_discriminator = 5e-5
learning_rate_generator = 5e-5
model_name = "bert-base-cased"
bert_tokenizer, bert_config = get_tokenizer(model_name=model_name)

In [None]:
if not os.path.exists("../Plots/generator1/"):
    os.mkdir(path="../Plots/generator1/")

In [None]:
for unsupervised_ratio in [0.99, 0.95, 0.90, 0.50]:
    bert_save_path = (
        f"../Plots/generator1/bert_{int(100*(1-unsupervised_ratio))}sup.pth"
    )
    discriminator_save_path = (
        f"../Plots/generator1/discriminator_{int(100*(1-unsupervised_ratio))}sup.pth"
    )
    generator_save_path = (
        f"../Plots/generator1/generator_{int(100*(1-unsupervised_ratio))}sup.pth"
    )

    bert_model, _ = get_bert_model(model_name=model_name)
    classifier = Discriminator()
    generator = Generator()

    # Get dataloaders
    train_dataloader, val_dataloader, test_dataloader = generate_dataloader(
        dataset_folder_path=dataset_path,
        unsupervised_ratio=unsupervised_ratio,
        tokenizer=bert_tokenizer,
        train_batch_size=train_batch_size,
        valid_batch_size=val_test_batch_size,
        test_batch_size=val_test_batch_size,
        use_bow_dataset=False,
        random_seed=42,
    )

    # Define optimizers
    discriminator_params = [v for v in bert_model.parameters()] + [
        v for v in classifier.parameters()
    ]
    generator_params = [v for v in generator.parameters()]
    d_optimizer = torch.optim.AdamW(
        discriminator_params, lr=learning_rate_discriminator
    )
    g_optimizer = torch.optim.AdamW(generator_params, lr=learning_rate_generator)

    bert_model, generator, classifier, gan1_training_results = train_gan(
        transformer=bert_model,
        generator=generator,
        discriminator=classifier,
        bow_mode=False,
        generator_optimizer=g_optimizer,
        discriminator_optimizer=d_optimizer,
        epochs=epochs,
        generator_scheduler=None,
        discriminator_scheduler=None,
        train_dataloader=train_dataloader,
        validation_dataloader=val_dataloader,
        bert_save_path=bert_save_path,
        discriminator_save_path=discriminator_save_path,
        generator_save_path=generator_save_path,
    )

    # Test the model

    # Visualize results

