<a href="https://colab.research.google.com/github/ZahraDehghanian97/LenseCraft/blob/master/Cinematography_Instruction_Analysis.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import json
import math
from enum import Enum
from typing import List, Dict, Any

import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from transformers import CLIPTextModel, CLIPTokenizer

In [None]:
!gdown 1VT2XfBj9LFWLUBjv65dzC4bVzH0zdNDU

Downloading...
From: https://drive.google.com/uc?id=1VT2XfBj9LFWLUBjv65dzC4bVzH0zdNDU
To: /content/random_simulation_dataset.json
  0% 0.00/1.06M [00:00<?, ?B/s]100% 1.06M/1.06M [00:00<00:00, 136MB/s]


In [None]:
class CameraMovement(Enum):
    panLeft = "panLeft"
    panRight = "panRight"
    tiltUp = "tiltUp"
    tiltDown = "tiltDown"
    dollyIn = "dollyIn"
    dollyOut = "dollyOut"
    truckLeft = "truckLeft"
    truckRight = "truckRight"
    pedestalUp = "pedestalUp"
    pedestalDown = "pedestalDown"
    fullZoomIn = "fullZoomIn"
    fullZoomOut = "fullZoomOut"
    halfZoomIn = "halfZoomIn"
    halfZoomOut = "halfZoomOut"
    shortZoomIn = "shortZoomIn"
    shortZoomOut = "shortZoomOut"
    shortArcShotLeft = "shortArcShotLeft"
    shortArcShotRight = "shortArcShotRight"
    halfArcShotLeft = "halfArcShotLeft"
    halfArcShotRight = "halfArcShotRight"
    fullArcShotLeft = "fullArcShotLeft"
    fullArcShotRight = "fullArcShotRight"
    panAndTilt = "panAndTilt"
    dollyAndPan = "dollyAndPan"
    zoomAndTruck = "zoomAndTruck"

class MovementEasing(Enum):
    linear = "linear"
    easeInQuad = "easeInQuad"
    easeInCubic = "easeInCubic"
    easeInQuart = "easeInQuart"
    easeInQuint = "easeInQuint"
    easeOutQuad = "easeOutQuad"
    easeOutCubic = "easeOutCubic"
    easeOutQuart = "easeOutQuart"
    easeOutQuint = "easeOutQuint"
    easeInOutQuad = "easeInOutQuad"
    easeInOutCubic = "easeInOutCubic"
    easeInOutQuart = "easeInOutQuart"
    easeInOutQuint = "easeInOutQuint"
    easeInSine = "easeInSine"
    easeOutSine = "easeOutSine"
    easeInOutSine = "easeInOutSine"
    easeInExpo = "easeInExpo"
    easeOutExpo = "easeOutExpo"
    easeInOutExpo = "easeInOutExpo"
    easeInCirc = "easeInCirc"
    easeOutCirc = "easeOutCirc"
    easeInOutCirc = "easeInOutCirc"
    easeInBounce = "easeInBounce"
    easeOutBounce = "easeOutBounce"
    easeInOutBounce = "easeInOutBounce"
    easeInElastic = "easeInElastic"
    easeOutElastic = "easeOutElastic"
    easeInOutElastic = "easeInOutElastic"

In [None]:
class SimulationDataset(Dataset):
    def __init__(self, file_path: str):
        with open(file_path, 'r') as file:
            data = json.load(file)

        self.processed_data = []

        for simulation in data['simulations']:
            if len(simulation['instructions']) != 1:
                continue

            instruction = simulation['instructions'][0]

            if instruction['frameCount'] != 30:
                continue

            camera_frames = simulation['cameraFrames']
            if len(camera_frames) != 30:
                continue

            flattened_frames = []
            for frame in camera_frames:
                flattened_frames.extend([frame['position']['x'], frame['position']['y'], frame['position']['z']])
                flattened_frames.extend([frame['angle']['x'], frame['angle']['y'], frame['angle']['z']])
                flattened_frames.append(frame['focalLength'])

            camera_movement = CameraMovement[instruction['cameraMovement']].value
            movement_easing = MovementEasing[instruction['movementEasing']].value

            camera_movement_index = list(CameraMovement).index(CameraMovement(camera_movement))
            movement_easing_index = list(MovementEasing).index(MovementEasing(movement_easing))

            self.processed_data.append({
                'frames': flattened_frames,
                'camera_movement': camera_movement_index,
                'movement_easing': movement_easing_index,
                'instruction_text': f"{instruction['cameraMovement']} with {instruction['movementEasing']}"
            })

    def __len__(self):
        return len(self.processed_data)

    def __getitem__(self, idx):
        item = self.processed_data[idx]
        return {
            'frames': torch.tensor(item['frames'], dtype=torch.float32),
            'camera_movement': torch.tensor(item['camera_movement'], dtype=torch.long),
            'movement_easing': torch.tensor(item['movement_easing'], dtype=torch.long),
            'instruction_text': item['instruction_text']
        }


def collate_fn(batch):
    return {
        'frames': torch.stack([item['frames'] for item in batch]),
        'camera_movement': torch.stack([item['camera_movement'] for item in batch]),
        'movement_easing': torch.stack([item['movement_easing'] for item in batch]),
        'instruction_text': [item['instruction_text'] for item in batch]
    }

In [None]:
# Sinusoidal Positional Encoding

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        return x + self.pe[:x.size(0), :]

In [None]:
class MultiTaskAutoencoder(nn.Module):
    def __init__(self, input_dim, d_model, nhead, num_encoder_layers, num_decoder_layers,
                 num_camera_movements, num_movement_easings, max_seq_length, latent_dim):
        super(MultiTaskAutoencoder, self).__init__()

        self.input_dim = input_dim
        self.d_model = d_model
        self.max_seq_length = max_seq_length
        self.latent_dim = latent_dim

        self.input_projection = nn.Linear(7, d_model)
        self.pos_encoder = PositionalEncoding(d_model, max_seq_length)

        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_encoder_layers)

        self.latent_projection = nn.Linear(d_model * max_seq_length, latent_dim)

        self.camera_movement_classifier = nn.Linear(latent_dim, num_camera_movements)
        self.movement_easing_classifier = nn.Linear(latent_dim, num_movement_easings)

        decoder_layer = nn.TransformerDecoderLayer(d_model=d_model, nhead=nhead)
        self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_decoder_layers)

        self.output_projection = nn.Linear(d_model, 7)

    def encode(self, src):
        src = src.view(-1, self.max_seq_length, 7)
        src = self.input_projection(src)
        src = src.permute(1, 0, 2)
        src = self.pos_encoder(src)
        memory = self.transformer_encoder(src)
        return memory

    def decode(self, memory, tgt):
        tgt = self.input_projection(tgt)
        tgt = self.pos_encoder(tgt)
        output = self.transformer_decoder(tgt, memory)
        output = self.output_projection(output)
        return output

    def forward(self, x):
        x = x.view(-1, self.max_seq_length, 7)

        memory = self.encode(x)

        latent = memory.permute(1, 0, 2).reshape(-1, self.d_model * self.max_seq_length)
        latent = self.latent_projection(latent)

        camera_movement_logits = self.camera_movement_classifier(latent)
        movement_easing_logits = self.movement_easing_classifier(latent)

        reconstructed_frames = []
        for i in range(self.max_seq_length):
            step_input = x[:, :i+1, :].permute(1, 0, 2)
            step_output = self.decode(memory, step_input)
            reconstructed_frames.append(step_output[-1])

        reconstructed = torch.stack(reconstructed_frames, dim=1)

        return {
            'latent': latent,
            'camera_movement_logits': camera_movement_logits,
            'movement_easing_logits': movement_easing_logits,
            'reconstructed': reconstructed.view(-1, self.max_seq_length * 7),
        }

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dataset = SimulationDataset('random_simulation_dataset.json')
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)

input_dim = 30 * 7  # 30 frames, 7 values per frame
d_model = 256
nhead = 8
num_encoder_layers = 3
num_decoder_layers = 3
num_camera_movements = len(CameraMovement)
num_movement_easings = len(MovementEasing)
max_seq_length = 30
clip_model_name = "openai/clip-vit-base-patch32"
latent_dim = 512

model = MultiTaskAutoencoder(input_dim, d_model, nhead, num_encoder_layers, num_decoder_layers,
                             num_camera_movements, num_movement_easings, max_seq_length, latent_dim).to(device)

clip_tokenizer = CLIPTokenizer.from_pretrained(clip_model_name)
clip_text_encoder = CLIPTextModel.from_pretrained(clip_model_name).to(device)

optimizer = optim.Adam(model.parameters(), lr=0.001)

classification_criterion = nn.CrossEntropyLoss()
reconstruction_criterion = nn.MSELoss()

In [None]:
num_epochs=100

for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    total_camera_movement_loss = 0
    total_movement_easing_loss = 0
    total_reconstruction_loss = 0
    total_clip_loss = 0

    for batch in dataloader:
        frames = batch['frames'].to(device)
        camera_movement = batch['camera_movement'].to(device)
        movement_easing = batch['movement_easing'].to(device)
        instruction_text = batch['instruction_text']

        optimizer.zero_grad()

        output = model(frames)

        camera_movement_loss = classification_criterion(output['camera_movement_logits'], camera_movement)
        movement_easing_loss = classification_criterion(output['movement_easing_logits'], movement_easing)
        reconstruction_loss = reconstruction_criterion(output['reconstructed'], frames)

        clip_input = clip_tokenizer(instruction_text, padding=True, truncation=True, return_tensors="pt").to(device)
        with torch.no_grad():
            clip_text_features = clip_text_encoder(**clip_input).last_hidden_state[:, 0, :]

        clip_loss = 1 - F.cosine_similarity(output['latent'], clip_text_features).mean()

        loss = camera_movement_loss + movement_easing_loss + reconstruction_loss + clip_loss

        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        total_camera_movement_loss += camera_movement_loss.item()
        total_movement_easing_loss += movement_easing_loss.item()
        total_reconstruction_loss += reconstruction_loss.item()
        total_clip_loss += clip_loss.item()

    avg_loss = total_loss / len(dataloader)
    avg_camera_movement_loss = total_camera_movement_loss / len(dataloader)
    avg_movement_easing_loss = total_movement_easing_loss / len(dataloader)
    avg_reconstruction_loss = total_reconstruction_loss / len(dataloader)
    avg_clip_loss = total_clip_loss / len(dataloader)

    print(f"Epoch [{epoch+1}/{num_epochs}]")
    print(f"Total Loss: {avg_loss:.4f}")
    print(f"Camera Movement Loss: {avg_camera_movement_loss:.4f}")
    print(f"Movement Easing Loss: {avg_movement_easing_loss:.4f}")
    print(f"Reconstruction Loss: {avg_reconstruction_loss:.4f}")
    print(f"CLIP Loss: {avg_clip_loss:.4f}")
    print("-" * 50)

Epoch [1/100]
Total Loss: 341.3379
Camera Movement Loss: 6.5060
Movement Easing Loss: 8.3289
Reconstruction Loss: 325.5145
CLIP Loss: 0.9884
--------------------------------------------------
Epoch [2/100]
Total Loss: 310.5421
Camera Movement Loss: 15.0435
Movement Easing Loss: 16.9723
Reconstruction Loss: 277.5394
CLIP Loss: 0.9869
--------------------------------------------------
Epoch [3/100]
Total Loss: 332.9131
Camera Movement Loss: 14.4351
Movement Easing Loss: 13.8457
Reconstruction Loss: 303.7138
CLIP Loss: 0.9184
--------------------------------------------------
Epoch [4/100]
Total Loss: 285.4020
Camera Movement Loss: 11.6725
Movement Easing Loss: 8.5087
Reconstruction Loss: 264.2918
CLIP Loss: 0.9289
--------------------------------------------------
Epoch [5/100]
Total Loss: 260.6937
Camera Movement Loss: 10.1781
Movement Easing Loss: 9.7378
Reconstruction Loss: 239.8354
CLIP Loss: 0.9423
--------------------------------------------------
Epoch [6/100]
Total Loss: 244.6559

KeyboardInterrupt: 