In [1]:
import lightning.pytorch as pl
import torch
import torch.nn as nn
from bs4 import BeautifulSoup
from omegaconf import OmegaConf
from src.cells.utils.compile_utils import torch_compile
from torch.nn.utils.rnn import pad_sequence
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader

torch.set_float32_matmul_precision("medium")
torch._dynamo.config.suppress_errors = True
# Silence all warnings
import warnings

import torch.nn.functional as F
from lightning.pytorch.callbacks import (
    EarlyStopping,
    GradientAccumulationScheduler,
    LearningRateFinder,
    LearningRateMonitor,
    ModelCheckpoint,
    StochasticWeightAveraging,
)
from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger

warnings.filterwarnings("ignore")

In [2]:
class IMDBDataLoader(pl.LightningDataModule):
    def __init__(self, dataset_path, tokenizer_path, batch_size, num_workers, max_len):
        super().__init__()
        self.dataset_path = dataset_path

        self.batch_size = batch_size
        self.num_workers = num_workers
        self.max_len = max_len

        self.tokenizer = self._load_tokenizer(tokenizer_path)

    def prepare_data(self):
        from datasets import load_dataset

        self.ds = load_dataset(self.dataset_path)
        # self.ds = self.ds.map(lambda example : {'text':self._remove_html_tags(example['text'])},num_proc=self.num_workers,)
        self.label_map = {0: "neg", 1: "pos"}

    @staticmethod
    def _load_tokenizer(tokenizer_path):
        from src.tokenize.tokenizer import Tokenizer

        return Tokenizer(tokenizer_path)

    @staticmethod
    def _remove_html_tags(text):
        soup = BeautifulSoup(text, "html.parser")
        # Get the text without HTML tags
        clean_text = soup.get_text()
        return clean_text

    def _collate_fn(self, batch):
        x, y = [self._remove_html_tags(en["text"]) for en in batch], [en["label"] for en in batch]
        x = [torch.tensor(tokens[: self.max_len]) for tokens in self.tokenizer.encode_as_ids(x)]
        x = pad_sequence(
            x,
            batch_first=True,
            padding_value=self.tokenizer.eos_id(),
        )
        y = torch.tensor(y)
        return x, y

    def setup(self, stage):

        self.train_data = self.ds["train"]
        self.val_data = self.ds["test"]
        self.test_data = self.ds["unsupervised"]

    def train_dataloader(self):
        return DataLoader(
            self.train_data,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            pin_memory=True,
            collate_fn=self._collate_fn,
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_data,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            pin_memory=True,
            collate_fn=self._collate_fn,
        )

    def test_dataloader(self):
        return DataLoader(
            self.val_data,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            pin_memory=True,
            collate_fn=self._collate_fn,
        )

In [3]:
ds = IMDBDataLoader(
    "stanfordnlp/imdb",
    "/home/pranav-pc/projects/OpenTransformer/multiformer/tokenizer_checkpoints/",
    16,
    25,
    1024,
)

In [4]:
# Sanity
# ds.prepare_data()
# ds.setup('train')
# for idx,label in ds.train_dataloader():
#     print(idx.shape,label)
#     break

In [5]:
import torchmetrics

In [6]:
class BLMClassifierModel(pl.LightningModule):
    def __init__(self, model, learning_rate=5e-5, num_classes=2, embedding_dim=768):
        super().__init__()
        self.learning_rate = learning_rate
        self.model = model

        self.val_acc = torchmetrics.Accuracy(task="multiclass", num_classes=2)
        self.train_acc = torchmetrics.Accuracy(task="multiclass", num_classes=2)

        self.val_f1_sccore = torchmetrics.F1Score(task="multiclass", num_classes=2)
        self.train_f1_sccore = torchmetrics.F1Score(task="multiclass", num_classes=2)

        self.classifer = nn.Sequential(
            nn.Linear(embedding_dim, embedding_dim),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(embedding_dim, num_classes),
        )

    def forward(self, input_ids):
        return self.model(input_ids)

    def _common_step(self, batch):
        x, label = batch
        hidden_state = self.forward(x)
        logits = self.classifer(hidden_state[:, -1, :])

        return logits

    def training_step(self, batch, batch_idx):
        x, targets = batch
        logits = self._common_step(batch)
        loss = F.cross_entropy(logits, targets)
        preds = torch.argmax(logits, dim=1)
        val_acc = self.train_acc(preds, targets)
        val_f1score = self.train_f1_sccore(preds, targets)
        self.log_dict(
            {"train_loss": loss, "train_acc": val_acc, "train_f1score": val_f1score},
            prog_bar=False,
        )
        return loss

    def validation_step(self, batch, batch_idx):
        x, targets = batch
        logits = self._common_step(batch)

        loss = F.cross_entropy(logits, targets)
        preds = torch.argmax(logits, dim=1)
        val_acc = self.val_acc(preds, targets)
        val_f1score = self.val_f1_sccore(preds, targets)
        self.log_dict(
            {"val_loss": loss, "val_acc": val_acc, "val_f1score": val_f1score}, prog_bar=True
        )
        return loss

    def predict_step(self, idx):
        batch = idx, None
        logits = self._common_step(batch)
        preds = torch.argmax(logits, dim=1)
        return preds

    def configure_optimizers(self):
        param_dict = {pn: p for pn, p in self.named_parameters()}
        param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
        decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
        nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
        optim_groups = [
            {"params": decay_params, "weight_decay": 1e-2},
            {"params": nodecay_params, "weight_decay": 0.0},
        ]
        lr_scheduler_init = {"T_max": 1e04, "eta_min": 1e-04}
        optimizer = torch.optim.AdamW(
            optim_groups, lr=self.learning_rate, betas=(0.9, 0.95), fused=False
        )
        scheduler = {
            "scheduler": CosineAnnealingLR(optimizer, **lr_scheduler_init),
            "interval": "step",
            "frequency": 10,
        }
        return [optimizer], [scheduler]

In [7]:
def main(args):

    ds = IMDBDataLoader(
        args.files.data_path,
        args.files.tokenizer_path,
        args.trainer_params.batch_size,
        args.trainer_params.num_workers,
        1024,
    )

    from src.models.blm.pl_training import Transformer

    MODEL_CHECKPOINT = args.paths.base_model_checkpoint
    base_model = Transformer.load_from_checkpoint(MODEL_CHECKPOINT)

    for param in base_model.parameters():
        param.requires_grad = False

    model = BLMClassifierModel(base_model)
    if args.trainer_params.resume_training:
        model.load_state_dict(torch.load(args.paths.resume_from_checkpoint)["state_dict"])
    # model = torch_compile(model, dynamic=True, TORCH_COMPILE_BACKEND="inductor")
    accumulator = GradientAccumulationScheduler(
        scheduling=args.trainer_params.gradient_accumulation_scheduler
    )

    logger = TensorBoardLogger(save_dir="./lightning-log-ft-imdb/", name="IMDB", version=0.3)

    if args.trainer_params.wandb_enabled:
        import wandb

        print("W&B")
        wandb.login()
        logger = WandbLogger(**args.trainer_params.wandb)

    checkpoint_callback = ModelCheckpoint(**args.trainer_params.checkpoint)
    early_stop = EarlyStopping(**args.trainer_params.earlystopping)
    stochastic_weight_avg = StochasticWeightAveraging(swa_lrs=1e-6)

    trainer = pl.Trainer(
        logger=logger,
        **args.trainer_params.trainer,
        callbacks=[
            early_stop,
            checkpoint_callback,
            accumulator,
            LearningRateMonitor(logging_interval="step"),
            stochastic_weight_avg,
            # DeviceStatsMonitor()
        ],
    )
    model.train()
    trainer.fit(model, ds)

    return model, trainer


config_path = "/home/pranav-pc/projects/OpenTransformer/multiformer/src/models/blm/conf/finetune-imdb-classifier.yaml"
args = OmegaConf.load(config_path)
model, trainer = main(args)

Seed set to 123
Seed set to 123
Using bfloat16 Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name            | Type               | Params | Mode 
---------------------------------------------------------------
0 | model           | Transformer        | 43.2 M | train
1 | val_acc         | MulticlassAccuracy | 0      | train
2 | train_acc       | MulticlassAccuracy | 0      | train
3 | val_f1_sccore   | MulticlassF1Score  | 0      | train
4 | train_f1_sccore | MulticlassF1Score  | 0      | train
5 | classifer       | Sequential         | 592 K  | train
---------------------------------------------------------------
592 K     Trainable params
43.2 M    Non-trainable params
43.8 M    Total params
175.227   Total estimated model params size (MB)


Sanity Checking: |                                        | 0/? [00:00<?, ?it/s]

Training: |                                               | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Metric val_f1score improved. New best score: 0.607


Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Monitored metric val_f1score did not improve in the last 10 records. Best score: 0.607. Signaling Trainer to stop.


In [8]:
trainer.validate(model, ds)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Validation: |                                             | 0/? [00:00<?, ?it/s]

[{'val_loss': 0.5857179164886475,
  'val_acc': 0.6981199979782104,
  'val_f1score': 0.6981199979782104}]

In [9]:
## Why acc and f1 score always return same value? There seems something wrong

In [15]:
model.model.eval()
model.model.cuda()

corpus = [
    "Jack was hungry, so he went looking for",
    "Once upon a time there was a pumpkin. It was a very special pumpkin, it could speak. It was sad because it couldn’t move. Every day, it would say",
]

# print(tokens)
for text in corpus:
    tokens = torch.LongTensor(ds.tokenizer.encode(text)).to("cuda:0").view(1, -1)[:, :-1]
    for _ in range(3):
        print(
            ds.tokenizer.decode_ids(
                model.model.predict_step(
                    tokens,
                    None,
                    max_new_tokens=1000,
                    temperature=1.1,
                    top_k=2,
                    conditional_break=[13, 13],
                )[0].tolist()
            )
        )
        print("==" * 20)
    print("==" * 30)

Jack was hungry, so he went looking for something to eat. He went to the kitchen and saw a plate on the table. He picked it up and looked around. Suddenly, he saw a big bowl of food. Jack was so excited that he started jumping up and down. He grabbed the plate and ran to the living room.


Jack was hungry, so he went looking for something to eat. He looked in the kitchen, but he couldn't find anything. Then he heard a noise coming from the living room. He went to investigate and saw a big box on the floor. He opened the box and found a delicious looking cake. He was so excited that he started to eat it right away. 


Jack was hungry, so he went looking for something to eat. He went to the kitchen and opened the cupboard. He looked around and saw a big bowl of food. He was so happy! He took the bowl and started to eat it.


Once upon a time there was a pumpkin. It was a very special pumpkin, it could speak. It was sad because it couldn’t move. Every day, it would say "hello" and "goodby