In [1]:
import numpy as np
import pandas as pd
import scipy.stats as sps
from tqdm import tqdm
from torchinfo import summary # DEBUG

from utils.utils import *
from utils.dataset_loaders import *

import torch
from torch.utils.data import DataLoader
from diffusers import UNet3DConditionModel
from diffusers.optimization import get_cosine_schedule_with_warmup

import matplotlib.pyplot as plt
import seaborn as sns
sns.set_theme(style="whitegrid")

Instruments for manual noising and denoising.

In [3]:
# А нужно ли?

class CorrellatedNoiseVideoScheduler():
    __slots__ = "betas", "alphas", "alphas_cumprod"

    def __init__(
        self,
        num_steps=1000,
        beta_start=1e-4,
        beta_end=2e-2
    ):
        self.betas = torch.linspace(beta_start, beta_end, num_steps, dtype=torch.float32)
        self.alphas = 1.0 - self.betas
        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
    
    def add_noise(
        self,
        original_samples: torch.FloatTensor,
        noise: torch.FloatTensor,
        timesteps: torch.IntTensor,
    ) -> torch.FloatTensor:
        # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
        alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
        timesteps = timesteps.to(original_samples.device)

        sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
        sqrt_alpha_prod = sqrt_alpha_prod.flatten()
        while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
            sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)

        sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
        sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
        while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
            sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)

        noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
        return noisy_samples
    
    def step(
        self,
        model_output: torch.FloatTensor,
        timestep: int,
        sample: torch.FloatTensor,
        generator=None,
        return_dict: bool = True,
    ) -> Union[DDPMSchedulerOutput, Tuple]:
        t = timestep

        prev_t = self.previous_timestep(t)

        if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]:
            model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1)
        else:
            predicted_variance = None

        # 1. compute alphas, betas
        alpha_prod_t = self.alphas_cumprod[t]
        alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one
        beta_prod_t = 1 - alpha_prod_t
        beta_prod_t_prev = 1 - alpha_prod_t_prev
        current_alpha_t = alpha_prod_t / alpha_prod_t_prev
        current_beta_t = 1 - current_alpha_t

        # 2. compute predicted original sample from predicted noise also called
        # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
        pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)

        # 3. Clip or threshold "predicted x_0"
        if self.config.thresholding:
            pred_original_sample = self._threshold_sample(pred_original_sample)
        elif self.config.clip_sample:
            pred_original_sample = pred_original_sample.clamp(
                -self.config.clip_sample_range, self.config.clip_sample_range
            )

        # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
        # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
        pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * current_beta_t) / beta_prod_t
        current_sample_coeff = current_alpha_t ** (0.5) * beta_prod_t_prev / beta_prod_t

        # 5. Compute predicted previous sample µ_t
        # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
        pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample

        # 6. Add noise
        variance = 0
        if t > 0:
            device = model_output.device
            variance_noise = randn_tensor(
                model_output.shape, generator=generator, device=device, dtype=model_output.dtype
            )
            if self.variance_type == "fixed_small_log":
                variance = self._get_variance(t, predicted_variance=predicted_variance) * variance_noise
            elif self.variance_type == "learned_range":
                variance = self._get_variance(t, predicted_variance=predicted_variance)
                variance = torch.exp(0.5 * variance) * variance_noise
            else:
                variance = (self._get_variance(t, predicted_variance=predicted_variance) ** 0.5) * variance_noise

        pred_prev_sample = pred_prev_sample + variance

        if not return_dict:
            return (pred_prev_sample,)

        return DDPMSchedulerOutput(prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample)

NameError: name 'Union' is not defined

Creating dataset and dataloader for UCF-101.

In [3]:
UCF_dataset = UCFDataset("./datasets/UCF-101/")

batch_size = 1
UCF_dataloader = DataLoader(UCF_dataset, shuffle=True, batch_size=batch_size)

Trying default DDPMScheduler for working with videos.

In [None]:
total_num_steps = 1000

def train_simple_new(
    model,
    dataloader,
    noise_scheduler,
    optimizer,
    lr_scheduler,
    criterion,
    num_epochs,
    device="cuda:0",
    noise_cov=lambda x: torch.eye(x),
):
    """
    noise_cov -- matrix with the shape of video length or callable that receives video length and 
                 returns matrix
    """

    losses = []
    for epoch in range(num_epochs):
        pbar = tqdm(dataloader)
        for i, (videos, _) in enumerate(pbar):
            videos = videos.to(device)
            steps = torch.randint(low=0, high=total_num_steps + 1, size=(videos.shape[0],), device=device)
            if callable(noise_cov):
                noise_gen = MultivarNorm(cov_matrix = noise_cov(videos.shape[1]))
            else:
                noise_gen = MultivarNorm(cov_matrix = noise_cov)
            noise = noise_gen.sample(videos.shape)
            noised_videos = noise_scheduler.add_noise(videos, noise, steps)
            predicted_noise = model(noised_videos, steps.to(device)).sample
            loss = criterion(noise, predicted_noise)
            losses.append(loss.item())

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            lr_scheduler.step()

            pbar.set_postfix(MSE=loss.item())

    return losses

def init_training_modules_new(
    lr_warmup_steps,
    num_epochs,
    device="cuda:0"
):
    model = UNet3DConditionModel(
        sample_size=img_size,
        in_channels=4,
        out_channels=3,
        layers_per_block=2,
        block_out_channels=(64, 128),
        down_block_types=(
            "CrossAttnDownBlock3D",
            "DownBlock3D",
        ),
        up_block_types=(
            "UpBlock3D",
            "CrossAttnUpBlock3D",
          ),
    )
    model.to(device)
    model.train()

    noise_scheduler = DDPMScheduler(
        num_train_timesteps=total_num_steps, beta_start=beta_start, beta_end=beta_end)

    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)

    lr_scheduler = get_cosine_schedule_with_warmup(
        optimizer=optimizer,
        num_warmup_steps=lr_warmup_steps,
        num_training_steps=(len(dataloader) * num_epochs),
    )

    criterion = nn.MSELoss()

    output = (model, noise_scheduler, optimizer, lr_scheduler, criterion)

    return output

def sample_videos(
    model,
    num_videos,
    video_length,
    prompts
):
    with torch.no_grad():
        sample = torch.randn(num_videos, video_length, 1, img_size, img_size).to(device)
        for i, t in enumerate(noise_scheduler.timesteps):
            residual = model(sample, t).sample
            sample = noise_scheduler.step(residual, t, sample).prev_sample
    return sample

In [57]:
model = UNet3DConditionModel(
        sample_size=(240, 320),
        in_channels=3,
        out_channels=3,
        layers_per_block=2,
        block_out_channels=(32, 32),
        down_block_types=(
            "CrossAttnDownBlock3D",
            "DownBlock3D",
        ),
        up_block_types=(
            "UpBlock3D",
            "CrossAttnUpBlock3D",
          ),
    )
# model = model.to("cuda:0")

In [59]:
model = model.cpu()

In [None]:
summary(
    model,
    input_data = {
        "sample": torch.randn(1, 3, 30, 240, 320),
        "timestep": 500,
        "encoder_hidden_states": torch.ones(1, 200, 1) * 3.0,
    }
)