In [None]:
import logging
import os
import sys

import datasets
import h5py
import numpy as np
import pandas as pd
import torch
from bokeh.io import output_notebook, show
from bokeh.palettes import Bright5 as palette
from bokeh.plotting import figure
from lightning import pytorch as pl
from transformers import AutoTokenizer, AutoModelForSequenceClassification

logging.basicConfig(
    stream=sys.stdout,
    format="%(asctime)s [%(name)s] - %(levelname)s: %(message)s",
    level=logging.INFO
)

pl.seed_everything(101588, workers=True)
output_notebook()


class ZipDataset:
    def __init__(self, **tensors):
        tensors = {k: torch.from_numpy(v) for k, v in tensors.items()}
        self.datasets = {
            k: torch.utils.data.TensorDataset(v) for k, v in tensors.items()
        }

    def __len__(self):
        return len(list(self.datasets.values())[0])

    def __getitem__(self, i):
        return {k: v[i] for k, v in self.datasets.items()}


class Dataset(pl.LightningDataModule):
    def __init__(
        self,
        model_name: str,
        dataset: str,
        batch_size: int,
        subset_size: int | None = None,
        force: bool = None
    ) -> None:
        super().__init__()
        self.save_hyperparameters()
        self.datasets = {}
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)

    def tokenize(self, examples):
        return self.tokenizer(examples, padding="max_length", truncation=True)

    def prepare_data(self):
        if os.path.exists(self.hparams.dataset) and not self.hparams.force:
            with h5py.File(self.hparams.dataset, "r") as f:
                if f.attrs["model_name"] != self.hparams.model_name:
                    raise ValueError(
                        "Dataset {} was tokenized with model {}, "
                        "doesn't match specified model {}".format(
                            self.hparams.dataset,
                            f.attrs["model_name"],
                            self.hparams.model_name
                        )
                    )
            logging.info("Dataset already exists and not forcing overwrite")
            return
        elif self.hparams.force:
            logging.info("Dataset already exists but forcing overwrite")

        logging.info("Downloading and preprocessing dataset")
        dataset = datasets.load_dataset("yelp_review_full")
        dataset = dataset.map(lambda x: self.tokenize(x["text"]), batched=True)
        dataset.set_format("numpy")

        logging.info(f"Writing data to {self.hparams.dataset}")
        with h5py.File(self.hparams.dataset, "w") as f:
            f.attrs["model_name"] = self.hparams.model_name

            for split, dset in dataset.items():
                group = f.create_group(split)
                for feature in dset.features:
                    X = dset[feature]
                    if feature == "text":
                        group.attrs["text"] = "\n".join(X)
                    else:
                        chunks = (1000,)
                        if X.ndim == 2:
                            chunks = chunks + (X.shape[-1],)
                        group.create_dataset(feature, data=X, chunks=chunks)
                    f.flush()
        del dataset

    def setup(self, stage):
        split = "test" if stage != "fit" else "train"

        logging.info(f"Loading {split} data")
        with h5py.File(self.hparams.dataset, "r") as f:
            if split == "train" and self.hparams.subset_size is not None:
                dataset = f[split]
                size = len(list(dataset.values())[0])
                idx = np.random.choice(
                    size, size=self.hparams.subset_size, replace=False
                )
                idx = np.sort(idx)
            else:
                idx = slice(None, None)

            self.datasets[split] = {k: v[idx] for k, v in f[split].items()}

        labels = self.datasets[split].pop("label")
        labels = labels / 2 - 1
        self.datasets[split]["labels"] = labels

    def on_after_batch_transfer(self, batch, _):
        converted = {}
        for k, v in batch.items():
            v = v[0]
            if k != "labels":
                v = v.type(torch.int64)
            converted[k] = v
        return converted

    def train_dataloader(self):
        dataset = ZipDataset(**self.datasets["train"])
        return torch.utils.data.DataLoader(
            dataset,
            shuffle=True,
            batch_size=self.hparams.batch_size,
            pin_memory=True
        )

    def predict_dataloader(self):
        dataset = ZipDataset(**self.datasets["test"])
        return torch.utils.data.DataLoader(
            dataset,
            shuffle=False,
            batch_size=4 * self.hparams.batch_size,
            pin_memory=True
        )


class Model(pl.LightningModule):
    def __init__(self, model_name: str, learning_rate: float = 1e-3) -> None:
        super().__init__()
        self.save_hyperparameters()
        self.model = AutoModelForSequenceClassification.from_pretrained(
            model_name, num_labels=1
        )
        self.loss = torch.nn.MSELoss()

    def forward(self, X):
        y = X.pop("labels").type(torch.float32)
        return self.model(**X).logits[:, 0], y

    def training_step(self, batch):
        loss = self.loss(*self(batch))
        self.log(
            "train_loss",
            loss,
            on_step=True,
            on_epoch=True,
            prog_bar=True,
            logger=True,
        )
        return loss

    def configure_optimizers(self):
        params = self.model.parameters()
        optimizer = torch.optim.AdamW(params, self.hparams.learning_rate)
        scheduler = torch.optim.lr_scheduler.OneCycleLR(
            optimizer,
            self.hparams.learning_rate,
            total_steps=self.trainer.estimated_stepping_batches,
            pct_start=0.1
        )
        return [optimizer], [scheduler]

    def configure_callbacks(self) -> list[pl.Callback]:
        checkpoint = pl.callbacks.ModelCheckpoint(
            filename="weights.pt",
            save_last=True,
            auto_insert_metric_name=False,
        )
        return [checkpoint]

In [None]:
dataset = "dataset.hdf5"
model_name = "bert-base-cased"
batch_size = 16
learning_rate = 1e-3
max_epochs = 1
subset_size = 200000

logger = pl.loggers.CSVLogger("logs", name=model_name)
dataset = Dataset(model_name, dataset, batch_size, subset_size)
model = Model(model_name, learning_rate)
trainer = pl.Trainer(
    max_epochs=max_epochs,
    devices=[0],
    precision="16-mixed",
    accumulate_grad_batches=8,
    logger=logger
)

In [None]:
trainer.fit(model, dataset)

In [None]:
predictions = trainer.predict(model, dataset)

y_hat = torch.cat([i[0] for i in predictions]).cpu().numpy()
y = torch.cat([i[1] for i in predictions]).cpu().numpy()
df = pd.DataFrame(dict(predictions=y_hat, labels=y))
df.to_csv(os.path.join(logger.log_dir, "results.csv"), index=False)

In [None]:
metrics = pd.read_csv(os.path.join(logger.log_dir, "metrics.csv"))
metrics

In [None]:
p = figure(
    height=300,
    width=700,
    x_axis_label="Step",
    y_axis_label="Train loss",
    tools="",
    tooltips=[("Step", "@step"), ("Loss", "@train_loss_step")]
)
p.line(
    "step",
    "train_loss_step",
    line_color=palette[0],
    line_width=1.5,
    line_alpha=0.9,
    source=metrics
)
show(p)

In [None]:
results = pd.read_csv(os.path.join(logger.log_dir, "results.csv"))
slope, _ = np.polyfit(results.predictions, results.labels, 1)
label_variance = results.labels.var()
residuals = (results.diff(axis=1)**2).mean().labels
r2 = 1 - residuals / label_variance

In [None]:
p = figure(
    title=rf"$$\text{{Slope: {slope:0.3f}, R}}^2\text{{: {r2:0.4f}}}$$",
    height=300,
    width=700,
    x_axis_label="Predicted Score",
    y_axis_label="Density"
)
p.title.text_font_style = "normal"
p.toolbar_location = None

for i, (label, df) in enumerate(results.groupby("labels")):
    hist, bins = np.histogram(df.predictions, bins=50)
    hist = hist / hist.sum() / (bins[1] - bins[0])
    centers = (bins[:1] + bins[:-1]) / 2

    label = (label + 1) * 2
    p.line(
        (centers + 1) * 2,
        hist,
        line_color=palette[i],
        line_width=2,
        line_alpha=0.8,
        legend_label=str(int(label))
    )
p.legend.title = "True Label"
show(p)