In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os

os.chdir("..")

from pathlib import Path
from tempfile import TemporaryDirectory

import pytorch_lightning as pl
import torch
from embeddings.defaults import RESULTS_PATH
from embeddings.task.lightning_task.text_classification import TextClassificationTask
from embeddings.model.lightning_module.text_classification import (
    TextClassificationModule,
)
from embeddings.pipeline.lightning_classification import LightningClassificationPipeline
from embeddings.utils.utils import build_output_path, format_eval_result

### Train simple downsampled pipeline

In [None]:
embedding_name_or_path = "hf-internal-testing/tiny-albert"
dataset_name = "clarin-pl/polemo2-official"
input_columns_name = "text"
target_column_name = "target"
# path = TemporaryDirectory()
# output_path = path.name
output_path = "."

pipeline = LightningClassificationPipeline(
    embedding_name_or_path=embedding_name_or_path,
    dataset_name_or_path=dataset_name,
    input_column_name=input_columns_name,
    target_column_name=target_column_name,
    output_path=output_path,
    finetune_last_n_layers=0,
    load_dataset_kwargs={
        "train_domains": ["hotels", "medicine"],
        "dev_domains": ["hotels", "medicine"],
        "test_domains": ["hotels", "medicine"],
        "text_cfg": "text",
    },
    datamodule_kwargs={"max_seq_length": 64,},
    task_train_kwargs={
        "max_epochs": 1,
        "devices": "auto",
        "accelerator": "cpu",
        "deterministic": True,
    },
)
result = pipeline.run()

### Load model from chechpoint automatically generated with Trainer

In [None]:
ckpt_path = (
    Path(output_path)
    / "lightning_logs"
    / "version_1"
    / "checkpoints"
    / "epoch=0-step=180.ckpt"
)
ckpt_path.resolve()

In [None]:
task_from_ckpt = TextClassificationTask.from_checkpoint(
    checkpoint_path=ckpt_path,
    output_path=output_path,
    task_train_kwargs={},
    early_stopping_kwargs={},
)

In [None]:
model_from_ckpt = TextClassificationModule.load_from_checkpoint(
    str(ckpt_path), strict=True
)

### Validate model states (because of the warning)

In [None]:
model_state_dict = pipeline.model.task.model.model.state_dict()
model_from_ckpt_state_dict = task_from_ckpt.model.model.state_dict()

In [None]:
model_state_dict.keys() == model_from_ckpt_state_dict.keys()

In [None]:
for k in model_state_dict.keys():
    assert torch.equal(model_state_dict[k], model_from_ckpt_state_dict[k])

### Manually save and load the model

In [None]:
pipeline.model.task.trainer.save_checkpoint("example.ckpt")
new_model = TextClassificationModule.load_from_checkpoint(
    checkpoint_path="example.ckpt"
)
new_task_from_ckpt = TextClassificationTask.from_checkpoint(
    checkpoint_path=ckpt_path,
    output_path=output_path,
    task_train_kwargs={},
    early_stopping_kwargs={},
)

### Use model from checkpoint for predictions

In [None]:
test_dataloader = pipeline.datamodule.test_dataloader()
preds = task_from_ckpt.predict(test_dataloader)

In [None]:
trainer = pl.Trainer()
preds_other = trainer.predict(model_from_ckpt, dataloaders=test_dataloader)