In [10]:
# Imports
from datasets import load_dataset
from transformers import AutoTokenizer,AutoModelForSequenceClassification, DataCollatorWithPadding

import torch 
from torch import nn
from torch.utils.data import DataLoader

import lightning.pytorch as pl
import torchmetrics

import os

from tqdm.notebook import tqdm

In [2]:
# Silence warnings
TOKENIZERS_PARALLELISM= False

In [3]:
class MRCPDataLoader(pl.LightningDataModule):
    def __init__(self,batch_size,checkpoint="bert-base-uncased",num_workers=4):
        super(MRCPDataLoader,self).__init__()
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.checkpoint = checkpoint
        self.tokenizer= AutoTokenizer.from_pretrained(self.checkpoint)
        self.data_collector = DataCollatorWithPadding(self.tokenizer)
        
        
    def prepare_data(self):
        load_dataset('glue','mrpc')
        

    def _tokenizer_func(self,sample):
        return self.tokenizer(sample['sentence1'],sample['sentence2'],truncation=True)


    def setup(self, stage):
        dataset = load_dataset('glue','mrpc')
        tokenized_dataset = dataset.map(self._tokenizer_func,batched=True,batch_size=100)
        data_collector = DataCollatorWithPadding(self.tokenizer)
        
        # Preprocessing 
        tokenized_dataset = tokenized_dataset.remove_columns(['sentence1', 'sentence2','idx'])  # remove text columns
        tokenized_dataset = tokenized_dataset.rename_column('label','labels')
        tokenized_dataset.set_format('torch')
        
        self.tokenized_dataset = tokenized_dataset
        

    def train_dataloader(self):
        return DataLoader(dataset=self.tokenized_dataset['train'],batch_size=self.batch_size,shuffle=True,num_workers=self.num_workers,collate_fn=self.data_collector)
    
    def val_dataloader(self):
        return DataLoader(dataset=self.tokenized_dataset['validation'],batch_size=self.batch_size,shuffle=False,num_workers=self.num_workers,collate_fn=self.data_collector)
    
    def test_dataloader(self):
        return DataLoader(dataset=self.tokenized_dataset['test'],batch_size=self.batch_size,shuffle=False,num_workers=self.num_workers,collate_fn=self.data_collector)
   

In [4]:
batch_size = 10
checkpoint = "bert-base-uncased"
ds = MRCPDataLoader(batch_size,checkpoint=checkpoint,num_workers=8)

In [5]:
## Inspecting
ds.setup(stage='fit')
for batch in ds.train_dataloader():
    break
    
{k:v.shape for k,v in batch.items()}

Map:   0%|          | 0/408 [00:00<?, ? examples/s]

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `_

{'labels': torch.Size([10]),
 'input_ids': torch.Size([10, 80]),
 'token_type_ids': torch.Size([10, 80]),
 'attention_mask': torch.Size([10, 80])}

In [6]:
model = AutoModelForSequenceClassification.from_pretrained(checkpoint)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [7]:
## Inspect single training cycle 
model(**batch).logits.shape

torch.Size([10, 2])

In [8]:
# import torchvision
# from torchview import draw_graph

# model_graph = draw_graph(model(),input_data=batch, expand_nested=True)
# model_graph.visual_graph

In [11]:

logger = pl.loggers.TensorBoardLogger(save_dir='./log/', name='mrpc', version=0.1)

profiler = pl.profilers.PyTorchProfiler(
    on_trace_ready=torch.profiler.tensorboard_trace_handler('./log/',),
    schedule=torch.profiler.schedule(skip_first=10, wait=10, warmup=1, active=2)
)

# saves top-K checkpoints based on "val_loss" metric
checkpoint_callback = pl.callbacks.ModelCheckpoint(
    save_top_k=1,
#     save_last=True,
    monitor="val_loss",
    mode="max",
    dirpath="checkpoints/mrpc/",
    filename="{epoch}-{val_f1score:.3f}",
)


trainer = pl.Trainer(
    logger=logger,
    accelerator='auto',
    devices=[0],
    min_epochs=1,
    max_epochs=5,
    precision='16-mixed',
#     enable_model_summary=True,
#     profiler=profiler,
    callbacks=[checkpoint_callback,
               pl.callbacks.EarlyStopping('val_loss',mode='min',patience=5,verbose=True,min_delta=0.00)],
    enable_checkpointing  = True,
)
if os.path.exists(checkpoint_callback.dirpath):
    best_checkpoint_filename = os.listdir(checkpoint_callback.dirpath)
else: 
    best_checkpoint_filename = None

if best_checkpoint_filename:
    print('Loading model from checkpoints : ',best_checkpoint_filename[0])
    trainer.fit(model, ds, ckpt_path=os.path.join(checkpoint_callback.dirpath, best_checkpoint_filename[0]))
else : 
    trainer.fit(model,datamodule=ds)

trainer.validate(model, ds)


Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [12]:
## TODO : Make it work with pytorch lightning