In [36]:
import torch
import torch.optim as optim
import torch.nn as nn
import torchvision
from torch.utils.data import DataLoader, Dataset
import torchvision.models as models
import torchvision.transforms as transforms
from torchvision import datasets
from PIL import Image
import numpy as np
import clip
import matplotlib.pyplot as plt

In [19]:
class ImageTextToVideoGenerator(nn.Module):
    def __init__(self):
        super(ImageTextToVideoGenerator, self).__init__()
        
        # Image Encoder (ResNet-18 for now, replaceable)
        self.image_encoder = models.resnet18(pretrained=True)
        self.image_encoder.fc = nn.Linear(self.image_encoder.fc.in_features, 512)
        
        # Text Encoder (Simple LSTM for now, replaceable)
        self.text_encoder = nn.LSTM(input_size=512, hidden_size=512, num_layers=1, batch_first=True)
        
        # Video Decoder (Basic Conv3D for now, replaceable)
        self.video_decoder = nn.Sequential(
            nn.ConvTranspose3d(512, 256, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1)),
            nn.ReLU(),
            nn.ConvTranspose3d(256, 128, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1)),
            nn.ReLU(),
            nn.ConvTranspose3d(128, 3, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1)),
            nn.Tanh()
        )

    def forward(self, image, text):
        # Encode image
        img_features = self.image_encoder(image)
        
        # Encode text
        _, (text_features, _) = self.text_encoder(text)
        text_features = text_features[-1]  # Take last hidden state
        
        # Combine image and text features
        combined_features = img_features + text_features
        
        # Reshape for video generation (assuming 16 frames, 4x4 spatial size initially)
        video_latent = combined_features.view(-1, 512, 1, 1, 1).repeat(1, 1, 4, 4, 4)
        
        # Decode video
        video = self.video_decoder(video_latent)
        
        return video

In [21]:
model = ImageTextToVideoGenerator()
image = torch.randn(1, 3, 224, 224)  # Example input image
text = torch.randn(1, 10, 512)  # Example input text (10 words, 300-d vectors)
output_video = model(image, text)
print("Generated video shape:", output_video.shape)  # Expected: (1, 3, 16, 64, 64)

Generated video shape: torch.Size([1, 3, 32, 32, 32])


In [22]:
model = ImageTextToVideoGenerator()
model.eval()

ImageTextToVideoGenerator(
  (image_encoder): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1

In [23]:
def get_clip_embeddings(text):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    clip_model, preprocess = clip.load("ViT-B/32", device=device)
    text_tokens = clip.tokenize([text]).to(device)
    with torch.no_grad():
        text_features = clip_model.encode_text(text_tokens)
    return text_features

In [24]:
from torchvision import transforms

def preprocess_image(image_path):
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    image = Image.open(image_path).convert("RGB")
    image = transform(image).unsqueeze(0)  # Add batch dimension
    return image



In [25]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)

# Move model to device
model = ImageTextToVideoGenerator().to(device)
model.eval()
image = preprocess_image("images/test.jpg").to(device).float()

# Load and move text features to device
text_features = get_clip_embeddings("A car is driving on the road").to(device).float()

with torch.no_grad():
    output_video = model(image, text_features)

print("Generated video shape:", output_video.shape)  # Expected: (1, 3, 16, 64, 64)


Device: cuda
Generated video shape: torch.Size([1, 3, 32, 32, 32])


In [26]:
import imageio

In [32]:
# Convert the tensor to a numpy array (assumes output_video is a tensor)
video_frames = output_video.squeeze(0).permute(1,2,3,0).cpu().numpy()  # Shape: (frames, height, width, channels)
print("Video frames shape:", video_frames.shape)
# Ensure the frames are in the range [0, 255] and are of type uint8
video_frames = np.clip(video_frames * 255, 0, 255).astype(np.uint8)
# Define the video writer
output_path = "outputs/raw.mp4"
fps = 30  # Frames per second, adjust as needed

# Save the video using imageio
with imageio.get_writer(output_path, fps=fps) as writer:
    for frame in video_frames:
        # Ensure the frame has the correct shape (height, width, channels)
        if frame.shape[-1] == 3:  # Make sure we have 3 channels (RGB)
            writer.append_data(frame)
        else:
            print(f"Skipping frame due to incorrect channel count: {frame.shape}")

print(f"Video saved to {output_path}")

Video frames shape: (32, 32, 32, 3)
Video saved to outputs/raw.mp4


In [37]:
device = "cuda" if torch.cuda.is_available() else "cpu"

# Define transformations to apply to the images
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Load CIFAR-10 dataset
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

100%|██████████| 170M/170M [00:15<00:00, 10.9MB/s] 


In [38]:
# Generate random text embeddings (simulate this)
def generate_random_text_embeddings(batch_size, seq_length, embedding_size=512):
    return torch.randn(batch_size, seq_length, embedding_size).to(device)


In [42]:
import torch.nn.functional as F

def image_text_coherence_loss(image_features, text_features):
    # Cosine similarity between image and text features
    similarity = F.cosine_similarity(image_features, text_features)
    loss = 1 - similarity.mean()  # Want to minimize the distance between features
    return loss

def temporal_coherence_loss(video_frames):
    # Temporal loss: ensure smooth transitions between frames
    loss = 0
    for i in range(1, video_frames.size(2)):  # Compare consecutive frames
        loss += F.mse_loss(video_frames[:, :, i], video_frames[:, :, i - 1])
    return loss

def perceptual_loss(generated_frames, target_image):
    # Use a pre-trained VGG model for perceptual loss
    vgg = models.vgg19(pretrained=True).features.to(device).eval()

    # Initialize a loss accumulator
    total_loss = 0.0

    # Iterate over each frame in the generated video
    for i in range(generated_frames.size(2)):  # Iterate over the video frames (depth dimension)
        frame = generated_frames[:, :, i, :, :]  # Extract one frame from the 5D tensor

        # Compute perceptual loss for each frame
        with torch.no_grad():
            target_features = vgg(target_image)
            generated_features = vgg(frame)

        total_loss += F.mse_loss(generated_features, target_features)

    return total_loss



In [43]:
# Initialize model, loss functions, and optimizer
model = ImageTextToVideoGenerator().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for images, labels in train_loader:
        images = images.to(device)

        # Generate random text embeddings for each batch
        text_embeddings = generate_random_text_embeddings(batch_size=images.size(0), seq_length=10)

        # Zero the gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = model(images, text_embeddings)

        # Losses
        image_features = model.image_encoder(images)
        text_features = text_embeddings[:, -1, :]  # Use last text feature

        # 1. Image-Text Coherence Loss
        it_loss = image_text_coherence_loss(image_features, text_features)

        # 2. Temporal Coherence Loss
        temporal_loss = temporal_coherence_loss(outputs)

        # 3. Perceptual Loss (optional)
        perceptual_loss_value = perceptual_loss(outputs, images)

        # Total loss (you can weight the losses as needed)
        total_loss = it_loss + temporal_loss + perceptual_loss_value
        total_loss.backward()
        
        # Update model parameters
        optimizer.step()

        running_loss += total_loss.item()

    epoch_loss = running_loss / len(train_loader)
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}')


KeyboardInterrupt: 