<a href="https://colab.research.google.com/github/TurkuNLP/Deep_Learning_in_LangTech_course/blob/master/ex4_parameters.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Setup

In [None]:
!pip3 install -q transformers datasets evaluate accelerate

In [None]:
from pprint import pprint
import logging

logging.disable(logging.INFO)

---
# Download and prepare data

In [None]:
import datasets

dataset = datasets.load_dataset('imdb')
dataset = dataset.shuffle() #This is never a bad idea, datasets may have ordering to them, which is not what we want
del dataset["unsupervised"] # Delete the unlabeled part of the dataset to make things faster

---

# Tokenize and vectorize data

In [None]:
import transformers

model_name = "bert-base-cased"
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)

# Define a simple function that applies the tokenizer
def tokenize(example):
    return tokenizer(
        example["text"],
        max_length=128,
        truncation=True,
    )

# Apply the tokenizer to the whole dataset using .map()
dataset = dataset.map(tokenize)

---

# Define model

(Note that here we define the model structure and computation without setting any parameters yet!)

In [None]:
import torch


# This gives a new name to the config class, just for convenience
BasicConfig = transformers.PretrainedConfig


# This is the model
class SimpleCNN(transformers.PreTrainedModel):

    config_class = BasicConfig

    # In the initialization method, one instantiates the layers
    # these will be the parameters of the model
    def __init__(self, config):
        super().__init__(config)
        # Embedding layer: vocab size x embedding dim
        self.embeddings = torch.nn.Embedding(
            num_embeddings=config.vocab_size,
            embedding_dim=config.embedding_dim
        )
        # Convolution layer: TODO
        self.convolution = torch.nn.Conv1d(
            config.embedding_dim,
            config.num_filters,
            config.filter_size,
            padding=1
        )
        # Activation function following convolution
        self.activation = torch.nn.ReLU()
        # Pooling layer: global max pooling, regardless of input length
        self.pooling_layer = torch.nn.AdaptiveMaxPool1d(
            output_size=1
        )
        # Output layer: num filters to output size
        self.output_layer = torch.nn.Linear(
            in_features=config.num_filters,
            out_features=config.num_labels
        )
        # Loss function: standard loss for classification
        self.loss = torch.nn.CrossEntropyLoss()

    def forward(self, input_ids, labels=None, attention_mask=None):
        #shape of input: [batch_size, maxlen]
        x = self.embeddings(input_ids)
        #shape of x: [batch_size, maxlen, embedding_dim]
        x = x.permute((0,2,1))
        #shape of x: [batch_size, embedding_dim, maxlen]
        x = self.convolution(x)
        #shape of x: [batch_size, filters, maxlen]
        x = self.activation(x)
        #shape of x: [batch_size, filters, maxlen]
        x = self.pooling_layer(x)
        #shape of x: [batch_size, filters, 1]
        x = x.flatten(start_dim=1)
        #shape of x: [batch_size, filters]
        output = self.output_layer(x)

        # Return value computed as in the MLP:
        if labels is not None:
            # We have labels, so we can calculate the loss
            return (self.loss(output,labels), output)
        else:
            # No labels, so just return the output
            return (output,)

---
# Define training support

(Collator, evaluation, Callbacks)

In [None]:
import evaluate

# evaluation
accuracy = evaluate.load("accuracy")

def compute_accuracy(outputs_and_labels):
    outputs, labels = outputs_and_labels
    predictions = outputs.argmax(axis=-1) #pick the index of the "winning" label
    return accuracy.compute(predictions=predictions, references=labels)

# collator
data_collator = transformers.DataCollatorWithPadding(tokenizer)

# Callbacks / logging
from collections import defaultdict

class LogSavingCallback(transformers.TrainerCallback):
    def on_train_begin(self, *args, **kwargs):
        self.logs = defaultdict(list)
        self.training = True

    def on_train_end(self, *args, **kwargs):
        self.training = False

    def on_log(self, args, state, control, logs, model=None, **kwargs):
        if self.training:
            for k, v in logs.items():
                if k != "epoch" or v not in self.logs[k]:
                    self.logs[k].append(v)

---
# Hyperparameter search - First option

In [None]:
for lr in [0.000005, 0.00005, 0.0005, 0.005, 0.05, 0.5]:

    # create the model
    config = BasicConfig(
        vocab_size = tokenizer.vocab_size,
        num_labels = len(set(dataset['train']['label'])),
        embedding_dim = 64,
        filter_size = 3,
        num_filters = 10,
    )

    model = SimpleCNN(config)

    # Set training arguments
    trainer_args = transformers.TrainingArguments(
        "checkpoints",
        eval_strategy="steps",
        logging_strategy="steps",
        load_best_model_at_end=True,
        eval_steps=500,
        logging_steps=500,
        learning_rate=lr, # <--- parameter goes here
        per_device_train_batch_size=8,
        max_steps=2500,
        report_to="none", # skip wandb login
    )

    trainer = transformers.Trainer(
        model=model,
        args=trainer_args,
        train_dataset=dataset["train"],
        eval_dataset=dataset["test"],
        compute_metrics=compute_accuracy,
        data_collator=data_collator,
        callbacks=[transformers.EarlyStoppingCallback(early_stopping_patience=5), LogSavingCallback()]
    )

    trainer.train()
    eval_results = trainer.evaluate(dataset["test"])
    print('Learning rate:', lr, 'Accuracy:', eval_results['eval_accuracy'])

---
# Hyperparameter search – Second option

* Hyperparameter search using [Optuna](https://optuna.org/)

In [None]:
!pip install optuna

In [None]:
import optuna

def objective(trial):
    # Define the search space for hyperparameters
    learning_rate = trial.suggest_float("learning_rate", 5e-4, 5e-2, log=True)
    num_filters = trial.suggest_categorical("num_filters", [10, 16, 24])

    # create the model
    config = BasicConfig(
        vocab_size = tokenizer.vocab_size,
        num_labels = len(set(dataset['train']['label'])),
        embedding_dim = 64,
        filter_size = 3,
        num_filters = num_filters, # <--- parameter goes here
    )

    model = SimpleCNN(config)

    # Set training arguments
    trainer_args = transformers.TrainingArguments(
        "checkpoints",
        eval_strategy="steps",
        logging_strategy="steps",
        load_best_model_at_end=True,
        eval_steps=500,
        logging_steps=500,
        learning_rate=learning_rate, # <--- parameter goes here
        per_device_train_batch_size=8,
        max_steps=2500,
        report_to="none", # skip wandb login
    )

    trainer = transformers.Trainer(
        model=model,
        args=trainer_args,
        train_dataset=dataset["train"],
        eval_dataset=dataset["test"],
        compute_metrics=compute_accuracy,
        data_collator=data_collator,
        callbacks=[transformers.EarlyStoppingCallback(early_stopping_patience=5), LogSavingCallback()]
    )

    trainer.train()
    eval_results = trainer.evaluate(dataset["test"])
    print('Learning rate:', learning_rate, 'Filters:', num_filters, 'Accuracy:', eval_results['eval_accuracy'])
    return eval_results['eval_accuracy']



study = optuna.create_study(direction="maximize")
study.optimize(objective, n_trials=3) # <--- How many trials we run, more would be needed in real case!