In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from transformers import SwinModel, SwinConfig
from PIL import Image
from einops import rearrange
from pathlib import Path
import sys

# Add TimeSformer path to Python path
sys.path.append('/scratch/sharath/TimeSformer')
from timesformer.models.vit import TimeSformer  # Import TimeSformer directly

# Constants
DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'
BATCH_SIZE = 4
EPOCHS = 150
LEARNING_RATE = 0.0001
IMAGE_SIZE = (224, 224)

# Dataset for loading images
class VideoGenerationDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = Path(root_dir)
        self.samples = list(self.root_dir.glob('*_frame_0.png'))
        self.transform = transform or transforms.Compose([
            transforms.Resize(IMAGE_SIZE),
            transforms.ToTensor(),
        ])
    
    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_base_name = self.samples[idx].stem.replace('_frame_0', '')
        img_paths = [self.root_dir / f"{img_base_name}_frame_{i}.png" for i in range(5)]
        
        # Load and transform frames, then stack them into the T dimension
        images = [self.transform(Image.open(p)) for p in img_paths]
        images = torch.stack(images, dim=1)  # Shape: (C, T, H, W)
        
        return images  # Return a tensor of shape (C, T, H, W)

# Define the enhanced video generation model
class EnhancedVideoGenerator(nn.Module):
    def __init__(self):
        super(EnhancedVideoGenerator, self).__init__()
        
        # Load pre-trained Swin Transformer for encoding individual frames
        swin_config = SwinConfig(image_size=224, num_labels=768)
        self.swin = SwinModel(swin_config)
        
        # Initialize TimeSformer without pre-trained weights for processing video frames
        self.timesformer = TimeSformer(img_size=224, num_classes=400, num_frames=5, attention_type='divided_space_time')
        
        # Define a Diffusion-based Decoder for frame generation
        self.decoder = nn.Sequential(
            nn.Linear(400, 224 * 224 * 3 * 5),  # Map 400 to the full pixel count for (5, 3, 224, 224)
            nn.Unflatten(1, (5, 3, 224, 224)),  # Reshape to (5, 3, 224, 224)
            nn.Tanh()
        )

    def forward(self, images):
        # Encode each frame using Swin Transformer
        frame_features = [self.swin(images[:, :, i, :, :]).last_hidden_state.mean(dim=1) for i in range(images.shape[2])]
        frame_features = torch.stack(frame_features, dim=1)  # Shape: (B, T, 768)
        
        # Pass temporal features through TimeSformer
        video_features = self.timesformer(frame_features)  # Shape: (B, 400)

        # Check the shape of video_features for debugging
        print(f"TimeSformer output shape: {video_features.shape}")

        # Decode to generate a sequence of frames
        generated_frames = self.decoder(video_features)
        
        return generated_frames  # Shape: (batch_size, 5, 3, 224, 224)

def train_model(model, dataloader, criterion, optimizer):
    model.train()
    for epoch in range(EPOCHS):
        running_loss = 0.0
        for images in dataloader:
            images = images.to(DEVICE)
            optimizer.zero_grad()

            # Get model outputs
            outputs = model(images)  # Shape: (batch_size, 5, 3, 224, 224)
            
            # Calculate the loss using the full temporal sequence
            loss = criterion(outputs, images)  # MSE loss comparing generated vs actual frames
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        
        print(f"Epoch [{epoch+1}/{EPOCHS}], Loss: {running_loss / len(dataloader):.4f}")

# Load Data
train_loader = DataLoader(VideoGenerationDataset("/scratch/sharath/MArketing/processed_dataset/train"), batch_size=BATCH_SIZE)

# Initialize model and training components
model = EnhancedVideoGenerator().to(DEVICE)
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
criterion = nn.MSELoss()



In [4]:
# Fine-tune the model
train_model(model, train_loader, criterion, optimizer)

# Save the trained model
torch.save(model.state_dict(), "enhanced_video_generator.pth")





Epoch 0   Train Loss     :21.5668000000  ; Validation Loss:22.4565000000  
Epoch 1   Train Loss     :21.4698700833  ; Validation Loss:22.3562238591  
Epoch 2   Train Loss     :21.3729401666  ; Validation Loss:22.2559477181  
Epoch 3   Train Loss     :21.2760102499  ; Validation Loss:22.1556715772  
Epoch 4   Train Loss     :21.1790803332  ; Validation Loss:22.0553954362  
Epoch 5   Train Loss     :21.0821504165  ; Validation Loss:21.9551192953  
Epoch 6   Train Loss     :20.9852204998  ; Validation Loss:21.8548431544  
Epoch 7   Train Loss     :20.8882905831  ; Validation Loss:21.7545670134  
Epoch 8   Train Loss     :20.7913606664  ; Validation Loss:21.6542908725  
Epoch 9   Train Loss     :20.6944307497  ; Validation Loss:21.5540147315  
Epoch 10  Train Loss     :20.5975008330  ; Validation Loss:21.4537385906  
Epoch 11  Train Loss     :20.5005709164  ; Validation Loss:21.3534624497  
Epoch 12  Train Loss     :20.4036409997  ; Validation Loss:21.2531863087  
Epoch 13  Train Loss     

In [None]:
#  Inference function
def generate_video(model, input_images):
    model.eval()
    with torch.no_grad():
        output_frames = model(input_images.to(DEVICE))
    return output_frames

# Test DataLoader and video generation
test_loader = DataLoader(VideoGenerationDataset("/scratch/sharath/MArketing/processed_dataset/test"), batch_size=1)
input_images = next(iter(test_loader))
generated_frames = generate_video(model, input_images)

# Save frames as video
def save_video(frames, filename="generated_video.avi", fps=10):
    height, width = frames[0].shape[1], frames[0].shape[2]
    out = cv2.VideoWriter(filename, cv2.VideoWriter_fourcc(*'DIVX'), fps, (width, height))
    for frame in frames:
        frame_np = (frame.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
        out.write(cv2.cvtColor(frame_np, cv2.COLOR_RGB2BGR))
    out.release()
    print('Video saved in designated path')

save_video(generated_frames[0], "generated_video.avi")


Video saved in designated path


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from transformers import CLIPModel
from PIL import Image
from pathlib import Path
import numpy as np
import cv2
from einops import rearrange
import sys
import torchaudio

# Add TimeSformer path to Python path
sys.path.append('/scratch/sharath/TimeSformer')
from timesformer.models.vit import TimeSformer  # Import TimeSformer directly

In [None]:
# Constants
DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'
BATCH_SIZE = 4
EPOCHS = 50
LEARNING_RATE = 0.0001
IMAGE_SIZE = (224, 224)

In [None]:
class VideoGenerationDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = Path(root_dir)
        self.samples = list(self.root_dir.glob('*_frame_0.png'))
        self.transform = transform or transforms.Compose([
            transforms.Resize(IMAGE_SIZE),
            transforms.ToTensor(),
        ])
    
    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_base_name = self.samples[idx].stem.replace('_frame_0', '')
        img_paths = [self.root_dir / f"{img_base_name}_frame_{i}.png" for i in range(5)]
        
        images = [self.transform(Image.open(p)) for p in img_paths]
        images = torch.stack(images)  # Shape: (5, C, H, W)
        
        return images


In [None]:
# Model combining pre-trained components
class MultiPretrainedVideoGenerator(nn.Module):
    def __init__(self):
        super(MultiPretrainedVideoGenerator, self).__init__()
        
        # Load pre-trained model for image encoding
        self.clip = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").vision_model
        
        # Initialize TimeSformer without pre-trained weights
        self.timesformer = TimeSformer(img_size=224, num_classes=400, num_frames=8, attention_type='divided_space_time')

        # Frame Decoder for high-quality frame synthesis
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(768, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(128, 3, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )

    def forward(self, images):
        # Encode each frame using CLIP
        frame_features = [self.clip(img.unsqueeze(0)).last_hidden_state.mean(dim=1) for img in images]
        frame_features = torch.stack(frame_features).squeeze(1)  # Shape: (5, 768)

        # Temporal processing with TimeSformer
        video_features = self.timesformer(frame_features.unsqueeze(0)).squeeze(0)  # Shape: (5, 768)

        # Decode each feature to produce frames
        generated_frames = [self.decoder(feature.unsqueeze(-1).unsqueeze(-1)) for feature in video_features]
        
        return torch.stack(generated_frames)

In [None]:
# Training loop
def train_model(model, dataloader, criterion, optimizer):
    model.train()
    for epoch in range(EPOCHS):
        running_loss = 0.0
        for images in dataloader:
            images = images.to(DEVICE)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, images)  # MSE loss comparing generated vs actual frames
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        
        print(f"Epoch [{epoch+1}/{EPOCHS}], Loss: {running_loss / len(dataloader):.4f}")

# Load Data
train_loader = DataLoader(VideoGenerationDataset("/scratch/sharath/MArketing/processed_dataset/train"), batch_size=BATCH_SIZE)

# Initialize model and training components
model = MultiPretrainedVideoGenerator().to(DEVICE)
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
criterion = nn.MSELoss()




In [None]:
model

In [None]:
# Fine-tune the model
train_model(model, train_loader, criterion, optimizer)

In [None]:
# Fine-tune the model
train_model(model, train_loader, criterion, optimizer)

# Save the trained model
torch.save(model.state_dict(), "pretrained_video_generator.pth")

# Inference function
def generate_video(model, input_images, input_audio):
    model.eval()
    with torch.no_grad():
        output_frames = model(input_images.to(DEVICE), input_audio.to(DEVICE))
    return output_frames

# Test DataLoader and video generation
test_loader = DataLoader(VideoGenerationDataset("/scratch/sharath/MArketing/processed_dataset/test"), batch_size=1)
input_images, input_audio = next(iter(test_loader))
generated_frames = generate_video(model, input_images, input_audio)

# Save frames as video
def save_video(frames, filename="generated_video.avi", fps=10):
    height, width = frames[0].shape[1], frames[0].shape[2]
    out = cv2.VideoWriter(filename, cv2.VideoWriter_fourcc(*'DIVX'), fps, (width, height))
    for frame in frames:
        frame_np = (frame.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
        out.write(cv2.cvtColor(frame_np, cv2.COLOR_RGB2BGR))
    out.release()

save_video(generated_frames[0], "generated_video.avi")
