In [1]:
import lightning.pytorch as pl
import torch
import torch.nn as nn
from bs4 import BeautifulSoup
from omegaconf import OmegaConf
from src.cells.utils.compile_utils import torch_compile
from torch.nn.utils.rnn import pad_sequence
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader

torch.set_float32_matmul_precision("medium")
torch._dynamo.config.suppress_errors = True
# Silence all warnings
import warnings

import torch.nn.functional as F
from lightning.pytorch.callbacks import (
    EarlyStopping,
    GradientAccumulationScheduler,
    LearningRateFinder,
    LearningRateMonitor,
    ModelCheckpoint,
    StochasticWeightAveraging,
)
from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger

warnings.filterwarnings("ignore")

In [2]:
class IMDBDataLoader(pl.LightningDataModule):
    def __init__(self, dataset_path, tokenizer_path, batch_size, num_workers, max_len):
        super().__init__()
        self.dataset_path = dataset_path

        self.batch_size = batch_size
        self.num_workers = num_workers
        self.max_len = max_len

        self.tokenizer = self._load_tokenizer(tokenizer_path)

    def prepare_data(self):
        from datasets import load_dataset

        self.ds = load_dataset(self.dataset_path)
        # self.ds = self.ds.map(lambda example : {'text':self._remove_html_tags(example['text'])},num_proc=self.num_workers,)
        self.label_map = {0: "neg", 1: "pos"}

    @staticmethod
    def _load_tokenizer(tokenizer_path):
        from src.tokenize.tokenizer import Tokenizer

        return Tokenizer(tokenizer_path)

    @staticmethod
    def _remove_html_tags(text):
        soup = BeautifulSoup(text, "html.parser")
        # Get the text without HTML tags
        clean_text = soup.get_text()
        return clean_text

    def _collate_fn(self, batch):
        x, y = [self._remove_html_tags(en["text"]) for en in batch], [en["label"] for en in batch]
        x = [torch.tensor(tokens[: self.max_len]) for tokens in self.tokenizer.encode_as_ids(x)]
        x = pad_sequence(
            x,
            batch_first=True,
            padding_value=self.tokenizer.eos_id(),
        )
        y = torch.tensor(y)
        return x, y

    def setup(self, stage):

        self.train_data = self.ds["train"]
        self.val_data = self.ds["test"]
        self.test_data = self.ds["unsupervised"]

    def train_dataloader(self):
        return DataLoader(
            self.train_data,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            pin_memory=True,
            collate_fn=self._collate_fn,
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_data,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            pin_memory=True,
            collate_fn=self._collate_fn,
        )

    def test_dataloader(self):
        return DataLoader(
            self.val_data,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            pin_memory=True,
            collate_fn=self._collate_fn,
        )

In [3]:
# Sanity
# ds.prepare_data()
# ds.setup('train')
# for idx,label in ds.train_dataloader():
#     print(idx.shape,label)
#     break

In [4]:
import torchmetrics

In [5]:
class BLMClassifierModel(pl.LightningModule):
    def __init__(self, model, learning_rate=5e-5, num_classes=2, embedding_dim=768):
        super().__init__()
        self.learning_rate = learning_rate
        self.model = model

        self.acc = torchmetrics.Accuracy(task="multiclass", num_classes=2)
        self.f1_score = torchmetrics.F1Score(task="multiclass", num_classes=2)
        self.classifer = nn.Linear(embedding_dim, num_classes)

    def forward(self, input_ids):
        return self.model(input_ids)

    def _common_step(self, batch):
        x, label = batch
        hidden_state = self.forward(x)
        logits = self.classifer(hidden_state[:, -1, :])
        return logits

    def training_step(self, batch, batch_idx):
        x, targets = batch
        logits = self._common_step(batch)
        loss = F.cross_entropy(logits, targets)
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, targets = batch
        logits = self._common_step(batch)

        loss = F.cross_entropy(logits, targets)
        preds = torch.argmax(logits, dim=1)
        val_acc = self.acc(preds, targets)
        val_f1score = self.f1_score(preds, targets)
        self.log_dict(
            {"val_loss": loss, "val_acc": val_acc, "val_f1score": val_f1score}, prog_bar=True
        )
        return loss

    # def test_step(self, batch, batch_idx):
    #     x, targets = batch
    #     logits = self._common_step(batch)
    #     loss = F.cross_entropy(logits, targets)
    #     preds = torch.argmax(logits, dim=1)
    #     test_acc = self.acc(preds, targets)
    #     test_f1score = self.f1_score(preds, targets)
    #     self.log_dict({"test_loss": loss,"test_acc":test_acc,"test_f1score":test_f1score}, prog_bar=False)
    #     return loss

    def configure_optimizers(self):
        param_dict = {pn: p for pn, p in self.named_parameters()}
        param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
        decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
        nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
        optim_groups = [
            {"params": decay_params, "weight_decay": 1e-2},
            {"params": nodecay_params, "weight_decay": 0.0},
        ]
        lr_scheduler_init = {"T_max": 1e04, "eta_min": 1e-04}
        optimizer = torch.optim.AdamW(
            optim_groups, lr=self.learning_rate, betas=(0.9, 0.95), fused=False
        )
        scheduler = {
            "scheduler": CosineAnnealingLR(optimizer, **lr_scheduler_init),
            "interval": "step",
            "frequency": 10,
        }
        return [optimizer], [scheduler]

In [6]:
def main(args, train=True):

    ds = IMDBDataLoader(
        args.files.data_path,
        args.files.tokenizer_path,
        args.trainer_params.batch_size,
        args.trainer_params.num_workers,
        1024,
    )

    from src.models.blm.pl_training import Transformer

    MODEL_CHECKPOINT = args.paths.base_model_checkpoint

    base_model = Transformer.load_from_checkpoint(MODEL_CHECKPOINT)

    model = BLMClassifierModel(base_model)
    if args.trainer_params.resume_training:
        model.load_state_dict(torch.load(args.paths.resume_from_checkpoint)["state_dict"])
    # model = torch_compile(model, dynamic=True, TORCH_COMPILE_BACKEND="inductor")
    accumulator = GradientAccumulationScheduler(
        scheduling=args.trainer_params.gradient_accumulation_scheduler
    )

    logger = TensorBoardLogger(save_dir="./lightning-log-ft-imdb/", name="IMDB", version=0.1)

    if args.trainer_params.wandb_enabled:
        import wandb

        print("W&B")
        wandb.login()
        logger = WandbLogger(**args.trainer_params.wandb)

    checkpoint_callback = ModelCheckpoint(**args.trainer_params.checkpoint)
    early_stop = EarlyStopping(**args.trainer_params.earlystopping)
    stochastic_weight_avg = StochasticWeightAveraging(swa_lrs=1e-6)

    trainer = pl.Trainer(
        logger=logger,
        **args.trainer_params.trainer,
        callbacks=[
            early_stop,
            checkpoint_callback,
            accumulator,
            LearningRateMonitor(logging_interval="step"),
            stochastic_weight_avg,
            # DeviceStatsMonitor()
        ],
    )
    if train:
        trainer.fit(model, ds)
    return ds, model, trainer


config_path = "/home/pranav-pc/projects/OpenTransformer/multiformer/src/models/blm/conf/finetune-imdb-classifier.yaml"
args = OmegaConf.load(config_path)
ds, model, trainer = main(args, train=False)

Seed set to 123
Seed set to 123
Using bfloat16 Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [7]:
trainer.validate(model, ds)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Validation: |                                             | 0/? [00:00<?, ?it/s]

[{'val_loss': 1.1901986598968506,
  'val_acc': 0.8829200267791748,
  'val_f1score': 0.8829200267791748}]

In [31]:
model.model.eval()
model.model.cuda()

corpus = [
    "Jack was hungry, so he went looking for",
    "Once upon a time there was a pumpkin. It was a very special pumpkin, it could speak. It was sad because it couldn’t move. Every day, it would say",
]

# print(tokens)
for text in corpus:
    tokens = torch.LongTensor(ds.tokenizer.encode(text)).to("cuda:0").view(1, -1)[:, :-1]
    for _ in range(3):
        print(
            ds.tokenizer.decode_ids(
                model.model.predict_step(
                    tokens,
                    None,
                    max_new_tokens=1000,
                    temperature=0.9,
                    top_k=2,
                    conditional_break=[13, 13],
                )[0].tolist()
            )
        )
        print("==" * 20)
    print("==" * 30)

Jack was hungry, so he went looking for something to eat. He saw a big, red apple on the ground and he wanted to eat it. He picked up the apple and took a bite. It was so juicy and yummy! Jack was so happy. He ate the apple all up and then he started to play. He ran and jumped around the garden, and he even tried to catch the apple in his mouth.


Jack was hungry, so he went looking for something to eat. He looked around the kitchen and he saw a big bowl of cereal. He thought it would be a good snack. He opened the cupboard and saw a big bowl of cereal. He was so excited! He grabbed the big spoon and started to eat the cereal.


Jack was hungry, so he went looking for something. He looked in the kitchen and the cupboard, and he looked in his room, but he couldn't find what he was looking for. He looked in the kitchen and the cupboard, but there was nothing there either.


Once upon a time there was a pumpkin. It was a very special pumpkin, it could speak. It was sad because it couldn’t

In [28]:
## Let's make sure that inference is done on the new finetunned weights

In [26]:
for name, value in model.model.named_parameters():
    print(name, value)

tok_embd.weight Parameter containing:
tensor([[-4.0612e-02, -1.0990e-02, -2.8261e-02,  ...,  3.4319e-02,
         -3.4704e-02, -1.8701e-02],
        [-2.1640e-03,  1.2131e-02, -1.7759e-02,  ..., -1.2359e-04,
         -3.9344e-02,  2.9360e-02],
        [-8.2522e-02, -1.5970e-02, -5.1656e-02,  ..., -1.5299e-01,
         -4.8903e-02,  9.0692e-02],
        ...,
        [-4.0577e-02, -1.0985e-02, -2.8255e-02,  ...,  3.4327e-02,
         -3.4681e-02, -1.8714e-02],
        [-4.0572e-02, -1.0950e-02, -2.8243e-02,  ...,  3.4374e-02,
         -3.4712e-02, -1.8676e-02],
        [-4.0550e-02, -1.0993e-02, -2.8262e-02,  ...,  3.4358e-02,
         -3.4678e-02, -1.8740e-02]], device='cuda:0', requires_grad=True)
layers.0.norms.w Parameter containing:
tensor([0.1576, 0.1675, 0.1600, 0.1530, 0.1468, 0.1576, 0.1658, 0.2789, 0.1622,
        0.1383, 0.0581, 0.1467, 0.1353, 0.1373, 0.1606, 0.1574, 0.1651, 0.1377,
        0.1656, 0.1322, 0.1720, 0.1415, 0.1542, 0.1247, 0.1428, 0.1383, 0.1361,
        0.2108

In [27]:
from src.models.blm.pl_training import Transformer

MODEL_CHECKPOINT = args.paths.base_model_checkpoint

base_model = Transformer.load_from_checkpoint(MODEL_CHECKPOINT)

for name, value in base_model.named_parameters():
    print(name, value)

tok_embd.weight Parameter containing:
tensor([[-0.0407, -0.0110, -0.0283,  ...,  0.0344, -0.0348, -0.0187],
        [-0.0024,  0.0108, -0.0182,  ...,  0.0005, -0.0436,  0.0270],
        [-0.0827, -0.0160, -0.0518,  ..., -0.1533, -0.0490,  0.0909],
        ...,
        [-0.0407, -0.0110, -0.0283,  ...,  0.0344, -0.0348, -0.0188],
        [-0.0407, -0.0110, -0.0283,  ...,  0.0344, -0.0348, -0.0187],
        [-0.0406, -0.0110, -0.0283,  ...,  0.0344, -0.0347, -0.0188]],
       device='cuda:0', requires_grad=True)
layers.0.norms.w Parameter containing:
tensor([0.1570, 0.1649, 0.1578, 0.1534, 0.1473, 0.1549, 0.1634, 0.2792, 0.1523,
        0.1406, 0.0617, 0.1445, 0.1376, 0.1397, 0.1580, 0.1514, 0.1633, 0.1397,
        0.1655, 0.1390, 0.1695, 0.1402, 0.1481, 0.1264, 0.1438, 0.1410, 0.1371,
        0.2102, 0.1409, 0.1850, 0.1454, 0.1394, 0.1555, 0.1386, 0.1472, 0.1209,
        0.2077, 0.1555, 0.1558, 0.1519, 0.1542, 0.1612, 0.1654, 0.1531, 0.1833,
        0.1591, 0.1632, 0.1340, 0.1568, 0.145