In [None]:
# %pip install -q git+https://github.com/openai/CLIP.git
# %pip install -q timm

In [1]:
import os
import glob

import numpy as np
import pandas as pd
from tqdm.auto import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.optim import AdamW
from torch.optim.lr_scheduler import ReduceLROnPlateau


from transformers import CLIPProcessor, CLIPModel
from utils.data import Transform, ImageTextDataset, collate_fn

In [2]:
device = 0 if torch.cuda.is_available() else "cpu"
device

0

In [3]:
DATA_ROOT = 'data'

In [4]:
clip_checkpoint = "openai/clip-vit-base-patch32"
model = CLIPModel.from_pretrained(clip_checkpoint).to(device)
processor = CLIPProcessor.from_pretrained(clip_checkpoint)
tokenizer = processor.tokenizer

In [5]:
train_loader = DataLoader(
    ImageTextDataset(DATA_ROOT, "train", transform=Transform(224, True)),
    batch_size=32,
    shuffle=True,
    drop_last=True,
    collate_fn=collate_fn
)

eval_loader = DataLoader(
    ImageTextDataset(DATA_ROOT, "eval", transform=Transform(224, False)),
    batch_size=32,
    shuffle=True,
    drop_last=True,
    collate_fn=collate_fn
)

In [6]:
class AvgMeter:
    def __init__(self):
        self.reset()
    def reset(self):
        self.avg, self.sum, self.count = 0, 0, 0
    def update(self, val, count=1):
        self.count += count
        self.sum += val * count
        self.avg = self.sum / self.count
    def __repr__(self):
        return f'{self.avg:.4f}'

def train_epoch(model, loader, optimizer):
    loss_meter = AvgMeter()
    pbar = tqdm(loader, total=len(loader))
    for batch in pbar:
        optimizer.zero_grad()
        
        batch = {k: v.to(device) for k, v in batch.items()}
        output = model(**batch, return_loss=True)
        
        loss = output.loss
        loss.backward()
        optimizer.step()
        
        loss_meter.update(loss.item(), batch['pixel_values'].size(0))
        
        pbar.set_postfix(train_loss=loss_meter.avg)
    return loss_meter
    
def eval_epoch(model, loader):
    loss_meter = AvgMeter()
    pbar = tqdm(loader, total=len(loader))
    for batch in pbar:
        batch = {k: v.to(device) for k, v in batch.items()}
        output = model(**batch, return_loss=True)
        loss = output.loss
        loss_meter.update(loss.item(), batch['pixel_values'].size(0))
        pbar.set_postfix(eval_loss=loss_meter.avg)
    return loss_meter

In [7]:
lr = 2e-6
num_epochs = 10
optimizer = AdamW(model.parameters(), lr=lr, weight_decay=0.2)
lr_scheduler = ReduceLROnPlateau(optimizer, mode="min", patience=2, factor=0.5)

for epoch in range(num_epochs):
    cur_lr = optimizer.param_groups[0]['lr']
    model.train()
    train_loss = train_epoch(model, train_loader, optimizer)
    
    model.save_pretrained(f'./out/lr{cur_lr}_{epoch}/')
    model.eval()
    with torch.no_grad():
        eval_loss = eval_epoch(model, eval_loader)
        print(f'EPOCH{epoch} eval loss: {eval_loss}')
        lr_scheduler.step(eval_loss.avg)

  0%|          | 0/621 [00:00<?, ?it/s]

  0%|          | 0/77 [00:00<?, ?it/s]



EPOCH0 eval loss: 0.4128


  0%|          | 0/621 [00:00<?, ?it/s]

  0%|          | 0/77 [00:00<?, ?it/s]

EPOCH1 eval loss: 0.3843


  0%|          | 0/621 [00:00<?, ?it/s]

  0%|          | 0/77 [00:00<?, ?it/s]

EPOCH2 eval loss: 0.3828


  0%|          | 0/621 [00:00<?, ?it/s]

  0%|          | 0/77 [00:00<?, ?it/s]

EPOCH3 eval loss: 0.3866


  0%|          | 0/621 [00:00<?, ?it/s]

  0%|          | 0/77 [00:00<?, ?it/s]

EPOCH4 eval loss: 0.3955


  0%|          | 0/621 [00:00<?, ?it/s]

  0%|          | 0/77 [00:00<?, ?it/s]

EPOCH5 eval loss: 0.3969


  0%|          | 0/621 [00:00<?, ?it/s]

  0%|          | 0/77 [00:00<?, ?it/s]

EPOCH6 eval loss: 0.3921


  0%|          | 0/621 [00:00<?, ?it/s]

  0%|          | 0/77 [00:00<?, ?it/s]

EPOCH7 eval loss: 0.3896


  0%|          | 0/621 [00:00<?, ?it/s]

  0%|          | 0/77 [00:00<?, ?it/s]

EPOCH8 eval loss: 0.4261


  0%|          | 0/621 [00:00<?, ?it/s]

  0%|          | 0/77 [00:00<?, ?it/s]

EPOCH9 eval loss: 0.4507
