In [1]:
# adding parent directory to sys.path to import modules
import sys
sys.path.append('..')

In [2]:
from model.model import FineTuner
from model.dataloader import DataModule
from utils import LogPredictionSamples
from utils import EarlyStopping

In [3]:
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint

In [4]:
dm_hparams = dict(
        train_path='../data/train.csv',
        val_path='../data/val.csv',
        test_path='../data/test.csv',
        tokenizer_name_or_path='google/mt5-small',
        max_source_length=128,
        max_target_length=128,
        train_batch_size=2,
        val_batch_size=2,
        test_batch_size=2
    )
dm = DataModule(**dm_hparams)

In [5]:
dm.setup()

In [6]:
for batch in dm.val_dataloader():
    input_ids = batch['input_ids']
    attention_mask = batch['attention_mask']
    labels = batch['labels']
    break

In [7]:
model_hparams = dict(
        learning_rate=2e-5,
        model_name_or_path='google/mt5-small',
        eval_beams=4,
        tgt_max_seq_len=128,
        tokenizer=dm.tokenizer
    )
model=FineTuner(**model_hparams)

In [8]:
output = model(input_ids, attention_mask, labels)

In [9]:
output['logits'].shape

torch.Size([2, 128, 250112])

In [10]:
loss = output['loss']

In [11]:
loss.item()

33.90782928466797

In [12]:
input_text, pred_text, ref_text = model._generative_step(batch)

In [13]:
input_text

['2017 actual total 154983.0',
 'black male workers transportation and warehousing 11.1']

In [14]:
pred_text

['<extra_id_0>', '<extra_id_0>']

In [15]:
ref_text

['if the pre-production development activities were to be included, the fy 2017 r&d budget authority would have been $155.0 billion instead of the $125.3 billion in actual budget authority',
 'among black male workers, 11% were employed in the transportation and warehousing sector']

In [16]:
for key in batch:
    batch[key] = batch[key].cuda()

In [17]:
batch['input_ids'].device

device(type='cuda', index=0)

In [18]:
checkpoint_callback = ModelCheckpoint(monitor='val_loss', mode='min')
trainer_hparams = dict(
    gpus=1,
    strategy='dp',
    max_epochs=5,
    num_sanity_val_steps=3,
    logger=WandbLogger(save_dir='../experiments'),
    callbacks=[checkpoint_callback]
)
trainer = pl.Trainer(**trainer_hparams)


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


In [19]:
trainer.fit(model, dm)

  f"DataModule.{name} has already been called, so it will not be called again. "
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mshivprasad[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.12.11 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade



  | Name  | Type                        | Params
------------------------------------------------------
0 | model | MT5ForConditionalGeneration | 300 M 
------------------------------------------------------
300 M     Trainable params
0         Non-trainable params
300 M     Total params
1,200.707 Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

In [20]:
trainer.validate(model, dm)

  f"DataModule.{name} has already been called, so it will not be called again. "
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]


Validating: 0it [00:00, ?it/s]

--------------------------------------------------------------------------------
DATALOADER:0 VALIDATE RESULTS
{'val_bleu': 0.01892796717584133,
 'val_loss': 27.61751365661621,
 'val_loss_epoch': 27.61751365661621}
--------------------------------------------------------------------------------


[{'val_loss': 27.61751365661621,
  'val_loss_epoch': 27.61751365661621,
  'val_bleu': 0.01892796717584133}]

In [21]:
from sacrebleu.metrics import BLEU

In [22]:
bleu = BLEU()

In [23]:
ref_text = ["howru"]
pred_text = ["how r u"]

In [24]:
bleu.corpus_score(pred_text, [ref_text])

BLEU = 0.00 0.0/0.0/0.0/0.0 (BP = 1.000 ratio = 3.000 hyp_len = 3 ref_len = 1)

In [25]:
a = [1, 2, 3]

In [26]:
print([zip(a, a)])

[<zip object at 0x7f5b8bedac30>]
