In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from torch.utils.data import DataLoader

import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint

import matplotlib.pyplot as plt

from src.reader import KineticDataset, KineticDatasetVideo

from src.my_model import MaskVideo, Encoder
from src.trainer import AttentionMaskModeling

import torch

In [3]:
path = "/mnt/e/kinetics-dataset/k400"
split = "train"

ds = KineticDatasetVideo.get_ds(path, split, 16)
ds_loader = DataLoader(
    ds, 1, True,
    num_workers=4
)

100


In [4]:
is_finetune = False
image_size = (256, 256)
patch_size = (8, 8)
depth, heads, dim, embed_dim = 12, 12, 768, 32

window_size = (3, 3)
length, height, width = 32, 32, 32
temporal_depth, temporal_heads, temporal_dim = 4, 8, 128
n_codes = 8192

vitvq_path = "./checkpoint/imagenet_vitvq_base.ckpt"

In [5]:
lightning_model = AttentionMaskModeling(
    model=MaskVideo(
        is_finetune=is_finetune,
        
        image_size=image_size, patch_size=patch_size,
        vit_depth=depth, vit_heads=heads, vit_dim=dim,
        n_codes=n_codes, embed_dim=embed_dim,

        window_size=window_size,
        length=length, height=height, width=width,
        temporal_depth=temporal_depth, temporal_heads=temporal_heads, temporal_dim=temporal_dim,

        drop_prob=0.1, depth_prob=0.1,
        vitvq_path=vitvq_path
	), 
    top_p=0.9
)
# lightning_model.cuda()
print()




In [6]:
for batch in ds_loader:
    frames, _ = batch
    break

In [7]:
encoder = Encoder(
    vit=lightning_model.model.vit,
    pre_quant=lightning_model.model.pre_quant,
    quantizer=lightning_model.model.quantizer,
    temporal_transformer=lightning_model.model.transformer,
)
encoder.eval().cuda()
encoder.export()
print()

ONNX model saved: encoder.onnx



In [None]:
code, mask, ks, vs = encoder(frames.cuda())

mask.shape

In [8]:
ks, vs = None, None
for t in range(frames.shape[1]):
    frame = frames[:, t:t+1]
    code, mask, ks, vs = encoder(frame.cuda(), ks, vs)
    print(len(ks), ks[0].shape, ks.shape)
    print(code.shape, mask.shape)
    
    # break

4 torch.Size([1024, 8, 1, 16]) torch.Size([4, 1024, 8, 1, 16])
torch.Size([1, 1024]) torch.Size([1, 1024])
4 torch.Size([1024, 8, 2, 16]) torch.Size([4, 1024, 8, 2, 16])
torch.Size([1, 1024]) torch.Size([1, 1024])
4 torch.Size([1024, 8, 3, 16]) torch.Size([4, 1024, 8, 3, 16])
torch.Size([1, 1024]) torch.Size([1, 1024])
4 torch.Size([1024, 8, 4, 16]) torch.Size([4, 1024, 8, 4, 16])
torch.Size([1, 1024]) torch.Size([1, 1024])
4 torch.Size([1024, 8, 5, 16]) torch.Size([4, 1024, 8, 5, 16])
torch.Size([1, 1024]) torch.Size([1, 1024])
4 torch.Size([1024, 8, 6, 16]) torch.Size([4, 1024, 8, 6, 16])
torch.Size([1, 1024]) torch.Size([1, 1024])
4 torch.Size([1024, 8, 7, 16]) torch.Size([4, 1024, 8, 7, 16])
torch.Size([1, 1024]) torch.Size([1, 1024])
4 torch.Size([1024, 8, 8, 16]) torch.Size([4, 1024, 8, 8, 16])
torch.Size([1, 1024]) torch.Size([1, 1024])
4 torch.Size([1024, 8, 9, 16]) torch.Size([4, 1024, 8, 9, 16])
torch.Size([1, 1024]) torch.Size([1, 1024])
4 torch.Size([1024, 8, 10, 16]) torch

./trtexec --onnx=model.onnx --minShapes=input:1x1x3x256x256 --optShapes=input:16x3x244x244 --maxShapes=input:32x3x244x244 --shapes=input:5x3x244x244

In [None]:
lightning_model.load_from_checkpoint("./trained/semcom-epoch=21-train_acc=0.94.ckpt")

In [None]:
lightning_model.model.export_encoder()

In [None]:
encoder.export()