In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
import numpy as np
from torchvision import datasets
from tqdm import tqdm
import cv2
import os

# Define 3D convolutional UNet architecture for video generation
class VideoDDPM(nn.Module):
    def __init__(self):
        super(VideoDDPM, self).__init__()
        # Define your 3D UNet here (you can replace this with any suitable architecture)
        self.encoder = nn.Sequential(
            # Use 3D convolutions for spatial-temporal feature extraction
            nn.Conv3d(3, 64, kernel_size=3, stride=1, padding=1),  # Input: (batch_size, 3, T, H, W)
            nn.ReLU(),
            nn.Conv3d(64, 128, kernel_size=3, stride=2, padding=1),  # Downsampling in time and space
            nn.ReLU(),
            nn.Conv3d(128, 256, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            # Add more layers as necessary
        )

        self.decoder = nn.Sequential(
            nn.ConvTranspose3d(256, 128, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose3d(128, 64, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv3d(64, 3, kernel_size=3, stride=1, padding=1),
            nn.Sigmoid()  # To output normalized images in the range [0, 1]
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

# Custom dataset class for video frames
class VideoDataset(Dataset):
    def __init__(self, video_folder, frame_count=3, transform=None):
        self.video_folder = video_folder
        self.frame_count = frame_count
        self.transform = transform
        self.video_files = sorted(os.listdir(video_folder))

    def __len__(self):
        return len(self.video_files) - self.frame_count

    def __getitem__(self, idx):
        frames = []
        for i in range(self.frame_count):
            frame_path = os.path.join(self.video_folder, self.video_files[idx + i])
            frame = cv2.imread(frame_path)
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            frame = Image.fromarray(frame)
            if self.transform:
                frame = self.transform(frame)
            frames.append(frame)
        frames = torch.stack(frames)  # Stack frames into a tensor (T, H, W, C)
        return frames

# Training loop
def train_ddpm(model, dataset, epochs=100, batch_size=8, device="cuda"):
    model.to(device)
    model.train()
    optimizer = optim.Adam(model.parameters(), lr=1e-4)

    # Define DDPM noise schedule, loss function, and other necessary components
    # For simplicity, here we use a basic MSE loss between predicted frames and ground truth frames
    mse_loss = nn.MSELoss()

    data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    for epoch in range(epochs):
        epoch_loss = 0
        for batch_idx, frames in enumerate(tqdm(data_loader)):
            frames = frames.to(device)

            optimizer.zero_grad()

            # Add noise to frames for DDPM
            noisy_frames, noise = add_noise(frames)

            # Forward pass through DDPM model
            predicted_frames = model(noisy_frames)

            # Compute loss
            loss = mse_loss(predicted_frames, frames)  # Example: MSE between predicted and true frames
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()

        print(f"Epoch {epoch+1}/{epochs}, Loss: {epoch_loss / len(data_loader)}")

# Noise addition function (this is a very simplified version)
def add_noise(frames, noise_factor=0.1):
    noise = torch.randn_like(frames) * noise_factor
    noisy_frames = frames + noise
    return noisy_frames, noise

# Dataset and training setup
transform = T.Compose([T.Resize((64, 64)), T.ToTensor()])
video_dataset = VideoDataset(video_folder="path/to/video/frames", frame_count=16, transform=transform)

model = VideoDDPM()

# Train the DDPM model
train_ddpm(model, video_dataset, epochs=100, batch_size=8, device="cuda")


  from .autonotebook import tqdm as notebook_tqdm
