# Baseline

- pretrained Electra Small discrim as text tower
- pretrained ViT Tiny trained with DINO on ImageNet100 as image tower
- Compute classic CLIP loss 
- finetune on Flickr30k
 

In [1]:
import os, sys, random
import numpy as np
import pandas as pd
from PIL import Image
from pathlib import Path

from functools import partial
import itertools
from tqdm.autonotebook import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F

from transformers import ElectraTokenizerFast, ElectraConfig, ElectraForMaskedLM, ElectraForPreTraining
from torchvision.datasets import Flickr30k
from torchvision import transforms 
from hugdatafast import datasets
from hugdatafast.fastai import HF_Datasets
from fastai.text.all import *
from timm.models.layers import PatchEmbed
from timm.models.vision_transformer import VisionTransformer, _cfg

from utils import ELECTRADataProcessor

import wandb
from fastai.callback.wandb import WandbCallback


  from tqdm.autonotebook import tqdm


In [2]:
torch.cuda.is_available()

True

## Config

In [8]:
class CFG:
    debug = False
    image_path = "./flickr30k-images"
    captions_path = "./flickr30k/results_20130124.token"
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    base_name = "vanilla"
    seed = 2022

    pretrained_checkpoint_electra = None
    # None to use model from HuggingFace

    
    adam_bias_correction_elec = False
    size_elec = "small"

    logger = "wandb"

    datas = ["flickr30k"]
    num_workers = 3


In [9]:
i = ['small', 'base', 'large'].index(CFG.size_elec)
CFG.lr_elec = [3e-4, 1e-4, 5e-5][i]
CFG.layer_lr_decay_elec = [0.8, 0.8, 0.9][i]
CFG.max_length_elec = [128, 512, 512][i]

if CFG.pretrained_checkpoint_electra is None: CFG.max_length_elec = 512 
# All public models is ++, which use max_length 512


In [10]:
# huggingface/transformers
hf_tokenizer_elec = ElectraTokenizerFast.from_pretrained(f"google/electra-{CFG.size_elec}-discriminator")
electra_config = ElectraConfig.from_pretrained(f'google/electra-{CFG.size_elec}-discriminator')


In [11]:
class LightWandbCallback(Callback):
    def __init__(self, run):
        self.run = run
    def after_epoch(self):
        if self.epoch != (self.n_epoch - 1): return
        wandb.log({n:s for n,s in zip(self.recorder.metric_names, self.recorder.log) if n not in ['train_loss', 'epoch', 'time']})
    def after_fit(self):
        wandb.log({}) # ensure sync of last step
        self.run.finish()