In [None]:
def is_notebook():
    try:
        __IPYTHON__
        return True
    except ValueError:
        return False

In [None]:
import os
from hydra import initialize, compose
from omegaconf import OmegaConf, DictConfig
import hydra

if is_notebook():
    with initialize(config_path="conf/"):
        cfg = compose(config_name="config.yaml", overrides=[])#["+db=mysql"])
else:
    @hydra.main(config_path="conf", config_name="config")
    def get_cfg(cfg):
        return dict(cfg)
    cfg = get_cfg()
print(cfg)
locals().update(cfg)

In [None]:
if is_notebook():
    # override variables to experiment in notebook
    gpu = 2

In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu)

import pandas as pd
from PIL import Image
import numpy as np

import torch

root_folder = "/raid/8wiehe/"

In [None]:
dict(cfg)["lr"]

In [None]:
lr

In [None]:
if batch_size is None:
    if model_name == "ViT-L/14":
        batch_size = 4 # max 32 for single GPU CL on VitB16, 4 for ViT-L/14 (9.2GB)
    elif model_name == "ViT-B/16":
        batch_size = 32
    elif model_name == "ViT-B/32":
        batch_size = 64
    else:
        batch_size = 32

val_check_interval = int(cfg["val_check_interval"] * (32 / batch_size))

In [None]:
from clip_utils import load_clip, FinetuneDataModule
from contrastive_learning_utils import LitCLCLIP

clip_base_model, transform, clip_name = load_clip(model_name, device="cpu")

data_module = FinetuneDataModule(clip_base_model, transform, dataset_name=dataset_name, mode=mode, 
                                 use_augs=use_augs, use_cl=True, sent_frac=sent_frac, batch_size=batch_size,
                                root_folder=root_folder, use_ffcv=use_ffcv)

lit_model = LitCLCLIP(clip_base_model, mode, max_epochs, lr, data_module.steps_per_epoch, 
                 weight_decay=weight_decay, gen_freq=gen_freq)
lit_model.label_names = data_module.label_names

In [None]:
len(data_module.val_dataloader())

In [None]:

import pytorch_lightning
from pytorch_lightning.loggers import WandbLogger

# for ffcv
if use_ffcv:
    from types import MethodType
    import ffcv_custom_PTL_methods 


wandb_logger = pytorch_lightning.loggers.WandbLogger(name=None, 
                                                     save_dir=root_folder + "pytorch_lightning/", 
                                                     offline=False, id=None, 
                                      anonymous=None, version=None, project="cl_early_tests", 
                                      log_model=False, experiment=None, prefix='')
wandb_logger.log_hyperparams({"mode": mode,
                              "dataset_name": dataset_name,
                              "sent_frac": sent_frac,
                              "use_augs": use_augs,
                              "batch_size": batch_size,
                              "model_name": model_name,
                              "use_ffcv": use_ffcv,
                             })
# log gradients and model topology
wandb_logger.watch(lit_model)


trainer = pytorch_lightning.Trainer(val_check_interval=val_check_interval,
                                    precision=precision,
                                    logger=wandb_logger,
                                    max_epochs=max_epochs,
                                    gpus=int(torch.cuda.is_available()),
                                    #overfit_batches=1, 
                                    benchmark=True,
                                    )

if use_ffcv:
    # for ffcv
    trainer.fit_loop.epoch_loop.on_run_start = MethodType(custom_PTL_methods.on_run_start, trainer.fit_loop.epoch_loop)
    trainer.fit_loop.epoch_loop.advance = MethodType(custom_PTL_methods.advance, trainer.fit_loop.epoch_loop)


In [None]:
trainer.fit(lit_model, data_module)
# remove wandb hooks
#wandb_logger.unwatch(model)

In [None]:
trainer.val_check_batch