## Import Libraries

In [1]:
%load_ext autoreload
%autoreload 2
from tqdm import tqdm
from typing import List
import pytorch_lightning as pl
from omegaconf import DictConfig
from pytorch_lightning import Callback, LightningDataModule, LightningModule, Trainer
from pytorch_lightning.loggers import LightningLoggerBase
import pyrootutils
root = pyrootutils.setup_root(
    search_from='/vol/aimspace/users/rohn/vlp/',
    indicator=[".git", "pyproject.toml"],
    pythonpath=True,
    dotenv=True,
)
from matplotlib import pyplot as plt
from src import utils

## Load Hydra Config

In [None]:
import hydra
from hydra import initialize, compose
from hydra.core.global_hydra import GlobalHydra
from hydra.core.hydra_config import HydraConfig

# Step 1: Clear existing Hydra state
if GlobalHydra.instance().is_initialized():
    GlobalHydra.instance().clear()

from omegaconf import OmegaConf

initialize(config_path="../configs")
cfg = compose(config_name="pretraining",
              overrides=["+experiment=maskvlm"],
              return_hydra_config=True)

HydraConfig.instance().set_config(cfg)  # Required for interpolation to work

# print(OmegaConf.to_yaml(cfg))

## Initialize components

In [None]:
print(OmegaConf.to_yaml(cfg.datamodule))
cfg.datamodule.train_val_split = [20, 20] # for testing
datamodule: LightningDataModule = hydra.utils.instantiate(cfg.datamodule)
datamodule.setup('fit')

In [None]:
print(OmegaConf.to_yaml(cfg.model))
model: LightningModule = hydra.utils.instantiate(cfg.model)

In [None]:
print(OmegaConf.to_yaml(cfg.callbacks))
callbacks: List[Callback] = utils.instantiate_callbacks(cfg.callbacks)

In [None]:
print(OmegaConf.to_yaml(cfg.logger))
# disable logging for now
cfg.logger = None
loggers: LightningLoggerBase = utils.instantiate_loggers(cfg.logger)

In [None]:
print(OmegaConf.to_yaml(cfg.trainer))
cfg.trainer.max_epochs = 1
cfg.trainer.accelerator = 'cpu'
trainer: Trainer = hydra.utils.instantiate(cfg.trainer, callbacks=callbacks, logger=loggers)

## Testing Data Module

In [None]:
# print sized of train and val dataloaders
print('total number of samples:', len(datamodule.train_dataloader().dataset) + len(datamodule.val_dataloader().dataset))
print('train dataloader size:', len(datamodule.train_dataloader().dataset))
print('val dataloader size:', len(datamodule.val_dataloader().dataset))

In [9]:
# for batch in tqdm(datamodule.train_dataloader()):
#     break

In [None]:
# get one batch
batch = next(iter(datamodule.train_dataloader()))
print(batch.keys())

In [11]:
# print some sentences and their masked versions
# tokenizer = datamodule.train_dataset.subset.dataset.tokenizer
# for i in range(5):
#     # print the text
#     n = batch['tokenized_data']['attention_mask'][i].sum().item()
#     print(str(i) + ' #'*20)
#     print(f"Original: {tokenizer.decode(batch['tokenized_data']['input_ids'][i][:n])}")
#     print(f"Masked: {tokenizer.decode(batch['masked_tokenized_data']['input_ids'][i][:n])}")
#     print(f'Mask: {batch["text_mask"][i][:n]}')

#     # show the images
#     fig, axs = plt.subplots(1, 3, figsize=(15, 5))
#     # set axis off
#     for ax in axs:
#         ax.axis('off')
#     axs[0].imshow(batch['image'][i].permute(1, 2, 0))
#     axs[0].set_title('Original')
#     axs[1].imshow(batch['masked_img'][i].permute(1, 2, 0))
#     axs[1].set_title('Masked')
#     axs[2].imshow(batch['img_mask'][i].permute(1, 2, 0), cmap='gray')
#     axs[2].set_title('Mask')
#     plt.show()

## Testing Model

In [12]:
# txt_logits, reconstructed_img, z_img, z_txt, itm_labels, itm_gt = model.net(batch)

## Testing Module

In [13]:
# model.validation_step(batch, 0)

## Testing Trainer

In [None]:
trainer.fit(model, datamodule)