In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint

import numpy as np
from skimage import data
from skimage.transform import resize

import matplotlib.pyplot as plt

import torch

from src.reader import KineticDataset, KineticDatasetVideo

from src.vqgan import ViTVQGAN
from src.my_model import MaskCode, Encoder, Decoder, MaskVideo
from src.trainer import AttentionMaskModeling

# torch.set_float32_matmul_precision("high")

In [20]:
torch.Tensor([-6.55e4]).half()

tensor([-65504.], dtype=torch.float16)

In [3]:
path = "/mnt/e/kinetics-dataset/k400"
split = "train"

# ds = KineticDatasetVideo(
#     path, split,
#     n_frames=16,
# )
ds = KineticDatasetVideo.get_ds(path, split, 16)
ds_loader = DataLoader(
    ds, 2, True, 
    num_workers=4
)

100


In [4]:
for batch in ds_loader:
    frames, _ = batch
    break

# _batch = batch[0].cuda(), None
B, T, _, H, W = frames.shape
frames.shape

torch.Size([2, 16, 3, 256, 256])

In [5]:
is_finetune = False
image_size = (256, 256)
patch_size = (8, 8)
depth, heads, dim, embed_dim = 12, 12, 768, 32

window_size = (3, 3)
length, height, width = 32, 32, 32
temporal_depth, temporal_heads, temporal_dim = 4, 8, 128
n_codes=8192

vitvq_path = "./checkpoint/imagenet_vitvq_base.ckpt"

In [None]:
lightning_model = AttentionMaskModeling(
    model=MaskVideo(
        is_finetune=is_finetune,
        image_size=image_size, patch_size=patch_size,
        depth=depth, heads=heads, dim=dim,
        n_codes=n_codes, embed_dim=embed_dim,

        window_size=window_size,
        length=length, height=height, width=width,
        temporal_depth=temporal_depth, temporal_heads=temporal_heads, temporal_dim=temporal_dim,

        drop_prob=0.1, depth_prob=0.1,
        vitvq_path=vitvq_path
	), 
    top_p=0.95
)
lightning_model.cuda()
lightning_model.training_step(_batch, 0)

# frames, _ = _batch  # (B, T, HW)
# mask = lightning_model.get_mask_from_logits(frames)  # (B, T', HW)

# _, _, attn_logits = lightning_model.teacher.encoder(frames, None, False)
# attn_logits.shape

In [6]:
model = MaskVideo(
    is_finetune=is_finetune,
    image_size=image_size, patch_size=patch_size,
    depth=depth, heads=heads, dim=dim,
    n_codes=n_codes, embed_dim=embed_dim,

    window_size=window_size,
    length=length, height=height, width=width,
    temporal_depth=temporal_depth, temporal_heads=temporal_heads, temporal_dim=temporal_dim,

    drop_prob=0.1, depth_prob=0.1,
    vitvq_path=vitvq_path
)
model.eval()
model.cuda()
print()




In [None]:
encoder = Encoder(
    is_finetune=is_finetune,
    image_size=image_size, patch_size=patch_size,
    depth=depth, heads=heads, dim=dim,
    window_size=window_size,
    length=length, height=height, width=width,
    temporal_depth=temporal_depth, temporal_heads=temporal_heads, temporal_dim=temporal_dim,
    n_codes=n_codes, embed_dim=embed_dim,
    drop_prob=0.1, depth_prob=0.1,
    vitvq_path=vitvq_path
)
encoder.cuda()
print()

In [7]:
model.reset_cache()
code_a, logits_a = [], []
with torch.autocast(device_type='cuda', dtype=torch.float16):
    for t in range(T):
        frame = frames[:, t:t+1]  # indexing while keeping dimension

        code, logits, _, _ = model(frame.cuda(), None, use_cache=True)
        logits_a.append(logits[:, -1])
        code_a.append(code[:, -1])

        # if t == 3: break

code_a = torch.stack(code_a, dim=1)
logits_a = torch.stack(logits_a, dim=1)
logits_a.shape

torch.Size([2, 16, 1024, 8192])

In [8]:
model.reset_cache()
with torch.autocast(device_type='cuda', dtype=torch.float16):
    code_b, logits_b, _, _ = model(frames.cuda(), None, use_cache=False)
logits_b.shape, code_b.shape

(torch.Size([2, 17, 1024, 8192]), torch.Size([2, 16, 1024]))

In [None]:
with torch.autocast(device_type='cuda', dtype=torch.float16):
    code, logits, attn = encoder(_batch[0])

code.shape

In [None]:
decoder = Decoder(
    is_finetune=is_finetune,
    image_size=image_size, patch_size=patch_size,
    depth=depth, heads=heads, dim=dim,
    n_codes=n_codes, embed_dim=embed_dim,
    drop_prob=0.1, depth_prob=0.1,
    vitvq_path=vitvq_path
)
decoder.cuda()
print()

In [None]:
with torch.autocast(device_type='cuda', dtype=torch.float16):
    output = decoder(code)
output.shape

In [None]:
is_dev = False

lightning_model = AttentionMaskModeling(
    model=MaskVideo(
        is_finetune=is_finetune,
        image_size=image_size, patch_size=patch_size,
        depth=depth, heads=heads, dim=dim,
        n_codes=n_codes, embed_dim=embed_dim,

        window_size=window_size,
        length=length, height=height, width=width,
        temporal_depth=temporal_depth, temporal_heads=temporal_heads, temporal_dim=temporal_dim,

        drop_prob=0.1, depth_prob=0.1,
        vitvq_path=vitvq_path
	), 
    top_p=0.95
)
# wandb_logger = WandbLogger(
# 	project="semcom",
# )
# wandb_logger.experiment.config.update({
#     "dim": dim,
#     "depth": depth
# })

# torch.compile(lightning_model)
trainer = pl.Trainer(
    # training settings
    max_epochs=25,
    val_check_interval=1.0,
    accelerator="cpu" if is_dev else "gpu",
    precision="32-true" if is_dev else "16-mixed",
    accumulate_grad_batches=1,
    gradient_clip_val=1.0,
    # logging settings
    default_root_dir=f"./checkpoints",
    # logger=wandb_logger,
    callbacks=[
        ModelCheckpoint(
			monitor="train_acc",
			dirpath="./trained",
			filename="semcom-{epoch:02d}-{train_acc:.2f}",
			save_top_k=3,
			mode="max",
		)
	],
    # dev setting
    fast_dev_run=is_dev,
)
trainer.fit(
    lightning_model, 
    train_dataloaders=ds_loader,
    # val_dataloaders=None,
)