<a href="https://colab.research.google.com/github/ZahraDehghanian97/LenseCraft/blob/master/Cinematography_Instruction_Analysis.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#Imports and Helper Functions

##Imports

In [None]:
import json
import math
from enum import Enum
from typing import List, Dict, Any

from google.colab import files

import numpy as np
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from transformers import CLIPTextModel, CLIPTokenizer
from tqdm import tqdm
from sklearn.model_selection import train_test_split

In [None]:
!gdown -O /content/random_simulation_dataset.json 1VT2XfBj9LFWLUBjv65dzC4bVzH0zdNDU

Downloading...
From (original): https://drive.google.com/uc?id=1VT2XfBj9LFWLUBjv65dzC4bVzH0zdNDU
From (redirected): https://drive.google.com/uc?id=1VT2XfBj9LFWLUBjv65dzC4bVzH0zdNDU&confirm=t&uuid=1e4c29b9-0ef5-4191-8e9a-ad1390da2b81
To: /content/random_simulation_dataset.json
100% 105M/105M [00:00<00:00, 146MB/s] 


##Helpers

In [None]:
class CameraMovementType(Enum):
    panLeft = "panLeft"
    panRight = "panRight"
    tiltUp = "tiltUp"
    tiltDown = "tiltDown"
    dollyIn = "dollyIn"
    dollyOut = "dollyOut"
    truckLeft = "truckLeft"
    truckRight = "truckRight"
    pedestalUp = "pedestalUp"
    pedestalDown = "pedestalDown"
    fullZoomIn = "fullZoomIn"
    fullZoomOut = "fullZoomOut"
    halfZoomIn = "halfZoomIn"
    halfZoomOut = "halfZoomOut"
    shortZoomIn = "shortZoomIn"
    shortZoomOut = "shortZoomOut"
    shortArcShotLeft = "shortArcShotLeft"
    shortArcShotRight = "shortArcShotRight"
    halfArcShotLeft = "halfArcShotLeft"
    halfArcShotRight = "halfArcShotRight"
    fullArcShotLeft = "fullArcShotLeft"
    fullArcShotRight = "fullArcShotRight"
    panAndTilt = "panAndTilt"
    dollyAndPan = "dollyAndPan"
    zoomAndTruck = "zoomAndTruck"

class EasingType(Enum):
    linear = "linear"
    easeInQuad = "easeInQuad"
    easeInCubic = "easeInCubic"
    easeInQuart = "easeInQuart"
    easeInQuint = "easeInQuint"
    easeOutQuad = "easeOutQuad"
    easeOutCubic = "easeOutCubic"
    easeOutQuart = "easeOutQuart"
    easeOutQuint = "easeOutQuint"
    easeInOutQuad = "easeInOutQuad"
    easeInOutCubic = "easeInOutCubic"
    easeInOutQuart = "easeInOutQuart"
    easeInOutQuint = "easeInOutQuint"
    easeInSine = "easeInSine"
    easeOutSine = "easeOutSine"
    easeInOutSine = "easeInOutSine"
    easeInExpo = "easeInExpo"
    easeOutExpo = "easeOutExpo"
    easeInOutExpo = "easeInOutExpo"
    easeInCirc = "easeInCirc"
    easeOutCirc = "easeOutCirc"
    easeInOutCirc = "easeInOutCirc"
    easeInBounce = "easeInBounce"
    easeOutBounce = "easeOutBounce"
    easeInOutBounce = "easeInOutBounce"
    easeInElastic = "easeInElastic"
    easeOutElastic = "easeOutElastic"
    easeInOutElastic = "easeInOutElastic"

class CameraAngle(Enum):
    lowAngle = "lowAngle"
    mediumAngle = "mediumAngle"
    highAngle = "highAngle"
    birdsEyeView = "birdsEyeView"

class ShotType(Enum):
    closeUp = "closeUp"
    mediumShot = "mediumShot"
    longShot = "longShot"


movement_descriptions = {
    "panLeft": "panning left",
    "panRight": "panning right",
    "tiltUp": "tilting up",
    "tiltDown": "tilting down",
    "dollyIn": "moving closer",
    "dollyOut": "moving away",
    "truckLeft": "moving left",
    "truckRight": "moving right",
    "pedestalUp": "rising vertically",
    "pedestalDown": "descending vertically",
    "fullZoomIn": "zooming in fully",
    "fullZoomOut": "zooming out fully",
    "halfZoomIn": "zooming in halfway",
    "halfZoomOut": "zooming out halfway",
    "shortZoomIn": "zooming in slightly",
    "shortZoomOut": "zooming out slightly",
    "shortArcShotLeft": "moving in a short arc to the left",
    "shortArcShotRight": "moving in a short arc to the right",
    "halfArcShotLeft": "moving in a half arc to the left",
    "halfArcShotRight": "moving in a half arc to the right",
    "fullArcShotLeft": "moving in a full arc to the left",
    "fullArcShotRight": "moving in a full arc to the right",
    "panAndTilt": "panning and tilting",
    "dollyAndPan": "moving and panning",
    "zoomAndTruck": "zooming and moving sideways",
}

easing_descriptions = {
    "linear": "at a constant speed",
    "easeInQuad": "slowly at first, then accelerating gradually",
    "easeInCubic": "slowly at first, then accelerating more rapidly",
    "easeInQuart": "very slowly at first, then accelerating dramatically",
    "easeInQuint": "extremely slowly at first, then accelerating very dramatically",
    "easeOutQuad": "quickly at first, then decelerating gradually",
    "easeOutCubic": "quickly at first, then decelerating more rapidly",
    "easeOutQuart": "very quickly at first, then decelerating dramatically",
    "easeOutQuint": "extremely quickly at first, then decelerating very dramatically",
    "easeInOutQuad": "gradually accelerating, then gradually decelerating",
    "easeInOutCubic": "slowly accelerating, then decelerating more rapidly",
    "easeInOutQuart": "slowly accelerating, then decelerating dramatically",
    "easeInOutQuint": "very slowly accelerating, then decelerating very dramatically",
    "easeInSine": "with a gentle start, gradually increasing in speed",
    "easeOutSine": "quickly at first, then gently decelerating",
    "easeInOutSine": "with a gentle start and end, faster in the middle",
    "easeInExpo": "starting very slowly, then accelerating exponentially",
    "easeOutExpo": "starting very fast, then decelerating exponentially",
    "easeInOutExpo": "starting and ending slowly, with rapid acceleration and deceleration in the middle",
    "easeInCirc": "starting slowly, then accelerating sharply towards the end",
    "easeOutCirc": "starting quickly, then decelerating sharply towards the end",
    "easeInOutCirc": "with sharp acceleration and deceleration at both ends",
    "easeInBounce": "with a bouncing effect that intensifies towards the end",
    "easeOutBounce": "quickly at first, then bouncing to a stop",
    "easeInOutBounce": "with a bouncing effect at both the start and end",
    "easeInElastic": "with an elastic effect that intensifies towards the end",
    "easeOutElastic": "quickly at first, then oscillating to a stop",
    "easeInOutElastic": "with an elastic effect at both the start and end",
}

angle_descriptions = {
    "lowAngle": "from a low angle",
    "mediumAngle": "from a medium angle",
    "highAngle": "from a high angle",
    "birdsEyeView": "from a bird's eye view",
}

shot_descriptions = {
    "closeUp": "in a close-up shot",
    "mediumShot": "in a medium shot",
    "longShot": "in a long shot",
}

In [None]:
def get_clip_embedding(text: str) -> torch.Tensor:
    with torch.no_grad():
        inputs = clip_tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(device)
        embedding = clip_text_encoder(**inputs).pooler_output
    return embedding

In [None]:
def get_movement_description(movement, easing, camera_angle=None, shot_type=None):
    description = f"The camera is {movement_descriptions.get(movement, movement)}"

    if easing in easing_descriptions:
        description += f" {easing_descriptions[easing]}"
    else:
        description += f" with {easing} easing"

    if camera_angle:
        description += f", {angle_descriptions.get(camera_angle, camera_angle)}"

    if shot_type:
        description += f" {shot_descriptions.get(shot_type, shot_type)}"

    return description

def get_all_movement_descriptions():
    descriptions = []
    for movement in CameraMovementType:
        for easing in EasingType:
            descriptions.append(get_movement_description(movement.value, easing.value))
            for shot in ShotType:
                descriptions.append(get_movement_description(movement.value, easing.value, shot_type=shot.value))
            for angle in CameraAngle:
                descriptions.append(get_movement_description(movement.value, easing.value, camera_angle=angle.value))
                for shot in ShotType:
                    descriptions.append(get_movement_description(movement.value, easing.value, camera_angle=angle.value, shot_type=shot.value))
    return descriptions

def get_movement_index(movement, easing, camera_angle=None, shot_type=None):
    description = get_movement_description(movement, easing, camera_angle, shot_type)
    return all_movement_descriptions.index(description)

In [None]:
def get_negative_indices(positive_indices, num_features, num_negatives):
    batch_size = positive_indices.size(0)
    all_indices = torch.arange(num_features, device=positive_indices.device).unsqueeze(0).expand(batch_size, -1)
    mask = all_indices != positive_indices.unsqueeze(1)
    possible_negatives = all_indices[mask].view(batch_size, num_features - 1)
    indices = torch.randint(0, num_features - 1, (batch_size, num_negatives), device=positive_indices.device)
    negative_indices = torch.gather(possible_negatives, 1, indices)
    return negative_indices


def get_text_features(clip_text_features, positive_indices, negative_indices):
    positive_features = clip_text_features[positive_indices].unsqueeze(1)
    negative_features = clip_text_features[negative_indices.view(-1)].view(len(positive_indices), -1, clip_text_features.size(1))
    return torch.cat([positive_features, negative_features], dim=1)

def compute_similarities(text_features, latent, temperature):
    return torch.bmm(text_features, latent.unsqueeze(2)).squeeze(2) / temperature

def contrastive_loss(latent, clip_text_features, positive_indices, temperature=0.3, num_negatives=50):
    latent = F.normalize(latent, p=2, dim=1)
    clip_text_features = F.normalize(clip_text_features, p=2, dim=1)

    negative_indices = get_negative_indices(positive_indices, clip_text_features.size(0), num_negatives)
    text_features = get_text_features(clip_text_features, positive_indices, negative_indices)
    similarities = compute_similarities(text_features, latent, temperature)

    targets = torch.zeros(latent.size(0), dtype=torch.long, device=latent.device)
    return F.cross_entropy(similarities, targets)

In [None]:
def normalize_camera_trajectory(camera_frames, subject_center, subject_area):
    trajectories = []
    for frame in camera_frames:
        trajectory = []
        position = np.array([frame['position']['x'], frame['position']['y'], frame['position']['z']])
        relative_position = (position - subject_center) * subject_area
        trajectory.extend(relative_position.tolist())
        trajectory.extend([0, 0, 0])  # trajectory.extend([frame['angle']['x'], frame['angle']['y'], frame['angle']['z']])
        trajectory.append(frame['focalLength'])
        trajectories.append(trajectory)
    return trajectories

In [None]:
def generate_square_subsequent_mask(sz):
    mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask

def process_camera_trajectory(trajectory: torch.Tensor) -> List[Dict]:
    trajectory = trajectory.cpu().detach().numpy()
    camera_frames = []
    for frame in trajectory:
        position, angle, focal_length = frame[:3], frame[3:6], frame[6]
        camera_frame = {
            "position": {axis: float(value) for axis, value in zip(['x', 'y', 'z'], position)},
            "angle": {axis: float(value) for axis, value in zip(['x', 'y', 'z'], angle)},
            "focalLength": float(focal_length)
        }
        camera_frames.append(camera_frame)
    return camera_frames

def save_to_json(trajectory: torch.Tensor, filename: str):
    camera_frames = process_camera_trajectory(trajectory)
    with open(filename, 'w') as f:
        json.dump(camera_frames, f, indent=2)

#Implementation

In [None]:
clip_model_name = "openai/clip-vit-large-patch14" #@param ["openai/clip-vit-base-patch32", "openai/clip-vit-large-patch14"]

if clip_model_name == "openai/clip-vit-large-patch14":
    clip_embedding_dim = 768
else:
    clip_embedding_dim = 512

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
clip_tokenizer = CLIPTokenizer.from_pretrained(clip_model_name)
clip_text_encoder = CLIPTextModel.from_pretrained(clip_model_name).to(device)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


In [None]:
all_movement_descriptions = get_all_movement_descriptions()

batch_size = 1000
all_clip_text_features = []

for i in tqdm(range(0, len(all_movement_descriptions), batch_size), desc="Processing batches"):
    batch = all_movement_descriptions[i:i+batch_size]
    batch_features = get_clip_embedding(batch)
    all_clip_text_features.append(batch_features)
    torch.cuda.empty_cache()

all_clip_text_features = torch.cat(all_clip_text_features, dim=0)
torch.cuda.empty_cache()

Processing batches: 100%|██████████| 140/140 [00:29<00:00,  4.73it/s]


In [None]:
class SimulationDataset(Dataset):
    def __init__(self, json_file_path: str):
        with open(json_file_path, 'r') as file:
            raw_data = json.load(file)
        self.simulation_data = [self._process_single_simulation(sim) for sim in raw_data['simulations']
                if self._is_simulation_valid(sim)]

    def __len__(self):
        return len(self.simulation_data)

    def __getitem__(self, index):
        simulation = self.simulation_data[index]
        return {
            'camera_trajectory': torch.tensor(simulation['camera_trajectory'], dtype=torch.float32),
            'movement_type': torch.tensor(simulation['movement_type'], dtype=torch.long),
            'easing_type': torch.tensor(simulation['easing_type'], dtype=torch.long),
            'label_index': simulation['label_index']
        }

    def _is_simulation_valid(self, simulation):
        return (len(simulation['instructions']) == 1 and
                simulation['instructions'][0]['frameCount'] == 30 and
                len(simulation['cameraFrames']) == 30)

    def _process_single_simulation(self, simulation):
        instruction = simulation['instructions'][0]
        subject = simulation['subjects'][0]
        subject_center = np.array([subject['position']['x'], subject['position']['y'], subject['position']['z']])

        camera_trajectory = normalize_camera_trajectory(simulation['cameraFrames'], subject_center, subject_area=1)

        movement_type = CameraMovementType[instruction['cameraMovement']].value
        easing_type = EasingType[instruction['movementEasing']].value
        camera_angle = CameraAngle[instruction.get('initialCameraAngle')].value if 'initialCameraAngle' in instruction else None
        shot_type = ShotType[instruction.get('initialShotType')].value if 'initialShotType' in instruction else None

        movement_type_index = list(CameraMovementType).index(CameraMovementType(movement_type))
        easing_type_index = list(EasingType).index(EasingType(easing_type))

        label_index = get_movement_index(movement_type, easing_type, camera_angle, shot_type)

        return {
            'camera_trajectory': camera_trajectory,
            'movement_type': movement_type_index,
            'easing_type': easing_type_index,
            'label_index': label_index
        }

def batch_collate(batch):
    return {
        'camera_trajectory': torch.stack([item['camera_trajectory'] for item in batch]),
        'movement_type': torch.stack([item['movement_type'] for item in batch]),
        'easing_type': torch.stack([item['easing_type'] for item in batch]),
        'positive_indices': torch.tensor([item['label_index'] for item in batch], dtype=torch.long)
    }

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(0, max_len).unsqueeze(1).float()
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, d_model)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:, :x.size(1)]
        return self.dropout(x)

In [None]:
class Classifier(nn.Module):
    def __init__(self, latent_dim, num_movement_types, num_easing_types, dropout_rate=0.1):
        super(Classifier, self).__init__()
        hidden_dim = 256

        self.movement_type_classifier = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(hidden_dim, num_movement_types)
        )

        self.easing_type_classifier = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(hidden_dim, num_easing_types)
        )

    def forward(self, latent):
        movement_type_logits = self.movement_type_classifier(latent)
        easing_type_logits = self.easing_type_classifier(latent)
        return movement_type_logits, easing_type_logits

In [None]:
class Encoder(nn.Module):
    def __init__(self, input_dim, latent_dim, nhead, num_encoder_layers, dim_feedforward, dropout_rate):
        super(Encoder, self).__init__()

        self.input_projection = nn.Linear(input_dim, latent_dim)
        self.pos_encoder = PositionalEncoding(latent_dim)

        encoder_layer = nn.TransformerEncoderLayer(d_model=latent_dim, nhead=nhead, dim_feedforward=dim_feedforward, dropout=dropout_rate)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_encoder_layers)

        self.latent_query = nn.Parameter(torch.randn(1, 1, latent_dim))

    def forward(self, src, src_key_padding_mask=None):
        src_embedded = self.input_projection(src)
        src_embedded = self.pos_encoder(src_embedded)
        src_embedded = src_embedded.permute(1, 0, 2)

        latent_query = self.latent_query.repeat(1, src_embedded.shape[1], 1)
        src_with_query = torch.cat([latent_query, src_embedded], dim=0)

        if src_key_padding_mask is not None:
            query_mask = torch.zeros((src_key_padding_mask.shape[0], 1), dtype=torch.bool, device=device)
            src_key_padding_mask = torch.cat([query_mask, src_key_padding_mask], dim=1)

        memory = self.transformer_encoder(src_with_query, src_key_padding_mask=src_key_padding_mask)

        latent = memory[0]

        return latent

class AutoregressiveDecoder(nn.Module):
    def __init__(self, output_dim, latent_dim, nhead, num_decoder_layers, dim_feedforward, dropout_rate, seq_length):
        super(AutoregressiveDecoder, self).__init__()
        self.seq_length = seq_length
        self.output_dim = output_dim

        self.pos_encoder = PositionalEncoding(latent_dim)
        self.embedding = nn.Linear(output_dim, latent_dim)

        decoder_layer = nn.TransformerDecoderLayer(d_model=latent_dim, nhead=nhead, dim_feedforward=dim_feedforward, dropout=dropout_rate)
        self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_decoder_layers)

        self.output_projection = nn.Linear(latent_dim, output_dim)

    def forward(self, latent, target=None, teacher_forcing_ratio=0.5):
        memory = latent.unsqueeze(1).repeat(1, self.seq_length, 1)
        memory = memory.transpose(0, 1)

        decoder_input = torch.zeros(latent.shape[0], 1, self.output_dim, device=device)

        outputs = []

        for t in range(self.seq_length):
            embedded = self.embedding(decoder_input)
            embedded = self.pos_encoder(embedded)
            embedded = embedded.transpose(0, 1)

            tgt_mask = nn.Transformer.generate_square_subsequent_mask(embedded.size(0), device=device)

            output = self.transformer_decoder(embedded, memory, tgt_mask=tgt_mask)
            output = output.transpose(0, 1)
            output = self.output_projection(output)

            outputs.append(output)

            if target is not None and torch.rand(1).item() < teacher_forcing_ratio:
                decoder_input = target[:, t:t+1, :]
            else:
                decoder_input = output

        return torch.cat(outputs, dim=1)


class MultiTaskAutoencoder(nn.Module):
    def __init__(self, input_dim=7, nhead=4, num_encoder_layers=3, num_decoder_layers=3, dim_feedforward=2048,
                 dropout_rate=0.1, seq_length=30, latent_dim=512, num_movement_types=len(CameraMovementType),
                 num_easing_types=len(EasingType)):
        super(MultiTaskAutoencoder, self).__init__()

        self.encoder = Encoder(input_dim, latent_dim, nhead, num_encoder_layers, dim_feedforward, dropout_rate)
        self.classifier = Classifier(latent_dim, num_movement_types, num_easing_types, dropout_rate=dropout_rate)
        self.decoder = AutoregressiveDecoder(input_dim, latent_dim, nhead, num_decoder_layers, dim_feedforward, dropout_rate, seq_length)

    def forward(self, src, src_key_padding_mask=None, target=None, teacher_forcing_ratio=0.5):
        latent = self.encoder(src, src_key_padding_mask)
        movement_type_logits, easing_type_logits = self.classifier(latent)
        reconstructed = self.decoder(latent, target, teacher_forcing_ratio)

        return {
            'latent': latent,
            'movement_type_logits': movement_type_logits,
            'easing_type_logits': easing_type_logits,
            'reconstructed': reconstructed,
        }

In [None]:
def init_losses():
    return {
        'total': 0,
        'movement_type': 0,
        'easing_type': 0,
        'reconstruction': 0,
        'clip': 0,
        'clip_contrastive': 0
    }

def print_detailed_losses(phase, losses):
    print(f"{phase} Losses:", end=' ')
    for key, value in losses.items():
        print(f"  {key}: {value:.4f}", end=' ')
    print()

In [None]:
def apply_mask_and_noise(data, mask_ratio=0.0, noise_std=0.0):
    mask = torch.bernoulli(torch.full((data.shape[0], data.shape[1]), 1 - mask_ratio, device=data.device)).bool()

    masked_data = data.clone()
    masked_data[~mask] = 0

    noisy_data = masked_data + torch.normal(mean=0, std=noise_std, size=data.shape, device=data.device)

    src_key_padding_mask = ~mask

    return noisy_data, mask, src_key_padding_mask

In [None]:
def train_epoch(model, dataloader, optimizer, criterion, noise_std=0.0, mask_ratio=0.0, teacher_forcing_ratio=0.5):
    model.train()
    total_losses = init_losses()

    for batch in tqdm(dataloader, desc="Training"):
        losses = process_batch(model, batch, criterion, noise_std, mask_ratio, teacher_forcing_ratio)

        loss = losses['total']
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        for key in total_losses:
            total_losses[key] += losses[key].item() if torch.is_tensor(losses[key]) else losses[key]

    return {k: v / len(dataloader) for k, v in total_losses.items()}

def validate(model, dataloader, criterion):
    model.eval()
    total_losses = init_losses()

    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Validating"):
            losses = process_batch(model, batch, criterion, teacher_forcing_ratio=0)

            for key in total_losses:
                total_losses[key] += losses[key].item() if torch.is_tensor(losses[key]) else losses[key]

    return {k: v / len(dataloader) for k, v in total_losses.items()}


def process_batch(model, batch, criterion, noise_std=0.0, mask_ratio=0.0, teacher_forcing_ratio=0.5):
    camera_trajectory = batch['camera_trajectory'].to(device)
    movement_type = batch['movement_type'].to(device)
    easing_type = batch['easing_type'].to(device)
    positive_indices = batch['positive_indices'].to(device)

    noisy_trajectory, mask, src_key_padding_mask = apply_mask_and_noise(camera_trajectory, mask_ratio, noise_std)

    target = camera_trajectory if model.training else None

    output = model(noisy_trajectory, src_key_padding_mask, target, teacher_forcing_ratio)

    losses = {
        'movement_type': criterion['classification'](output['movement_type_logits'], movement_type),
        'easing_type': criterion['classification'](output['easing_type_logits'], easing_type),
        'reconstruction': (criterion['reconstruction'](output['reconstructed'][mask], camera_trajectory[mask])) * 2,
        'clip': (1 - F.cosine_similarity(output['latent'], all_clip_text_features[positive_indices]).mean()) * 5,
        'clip_contrastive': (contrastive_loss(output['latent'], all_clip_text_features, positive_indices)) * 0.3
    }
    losses['total'] = sum(losses.values())

    return losses

def train_model(model, train_dataloader, val_dataloader, config):
    optimizer = optim.Adam(model.parameters(), lr=config['learning_rate'])
    criterion = {
        'classification': nn.CrossEntropyLoss(),
        'reconstruction': nn.MSELoss()
    }

    best_val_loss = float('inf')
    epochs_without_improvement = 0

    for epoch in range(config['num_epochs']):
        current_teacher_forcing_ratio = (1 - epoch / config['num_epochs'])
        train_losses = train_epoch(model, train_dataloader, optimizer, criterion,
                                   config['noise_std'], config['mask_ratio'], current_teacher_forcing_ratio)
        val_losses = validate(model, val_dataloader, criterion)

        print(f"Epoch {epoch+1}/{config['num_epochs']}")
        print(f"Train Loss: {train_losses['total']:.4f}, Validation Loss: {val_losses['total']:.4f}")
        print_detailed_losses("Train", train_losses)
        print_detailed_losses("Valid", val_losses)

        if val_losses['total'] < best_val_loss:
            best_val_loss = val_losses['total']
            epochs_without_improvement = 0
            torch.save(model.state_dict(), 'best_model.pth')
            print("New best model saved!")
        else:
            epochs_without_improvement += 1

        if epochs_without_improvement >= config['patience']:
            print(f"Early stopping triggered after {epoch+1} epochs")
            break

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dataset = SimulationDataset('random_simulation_dataset.json')

train_dataset, val_dataset = train_test_split(dataset, test_size=0.3, random_state=42)
# train_dataset = val_dataset = [dataset[0]] * 32

batch_size = 32
num_workers = 2

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,
                              collate_fn=batch_collate, num_workers=num_workers)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False,
                            collate_fn=batch_collate, num_workers=num_workers)

model = MultiTaskAutoencoder(latent_dim=clip_embedding_dim).to(device)

In [None]:
config = {
    'num_epochs': 2000,
    'patience': 30,
    'learning_rate': 0.0001,
    'noise_std': 0.1,
    'mask_ratio': 0.5,
}

torch.cuda.empty_cache()
train_model(model, train_dataloader, val_dataloader, config)

##Inference

In [None]:
model.load_state_dict(torch.load('best_model.pth'))
model.eval()
print()

In [None]:
file_path = '/content/result.json'

input_text = "The camera is moving in a full arc to the left with a bouncing effect that intensifies towards the end in a close-up shot" # The camera is panning left at a constant speed, from a low angle in a close-up shot
latent = get_clip_embedding(input_text)
camera_frames = model.autoregressive_decode(latent).squeeze(0)
save_to_json(camera_frames, file_path)

files.download(file_path)

In [None]:
input_path = '/content/input.json'
output_path = '/content/output.json'

input_camera_tranjectory = dataset[0]['camera_trajectory'].to(device)
reconstructed_camera_tranjectory = model(input_camera_tranjectory.unsqueeze(0))['reconstructed'].squeeze(0)

save_to_json(input_camera_tranjectory, input_path)
save_to_json(reconstructed_camera_tranjectory, output_path)

files.download(input_path)
files.download(output_path)