In [1]:
import numpy as np
import pandas as pd
import scipy.stats as sps
from tqdm import tqdm
from torchinfo import summary # DEBUG

from utils.utils import *
from utils.dataset_loaders import *
from models.basic_model import *

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision.io import write_video
from diffusers import UNet3DConditionModel, DDPMScheduler

import matplotlib.pyplot as plt
import seaborn as sns
sns.set_theme(style="whitegrid")

Creating dataset and dataloader for UCF-101.

In [2]:
UCF_dataset = UCFDataset("./datasets/UCF-101/")

batch_size = 1
UCF_dataloader = DataLoader(UCF_dataset, shuffle=True, batch_size=batch_size)

Trying default DDPMScheduler for working with videos.

In [3]:
model, noise_scheduler, optimizer, lr_scheduler, criterion = init_basic_model(
    lr_warmup_steps=100,
    num_epochs=1,
    beta_start=1e-4,
    beta_end=2e-2,
    object_cnt = len(UCF_dataloader),
)

In [4]:
train_simple_new(
    model=model,
    dataloader=UCF_dataloader,
    noise_scheduler=noise_scheduler,
    optimizer=optimizer,
    lr_scheduler=lr_scheduler,
    criterion=criterion,
    num_epochs=1,
    device="cpu",
    noise_cov=progressive_noise,
)

  0%|                                                                                                             | 0/13320 [05:37<?, ?it/s, MSE=1.09]

done once





[1.0888562202453613]

In [None]:
vid = sample_videos(model=model, num_videos=1, video_length=25, noise_scheduler=noise_scheduler)

  0%|▏                                                                                                           | 2/1000 [03:50<31:19:35, 113.00s/it]

In [3]:
# torch.backends.cuda.enable_mem_efficient_sdp(False)
# torch.backends.cuda.enable_flash_sdp(False)
# torch.backends.cuda.enable_math_sdp(False)

summary(
    model,
    input_data = {
        "sample": torch.randn(1, 3, 75, 240, 320),
        "timestep": 500,
        "encoder_hidden_states": torch.ones(1, 75, 24) * 3.0,
    }
)

Layer (type:depth-idx)                                                      Output Shape              Param #
UNet3DConditionModel                                                        [1, 3, 75, 240, 320]      --
├─Timesteps: 1-1                                                            [1, 12]                   --
├─TimestepEmbedding: 1-2                                                    [1, 48]                   2,352
│    └─LoRACompatibleLinear: 2-1                                            [1, 48]                   624
├─SiLU: 1-3                                                                 [1, 48]                   --
├─TimestepEmbedding: 1-4                                                    --                        (recursive)
│    └─LoRACompatibleLinear: 2-2                                            [1, 48]                   2,352
├─Conv2d: 1-5                                                               [75, 12, 240, 320]        336
├─TransformerTemporalModel: 1-6  