# ICT3104 Project Team 12

## Setup
These libraries are required to run the tool.

In [2]:
# Import libraries
import os
import ipywidgets as widgets
from IPython.display import display, clear_output, Video
import yaml

## Data Exploration
This section lets you select and view the videos from the data/input_files/training_videos folder, allowing you to verify if you have chosen the correct video.
To view the videos, select an option from the dropdown video and wait for the video to be displayed.

In [None]:
# Populate dropdown with files in folder
folder_path = "./ict3104-team12-2023/data/input_files/training_videos/"
video_options = os.listdir(folder_path)

dropdown = widgets.Dropdown(options=video_options)

# Define a function to update and display the video based on the selected option
def update_video(change):
    selected_option = change['new']

    # Clear the previous video output before displaying the new one
    clear_output(wait=True)
    display(dropdown)

    # Check if file is in valid mp4 format
    if ".mp4" in selected_option:
        video_path = f"{folder_path}{selected_option}"
        video_display = Video(video_path, width=400, height=400, embed=True)
        display(video_display)
    else:
        print("Please choose a file with .mp4 format.")

# Register the function to be called when the dropdown value changes
dropdown.observe(update_video, names='value')

# Display the initial video based on the default selected option
initial_option = dropdown.value
update_video({'new': initial_option})

## Inference

In [None]:
# Create a text input widget for pretrained_model_path
pretrained_model_path_input = widgets.Text(
    value="./checkpoints/stable-diffusion-v1-4",
    description="Pretrained Model Path:",
)

# Declare the validation_data dictionary
validation_data = {
    'prompts': [
        "Iron man on the beach",
        "Stormtrooper on the sea",
        "Astronaut on the beach"
    ],
    'video_length': 32,
    'width': 512,
    'height': 512,
    'num_inference_steps': 50,
    'guidance_scale': 12.5,
    'use_inv_latent': False,
    'num_inv_steps': 50,
    'dataset_set': 'val'
}

# Create a text input widget for prompts
prompts_input = widgets.Textarea(
    value = "\n".join(validation_data['prompts']),
    description = "Prompts:",
    rows = 4
)

update_config_button = widgets.Button(description="Update Config")
yaml_file_path = "./ict3104-team12-2023/configs/pose_sample.yaml"
output_widget = widgets.Output()

def update_config(button_click):
    updated_path = pretrained_model_path_input.value
    updated_prompts = prompts_input.value.split("\n")

    # Read the existing YAML content from the file
    with open(yaml_file_path, 'r') as yaml_file:
        config_data = yaml.safe_load(yaml_file)

    config_data['pretrained_model_path'] = updated_path
    config_data['validation_data']['prompts'] = updated_prompts

    # Write the updated content back to the file
    with open(yaml_file_path, 'w') as yaml_file:
        yaml.dump(config_data, yaml_file, default_flow_style=None, sort_keys=False)

    # Display the updated YAML content in the output widget
    with output_widget:
        clear_output()
        print(yaml.dump(config_data, default_flow_style=None, sort_keys=False))

update_config_button.on_click(update_config)

# Display input fields
display(pretrained_model_path_input)
display(prompts_input)
# Preview
display(update_config_button, output_widget)

## Training

In [3]:
# Text input widget for the user's name
name_widget = widgets.Text(description='Name:')

# Dropdown widget for batch size
batch_size_options = [16, 32, 64, 128]
batch_size_widget = widgets.Dropdown(description='Batch Size:', options=batch_size_options)

# Dropdown widget for the number of epochs
epoch_options = [10, 20, 30, 40]
epoch_widget = widgets.Dropdown(description='Epochs:', options=epoch_options)

# Submit button
submit_button = widgets.Button(description='Submit')

# Output widget to display the results
output_widget = widgets.Output()

# Function to handle user input and display the results
def handle_submit(button_click):
    user_name = name_widget.value
    batch_size = batch_size_widget.value
    epochs = epoch_widget.value
    
    with output_widget:
        output_widget.clear_output()
        print(f"User Name: {user_name}")
        print(f"Batch Size: {batch_size}")
        print(f"Epochs: {epochs}")

# Attach the event handler to the submit button
submit_button.on_click(handle_submit)

# Display widgets
display(name_widget)
display(batch_size_widget)
display(epoch_widget)
display(submit_button)
display(output_widget)

Text(value='', description='Name:')

Dropdown(description='Batch Size:', options=(16, 32, 64, 128), value=16)

Dropdown(description='Epochs:', options=(10, 20, 30, 40), value=10)

Button(description='Submit', style=ButtonStyle())

Output()

## Testing