In [2]:
import wandb
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor

import torch
import torch.nn as nn
from dataclasses import dataclass
from torch.utils.data import Dataset, DataLoader
from transformers import AutoProcessor, BlipForConditionalGeneration



In [3]:
from datasets import load_dataset
ds = load_dataset('gorovuha/ru_image_captioning')

In [4]:
class ImageCaptioningDataset(Dataset):
    def __init__(self, dataset, processor):
        self.dataset = dataset
        self.processor = processor

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]
        encoding = self.processor(images=item["image"], text=item['capt2'], padding="max_length", return_tensors="pt")
        # remove batch dimension
        encoding = {k:v.squeeze() for k,v in encoding.items()}
        return encoding

In [5]:
processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base")

In [6]:
train_dataset = ImageCaptioningDataset(ds['train'], processor)
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=4, num_workers=9)

val_dataset = ImageCaptioningDataset(ds['test'], processor)
val_dataloader = DataLoader(val_dataset, shuffle=False, batch_size=4, num_workers=9)

test_dataset = ImageCaptioningDataset(ds['validation'], processor)
test_dataloader = DataLoader(test_dataset, shuffle=False, batch_size=4, num_workers=9)

In [7]:
@dataclass
class CFG:
    # dataset
    mode: str = "train"
    num_workes: int = 9
    description: str = '''baseline'''

    # training
    batch_size: int = 4
    wandb_project: str = 'BLIP-FineTune-Ru'
    default_root_dir: str = 'weights'
    checkpoints_dir: str = 'weights/checkpoints'
    lr: float = 5e-6
    weight_decay: float = 1e-1
    max_epochs: int = 10

In [8]:
CFG = CFG()

In [9]:
class BLIP(pl.LightningModule):

    def __init__(self, config, pretrain=None, **kwargs):
        super().__init__()
        self.args = config
        self.learning_rate = config.lr
        self.save_hyperparameters()
        self.model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")

        if pretrain is not None:
            self.model.load_state_dict(torch.load(pretrain), strict=False)
            print(f'resumed from {pretrain}')


    def forward(self, batch):
        input_ids = batch["input_ids"]
        pixel_values = batch["pixel_values"]
        attention_mask = batch['attention_mask']
        outputs = self.model(input_ids=input_ids, pixel_values=pixel_values,
                             attention_mask=attention_mask, labels=input_ids)

        return outputs

    def training_step(self, batch, batch_idx):
        input_ids = batch["input_ids"]
        pixel_values = batch["pixel_values"]
        attention_mask = batch['attention_mask']
        outputs = self.model(input_ids=input_ids, pixel_values=pixel_values,
                             attention_mask=attention_mask, labels=input_ids)
            
        loss = outputs.loss

        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        return loss

    def validation_step(self, batch, batch_idx):
        input_ids = batch["input_ids"]
        pixel_values = batch["pixel_values"]
        attention_mask = batch['attention_mask']
        outputs = self.model(input_ids=input_ids, pixel_values=pixel_values,
                             attention_mask=attention_mask, labels=input_ids)        
        loss = outputs.loss

        self.log("val_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.learning_rate, weight_decay=self.args.weight_decay)
        lr_scheduler = {
                        'scheduler': torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                                                'min',
                                                                                factor=0.5,
                                                                                patience=100,
                                                                                threshold=0.07,),
                        'interval': 'epoch',
                        'frequency': 1,
                        'monitor': "train_loss",
                        'name': 'lr/reduce_on_plateau'
                        }
        return [optimizer], [lr_scheduler]

In [17]:
# 65cbe958b1656ff763c0cd06331c4548953c2806
wandb_name = f"{CFG.description}_try_2"

wandb.login()
wandb_logger = WandbLogger(project = CFG.wandb_project, name=wandb_name)

wandb.init()
wandb.config.update({k: v for k, v in CFG.__dict__.items() if not k.startswith("__")})



VBox(children=(Label(value='0.002 MB of 0.006 MB uploaded\r'), FloatProgress(value=0.34401524337432876, max=1.…

0,1
epoch,▁▁▁▁▂▂▂▂▂▂▂▂▃▃▃▃▄▄▄▄▅▅▅▅▅▅▅▅▆▆▆▇▇▇▇▇▇▇▇█
lr/reduce_on_plateau,▁▁▁▁▁▁▁▁▁▁▁
train_loss_epoch,█▃▂▂▂▂▁▁▁▁
train_loss_step,█▂▂▂▂▂▂▂▂▂▁▂▁▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
trainer/global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
val_loss_epoch,█▅▃▁▁▂▁▂▂▄
val_loss_step,▆▅█▆▄▇▅▅▃▃▃▄▅▃▄▃▃▂▂▃▆▆▁▃▂▂▂▃▅▄▂▃▆▅▄▂▅▅▄▃

0,1
epoch,10.0
lr/reduce_on_plateau,5e-05
train_loss_epoch,0.07096
train_loss_step,0.04457
trainer/global_step,3949.0
val_loss_epoch,0.27335
val_loss_step,0.50788


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011112736311123526, max=1.0…

In [18]:
model = BLIP(config=CFG, pretrain=None)

lr_monitor = LearningRateMonitor(logging_interval='epoch')

checkpoint_callback = ModelCheckpoint(
    dirpath=CFG.checkpoints_dir,
    monitor='val_loss',
    save_top_k=3,
    filename=wandb_name,
    mode='min',
    save_weights_only=True,
)

In [19]:
trainer = pl.Trainer(logger = wandb_logger,
                     default_root_dir = CFG.default_root_dir,
                     accelerator = 'gpu',
                     callbacks = [checkpoint_callback, lr_monitor],
                     max_epochs = CFG.max_epochs,
                     check_val_every_n_epoch = 1,
                     )

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [20]:
trainer.fit(model, train_dataloaders = train_dataloader, val_dataloaders = val_dataloader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type                         | Params | Mode
--------------------------------------------------------------
0 | model | BlipForConditionalGeneration | 247 M  | eval
--------------------------------------------------------------
247 M     Trainable params
0         Non-trainable params
247 M     Total params
989.656   Total estimated model params size (MB)


Sanity Checking: |                                        | 0/? [00:00<?, ?it/s]

Training: |                                               | 0/? [00:00<?, ?it/s]

In [11]:
path_to_model = CFG.checkpoints_dir+'/baseline_try_2-v4.ckpt'

model = BLIP(config=CFG, pretrain=path_to_model)

  return self.fget.__get__(instance, owner)()


resumed from weights/checkpoints/baseline_try_2-v4.ckpt


In [14]:
from tqdm.notebook import tqdm

av_loss, count = 0, 0
with torch.no_grad():
    for idx, batch in tqdm(enumerate(test_dataloader)):
        outputs = model(batch)
        loss = outputs.loss
        av_loss += loss.item()
        count += 1
        print(av_loss/count)

0it [00:00, ?it/s]

12.328524589538574
12.132808208465576


KeyboardInterrupt: 

In [None]:
trainer.push_to_hub()