## Converting to pytorch lightning fitting routine 

In [None]:
import numpy as np
import bisect
import tensorboard
import torch
import torch.nn.functional as F
import torch.optim as optim
import pytorch_lightning as pl

# from kws.eval import Metric


class Routine(pl.LightningModule):
    def __init__(self, model):
        super().__init__()
        self.model = model
        self.lr = 1e-3
        self.validation_step_outputs = []
        self.training_step_outputs = []
        self.test_step_outputs = []

    def forward(self, x):
        y_hat = self.model(x)
        return y_hat

    def training_step(self, batch, batch_idx):
        # training_step defines the train loop.
        x = batch["x"]
        y = batch["y"]
        y_hat = self(x)
        y_hat = y_hat.squeeze()

        # dummy metrics
        metrics_dict = {"loss": 10, "train_EM": 0.9, "train_F1": 0.9}

        # loss = F.binary_cross_entropy_with_logits(y_hat, y)

        # y_hat = (F.sigmoid(y_hat) > 0.5).float()

        # metrics = self.metric(y_hat, y)()
        # metrics_dict = {
        #     "loss": loss,
        #     "train_ttr": metrics.ttr,
        #     "train_ftr": metrics.ftr,
        #     "train_acc": metrics.acc,
        # }
        self.training_step_outputs.append(metrics_dict)
        return metrics_dict

    def on_train_epoch_end(self):
        results = {
            "loss": torch.tensor(
                [x["loss"] for x in self.training_step_outputs]
            ).mean(),
            "F1": torch.tensor(
                [x["train_F1"] for x in self.training_step_outputs]
            ).mean(),
            "EM": torch.tensor(
                [x["train_EM"] for x in self.training_step_outputs]
            ).mean(),
        }
        # self.log(f"LR",self.lr, on_epoch=True, prog_bar=True, logger=True)
        for k, v in results.items():
            self.log(
                f"train_{k}",
                v,
                on_epoch=True,
                prog_bar=True,
                logger=True,
                sync_dist=True,
            )

    def validation_step(self, batch, batch_idx):
        x = batch["x"]
        y = batch["y"]
        y_hat = self(x)
        # (batch, num_classes)
        y_hat = y_hat.squeeze()
        # (batch,)
        loss = F.binary_cross_entropy_with_logits(y_hat, y)
        # pred = F.sigmoid(y_hat)
        # y_hat = (pred > 0.5).float()

        # metrics = self.metric(y_hat, y)()
        # metrics_dict = {
        #     "val_loss": loss,
        #     "val_ttr": metrics.ttr,
        #     "val_ftr": metrics.ftr,
        #     "val_acc": metrics.acc,
        # }

        # dummy metrics
        metrics_dict = {"loss": 10, "val_EM": 0.9, "val_F1": 0.9}
        self.validation_step_outputs.append(metrics_dict)
        return metrics_dict

    def on_validation_epoch_end(self):
        results = {
            "loss": torch.tensor(
                [x["val_loss"] for x in self.validation_step_outputs]
            ).mean(),
            "EM": torch.tensor(
                [x["val_EM"] for x in self.validation_step_outputs]
            ).mean(),
            "F1": torch.tensor(
                [x["val_F1"] for x in self.validation_step_outputs]
            ).mean(),
        }
        for k, v in results.items():
            self.log(
                f"val_{k}", v, on_epoch=True, prog_bar=True, logger=True, sync_dist=True
            )
            # self.log(f"val_{k}", v, on_epoch=True, prog_bar=True) # , logger=True)

    def test_step(self, batch, batch_idx):
        x = batch["x"]
        y = batch["y"]
        y_hat = self(x)
        # (batch, num_classes)
        y_hat = y_hat.squeeze()
        # (batch,)
        pred = F.sigmoid(y_hat)

        # (batch_probabilities,)
        # y_hat = (pred > 0.5).float()
        # (batch_labels,)
        # metrics = self.metric(y_hat, y)()

        metrics_dict = {
            "test_EM": 0.9,
            "test_F1": 0.8,
        }
        self.test_step_outputs.append(metrics_dict)
        return metrics_dict

    def on_test_epoch_end(self):
        results = {
            "F1": torch.tensor([x["test_EM"] for x in self.test_step_outputs]).mean(),
            "EM": torch.tensor([x["test_F1"] for x in self.test_step_outputs]).mean(),
        }

        for k, v in results.items():
            self.log(
                f"test_{k}",
                v,
                on_epoch=True,
                prog_bar=True,
                logger=True,
                sync_dist=True,
            )

    def configure_optimizers(self):

        # special scheduler for transformers
        optimizer = optim.AdamW(
            filter(lambda p: p.requires_grad, self.parameters()),
            lr=self.cfg_fitting.learning_rate,
            betas=(0.9, 0.999),
            eps=1e-08,
            weight_decay=0.05,
        )

        def lr_scheduler_lambda_1(epoch):
            if epoch < 3:
                # warm up lr
                lr_scale = self.cfg_fitting.lr_rate[epoch]
            else:
                # warmup schedule
                lr_pos = int(
                    -1 - bisect.bisect_left(self.cfg_fitting.lr_scheduler_epoch, epoch)
                )
                if lr_pos < -3:
                    lr_scale = max(self.cfg_fitting.lr_rate[0] * (0.98**epoch), 0.03)
                else:
                    lr_scale = self.cfg_fitting.lr_rate[lr_pos]
            return lr_scale

        scheduler_1 = optim.lr_scheduler.LambdaLR(
            optimizer, lr_lambda=lr_scheduler_lambda_1
        )

        return {
            "optimizer": optimizer,
            "lr_scheduler": scheduler_1,
            "monitor": "val_loss",
        }

In [2]:
import os
import torch.nn.functional as F
from pytorch_lightning import Trainer
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
from torch.utils.data import DataLoader
from nlp import Dataset
from functools import partial
from pathlib import Path
from babl.data import convert_to_features
import random
import json
from babl.data import T2TDataCollator
from dataclasses import dataclass
from pytorch_lightning.callbacks import (
    EarlyStopping,
    ModelCheckpoint,
    LearningRateMonitor,
)


class CallbackCollection:
    def __init__(self, model_name, data_path) -> None:

        ####################################################################################
        @dataclass
        class FittingArgs:
            es_patience: int = 5
            model_dir: str = str(Path("/home/ola/Code/babl/outputs") / model_name)

        ####################################################################################

        self.data_path = data_path

        self.args = FittingArgs()

    def __call__(self):
        lr_monitor = LearningRateMonitor(logging_interval="epoch")

        early_stopping = EarlyStopping(
            mode="min", monitor="val_loss", patience=self.args.es_patience
        )
        checkpoint_callback = ModelCheckpoint(
            monitor="val_loss",
            dirpath=self.args.model_dir,
            save_top_k=2,
            save_last=True,
            mode="min",
            filename="{epoch}-{val_loss:.2f}-{val_acc:.2f}-{val_ttr:.2f}-{val_ftr:.2f}",
        )

        callbacks = {
            "checkpoint": checkpoint_callback,
            "lr": lr_monitor,
            "es": early_stopping,
        }
        # callbacks = [checkpoint_callback, lr_monitor, early_stopping]
        return callbacks


# class TextDataset(Dataset):
#     def __init__(self, dath_path, tokenizer):
#         # super().__init__()
#         self.data_path = Path(dath_path)
#         self.tokenizer = tokenizer

#     def read(
#         self,
#     ):
#         examples = []
#         with open(self.data_path, "r") as json_file:
#             x = list(json_file)
#             # logger.debug(f"[data.py::build_dataset]{x=}")
#             for json_str in x:
#                 examples.append(json.loads(json_str))
#         return examples

#     def extract_valid_pairs(self, samples):
#         valid_questions = []
#         for l in samples:
#             # clear all docs with more or less than one answer
#             # clean all docs with no annotations
#             if len(l["annotations"][0]["short_answers"]) == 1:
#                 if len(l["long_answer_candidates"]) > 2:
#                     valid_questions.append(l)
#         return valid_questions

#     def encode(self, dataset):
#         ####################################################################################
#         @dataclass
#         class Args:
#             input_max_len: int = 64
#             output_max_len: int = 64

#         ####################################################################################

#         txt2feats = partial(convert_to_features, args=Args(), tokenizer=self.tokenizer)
#         # map convert_to_features batch wise
#         ds = dataset.map(txt2feats, batched=True)
#         # set the tensor type and the columns which the dataset should return
#         columns = ["input_ids", "target_ids", "attention_mask", "target_attention_mask"]
#         ds.set_format(type="torch", columns=columns)
#         return ds

#     def construct_ds(self, examples):
#         datapoints = {}
#         datapoints["input_text"] = []
#         datapoints["target_text"] = []
#         for i, q in enumerate(examples):

#             # fitting dataset; positive and negative fitting examples
#             if random.randint(0, 1) == 1:
#                 # Construct positive example
#                 datapoints["input_text"].append(
#                     f"question: {q['question_text']}  context: {self.get_long_answer(q)} </s>"
#                 )
#                 datapoints["target_text"].append(self.get_short_answer(q))
#                 # if i % 10000 == 0:
#                 #     print("-"*100)
#                 #     print("Positive fitting example:")
#                 #     print(f"[input_text]: question: {q['question_text']}  context: {get_long_answer(q)} </s>")
#                 #     print(f"[target_text]: {get_short_answer(q)}")
#                 #     print("-"*100)
#             else:
#                 # Construct negative example
#                 datapoints["input_text"].append(
#                     f"question: {q['question_text']}  context: {self.get_random_negative(q)} </s>"
#                 )
#                 datapoints["target_text"].append("None </s>")
#                 # if i % 10000 == 0:
#                 #     print("-"*100)
#                 #     print("negative fitting example:")
#                 #     print(f"[input_text]: question: {q['question_text']}  context: {get_random_negative(q)} </s>")
#                 #     print(f"[target_text]: None </s>")
#                 #     print("-"*100)
#         assert len(datapoints["target_text"]) == len(
#             datapoints["input_text"]
#         ), "incorrect data distribution"

#         return self.encode(super().from_dict(datapoints))

#     def get_exert(self, doc, start_token, end_token):
#         return " ".join(doc.split(" ")[start_token:end_token])

#     def get_short_answer(self, q):
#         answer_indx = q["annotations"][0]["short_answers"][0]
#         return self.get_exert(
#             q["document_text"], answer_indx["start_token"], answer_indx["end_token"]
#         )

#     def get_long_answer(self, q):
#         answer_indx = q["annotations"][0]["long_answer"]
#         return self.get_exert(
#             q["document_text"], answer_indx["start_token"], answer_indx["end_token"]
#         )

#     def get_random_negative(self, q):
#         long_answer_indx = q["annotations"][0]["long_answer"]

#         for i in range(len(q["long_answer_candidates"])):
#             if (
#                 q["long_answer_candidates"][i]["start_token"]
#                 == long_answer_indx["start_token"]
#             ):
#                 del q["long_answer_candidates"][i]
#                 break

#         answer_indx = random.choice(q["long_answer_candidates"])
#         return self.get_exert(
#             q["document_text"], answer_indx["start_token"], answer_indx["end_token"]
#         )


class TextDataModule(pl.LightningDataModule):
    def __init__(self, data_path, tokenizer):
        super().__init__()

        self.train_path = Path(data_path) / "50k.jsonl"
        self.val_path = Path(data_path) / "10k.jsonl"
        # NOTICE, we  re-use the validatio dataset
        self.test_path = Path(data_path) / "10k.jsonl"
        self.tokenizer = tokenizer
        self.pin_memory = False  # True if torch.cuda.is_available() else False

    def train_dataloader(self):
        ds_train = TextDataset(self.train_path, self.tokenizer)
        return DataLoader(
            ds_train,
            batch_size=32,
            shuffle=True,
            drop_last=True,
            pin_memory=self.pin_memory,
            collate_fn=T2TDataCollator(),
        )

    def val_dataloader(self):
        ds_val = TextDataset(self.val_path, self.tokenizer)
        return DataLoader(
            ds_val,
            batch_size=32,
            shuffle=True,
            drop_last=True,
            pin_memory=self.pin_memory,
            collate_fn=T2TDataCollator(),
        )

    def test_dataloader(self):

        ds_test = TextDataset(self.test_path, self.tokenizer)
        return DataLoader(
            ds_test,
            batch_size=32,
            shuffle=True,
            drop_last=True,
            pin_memory=self.pin_memory,
            collate_fn=T2TDataCollator(),
        )


# class Predictor(nn.Module):
#     def __init__(self, model):
#         super().__init__()
#         self.model = model

#     def forward(self, x):
#         logits = self.model(x)
#         pred = F.sigmoid(logits)
#         return pred


class Fitter:
    def __init__(
        self,
        model,
        tokenizer,
        model_name,
        data_path="/home/ola/Code/babl/inputs",
    ):
        self.model = model
        self.tokenizer = tokenizer
        self.model_name = model_name
        self.data_path = data_path

    def setup(self):
        data_module = TextDataModule(data_path=self.data_path, tokenizer=self.tokenizer)

        train_loader = data_module.train_dataloader()
        val_loader = data_module.val_dataloader()
        test_loader = data_module.test_dataloader()

        return train_loader, val_loader, test_loader

    def callbacks(self):
        # cfg_fitting = self.cfg_fitting
        callback_collection = CallbackCollection(self.model_name, self.data_path)
        return callback_collection()

    def __call__(self):

        ####################################################################################
        @dataclass
        class FittingArgs:
            es_patience: int = 5
            model_dir: str = str(Path("/home/ola/Code/babl/outputs") / self.model_name)

        ####################################################################################

        args = FittingArgs()

        logger = TensorBoardLogger(
            save_dir=args.model_dir,
            name="lightning_logs",
        )
        Model = self.model
        # get loaders and datamodule to access input shape
        train_loader, val_loader, test_loader = self.setup()
        print("Created training, validating and test loaders .... ")
        # get input shape for onnx exporting
        # input_shape = data_module.input_shape
        # init model
        # kwargs = {}
        # model = Model(**kwargs)

        # setup training, validating and testing routines for the model
        routine = Routine(self.model)

        # Init a trainer to execute routine
        callback_dict = self.callbacks()
        callback_list = [v for (_, v) in callback_dict.items()]
        number_devices = os.getenv("CUDA_VISIBLE_DEVICES", "1,").split(",")
        try:
            number_devices.remove("")
        except ValueError:
            pass

        ####################################################################################
        @dataclass
        class FittingArgs:
            max_epoch: int = 10
            fast_dev_run: bool = True

        ####################################################################################

        args = FittingArgs()
        trainer = Trainer(
            accelerator="gpu",
            devices=len(number_devices),
            strategy=os.getenv("STRATEGY", "ddp"),
            sync_batchnorm=True,
            logger=logger,
            max_epochs=args.max_epoch,
            callbacks=callback_list,
            num_sanity_val_steps=2,
            # resume_from_checkpoint=self.cfg_fitting.resume_from_checkpoint,
            gradient_clip_val=1.0,
            fast_dev_run=args.fast_dev_run,
        )

        trainer.fit(
            routine, train_dataloaders=train_loader, val_dataloaders=val_loader
        )  # ,ckpt_path=PATH)

        if args.fast_dev_run:
            # issue with finding best weights path for in fast dev run using last model weights
            model_ckpt_path = callback_dict["checkpoint"].__dict__["last_model_path"]
        else:
            model_ckpt_path = callback_dict["checkpoint"].__dict__["best_model_path"]

        trainer.test(
            dataloaders=test_loader,
            ckpt_path=model_ckpt_path,
        )
        # Return the input_shapes and trainer of the model for exporting
        return trainer

In [24]:
from torch.utils.data import Dataset
from pathlib import Path
import random
import json
from functools import partial
from dataclasses import dataclass
from babl.data import convert_to_features


class TextDataset(Dataset):

    def __init__(self, dath_path, tokenizer, plain_text=False):
        # super().__init__()
        self.data_path = Path(dath_path)
        self.tokenizer = tokenizer
        self.plain_text = plain_text
        self.ds = {}
        if plain_text:
            self.construct_ds(self.extract_valid_pairs(self.read()))
        else:

            self.construct_ds(self.extract_valid_pairs(self.read()))
            self.ds = self.convert_to_features(self.ds)

    def __len__(self):
        return len(self.ds.values()[0])

    def __getitem__(self, idx):
        if self.plain_text:
            return {
                "input_text": self.ds["input_text"][idx],
                "target_text": self.ds["target_text"][idx],
            }
        else:
            return {
                "input_ids": self.ds["input_ids"][idx],
                "attention_mask": self.ds["attention_mask"][idx],
                "target_ids": self.ds["input_ids"][idx],
                "target_attention_mask": self.ds["attention_mask"][idx],
            }

    def read(
        self,
    ):
        examples = []
        with open(self.data_path, "r") as json_file:
            x = list(json_file)
            # logger.debug(f"[data.py::build_dataset]{x=}")
            for json_str in x:
                examples.append(json.loads(json_str))
        return examples

    def extract_valid_pairs(self, samples):
        valid_questions = []
        for l in samples:
            # clear all docs with more or less than one answer
            # clean all docs with no annotations
            if len(l["annotations"][0]["short_answers"]) == 1:
                if len(l["long_answer_candidates"]) > 2:
                    valid_questions.append(l)
        return valid_questions

    def construct_ds(self, examples):

        self.ds["input_text"] = []
        self.ds["target_text"] = []
        for i, q in enumerate(examples):

            # fitting dataset; positive and negative fitting examples
            if random.randint(0, 1) == 1:
                # Construct positive example
                self.ds["input_text"].append(
                    f"question: {q['question_text']}  context: {self.get_long_answer(q)} </s>"
                )
                self.ds["target_text"].append(self.get_short_answer(q))
                # if i % 10000 == 0:
                #     print("-"*100)
                #     print("Positive fitting example:")
                #     print(f"[input_text]: question: {q['question_text']}  context: {get_long_answer(q)} </s>")
                #     print(f"[target_text]: {get_short_answer(q)}")
                #     print("-"*100)
            else:
                # Construct negative example
                self.ds["input_text"].append(
                    f"question: {q['question_text']}  context: {self.get_random_negative(q)} </s>"
                )
                self.ds["target_text"].append("None </s>")
                # if i % 10000 == 0:
                #     print("-"*100)
                #     print("negative fitting example:")
                #     print(f"[input_text]: question: {q['question_text']}  context: {get_random_negative(q)} </s>")
                #     print(f"[target_text]: None </s>")
                #     print("-"*100)
        assert len(self.ds["target_text"]) == len(
            self.ds["input_text"]
        ), "incorrect data distribution"

    def convert_to_features(self, batch):

        # tokenizer = T5Tokenizer.from_pretrained(args.model_name_or_path)  # "t5-small")
        ####################################################################################
        @dataclass
        class Args:
            input_max_len: int = 64
            output_max_len: int = 64

            ####################################################################################

        input_encodings = self.tokenizer.batch_encode_plus(
            batch["input_text"],
            truncation=True,
            pad_to_max_length=True,
            max_length=Args().input_max_len,
        )
        target_encodings = self.tokenizer.batch_encode_plus(
            batch["target_text"],
            truncation=True,
            pad_to_max_length=True,
            max_length=Args().output_max_len,
        )
        # print("input_encodings", input_encodings.keys())
        # print("target_encodings", target_encodings.keys())
        encodings = {
            "input_ids": input_encodings["input_ids"],
            "attention_mask": input_encodings["attention_mask"],
            "target_ids": target_encodings["input_ids"],
            "target_attention_mask": target_encodings["attention_mask"],
        }
        return encodings

    # def encode(self, dataset):
    #     ####################################################################################
    #     @dataclass
    #     class Args:
    #         input_max_len: int = 64
    #         output_max_len: int = 64

    #     ####################################################################################
    #     txt2feats = partial(convert_to_features, args=Args(), tokenizer=self.tokenizer)
    #     ds = self.construct_ds(self.extract_valid_pairs(self.read()))
    #     return ds

    def get_exert(self, doc, start_token, end_token):
        return " ".join(doc.split(" ")[start_token:end_token])

    def get_short_answer(self, q):
        answer_indx = q["annotations"][0]["short_answers"][0]
        return self.get_exert(
            q["document_text"], answer_indx["start_token"], answer_indx["end_token"]
        )

    def get_long_answer(self, q):
        answer_indx = q["annotations"][0]["long_answer"]
        return self.get_exert(
            q["document_text"], answer_indx["start_token"], answer_indx["end_token"]
        )

    def get_random_negative(self, q):
        long_answer_indx = q["annotations"][0]["long_answer"]

        for i in range(len(q["long_answer_candidates"])):
            if (
                q["long_answer_candidates"][i]["start_token"]
                == long_answer_indx["start_token"]
            ):
                del q["long_answer_candidates"][i]
                break

        answer_indx = random.choice(q["long_answer_candidates"])
        return self.get_exert(
            q["document_text"], answer_indx["start_token"], answer_indx["end_token"]
        )

In [28]:
from babl.models import MODELS_CHOICES, MODELS
from pathlib import Path

model_name = "t5"
full_model_name = MODELS_CHOICES[model_name][0]
t_w_m = MODELS[model_name]

tokenizer = t_w_m["tok"]
model = t_w_m["model"]

t = tokenizer.from_pretrained(full_model_name)
m = model.from_pretrained(full_model_name)

data_path = Path("/home/ola/Code/babl/inputs") / "10k.jsonl"

ds = TextDataset(data_path, tokenizer=t, plain_text=True)

# data_module = TextDataModule(data_path, tokenizer)

# Fitter(model=m, model_name=full_model_name, tokenizer=tokenizer, data_path=data_path)()

In [29]:
ds[4]

('question: who played the devil in storm of the century  context: <Li> Richard Fitzpatrick as Jonas Stanhope </Li> </s>',
 'None </s>')

In [19]:
ds.ds["input_text"].__len__()

3194