# Building an AI Text-to-Video Model from Scratch

## What is being built?

<img src='https://miro.medium.com/v2/resize:fit:2000/format:webp/1*6h3oJzGEH0xrER2Tv8M7KQ.gif' width='600'>

Such training datasets require extremely high computational power. Therefore, we will work with a video dataset of moving objects generated from Python code.

Use the GAN (Generative Adversarial Networks) architecture to create final model due to it's easier and quicker to train and test.

## GAN Architecture

### What is GAN?

Generative Adversarial Network (GAN) is a deep learning model where 2 neural networks compete:
- **generator**: creates new data (like images or music) from a given dataset
- **discriminator** tell if the data is real or fake.

This process continues until the generated data is indistinguishable from the original.

### Real-World Application

1. **Generate Images**: GANs create realistic images from text prompts or modify existing images, such as enhancing resolution or adding color to black-and-white photos.
2. **Data Augmentation**: They generate synthetic data to train other machine learning models, such as creating fraudulent transaction data for fraud detection systems.
3. **Complete Missing Information**: GANs can fill in missing data, like generating sub-surface images from terrain maps for energy applications.
4. **Generate 3D Models**: They convert 2D images into 3D models, useful in fields like healthcare for creating realistic organ images for surgical planning.

### How does a GAN work?

1. **Training Set Analysis**: The generator analyzes the training set to identify data attributes, while the discriminator independently analyzes the same data to learn its attributes.
2. **Data Modification**: The generator adds noise (random changes) to some attributes of the data.
3. **Data Passing**: The modified data is then passed to the discriminator.
4. **Probability Calculation**: The discriminator calculates the probability that the generated data is from the original dataset.
5. **Feedback Loop**: The discriminator provides feedback to the generator, guiding it to reduce random noise in the next cycle.
6. **Adversarial Training**: The generator tries to maximize the discriminator’s mistakes, while the discriminator tries to minimize its own errors. Through many training iterations, both networks improve and evolve.
7. **Equilibrium State**: Training continues until the discriminator can no longer distinguish between real and synthesized data, indicating that the generator has successfully learned to produce realistic data. At this point, the training process is complete.

<img src='https://miro.medium.com/v2/resize:fit:2000/format:webp/1*2HsK-UFPRvCdAmQyS3Ol1Q.jpeg' width=700>

### GAN training example

An example of image-to-image translation, focusing on modifying a human face.

1. Input Image: The input is a real image of a human face.
2. Attribute Modification: The generator modifies attributes of the face, like adding sunglasses to the eyes.
3. Generated Images: The generator creates a set of images with sunglasses added.
4. Discriminator’s Task: The discriminator receives a mix of real images (people with sunglasses) and generated images (faces where sunglasses were added).
5. Evaluation: The discriminator tries to differentiate between real and generated images.
6. Feedback Loop:
- If the discriminator correctly identifies fake images, the generator adjusts its parameters to produce more convincing images.
- If the generator successfully fools the discriminator, the discriminator updates its parameters to improve its detection.

Through this adversarial process, both networks continuously improve:
- The generator gets better at creating realistic images.
- The discriminator gets better at identifying fakes.

Until equilibrium is reached, where the discriminator can no longer tell the difference between real and generated images. At this point, the GAN has successfully learned to produce realistic modifications.

## Libs

In [1]:
import os
import random
import numpy as np

import cv2
from PIL import Image, ImageDraw, ImageFont

import torch
from torch.utils.data import Dataset
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
from torch.nn.utils.rnn import pad_sequence
from torchvision.utils import save_image

import matplotlib.pyplot as plt
from tqdm import tqdm

## Creating the Training Data

Our training video dataset consists of 1-sec-long 10,000 videos of a circle moving in different directions with different motions.

In [2]:
os.makedirs('training_dataset', exist_ok=True)
num_videos = 10000
frames_per_video = 10
img_size = (64, 64) # the size of each image in the dataset
shape_size = 10 # the size of the Circle

Define the text prompts of the training dataset based on which training videos will be generated.

In [3]:
# Define text prompts and corresponding movements for circles
prompts_and_movements = [
    ("circle moving down", "circle", "down"),  # Move circle downward
    ("circle moving left", "circle", "left"),  # Move circle leftward
    ("circle moving right", "circle", "right"),  # Move circle rightward
    ("circle moving diagonally up-right", "circle", "diagonal_up_right"),  # Move circle diagonally up-right
    ("circle moving diagonally down-left", "circle", "diagonal_down_left"),  # Move circle diagonally down-left
    ("circle moving diagonally up-left", "circle", "diagonal_up_left"),  # Move circle diagonally up-left
    ("circle moving diagonally down-right", "circle", "diagonal_down_right"),  # Move circle diagonally down-right
    ("circle rotating clockwise", "circle", "rotate_clockwise"),  # Rotate circle clockwise
    ("circle rotating counter-clockwise", "circle", "rotate_counter_clockwise"),  # Rotate circle counter-clockwise
    ("circle shrinking", "circle", "shrink"),  # Shrink circle
    ("circle expanding", "circle", "expand"),  # Expand circle
    ("circle bouncing vertically", "circle", "bounce_vertical"),  # Bounce circle vertically
    ("circle bouncing horizontally", "circle", "bounce_horizontal"),  # Bounce circle horizontally
    ("circle zigzagging vertically", "circle", "zigzag_vertical"),  # Zigzag circle vertically
    ("circle zigzagging horizontally", "circle", "zigzag_horizontal"),  # Zigzag circle horizontally
    ("circle moving up-left", "circle", "up_left"),  # Move circle up-left
    ("circle moving down-right", "circle", "down_right"),  # Move circle down-right
    ("circle moving down-left", "circle", "down_left"),  # Move circle down-left
]

Code some mathematical equations to move that circle based on the prompts.

In [None]:
def create_image_with_moving_shape(size, frame_num, shape, direction):
    # Create a new RGB image with specified size and white background
    img = Image.new('RGB', size, color=(255, 255, 255))

    # Create a drawing context for the image
    draw = ImageDraw.Draw(img)

    # Calculate the center coordinates of the image
    center_x, center_y = size[0] // 2, size[1] // 2
    position = (center_x, center_y)

    # Define a dictionary mapping directions to their respective position adjustments or image transformations
    direction_map = {
        # Adjust position downwards based on frame number
        "down": (0, frame_num * 5 % size[1]),
        # Adjust position to the left based on frame number
        "left": (-frame_num * 5 % size[0], 0),
        # Adjust position to the right based on frame number
        "right": (frame_num * 5 % size[0], 0),
        # Adjust position diagonally up and to the right
        "diagonal_up_right": (frame_num * 5 % size[0], -frame_num * 5 % size[1]),
        # Adjust position diagonally down and to the left
        "diagonal_down_left": (-frame_num * 5 % size[0], frame_num * 5 % size[1]),
        # Adjust position diagonally up and to the left
        "diagonal_up_left": (-frame_num * 5 % size[0], -frame_num * 5 % size[1]),
        # Adjust position diagonally down and to the right
        "diagonal_down_right": (frame_num * 5 % size[0], frame_num * 5 % size[1]),
        # Rotate the image clockwise based on frame number
        "rotate_clockwise": img.rotate(frame_num * 10 % 360, center=(center_x, center_y), fillcolor=(255, 255, 255)),
        # Rotate the image counter-clockwise based on frame number
        "rotate_counter_clockwise": img.rotate(-frame_num * 10 % 360, center=(center_x, center_y), fillcolor=(255, 255, 255)),
        # Adjust position for a bouncing effect vertically
        "bounce_vertical": (0, center_y - abs(frame_num * 5 % size[1] - center_y)),
        # Adjust position for a bouncing effect horizontally
        "bounce_horizontal": (center_x - abs(frame_num * 5 % size[0] - center_x), 0),
        # Adjust position for a zigzag effect vertically
        "zigzag_vertical": (0, center_y - frame_num * 5 % size[1]) if frame_num % 2 == 0 else (0, center_y + frame_num * 5 % size[1]),
        # Adjust position for a zigzag effect horizontally
        "zigzag_horizontal": (center_x - frame_num * 5 % size[0], center_y) if frame_num % 2 == 0 else (center_x + frame_num * 5 % size[0], center_y),
        # Adjust position upwards and to the right based on frame number
        "up_right": (frame_num * 5 % size[0], -frame_num * 5 % size[1]),
        # Adjust position upwards and to the left based on frame number
        "up_left": (-frame_num * 5 % size[0], -frame_num * 5 % size[1]),
        # Adjust position downwards and to the right based on frame number
        "down_right": (frame_num * 5 % size[0], frame_num * 5 % size[1]),
        # Adjust position downwards and to the left based on frame number
        "down_left": (-frame_num * 5 % size[0], frame_num * 5 % size[1])
    }

    # Check if direction is in the direction map
    if direction in direction_map:
        # Check if the direction maps to a position adjustment
        if isinstance(direction_map[direction], tuple):
            # Update position based on the adjustment
            position = tuple(np.add(position, direction_map[direction]))
        else:  # If the direction maps to an image transformation
            # Update the image based on the transformation
            img = direction_map[direction]

    # Draw the shape (circle) at the calculated position
    if shape == "circle":
        draw.ellipse([position[0] - shape_size // 2, position[1] - shape_size // 2, position[0] + shape_size // 2, position[1] + shape_size // 2], fill=(0, 0, 255))

    return np.array(img)

The function above is used to move the circle for each frame based on the selected direction. Run a loop on top of it up to the number of videos times to generate all videos.

In [None]:
for i in range(num_videos):
    # Randomly choose a prompt and movement from the predefined list
    prompt, shape, direction = random.choice(prompts_and_movements)

    # Create a directory for the current video
    video_dir = f'training_dataset/video_{i}'
    os.makedirs(video_dir, exist_ok=True)

    # Write the chosen prompt to a text file in the video directory
    with open(f'{video_dir}/prompt.txt', 'w') as f:
        f.write(prompt)

    # Generate frames for the current video
    for frame_num in range(frames_per_video):
        # Create an image with a moving shape based on the current frame number, shape, and direction
        img = create_image_with_moving_shape(img_size, frame_num, shape, direction)

        # Save the generated image as a PNG file in the video directory
        cv2.imwrite(f'{video_dir}/frame_{frame_num}.png', img)

The training dataset hasn’t included the **motion of the circle moving up and then to the right**. This will be used as testing prompt for unseen data.

The training dataset contains many samples where objects moving away from the scene or appear partially in front of the camera. This help test whether the model can maintain consistency when the circle enters the scene from the very corner without breaking its shape.

<img src='https://miro.medium.com/v2/resize:fit:1100/format:webp/1*RP5M_TEt2H4Mo6OhnlcRLA.gif' width='500'>

## Data Pre-Processing

In [None]:
class TextToVideoDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        # Initialize the dataset with root directory and optional transform
        self.root_dir = root_dir
        self.transform = transform
        # List all subdirectories in the root directory
        self.video_dirs = [os.path.join(root_dir, d) for d in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, d))]
        # Initialize lists to store frame paths and corresponding prompts
        self.frame_paths = []
        self.prompts = []

        for video_dir in self.video_dirs:
            # List all PNG files in the video directory and store their paths
            frames = [os.path.join(video_dir, f) for f in os.listdir(video_dir) if f.endswith('.png')]
            self.frame_paths.extend(frames)
            # Read the prompt text file in the video directory and store its content
            with open(os.path.join(video_dir, 'prompt.txt'), 'r') as f:
                prompt = f.read().strip()
            # Repeat the prompt for each frame in the video and store in prompts list
            self.prompts.extend([prompt] * len(frames))

    def __len__(self):
        return len(self.frame_paths)

    # Retrieve a sample from the dataset given an index
    def __getitem__(self, idx):
        # Get the prompt corresponding to the given index
        frame_path = self.frame_paths[idx]
        image = Image.open(frame_path)
        prompt = self.prompts[idx]

        # Apply transformation if specified
        if self.transform:
            image = self.transform(image)

        return image, prompt

In [None]:
# Define a set of transformations to be applied to the data
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)) # Normalize image with mean and standard deviation
])

# Load the dataset
dataset = TextToVideoDataset(root_dir='training_dataset', transform=transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=16, shuffle=True)

## Text Embedding Layer

In [8]:
class TextEmbedding(nn.Module):
    def __init__(self, vocab_size, embed_size):
        super(TextEmbedding, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_size)

    def forward(self, x):
        return self.embedding(x)

## Generator Layer

The `Generator` is responsible for creating video frames from a combination of random noise and text embeddings. It aims to produce realistic video frames conditioned on the given text descriptions.

In [None]:
class Generator(nn.Module):
    def __init__(self, text_embed_size):
        super(Generator, self).__init__()

        # Fully connected layer that takes noise and text embedding as input
        self.fc1 = nn.Linear(100 + text_embed_size, 256 * 8 * 8)

        # Transposed convolutional layers to upsample the input
        self.deconv1 = nn.ConvTranspose2d(256, 128, 4, 2, 1)
        self.deconv2 = nn.ConvTranspose2d(128, 64, 4, 2, 1)
        self.deconv3 = nn.ConvTranspose2d(64, 3, 4, 2, 1)  # Output has 3 channels for RGB images

        # Activation functions
        self.relu = nn.ReLU(True)
        self.tanh = nn.Tanh()

    def forward(self, noise, text_embed):
        # Concatenate noise and text embedding along the channel dimension
        x = torch.cat((noise, text_embed), dim=1)

        # Fully connected layer followed by reshaping to 4D tensor
        x = self.fc1(x).view(-1, 256, 8, 8)

        # Upsampling through transposed convolution layers with ReLU activation
        x = self.relu(self.deconv1(x))
        x = self.relu(self.deconv2(x))

        # Final layer with Tanh activation to ensure output values are between -1 and 1 (for images)
        x = self.tanh(self.deconv3(x))

        return x


## Discriminator Layer

The `Discriminator` class functions as a binary classifier that distinguishes between real and generated video frames. Its purpose is to evaluate the authenticity of video frames.

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        # Convolutional layers to process input images
        self.conv1 = nn.Conv2d(3, 64, 4, 2, 1)    # 3 input channels (RGB), 64 output channels, kernel size 4x4, stride 2, padding 1
        self.conv2 = nn.Conv2d(64, 128, 4, 2, 1)  # 64 input channels, 128 output channels, kernel size 4x4, stride 2, padding 1
        self.conv3 = nn.Conv2d(128, 256, 4, 2, 1) # 128 input channels, 256 output channels, kernel size 4x4, stride 2, padding 1

        # Fully connected layer for classification
        self.fc1 = nn.Linear(256 * 8 * 8, 1)  # Input size 256x8x8 (output size of last convolution), output size 1 (binary classification)

        # Activation functions
        self.leaky_relu = nn.LeakyReLU(0.2, inplace=True)  # negative slope 0.2
        self.sigmoid = nn.Sigmoid()

    def forward(self, input):
        # Pass input through convolutional layers with LeakyReLU activation
        x = self.leaky_relu(self.conv1(input))
        x = self.leaky_relu(self.conv2(x))
        x = self.leaky_relu(self.conv3(x))

        # Flatten the output of convolutional layers
        x = x.view(-1, 256 * 8 * 8)

        # Pass through fully connected layer with Sigmoid activation for binary classification
        x = self.sigmoid(self.fc1(x))

        return x

## Training

Training Parameters

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Create a simple vocabulary for text prompts
all_prompts = [prompt for prompt, _, _ in prompts_and_movements]  # Extract all prompts from prompts_and_movements list
vocab = {word: idx for idx, word in enumerate(set(" ".join(all_prompts).split()))}  # Create a vocabulary dictionary where each unique word is assigned an index
vocab_size = len(vocab)  # Size of the vocabulary
embed_size = 10  # Size of the text embedding vector

def encode_text(prompt):
    # Encode a given prompt into a tensor of indices using the vocabulary
    return torch.tensor([vocab[word] for word in prompt.split()])

# Initialize models, loss function, and optimizers
text_embedding = TextEmbedding(vocab_size, embed_size).to(device)
netG = Generator(embed_size).to(device)
netD = Discriminator().to(device)
criterion = nn.BCELoss().to(device)
optimizerD = optim.Adam(netD.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=0.0002, betas=(0.5, 0.999))

Training Loop

In [None]:
num_epochs = 13
for epoch in tqdm(range(num_epochs)):
    for i, (data, prompts) in enumerate(dataloader):
        real_data = data.to(device)

        # Convert prompts to list
        prompts = [prompt for prompt in prompts]

        # Update Discriminator
        netD.zero_grad()
        batch_size = real_data.size(0)
        labels = torch.ones(batch_size, 1).to(device)  # Create labels for real data (ones)
        output = netD(real_data)  # Forward pass real data through Discriminator
        lossD_real = criterion(output, labels)  # Calculate loss on real data
        lossD_real.backward()  # Calculate gradients

        # Generate fake data
        noise = torch.randn(batch_size, 100).to(device)  # Generate random noise
        text_embeds = torch.stack([text_embedding(encode_text(prompt).to(device)).mean(dim=0) for prompt in prompts])  # Encode prompts into text embeddings
        fake_data = netG(noise, text_embeds)  # Generate fake data from noise and text embeddings
        labels = torch.zeros(batch_size, 1).to(device)  # Create labels for fake data (zeros)
        output = netD(fake_data.detach())  # Forward pass fake data through Discriminator (detach to avoid gradients flowing back to Generator)
        lossD_fake = criterion(output, labels)  # Calculate loss on fake data
        lossD_fake.backward()  # Calculate gradients
        optimizerD.step()  # Update Discriminator parameters

        # Update Generator
        netG.zero_grad()  # Zero the gradients of the Generator
        labels = torch.ones(batch_size, 1).to(device)  # Create labels for fake data (ones) to fool Discriminator
        output = netD(fake_data)  # Forward pass fake data (now updated) through Discriminator
        lossG = criterion(output, labels)  # Calculate loss for Generator based on Discriminator's response
        lossG.backward()  # Calculate gradients
        optimizerG.step()  # Update Generator parameters

    print(f"Epoch [{epoch + 1}/{num_epochs}] Loss D: {lossD_real + lossD_fake}, Loss G: {lossG}")

  8%|▊         | 1/13 [05:58<1:11:43, 358.59s/it]

Epoch [1/13] Loss D: 1.000091552734375, Loss G: 1.2070484161376953


 15%|█▌        | 2/13 [12:02<1:06:15, 361.44s/it]

Epoch [2/13] Loss D: 1.1845459938049316, Loss G: 1.0169572830200195


 23%|██▎       | 3/13 [18:04<1:00:20, 362.08s/it]

Epoch [3/13] Loss D: 1.1252405643463135, Loss G: 0.9279786348342896


 31%|███       | 4/13 [24:01<54:00, 360.03s/it]  

Epoch [4/13] Loss D: 1.079437255859375, Loss G: 0.9062992334365845


 38%|███▊      | 5/13 [29:45<47:12, 354.06s/it]

Epoch [5/13] Loss D: 1.0015060901641846, Loss G: 1.0048640966415405


 46%|████▌     | 6/13 [35:25<40:45, 349.34s/it]

Epoch [6/13] Loss D: 1.0681159496307373, Loss G: 1.0402806997299194


 54%|█████▍    | 7/13 [41:11<34:49, 348.23s/it]

Epoch [7/13] Loss D: 1.1377320289611816, Loss G: 0.9768543243408203


 62%|██████▏   | 8/13 [46:50<28:47, 345.48s/it]

Epoch [8/13] Loss D: 1.1480485200881958, Loss G: 0.9558865427970886


 69%|██████▉   | 9/13 [52:33<22:58, 344.64s/it]

Epoch [9/13] Loss D: 1.3059004545211792, Loss G: 0.930357038974762


 77%|███████▋  | 10/13 [58:12<17:08, 342.83s/it]

Epoch [10/13] Loss D: 0.951496958732605, Loss G: 1.018438458442688


 85%|████████▍ | 11/13 [1:03:54<11:25, 342.57s/it]

Epoch [11/13] Loss D: 0.9797308444976807, Loss G: 0.9381852746009827


 92%|█████████▏| 12/13 [1:09:36<05:42, 342.40s/it]

Epoch [12/13] Loss D: 1.0248430967330933, Loss G: 0.9506300687789917


100%|██████████| 13/13 [1:15:16<00:00, 347.43s/it]

Epoch [13/13] Loss D: 0.9496458768844604, Loss G: 1.087607502937317





Through backpropagation, the loss will be adjusted for both the generator and discriminator.

There is a high risk of overfitting. If we had a more diverse dataset, we could consider using higher epochs.

## Saving the Model

In [14]:
torch.save(netG.state_dict(), './model/generator.pth')
torch.save(netD.state_dict(), './model/discriminator.pth')

## Generating AI Video

The motion where the circle moves up and then to the right is not present in the training data, so the model is unfamiliar with this specific motion. However, it has been trained on other motions, so use this motion as a prompt to test the trained model.

In [15]:
def generate_video(text_prompt, num_frames=10):
    # Create a directory for the generated video frames based on the text prompt
    os.makedirs(f'generated_video_{text_prompt.replace(" ", "_")}', exist_ok=True)

    # Encode the text prompt into a text embedding tensor
    text_embed = text_embedding(encode_text(text_prompt).to(device)).mean(dim=0).unsqueeze(0)

    # Generate frames for the video
    for frame_num in range(num_frames):
        # Generate random noise
        noise = torch.randn(1, 100).to(device)

        # Generate a fake frame using the Generator network
        with torch.no_grad():
            fake_frame = netG(noise, text_embed)

        # Save the generated fake frame as an image file
        save_image(fake_frame, f'generated_video_{text_prompt.replace(" ", "_")}/frame_{frame_num}.png')

generate_video('circle moving up-right')

Merge all the frames of the generated video into a single short video.

In [None]:
# Define the path to folder containing the PNG frames
folder_path = 'generated_video_circle_moving_up-right'

# Get the list of all PNG files in the folder
image_files = [f for f in os.listdir(folder_path) if f.endswith('.png')]

# Sort the images by name (they are numbered sequentially)
image_files.sort()

# Create a list to store the frames
frames = []
for image_file in image_files:
    image_path = os.path.join(folder_path, image_file)
    frame = cv2.imread(image_path)
    frames.append(frame)

frames = np.array(frames)
fps = 10 # frames per second

# Create a video writer object
fourcc = cv2.VideoWriter_fourcc(*'XVID')
out = cv2.VideoWriter('generated_video.avi', fourcc, fps, (frames[0].shape[1], frames[0].shape[0]))

# Write each frame to the video
for frame in frames:
    out.write(frame)

# Release the video writer
out.release()

# Reference

[Building an AI Text-to-Video Model from Scratch Using Python](https://levelup.gitconnected.com/building-an-ai-text-to-video-model-from-scratch-using-python-35b4eb4002de)