<a href="https://colab.research.google.com/github/ZahraDehghanian97/LensCraft/blob/master/Translator.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Dataset Generation (Using Templates)

In [1]:

CAMERA_PARAMETERS = {
    "CameraVerticalAngle": {
        "values": ["low", "eye", "high", "overhead", "birdsEye"],
        "synonyms": {
            "low": ["low-angle", "from below", "upward-facing", "worm's eye view", "ground level", "looking up"],
            "eye": ["eye-level", "neutral", "straight-on", "level view", "natural angle", "standard height"],
            "high": ["high-angle", "from above", "downward-facing", "elevated view", "raised perspective", "looking down"],
            "overhead": ["overhead", "from directly above", "top-down", "ceiling view", "direct overhead", "vertical down"],
            "birdsEye": ["bird's eye", "aerial", "far overhead", "extreme overhead", "elevated overhead", "sky view"]
        }
    },
    "ShotSize": {
        "values": [
            "extremeCloseUp",
            "closeUp",
            "mediumCloseUp",
            "mediumShot",
            "fullShot",
            "longShot",
            "veryLongShot",
            "extremeLongShot",
        ],
        "synonyms": {
            "extremeCloseUp": ["extreme close-up", "macro shot", "detail view", "intimate detail", "super close-up", "microscopic view"],
            "closeUp": ["close-up", "tight shot", "near view", "facial shot", "intimate frame", "detailed view"],
            "mediumCloseUp": ["medium close-up", "head and shoulders", "bust shot", "upper body frame", "shoulder shot", "partial upper body"],
            "mediumShot": ["medium shot", "mid-shot", "waist shot", "half body", "waist-up view", "mid-frame"],
            "fullShot": ["full shot", "full body", "head to toe", "complete view", "entire figure", "full frame"],
            "longShot": ["long shot", "wide shot", "full view", "environmental view", "contextual shot", "scene-setting shot"],
            "veryLongShot": ["very long shot", "very wide shot", "establishing shot", "master shot", "broad view", "expansive shot"],
            "extremeLongShot": ["extreme long shot", "extreme wide", "panoramic view", "vista shot", "grand view", "epic scale"]
        }
    },
    "MovementSpeed": {
        "values": [
            "slowToFast",
            "fastToSlow",
            "constant",
            "stopAndGo",
            "deliberateStartStop",
            "erratic",
            "pulsing"
        ],
        "synonyms": {
            "slowToFast": ["gradually accelerating", "increasing speed", "picking up pace", "building momentum", "ramping up", "progressive acceleration"],
            "fastToSlow": ["gradually decelerating", "decreasing speed", "slowing down", "easing to stop", "winding down", "tapering speed"],
            "constant": ["steady", "uniform", "consistent speed", "unchanging pace", "maintained velocity", "even movement"],
            "stopAndGo": ["intermittent", "start and stop", "punctuated movement", "staccato motion", "interrupted flow", "periodic pauses"],
            "deliberateStartStop": ["measured pauses", "intentional stops", "rhythmic stopping", "choreographed pauses", "planned breaks", "controlled stops"],
            "erratic": ["unpredictable", "varying", "irregular", "random speeds", "sporadic", "changeable pace"],
            "pulsing": ["rhythmic", "beating", "oscillating", "cyclic motion", "wave-like", "periodic"]
        }
    },
    "SubjectInFramePosition": {
        "values": [
            "left", "right", "top", "bottom", "center",
            "topLeft", "topRight", "bottomLeft", "bottomRight",
            "outerLeft", "outerRight", "outerTop", "outerBottom",
            "offsetCenter",
            "diagonal"
        ],
        "synonyms": {
            "left": ["on the left", "left side", "left portion", "leftward", "port side", "left zone"],
            "right": ["on the right", "right side", "right portion", "rightward", "starboard side", "right zone"],
            "top": ["at the top", "upper portion", "top area", "upper zone", "superior position", "high position"],
            "bottom": ["at the bottom", "lower portion", "bottom area", "lower zone", "inferior position", "low position"],
            "center": ["in the center", "middle", "central position", "dead center", "bull's eye", "epicenter"],
            "topLeft": ["upper left", "top left corner", "northwest position", "high left", "upper port", "northeast corner"],
            "topRight": ["upper right", "top right corner", "northeast position", "high right", "upper starboard", "northwest corner"],
            "bottomLeft": ["lower left", "bottom left corner", "southwest position", "low left", "lower port", "southwest corner"],
            "bottomRight": ["lower right", "bottom right corner", "southeast position", "low right", "lower starboard", "southeast corner"],
            "outerLeft": ["far left", "leftmost edge", "extreme left", "peripheral left", "margin left", "border left"],
            "outerRight": ["far right", "rightmost edge", "extreme right", "peripheral right", "margin right", "border right"],
            "outerTop": ["very top", "topmost edge", "extreme top", "peripheral top", "margin top", "border top"],
            "outerBottom": ["very bottom", "bottommost edge", "extreme bottom", "peripheral bottom", "margin bottom", "border bottom"],
            "offsetCenter": ["slightly off-center", "near center", "just off middle", "asymmetric center", "shifted center", "biased center"],
            "diagonal": ["diagonal position", "angular placement", "oblique position", "slanted position", "diagonal offset", "cross-frame"]
        }
    },
    "SubjectView": {
        "values": [
            "front",
            "back",
            "left",
            "right",
            "threeQuarterFrontLeft",
            "threeQuarterFrontRight",
            "threeQuarterBackLeft",
            "threeQuarterBackRight",
            "overhead",
            "silhouette"
        ],
        "synonyms": {
            "front": ["front view", "facing camera", "direct front", "head-on", "straight ahead", "forward facing"],
            "back": ["back view", "from behind", "rear view", "posterior view", "reverse angle", "backing"],
            "left": ["left side", "left profile", "from the left", "port side view", "left aspect", "sinistral view"],
            "right": ["right side", "right profile", "from the right", "starboard side view", "right aspect", "dextral view"],
            "threeQuarterFrontLeft": ["angled front left", "partial front left", "diagonal front left", "oblique front left", "left forward angle", "left front perspective"],
            "threeQuarterFrontRight": ["angled front right", "partial front right", "diagonal front right", "oblique front right", "right forward angle", "right front perspective"],
            "threeQuarterBackLeft": ["angled back left", "partial back left", "diagonal back left", "oblique back left", "left rear angle", "left back perspective"],
            "threeQuarterBackRight": ["angled back right", "partial back right", "diagonal back right", "oblique back right", "right rear angle", "right back perspective"],
            "overhead": ["from above", "top view", "bird's perspective", "downward view", "superior view", "zenith angle"],
            "silhouette": ["shadow form", "outlined shape", "backlit profile", "contour view", "rim lit", "shape outline"]
        }
    },
    "CameraMovementType": {
        "values": [
            "static", "panLeft", "panRight", "tiltUp", "tiltDown",
            "dollyIn", "dollyOut", "truckLeft", "truckRight",
            "pedestalUp", "pedestalDown", "arcLeft", "arcRight",
            "craneUp", "craneDown", "dollyOutZoomIn", "dollyInZoomOut",
            "dutchLeft", "dutchRight", "follow",
            "spiral",
            "snakeTrack",
            "boomerang"
        ],
        "synonyms": {
            "static": ["stationary", "fixed", "still", "locked off", "immobile", "stable"],
            "panLeft": ["pan left", "sweep left", "rotate left", "horizontal left", "left scan", "leftward pan"],
            "panRight": ["pan right", "sweep right", "rotate right", "horizontal right", "right scan", "rightward pan"],
            "tiltUp": ["tilt upward", "look up", "angle up", "vertical up", "upward pivot", "ascend view"],
            "tiltDown": ["tilt downward", "look down", "angle down", "vertical down", "downward pivot", "descend view"],
            "dollyIn": ["move forward", "push in", "track forward", "advance", "forward track", "approach"],
            "dollyOut": ["move backward", "pull out", "track backward", "retreat", "backward track", "withdraw"],
            "truckLeft": ["slide left", "track left", "lateral left", "crab left", "sideways left", "parallel left"],
            "truckRight": ["slide right", "track right", "lateral right", "crab right", "sideways right", "parallel right"],
            "pedestalUp": ["raise up", "elevate", "lift up", "vertical rise", "upward boost", "ascend"],
            "pedestalDown": ["lower down", "descend", "move down", "vertical drop", "downward sink", "descend"],
            "arcLeft": ["curve left", "orbit left", "circular left", "left circle", "rounded left", "left orbit"],
            "arcRight": ["curve right", "orbit right", "circular right", "right circle", "rounded right", "right orbit"],
            "craneUp": ["boom up", "jib up", "rise up", "sweep up", "ascending arc", "upward boom"],
            "craneDown": ["boom down", "jib down", "lower", "sweep down", "descending arc", "downward boom"],
            "dollyOutZoomIn": ["pull back and zoom", "compensating pullback", "contra-zoom out", "reverse dolly zoom", "backward zoom", "vertigo effect"],
            "dollyInZoomOut": ["push in and zoom out", "compensating push", "contra-zoom in", "forward dolly zoom", "forward zoom", "inverse vertigo"],
            "dutchLeft": ["roll left", "tilt left", "diagonal left", "left rotation", "left cant", "oblique left"],
            "dutchRight": ["roll right", "tilt right", "diagonal right", "right rotation", "right cant", "oblique right"],
            "follow": ["track subject", "maintain follow", "chase movement", "pursuit shot", "accompany motion", "shadow movement"],
            "spiral": ["helical movement", "corkscrew motion", "spiral track", "circular descent", "winding path", "coil movement"],
            "snakeTrack": ["serpentine movement", "winding track", "meandering motion", "curved path", "s-curve movement", "flowing track"],
            "boomerang": ["return movement", "back-and-forth", "pendulum motion", "swing track", "reversing path", "loop movement"]
        }
    }
}


##Generate a Dataset with Random Number Parameter

In [2]:
import random
import json

# Expanded template components with more natural variations
TEMPLATE_COMPONENTS = {
    "angle": [
        "from a {angle} angle",
        "with a {angle} perspective",
        "maintaining a {angle} view",
        "using a {angle} vantage point",
        "positioned at a {angle} level",
        "set up with a {angle} viewpoint",
        "utilizing a {angle} camera position",
        "with the camera {angle}",
        "capturing from {angle}",
        "at a {angle} height"
    ],
    "shot_size": [
        "capture a {shot_size}",
        "frame a {shot_size}",
        "execute a {shot_size}",
        "create a {shot_size} shot",
        "compose a {shot_size}",
        "establish a {shot_size}",
        "set up a {shot_size}",
        "design a {shot_size} composition",
        "deliver a {shot_size}",
        "aim for a {shot_size}"
    ],
    "movement_type":[
        "as the camera {movement_type}",
        "with a {movement_type} movement",
        "using a {movement_type} motion",
        "implementing a {movement_type}",
        "executing a {movement_type}",
        "performing a {movement_type}",
        "following through with a {movement_type}",
        "{movement_type} the camera {movement_speed}",
        "in a {movement_type} pattern",
        "with camera movement {movement_type}"
    ],
    "movement_speed": [
        "with speed {movement_speed}",
        "at {movement_speed}",
        "with a {movement_speed} pace",
        "{movement_type} the camera",
        "moving {movement_speed}",
        "{movement_speed}",
        "with a {movement_speed} speed",
        "at a {movement_speed} rate"
    ],
    "frame_position": [
        "with the subject positioned {frame_position}",
        "keeping the subject {frame_position} in frame",
        "placing the subject {frame_position}",
        "maintaining the subject {frame_position}",
        "featuring the subject {frame_position}",
        "with subject placement {frame_position}",
        "composing the subject {frame_position}",
        "arranging the subject {frame_position}",
        "positioning our focus {frame_position}",
        "with the main element {frame_position}"
    ],
    "subject_view": [
        "showing their {subject_view}",
        "capturing their {subject_view}",
        "emphasizing their {subject_view}",
        "highlighting their {subject_view}",
        "revealing their {subject_view}",
        "displaying the {subject_view}",
        "featuring their {subject_view}",
        "presenting the {subject_view}",
        "focusing on their {subject_view}",
        "accentuating the {subject_view}"
    ]
}

# Opening phrases to add variety
OPENING_PHRASES = [
    "The shot requires",
    "Set up",
    "The scene calls for",
    "We need",
    "This shot demands",
    "Let's capture",
    "Plan to get",
    "The sequence needs",
    "We're looking for",
    "Arrange",
    "Position the camera to",
    "The frame should",
    "We want to",
    "The goal is to",
    "Focus on"
]

# Connecting phrases for more natural flow
CONNECTING_PHRASES = [
    "while",
    "as",
    "and",
    ", then",
    ". Also,",
    ". Meanwhile,",
    ", with",
    ". At the same time,",
    ". Additionally,",
    ", making sure to"
]

def generate_dynamic_template(num_params):
    """Generate a more natural template with a specific number of parameters"""
    param_mapping = {
        "CameraVerticalAngle": "angle",
        "ShotSize": "shot_size",
        "MovementSpeed": "movement_speed",
        "CameraMovementType": "movement_type",
        "SubjectInFramePosition": "frame_position",
        "SubjectView": "subject_view"
    }

    # Select random parameters
    available_params = list(set(param_mapping.keys()))
    selected_params = random.sample(available_params, min(num_params, len(available_params)))

    # Build template with more natural language structure
    template_parts = []
    used_components = set()

    # Randomly decide whether to use an opening phrase
    if random.random() < 0.7:  # 70% chance to use opening phrase
        template_parts.append(random.choice(OPENING_PHRASES))

    # Generate component parts
    component_parts = []
    for param in selected_params:
        component_key = param_mapping[param]
        if component_key not in used_components:
            component_parts.append(random.choice(TEMPLATE_COMPONENTS[component_key]))
            used_components.add(component_key)

    # Randomly arrange components with connecting phrases
    while component_parts:
        template_parts.append(component_parts.pop(random.randint(0, len(component_parts)-1)))
        if component_parts and random.random() < 0.7:  # 70% chance to add connector
            template_parts.append(random.choice(CONNECTING_PHRASES))

    # Join all parts and clean up any double spaces or awkward punctuation
    template = " ".join(template_parts)
    template = template.replace(" ,", ",")
    template = template.replace("  ", " ")
    template = template.strip()

    # Ensure proper ending punctuation
    if not template.endswith((".","!")):
        template += "."
    return template, selected_params

def generate_prompt(num_params):
    """Generate a single prompt with specified number of parameters"""
    template, selected_params = generate_dynamic_template(num_params)

    params = {}
    original_params = {}

    # Generate values only for selected parameters
    for param_name in selected_params:
        value = random.choice(CAMERA_PARAMETERS[param_name]["values"])
        synonym = random.choice(CAMERA_PARAMETERS[param_name]["synonyms"][value])

        params[param_name.lower()] = synonym
        original_params[param_name] = value

    # Prepare template parameters
    template_params = {
        "angle": params.get("cameraverticalangle", ""),
        "shot_size": params.get("shotsize", ""),
        "movement_type": params.get("cameramovementtype", ""),
        "movement_speed": params.get("movementspeed", ""),
        "frame_position": params.get("subjectinframeposition", ""),
        "subject_view": params.get("subjectview", "")
    }

    try:
        prompt = template.format(**template_params)
        # Clean up any potential double spaces
        prompt = " ".join(prompt.split())
    except KeyError as e:
        print(f"Template error: {e}")
        print(f"Template: {template}")
        print(f"Params: {template_params}")
        return None

    return {
        "prompt": prompt,
        "parameters": original_params,
        "template": template
    }

def generate_dataset(num_samples=10000):
    """Generate a dataset of prompts with varying numbers of parameters"""
    dataset = []
    for _ in range(num_samples):
        num_params = random.randint(2, 6)
        entry = generate_prompt(num_params)
        if entry:
            dataset.append(entry)
    return dataset

# Generate and save the dataset
if __name__ == "__main__":
    dataset = generate_dataset()

    with open("generated_variable_camera_prompts.json", "w") as f:
        json.dump(dataset, f, indent=2)

    print("\nExample generated prompts:")
    for i in range(min(3, len(dataset))):
        print(f"\nPrompt {i+1}:")
        print("Text:", dataset[i]["prompt"])
        print("Parameters:", dataset[i]["parameters"])
        print("Template:", dataset[i]["template"])


Example generated prompts:

Prompt 1:
Text: with a decreasing speed pace . Additionally, upward boom the camera decreasing speed and using a high-angle vantage point presenting the partial back left with the subject positioned topmost edge.
Parameters: {'CameraVerticalAngle': 'high', 'MovementSpeed': 'fastToSlow', 'CameraMovementType': 'craneUp', 'SubjectView': 'threeQuarterBackLeft', 'SubjectInFramePosition': 'outerTop'}
Template: with a {movement_speed} pace . Additionally, {movement_type} the camera {movement_speed} and using a {angle} vantage point presenting the {subject_view} with the subject positioned {frame_position}.

Prompt 2:
Text: Plan to get s-curve movement the camera maintained velocity s-curve movement the camera, making sure to compose a master shot highlighting their left profile with the camera neutral, with maintaining the subject border right.
Parameters: {'MovementSpeed': 'constant', 'CameraVerticalAngle': 'eye', 'CameraMovementType': 'snakeTrack', 'ShotSize': '

In [3]:
dataset[900]

{'prompt': 'We need implementing a sweep up as positioning our focus extreme right, making sure to featuring their diagonal back left design a super close-up composition at a irregular rate as positioned at a top-down level.',
 'parameters': {'CameraMovementType': 'craneUp',
  'MovementSpeed': 'erratic',
  'CameraVerticalAngle': 'overhead',
  'SubjectInFramePosition': 'outerRight',
  'ShotSize': 'extremeCloseUp',
  'SubjectView': 'threeQuarterBackLeft'},
 'template': 'We need implementing a {movement_type} as positioning our focus {frame_position}, making sure to featuring their {subject_view} design a {shot_size} composition at a {movement_speed} rate as positioned at a {angle} level.'}

In [4]:
import json

def analyze_template_parameters(dataset_path):
    """
    Load a dataset generated by the template-based system and create binary arrays
    indicating which parameters are present in each entry.

    Args:
        dataset_path (str): Path to the JSON dataset file

    Returns:
        list: List of binary arrays where 1 indicates parameter presence and 0 indicates absence
    """
    # Define the ordered list of all possible parameters
    all_parameters = [
        "CameraVerticalAngle",
        "ShotSize",
        "MovementSpeed",
        "SubjectInFramePosition",
        "SubjectView",
        "CameraMovementType"
    ]

    # Load the dataset
    with open(dataset_path, 'r') as f:
        dataset = json.load(f)

    # Initialize the result list
    binary_representations = []

    # Process each entry in the dataset
    for entry in dataset:
        # Get the parameters present in this entry
        present_parameters = entry['parameters'].keys()
        # Create binary array for this entry
        binary_array = [1 if param in present_parameters else 0 for param in all_parameters]
        binary_representations.append(binary_array)

    return binary_representations

# Example usage and verification
if __name__ == "__main__":
    # Analyze the dataset
    result = analyze_template_parameters("generated_variable_camera_prompts.json")

    # Print results with parameter names for verification
    all_parameters = [
        "CameraVerticalAngle",
        "ShotSize",
        "MovementSpeed",
        "SubjectInFramePosition",
        "SubjectView",
        "CameraMovementType"
    ]

    print(f"\nAnalyzed {len(result)} entries")
    print("\nExample entries:")

    # Print first 3 entries with detailed information
    mask = []
    for i, binary_array in enumerate(result[:3]):
        print(f"\nEntry {i + 1}:")
        print(len(dataset[i]['parameters']))
        print("Parameter Presence:")
        for param, present in zip(all_parameters, binary_array):
            status = "Present" if present == 1 else "Absent"
            print(f"{param}: {status}")
        print(f"Binary representation: {binary_array}")
        mask.append(binary_array)


    # Print some statistics
    param_counts = [sum(x) for x in zip(*result)]
    print("\nParameter usage statistics:")
    for param, count in zip(all_parameters, param_counts):
        percentage = (count / len(result)) * 100
        print(f"{param}: used in {count} entries ({percentage:.1f}%)")



Analyzed 10000 entries

Example entries:

Entry 1:
5
Parameter Presence:
CameraVerticalAngle: Present
ShotSize: Absent
MovementSpeed: Present
SubjectInFramePosition: Present
SubjectView: Present
CameraMovementType: Present
Binary representation: [1, 0, 1, 1, 1, 1]

Entry 2:
6
Parameter Presence:
CameraVerticalAngle: Present
ShotSize: Present
MovementSpeed: Present
SubjectInFramePosition: Present
SubjectView: Present
CameraMovementType: Present
Binary representation: [1, 1, 1, 1, 1, 1]

Entry 3:
2
Parameter Presence:
CameraVerticalAngle: Absent
ShotSize: Present
MovementSpeed: Absent
SubjectInFramePosition: Absent
SubjectView: Present
CameraMovementType: Absent
Binary representation: [0, 1, 0, 0, 1, 0]

Parameter usage statistics:
CameraVerticalAngle: used in 6626 entries (66.3%)
ShotSize: used in 6660 entries (66.6%)
MovementSpeed: used in 6757 entries (67.6%)
SubjectInFramePosition: used in 6685 entries (66.8%)
SubjectView: used in 6668 entries (66.7%)
CameraMovementType: used in 665

##Training (BERT)

In [5]:
# Define all possible parameter keys and values
parameter_keys = {
    "CameraVerticalAngle": ["low", "eye", "high", "overhead", "birdsEye", "Not_Specified"],
    "ShotSize": [
        "extremeCloseUp",
            "closeUp",
            "mediumCloseUp",
            "mediumShot",
            "fullShot",
            "longShot",
            "veryLongShot",
            "extremeLongShot"
       , "Not_Specified"],
    "MovementSpeed": [
        "slowToFast",
            "fastToSlow",
            "constant",
            "stopAndGo",
            "deliberateStartStop",
            "erratic",
            "pulsing",
        "Not_Specified"],
    "SubjectInFramePosition": [
         "left", "right", "top", "bottom", "center",
            "topLeft", "topRight", "bottomLeft", "bottomRight",
            "outerLeft", "outerRight", "outerTop", "outerBottom",
            "offsetCenter",
            "diagonal",
        "Not_Specified"],
    "SubjectView": [
        "front",
            "back",
            "left",
            "right",
            "threeQuarterFrontLeft",
            "threeQuarterFrontRight",
            "threeQuarterBackLeft",
            "threeQuarterBackRight",
            "overhead",
            "silhouette",
        "Not_Specified"],
    "CameraMovementType": [
        "static", "panLeft", "panRight", "tiltUp", "tiltDown",
            "dollyIn", "dollyOut", "truckLeft", "truckRight",
            "pedestalUp", "pedestalDown", "arcLeft", "arcRight",
            "craneUp", "craneDown", "dollyOutZoomIn", "dollyInZoomOut",
            "dutchLeft", "dutchRight", "follow",
            "spiral",
            "snakeTrack",
            "boomerang",
        "Not_Specified"]
}


In [6]:
class ParameterEncoder:
    def __init__(self, parameter_keys):
        self.parameter_keys = parameter_keys
        self.num_classes = sum(len(values) for values in parameter_keys.values())

    def encode(self, parameters):
        encoded = []
        for key, values in self.parameter_keys.items():
            class_index = values.index(parameters.get(key, "Not_Specified"))
            encoded.append((key, class_index))  # Store (key, class_index) pairs
        return encoded

    def decode(self, encoded_labels):
        decoded_params = {}
        for key, class_index in encoded_labels:
            decoded_params[key] = self.parameter_keys[key][class_index]
        return decoded_params

In [7]:
import torch
from torch import nn, optim
from torch.utils.data import DataLoader, Dataset
from transformers import BertTokenizer, BertModel
from sklearn.model_selection import train_test_split
import json
import numpy as np

# Load the dataset
with open("generated_variable_camera_prompts.json", "r") as f:
    dataset = json.load(f)

# Extract prompts and parameters
prompts = [item["prompt"] for item in dataset]
parameters = [item["parameters"] for item in dataset]

# Instantiate the encoder
encoder = ParameterEncoder(parameter_keys)

# Encode parameters using the new encoder
encoded_parameters = [encoder.encode(p) for p in parameters]
print(len(encoded_parameters))

# Train/test split
X_train, X_test, y_train, y_test = train_test_split(prompts, encoded_parameters, test_size=0.2, random_state=42)

10000


In [8]:
y_train[0]

[('CameraVerticalAngle', 5),
 ('ShotSize', 0),
 ('MovementSpeed', 0),
 ('SubjectInFramePosition', 15),
 ('SubjectView', 10),
 ('CameraMovementType', 12)]

In [12]:
class CameraDataset(Dataset):
    def __init__(self, prompts, labels, tokenizer, parameter_keys, max_length=128):
        self.prompts = prompts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.parameter_keys = parameter_keys  # Pass parameter_keys to the dataset

    def __len__(self):
        return len(self.prompts)

    def __getitem__(self, idx):
        prompt = self.prompts[idx]
        label = self.labels[idx]
        encoding = self.tokenizer(prompt, max_length=self.max_length, padding="max_length", truncation=True, return_tensors="pt")

        # Create a one-hot encoded label vector for each parameter
        label_vector = []
        for key, values in self.parameter_keys.items():
            one_hot = [0] * len(values)
            for param_key, class_index in label:
                if param_key == key:
                    one_hot[class_index] = 1
                    break  # Move to the next parameter after finding a match
            label_vector.extend(one_hot)

        return encoding["input_ids"].squeeze(0), encoding["attention_mask"].squeeze(0), torch.tensor(label_vector, dtype=torch.float)


# Initialize tokenizer and model
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
bert_model = BertModel.from_pretrained("bert-base-uncased")

# Create datasets and dataloaders
train_dataset = CameraDataset(X_train, y_train, tokenizer, parameter_keys)
test_dataset = CameraDataset(X_test, y_test, tokenizer, parameter_keys)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=16)

# Define the model
class CameraPredictor(nn.Module):
    def __init__(self, bert_model, parameter_keys):
        super(CameraPredictor, self).__init__()
        self.bert = bert_model
        self.dropout = nn.Dropout(0.3)
        self.feature_classifiers = nn.ModuleDict()  # Dictionary to hold feature-specific MLPs

        for key, values in parameter_keys.items():
            self.feature_classifiers[key] = nn.Sequential(
                nn.Linear(bert_model.config.hidden_size, 64),  # Hidden layer
                nn.ReLU(),
                nn.Linear(64, len(values))  # Output layer
            )

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.pooler_output
        x = self.dropout(pooled_output)

        feature_outputs = {}
        for key, classifier in self.feature_classifiers.items():
            feature_outputs[key] = torch.softmax(classifier(x), dim=1)  # Apply softmax for single-value prediction
        return feature_outputs

# Instantiate the model
model = CameraPredictor(bert_model, parameter_keys)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# Loss and Optimizer (modified for multi-task learning)
criterion = nn.CrossEntropyLoss()  # Use CrossEntropyLoss for single class
optimizer = optim.AdamW(model.parameters(), lr=1e-4)


def calculate_accuracy(outputs, labels):
    total_correct = 0
    total_predictions = 0

    for key in outputs:
        feature_outputs = outputs[key]
        start_index = sum(len(parameter_keys[k]) for k in parameter_keys if k < key)
        end_index = start_index + len(parameter_keys[key])
        feature_labels = labels[:, start_index:end_index]

        _, predicted_indices = torch.max(feature_outputs, dim=1)  # Get predicted indices
        _, true_indices = torch.max(feature_labels, dim=1)  # Get true indices

        correct_predictions = (predicted_indices == true_indices).float()
        total_correct += correct_predictions.sum().item()
        total_predictions += correct_predictions.numel()

    accuracy = total_correct / total_predictions if total_predictions else 0
    return accuracy

def calculate_loss(outputs, labels):
    total_loss = 0
    for key, values in parameter_keys.items():
        start_index = sum(len(parameter_keys[k]) for k in parameter_keys if k < key)
        end_index = start_index + len(values)
        feature_labels = labels[:, start_index:end_index].to(device)  # Extract labels for the feature
        feature_outputs = outputs[key]  # Get output for the feature
        loss = criterion(feature_outputs, feature_labels)
        total_loss += loss
    return total_loss

def train_model(model, train_loader, test_loader, epochs=5):
    best_val_accuracy = 0
    best_model_state_dict = None  # Store the state dict of the best model

    for epoch in range(epochs):
        model.train()
        train_loss = 0
        train_accuracy = 0
        num_train_batches = 0

        for input_ids, attention_mask, labels in train_loader:
            input_ids = input_ids.to(device)
            attention_mask = attention_mask.to(device)
            labels = labels.to(device)
            optimizer.zero_grad()
            outputs = model(input_ids, attention_mask)
            total_loss = calculate_loss(outputs, labels)
            total_loss.backward()
            optimizer.step()

            # Calculate training accuracy
            train_accuracy += calculate_accuracy(outputs, labels)
            train_loss += total_loss.item()
            num_train_batches += 1

        avg_train_loss = train_loss / num_train_batches
        avg_train_accuracy = train_accuracy / num_train_batches

        # Evaluate
        model.eval()
        test_loss = 0
        test_accuracy = 0
        num_test_batches = 0

        with torch.no_grad():
            for input_ids, attention_mask, labels in test_loader:
                input_ids = input_ids.to(device)
                attention_mask = attention_mask.to(device)
                labels = labels.to(device)
                outputs = model(input_ids, attention_mask)
                total_loss = calculate_loss(outputs, labels)

                # Calculate validation accuracy
                test_accuracy += calculate_accuracy(outputs, labels)
                test_loss += total_loss.item()
                num_test_batches += 1

        avg_test_loss = test_loss / num_test_batches
        avg_test_accuracy = test_accuracy / num_test_batches

        # Track best validation accuracy
        if avg_test_accuracy > best_val_accuracy:
            best_val_accuracy = avg_test_accuracy
            best_model_state_dict = model.state_dict()  # Save the model's state dict


        print(f"Epoch {epoch+1}")
        print(f"Training Loss: {avg_train_loss:.4f}, Training Accuracy: {avg_train_accuracy:.4f}")
        print(f"Validation Loss: {avg_test_loss:.4f}, Validation Accuracy: {avg_test_accuracy:.4f}")
        print("-" * 50)

    # Load the best model after training
    if best_model_state_dict is not None:
        model.load_state_dict(best_model_state_dict)  # Load the best model's state dict
        print("Loaded best model with accuracy:", best_val_accuracy)

# Train the model and evaluate
train_model(model, train_loader, test_loader, epochs=2)


Epoch 1
Training Loss: 15.8146, Training Accuracy: 0.2743
Validation Loss: 15.7807, Validation Accuracy: 0.2857
--------------------------------------------------
Epoch 2
Training Loss: 15.7880, Training Accuracy: 0.2881
Validation Loss: 15.7820, Validation Accuracy: 0.2857
--------------------------------------------------
Loaded best model with accuracy: 0.28566666666666674


In [50]:
import torch
import numpy as np
from torch.nn import functional as F

def predict(prompts, labels, model, parameter_keys, encoder, return_raw=False):

    model.eval()
    all_predicted_params = []
    all_raw_outputs = []
    total_loss = 0
    total_accuracy = 0
    num_prompts = 0

    with torch.no_grad():
      for i, prompt in enumerate(prompts):
        encoded_label = encoder.encode(labels[i])

        encoding = tokenizer(prompt, max_length=128, padding="max_length", truncation=True, return_tensors="pt")
        input_ids = encoding["input_ids"].to(device)
        attention_mask = encoding["attention_mask"].to(device)

        outputs = model(input_ids, attention_mask)
        raw_outputs = {key: output.cpu().numpy() for key, output in outputs.items()}
        if return_raw:
            all_raw_outputs.append(raw_outputs)
        else:
            predicted_params = {}
            start_idx = 0
            for key, values in parameter_keys.items():
                predicted_value_idx = np.argmax(raw_outputs[key][0])
                predicted_params[key] = values[predicted_value_idx]
            all_predicted_params.append(predicted_params)

        # Prepare the label_vector in one-hot encoded format from the encoder
        label_vector = []
        for key, values in parameter_keys.items():
            one_hot = [0] * len(values)
            for param_key, class_index in encoded_label:
                if key == param_key:
                    one_hot[class_index] = 1
                else:
                    one_hot[values.index("Not_Specified")] = 1
            label_vector.extend(one_hot)

        label_tensor = torch.tensor(label_vector, dtype=torch.float).to(device)

        # Compute loss and accuracy using the provided functions
        batch_loss = calculate_loss(outputs, torch.unsqueeze(label_tensor,0))
        batch_accuracy = calculate_accuracy(outputs, torch.unsqueeze(label_tensor,0))
        total_loss += batch_loss.item()
        total_accuracy += batch_accuracy
        num_prompts +=1

    avg_loss = total_loss / num_prompts if num_prompts > 0 else 0
    avg_accuracy = total_accuracy / num_prompts if num_prompts > 0 else 0

    return all_predicted_params if not return_raw else None, all_raw_outputs, avg_accuracy, avg_loss

In [51]:
# Assuming model, tokenizer, parameter_keys, device, encoder, calculate_loss and calculate_accuracy are already defined
prompts_to_predict = [
    "The shot requires a low-angle perspective and a close-up.",
    "Set up a medium shot with the subject on the right.",
    "The scene calls for the camera to pan left and use a full shot.",
    "We need a very long shot of the subject with a pulsing speed."
]

# Dummy labels for demonstration, replace with your actual labels
labels_to_predict = [
    {"CameraVerticalAngle": "low", "ShotSize": "closeUp"},
    {"ShotSize": "mediumShot", "SubjectInFramePosition": "right"},
    {"CameraMovementType": "panLeft", "ShotSize": "fullShot"},
    {"ShotSize": "veryLongShot", "MovementSpeed": "pulsing"}
]


predicted_params, raw_outputs, accuracy, loss = predict(prompts_to_predict, labels_to_predict, model, parameter_keys, encoder)

if predicted_params:
    for i, prompt in enumerate(prompts_to_predict):
        print(f"Prompt: {prompt}")
        print(f"Predicted Parameters: {predicted_params[i]}")

print(f"Overall Accuracy: {accuracy:.4f}")
print(f"Overall Loss: {loss:.4f}")

Prompt: The shot requires a low-angle perspective and a close-up.
Predicted Parameters: {'CameraVerticalAngle': 'high', 'ShotSize': 'extremeCloseUp', 'MovementSpeed': 'fastToSlow', 'SubjectInFramePosition': 'top', 'SubjectView': 'Not_Specified', 'CameraMovementType': 'dollyIn'}
Prompt: Set up a medium shot with the subject on the right.
Predicted Parameters: {'CameraVerticalAngle': 'high', 'ShotSize': 'extremeCloseUp', 'MovementSpeed': 'fastToSlow', 'SubjectInFramePosition': 'top', 'SubjectView': 'Not_Specified', 'CameraMovementType': 'dollyIn'}
Prompt: The scene calls for the camera to pan left and use a full shot.
Predicted Parameters: {'CameraVerticalAngle': 'high', 'ShotSize': 'extremeCloseUp', 'MovementSpeed': 'fastToSlow', 'SubjectInFramePosition': 'top', 'SubjectView': 'Not_Specified', 'CameraMovementType': 'dollyIn'}
Prompt: We need a very long shot of the subject with a pulsing speed.
Predicted Parameters: {'CameraVerticalAngle': 'high', 'ShotSize': 'extremeCloseUp', 'Movement

Evaluation on LLM generated Prompt

In [None]:
# Load the dataset
with open("focused_camera_prompts_dataset.json", "r") as f:
    dataset = json.load(f)

# Extract prompts and parameters
prompts = [item["prompt"] for item in dataset]
parameters = [item["parameters"] for item in dataset]

def standardize_parameters(params):
    """
    Add 'Not_Specified' for any missing parameters from the set of 6 standard parameters.
    """
    all_parameters = [
        "CameraVerticalAngle",
        "ShotSize",
        "MovementSpeed",
        "SubjectInFramePosition",
        "SubjectView",
        "CameraMovementType"
    ]

    standardized_params = params.copy()
    for param in all_parameters:
        if param not in standardized_params:
            standardized_params[param] = "Not_Specified"

    return standardized_params
parameters = np.array([standardize_parameters(p) for p in parameters])


In [None]:
for i in range(len(prompts)):
    predicted = predict( prompts[i])
    print("Prompt:",  prompts[i])
    print("Predicted Parameters:", predicted)
    print("Actual Parameters:", parameters[i])
    print(" \n")

Prompt: The camera begins with a low vertical angle, capturing a full shot of the scene. As it moves swiftly, it glides towards the subject positioned at the outer top of the frame. The rapid pace creates an exhilarating feeling, but as it approaches, the movement slows, allowing viewers to take in the details. The framing emphasizes the subject while revealing the context surrounding them, blending action and ambiance seamlessly.
Predicted Parameters: {'CameraVerticalAngle': 'low', 'ShotSize': 'fullShot', 'MovementSpeed': 'Not_Specified', 'SubjectInFramePosition': 'top', 'SubjectView': 'Not_Specified', 'CameraMovementType': 'Not_Specified'}
Actual Parameters: {'CameraVerticalAngle': 'low', 'ShotSize': 'fullShot', 'MovementSpeed': 'fastToSlow', 'SubjectInFramePosition': 'outerTop', 'SubjectView': 'Not_Specified', 'CameraMovementType': 'Not_Specified'}
 

Prompt: The camera hovers overhead, capturing an extreme close-up of the subject nestled in the bottom right of the frame. As it foll

In [None]:
def compute_accuracy(prompts, parameters, predict_fn):
    """
    Compute accuracy for parameter predictions.
    """
    all_parameters = [
        "CameraVerticalAngle",
        "ShotSize",
        "MovementSpeed",
        "SubjectInFramePosition",
        "SubjectView",
        "CameraMovementType"
    ]

    # Initialize counters for each parameter
    param_correct = {param: 0 for param in all_parameters}
    param_total = {param: 0 for param in all_parameters}
    total_correct = 0
    total_predictions = 0

    # Track mismatches for error analysis
    mismatches = []

    for i in range(len(prompts)):
        example_prompt = prompts[i]
        predicted = predict_fn(example_prompt)
        actual = standardize_parameters(parameters[i])

        # Track mismatches for this prompt
        prompt_mismatches = []

        # Check each parameter
        for param in all_parameters:
            param_total[param] += 1
            total_predictions += 1

            pred_value = predicted.get(param, "Not_Specified")
            actual_value = actual.get(param, "Not_Specified")

            if pred_value == actual_value:
                param_correct[param] += 1
                total_correct += 1
            else:
                prompt_mismatches.append({
                    'parameter': param,
                    'predicted': pred_value,
                    'actual': actual_value
                })

        if prompt_mismatches:
            mismatches.append({
                'prompt': example_prompt,
                'mismatches': prompt_mismatches
            })

    # Calculate accuracies
    overall_accuracy = total_correct / total_predictions
    param_accuracies = {param: param_correct[param] / param_total[param]
                       for param in all_parameters}

    # Print results
    print("\nOverall Accuracy:", f"{overall_accuracy:.4f}")
    print("\nPer-Parameter Accuracies:")
    for param, acc in param_accuracies.items():
        print(f"{param}: {acc:.4f}")

    # Print example mismatches
    print("\nExample Mismatches (first 5):")
    for mismatch in mismatches[:5]:
        print(f"\nPrompt: {mismatch['prompt']}")
        print("Mismatched Parameters:")
        for error in mismatch['mismatches']:
            print(f"- {error['parameter']}: predicted '{error['predicted']}' instead of '{error['actual']}'")

    return {
        'overall_accuracy': overall_accuracy,
        'parameter_accuracies': param_accuracies,
        'mismatches': mismatches
    }

# Use the function
results = compute_accuracy(prompts, parameters, predict)


Overall Accuracy: 0.5250

Per-Parameter Accuracies:
CameraVerticalAngle: 0.5500
ShotSize: 0.6800
MovementSpeed: 0.3800
SubjectInFramePosition: 0.6000
SubjectView: 0.7200
CameraMovementType: 0.2200

Example Mismatches (first 5):

Prompt: The camera begins with a low vertical angle, capturing a full shot of the scene. As it moves swiftly, it glides towards the subject positioned at the outer top of the frame. The rapid pace creates an exhilarating feeling, but as it approaches, the movement slows, allowing viewers to take in the details. The framing emphasizes the subject while revealing the context surrounding them, blending action and ambiance seamlessly.
Mismatched Parameters:
- MovementSpeed: predicted 'Not_Specified' instead of 'fastToSlow'
- SubjectInFramePosition: predicted 'top' instead of 'outerTop'

Prompt: The camera hovers overhead, capturing an extreme close-up of the subject nestled in the bottom right of the frame. As it follows the subject's movements, it glides smoothly

## Dataset Generation (Using LLM)

In [None]:
! pip install --upgrade openai




In [None]:
from openai import OpenAI
import random
import json

# Initialize OpenAI client with your API key
client = OpenAI(api_key="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")

# Parameters for generating prompts
parameters = {
    "CameraVerticalAngle": ["low", "eye", "high", "overhead", "birdsEye"],
    "ShotSize": [
        "extremeCloseUp",
        "closeUp",
        "mediumCloseUp",
        "mediumShot",
        "fullShot",
        "longShot",
        "veryLongShot",
        "extremeLongShot",
    ],
    "MovementSpeed": [
        "slowToFast",
        "fastToSlow",
        "constant",
        "stopAndGo",
        "deliberateStartStop",
    ],
    "SubjectInFramePosition": [
        "left",
        "right",
        "top",
        "bottom",
        "center",
        "topLeft",
        "topRight",
        "bottomLeft",
        "bottomRight",
        "outerLeft",
        "outerRight",
        "outerTop",
        "outerBottom",
    ],
    "SubjectView": [
        "front",
        "back",
        "left",
        "right",
        "threeQuarterFrontLeft",
        "threeQuarterFrontRight",
        "threeQuarterBackLeft",
        "threeQuarterBackRight",
    ],
    "CameraMovementType": [
        "static",
        "panLeft",
        "panRight",
        "tiltUp",
        "tiltDown",
        "dollyIn",
        "dollyOut",
        "truckLeft",
        "truckRight",
        "pedestalUp",
        "pedestalDown",
        "arcLeft",
        "arcRight",
        "craneUp",
        "craneDown",
        "dollyOutZoomIn",
        "dollyInZoomOut",
        "dutchLeft",
        "dutchRight",
        "follow",
    ],
}

# Function to interact with OpenAI API
def chat_gpt(prompt):
    response = client.chat.completions.create(
        model="gpt-4o-mini",  # Use the gpt-4o-mini model
        messages=[{"role": "user", "content": prompt}],
    )
    return response.choices[0].message.content.strip()

# Function to generate a single diverse user prompt
def generate_single_prompt():
    input_params = {
        "CameraVerticalAngle": random.choice(parameters["CameraVerticalAngle"]),
        "ShotSize": random.choice(parameters["ShotSize"]),
        "MovementSpeed": random.choice(parameters["MovementSpeed"]),
        "SubjectInFramePosition": random.choice(parameters["SubjectInFramePosition"]),
        "SubjectView": random.choice(parameters["SubjectView"]),
        "CameraMovementType": random.choice(parameters["CameraMovementType"]),
    }

    # Allow random omission of some parameters
    params_to_include = random.sample(list(input_params.keys()), random.randint(3, len(input_params)))
    filtered_params = {key: value for key, value in input_params.items() if key in params_to_include}

    # Create a dynamic and focused prompt
    prompt = (
        f"You are describing a camera setup and movement. Focus only on the camera's movements, "
        f"angles, framing, and motion. Here are the camera "
        f"parameters to consider:\n\n"
        + "\n".join([f"- {key.replace('Camera', '').replace('InFrame', ' In Frame')} is {value}" for key, value in filtered_params.items()])
        + "\n\n"
        f"Please describe the shot naturally and realistically, using varied and human-like expressions. Be concise. "
    )

    #print (f"Prompt: \n {prompt} \n")

    # Call GPT to generate a prompt
    user_prompt = chat_gpt(prompt)

    #print (f"Output: \n {user_prompt} \n")
    return {"prompt": user_prompt, "parameters": filtered_params}

# Generate a dataset
def generate_dataset(num_samples=100):
    dataset = []
    for _ in range(num_samples):
        dataset.append(generate_single_prompt())
    return dataset

# Save the dataset
dataset = generate_dataset()
with open("focused_camera_prompts_dataset.json", "w") as f:
    json.dump(dataset, f, indent=2)


In [None]:
dataset[0].keys()

dict_keys(['prompt', 'parameters'])

In [None]:

def analyze_parameters(dataset_path):
    """
    Load a dataset and create binary arrays indicating which parameters are present in each entry.

    Args:
        dataset_path (str): Path to the JSON dataset file

    Returns:
        list: List of binary arrays where 1 indicates parameter presence and 0 indicates absence
    """
    # Define the ordered list of all possible parameters
    all_parameters = [
        "CameraVerticalAngle",
        "ShotSize",
        "MovementSpeed",
        "SubjectInFramePosition",
        "SubjectView",
        "CameraMovementType"
    ]

    # Load the dataset
    with open(dataset_path, 'r') as f:
        dataset = json.load(f)

    # Initialize the result list
    binary_representations = []

    # Process each entry in the dataset
    for entry in dataset:
        # Get the parameters present in this entry
        present_parameters = entry['parameters'].keys()

        # Create binary array for this entry
        binary_array = [1 if param in present_parameters else 0 for param in all_parameters]
        binary_representations.append(binary_array)

    return binary_representations

# Example usage:
if __name__ == "__main__":
    # Analyze the dataset
    result = analyze_parameters("focused_camera_prompts_dataset.json")

    # Print results with parameter names for verification
    all_parameters = [
        "CameraVerticalAngle",
        "ShotSize",
        "MovementSpeed",
        "SubjectInFramePosition",
        "SubjectView",
        "CameraMovementType"
    ]
    mask = []
    for i, binary_array in enumerate(result):
        print(f"\nEntry {i + 1}:")
        print(len(dataset[i]['parameters']))
        print("Parameter Presence:")
        for param, present in zip(all_parameters, binary_array):
            status = "Present" if present == 1 else "Absent"
            print(f"{param}: {status}")
        print(f"Binary representation: {binary_array}")
        mask.append(binary_array)


Entry 1:
4
Parameter Presence:
CameraVerticalAngle: Present
ShotSize: Absent
MovementSpeed: Present
SubjectInFramePosition: Present
SubjectView: Absent
CameraMovementType: Present
Binary representation: [1, 0, 1, 1, 0, 1]

Entry 2:
6
Parameter Presence:
CameraVerticalAngle: Present
ShotSize: Present
MovementSpeed: Present
SubjectInFramePosition: Present
SubjectView: Present
CameraMovementType: Present
Binary representation: [1, 1, 1, 1, 1, 1]

Entry 3:
6
Parameter Presence:
CameraVerticalAngle: Present
ShotSize: Present
MovementSpeed: Present
SubjectInFramePosition: Present
SubjectView: Present
CameraMovementType: Present
Binary representation: [1, 1, 1, 1, 1, 1]

Entry 4:
3
Parameter Presence:
CameraVerticalAngle: Absent
ShotSize: Absent
MovementSpeed: Present
SubjectInFramePosition: Present
SubjectView: Present
CameraMovementType: Absent
Binary representation: [0, 0, 1, 1, 1, 0]

Entry 5:
5
Parameter Presence:
CameraVerticalAngle: Present
ShotSize: Present
MovementSpeed: Present
Sub

In [None]:
mask

[[1, 0, 1, 1, 0, 1],
 [1, 1, 1, 1, 1, 1],
 [1, 1, 1, 1, 1, 1],
 [0, 0, 1, 1, 1, 0],
 [1, 1, 1, 1, 1, 0],
 [1, 1, 1, 1, 1, 1],
 [0, 0, 1, 1, 0, 1],
 [1, 1, 0, 1, 1, 1],
 [1, 1, 1, 1, 1, 1],
 [1, 1, 1, 1, 1, 1]]

## Traning (CLIP)

In [None]:
! pip install git+https://github.com/openai/CLIP.git

In [None]:
import torch
from torch import nn, optim
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import train_test_split
import clip
import json
import numpy as np
from transformers import CLIPTokenizer

# Load the dataset
with open("generated_camera_prompts.json", "r") as f:
    dataset = json.load(f)

# Extract prompts and parameters
prompts = [item["prompt"] for item in dataset]
parameters = [item["parameters"] for item in dataset]

# Define all possible parameter keys and values (same as before)
parameter_keys = {
    "CameraVerticalAngle": ["low", "eye", "high", "overhead", "birdsEye"],
    "ShotSize": [
        "extremeCloseUp", "closeUp", "mediumCloseUp", "mediumShot",
        "fullShot", "longShot", "veryLongShot", "extremeLongShot",
    ],
    "MovementSpeed": [
        "slowToFast", "fastToSlow", "constant", "stopAndGo",
        "deliberateStartStop",
    ],
    "SubjectInFramePosition": [
        "left", "right", "top", "bottom", "center", "topLeft", "topRight",
        "bottomLeft", "bottomRight", "outerLeft", "outerRight", "outerTop",
        "outerBottom",
    ],
    "SubjectView": [
        "front", "back", "left", "right", "threeQuarterFrontLeft",
        "threeQuarterFrontRight", "threeQuarterBackLeft", "threeQuarterBackRight",
    ],
    "CameraMovementType": [
        "static", "panLeft", "panRight", "tiltUp", "tiltDown", "dollyIn",
        "dollyOut", "truckLeft", "truckRight", "pedestalUp", "pedestalDown",
        "arcLeft", "arcRight", "craneUp", "craneDown", "dollyOutZoomIn",
        "dollyInZoomOut", "dutchLeft", "dutchRight", "follow",
    ],
}

# One-hot encode parameters (same as before)
def encode_parameters(parameters):
    encoded = []
    for key, values in parameter_keys.items():
        vec = [0] * len(values)
        if key in parameters and parameters[key] in values:
            vec[values.index(parameters[key])] = 1
        encoded.extend(vec)
    return encoded

encoded_parameters = np.array([encode_parameters(p) for p in parameters])

# Train/test split
X_train, X_test, y_train, y_test = train_test_split(
    prompts, encoded_parameters, test_size=0.2, random_state=42
)

# Define a custom dataset for CLIP
class CLIPCameraDataset(Dataset):
    def __init__(self, prompts, labels, tokenizer, max_length=77):  # CLIP's max token length is 77
        self.prompts = prompts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.prompts)

    def __getitem__(self, idx):
        prompt = self.prompts[idx]
        label = self.labels[idx]

        # Tokenize with truncation and padding
        encoding = self.tokenizer(
            prompt,
            max_length=self.max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )

        return (
            encoding.input_ids.squeeze(0),
            encoding.attention_mask.squeeze(0),
            torch.tensor(label, dtype=torch.float)
        )

# Initialize CLIP model and tokenizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
clip_model, _ = clip.load("ViT-B/32", device=device)
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")

# Create datasets and dataloaders
train_dataset = CLIPCameraDataset(X_train, y_train, tokenizer)
test_dataset = CLIPCameraDataset(X_test, y_test, tokenizer)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=16)

# Define the model using CLIP's text encoder
class CLIPCameraPredictor(nn.Module):
    def __init__(self, clip_model, num_labels):
        super(CLIPCameraPredictor, self).__init__()
        self.clip = clip_model
        self.text_projection = clip_model.text_projection

        # Freeze CLIP parameters
        for param in self.clip.parameters():
            param.requires_grad = False

        # New layers for parameter prediction
        self.dropout = nn.Dropout(0.3)
        self.fc1 = nn.Linear(512, 256)  # CLIP's text embedding dimension is 512
        self.fc2 = nn.Linear(256, num_labels)

    def forward(self, input_ids, attention_mask):
        # Get CLIP text embeddings
        text_features = self.clip.encode_text(input_ids)
        text_features = text_features / text_features.norm(dim=1, keepdim=True)

        # Process through our layers
        x = self.dropout(text_features)
        x = torch.relu(self.fc1(x))
        x = self.dropout(x)
        x = torch.sigmoid(self.fc2(x))
        return x

# Instantiate the model
num_labels = y_train.shape[1]
model = CLIPCameraPredictor(clip_model, num_labels)
model = model.to(device)

# Define loss and optimizer
criterion = nn.BCELoss()
optimizer = optim.AdamW(
    [p for p in model.parameters() if p.requires_grad],
    lr=1e-4
)

# Training loop
def train_model(model, train_loader, test_loader, epochs=5):
    for epoch in range(epochs):
        model.train()
        train_loss = 0
        for input_ids, attention_mask, labels in train_loader:
            input_ids = input_ids.to(device)
            attention_mask = attention_mask.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()
            outputs = model(input_ids, attention_mask)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

        print(f"Epoch {epoch+1}, Loss: {train_loss / len(train_loader):.4f}")

        # Evaluate
        model.eval()
        test_loss = 0
        with torch.no_grad():
            for input_ids, attention_mask, labels in test_loader:
                input_ids = input_ids.to(device)
                attention_mask = attention_mask.to(device)
                labels = labels.to(device)
                outputs = model(input_ids, attention_mask)
                loss = criterion(outputs, labels)
                test_loss += loss.item()
        print(f"Validation Loss: {test_loss / len(test_loader):.4f}")

# Predict function
def predict(prompt):
    model.eval()
    encoding = tokenizer(
        prompt,
        max_length=77,
        padding="max_length",
        truncation=True,
        return_tensors="pt"
    )
    input_ids = encoding.input_ids.to(device)
    attention_mask = encoding.attention_mask.to(device)

    with torch.no_grad():
        outputs = model(input_ids, attention_mask).cpu().numpy()[0]

    predicted_params = {}
    start_idx = 0
    for key, values in parameter_keys.items():
        end_idx = start_idx + len(values)
        predicted_value_idx = np.argmax(outputs[start_idx:end_idx])
        if outputs[start_idx:end_idx][predicted_value_idx] > 0.1:
            predicted_params[key] = values[predicted_value_idx]
        start_idx = end_idx
    return predicted_params

# Train the model
train_model(model, train_loader, test_loader, epochs=50)

# Test prediction
example_prompt = X_test[0]
predicted = predict(example_prompt)
print("Prompt:", example_prompt)
print("Predicted Parameters:", predicted)
print("Actual Parameters:", y_test[0])

RuntimeError: mat1 and mat2 must have the same dtype, but got Half and Float

In [None]:
# Evaluate the model
model.eval()
correct, total = 0, 0
with torch.no_grad():
    for batch_X, batch_y in test_loader:
        outputs = torch.sigmoid(model(batch_X)) > 0.5
        correct += (outputs.numpy() == batch_y.numpy()).sum()
        total += batch_y.numel()

print(f"Accuracy: {correct / total:.2%}")
