In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from config import *
from data import JSONLDataset, ImageDataset, ImageCaptionDataset, CLASS_MAP, get_image_UIDs, crop_image_preprocess_image_text_batch
from path import SPLITS_PATH
from models.vl_encoders import VLE_REGISTRY, VLEncoder
from viz import print_layer_numel
from utils import get_compute_capability

from torch import nn
from torchvision.models import segmentation as segmodels
from functools import partial
from torchvision.transforms._presets import SemanticSegmentation
import torchvision.transforms.v2 as T
from torch.utils.data import DataLoader
from collections import OrderedDict
import math

In [3]:
from vendors.flair.src.flair.train import backward, unwrap_model

---

In [5]:
offset = ...

image_ds = ImageDataset(Path('/home/olivieri/exp/data/data_gen/VOC2012/train_no_aug/images'))
caption_ds = JSONLDataset(Path('/home/olivieri/exp/data/data_gen/VOC2012/train_no_aug/captions.jsonl'))

image_caption_ds = ImageCaptionDataset(
    image_ds,
    caption_ds
)
len(image_caption_ds)

2699

In [6]:
vle: VLEncoder = VLE_REGISTRY.get("flair", device=CONFIG['device'])
# vle.set_vision_trainable_params('visual_proj')
vle.set_vision_trainable_params('mlp+visual_proj')
print_layer_numel(vle.model, print_only_total=True, only_trainable=True)

visual_proj.attn.in_proj_weight: 786,432
visual_proj.attn.in_proj_bias : 1,536
visual_proj.attn.out_proj.weight: 262,144
visual_proj.attn.out_proj.bias: 512
visual_proj.ln_q.weight       : 512
visual_proj.ln_q.bias         : 512
visual_proj.ln_k.weight       : 512
visual_proj.ln_k.bias         : 512
visual_proj.ln_v.weight       : 512
visual_proj.ln_v.bias         : 512
image_post.proj               : 393,216
Total: 1,446,912


In [7]:
center_crop_fn = T.CenterCrop(CONFIG['vle']['image_size'])

collate_fn = partial(
    crop_image_preprocess_image_text_batch,
    crop_fn=center_crop_fn,
    preprocess_images_fn=vle.preprocess_images,
    preprocess_texts_fn=vle.preprocess_texts
)

image_caption_dl = DataLoader(
    image_caption_ds,
    batch_size=CONFIG["vle"]['train']["batch_size"],
    shuffle=False,
    generator=TORCH_GEN.clone_state(),
    collate_fn=collate_fn,
)

In [8]:
# TODO investigate what this rank and world_size is.
criterion = vle.create_loss(
      add_mps_loss = True,
      rank = 0,
      world_size = 1,
      num_caps_per_img = 1
)

In [9]:
exclude = lambda n, p: p.ndim < 2 or "bn" in n or "ln" in n or "bias" in n or 'logit_scale' in n
include = lambda n, p: not exclude(n, p)

named_parameters = list(vle.model.named_parameters())
gain_or_bias_params = [p for n, p in named_parameters if exclude(n, p) and p.requires_grad]
rest_params = [p for n, p in named_parameters if include(n, p) and p.requires_grad]

optimizer = torch.optim.AdamW(
    [
        {"params": gain_or_bias_params, "weight_decay": 0.},
        {"params": rest_params, "weight_decay": 1e-2}, # authors use 0.5
    ],
    lr=5e-4,
    betas=(0.9, 0.98),
    eps=1e-8,
)

In [None]:
lr = 1e-4
optimizer = torch.optim.AdamW(vle.model.parameters(), lr=lr)

if get_compute_capability() >= 7.0:
    vle.model = torch.compile(vle.model)

for epoch in range(CONFIG["vle"]['train']["num_epochs"]):

    for step, (images, texts) in enumerate(image_caption_dl):
            
            # scs_img = (diff_img*255).to(torch.uint8) # for viewable images

            # TODO handle AMP

            vle.model.train()

            vle_output = vle.encode_and_project(images, texts, broadcast=False)

            optimizer.zero_grad()

            losses = criterion(
                    image_features=vle_output.global_image_token,
                    image_tokens=vle_output.local_image_tokens.clone(),
                    text_features=vle_output.global_text_token.squeeze(1),
                    logit_scale=vle.model.logit_scale,
                    visual_proj=vle.model.visual_proj,
                    logit_bias=vle.model.logit_bias,
                    output_dict=True
            )
            total_loss = sum(losses.values())

            scaler = None # for AMP
            backward(total_loss, scaler)

            grad_clip_norm = None
            if grad_clip_norm is not None:
                torch.nn.utils.clip_grad_norm_(vle.model.parameters(), grad_clip_norm, norm_type=2.0)
            
            optimizer.step()

            if (step+1) % 50 == 0:
                print(f"step {step+1}/{len(image_caption_dl)}, {total_loss=}")

            with torch.no_grad():
                unwrap_model(vle.model).logit_scale.clamp_(0, math.log(100))

    print(f"Epoch {epoch+1}/{CONFIG['vle']['train']['num_epochs']}, {total_loss=}")

step 50/169, total_loss=tensor(9.3256, device='cuda:0', grad_fn=<AddBackward0>)
step 100/169, total_loss=tensor(9.1323, device='cuda:0', grad_fn=<AddBackward0>)
step 150/169, total_loss=tensor(9.1443, device='cuda:0', grad_fn=<AddBackward0>)


KeyboardInterrupt: 