## Converting to pytorch lightning fitting routine 

In [45]:
import numpy as np
import bisect
import tensorboard
import torch
import torch.nn.functional as F
import torch.optim as optim
import pytorch_lightning as pl

# from kws.eval import Metric


class Routine(pl.LightningModule):
    def __init__(self, model):
        super().__init__()
        self.model = model
        self.lr = 1e-3
        self.validation_step_outputs = []
        self.training_step_outputs = []
        self.test_step_outputs = []

    def forward(
        self,
        input_ids,
    ):
        y_hat = self.model(x)
        return y_hat

    def training_step(self, batch, batch_idx):
        # training_step defines the train loop.
        print(f"keys = {batch.keys()}")
        print(f"{batch=}")

        y_hat = self(**batch)
        print(f"{y_hat=}")

        y_hat = y_hat.squeeze()

        # dummy metrics
        metrics_dict = {"loss": 10, "train_EM": 0.9, "train_F1": 0.9}

        # loss = F.binary_cross_entropy_with_logits(y_hat, y)

        # y_hat = (F.sigmoid(y_hat) > 0.5).float()

        # metrics = self.metric(y_hat, y)()
        # metrics_dict = {
        #     "loss": loss,
        #     "train_ttr": metrics.ttr,
        #     "train_ftr": metrics.ftr,
        #     "train_acc": metrics.acc,
        # }
        self.training_step_outputs.append(metrics_dict)
        return metrics_dict

    def on_train_epoch_end(self):
        results = {
            "loss": torch.tensor(
                [x["loss"] for x in self.training_step_outputs]
            ).mean(),
            "F1": torch.tensor(
                [x["train_F1"] for x in self.training_step_outputs]
            ).mean(),
            "EM": torch.tensor(
                [x["train_EM"] for x in self.training_step_outputs]
            ).mean(),
        }
        # self.log(f"LR",self.lr, on_epoch=True, prog_bar=True, logger=True)
        for k, v in results.items():
            self.log(
                f"train_{k}",
                v,
                on_epoch=True,
                prog_bar=True,
                logger=True,
                sync_dist=True,
            )

    def validation_step(self, batch, batch_idx):
        # x = batch["x"]
        # y = batch["y"]
        y_hat = self(**batch)
        # (batch, num_classes)
        y_hat = y_hat.squeeze()
        # (batch,)
        loss = F.binary_cross_entropy_with_logits(y_hat, y)
        # pred = F.sigmoid(y_hat)
        # y_hat = (pred > 0.5).float()

        # metrics = self.metric(y_hat, y)()
        # metrics_dict = {
        #     "val_loss": loss,
        #     "val_ttr": metrics.ttr,
        #     "val_ftr": metrics.ftr,
        #     "val_acc": metrics.acc,
        # }

        # dummy metrics
        metrics_dict = {"loss": 10, "val_EM": 0.9, "val_F1": 0.9}
        self.validation_step_outputs.append(metrics_dict)
        return metrics_dict

    def on_validation_epoch_end(self):
        results = {
            "loss": torch.tensor(
                [x["val_loss"] for x in self.validation_step_outputs]
            ).mean(),
            "EM": torch.tensor(
                [x["val_EM"] for x in self.validation_step_outputs]
            ).mean(),
            "F1": torch.tensor(
                [x["val_F1"] for x in self.validation_step_outputs]
            ).mean(),
        }
        for k, v in results.items():
            self.log(
                f"val_{k}", v, on_epoch=True, prog_bar=True, logger=True, sync_dist=True
            )
            # self.log(f"val_{k}", v, on_epoch=True, prog_bar=True) # , logger=True)

    def test_step(self, batch, batch_idx):
        # x = batch["x"]
        # y = batch["y"]
        y_hat = self(**batch)
        # (batch, num_classes)
        y_hat = y_hat.squeeze()
        # (batch,)
        pred = F.sigmoid(y_hat)

        # (batch_probabilities,)
        # y_hat = (pred > 0.5).float()
        # (batch_labels,)
        # metrics = self.metric(y_hat, y)()

        metrics_dict = {
            "test_EM": 0.9,
            "test_F1": 0.8,
        }
        self.test_step_outputs.append(metrics_dict)
        return metrics_dict

    def on_test_epoch_end(self):
        results = {
            "F1": torch.tensor([x["test_EM"] for x in self.test_step_outputs]).mean(),
            "EM": torch.tensor([x["test_F1"] for x in self.test_step_outputs]).mean(),
        }

        for k, v in results.items():
            self.log(
                f"test_{k}",
                v,
                on_epoch=True,
                prog_bar=True,
                logger=True,
                sync_dist=True,
            )

    def configure_optimizers(self):

        # special scheduler for transformers
        optimizer = optim.AdamW(
            filter(lambda p: p.requires_grad, self.parameters()),
            lr=0.001,  # self.cfg_fitting.learning_rate,
            betas=(0.9, 0.999),
            eps=1e-08,
            weight_decay=0.05,
        )

        # def lr_scheduler_lambda_1(epoch):
        #     if epoch < 3:
        #         # warm up lr
        #         lr_scale = self.cfg_fitting.lr_rate[epoch]
        #     else:
        #         # warmup schedule
        #         lr_pos = int(
        #             -1 - bisect.bisect_left(self.cfg_fitting.lr_scheduler_epoch, epoch)
        #         )
        #         if lr_pos < -3:
        #             lr_scale = max(self.cfg_fitting.lr_rate[0] * (0.98**epoch), 0.03)
        #         else:
        #             lr_scale = self.cfg_fitting.lr_rate[lr_pos]
        #     return lr_scale

        # scheduler_1 = optim.lr_scheduler.LambdaLR(
        #     optimizer, lr_lambda=lr_scheduler_lambda_1
        # )

        return {
            "optimizer": optimizer,
            # "lr_scheduler": scheduler_1,
            "monitor": "val_loss",
        }

In [46]:
import os
import torch.nn.functional as F
from pytorch_lightning import Trainer
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
from torch.utils.data import DataLoader

# from nlp import Dataset

from torch.utils.data import Dataset
from pathlib import Path
import random
import json
from dataclasses import dataclass

# from functools import partial
from pathlib import Path

# from babl.data import convert_to_features
import random
import json
import torch
from babl.data import T2TDataCollator
from pytorch_lightning.callbacks import (
    EarlyStopping,
    ModelCheckpoint,
    LearningRateMonitor,
)


class CallbackCollection:
    def __init__(self, model_name, data_path) -> None:

        ####################################################################################
        @dataclass
        class FittingArgs:
            es_patience: int = 5
            model_dir: str = str(Path("/home/ola/Code/babl/outputs") / model_name)

        ####################################################################################

        self.data_path = data_path

        self.args = FittingArgs()

    def __call__(self):
        lr_monitor = LearningRateMonitor(logging_interval="epoch")

        early_stopping = EarlyStopping(
            mode="min", monitor="val_loss", patience=self.args.es_patience
        )
        checkpoint_callback = ModelCheckpoint(
            monitor="val_loss",
            dirpath=self.args.model_dir,
            save_top_k=2,
            save_last=True,
            mode="min",
            filename="{epoch}-{val_loss:.2f}-{val_acc:.2f}-{val_ttr:.2f}-{val_ftr:.2f}",
        )

        callbacks = {
            "checkpoint": checkpoint_callback,
            "lr": lr_monitor,
            "es": early_stopping,
        }
        # callbacks = [checkpoint_callback, lr_monitor, early_stopping]
        return callbacks


class TextDataset(Dataset):

    def __init__(self, dath_path, tokenizer, plain_text=False):
        # super().__init__()
        self.data_path = Path(dath_path)
        self.tokenizer = tokenizer
        self.plain_text = plain_text
        self.ds = {}
        if plain_text:
            self.construct_ds(self.extract_valid_pairs(self.read()))
        else:

            self.construct_ds(self.extract_valid_pairs(self.read()))
            self.ds = self.convert_to_features(self.ds)

    def __len__(self):
        # print(self.ds)
        return list(self.ds.values())[0].__len__()
        # return len(self.ds.values()[0])

    def __getitem__(self, idx):
        if self.plain_text:
            return {
                "input_text": self.ds["input_text"][idx],
                "target_text": self.ds["target_text"][idx],
            }
        else:
            return {
                "input_ids": self.ds["input_ids"][idx],
                "attention_mask": self.ds["attention_mask"][idx],
                "target_ids": self.ds["input_ids"][idx],
                "target_attention_mask": self.ds["attention_mask"][idx],
            }

    def read(
        self,
    ):
        examples = []
        with open(self.data_path, "r") as json_file:
            x = list(json_file)
            # logger.debug(f"[data.py::build_dataset]{x=}")
            for json_str in x:
                examples.append(json.loads(json_str))
        return examples

    def extract_valid_pairs(self, samples):
        valid_questions = []
        for l in samples:
            # clear all docs with more or less than one answer
            # clean all docs with no annotations
            if len(l["annotations"][0]["short_answers"]) == 1:
                if len(l["long_answer_candidates"]) > 2:
                    valid_questions.append(l)
        return valid_questions

    def construct_ds(self, examples):

        self.ds["input_text"] = []
        self.ds["target_text"] = []
        for i, q in enumerate(examples):

            # fitting dataset; positive and negative fitting examples
            if random.randint(0, 1) == 1:
                # Construct positive example
                self.ds["input_text"].append(
                    f"question: {q['question_text']}  context: {self.get_long_answer(q)} </s>"
                )
                self.ds["target_text"].append(self.get_short_answer(q))
            else:
                # Construct negative example
                self.ds["input_text"].append(
                    f"question: {q['question_text']}  context: {self.get_random_negative(q)} </s>"
                )
                self.ds["target_text"].append("None </s>")
        assert len(self.ds["target_text"]) == len(
            self.ds["input_text"]
        ), "incorrect data distribution"

    def convert_to_features(self, batch):

        # tokenizer = T5Tokenizer.from_pretrained(args.model_name_or_path)  # "t5-small")
        ####################################################################################
        @dataclass
        class Args:
            input_max_len: int = 64
            output_max_len: int = 64

            ####################################################################################

        input_encodings = self.tokenizer.batch_encode_plus(
            batch_text_or_text_pairs=batch["input_text"],
            truncation=True,
            pad_to_max_length=True,
            max_length=Args().input_max_len,
        )
        target_encodings = self.tokenizer.batch_encode_plus(
            batch_text_or_text_pairs=batch["target_text"],
            truncation=True,
            pad_to_max_length=True,
            max_length=Args().output_max_len,
        )
        # print("input_encodings", input_encodings.keys())
        # print("target_encodings", target_encodings.keys())
        # encodings = {
        #     "input_ids": input_encodings["input_ids"],
        #     "attention_mask": input_encodings["attention_mask"],
        #     "target_ids": target_encodings["input_ids"],
        #     "target_attention_mask": target_encodings["attention_mask"],
        # }

        encodings = {
            "input_ids": torch.tensor(input_encodings["input_ids"]),
            "attention_mask": torch.tensor(input_encodings["attention_mask"]),
            "target_ids": torch.tensor(target_encodings["input_ids"]),
            "target_attention_mask": torch.tensor(target_encodings["attention_mask"]),
        }
        return encodings

    def get_exert(self, doc, start_token, end_token):
        return " ".join(doc.split(" ")[start_token:end_token])

    def get_short_answer(self, q):
        answer_indx = q["annotations"][0]["short_answers"][0]
        return self.get_exert(
            q["document_text"], answer_indx["start_token"], answer_indx["end_token"]
        )

    def get_long_answer(self, q):
        answer_indx = q["annotations"][0]["long_answer"]
        return self.get_exert(
            q["document_text"], answer_indx["start_token"], answer_indx["end_token"]
        )

    def get_random_negative(self, q):
        long_answer_indx = q["annotations"][0]["long_answer"]

        for i in range(len(q["long_answer_candidates"])):
            if (
                q["long_answer_candidates"][i]["start_token"]
                == long_answer_indx["start_token"]
            ):
                del q["long_answer_candidates"][i]
                break

        answer_indx = random.choice(q["long_answer_candidates"])
        return self.get_exert(
            q["document_text"], answer_indx["start_token"], answer_indx["end_token"]
        )


class TextDataModule(pl.LightningDataModule):
    def __init__(self, data_path, tokenizer):
        super().__init__()

        self.train_path = Path(data_path) / "50k.jsonl"
        self.val_path = Path(data_path) / "10k.jsonl"
        # NOTICE, we  re-use the validatio dataset
        self.test_path = Path(data_path) / "10k.jsonl"
        self.tokenizer = tokenizer
        self.pin_memory = False  # True if torch.cuda.is_available() else False

    def train_dataloader(self):
        ds_train = TextDataset(self.train_path, self.tokenizer)
        return DataLoader(
            ds_train,
            batch_size=32,
            shuffle=True,
            drop_last=True,
            pin_memory=self.pin_memory,
            collate_fn=T2TDataCollator(),
        )

    def val_dataloader(self):
        ds_val = TextDataset(self.val_path, self.tokenizer)
        return DataLoader(
            ds_val,
            batch_size=32,
            shuffle=True,
            drop_last=True,
            pin_memory=self.pin_memory,
            collate_fn=T2TDataCollator(),
        )

    def test_dataloader(self):

        ds_test = TextDataset(self.test_path, self.tokenizer)
        return DataLoader(
            ds_test,
            batch_size=32,
            shuffle=True,
            drop_last=True,
            pin_memory=self.pin_memory,
            collate_fn=T2TDataCollator(),
        )


# class Predictor(nn.Module):
#     def __init__(self, model):
#         super().__init__()
#         self.model = model

#     def forward(self, x):
#         logits = self.model(x)
#         pred = F.sigmoid(logits)
#         return pred


class Fitter:
    def __init__(
        self,
        model,
        tokenizer,
        model_name,
        data_path="/home/ola/Code/babl/inputs",
    ):
        self.model = model
        self.tokenizer = tokenizer
        self.model_name = model_name
        self.data_path = data_path

    def setup(self):
        data_module = TextDataModule(data_path=self.data_path, tokenizer=self.tokenizer)

        train_loader = data_module.train_dataloader()
        val_loader = data_module.val_dataloader()
        test_loader = data_module.test_dataloader()

        return train_loader, val_loader, test_loader

    def callbacks(self):
        # cfg_fitting = self.cfg_fitting
        callback_collection = CallbackCollection(self.model_name, self.data_path)
        return callback_collection()

    def __call__(self):

        ####################################################################################
        @dataclass
        class FittingArgs:
            es_patience: int = 5
            model_dir: str = str(Path("/home/ola/Code/babl/outputs") / self.model_name)

        ####################################################################################

        args = FittingArgs()

        logger = TensorBoardLogger(
            save_dir=args.model_dir,
            name="lightning_logs",
        )
        Model = self.model
        # get loaders and datamodule to access input shape
        train_loader, val_loader, test_loader = self.setup()
        print("Created training, validating and test loaders .... ")
        # get input shape for onnx exporting
        # input_shape = data_module.input_shape
        # init model
        # kwargs = {}
        # model = Model(**kwargs)

        # setup training, validating and testing routines for the model
        routine = Routine(self.model)

        # Init a trainer to execute routine
        callback_dict = self.callbacks()
        callback_list = [v for (_, v) in callback_dict.items()]
        number_devices = os.getenv("CUDA_VISIBLE_DEVICES", "1,").split(",")
        try:
            number_devices.remove("")
        except ValueError:
            pass

        ####################################################################################
        @dataclass
        class FittingArgs:
            max_epoch: int = 10
            fast_dev_run: bool = True

        ####################################################################################

        args = FittingArgs()
        trainer = Trainer(
            accelerator="gpu",
            devices=len(number_devices),
            strategy=os.getenv("STRATEGY", "ddp_notebook"),
            sync_batchnorm=True,
            logger=logger,
            max_epochs=args.max_epoch,
            callbacks=callback_list,
            num_sanity_val_steps=2,
            # resume_from_checkpoint=self.cfg_fitting.resume_from_checkpoint,
            gradient_clip_val=1.0,
            fast_dev_run=args.fast_dev_run,
        )

        trainer.fit(
            routine, train_dataloaders=train_loader, val_dataloaders=val_loader
        )  # ,ckpt_path=PATH)

        if args.fast_dev_run:
            # issue with finding best weights path for in fast dev run using last model weights
            model_ckpt_path = callback_dict["checkpoint"].__dict__["last_model_path"]
        else:
            model_ckpt_path = callback_dict["checkpoint"].__dict__["best_model_path"]

        trainer.test(
            dataloaders=test_loader,
            ckpt_path=model_ckpt_path,
        )
        # Return the input_shapes and trainer of the model for exporting
        return trainer

In [None]:
from babl.models import MODELS_CHOICES, MODELS
from pathlib import Path

model_name = "t5"
full_model_name = MODELS_CHOICES[model_name][0]
t_w_m = MODELS[model_name]

tokenizer = t_w_m["tok"]
model = t_w_m["model"]

t = tokenizer.from_pretrained(full_model_name)
m = model.from_pretrained(full_model_name)

data_path_root = Path("/home/ola/Code/babl/inputs")


# data_path_val = data_path_root / "10k.jsonl"
# ds = TextDataset(data_path_val, tokenizer=t, plain_text=False)
# from babl.data import T2TDataCollator
# from torch.utils.data import DataLoader
# t_dl = DataLoader(ds, batch_size=64, shuffle=True, collate_fn=T2TDataCollator())
# test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)


# data_module = TextDataModule(data_path, tokenizer)

Fitter(model=m, model_name=full_model_name, tokenizer=t, data_path=data_path_root)()

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
Running in `fast_dev_run` mode: will run the requested loop using 1 batch(es). Logging and checkpointing is suppressed.


Created training, validating and test loaders .... 


Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/1
----------------------------------------------------------------------------------------------------
distributed_backend=nccl
All distributed processes registered. Starting with 1 processes
----------------------------------------------------------------------------------------------------

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

  | Name  | Type                       | Params | Mode
------------------------------------------------------------
0 | model | T5ForConditionalGeneration | 60.5 M | eval
------------------------------------------------------------
60.5 M    Trainable params
0         Non-trainable params
60.5 M    Total params
242.026   Total estimated model params size (MB)
0         Modules in train mode
277       Modules in eval mode
/home/ola/.local/share/virtualenvs/babl-f0iBC2ut/lib/python3.12/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many 

Epoch 0:   0%|          | 0/1 [00:00<?, ?it/s] keys = dict_keys(['input_ids', 'attention_mask', 'labels', 'decoder_attention_mask'])
batch={'input_ids': tensor([[  822,    10,   125,  ...,   376,     3,     1],
        [  822,    10,   149,  ...,     9,  9369,     1],
        [  822,    10,    46,  ...,  4401,   725,     1],
        ...,
        [  822,    10,   113,  ...,    87,   382,     1],
        [  822,    10,   125,  ...,   341,  3223,     1],
        [  822,    10,   125,  ..., 19282,     3,     1]], device='cuda:0'), 'attention_mask': tensor([[1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        ...,
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1]], device='cuda:0'), 'labels': tensor([[  822,    10,   125,  ...,   376,     3,     1],
        [  822,    10,   149,  ...,     9,  9369,     1],
        [  822,    10,    46,  ...,  4401,   725,     1],
        ...,
        [  822, 

ProcessRaisedException: 

-- Process 0 terminated with the following error:
Traceback (most recent call last):
  File "/home/ola/.local/share/virtualenvs/babl-f0iBC2ut/lib/python3.12/site-packages/torch/multiprocessing/spawn.py", line 90, in _wrap
    fn(i, *args)
  File "/home/ola/.local/share/virtualenvs/babl-f0iBC2ut/lib/python3.12/site-packages/pytorch_lightning/strategies/launchers/multiprocessing.py", line 173, in _wrapping_function
    results = function(*args, **kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ola/.local/share/virtualenvs/babl-f0iBC2ut/lib/python3.12/site-packages/pytorch_lightning/trainer/trainer.py", line 575, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/home/ola/.local/share/virtualenvs/babl-f0iBC2ut/lib/python3.12/site-packages/pytorch_lightning/trainer/trainer.py", line 982, in _run
    results = self._run_stage()
              ^^^^^^^^^^^^^^^^^
  File "/home/ola/.local/share/virtualenvs/babl-f0iBC2ut/lib/python3.12/site-packages/pytorch_lightning/trainer/trainer.py", line 1026, in _run_stage
    self.fit_loop.run()
  File "/home/ola/.local/share/virtualenvs/babl-f0iBC2ut/lib/python3.12/site-packages/pytorch_lightning/loops/fit_loop.py", line 216, in run
    self.advance()
  File "/home/ola/.local/share/virtualenvs/babl-f0iBC2ut/lib/python3.12/site-packages/pytorch_lightning/loops/fit_loop.py", line 455, in advance
    self.epoch_loop.run(self._data_fetcher)
  File "/home/ola/.local/share/virtualenvs/babl-f0iBC2ut/lib/python3.12/site-packages/pytorch_lightning/loops/training_epoch_loop.py", line 150, in run
    self.advance(data_fetcher)
  File "/home/ola/.local/share/virtualenvs/babl-f0iBC2ut/lib/python3.12/site-packages/pytorch_lightning/loops/training_epoch_loop.py", line 320, in advance
    batch_output = self.automatic_optimization.run(trainer.optimizers[0], batch_idx, kwargs)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ola/.local/share/virtualenvs/babl-f0iBC2ut/lib/python3.12/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 192, in run
    self._optimizer_step(batch_idx, closure)
  File "/home/ola/.local/share/virtualenvs/babl-f0iBC2ut/lib/python3.12/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 270, in _optimizer_step
    call._call_lightning_module_hook(
  File "/home/ola/.local/share/virtualenvs/babl-f0iBC2ut/lib/python3.12/site-packages/pytorch_lightning/trainer/call.py", line 171, in _call_lightning_module_hook
    output = fn(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^
  File "/home/ola/.local/share/virtualenvs/babl-f0iBC2ut/lib/python3.12/site-packages/pytorch_lightning/core/module.py", line 1302, in optimizer_step
    optimizer.step(closure=optimizer_closure)
  File "/home/ola/.local/share/virtualenvs/babl-f0iBC2ut/lib/python3.12/site-packages/pytorch_lightning/core/optimizer.py", line 154, in step
    step_output = self._strategy.optimizer_step(self._optimizer, closure, **kwargs)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ola/.local/share/virtualenvs/babl-f0iBC2ut/lib/python3.12/site-packages/pytorch_lightning/strategies/ddp.py", line 270, in optimizer_step
    optimizer_output = super().optimizer_step(optimizer, closure, model, **kwargs)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ola/.local/share/virtualenvs/babl-f0iBC2ut/lib/python3.12/site-packages/pytorch_lightning/strategies/strategy.py", line 239, in optimizer_step
    return self.precision_plugin.optimizer_step(optimizer, model=model, closure=closure, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ola/.local/share/virtualenvs/babl-f0iBC2ut/lib/python3.12/site-packages/pytorch_lightning/plugins/precision/precision.py", line 123, in optimizer_step
    return optimizer.step(closure=closure, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ola/.local/share/virtualenvs/babl-f0iBC2ut/lib/python3.12/site-packages/torch/optim/optimizer.py", line 487, in wrapper
    out = func(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^
  File "/home/ola/.local/share/virtualenvs/babl-f0iBC2ut/lib/python3.12/site-packages/torch/optim/optimizer.py", line 91, in _use_grad
    ret = func(self, *args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ola/.local/share/virtualenvs/babl-f0iBC2ut/lib/python3.12/site-packages/torch/optim/adamw.py", line 197, in step
    loss = closure()
           ^^^^^^^^^
  File "/home/ola/.local/share/virtualenvs/babl-f0iBC2ut/lib/python3.12/site-packages/pytorch_lightning/plugins/precision/precision.py", line 109, in _wrap_closure
    closure_result = closure()
                     ^^^^^^^^^
  File "/home/ola/.local/share/virtualenvs/babl-f0iBC2ut/lib/python3.12/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 146, in __call__
    self._result = self.closure(*args, **kwargs)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ola/.local/share/virtualenvs/babl-f0iBC2ut/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/ola/.local/share/virtualenvs/babl-f0iBC2ut/lib/python3.12/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 131, in closure
    step_output = self._step_fn()
                  ^^^^^^^^^^^^^^^
  File "/home/ola/.local/share/virtualenvs/babl-f0iBC2ut/lib/python3.12/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 319, in _training_step
    training_step_output = call._call_strategy_hook(trainer, "training_step", *kwargs.values())
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ola/.local/share/virtualenvs/babl-f0iBC2ut/lib/python3.12/site-packages/pytorch_lightning/trainer/call.py", line 323, in _call_strategy_hook
    output = fn(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^
  File "/home/ola/.local/share/virtualenvs/babl-f0iBC2ut/lib/python3.12/site-packages/pytorch_lightning/strategies/strategy.py", line 390, in training_step
    return self._forward_redirection(self.model, self.lightning_module, "training_step", *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ola/.local/share/virtualenvs/babl-f0iBC2ut/lib/python3.12/site-packages/pytorch_lightning/strategies/strategy.py", line 641, in __call__
    wrapper_output = wrapper_module(*args, **kwargs)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ola/.local/share/virtualenvs/babl-f0iBC2ut/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ola/.local/share/virtualenvs/babl-f0iBC2ut/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ola/.local/share/virtualenvs/babl-f0iBC2ut/lib/python3.12/site-packages/torch/nn/parallel/distributed.py", line 1643, in forward
    else self._run_ddp_forward(*inputs, **kwargs)
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ola/.local/share/virtualenvs/babl-f0iBC2ut/lib/python3.12/site-packages/torch/nn/parallel/distributed.py", line 1459, in _run_ddp_forward
    return self.module(*inputs, **kwargs)  # type: ignore[index]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ola/.local/share/virtualenvs/babl-f0iBC2ut/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ola/.local/share/virtualenvs/babl-f0iBC2ut/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ola/.local/share/virtualenvs/babl-f0iBC2ut/lib/python3.12/site-packages/pytorch_lightning/strategies/strategy.py", line 634, in wrapped_forward
    out = method(*_args, **_kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/ipykernel_61659/3088556754.py", line 30, in training_step
    y_hat = self(**batch)
            ^^^^^^^^^^^^^
  File "/home/ola/.local/share/virtualenvs/babl-f0iBC2ut/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ola/.local/share/virtualenvs/babl-f0iBC2ut/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: Routine.forward() got an unexpected keyword argument 'input_ids'


In [24]:
for d in t_dl:
    print(d)

{'input_ids': tensor([[  822,    10,   149,  ...,  7365,  8244,     1],
        [  822,    10,   125,  ..., 25930,    35,     1],
        [  822,    10,   113,  ...,     0,     0,     0],
        ...,
        [  822,    10,   113,  ...,  2390,   789,     1],
        [  822,    10,   116,  ...,     5,   299,     1],
        [  822,    10,   116,  ...,    87,   382,     1]]), 'attention_mask': tensor([[1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 0, 0, 0],
        ...,
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1]]), 'labels': tensor([[  822,    10,   149,  ...,  7365,  8244,     1],
        [  822,    10,   125,  ..., 25930,    35,     1],
        [  822,    10,   113,  ...,  -100,  -100,  -100],
        ...,
        [  822,    10,   113,  ...,  2390,   789,     1],
        [  822,    10,   116,  ...,     5,   299,     1],
        [  822,    10,   116,  ...,    87,   382,     1]]), 'decoder_atte

In [None]:
ds.ds["input_text"].__len__()

In [15]:
list({"x": [1, 2, 3, 4]}.values())[0].__len__()

4