<a href="https://colab.research.google.com/github/ZahraDehghanian97/LenseCraft/blob/master/Cinematography_Instruction_Analysis.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#Imports and Helper Functions

##Imports

In [None]:
import json
import math
from enum import Enum
from typing import List, Dict
import random

from google.colab import files

import numpy as np
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from transformers import CLIPTextModel, CLIPTokenizer
from tqdm import tqdm
from sklearn.model_selection import train_test_split

In [None]:
!gdown -O /content/random_simulation_dataset.json 1VT2XfBj9LFWLUBjv65dzC4bVzH0zdNDU

Downloading...
From (original): https://drive.google.com/uc?id=1VT2XfBj9LFWLUBjv65dzC4bVzH0zdNDU
From (redirected): https://drive.google.com/uc?id=1VT2XfBj9LFWLUBjv65dzC4bVzH0zdNDU&confirm=t&uuid=1e4c29b9-0ef5-4191-8e9a-ad1390da2b81
To: /content/random_simulation_dataset.json
100% 105M/105M [00:00<00:00, 146MB/s] 


##Helpers

In [None]:
class CameraMovementType(Enum):
    panLeft = "panLeft"
    panRight = "panRight"
    tiltUp = "tiltUp"
    tiltDown = "tiltDown"
    dollyIn = "dollyIn"
    dollyOut = "dollyOut"
    truckLeft = "truckLeft"
    truckRight = "truckRight"
    pedestalUp = "pedestalUp"
    pedestalDown = "pedestalDown"
    fullZoomIn = "fullZoomIn"
    fullZoomOut = "fullZoomOut"
    halfZoomIn = "halfZoomIn"
    halfZoomOut = "halfZoomOut"
    shortZoomIn = "shortZoomIn"
    shortZoomOut = "shortZoomOut"
    shortArcShotLeft = "shortArcShotLeft"
    shortArcShotRight = "shortArcShotRight"
    halfArcShotLeft = "halfArcShotLeft"
    halfArcShotRight = "halfArcShotRight"
    fullArcShotLeft = "fullArcShotLeft"
    fullArcShotRight = "fullArcShotRight"
    panAndTilt = "panAndTilt"
    dollyAndPan = "dollyAndPan"
    zoomAndTruck = "zoomAndTruck"

class EasingType(Enum):
    linear = "linear"
    easeInQuad = "easeInQuad"
    easeInCubic = "easeInCubic"
    easeInQuart = "easeInQuart"
    easeInQuint = "easeInQuint"
    easeOutQuad = "easeOutQuad"
    easeOutCubic = "easeOutCubic"
    easeOutQuart = "easeOutQuart"
    easeOutQuint = "easeOutQuint"
    easeInOutQuad = "easeInOutQuad"
    easeInOutCubic = "easeInOutCubic"
    easeInOutQuart = "easeInOutQuart"
    easeInOutQuint = "easeInOutQuint"
    easeInSine = "easeInSine"
    easeOutSine = "easeOutSine"
    easeInOutSine = "easeInOutSine"
    easeInExpo = "easeInExpo"
    easeOutExpo = "easeOutExpo"
    easeInOutExpo = "easeInOutExpo"
    easeInCirc = "easeInCirc"
    easeOutCirc = "easeOutCirc"
    easeInOutCirc = "easeInOutCirc"
    easeInBounce = "easeInBounce"
    easeOutBounce = "easeOutBounce"
    easeInOutBounce = "easeInOutBounce"
    easeInElastic = "easeInElastic"
    easeOutElastic = "easeOutElastic"
    easeInOutElastic = "easeInOutElastic"

class CameraAngle(Enum):
    lowAngle = "lowAngle"
    mediumAngle = "mediumAngle"
    highAngle = "highAngle"
    birdsEyeView = "birdsEyeView"

class ShotType(Enum):
    closeUp = "closeUp"
    mediumShot = "mediumShot"
    longShot = "longShot"


movement_descriptions = {
    "panLeft": "panning left",
    "panRight": "panning right",
    "tiltUp": "tilting up",
    "tiltDown": "tilting down",
    "dollyIn": "moving closer",
    "dollyOut": "moving away",
    "truckLeft": "moving left",
    "truckRight": "moving right",
    "pedestalUp": "rising vertically",
    "pedestalDown": "descending vertically",
    "fullZoomIn": "zooming in fully",
    "fullZoomOut": "zooming out fully",
    "halfZoomIn": "zooming in halfway",
    "halfZoomOut": "zooming out halfway",
    "shortZoomIn": "zooming in slightly",
    "shortZoomOut": "zooming out slightly",
    "shortArcShotLeft": "moving in a short arc to the left",
    "shortArcShotRight": "moving in a short arc to the right",
    "halfArcShotLeft": "moving in a half arc to the left",
    "halfArcShotRight": "moving in a half arc to the right",
    "fullArcShotLeft": "moving in a full arc to the left",
    "fullArcShotRight": "moving in a full arc to the right",
    "panAndTilt": "panning and tilting",
    "dollyAndPan": "moving and panning",
    "zoomAndTruck": "zooming and moving sideways",
}

easing_descriptions = {
    "linear": "at a constant speed",
    "easeInQuad": "slowly at first, then accelerating gradually",
    "easeInCubic": "slowly at first, then accelerating more rapidly",
    "easeInQuart": "very slowly at first, then accelerating dramatically",
    "easeInQuint": "extremely slowly at first, then accelerating very dramatically",
    "easeOutQuad": "quickly at first, then decelerating gradually",
    "easeOutCubic": "quickly at first, then decelerating more rapidly",
    "easeOutQuart": "very quickly at first, then decelerating dramatically",
    "easeOutQuint": "extremely quickly at first, then decelerating very dramatically",
    "easeInOutQuad": "gradually accelerating, then gradually decelerating",
    "easeInOutCubic": "slowly accelerating, then decelerating more rapidly",
    "easeInOutQuart": "slowly accelerating, then decelerating dramatically",
    "easeInOutQuint": "very slowly accelerating, then decelerating very dramatically",
    "easeInSine": "with a gentle start, gradually increasing in speed",
    "easeOutSine": "quickly at first, then gently decelerating",
    "easeInOutSine": "with a gentle start and end, faster in the middle",
    "easeInExpo": "starting very slowly, then accelerating exponentially",
    "easeOutExpo": "starting very fast, then decelerating exponentially",
    "easeInOutExpo": "starting and ending slowly, with rapid acceleration and deceleration in the middle",
    "easeInCirc": "starting slowly, then accelerating sharply towards the end",
    "easeOutCirc": "starting quickly, then decelerating sharply towards the end",
    "easeInOutCirc": "with sharp acceleration and deceleration at both ends",
    "easeInBounce": "with a bouncing effect that intensifies towards the end",
    "easeOutBounce": "quickly at first, then bouncing to a stop",
    "easeInOutBounce": "with a bouncing effect at both the start and end",
    "easeInElastic": "with an elastic effect that intensifies towards the end",
    "easeOutElastic": "quickly at first, then oscillating to a stop",
    "easeInOutElastic": "with an elastic effect at both the start and end",
}

angle_descriptions = {
    "lowAngle": "from a low angle",
    "mediumAngle": "from a medium angle",
    "highAngle": "from a high angle",
    "birdsEyeView": "from a bird's eye view",
}

shot_descriptions = {
    "closeUp": "in a close-up shot",
    "mediumShot": "in a medium shot",
    "longShot": "in a long shot",
}

In [None]:
def get_clip_embedding(text: str) -> torch.Tensor:
    with torch.no_grad():
        inputs = clip_tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(device)
        embedding = clip_text_encoder(**inputs).pooler_output
    return embedding

In [None]:
def extract_camera_frame_data(camera_frames):
    return [
        [
            frame['position']['x'],
            frame['position']['y'],
            frame['position']['z'],
            frame['focalLength'],
            frame['angle']['x'],
            frame['angle']['y'],
            frame['angle']['z']
        ]
        for frame in camera_frames
    ]

In [None]:
def circular_distance_loss(pred, target):
    pred = torch.clamp(pred, min=-np.pi, max=np.pi)
    target = torch.clamp(target, min=-np.pi, max=np.pi)
    
    distance = torch.abs(pred - target)
    distance = torch.where(distance > np.pi, 2 * np.pi - distance, distance)
    
    return torch.mean(distance ** 2)

In [None]:
def cosine_decay(initial_value, final_value, current_epoch, total_epochs):
    cosine_decay = 0.5 * (1 + math.cos(math.pi * current_epoch / total_epochs))
    return final_value + (initial_value - final_value) * cosine_decay

def linear_increase(initial_value, final_value, current_epoch, total_epochs):
    return initial_value + (final_value - initial_value) * (current_epoch / total_epochs)

def get_noise_and_mask_values(current_epoch, total_epochs, config):
    noise_std = cosine_decay(
        initial_value=config['initial_noise_std'],
        final_value=config['final_noise_std'],
        current_epoch=current_epoch,
        total_epochs=total_epochs
    )

    mask_ratio = linear_increase(
        initial_value=config['initial_mask_ratio'],
        final_value=config['final_mask_ratio'],
        current_epoch=current_epoch,
        total_epochs=total_epochs
    )

    return noise_std, mask_ratio

def apply_mask_and_noise(data, mask_ratio=0.0, noise_std=0.0):
    mask = torch.bernoulli(torch.full((data.shape[0], data.shape[1]), 1 - mask_ratio, device=device)).bool()

    masked_data = data.clone()
    masked_data[~mask] = 0

    noisy_data = masked_data + torch.normal(mean=0, std=noise_std, size=data.shape, device=device)

    src_key_padding_mask = ~mask

    return noisy_data, mask, src_key_padding_mask

In [None]:
def generate_random_instruction():
    return {
        'movement': random.choice(list(CameraMovementType)),
        'easing': random.choice(list(EasingType)),
        'camera_angle': random.choice(list(CameraAngle)),
        'shot_type': random.choice(list(ShotType))
    }

def instruction_to_latent(instruction, model):
    movement_clip = movement_clip_dict[instruction['movement'].value].unsqueeze(0)
    easing_clip = easing_clip_dict[instruction['easing'].value].unsqueeze(0)
    angle_clip = angle_clip_dict[instruction['camera_angle'].value].unsqueeze(0)
    shot_clip = shot_clip_dict[instruction['shot_type'].value].unsqueeze(0)

    with torch.no_grad():
        latent = model.merge_latents(movement_clip, easing_clip, angle_clip, shot_clip)

    return latent

In [None]:
def process_camera_trajectory(trajectory: torch.Tensor) -> List[Dict]:
    trajectory = trajectory.cpu().detach().numpy()
    camera_frames = []
    for frame in trajectory:
        position, angle, focal_length = frame[:3], frame[3:6], frame[6]
        camera_frame = {
            "position": {axis: float(value) for axis, value in zip(['x', 'y', 'z'], position)},
            "angle": {axis: float(value) for axis, value in zip(['x', 'y', 'z'], angle)},
            "focalLength": float(focal_length)
        }
        camera_frames.append(camera_frame)
    return camera_frames

def save_to_json(trajectory: torch.Tensor, filename: str):
    camera_frames = process_camera_trajectory(trajectory)
    with open(filename, 'w') as f:
        json.dump(camera_frames, f, indent=2)

#Implementation

In [None]:
clip_model_name = "openai/clip-vit-large-patch14" #@param ["openai/clip-vit-base-patch32", "openai/clip-vit-large-patch14"]

if clip_model_name == "openai/clip-vit-large-patch14":
    clip_embedding_dim = 768
else:
    clip_embedding_dim = 512

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
clip_tokenizer = CLIPTokenizer.from_pretrained(clip_model_name)
clip_text_encoder = CLIPTextModel.from_pretrained(clip_model_name).to(device)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


In [None]:
movement_clip_embeddings = get_clip_embedding(list(movement_descriptions.values()))
easing_clip_embeddings = get_clip_embedding(list(easing_descriptions.values()))
angle_clip_embeddings = get_clip_embedding(list(angle_descriptions.values()))
shot_clip_embeddings = get_clip_embedding(list(shot_descriptions.values()))

movement_clip_dict = {k: movement_clip_embeddings[i] for i, k in enumerate(movement_descriptions.keys())}
easing_clip_dict = {k: easing_clip_embeddings[i] for i, k in enumerate(easing_descriptions.keys())}
angle_clip_dict = {k: angle_clip_embeddings[i] for i, k in enumerate(angle_descriptions.keys())}
shot_clip_dict = {k: shot_clip_embeddings[i] for i, k in enumerate(shot_descriptions.keys())}

Processing batches: 100%|██████████| 140/140 [00:29<00:00,  4.73it/s]


In [None]:
class SimulationDataset(Dataset):
    def __init__(self, json_file_path: str):
        with open(json_file_path, 'r') as file:
            raw_data = json.load(file)
        self.simulation_data = [self._process_single_simulation(sim) for sim in raw_data['simulations']
                if self._is_simulation_valid(sim)]

    def __len__(self):
        return len(self.simulation_data)

    def __getitem__(self, index):
        return self.simulation_data[index]

    def _is_simulation_valid(self, simulation):
        return (len(simulation['instructions']) == 1 and
                simulation['instructions'][0]['frameCount'] == 30 and
                len(simulation['cameraFrames']) == 30)

    def _process_single_simulation(self, simulation):
        instruction = simulation['instructions'][0]
        subject = simulation['subjects'][0]

        camera_trajectory = extract_camera_frame_data(simulation['cameraFrames'])

        movement_type = CameraMovementType[instruction['cameraMovement']]
        easing_type = EasingType[instruction['movementEasing']]
        camera_angle = CameraAngle[instruction.get('initialCameraAngle', 'mediumAngle')]
        shot_type = ShotType[instruction.get('initialShotType', 'mediumShot')]

        subject_data = [
            subject['position']['x'], subject['position']['y'], subject['position']['z'],
            subject['size']['x'], subject['size']['y'], subject['size']['z'],
            subject['rotation']['x'], subject['rotation']['y'], subject['rotation']['z']
        ]

        return {
            'camera_trajectory': torch.tensor(camera_trajectory, dtype=torch.float32),
            'subject': torch.tensor(subject_data, dtype=torch.float32),
            'movement_type': movement_type,
            'easing_type': easing_type,
            'camera_angle': camera_angle,
            'shot_type': shot_type
        }

def batch_collate(batch):
    return {
        'camera_trajectory': torch.stack([item['camera_trajectory'] for item in batch]),
        'subject': torch.stack([item['subject'] for item in batch]),
        'movement_type': [item['movement_type'] for item in batch],
        'easing_type': [item['easing_type'] for item in batch],
        'camera_angle': [item['camera_angle'] for item in batch],
        'shot_type': [item['shot_type'] for item in batch]
    }

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(0, max_len).unsqueeze(1).float()
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, d_model)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:, :x.size(1)]
        return self.dropout(x)

In [None]:
class Encoder(nn.Module):
    def __init__(self, input_dim, latent_dim, nhead, num_encoder_layers, dim_feedforward, dropout_rate):
        super(Encoder, self).__init__()

        self.input_projection = nn.Linear(input_dim, latent_dim)
        self.pos_encoder = PositionalEncoding(latent_dim)

        encoder_layer = nn.TransformerEncoderLayer(d_model=latent_dim, nhead=nhead, dim_feedforward=dim_feedforward, dropout=dropout_rate)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_encoder_layers)

        self.query_tokens = nn.ParameterDict({
            f"{qt}_query": nn.Parameter(torch.randn(1, 1, latent_dim))
            for qt in ['movement', 'easing', 'camera_angle', 'shot_type']
        })

    def forward(self, src, subject_embedded, src_key_padding_mask=None):
        src_embedded = self.input_projection(src)
        src_embedded = torch.cat([subject_embedded, src_embedded], dim=1)
        src_embedded = self.pos_encoder(src_embedded)
        src_embedded = src_embedded.permute(1, 0, 2)

        query_tokens = torch.cat([query.repeat(1, src_embedded.shape[1], 1) for query in self.query_tokens.values()], dim=0)
        src_with_queries = torch.cat([query_tokens, src_embedded], dim=0)

        if src_key_padding_mask is not None:
            subject_mask = torch.zeros((src_key_padding_mask.shape[0], 1), dtype=torch.bool, device=device)
            src_key_padding_mask = torch.cat([subject_mask, src_key_padding_mask], dim=1)
            query_mask = torch.zeros((src_key_padding_mask.shape[0], len(self.query_tokens)), dtype=torch.bool, device=device)
            src_key_padding_mask = torch.cat([query_mask, src_key_padding_mask], dim=1)

        memory = self.transformer_encoder(src_with_queries, src_key_padding_mask=src_key_padding_mask)

        return memory[:len(self.query_tokens)]



class Decoder(nn.Module):
    def __init__(self, output_dim, latent_dim, nhead, num_decoder_layers, dim_feedforward, dropout_rate, seq_length):
        super(Decoder, self).__init__()

        self.pos_encoder = PositionalEncoding(latent_dim)
        self.embedding = nn.Linear(output_dim, latent_dim)

        decoder_layer = nn.TransformerDecoderLayer(d_model=latent_dim, nhead=nhead, dim_feedforward=dim_feedforward, dropout=dropout_rate)
        self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_decoder_layers)

        self.output_projection = nn.Linear(latent_dim, output_dim)

    def forward(self, memory, decoder_input, subject_embedded, tgt_mask):
        embedded = self.embedding(decoder_input)
        embedded = torch.cat([subject_embedded, embedded], dim=1)
        embedded = self.pos_encoder(embedded)
        embedded = embedded.transpose(0, 1)

        output = self.transformer_decoder(embedded, memory, tgt_mask=tgt_mask)
        output = output.transpose(0, 1)
        output = self.output_projection(output[:, 1:, :])  # Remove subject from output

        return output

class MultiTaskAutoencoder(nn.Module):
    def __init__(self, input_dim=7, subject_dim=9, nhead=4, num_encoder_layers=3, num_decoder_layers=3, dim_feedforward=2048,
                 dropout_rate=0.1, seq_length=30, latent_dim=512):
        super(MultiTaskAutoencoder, self).__init__()

        self.subject_projection = nn.Linear(subject_dim, latent_dim)
        self.encoder = Encoder(input_dim, latent_dim, nhead, num_encoder_layers, dim_feedforward, dropout_rate)
        self.decoder = Decoder(input_dim, latent_dim, nhead, num_decoder_layers, dim_feedforward, dropout_rate, seq_length)

        self.latent_merger = nn.Linear(latent_dim * 4, latent_dim)

        self.seq_length = seq_length
        self.input_dim = input_dim

    def generate_square_subsequent_mask(self, sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask

    def autoregressive_decode(self, latent, subject_embedded, target=None, teacher_forcing_ratio=0.5):
        memory = latent.unsqueeze(1).repeat(1, self.seq_length, 1)
        memory = memory.transpose(0, 1)

        decoder_input = torch.zeros(latent.shape[0], 1, self.input_dim, device=device)
        outputs = []

        for t in range(self.seq_length):
            tgt_mask = self.generate_square_subsequent_mask(t + 2).to(device)
            
            output = self.decoder(memory, decoder_input, subject_embedded, tgt_mask)
            outputs.append(output[:, -1:, :])

            if target is not None and torch.rand(1).item() < teacher_forcing_ratio:
                decoder_input = torch.cat([decoder_input, target[:, t:t+1, :]], dim=1)
            else:
                decoder_input = torch.cat([decoder_input, output[:, -1:, :]], dim=1)

        return torch.cat(outputs, dim=1)

    def merge_latents(self, movement_embedding, easing_embedding, camera_angle_embedding, shot_type_embedding):
        combined = torch.cat([movement_embedding, easing_embedding, camera_angle_embedding, shot_type_embedding], dim=-1)
        return self.latent_merger(combined)

    def forward(self, src, subject, src_key_padding_mask=None, target=None, teacher_forcing_ratio=0.5):
        subject_embedded = self.subject_projection(subject).unsqueeze(1)
        movement_embedding, easing_embedding, camera_angle_embedding, shot_type_embedding = self.encoder(src, subject_embedded, src_key_padding_mask)
        latent = self.merge_latents(movement_embedding, easing_embedding, camera_angle_embedding, shot_type_embedding)
        reconstructed = self.autoregressive_decode(latent, subject_embedded, target, teacher_forcing_ratio)

        return {
            'movement_embedding': movement_embedding,
            'easing_embedding': easing_embedding,
            'camera_angle_embedding': camera_angle_embedding,
            'shot_type_embedding': shot_type_embedding,
            'reconstructed': reconstructed,
        }

In [None]:
def init_losses():
    return {
        'total': 0,
        'reconstruction': 0,
        'clip_movement': 0,
        'clip_easing': 0,
        'clip_camera_angle': 0,
        'clip_shot_type': 0
    }

def print_detailed_losses(phase, losses):
    print(f"{phase} Losses:", end=' ')
    for key, value in losses.items():
        print(f"  {key}: {value:.4f}", end=' ')
    print()

In [None]:
def train_epoch(model, dataloader, optimizer, criterion, noise_std=0.0, mask_ratio=0.0, teacher_forcing_ratio=0.5):
    model.train()
    total_losses = init_losses()

    for batch in tqdm(dataloader, desc="Training"):
        losses = process_batch(model, batch, criterion, noise_std, mask_ratio, teacher_forcing_ratio)

        loss = losses['total']
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        for key in total_losses:
            total_losses[key] += losses[key].item() if torch.is_tensor(losses[key]) else losses[key]

    return {k: v / len(dataloader) for k, v in total_losses.items()}

def validate(model, dataloader, criterion):
    model.eval()
    total_losses = init_losses()

    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Validating"):
            losses = process_batch(model, batch, criterion, teacher_forcing_ratio=0)

            for key in total_losses:
                total_losses[key] += losses[key].item() if torch.is_tensor(losses[key]) else losses[key]

    return {k: v / len(dataloader) for k, v in total_losses.items()}


def process_batch(model, batch, criterion, noise_std=0.0, mask_ratio=0.0, teacher_forcing_ratio=0.5):
    camera_trajectory = batch['camera_trajectory'].to(device)
    subject = batch['subject'].to(device)
    movement_clip = torch.stack([movement_clip_dict[t.value] for t in batch['movement_type']]).to(device)
    easing_clip = torch.stack([easing_clip_dict[t.value] for t in batch['easing_type']]).to(device)
    angle_clip = torch.stack([angle_clip_dict[t.value] for t in batch['camera_angle']]).to(device)
    shot_clip = torch.stack([shot_clip_dict[t.value] for t in batch['shot_type']]).to(device)

    noisy_trajectory, mask, src_key_padding_mask = apply_mask_and_noise(camera_trajectory, mask_ratio, noise_std)

    target = camera_trajectory if model.training else None

    output = model(noisy_trajectory, subject, src_key_padding_mask, target, teacher_forcing_ratio)

    masked_output = output['reconstructed'][mask]
    masked_target = camera_trajectory[mask]

    position_output = masked_output[:, :5]
    position_target = masked_target[:, :5]
    angle_output = masked_output[:, 5:]
    angle_target = masked_target[:, 5:]

    position_loss = criterion['reconstruction'](position_output, position_target)
    angle_loss = circular_distance_loss(angle_output, angle_target)

    losses = {
        'reconstruction': position_loss + angle_loss,
        'clip_movement': 1 - F.cosine_similarity(output['movement_embedding'], movement_clip).mean(),
        'clip_easing': 1 - F.cosine_similarity(output['easing_embedding'], easing_clip).mean(),
        'clip_camera_angle': 1 - F.cosine_similarity(output['camera_angle_embedding'], angle_clip).mean(),
        'clip_shot_type': 1 - F.cosine_similarity(output['shot_type_embedding'], shot_clip).mean(),
    }
    losses['total'] = sum(losses.values())

    return losses

def train_model(model, train_dataloader, val_dataloader, config):
    optimizer = optim.Adam(model.parameters(), lr=config['learning_rate'], weight_decay=config['weight_decay'])
    criterion = {
        'reconstruction': nn.MSELoss()
    }

    best_val_loss = float('inf')
    epochs_without_improvement = 0

    for epoch in range(config['num_epochs']):
        current_teacher_forcing_ratio = config['init_teacher_forcing_ratio'] * (1 - epoch / config['num_epochs'])
        current_noise_std, current_mask_ratio = get_noise_and_mask_values(epoch, config['num_epochs'], config)

        train_losses = train_epoch(model, train_dataloader, optimizer, criterion,
                                   current_noise_std, current_mask_ratio, current_teacher_forcing_ratio)
        val_losses = validate(model, val_dataloader, criterion)

        print(f"Epoch {epoch+1}/{config['num_epochs']}")
        print(f"Train Loss: {train_losses['total']:.4f}, Validation Loss: {val_losses['total']:.4f}")
        print_detailed_losses("Train", train_losses)
        print_detailed_losses("Valid", val_losses)

        if val_losses['total'] < best_val_loss:
            best_val_loss = val_losses['total']
            epochs_without_improvement = 0
            torch.save(model.state_dict(), 'best_model.pth')
            print("New best model saved!")
        else:
            epochs_without_improvement += 1

        if epochs_without_improvement >= config['patience']:
            print(f"Early stopping triggered after {epoch+1} epochs")
            break

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dataset = SimulationDataset('random_simulation_dataset.json')

train_dataset, val_dataset = train_test_split(dataset, test_size=0.3, random_state=42)
# train_dataset = val_dataset = [dataset[0]] * 32

batch_size = 32
num_workers = 2

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,
                              collate_fn=batch_collate, num_workers=num_workers)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False,
                            collate_fn=batch_collate, num_workers=num_workers)

model = MultiTaskAutoencoder(latent_dim=clip_embedding_dim, dropout_rate=0.2).to(device)

In [None]:
config = {
    'num_epochs': 2000,
    'patience': 30,
    'learning_rate': 0.0001,
    'weight_decay': 1e-5,
    'initial_noise_std': 0.2,
    'final_noise_std': 0.05,
    'initial_mask_ratio': 0.3,
    'final_mask_ratio': 0.7,
    'init_teacher_forcing_ratio': 0.8
}

torch.cuda.empty_cache()
train_model(model, train_dataloader, val_dataloader, config)

##Inference

In [None]:
model.load_state_dict(torch.load('best_model.pth'))
model.eval()
print()

In [None]:
file_path = '/content/result.json'

# random_instruction = generate_random_instruction()
# print("Random Instruction:", random_instruction)
instruction = {
    'movement': CameraMovementType['dollyIn'],
    'easing': EasingType['easeInQuad'],
    'camera_angle': CameraAngle['highAngle'],
    'shot_type': ShotType['longShot']
}

latent = instruction_to_latent(instruction, model)
with torch.no_grad():
    camera_frames = model.autoregressive_decode(latent).squeeze(0)
save_to_json(camera_frames, file_path)

files.download(file_path)

In [None]:
input_path = '/content/input.json'
output_path = '/content/output.json'

input_camera_tranjectory = dataset[0]['camera_trajectory'].to(device)
reconstructed_camera_tranjectory = model(input_camera_tranjectory.unsqueeze(0))['reconstructed'].squeeze(0)

save_to_json(input_camera_tranjectory, input_path)
save_to_json(reconstructed_camera_tranjectory, output_path)

files.download(input_path)
files.download(output_path)