In [None]:
from typing import Any, Dict, List, Optional, TypeVar, Union

import datasets
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, BatchEncoding

from embeddings.utils.loggers import get_logger

HuggingFaceDataset = TypeVar("HuggingFaceDataset")

_logger = get_logger(__name__)


class TextClassificationDataModule(pl.LightningDataModule):
    loader_columns = [
        "datasets_idx",
        "input_ids",
        "token_type_ids",
        "attention_mask",
        "start_positions",
        "end_positions",
        "labels",
    ]

    def __init__(
        self,
        model_name_or_path: str,
        dataset_name: str,
        text_fields: Union[str, List[str]],
        target_field: str,
        max_seq_length: int = 128,
        batch_size: int = 128,
        load_dataset_kwargs: Optional[Dict[str, Any]] = None,
        ignore_test_split: bool = True,
        **kwargs: Any,
    ):
        if isinstance(text_fields, str):
            text_fields = [text_fields]
        assert 1 <= len(text_fields) <= 2
        self.model_name_or_path = model_name_or_path
        self.dataset_name = dataset_name
        self.text_fields = text_fields
        self.target_field = target_field
        self.max_seq_length = max_seq_length
        self.train_batch_size = batch_size
        self.eval_batch_size = batch_size
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True)
        self.initialized = False
        if load_dataset_kwargs is None:
            self.load_dataset_kwargs = {}
        else:
            self.load_dataset_kwargs = load_dataset_kwargs
        self.ignore_test_split = ignore_test_split

            
        self.dataset = None
    @property
    def input_dim(self) -> int:
        # TODO: hardcoded for now because text pairs are encoded together
        return 768

    @property
    def output_dim(self) -> Optional[int]:
        if not self.initialized:
            _logger.warning("Datamodule not initialized. Returning None.")
            return None
        else:
            return self.num_labels

    def setup(self, stage: Optional[str] = None) -> None:
        if not self.initialized:
            self.initalize()
        if stage == "fit":
            for split in self.dataset.keys():
                if split == "test" and self.ignore_test_split:
                    continue
                
                self.dataset[split] = self.dataset[split].map(
                    self.convert_to_features,
                    batched=True,
                    remove_columns=[self.target_field],
                )
                self.columns = [
                    c for c in self.dataset[split].column_names if c in self.loader_columns
                ]
                self.dataset[split].set_format(type="torch", columns=self.columns)

    def initalize(self) -> None:
        self.dataset = datasets.load_dataset(self.dataset_name, **self.load_dataset_kwargs)
        self.num_labels = len(set(ex[self.target_field] for ex in self.dataset["train"]))
        self.initialized = True

    def prepare_data(self) -> None:
        pass

    def train_dataloader(self) -> DataLoader[HuggingFaceDataset]:
        return DataLoader(self.dataset["train"], batch_size=self.train_batch_size)

    def val_dataloader(self) -> DataLoader[HuggingFaceDataset]:
        if "validation" in self.dataset:
            return DataLoader(self.dataset["validation"], batch_size=self.eval_batch_size)
        # else:
        #     raise AttributeError("Validation dataset not available")

    def test_dataloader(self) -> DataLoader[HuggingFaceDataset]:
        if "test" in self.dataset and self.ignore_test_split == False:
            return DataLoader(self.dataset["test"], batch_size=self.eval_batch_size)
        else:
            raise AttributeError("Test dataset not available")

    def convert_to_features(
        self, example_batch: Dict[str, Any], indices: Optional[List[int]] = None
    ) -> BatchEncoding:
        # Either encode single sentence or sentence pairs
        if len(self.text_fields) > 1:
            texts_or_text_pairs = list(
                zip(example_batch[self.text_fields[0]], example_batch[self.text_fields[1]])
            )
        else:
            texts_or_text_pairs = example_batch[self.text_fields[0]]

        # Tokenize the text/text pairs
        features = self.tokenizer.batch_encode_plus(
            texts_or_text_pairs,
            max_length=self.max_seq_length,
            pad_to_max_length=True,
            truncation=True,
        )

        # Rename label to labels to make it easier to pass to model forward
        features["labels"] = example_batch[self.target_field]

        return features

In [None]:
import abc
from typing import Any, Dict, List, Optional, Tuple

import pytorch_lightning as pl
import torch
from pytorch_lightning.utilities.types import STEP_OUTPUT
from torch.optim import AdamW, Optimizer
from torchmetrics import F1, Accuracy, MetricCollection, Precision, Recall
from transformers import get_linear_schedule_with_warmup


class TextClassificationTransformer(pl.LightningModule, abc.ABC):
    def __init__(
        self,
        num_labels: int,
        metrics: Optional[MetricCollection] = None,
        learning_rate: float = 1e-4,
        adam_epsilon: float = 1e-8,
        warmup_steps: int = 0,
        weight_decay: float = 0.0,
        train_batch_size: int = 64,
        eval_batch_size: int = 32,
        **kwargs: Any,
    ) -> None:
        super().__init__()
        self.save_hyperparameters()

        if metrics is None:
            metrics = self.get_default_metrics(num_labels=num_labels)
        self.train_metrics = metrics.clone(prefix="train/")
        self.val_metrics = metrics.clone(prefix="val/")
        self.test_metrics = metrics.clone(prefix="test/")

    @staticmethod
    def get_default_metrics(num_labels: int) -> MetricCollection:
        if num_labels > 2:
            metrics = MetricCollection(
                [
                    Accuracy(num_classes=num_labels),
                    Precision(num_classes=num_labels, average="macro"),
                    Recall(num_classes=num_labels, average="macro"),
                    F1(num_classes=num_labels, average="macro"),
                ]
            )
        else:
            metrics = MetricCollection(
                [
                    Accuracy(num_classes=num_labels),
                    Precision(num_classes=num_labels),
                    Recall(num_classes=num_labels),
                    F1(num_classes=num_labels),
                ]
            )
        return metrics

    def shared_step(self, **batch: Any) -> Tuple[torch.Tensor, torch.Tensor]:
        logits = self.forward(**batch)
        loss_fn = torch.nn.CrossEntropyLoss()
        loss = loss_fn(logits.view(-1, self.hparams.num_labels), batch["labels"].view(-1))
        return loss, logits

    def training_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
        batch, batch_idx = args
        loss, preds = self.shared_step(**batch)
        self.train_metrics(preds, batch["labels"])
        self.log("train/Loss", loss, on_step=True, on_epoch=True)
        return {"loss": loss}

    def validation_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]:
        batch, batch_idx = args
        loss, preds = self.shared_step(**batch)
        self.val_metrics(preds, batch["labels"])
        self.log("val/Loss", loss, on_epoch=True)
        return None

    def test_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]:
        batch, batch_idx = args
        loss, preds = self.shared_step(**batch)
        if -1 not in batch["labels"]:
            self.test_metrics(preds, batch["labels"])
            self.log("test/Loss", loss, on_epoch=True)
        return None

    def training_epoch_end(self, outputs: List[Any]) -> None:
        self._aggregate_and_log_metrics(self.train_metrics)

    def validation_epoch_end(self, outputs: List[Any]) -> None:
        self._aggregate_and_log_metrics(self.val_metrics, prog_bar=True)

    def test_epoch_end(self, outputs: List[Any]) -> None:
        self._aggregate_and_log_metrics(self.test_metrics)

    def _aggregate_and_log_metrics(
        self, metrics: MetricCollection, prog_bar: bool = False
    ) -> Dict[str, float]:
        metric_values = metrics.compute()
        metrics.reset()
        self.log_dict(metric_values, prog_bar=prog_bar)
        return metric_values

    def setup(self, stage: Optional[str] = None) -> None:
        if stage == "fit":
            # Get dataloader by calling it - train_dataloader() is called after setup() by default
            train_loader = self.trainer.datamodule.train_dataloader()

            # Calculate total steps
            tb_size = self.hparams.train_batch_size * max(1, self.trainer.gpus)
            ab_size = self.trainer.accumulate_grad_batches * float(self.trainer.max_epochs)
            self.train_steps = (len(train_loader.dataset) // tb_size) // ab_size

    def configure_optimizers(self) -> Tuple[List[Optimizer], List[Any]]:
        """Prepare optimizer and schedule (linear warmup and decay)"""
        no_decay = ["bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {
                "params": [
                    p for n, p in self.named_parameters() if not any(nd in n for nd in no_decay)
                ],
                "weight_decay": self.hparams.weight_decay,
            },
            {
                "params": [
                    p for n, p in self.named_parameters() if any(nd in n for nd in no_decay)
                ],
                "weight_decay": 0.0,
            },
        ]
        optimizer = AdamW(
            optimizer_grouped_parameters,
            lr=self.hparams.learning_rate,
            eps=self.hparams.adam_epsilon,
        )

        return [optimizer], []

In [None]:
from collections import ChainMap
from typing import Any, Literal

from torch import nn
from transformers import AutoConfig, AutoModel

from embeddings.embedding.document_embedding import DocumentPoolEmbedding


class AutoTransformerForSequenceClassification(TextClassificationTransformer):
    def __init__(
        self,
        model_name_or_path: str,
        input_dim: int,
        num_labels: int,
        pool_strategy: Literal["cls", "mean", "max"] = "cls",
        dropout_rate: float = 0.5,
        freeze_transformer: bool = True,
        **kwargs: Any
    ):
        super().__init__(num_labels=num_labels, **kwargs)
        self.config = AutoConfig.from_pretrained(model_name_or_path, num_labels=num_labels)
        self.model = AutoModel.from_pretrained(model_name_or_path, config=self.config)
        self.doc_embedder = DocumentPoolEmbedding(strategy=pool_strategy)
        if freeze_transformer:
            self.freeze_transformer()

        self.layers = nn.Sequential(
            nn.Dropout(dropout_rate),
            nn.Linear(input_dim, input_dim),
            nn.Tanh(),
            nn.Dropout(dropout_rate),
            nn.Linear(input_dim, num_labels),
        )

    def forward(self, *args: Any, **kwargs: Any) -> Any:
        assert not (args and kwargs)
        assert args or kwargs
        inputs = kwargs if kwargs else args
        if isinstance(inputs, tuple):
            inputs = dict(ChainMap(*inputs))
        inputs.pop("labels", None)

        outputs = self.model(**inputs)
        pooled_output = self.doc_embedder(outputs.last_hidden_state)
        # pooled_output = outputs.last_hidden_state[:, 0, :]  # take <s> token (equiv. to [CLS])
        logits = self.layers(pooled_output)
        return logits

    def unfreeze_transformer(self, unfreeze_from: int = -1) -> None:
        if unfreeze_from == -1:
            for param in self.model.base_model.parameters():
                param.requires_grad = True
        else:
            requires_grad = False
            for name, param in self.model.base_model.named_parameters():
                if not requires_grad:
                    if name.startswith("encoder.layer"):
                        no_layer = int(name.split(".")[2])
                        if no_layer >= unfreeze_from:
                            requires_grad = True
                param.requires_grad = requires_grad

    def freeze_transformer(self) -> None:
        for param in self.model.base_model.parameters():
            param.requires_grad = False

In [None]:
from pytorch_lightning.callbacks.early_stopping import EarlyStopping

In [None]:
import abc
import dataclasses
import functools
from abc import ABC
from dataclasses import InitVar, dataclass, field
from typing import Any, Dict, Final, List, Set, Tuple, Type, TypeVar, Union

import optuna

from embeddings.embedding.auto_flair import AutoFlairWordEmbedding
from embeddings.embedding.flair_embedding import FlairTransformerEmbedding
from embeddings.embedding.static.embedding import StaticEmbedding
from embeddings.hyperparameter_search.parameters import (
    ConstantParameter,
    SearchableParameter,
)
from embeddings.utils.utils import PrimitiveTypes

Parameter = Union[SearchableParameter, ConstantParameter]
ParsedParameters = TypeVar("ParsedParameters")
SampledParameters = Dict[str, Union[PrimitiveTypes, Dict[str, PrimitiveTypes]]]
from embeddings.hyperparameter_search.configspace import BaseConfigSpace

@dataclass
class ConfigSpace(BaseConfigSpace):
    max_epochs: Parameter = SearchableParameter(
        name="max_epochs",
        type="categorical",
        choices=[1, 2, 5, 10, 25, 30],
    )
    mini_batch_size: Parameter = SearchableParameter(
        name="batch_size",
        type="categorical",
        choices=[8, 16, 32, 48, 64],
    )
    learning_rate: Parameter = SearchableParameter(
        name="learning_rate",
        type="categorical",
        choices=[1e-5, 1e-4, 1e-3, 1e-2],
    )
    pool_strategy: Parameter = SearchableParameter(
        name="pool_strategy",
        type="categorical",
        choices=["cls", "mean", "max"],
    )
    dropout: Parameter = SearchableParameter(
        name="dropout", type="discrete_uniform", low=0.0, high=0.5, q=0.05
    )
    freeze_transformer: Parameter = SearchableParameter(
        name="freeze_transformer",
        type="categorical",
        choices=[False, True],
    )

    @staticmethod
    def parse_parameters(parameters: Dict[str, PrimitiveTypes]) -> SampledParameters:
        dl_model_keys: Final = {
            "batch_size",
        }
        task_model_keys: Final = {
            "learning_rate",
            "pool_strategy",
            "dropout",
            "freeze_transformer",
        }
        task_trainer_keys: Final = {
            "max_epochs",
        }
        dl_model_kwargs = BaseConfigSpace._pop_parameters(
            parameters=parameters, parameters_keys=dl_model_keys
        )
        task_model_kwargs = BaseConfigSpace._pop_parameters(
            parameters=parameters, parameters_keys=task_model_keys
        )
        task_trainer_kwargs = BaseConfigSpace._pop_parameters(
            parameters=parameters, parameters_keys=task_trainer_keys
        )
        task_model_kwargs["train_batch_size"] = dl_model_kwargs["batch_size"]
        task_model_kwargs["eval_batch_size"] = dl_model_kwargs["batch_size"]

        return dl_model_kwargs, task_model_kwargs, task_trainer_kwargs

In [None]:
import torch
torch.cuda.empty_cache()

In [None]:
from embeddings.evaluator.text_classification_evaluator import (
    TextClassificationEvaluator,
)


def get_best_paramaters(study):
    best_params = study.best_params
    constant_params = study.best_trial.user_attrs
    parsed_params = best_params | constant_params
    return parsed_params


def objective(trial, cs, dataset) -> float:
    parameters = cs.sample_parameters(trial=trial)
    dl_model_kwargs, task_model_kwargs, task_trainer_kwargs = cs.parse_parameters(
        parameters
    )
    print("params", dl_model_kwargs, task_model_kwargs, task_trainer_kwargs)

    dm = TextClassificationDataModule(
        model_name_or_path="allegro/herbert-base-cased",
        dataset_name="clarin-pl/polemo2-official",
        text_fields=["text"],
        target_field="target",
        load_dataset_kwargs={
            "train_domains": ["hotels", "medicine"],
            "dev_domains": ["hotels", "medicine"],
            "test_domains": ["hotels", "medicine"],
            "text_cfg": "text",
        },
        ignore_test_split=True,
        **dl_model_kwargs,
    )
    dm.initalize()

    model = AutoTransformerForSequenceClassification(
        model_name_or_path="allegro/herbert-base-cased",
        input_dim=dm.input_dim,
        num_labels=dm.num_labels,
        **task_model_kwargs,
    )

    trainer = pl.Trainer(
        gpus=1,
        callbacks=[EarlyStopping(monitor="val/Loss", verbose=True, patience=5)],
        **task_trainer_kwargs,
    )

    try:
        trainer.fit(model, dm)
    except Exception as e:
        del model
        del dm
        del trainer
        torch.cuda.empty_cache()
        raise e
    evaluator = TextClassificationEvaluator()
    predictions = trainer.predict(
        dataloaders=dm.val_dataloader(), return_predictions=True
    )
    predictions = torch.argmax(torch.cat(predictions), dim=1).numpy()
    ground_truth = torch.cat([x["labels"] for x in list(dm.val_dataloader())]).numpy()
    metrics = evaluator.evaluate({"y_pred": predictions, "y_true": ground_truth})
    
    del dm
    del model
    del trainer
    torch.cuda.empty_cache()
    return metrics["f1__average_macro"]["f1"]


def run(n_trials):
    dataset = dataset = datasets.load_dataset(
        "clarin-pl/polemo2-official",
        **{
            "train_domains": ["hotels", "medicine"],
            "dev_domains": ["hotels", "medicine"],
            "test_domains": ["hotels", "medicine"],
            "text_cfg": "text",
        },
    )

    cs = ConfigSpace()
    study = optuna.create_study(
        direction="maximize",
        sampler=optuna.samplers.TPESampler(),
        pruner=optuna.pruners.MedianPruner(),
    )

    study.optimize(
        lambda trial: objective(trial, cs, dataset),
        n_trials=n_trials,
        show_progress_bar=True,
        catch=(Exception,),
    )

    return study.trials_dataframe(), get_best_paramaters(study)

In [None]:
df, metadata = run(n_trials=100)

In [None]:
metadata

In [None]:
dm = TextClassificationDataModule(
    model_name_or_path="allegro/herbert-base-cased",
    dataset_name="clarin-pl/polemo2-official",
    text_fields=["text"],
    target_field="target",
    load_dataset_kwargs={
        "train_domains": ["hotels", "medicine"],
        "dev_domains": ["hotels", "medicine"],
        "test_domains": ["hotels", "medicine"],
        "text_cfg": "text",
    },
    ignore_test_split=False,
    batch_size=8,
)
dm.initalize()

model = AutoTransformerForSequenceClassification(
    model_name_or_path="allegro/herbert-base-cased",
    input_dim=dm.input_dim,
    num_labels=dm.num_labels,
    **{
        "learning_rate": 1e-05,
        "pool_strategy": "cls",
        "dropout": 0.35,
        "freeze_transformer": False,
    },
)

trainer = pl.Trainer(
    gpus=1,
    callbacks=[EarlyStopping(monitor="val/Loss", verbose=True, patience=5)],
    **{"max_epochs": 10},
)

try:
    trainer.fit(model, dm)
except Exception as e:
    del model
    del dm
    del trainer
    torch.cuda.empty_cache()
    raise e
evaluator = TextClassificationEvaluator()
predictions = trainer.predict(dataloaders=dm.test_dataloader(), return_predictions=True)
predictions = torch.argmax(torch.cat(predictions), dim=1).numpy()
ground_truth = torch.cat([x["labels"] for x in list(dm.test_dataloader())]).numpy()
metrics = evaluator.evaluate({"y_pred": predictions, "y_true": ground_truth})

del dm
del model
del trainer
torch.cuda.empty_cache()
metrics

In [None]:
pipeline = TorchClassificationPipeline(
    embedding_name="allegro/herbert-base-cased",
    dataset_name="clarin-pl/polemo2-official",
    input_column_name=["text"],
    target_column_name="target",
    load_dataset_kwargs={
        "train_domains": ["hotels", "medicine"],
        "dev_domains": ["hotels", "medicine"],
        "test_domains": ["hotels", "medicine"],
        "text_cfg": "text",
    },
    task_train_kwargs={"max_epochs": 10, "gpus": 1},
    task_model_kwargs={"pool_strategy": "cls", "learning_rate": 5e-4}
)
result = pipeline.run()
print(result)