In [None]:
import abc
import dataclasses
import functools
import inspect
from abc import ABC
from dataclasses import InitVar, dataclass, field
from typing import Any, Dict, Final, List, Optional, Set, Tuple, Type, TypeVar, Union

import datasets
import numpy as np
import optuna
import pytorch_lightning as pl
import torch
from embeddings.data.huggingface_datamodule import (
    HuggingFaceDataset,
    TextClassificationDataModule,
)
from embeddings.embedding.auto_flair import AutoFlairWordEmbedding
from embeddings.embedding.flair_embedding import FlairTransformerEmbedding
from embeddings.embedding.static.embedding import StaticEmbedding
from embeddings.evaluator.text_classification_evaluator import (
    TextClassificationEvaluator,
)
from embeddings.hyperparameter_search.configspace import (
    BaseConfigSpace,
    Parameter,
    ParsedParameters,
    SampledParameters,
)
from embeddings.hyperparameter_search.parameters import (
    ConstantParameter,
    SearchableParameter,
)
from embeddings.model.lightning.auto_lightning import AutoTransformer
from embeddings.utils.utils import PrimitiveTypes
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.utilities.types import STEP_OUTPUT
from torch.optim import AdamW, Optimizer
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader
from torchmetrics import F1, Accuracy, MetricCollection, Precision, Recall
from transformers import AutoConfig, AutoModelForSequenceClassification

In [None]:
# print(inspect.getsource(TextClassificationDataModule))

In [None]:
# print(inspect.getsource(AutoTransformer))

### Define necessary model classes

In [None]:
class SequenceClassificationModule(pl.LightningModule, abc.ABC):
    def __init__(
        self,
        num_labels: int,
        metrics: Optional[MetricCollection] = None,
        learning_rate: float = 1e-4,
        adam_epsilon: float = 1e-8,
        warmup_steps: int = 0,
        weight_decay: float = 0.0,
        train_batch_size: int = 32,
        eval_batch_size: int = 32,
        use_scheduler: bool = False,
        **kwargs: Any,
    ) -> None:
        super().__init__()
        self.save_hyperparameters()

        if metrics is None:
            metrics = self.get_default_metrics(num_labels=num_labels)
        self.train_metrics = metrics.clone(prefix="train/")
        self.val_metrics = metrics.clone(prefix="val/")
        self.test_metrics = metrics.clone(prefix="test/")

    @staticmethod
    def get_default_metrics(num_labels: int) -> MetricCollection:
        if num_labels > 2:
            metrics = MetricCollection(
                [
                    Accuracy(num_classes=num_labels),
                    Precision(num_classes=num_labels, average="macro"),
                    Recall(num_classes=num_labels, average="macro"),
                    F1(num_classes=num_labels, average="macro"),
                ]
            )
        else:
            metrics = MetricCollection(
                [
                    Accuracy(num_classes=num_labels),
                    Precision(num_classes=num_labels),
                    Recall(num_classes=num_labels),
                    F1(num_classes=num_labels),
                ]
            )
        return metrics

    def shared_step(self, **batch: Any) -> Tuple[torch.Tensor, torch.Tensor]:
        outputs = self.forward(**batch)
        loss, logits = outputs[:2]
        preds = torch.argmax(logits, dim=1)
        # loss_fn = torch.nn.CrossEntropyLoss()
        # loss = loss_fn(logits.view(-1, self.hparams.num_labels), batch["labels"].view(-1))
        return loss, logits

    def training_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
        batch, batch_idx = args
        loss, preds = self.shared_step(**batch)
        self.train_metrics(preds, batch["labels"])
        lr = (
            self.lr_scheduler.get_last_lr()[-1]
            if self.hparams.use_scheduler
            else self.hparams.learning_rate
        )
        self.log("train/Loss", loss, on_step=True, on_epoch=True)
        self.log("train/LR", lr, on_step=True, on_epoch=True)
        return {"loss": loss, "lr": lr}

    def validation_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]:
        batch, batch_idx = args
        loss, preds = self.shared_step(**batch)
        self.val_metrics(preds, batch["labels"])
        self.log("val/Loss", loss, on_epoch=True)
        return None

    def test_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]:
        batch, batch_idx = args
        loss, preds = self.shared_step(**batch)
        if -1 not in batch["labels"]:
            self.test_metrics(preds, batch["labels"])
            self.log("test/Loss", loss, on_epoch=True)
        return None

    def training_epoch_end(self, outputs: List[Any]) -> None:
        self._aggregate_and_log_metrics(self.train_metrics)

    def validation_epoch_end(self, outputs: List[Any]) -> None:
        metric_values = self._aggregate_and_log_metrics(self.val_metrics, prog_bar=True)

    #         self.log('val_f1', metric_values['val/F1'], prog_bar=True, logger=False)

    def test_epoch_end(self, outputs: List[Any]) -> None:
        self._aggregate_and_log_metrics(self.test_metrics)

    def _aggregate_and_log_metrics(
        self, metrics: MetricCollection, prog_bar: bool = False
    ) -> Dict[str, float]:
        metric_values = metrics.compute()
        metrics.reset()
        self.log_dict(metric_values, prog_bar=prog_bar)
        return metric_values

    def setup(self, stage: Optional[str] = None) -> None:
        if stage == "fit" and self.hparams.use_scheduler:
            # Get dataloader by calling it - train_dataloader() is called after setup() by default
            train_loader = self.trainer.datamodule.train_dataloader()

            # Calculate total steps
            tb_size = self.hparams.train_batch_size * max(1, self.trainer.gpus)
            ab_size = self.trainer.accumulate_grad_batches * float(
                self.trainer.max_epochs
            )
            self.total_steps = (len(train_loader.dataset) // tb_size) // ab_size

    def configure_optimizers(self) -> Tuple[List[Optimizer], List[Any]]:
        """Prepare optimizer and schedule (linear warmup and decay)"""
        no_decay = ["bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {
                "params": [
                    p
                    for n, p in self.named_parameters()
                    if not any(nd in n for nd in no_decay)
                ],
                "weight_decay": self.hparams.weight_decay,
            },
            {
                "params": [
                    p
                    for n, p in self.named_parameters()
                    if any(nd in n for nd in no_decay)
                ],
                "weight_decay": 0.0,
            },
        ]
        optimizer = AdamW(
            optimizer_grouped_parameters,
            lr=self.hparams.learning_rate,
            eps=self.hparams.adam_epsilon,
        )

        if self.hparams.use_scheduler:

            def lr_lambda(step: int) -> float:
                if self.hparams.warmup_steps > 0 and step < self.hparams.warmup_steps:
                    return step / self.hparams.warmup_steps
                else:
                    return (self.total_steps - step) / (
                        self.total_steps - self.hparams.warmup_steps
                    )

            self.lr_scheduler = LambdaLR(optimizer, lr_lambda)
            lr_scheduler = [self.lr_scheduler]
        else:
            lr_scheduler = []

        return [optimizer], lr_scheduler

In [None]:
class AutoTransformerForSequenceClassification(
    AutoTransformer, SequenceClassificationModule
):
    def __init__(
        self,
        model_name_or_path: str,
        num_labels: int,
        unfreeze_from: Optional[int] = None,
        **kwargs: Any
    ):
        super().__init__(
            model_name_or_path,
            AutoModelForSequenceClassification,
            num_labels,
            unfreeze_from,
            **kwargs
        )

    def predict(
        self, dataloader: DataLoader[HuggingFaceDataset]
    ) -> Dict[str, np.ndarray]:
        predictions = []
        for batch in dataloader:
            outputs = self.forward(**batch)
            predictions.append(outputs.logits)
        predictions = torch.argmax(torch.cat(predictions), dim=1).numpy()
        ground_truth = torch.cat([x["labels"] for x in dataloader]).numpy()
        return {"y_pred": predictions, "y_true": ground_truth}

In [None]:
# print(inspect.getsource(AutoTransformer))

## Test run (loss should decrease)

In [None]:
embedding_name = "allegro/herbert-base-cased"
dataset_name = "clarin-pl/polemo2-official"
input_column_name = ["text"]
target_column_name = "target"
output_path = None
load_dataset_kwargs = {
    "train_domains": ["hotels", "medicine"],
    "dev_domains": ["hotels", "medicine"],
    "test_domains": ["hotels", "medicine"],
    "text_cfg": "text",
}
task_train_kwargs = {"max_epochs": 5, "gpus": 1}
task_model_kwargs = {"learning_rate": 1e-3}

In [None]:
dm = TextClassificationDataModule(
    model_name_or_path=embedding_name,
    dataset_name=dataset_name,
    text_fields=input_column_name,
    target_field=target_column_name,
    load_dataset_kwargs=load_dataset_kwargs,
)
dm.initalize()

In [None]:
model = AutoTransformerForSequenceClassification(
    model_name_or_path=embedding_name,
    input_dim=dm.input_dim,
    num_labels=dm.num_labels,
    **task_model_kwargs,
)
trainer = pl.Trainer(default_root_dir=output_path, **task_train_kwargs)
evaluator = TextClassificationEvaluator()

In [None]:
dm.setup("fit")
trainer.fit(model, dm)
trainer.test(datamodule=dm)
model_result = model.predict(dataloader=dm.test_dataloader())
metrics = evaluator.evaluate(model_result)

## Define config space

In [None]:
@dataclass
class ConfigSpace(BaseConfigSpace):
    max_epochs: Parameter = SearchableParameter(
        name="max_epochs",
        type="categorical",
        choices=[20, 25, 30],
    )
    mini_batch_size: Parameter = SearchableParameter(
        name="batch_size",
        type="categorical",
        choices=[8, 16, 32, 48, 64],
    )
    max_seq_length: Parameter = SearchableParameter(
        name="max_seq_length",
        type="categorical",
        choices=[128, 256],
    )
    optimizer: Parameter = SearchableParameter(
        name="optimizer",
        type="categorical",
        choices=["Adam", "AdamW"],
    )
    use_scheduler: Parameter = SearchableParameter(
        name="use_scheduler",
        type="categorical",
        choices=[False, True],
    )
    warmup_steps: Parameter = SearchableParameter(
        name="warmup_steps", type="int_uniform", low=0, high=200, step=10
    )
    learning_rate: Parameter = SearchableParameter(
        name="learning_rate", type="uniform", low=1e-5, high=1e-3
    )
    adam_epsilon: Parameter = SearchableParameter(
        name="adam_epsilon", type="uniform", low=1e-9, high=1e-6
    )
    weight_decay: Parameter = SearchableParameter(
        name="weight_decay", type="uniform", low=0.0, high=1e-2
    )
    unfreeze_from: Parameter = SearchableParameter(
        name="unfreeze_from",
        type="categorical",
        choices=[-1, 4, 7, 9, 11, None],
    )

    @staticmethod
    def parse_parameters(parameters: Dict[str, PrimitiveTypes]) -> SampledParameters:
        dl_model_keys: Final = {"batch_size", "max_seq_length"}
        task_model_keys: Final = {
            "learning_rate",
            "unfreeze_from",
            "optimizer",
            "use_scheduler",
            "warmup_steps",
            "adam_epsilon",
            "weight_decay",
        }
        task_trainer_keys: Final = {
            "max_epochs",
        }
        dl_model_kwargs = BaseConfigSpace._pop_parameters(
            parameters=parameters, parameters_keys=dl_model_keys
        )
        task_model_kwargs = BaseConfigSpace._pop_parameters(
            parameters=parameters, parameters_keys=task_model_keys
        )
        task_trainer_kwargs = BaseConfigSpace._pop_parameters(
            parameters=parameters, parameters_keys=task_trainer_keys
        )
        task_trainer_kwargs = {}
        task_model_kwargs["train_batch_size"] = dl_model_kwargs["batch_size"]
        task_model_kwargs["eval_batch_size"] = dl_model_kwargs["batch_size"]

        return dl_model_kwargs, task_model_kwargs, task_trainer_kwargs

In [None]:
def get_best_paramaters(study):
    best_params = study.best_params
    constant_params = study.best_trial.user_attrs
    parsed_params = best_params | constant_params
    return parsed_params


def objective(trial, cs, dataset) -> float:
    parameters = cs.sample_parameters(trial=trial)
    dl_model_kwargs, task_model_kwargs, task_trainer_kwargs = cs.parse_parameters(
        parameters
    )
    print("params", dl_model_kwargs, task_model_kwargs, task_trainer_kwargs)

    dm = TextClassificationDataModule(
        model_name_or_path="allegro/herbert-base-cased",
        dataset_name="clarin-pl/polemo2-official",
        text_fields=["text"],
        target_field="target",
        load_dataset_kwargs={
            "train_domains": ["hotels", "medicine"],
            "dev_domains": ["hotels", "medicine"],
            "test_domains": ["hotels", "medicine"],
            "text_cfg": "text",
        },
        ignore_test_split=True,
        **dl_model_kwargs,
    )
    dm.initalize()

    model = AutoTransformerForSequenceClassification(
        model_name_or_path="allegro/herbert-base-cased",
        input_dim=dm.input_dim,
        num_labels=dm.num_labels,
        **task_model_kwargs,
    )

    trainer = pl.Trainer(
        gpus=1,
        callbacks=[EarlyStopping(monitor="val/F1", verbose=True, patience=5, min_delta=0.01, mode="max")],
        max_epochs=20,
        **task_trainer_kwargs,
    )

    try:
        trainer.fit(model, dm)
    except Exception as e:
        del model
        del dm
        del trainer
        torch.cuda.empty_cache()
        raise e
    evaluator = TextClassificationEvaluator()

    model_result = model.predict(dataloader=dm.val_dataloader())
    metrics = self.evaluator.evaluate(model_result)
    del dm
    del model
    del trainer
    torch.cuda.empty_cache()
    return metrics["f1__average_macro"]["f1"]


def run(n_trials):
    dataset = dataset = datasets.load_dataset(
        "clarin-pl/polemo2-official",
        **{
            "train_domains": ["hotels", "medicine"],
            "dev_domains": ["hotels", "medicine"],
            "test_domains": ["hotels", "medicine"],
            "text_cfg": "text",
        },
    )

    cs = ConfigSpace()
    study = optuna.create_study(
        direction="maximize",
        sampler=optuna.samplers.TPESampler(),
        pruner=optuna.pruners.MedianPruner(),
    )

    study.optimize(
        lambda trial: objective(trial, cs, dataset),
        n_trials=n_trials,
        show_progress_bar=True,
        catch=(Exception,),
    )

    return study.trials_dataframe(), get_best_paramaters(study)

## Run hyperparameter search

In [None]:
torch.cuda.empty_cache()

In [None]:
df, metadata = run(n_trials=100)

In [None]:
metadata

### Allegro default config

In [None]:
pl.seed_everything(42)

dm = TextClassificationDataModule(
    model_name_or_path="allegro/herbert-base-cased",
    dataset_name="clarin-pl/polemo2-official",
    text_fields=["text"],
    target_field="target",
    load_dataset_kwargs={
        "train_domains": ["hotels", "medicine"],
        "dev_domains": ["hotels", "medicine"],
        "test_domains": ["hotels", "medicine"],
        "text_cfg": "text",
    },
    ignore_test_split=False,
    batch_size=16,
    max_seq_length=256,
)
dm.initalize()

model = AutoTransformerForSequenceClassification(
    model_name_or_path="allegro/herbert-base-cased",
    input_dim=dm.input_dim,
    num_labels=dm.num_labels,
    **{
        "learning_rate": 2e-5,
        "unfreeze_from": -1,  # change to -1
        "weight_decay": 0.0,
        "adam_epsilon": 1e-8,
        "use_scheduler": True,
        "warmup_steps": 100,
    },
)

trainer = pl.Trainer(
    gpus=1,
    **{"max_epochs": 4, "accumulate_grad_batches": 2},
)

trainer.fit(model, dm)
evaluator = TextClassificationEvaluator()
model_result = model.predict(dataloader=dm.val_dataloader())
metrics = self.evaluator.evaluate(model_result)

pl.reset_seed()

In [None]:
metrics

## klejbenchmark-baselines data

**Move notebook or folder or change imports here in order to run following cells!**

In [None]:
import os
import random

from klejbenchmark_baselines.config import Config
from klejbenchmark_baselines.dataset import Datasets
from klejbenchmark_baselines.model import KlejTransformer
from klejbenchmark_baselines.task import TASKS
from klejbenchmark_baselines.trainer import TrainerWithPredictor
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger

In [None]:
args_dict = {
    "task_name": "polemo2.0-in",
    "run_id": None,    # input run_id
    "task_path": None,    # input task_path
    "predict_path": None,   # input predict_path
    "logger_path": "output/tb/",
    "checkpoint_path": "output/checkpoints/",
    "tokenizer_name_or_path": "allegro/herbert-base-cased",
    "max_seq_length": 256,
    "do_lower_case": None,
    "model_name_or_path": "allegro/herbert-base-cased",
    "learning_rate": None,
    "adam_epsilon": None,
    "warmup_steps": None,
    "batch_size": 16,
    "gradient_accumulation_steps": 2,
    "num_train_epochs": None,
    "weight_decay": None,
    "max_grad_norm": None,
    "seed": None,
    "num_workers": None,
    "num_gpu": 1,
}

In [None]:
def set_seed(seed: int, num_gpu: int) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if num_gpu > 0:
        torch.cuda.manual_seed_all(seed)

In [None]:
config = Config.from_argparse(args_dict)
task = TASKS[args_dict["task_name"]](config)
datasets = Datasets(task)

set_seed(config.seed, config.num_gpu)
model = KlejTransformer(task, datasets)

# train
logger = TensorBoardLogger(
    save_dir=config.logger_path,
    name=config.run_id,
    version=config.task_name,
)

trainer = TrainerWithPredictor(
    weights_summary=None,
    logger=logger,
    accumulate_grad_batches=config.gradient_accumulation_steps,
    gradient_clip_val=config.max_grad_norm,
    max_epochs=config.num_train_epochs,
    gpus=config.num_gpu,
    **({"distributed_backend": "ddp"} if config.num_gpu > 1 else {}),
)