In [None]:
# Install necessary libraries
!pip install -q torch torchvision diffusers opencv-python moviepy ffmpeg-python numpy pillow ipython

# Import required libraries
import os
import torch
import numpy as np
from PIL import Image
from IPython.display import HTML
from pyramid_dit import PyramidDiTForVideoGeneration
from diffusers.utils import export_to_video
import pandas as pd
import cv2
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from pyramid_dit.flux_modules.modeling_text_encoder import FluxTextEncoderWithMask
from transformers import CLIPModel, CLIPProcessor
import torch.nn as nn
from torch.optim import AdamW
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from IPython.display import HTML, display
from pyramid_dit import PyramidDiTForVideoGeneration
from torch.distributions import Categorical
import torch.nn.functional as F

# Define the model path and configuration
model_path = "/pyramid-flow-miniflux"  # Replace with your Pyramid Flow checkpoint directory
model_dtype = "bf16"  # Options: "bf16", "fp16", "fp32"

# Set up device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize the Pyramid Flow model
variant = "diffusion_transformer_384p"  # Use 384p variant for now
model = PyramidDiTForVideoGeneration(
    model_path,
    model_dtype,
    model_name="pyramid_flux",
    model_variant=variant,
)

# Move components to GPU
model.vae.to(device)
model.dit.to(device)
model.text_encoder.to(device)
model.vae.enable_tiling()

# Set torch dtype
if model_dtype == "bf16":
    torch_dtype = torch.bfloat16
elif model_dtype == "fp16":
    torch_dtype = torch.float16
else:
    torch_dtype = torch.float32

In [None]:
class VideoTextDataset(Dataset):
    def __init__(self, video_dir, metadata_path, resolution=384, num_frames=16):
        """
        Args:
            video_dir (str): Directory containing video files.
            metadata_path (str): Path to the metadata CSV file with columns 'video_filename' and 'caption'.
            resolution (int): Target resolution for video frames.
            num_frames (int): Number of frames to sample from each video.
        """
        self.video_dir = video_dir
        self.metadata = pd.read_csv(metadata_path)
        self.resolution = resolution
        self.num_frames = num_frames

    def __len__(self):
        return len(self.metadata)

    def __getitem__(self, idx):
        # Get video file path and corresponding caption
        video_filename = self.metadata.iloc[idx]["video_filename"]
        caption = self.metadata.iloc[idx]["caption"]
        video_path = os.path.join(self.video_dir, video_filename)

        # Process video
        frames = self._process_video(video_path)
        return frames, caption

    def _process_video(self, video_path):
        """Load and process video frames."""
        cap = cv2.VideoCapture(video_path)
        frames = []
        total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        step = max(1, total_frames // self.num_frames)

        for i in range(0, total_frames, step):
            cap.set(cv2.CAP_PROP_POS_FRAMES, i)
            success, frame = cap.read()
            if not success:
                break

            # Convert frame to RGB and resize
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            frame = cv2.resize(frame, (self.resolution, self.resolution), interpolation=cv2.INTER_AREA)
            frame = torch.tensor(frame).permute(2, 0, 1) / 255.0  # Normalize to [0, 1]
            frames.append(frame)

            if len(frames) == self.num_frames:
                break

        cap.release()

        # Pad frames if fewer than num_frames are available
        while len(frames) < self.num_frames:
            frames.append(frames[-1])

        return torch.stack(frames)  # Shape: [num_frames, 3, resolution, resolution]


In [None]:
# Define dataset paths
video_dir = "./finetune_dataset/video/files"  # Replace with your video directory
metadata_path = "./finetune_dataset/metadata.csv"  # Replace with your metadata CSV file

# Dataset and DataLoader parameters
resolution = 384  # Video resolution (aligned with the model variant)
num_frames = 16  # Number of frames per video
batch_size = 4  # Batch size for training

# Initialize dataset and DataLoader
dataset = VideoTextDataset(video_dir, metadata_path, resolution=resolution, num_frames=num_frames)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)

# Verify DataLoader output
for frames, captions in dataloader:
    print("Frames shape:", frames.shape)  # [batch_size, num_frames, 3, resolution, resolution]
    print("Captions:", captions)
    break


In [None]:
# Initialize the FluxTextEncoder
text_encoder = FluxTextEncoderWithMask(model_path, torch_dtype).to(device)

# Define a function to encode prompts
def encode_prompts(prompts, text_encoder, num_images_per_prompt=1, device=device):
    """
    Encode textual prompts using the FluxTextEncoder.

    Args:
        prompts (list[str]): List of textual prompts.
        text_encoder (FluxTextEncoderWithMask): Initialized text encoder model.
        num_images_per_prompt (int): Number of video generations per prompt.
        device (torch.device): Device for computation.

    Returns:
        prompt_embeds: T5 embeddings for the prompts.
        prompt_attention_mask: Attention masks for the T5 embeddings.
        pooled_prompt_embeds: CLIP embeddings for the prompts.
    """
    with torch.no_grad():
        prompt_embeds, prompt_attention_mask, pooled_prompt_embeds = text_encoder.encode_prompt(
            prompts, num_images_per_prompt=num_images_per_prompt, device=device
        )
    return prompt_embeds, prompt_attention_mask, pooled_prompt_embeds

# Test the text encoder with sample prompts
test_prompts = ["A cinematic video of a spaceship landing on Mars", "A dog running in the park"]

# Encode prompts
prompt_embeds, prompt_attention_mask, pooled_prompt_embeds = encode_prompts(test_prompts, text_encoder)

# Display the shapes of the embeddings
print("T5 Prompt Embeddings Shape:", prompt_embeds.shape)  # [batch_size, seq_len, hidden_dim]
print("Prompt Attention Mask Shape:", prompt_attention_mask.shape)  # [batch_size, seq_len]
print("CLIP Pooled Embeddings Shape:", pooled_prompt_embeds.shape)  # [batch_size, hidden_dim]


In [None]:
# Reward Model
# here, I use CLIP to extract the feature and text embeddings. They are then combined into one tensor. 
# the tensor is turned into a scalar value using a few sequential layers, which forms the reward score.
class RAFTRewardModel(nn.Module):
    def __init__(self, clip_model_name="openai/clip-vit-large-patch14"):
        super().__init__()
        self.clip = CLIPModel.from_pretrained(clip_model_name)
        self.reward_head = nn.Sequential(
            nn.Linear(768, 384),
            nn.ReLU(),
            nn.Linear(384, 1)
        )
        
    def forward(self, frames, prompts):
        clip_features = self.clip.get_image_features(frames)
        text_features = self.clip.get_text_features(prompts)
        combined_features = torch.cat([clip_features, text_features], dim=-1)
        return self.reward_head(combined_features)

In [None]:
# Load a pre-trained CLIP model for evaluating the alignment between generated frames and the textual prompt
clip_model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14").to(device)
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")

def evaluate_aesthetic_score(frames, text_prompt, clip_model, clip_processor, device):
    """
    Evaluate the aesthetic alignment between video frames and a text prompt.

    Args:
        frames (torch.Tensor): A tensor of video frames with shape [num_frames, 3, height, width].
        text_prompt (str): The textual description for the video.
        clip_model (CLIPModel): Pre-trained CLIP model for scoring.
        clip_processor (CLIPProcessor): CLIP processor for preparing inputs.
        device (torch.device): The device for computation (e.g., "cuda").

    Returns:
        float: The aesthetic alignment score between the frames and the text prompt.
    """
    # Prepare video frames for CLIP
    processed_frames = []
    for frame in frames:
        frame_np = frame.permute(1, 2, 0).cpu().numpy()  # Convert to HWC format
        processed_frame = clip_processor(images=frame_np, return_tensors="pt")["pixel_values"]
        processed_frames.append(processed_frame)

    processed_frames = torch.cat(processed_frames).to(device)  # Shape: [num_frames, 3, height, width]

    # Encode the text prompt
    text_inputs = clip_processor(text=[text_prompt], return_tensors="pt", padding=True).to(device)
    text_features = clip_model.get_text_features(**text_inputs)

    # Encode the video frames
    with torch.no_grad():
        image_features = clip_model.get_image_features(processed_frames)

    # Normalize features for cosine similarity
    text_features = text_features / text_features.norm(dim=-1, keepdim=True)
    image_features = image_features / image_features.norm(dim=-1, keepdim=True)

    # Compute aesthetic alignment score as the average cosine similarity
    scores = torch.matmul(image_features, text_features.T).mean().item()
    return scores

# Example usage
sample_frames = torch.rand((16, 3, 384, 384))  # Replace with actual video frames (16 frames, 384x384 resolution)
sample_prompt = "A cinematic video of a spaceship landing on Mars"

# Evaluate the score
aesthetic_score = evaluate_aesthetic_score(sample_frames, sample_prompt, clip_model, clip_processor, device)
print("Aesthetic Score:", aesthetic_score)


In [None]:
# Parameters for fine-tuning the text-to-video model

# @markdown Enter value for `resolution`.
resolution = 384  # @param {type:"integer"}

# @markdown Enter value for `num_frames` (number of frames per video).
num_frames = 16  # @param {type:"integer"}

# @markdown Enter value for `batch_size`.
batch_size = 4  # @param {type:"integer"}

# @markdown Enter value for `num_inference_steps`.
num_inference_steps = [20, 20, 20]  # @param {type:"raw"}  # Per pyramid level inference steps

# @markdown Enter value for `guidance_scale`.
guidance_scale = 7.0  # @param {type:"number"}  # Strength of text conditioning for initial frame

# @markdown Enter value for `video_guidance_scale`.
video_guidance_scale = 5.0  # @param {type:"number"}  # Strength of text conditioning for subsequent frames

# @markdown Enter value for `learning_rate`.
learning_rate = 1e-5  # @param {type:"number"}  # Learning rate for fine-tuning

# @markdown Enter value for `epochs`.
epochs = 10  # @param {type:"integer"}  # Number of training epochs

# @markdown Enter value for `fps`.
fps = 24  # @param {type:"integer"}  # Frames per second for generated videos

# @markdown Enter value for `num_candidates`.
num_candidates = 8 #8 candidates for each and every input prompt.

# Display selected parameters
print(f"Resolution: {resolution}x{resolution}")
print(f"Number of Frames per Video: {num_frames}")
print(f"Batch Size: {batch_size}")
print(f"Number of Inference Steps per Pyramid Level: {num_inference_steps}")
print(f"Guidance Scale (Initial Frame): {guidance_scale}")
print(f"Video Guidance Scale (Subsequent Frames): {video_guidance_scale}")
print(f"Learning Rate: {learning_rate}")
print(f"Number of Epochs: {epochs}")
print(f"Frames per Second (FPS): {fps}")
print(f"Number of candidates for RAFT fine-tuning: {num_candidates}")


In [None]:
# Ensure the directory for saving the model exists
output_dir = "./fine_tuned_model"
os.makedirs(output_dir, exist_ok=True)

# Initialize dataset and dataloader
video_dataset = VideoTextDataset(
    video_dir=video_dir,
    metadata_path=metadata_path,
    resolution=resolution,
    num_frames=num_frames
)
video_dataloader = DataLoader(video_dataset, batch_size=batch_size, shuffle=True, num_workers=4)

# Define loss function
reconstruction_criterion = nn.MSELoss()  # Reconstruction loss for frame quality
aesthetic_weight = 0.5  # Weight for aesthetic score in the loss function

pyramid_model = PyramidDiTForVideoGeneration(
    model_path=model_path,
    model_dtype=torch_dtype,
    model_name="pyramid_flux",
    model_variant=variant,
).to(device)

# Load CLIP model for aesthetic evaluation
clip_model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14").to(device)
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")


#modified iterative fine-tunign with RAFT principles
#Three main principles to be followed: data generation(n candidates per prompt), reward ranking, finetuning
def train_with_raft(model, reward_model, dataloader, num_candidates, epochs, 
                    learning_rate, device):
    optimizer = AdamW(model.parameters(), lr=learning_rate)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=epochs * len(dataloader)
    )

    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        progress_bar = tqdm(dataloader, desc=f"Epoch {epoch + 1}/{epochs}")

        for batch in progress_bar:
            frames, captions = batch
            frames = frames.to(device)
            
            prompt_embeds, prompt_attention_mask, pooled_prompt_embeds = text_encoder.encode_prompt(
            captions, num_images_per_prompt=1, device=device
        )
            # generate 8 candidates  
            candidates = []
            candidate_scores = []
            
            for _ in range(num_candidates):
                with torch.no_grad():
                    generated_frames = model.forward(
                        prompt_embeds=prompt_embeds,
                        pooled_prompt_embeds=pooled_prompt_embeds,
                        video_guidance_scale=video_guidance_scale,
                        guidance_scale=guidance_scale,
                        height=resolution,
                        width=resolution,
                        num_inference_steps=num_inference_steps
                    )
                    candidates.append(generated_frames)
                    
                    # reward socres 
                    score = reward_model(generated_frames, captions)
                    candidate_scores.append(score)
            
            # get preference using softmax normalisation. the resulting distribution favors candidates with higher reward scores.
            candidate_scores = torch.stack(candidate_scores)
            preferences = F.softmax(candidate_scores, dim=0)
            
            # regenarate samples with non zero preference, and weight the final reconstruction loss with the above computed preferences.
            raft_loss = 0
            for i in range(num_candidates):
                if preferences[i] > 0:
                    generated = model.forward(
                        prompt_embeds=prompt_embeds,
                        pooled_prompt_embeds=pooled_prompt_embeds,
                        video_guidance_scale=video_guidance_scale,
                        guidance_scale=guidance_scale,
                        height=resolution,
                        width=resolution,
                        num_inference_steps=num_inference_steps
                    )
                    reconstruction_loss = reconstruction_criterion(generated, candidates[i])
                    raft_loss += preferences[i] * reconstruction_loss

            optimizer.zero_grad()
            raft_loss.backward()
            optimizer.step()
            scheduler.step()

            running_loss += raft_loss.item()
            progress_bar.set_postfix({
                "loss": running_loss / (progress_bar.n + 1),
                "best_score": torch.max(candidate_scores).item()
            })

        print(f"Epoch {epoch + 1} Loss: {running_loss / len(dataloader):.4f}")
        
        # Save checkpoint
        model_save_path = os.path.join(output_dir, f"raft_model_epoch_{epoch + 1}.pt")
        torch.save({
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'epoch': epoch,
            'loss': running_loss,
        }, model_save_path)
        print(f"Model checkpoint saved at {model_save_path}")

    print("Fine-tuning complete!")

# init reward model
reward_model = RAFTRewardModel().to(device)

# train with raft 
train_with_raft(
    model=pyramid_model,
    reward_model=reward_model,
    dataloader=video_dataloader,
    num_candidates=num_candidates,
    epochs=epochs,
    learning_rate=learning_rate,
    device=device
)


In [None]:
# Function to display generated videos
def show_video(video_path, width="70%"):
    """
    Displays a video in the notebook.
    
    Args:
        video_path (str): Path to the video file.
        width (str): Width of the video display.
    """
    html = f"""
    <video controls style="width: {width};">
        <source src="{video_path}" type="video/mp4">
        Your browser does not support the video tag.
    </video>
    """
    return HTML(html)

# Load the fine-tuned model
fine_tuned_model_path = "./finetune_model"  # Replace with the actual path

# Initialize the fine-tuned model
pyramid_model = PyramidDiTForVideoGeneration(
    model_path=fine_tuned_model_path,
    model_dtype=torch_dtype,
    model_name="pyramid_flux",
    model_variant="diffusion_transformer_384p",
).to(device)

# Set generation parameters
test_prompts = [
    "A cinematic video of a spaceship landing on Mars",
    "A dog running in a beautiful park during sunset"
]
num_frames = 16
resolution = 384
num_inference_steps = [20, 20, 20]
guidance_scale = 7.0
video_guidance_scale = 5.0
fps = 24

# Generate and display videos for each prompt
for idx, prompt in enumerate(test_prompts):
    print(f"Generating video for prompt: '{prompt}'")

    with torch.no_grad(), torch.cuda.amp.autocast(enabled=True, dtype=torch_dtype):
        frames = pyramid_model.generate(
            prompt=prompt,
            num_inference_steps=num_inference_steps,
            height=resolution,
            width=resolution,
            temp=num_frames,
            guidance_scale=guidance_scale,
            video_guidance_scale=video_guidance_scale,
            output_type="pil",
        )

    # Export frames to a video file
    video_path = f"generated_video_{idx + 1}.mp4"
    export_to_video(frames, video_path, fps=fps)

    # Display the video in the notebook
    display(show_video(video_path))

print("Video generation complete!")