<a href="https://colab.research.google.com/github/ZahraDehghanian97/LenseCraft/blob/master/Cinematography_Instruction_Analysis.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#Imports and Helper Functions

##Imports

In [1]:
import json
import math
from enum import Enum
from typing import List, Dict, Any

from google.colab import files

import numpy as np
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from transformers import CLIPTextModel, CLIPTokenizer
from tqdm import tqdm
from sklearn.model_selection import train_test_split

In [None]:
!gdown 1VT2XfBj9LFWLUBjv65dzC4bVzH0zdNDU

##Helpers

In [3]:
class CameraMovementType(Enum):
    panLeft = "panLeft"
    panRight = "panRight"
    tiltUp = "tiltUp"
    tiltDown = "tiltDown"
    dollyIn = "dollyIn"
    dollyOut = "dollyOut"
    truckLeft = "truckLeft"
    truckRight = "truckRight"
    pedestalUp = "pedestalUp"
    pedestalDown = "pedestalDown"
    fullZoomIn = "fullZoomIn"
    fullZoomOut = "fullZoomOut"
    halfZoomIn = "halfZoomIn"
    halfZoomOut = "halfZoomOut"
    shortZoomIn = "shortZoomIn"
    shortZoomOut = "shortZoomOut"
    shortArcShotLeft = "shortArcShotLeft"
    shortArcShotRight = "shortArcShotRight"
    halfArcShotLeft = "halfArcShotLeft"
    halfArcShotRight = "halfArcShotRight"
    fullArcShotLeft = "fullArcShotLeft"
    fullArcShotRight = "fullArcShotRight"
    panAndTilt = "panAndTilt"
    dollyAndPan = "dollyAndPan"
    zoomAndTruck = "zoomAndTruck"

class EasingType(Enum):
    linear = "linear"
    easeInQuad = "easeInQuad"
    easeInCubic = "easeInCubic"
    easeInQuart = "easeInQuart"
    easeInQuint = "easeInQuint"
    easeOutQuad = "easeOutQuad"
    easeOutCubic = "easeOutCubic"
    easeOutQuart = "easeOutQuart"
    easeOutQuint = "easeOutQuint"
    easeInOutQuad = "easeInOutQuad"
    easeInOutCubic = "easeInOutCubic"
    easeInOutQuart = "easeInOutQuart"
    easeInOutQuint = "easeInOutQuint"
    easeInSine = "easeInSine"
    easeOutSine = "easeOutSine"
    easeInOutSine = "easeInOutSine"
    easeInExpo = "easeInExpo"
    easeOutExpo = "easeOutExpo"
    easeInOutExpo = "easeInOutExpo"
    easeInCirc = "easeInCirc"
    easeOutCirc = "easeOutCirc"
    easeInOutCirc = "easeInOutCirc"
    easeInBounce = "easeInBounce"
    easeOutBounce = "easeOutBounce"
    easeInOutBounce = "easeInOutBounce"
    easeInElastic = "easeInElastic"
    easeOutElastic = "easeOutElastic"
    easeInOutElastic = "easeInOutElastic"

class CameraAngle(Enum):
    lowAngle = "lowAngle"
    mediumAngle = "mediumAngle"
    highAngle = "highAngle"
    birdsEyeView = "birdsEyeView"

class ShotType(Enum):
    closeUp = "closeUp"
    mediumShot = "mediumShot"
    longShot = "longShot"

In [4]:
def get_movement_description(movement, easing, camera_angle=None, shot_type=None):
    movement_descriptions = {
        "panLeft": "panning left",
        "panRight": "panning right",
        "tiltUp": "tilting up",
        "tiltDown": "tilting down",
        "dollyIn": "moving closer",
        "dollyOut": "moving away",
        "truckLeft": "moving left",
        "truckRight": "moving right",
        "pedestalUp": "rising vertically",
        "pedestalDown": "descending vertically",
        "fullZoomIn": "zooming in fully",
        "fullZoomOut": "zooming out fully",
        "halfZoomIn": "zooming in halfway",
        "halfZoomOut": "zooming out halfway",
        "shortZoomIn": "zooming in slightly",
        "shortZoomOut": "zooming out slightly",
        "shortArcShotLeft": "moving in a short arc to the left",
        "shortArcShotRight": "moving in a short arc to the right",
        "halfArcShotLeft": "moving in a half arc to the left",
        "halfArcShotRight": "moving in a half arc to the right",
        "fullArcShotLeft": "moving in a full arc to the left",
        "fullArcShotRight": "moving in a full arc to the right",
        "panAndTilt": "panning and tilting",
        "dollyAndPan": "moving and panning",
        "zoomAndTruck": "zooming and moving sideways",
    }

    easing_descriptions = {
        "linear": "at a constant speed",
        "easeInQuad": "slowly at first, then accelerating gradually",
        "easeInCubic": "slowly at first, then accelerating more rapidly",
        "easeInQuart": "very slowly at first, then accelerating dramatically",
        "easeInQuint": "extremely slowly at first, then accelerating very dramatically",
        "easeOutQuad": "quickly at first, then decelerating gradually",
        "easeOutCubic": "quickly at first, then decelerating more rapidly",
        "easeOutQuart": "very quickly at first, then decelerating dramatically",
        "easeOutQuint": "extremely quickly at first, then decelerating very dramatically",
        "easeInOutQuad": "gradually accelerating, then gradually decelerating",
        "easeInOutCubic": "slowly accelerating, then decelerating more rapidly",
        "easeInOutQuart": "slowly accelerating, then decelerating dramatically",
        "easeInOutQuint": "very slowly accelerating, then decelerating very dramatically",
        "easeInSine": "with a gentle start, gradually increasing in speed",
        "easeOutSine": "quickly at first, then gently decelerating",
        "easeInOutSine": "with a gentle start and end, faster in the middle",
        "easeInExpo": "starting very slowly, then accelerating exponentially",
        "easeOutExpo": "starting very fast, then decelerating exponentially",
        "easeInOutExpo": "starting and ending slowly, with rapid acceleration and deceleration in the middle",
        "easeInCirc": "starting slowly, then accelerating sharply towards the end",
        "easeOutCirc": "starting quickly, then decelerating sharply towards the end",
        "easeInOutCirc": "with sharp acceleration and deceleration at both ends",
        "easeInBounce": "with a bouncing effect that intensifies towards the end",
        "easeOutBounce": "quickly at first, then bouncing to a stop",
        "easeInOutBounce": "with a bouncing effect at both the start and end",
        "easeInElastic": "with an elastic effect that intensifies towards the end",
        "easeOutElastic": "quickly at first, then oscillating to a stop",
        "easeInOutElastic": "with an elastic effect at both the start and end",
    }

    angle_descriptions = {
        "lowAngle": "from a low angle",
        "mediumAngle": "from a medium angle",
        "highAngle": "from a high angle",
        "birdsEyeView": "from a bird's eye view",
    }

    shot_descriptions = {
        "closeUp": "in a close-up shot",
        "mediumShot": "in a medium shot",
        "longShot": "in a long shot",
    }

    description = f"The camera is {movement_descriptions.get(movement, movement)}"

    if easing in easing_descriptions:
        description += f" {easing_descriptions[easing]}"
    else:
        description += f" with {easing} easing"

    if camera_angle:
        description += f", {angle_descriptions.get(camera_angle, camera_angle)}"

    if shot_type:
        description += f" {shot_descriptions.get(shot_type, shot_type)}"

    return description

In [5]:
def get_clip_embedding(text: str) -> torch.Tensor:
    inputs = clip_tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(device)
    with torch.no_grad():
        embedding = clip_text_encoder(**inputs).pooler_output
    return embedding

In [6]:
def get_all_movement_descriptions():
    descriptions = []
    for movement in CameraMovementType:
        for easing in EasingType:
            descriptions.append(get_movement_description(movement.value, easing.value))
            for shot in ShotType:
                descriptions.append(get_movement_description(movement.value, easing.value, shot_type=shot.value))
            for angle in CameraAngle:
                descriptions.append(get_movement_description(movement.value, easing.value, camera_angle=angle.value))
                for shot in ShotType:
                    descriptions.append(get_movement_description(movement.value, easing.value, camera_angle=angle.value, shot_type=shot.value))
    return descriptions

def get_movement_index(movement, easing, camera_angle=None, shot_type=None):
    description = get_movement_description(movement, easing, camera_angle, shot_type)
    return all_movement_descriptions.index(description)

In [7]:
def autoregressive_decode(decoder, latent, num_frames, initial_input):
    current_input = initial_input.to(device)
    trajectory = []

    for _ in range(num_frames):
        output = decoder(latent, current_input)
        next_step = output[:, -1:, :]
        trajectory.append(next_step)
        current_input = torch.cat([current_input, next_step], dim=1)

    return torch.cat(trajectory, dim=1)

def generate_camera_movement(model: torch.nn.Module, text_input: str, num_frames: int = 30) -> List[Dict]:
    with torch.no_grad():
        latent = get_clip_embedding(text_input).to(device)
        initial_input = torch.zeros(1, 1, 7).to(device)

        camera_trajectory = autoregressive_decode(model.decoder, latent, num_frames, initial_input)

    return process_camera_trajectory(camera_trajectory)

def reconstruct_camera_movement(decoder: torch.nn.Module, camera_trajectory: torch.Tensor, mask: torch.Tensor = None, num_frames: int = 30) -> List[Dict]:
    camera_trajectory = camera_trajectory.unsqueeze(0)
    with torch.no_grad():
        latent = model.encoder(camera_trajectory, mask)
        initial_input = torch.zeros(1, 1, 7).to(device)
        reconstructed = autoregressive_decode(decoder, latent, num_frames, initial_input)

    return process_camera_trajectory(reconstructed)

def process_camera_trajectory(trajectory: torch.Tensor) -> List[Dict]:
    trajectory = trajectory.cpu().numpy()[0]
    camera_frames = []
    for frame in trajectory:
        position, angle, focal_length = frame[:3], frame[3:6], frame[6]
        camera_frame = {
            "position": {axis: float(value) for axis, value in zip(['x', 'y', 'z'], position)},
            "angle": {axis: float(value) for axis, value in zip(['x', 'y', 'z'], angle)},
            "focalLength": float(focal_length)
        }
        camera_frames.append(camera_frame)
    return camera_frames

def save_to_json(camera_frames: List[Dict], filename: str):
    with open(filename, 'w') as f:
        json.dump(camera_frames, f, indent=2)

#Implementation

In [None]:
clip_model_name = "openai/clip-vit-large-patch14" #@param ["openai/clip-vit-base-patch32", "openai/clip-vit-large-patch14"]

if clip_model_name == "openai/clip-vit-large-patch14":
    clip_embedding_dim = 768
else:
    clip_embedding_dim = 512

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
clip_tokenizer = CLIPTokenizer.from_pretrained(clip_model_name)
clip_text_encoder = CLIPTextModel.from_pretrained(clip_model_name).to(device)

In [None]:
all_movement_descriptions = get_all_movement_descriptions()

batch_size = 100
all_clip_text_features = []

for i in tqdm(range(0, len(all_movement_descriptions), batch_size), desc="Processing batches"):
    batch = all_movement_descriptions[i:i+batch_size]
    batch_features = get_clip_embedding(batch)
    all_clip_text_features.append(batch_features)
    torch.cuda.empty_cache()

all_clip_text_features = torch.cat(all_clip_text_features, dim=0)

In [10]:
class SimulationDataset(Dataset):
    def __init__(self, json_file_path: str):
        with open(json_file_path, 'r') as file:
            raw_data = json.load(file)
        self.simulation_data = [self._process_single_simulation(sim) for sim in raw_data['simulations']
                if self._is_simulation_valid(sim)]

    def __len__(self):
        return len(self.simulation_data)

    def __getitem__(self, index):
        simulation = self.simulation_data[index]
        return {
            'camera_trajectory': torch.tensor(simulation['camera_trajectory'], dtype=torch.float32),
            'movement_type': torch.tensor(simulation['movement_type'], dtype=torch.long),
            'easing_type': torch.tensor(simulation['easing_type'], dtype=torch.long),
            'label_index': simulation['label_index']
        }

    def _is_simulation_valid(self, simulation):
        return (len(simulation['instructions']) == 1 and
                simulation['instructions'][0]['frameCount'] == 30 and
                len(simulation['cameraFrames']) == 30)

    def _process_single_simulation(self, simulation):
        instruction = simulation['instructions'][0]
        subject = simulation['subjects'][0]
        subject_center = np.array([subject['position']['x'], subject['position']['y'], subject['position']['z']])
        subject_size = np.array([subject['size']['x'], subject['size']['y'], subject['size']['z']])
        subject_area = 1

        camera_trajectory = self._normalize_camera_trajectory(simulation['cameraFrames'], subject_center, subject_area)

        movement_type = CameraMovementType[instruction['cameraMovement']].value
        easing_type = EasingType[instruction['movementEasing']].value
        camera_angle = CameraAngle[instruction.get('initialCameraAngle')].value if 'initialCameraAngle' in instruction else None
        shot_type = ShotType[instruction.get('initialShotType')].value if 'initialShotType' in instruction else None

        movement_type_index = list(CameraMovementType).index(CameraMovementType(movement_type))
        easing_type_index = list(EasingType).index(EasingType(easing_type))

        label_index = get_movement_index(movement_type, easing_type, camera_angle, shot_type)

        return {
            'camera_trajectory': camera_trajectory,
            'movement_type': movement_type_index,
            'easing_type': easing_type_index,
            'label_index': label_index
        }

    def _normalize_camera_trajectory(self, camera_frames, subject_center, subject_area):
        trajectory = []
        for frame in camera_frames:
            relative_position = (np.array([frame['position']['x'], frame['position']['y'], frame['position']['z']]) - subject_center) * subject_area
            trajectory.extend(relative_position.tolist())
            trajectory.extend([frame['angle']['x'], frame['angle']['y'], frame['angle']['z']])
            trajectory.append(frame['focalLength'])
        return trajectory

def batch_collate(batch):
    return {
        'camera_trajectory': torch.stack([item['camera_trajectory'] for item in batch]),
        'movement_type': torch.stack([item['movement_type'] for item in batch]),
        'easing_type': torch.stack([item['easing_type'] for item in batch]),
        'positive_indices': torch.tensor([item['label_index'] for item in batch], dtype=torch.long)
    }

In [11]:
class PositionalEncoding(nn.Module): # Sinusoidal Positional Encoding
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        return x + self.pe[:x.size(0), :]


class Encoder(nn.Module):
    def __init__(self, input_dim, d_model, nhead, num_encoder_layers, max_seq_length, latent_dim):
        super(Encoder, self).__init__()
        self.d_model = d_model
        self.max_seq_length = max_seq_length

        self.input_projection = nn.Linear(7, d_model)
        self.pos_encoder = PositionalEncoding(d_model, max_seq_length)

        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_encoder_layers)

        self.encoder_to_latent = nn.Linear(d_model * max_seq_length, latent_dim)

    def forward(self, src, mask=None):
        batch_size = src.size(0)
        src = src.view(batch_size, self.max_seq_length, 7)
        src = self.input_projection(src)
        src = src.permute(1, 0, 2)
        src = self.pos_encoder(src)

        if mask is not None:
            mask = mask[:, :self.max_seq_length]
            mask = mask.permute(1, 0).unsqueeze(-1).expand(-1, -1, self.d_model)
            src = src * mask
            src_key_padding_mask = ~mask[:, :, 0].permute(1, 0).bool()
        else:
            src_key_padding_mask = None

        embedding = self.transformer_encoder(src, src_key_padding_mask=src_key_padding_mask)
        latent = self.encoder_to_latent(embedding.permute(1, 0, 2).reshape(batch_size, -1))
        return latent

class Decoder(nn.Module):
    def __init__(self, d_model, nhead, num_decoder_layers, max_seq_length, latent_dim):
        super(Decoder, self).__init__()
        self.d_model = d_model
        self.max_seq_length = max_seq_length

        self.latent_to_memory = nn.Linear(latent_dim, d_model * max_seq_length)
        self.input_projection = nn.Linear(7, d_model)
        self.pos_encoder = PositionalEncoding(d_model, max_seq_length)

        decoder_layer = nn.TransformerDecoderLayer(d_model=d_model, nhead=nhead)
        self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_decoder_layers)

        self.output_projection = nn.Linear(d_model, 7)

    def forward(self, latent, tgt):
        memory = self.latent_to_memory(latent).view(-1, self.max_seq_length, self.d_model).permute(1, 0, 2)
        tgt = self.input_projection(tgt)
        tgt = tgt.permute(1, 0, 2)
        tgt = self.pos_encoder(tgt)
        output = self.transformer_decoder(tgt, memory)
        output = output.permute(1, 0, 2)
        output = self.output_projection(output)
        return output

class Classifier(nn.Module):
    def __init__(self, latent_dim, num_movement_types, num_easing_types):
        super(Classifier, self).__init__()
        self.movement_type_classifier = nn.Linear(latent_dim, num_movement_types)
        self.easing_type_classifier = nn.Linear(latent_dim, num_easing_types)

    def forward(self, latent):
        movement_type_logits = self.movement_type_classifier(latent)
        easing_type_logits = self.easing_type_classifier(latent)
        return movement_type_logits, easing_type_logits

class MultiTaskAutoencoder(nn.Module):
    def __init__(self, input_dim, d_model, nhead, num_encoder_layers, num_decoder_layers,
                 num_movement_types, num_easing_types, max_seq_length, latent_dim):
        super(MultiTaskAutoencoder, self).__init__()
        self.encoder = Encoder(input_dim, d_model, nhead, num_encoder_layers, max_seq_length, latent_dim)
        self.decoder = Decoder(d_model, nhead, num_decoder_layers, max_seq_length, latent_dim)
        self.classifier = Classifier(latent_dim, num_movement_types, num_easing_types)
        self.max_seq_length = max_seq_length

    def forward(self, input, mask=None):
        latent = self.encoder(input, mask)

        movement_type_logits, easing_type_logits = self.classifier(latent)

        initial_input = torch.zeros(input.shape[0], 1, 7).to(device)
        reconstructed = autoregressive_decode(self.decoder, latent, self.max_seq_length, initial_input)

        return {
            'latent': latent,
            'movement_type_logits': movement_type_logits,
            'easing_type_logits': easing_type_logits,
            'reconstructed': reconstructed.view(-1, self.max_seq_length * 7),
        }

In [12]:
def create_mask(batch_size, seq_length, mask_ratio=0.5):
    return torch.bernoulli(torch.full((batch_size, seq_length), 1 - mask_ratio)).bool().to(device)

In [13]:
def contrastive_loss(latent, clip_text_features, positive_indices, temperature=0.5, num_samples=50):
    latent = F.normalize(latent, p=2, dim=1)
    clip_text_features = F.normalize(clip_text_features, p=2, dim=1)

    num_features = clip_text_features.size(0)
    random_indices = torch.randperm(num_features, device=device)[:num_samples]
    sampled_features = clip_text_features[random_indices]

    similarity_matrix = torch.matmul(latent, sampled_features.T) / temperature

    positive_mask = torch.zeros(similarity_matrix.shape, dtype=torch.bool, device=device)
    for i, idx in enumerate(positive_indices):
        if (random_indices == idx).any():
            positive_mask[i, torch.where(random_indices == idx)[0]] = True

    if positive_mask.any():
        target = torch.zeros(similarity_matrix.shape[0], dtype=torch.long, device=device)
        for i in range(similarity_matrix.shape[0]):
            positive_indices = torch.where(positive_mask[i])[0]
            if len(positive_indices) > 0:
                target[i] = positive_indices[0]
            else:
                target[i] = 0

        loss = F.cross_entropy(similarity_matrix, target, reduction='none')
        valid_loss = loss[positive_mask.any(dim=1)]
        return valid_loss.mean() if len(valid_loss) > 0 else torch.tensor(0.0, device=device, requires_grad=True)
    else:
        return torch.tensor(0.0, device=device, requires_grad=True)

In [14]:
def init_losses():
  return {
      'total': 0,
      'movement_type': 0,
      'easing_type': 0,
      'reconstruction': 0,
      'clip': 0,
      'clip_contrastive': 0
  }

def print_detailed_losses(phase, losses):
    print(f"{phase} Losses - "
          f"Movement: {losses['movement_type']:.4f}, "
          f"Easing: {losses['easing_type']:.4f}, "
          f"Reconstruction: {losses['reconstruction']:.4f}, "
          f"CLIP: {losses['clip']:.4f}, "
          f"CLIP Contrastive: {losses['clip_contrastive']:.4f}")

In [15]:
def train_epoch(model, dataloader, optimizer, criterion, all_clip_text_features):
    model.train()
    total_losses = init_losses()

    for batch in tqdm(dataloader, desc="Training"):
        losses = process_batch(model, batch, criterion, all_clip_text_features)

        loss = sum(losses.values())
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        for key in total_losses:
            total_losses[key] += losses[key].item() if torch.is_tensor(losses[key]) else losses[key]

    return {k: v / len(dataloader) for k, v in total_losses.items()}

def validate(model, dataloader, criterion, all_clip_text_features):
    model.eval()
    total_losses = init_losses()

    with torch.no_grad():
        for batch in dataloader:
            losses = process_batch(model, batch, criterion, all_clip_text_features)
            loss = sum(losses.values())

            for key in total_losses:
                total_losses[key] += losses[key].item() if torch.is_tensor(losses[key]) else losses[key]

    return {k: v / len(dataloader) for k, v in total_losses.items()}

def process_batch(model, batch, criterion, all_clip_text_features):
    camera_trajectory = batch['camera_trajectory'].to(device)
    movement_type = batch['movement_type'].to(device)
    easing_type = batch['easing_type'].to(device)
    positive_indices = batch['positive_indices'].to(device)

    mask = create_mask(camera_trajectory.shape[0], camera_trajectory.shape[1])
    output = model(camera_trajectory, mask)

    losses = {
        'movement_type': criterion['classification'](output['movement_type_logits'], movement_type),
        'easing_type': criterion['classification'](output['easing_type_logits'], easing_type),
        'reconstruction': criterion['reconstruction'](output['reconstructed'], camera_trajectory),
        'clip': 1 - F.cosine_similarity(output['latent'], all_clip_text_features[positive_indices]).mean(),
        'clip_contrastive': contrastive_loss(output['latent'], all_clip_text_features, positive_indices)
    }
    losses['total'] = sum(losses.values())

    return losses

def train_model(model, train_dataloader, val_dataloader, config):
    optimizer = optim.Adam(model.parameters(), lr=config['learning_rate'])
    criterion = {
        'classification': nn.CrossEntropyLoss(),
        'reconstruction': nn.MSELoss()
    }

    best_val_loss = float('inf')
    epochs_without_improvement = 0

    for epoch in range(config['num_epochs']):
        train_losses = train_epoch(model, train_dataloader, optimizer, criterion, config['all_clip_text_features'])
        val_losses = validate(model, val_dataloader, criterion, config['all_clip_text_features'])

        print(f"Epoch {epoch+1}/{config['num_epochs']}")
        print(f"Train Loss: {train_losses['total']:.4f}, Validation Loss: {val_losses['total']:.4f}")
        print_detailed_losses("Train", train_losses)
        print_detailed_losses("Valid", val_losses)

        if val_losses['total'] < best_val_loss:
            best_val_loss = val_losses['total']
            epochs_without_improvement = 0
            torch.save(model.state_dict(), 'best_model.pth')
            print("New best model saved!")
        else:
            epochs_without_improvement += 1

        if epochs_without_improvement >= config['patience']:
            print(f"Early stopping triggered after {epoch+1} epochs")
            break

In [None]:
dataset = SimulationDataset('random_simulation_dataset.json')

train_dataset, val_dataset = train_test_split(dataset, test_size=0.2, random_state=42)

train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=batch_collate)
val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=False, collate_fn=batch_collate)

input_dim = 30 * 7  # 30 camera_trajectory, 7 values per frame
d_model = 256
nhead = 8
num_encoder_layers = 3
num_decoder_layers = 3
num_movement_types = len(CameraMovementType)
num_easing_types = len(EasingType)
max_seq_length = 30
latent_dim = clip_embedding_dim

model = MultiTaskAutoencoder(input_dim, d_model, nhead, num_encoder_layers, num_decoder_layers,
                             num_movement_types, num_easing_types, max_seq_length, latent_dim).to(device)

In [None]:
config = {
    'num_epochs': 100,
    'patience': 10,
    'learning_rate': 0.001,
    'all_clip_text_features': all_clip_text_features
}

train_model(model, train_dataloader, val_dataloader, config)

##Inference

In [None]:
model.load_state_dict(torch.load('best_model.pth'))
model.eval()

In [None]:
file_path = '/content/result.json'

input_text = "The camera is moving in a full arc to the left with a bouncing effect that intensifies towards the end in a close-up shot" # The camera is panning left at a constant speed, from a low angle in a close-up shot
camera_frames = generate_camera_movement(model, input_text)
save_to_json(camera_frames, file_path)

files.download(file_path)

In [None]:
file_path = '/content/result.json'

camera_frames = reconstruct_camera_movement(model.decoder, dataset[1]['camera_trajectory'].to(device))
save_to_json(camera_frames, file_path)

files.download(file_path)