In [None]:
%%capture
!pip install torchmetrics
!pip install wandb
!pip install pytorch_lightning==1.6.0
!pip install transformers

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchmetrics
import wandb
import numpy as np
import pandas as pd
import pytorch_lightning as pl
from transformers import AutoModelForSequenceClassification
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import accuracy_score
from sklearn.metrics import confusion_matrix
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer
from pytorch_lightning.loggers import WandbLogger

In [None]:
print(pl.__version__)

1.6.0


In [None]:
class colaModel(pl.LightningModule):
    def __init__(self, model = "google/bert_uncased_L-2_H-128_A-2", lr = 3e-5):
        super(colaModel, self).__init__()
        self.lr = lr
        self.save_hyperparameters()

        self.num_classes = 2
        self.model = AutoModelForSequenceClassification.from_pretrained(model, num_labels= self.num_classes)
        # self.linear = nn.Linear(self.model.config.hidden_size, self.num_classes)


        self.train_accuracy_metric = torchmetrics.Accuracy(task="binary")
        self.val_accuracy_metric = torchmetrics.Accuracy(task="binary")
        self.f1_metric = torchmetrics.F1Score(num_classes = self.num_classes, task="binary")
        self.precision_macro_metric = torchmetrics.Precision(
            average = "macro", num_classes = self.num_classes, task="binary"
        )
        self.recall_macro_metric = torchmetrics.Recall(
            average = "macro", num_classes = self.num_classes, task="binary"
        )

        self.precision_micro_metric = torchmetrics.Precision(average = "micro", task="binary")
        self.recall_micro_metric = torchmetrics.Recall(average = "micro", task="binary")

    def forward(self, input_ids, attention_mask, labels = None):
       outputs = self.model(input_ids = input_ids, attention_mask = attention_mask,
                       labels = labels)
       return outputs

    def training_step(self, batch, batch_index):

       outputs = self.forward(input_ids= batch["input_ids"],
                              attention_mask= batch["attention_mask"],
                              labels= batch["label"])
       predictions = torch.argmax(outputs.logits, dim=1)
       train_acc = self.train_accuracy_metric(predictions, batch["label"])
       self.log("train/loss", outputs.loss, prog_bar = True, on_epoch = True)
       self.log("train/acc", train_acc, prog_bar = True, on_epoch = True)
       return outputs.loss

    def validation_step(self, batch, batch_index):
        labels = batch["label"]
        outputs = self.forward(input_ids = batch["input_ids"],
                               attention_mask= batch["attention_mask"],
                               labels = labels)
        preds = torch.argmax(outputs.logits, 1)

        # calculate metrics
        valid_acc = self.val_accuracy_metric(preds, labels)
        precision_macro = self.precision_macro_metric(preds, labels)
        recall_macro = self.recall_macro_metric(preds, labels)
        precision_micro = self.precision_micro_metric(preds, labels)
        recall_micro = self.recall_micro_metric(preds, labels)
        f1 = self.f1_metric(preds, labels)

        # log all these metrics
        self.log("valid/loss", outputs.loss, prog_bar = True, on_step = True)
        self.log("valid/acc", valid_acc, prog_bar = True, on_epoch = True)
        self.log("valid/precision_macro", precision_macro, prog_bar = True, on_epoch = True)
        self.log("valid/recall_macro", recall_macro, prog_bar = True, on_epoch = True)
        self.log("valid/precision_micro", precision_micro , prog_bar = True, on_epoch = True)
        self.log("valid/recall_micro",recall_micro , prog_bar = True, on_epoch = True)
        self.log("valid/f1",f1 , prog_bar = True, on_epoch = True)
        return {"labels": labels, "logits": outputs.logits}

    def validation_epoch_end(self, outputs):

        labels = torch.cat([x["labels"] for x in outputs])
        logits = torch.cat([x["logits"] for x in outputs])

        print('-------------------------------')
        # plot confusino matrix on w&b
        self.logger.experiment.log(
            {
                "conf": wandb.plot.confusion_matrix(
            probs = logits.numpy(), y_true = labels.numpy()
                )
            }
        )

    def configure_optimizers(self):
        return torch.optim.Adam(self.model.parameters(), lr = self.hparams["lr"])

In [None]:
# Dummy Dataset
class DummyCoLADataset(Dataset):
    def __len__(self):
        return 1000

    def __getitem__(self, idx):
        # Generating random tokens in the range of valid token IDs (Assuming using BERT's tokenizer)
        input_ids = torch.randint(0, 30000, (128,))
        attention_mask = torch.ones(128,)
        label = torch.tensor(0) if idx % 2 == 0 else torch.tensor(1)
        return {"input_ids": input_ids, "attention_mask": attention_mask, "label": label}

class DummyDataModule(pl.LightningDataModule):
    def __init__(self, tokenizer, batch_size=32):
        super().__init__()
        self.batch_size = batch_size
        self.tokenizer = tokenizer

    def train_dataloader(self):
        return DataLoader(DummyCoLADataset(), batch_size=self.batch_size)

    def val_dataloader(self):
        return DataLoader(DummyCoLADataset(), batch_size=self.batch_size)

# Initialize the model, data, and trainer, then run!
def run_model():
    wandb_logger = WandbLogger(offline=True)

    model_name = "google/bert_uncased_L-2_H-128_A-2"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = colaModel(model=model_name)
    data = DummyDataModule(tokenizer=tokenizer)

    trainer = pl.Trainer(max_epochs=3, progress_bar_refresh_rate=20, logger=wandb_logger)
    trainer.fit(model, datamodule=data)
run_model()



Downloading (…)lve/main/config.json:   0%|          | 0.00/382 [00:00<?, ?B/s]

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/17.7M [00:00<?, ?B/s]

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at google/bert_uncased_L-2_H-128_A-2 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  rank_zero_deprecation(
INFO:pytorch_lightning.utilities.rank_zero:GPU available: False, used: False
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.callbacks.model_summary:
  | Name                   | Type                          | Params
-------------------------------------------------------------------------
0 | model                  | BertForSequenceClassification | 4.4 M 
1 | train_accuracy_metric  | BinaryAccuracy                | 0     
2 | val_accuracy_metr

Sanity Checking: 0it [00:00, ?it/s]

-------------------------------


  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

-------------------------------


Validation: 0it [00:00, ?it/s]

-------------------------------


Validation: 0it [00:00, ?it/s]

-------------------------------
