In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint

import numpy as np
from skimage import data
from skimage.transform import resize

import matplotlib.pyplot as plt

import torch

from src.reader import KineticDataset, KineticDatasetVideo

from src.vqgan import ViTVQGAN
from src.my_model import MaskCode, Encoder, Decoder, MaskVideo
from src.trainer import AttentionMaskModeling


In [3]:
path = "/mnt/e/kinetics-dataset/k400"
split = "train"

# ds = KineticDatasetVideo(
#     path, split,
#     n_frames=16,
# )
ds = KineticDatasetVideo.get_ds(path, split, 16)
ds_loader = DataLoader(
    ds, 2, True, 
    num_workers=4
)

100


In [None]:
for batch in ds_loader:
    code, _ = batch
    break

_batch = batch[0].cuda(), None
B, T, _, H, W = code.shape
code.shape

In [4]:
is_finetune = False
image_size = (256, 256)
patch_size = (8, 8)
depth, heads, dim, embed_dim = 12, 12, 768, 32

window_size = (3, 3)
length, height, width = 32, 32, 32
temporal_depth, temporal_heads, temporal_dim = 4, 8, 128
n_codes=8192

vitvq_path = "./checkpoint/imagenet_vitvq_base.ckpt"

In [None]:
lightning_model = AttentionMaskModeling(
    model=MaskVideo(
        is_finetune=is_finetune,
        image_size=image_size, patch_size=patch_size,
        depth=depth, heads=heads, dim=dim,
        n_codes=n_codes, embed_dim=embed_dim,

        window_size=window_size,
        length=length, height=height, width=width,
        temporal_depth=temporal_depth, temporal_heads=temporal_heads, temporal_dim=temporal_dim,

        drop_prob=0.1, depth_prob=0.1,
        vitvq_path=vitvq_path
	), 
    top_p=0.95
)
lightning_model.cuda()
lightning_model.training_step(_batch, 0)

# frames, _ = _batch  # (B, T, HW)
# mask = lightning_model.get_mask_from_logits(frames)  # (B, T', HW)

# _, _, attn_logits = lightning_model.teacher.encoder(frames, None, False)
# attn_logits.shape

In [None]:
encoder = Encoder(
    is_finetune=is_finetune,
    image_size=image_size, patch_size=patch_size,
    depth=depth, heads=heads, dim=dim,
    window_size=window_size,
    length=length, height=height, width=width,
    temporal_depth=temporal_depth, temporal_heads=temporal_heads, temporal_dim=temporal_dim,
    n_codes=n_codes, embed_dim=embed_dim,
    drop_prob=0.1, depth_prob=0.1,
    vitvq_path=vitvq_path
)
encoder.cuda()
print()

In [None]:
with torch.autocast(device_type='cuda', dtype=torch.float16):
    code, logits, attn = encoder(_batch[0])
code.shape

In [None]:
decoder = Decoder(
    is_finetune=is_finetune,
    image_size=image_size, patch_size=patch_size,
    depth=depth, heads=heads, dim=dim,
    n_codes=n_codes, embed_dim=embed_dim,
    drop_prob=0.1, depth_prob=0.1,
    vitvq_path=vitvq_path
)
decoder.cuda()
print()

In [None]:
with torch.autocast(device_type='cuda', dtype=torch.float16):
    output = decoder(code)
output.shape

In [5]:
is_dev = False

lightning_model = AttentionMaskModeling(
    model=MaskVideo(
        is_finetune=is_finetune,
        image_size=image_size, patch_size=patch_size,
        depth=depth, heads=heads, dim=dim,
        n_codes=n_codes, embed_dim=embed_dim,

        window_size=window_size,
        length=length, height=height, width=width,
        temporal_depth=temporal_depth, temporal_heads=temporal_heads, temporal_dim=temporal_dim,

        drop_prob=0.1, depth_prob=0.1,
        vitvq_path=vitvq_path
	), 
    top_p=0.95
)
# wandb_logger = WandbLogger(
# 	project="semcom",
# )
# wandb_logger.experiment.config.update({
#     "dim": dim,
#     "depth": depth
# })
trainer = pl.Trainer(
    # training settings
    max_epochs=25,
    val_check_interval=1.0,
    accelerator="cpu" if is_dev else "gpu",
    precision="32-true" if is_dev else "16-mixed",
    accumulate_grad_batches=1,
    gradient_clip_val=1.0,
    # logging settings
    default_root_dir=f"./checkpoints",
    # logger=wandb_logger,
    callbacks=[
        ModelCheckpoint(
			monitor="train_acc",
			dirpath="./trained",
			filename="semcom-{epoch:02d}-{train_acc:.2f}",
			save_top_k=3,
			mode="max",
		)
	],
    # dev setting
    fast_dev_run=is_dev,
)
trainer.fit(
    lightning_model, 
    train_dataloaders=ds_loader,
    # val_dataloaders=None,
)

Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/home/leonard/anaconda3/envs/semcom/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:76: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default
`Trainer(val_check_interval=1.0)` was configured so validation will run at the end of the training epoch..
You are using a CUDA device ('NVIDIA GeForce RTX 3090 Ti') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which

Training: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=25` reached.
