In [1]:
# %pip install -q git+https://github.com/openai/CLIP.git
# %pip install -q timm

In [2]:
import os
import glob

import numpy as np
import pandas as pd
from tqdm.auto import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.optim import AdamW
from torch.optim.lr_scheduler import ReduceLROnPlateau

from transformers import CLIPProcessor, CLIPModel
from utils import (
    collate_fn,
    Transform, 
    ImageTextDataset, 
    CLIP_CHECKPOINT, 
    DATA_ROOT, 
    AvgMeter
)

In [3]:
device = 0 if torch.cuda.is_available() else "cpu"
device

0

In [4]:
def train_epoch(model, loader, optimizer) -> AvgMeter:
    loss_meter = AvgMeter()
    pbar = tqdm(loader, total=len(loader))
    for batch in pbar:
        optimizer.zero_grad()
        
        batch = {k: v.to(device) for k, v in batch.items()}
        output = model(**batch, return_loss=True)
        
        loss = output.loss
        loss.backward()
        optimizer.step()
        
        loss_meter.update(loss.item(), batch['pixel_values'].size(0))
        
        pbar.set_postfix(train_loss=loss_meter.avg)
    return loss_meter
    
def eval_epoch(model, loader) -> AvgMeter:
    loss_meter = AvgMeter()
    pbar = tqdm(loader, total=len(loader))
    for batch in pbar:
        batch = {k: v.to(device) for k, v in batch.items()}
        output = model(**batch, return_loss=True)
        loss = output.loss
        loss_meter.update(loss.item(), batch['pixel_values'].size(0))
        pbar.set_postfix(eval_loss=loss_meter.avg)
    return loss_meter

In [5]:
model = CLIPModel.from_pretrained(CLIP_CHECKPOINT).to(device)

batch_size=32
train_loader = DataLoader(
    ImageTextDataset(DATA_ROOT, "train", transform=Transform(224, True)),
    batch_size=batch_size,
    shuffle=True,
    drop_last=True,
    collate_fn=collate_fn
)

eval_loader = DataLoader(
    ImageTextDataset(DATA_ROOT, "eval", transform=Transform(224, False)),
    batch_size=batch_size,
    shuffle=True,
    collate_fn=collate_fn
)

lr = 2e-6
num_epochs = 7
optimizer = AdamW(model.parameters(), lr=lr, weight_decay=0.1)
lr_scheduler = ReduceLROnPlateau(optimizer, mode="min", patience=2, factor=0.5)

model.eval()
with torch.no_grad():
    train_loss = eval_epoch(model, train_loader)
    eval_loss = eval_epoch(model, eval_loader)
    print(f'EPOCH -1; LR {lr}; LOSS train: {train_loss}, eval: {eval_loss}.')

for epoch in range(num_epochs):
    cur_lr = optimizer.param_groups[0]['lr']
    model.train()
    train_loss = train_epoch(model, train_loader, optimizer)
    
    model.save_pretrained(f'./out/lr{lr}w-1_b{batch_size}x{epoch}/')
    model.eval()
    with torch.no_grad():
        eval_loss = eval_epoch(model, eval_loader)
        print(f'EPOCH {epoch}; LR {cur_lr}; LOSS train: {train_loss}, eval: {eval_loss}.')
        lr_scheduler.step(eval_loss.avg)

  0%|          | 0/621 [00:00<?, ?it/s]

  0%|          | 0/78 [00:00<?, ?it/s]



EPOCH -1; LR 2e-06; LOSS train: 0.9393, eval: 0.7486.


  0%|          | 0/621 [00:00<?, ?it/s]

  0%|          | 0/78 [00:00<?, ?it/s]

EPOCH 0; LR 2e-06; LOSS train: 0.6291, eval: 0.4101.


  0%|          | 0/621 [00:00<?, ?it/s]

  0%|          | 0/78 [00:00<?, ?it/s]

EPOCH 1; LR 2e-06; LOSS train: 0.5267, eval: 0.3895.


  0%|          | 0/621 [00:00<?, ?it/s]

  0%|          | 0/78 [00:00<?, ?it/s]

EPOCH 2; LR 2e-06; LOSS train: 0.4740, eval: 0.3884.


  0%|          | 0/621 [00:00<?, ?it/s]

  0%|          | 0/78 [00:00<?, ?it/s]

EPOCH 3; LR 2e-06; LOSS train: 0.4236, eval: 0.3829.


  0%|          | 0/621 [00:00<?, ?it/s]

  0%|          | 0/78 [00:00<?, ?it/s]

EPOCH 4; LR 2e-06; LOSS train: 0.3969, eval: 0.3976.


  0%|          | 0/621 [00:00<?, ?it/s]

  0%|          | 0/78 [00:00<?, ?it/s]

EPOCH 5; LR 2e-06; LOSS train: 0.3717, eval: 0.3951.


  0%|          | 0/621 [00:00<?, ?it/s]

  0%|          | 0/78 [00:00<?, ?it/s]

EPOCH 6; LR 2e-06; LOSS train: 0.3474, eval: 0.4219.
