## Import Libraries

In [None]:
%load_ext autoreload
%autoreload 2
from tqdm import tqdm
from typing import List
import pytorch_lightning as pl
from omegaconf import DictConfig
from pytorch_lightning import Callback, LightningDataModule, LightningModule, Trainer
from pytorch_lightning.loggers import LightningLoggerBase
import pyrootutils
root = pyrootutils.setup_root(
    search_from='/vol/aimspace/users/rohn/vlp/',
    indicator=[".git", "pyproject.toml"],
    pythonpath=True,
    dotenv=True,
)

from src import utils

## Load Hydra Config

In [None]:
import hydra
from hydra import initialize, compose
from hydra.core.global_hydra import GlobalHydra
from hydra.core.hydra_config import HydraConfig

# Step 1: Clear existing Hydra state
if GlobalHydra.instance().is_initialized():
    GlobalHydra.instance().clear()

from omegaconf import OmegaConf

initialize(config_path="../configs")
cfg = compose(config_name="pretraining",
              return_hydra_config=True)

HydraConfig.instance().set_config(cfg)  # Required for interpolation to work

print(OmegaConf.to_yaml(cfg))

## Load all Lightning Modules

In [None]:
print(OmegaConf.to_yaml(cfg.datamodule))
datamodule: LightningDataModule = hydra.utils.instantiate(cfg.datamodule)

In [None]:
print(OmegaConf.to_yaml(cfg.model))
model: LightningModule = hydra.utils.instantiate(cfg.model)

In [None]:
# root set by pyrootutils
!echo $PROJECT_ROOT

In [None]:
print(OmegaConf.to_yaml(cfg.callbacks))
callbacks: List[Callback] = utils.instantiate_callbacks(cfg.callbacks)

## Testing Data Module

In [None]:
datamodule.setup('')

In [None]:
# print sized of train and val dataloaders
print('total number of samples:', len(datamodule.train_dataloader().dataset) + len(datamodule.val_dataloader().dataset))
print('train dataloader size:', len(datamodule.train_dataloader().dataset))
print('val dataloader size:', len(datamodule.val_dataloader().dataset))

In [None]:
for batch in tqdm(datamodule.train_dataloader()):
    break

In [None]:
batch['image'].shape, batch['tokenized_data']['input_ids'].shape

In [None]:
len(datamodule.train_dataset) // 32

In [None]:
len(datamodule.train_dataloader()) + len(datamodule.val_dataloader())

In [None]:
11835 - 11271

In [None]:
len(datamodule.val_dataloader())

## Testing Model

In [None]:
model.net.text_model.model(**batch['tokenized_data'])

In [None]:
v, u = model(batch)

In [None]:
v.shape

In [None]:
u.shape

## Testing Callbacks (Validating Config)

In [None]:
print(OmegaConf.to_yaml(cfg.callbacks))

## Testing Paths

## Testing Trainer Fit

## ....