# Setup
Run the cell below to import and install the dependencies needed for the project.

In [None]:
# Functions to use along widgets
##General Imports
from IPython.display import HTML, display, Markdown, Video, clear_output
from ipywidgets import Layout
# Faciliate file selection
from tkinter import *
from tkinter import filedialog
# Widget Packages
import ipywidgets as widgets
# jupyter nbextension enable --py widgetsnbextension
# Used for local directory
import os
import sys
# import TSU.json_util as json_util
# from TSU.json_util import test_train_json

from IPython.display import HTML
from util_modules.ui_util.progress_bar import progress_bar

##Imports for captioning
import pandas
import cv2
import re

##Imports for Charting
import wandb
from dotenv import load_dotenv

##Imports for Feature Extraction
from os import walk
from video_features.utils.utils import build_cfg_path
from omegaconf import OmegaConf
import torch
from video_features.features_models.i3d.extract_i3d import ExtractI3D

##Turn off warnings for python
import warnings
warnings.filterwarnings('ignore')
warnings.simplefilter('ignore')

##Imports for TSU
from util_modules.json_util.test_train_json import inference_json

##Imports for Second Algorithm (STEP/3D3)

HTML('''<script>
code_show=true; 
function code_toggle() {
 if (code_show){
 $('div.input').hide();
 } else {
 $('div.input').show();
 }
 code_show = !code_show
} 
$( document ).ready(code_toggle);
</script>
<form action="javascript:code_toggle()"><input type="submit" value="Show/hide code"></form>''')

# Welcome to the HOI Interactive Notebook Tool 

### Quick-Start:
- Refer to the instructions and documentation provided in the Markdown cells of each of the sections below to better understand the functions of each feature.

### Pre-Requisites:
- 

## 1. Initialisation
Check if your device has a [CUDA-supported GPU](https://developer.nvidia.com/cuda-gpus#compute).

In [None]:
# layout init.
btn_layout = widgets.Layout(width='45%', height='40px')
btn_sm_layout = widgets.Layout(width='20%', height='40px') 
label_layout = Layout(width='200px',height='auto')
ddl_layout = widgets.Layout(width='45%', height= 'auto')


# Widget
btn_check_if_cuda = widgets.Button(description="Check if your device is CUDA supported",
    layout = btn_layout, button_style='info')

cuda_output = widgets.Output()

def check_if_cuda(b):
    with cuda_output:
        cuda_output.clear_output()
        if (torch.cuda.is_available()):
            print('Your computer is CUDA supported.')
        else:
            print('Your computer is not CUDA supported.')
    # --- write script here to check if user's machine if cuda supported
    return True

btn_check_if_cuda.on_click(check_if_cuda)

# Display
display(btn_check_if_cuda, cuda_output)

## 2. Select Pipeline
The following Machine Learning (ML) pipelines have been integrated with this project to allow you to have a hand at Activity Detection with ease. Click on any of the following links to learn more about their source repositiories if you're interested.
- [Toyota Smarthome (TSU)](https://project.inria.fr/toyotasmarthome/)
- [NVIDIA STEP)](https://github.com/NVlabs/STEP/)

In [None]:
# Pre-requisites
"""
List of Machine Learning stuff
Consider making it a text file to store/edit?
"""
ml_list = ['Toyota Smart Home', 'Nvidia STEP']

# Widgets
btn_select_pipeline = widgets.Button(description="Confirm",
    layout = btn_layout, button_style='info')

ddl_pipeline = widgets.Dropdown(
    options=ml_list,
    value=ml_list[0],layout = btn_layout,
    description='Select ML Pipeline:', style={'description_width': 'initial'})

# Output
widgetset = widgets.Output()

# Function
def selectWidgetSet(b):
    with widgetset:
        widgetset.clear_output()
        print(ddl_pipeline.value + " is selected as the pipeline.")

btn_select_pipeline.on_click(selectWidgetSet)

# Display
tooltip = widgets.Label("Supported Video Types: MP4, WebM, and OGG.")
mlBox = widgets.VBox([ddl_pipeline,btn_select_pipeline,widgetset])
mlBox

## 3. Data Exploration
Select and view playback of any video from the `data` folder. 

In [None]:
# Logic
def populateList(fileDirectory, fileType=""):
    folder_files = os.listdir(fileDirectory) #You can also use full path.
    print("`data/rgbVideos` directory contains {len_folder} file(s).".format(len_folder=len(folder_files)))
    fileList = []
    for file in folder_files:
        if file.endswith(fileType) or fileType == "":
            fileList.append(file)
    return fileList

fileList = populateList('data/rgbVideos', ".mp4")
        
# Widget
label_data_files= widgets.Label("Select Video(s):", layout = label_layout)
ddl_data_selected = widgets.Dropdown(
    options = fileList,
    value= fileList[0],
    layout = ddl_layout)

btn_video = widgets.Button(
    description='Play video', layout = btn_sm_layout, button_style='info'
)

btn_clear = widgets.Button(
    description='Clear Playback', layout = btn_sm_layout
)

# Display output
video_output = widgets.Output()
def explore_video(b):
    with video_output:
        # Create Video File
        video_output.clear_output()
        html_video_code = '<video width="80%" height="80%" controls><source src="./data/rgbVideos/{fileName}" type="video/mp4"></video>'.format(fileName = ddl_data_selected.value)
        video_output.append_display_data(HTML(html_video_code))

btn_video.on_click(explore_video)

# CLear output
def clear_video(b):
    with video_output:
        video_output.clear_output()
        
btn_clear.on_click(clear_video)

# Display
data_box = widgets.VBox([widgets.HBox([label_data_files, ddl_data_selected])])
data_btn_box = widgets.VBox([widgets.HBox([btn_clear, btn_video]),video_output])

display(data_box)
display(data_btn_box)



## 4. Feature Extraction
This section performs feature extractions from videos into numpy format, so that it can be used to perform inferencing, training and testing.

### Parameters:
- Folder & File
    - Locate the directory where your videos are stored for feature extraction
    - Select one one or more videos
- Stream Type
    - Select between `null, rgb, flow` stream types

In [None]:
# Widgets

sub_folders = [name for name in os.listdir('./data') if os.path.isdir(os.path.join('./data', name))]
# sub_folders.remove('.ipynb_checkpoints')
sub_folders = sorted(sub_folders,reverse=True)

list_of_videos = []

label_video_folder = widgets.Label("Select Folder:", layout = label_layout)
ddl_video_folder = widgets.Dropdown(
    options=sub_folders,
    layout = btn_layout)

label_stream = widgets.Label("Select Stream:", layout = label_layout)
ddl_stream = widgets.Select(
    options=['null', 'rgb', 'flow'],
    value='null',
    layout = ddl_layout)

btn_confirm_folder = widgets.Button(description="Confirm Folder", layout = btn_layout, button_style='info')

def selectvideos(b):
    for (dirpath, dirnames, filenames) in walk('./data/rgbVideos/'):
        list_of_videos.extend(filenames)
        break
    label_feature_files= widgets.Label("Select Video(s):", layout = label_layout)
    ddl_feature_files = widgets.SelectMultiple(
        options=list_of_videos,
        layout = ddl_layout)
    test_box = widgets.VBox([widgets.HBox([label_feature_files, ddl_feature_files]),
                     widgets.HBox([label_stream, ddl_stream]), btn_extract])
    display(test_box)
    
btn_confirm_folder.on_click(selectvideos)

btn_extract = widgets.Button(description="Extract Features", layout = btn_layout, button_style='info')

# Display
feature_box = widgets.VBox([widgets.HBox([label_video_folder, ddl_video_folder]),
                            btn_confirm_folder])
feature_box
    

## 5. Inference
This is the inference section where users can test and evaluate how accurate their model is in the actual video itself, comparing the ground truth and generated annotations.

### Parameters:
- Model 
  - Select from one of the pre-trained ML models
- Input Video 
  - Select an input video from the `/data/rgbVideos` directory

In [None]:
# Logic
model_dir = './TSU/models/'
# listdir() returns a list containing the names of the entries in the directory given by path.
modelList = os.listdir("./TSU/models")
train_modelList = os.listdir("./TSU/PDAN")
modelList += train_modelList

## retrieve all TSU videos to dropdown list
folder_files = os.listdir('data/rgbVideos') 
print("data/rgbVideos` directory contains {len_folder} file(s).".format(len_folder=len(folder_files)))
fileList=[]
for file in folder_files:
    fileList.append(file)

# Widgets
label_inference_model = widgets.Label("Select Model:", layout = label_layout)
ddl_inference_model = widgets.Dropdown(
    options = modelList,
    value = modelList[0],
    layout = btn_layout)

label_inference_video = widgets.Label("Select Input Video:", layout = label_layout)
selected_input_video = widgets.Dropdown(
    options = fileList,
    value= fileList[0],
    layout = btn_layout)

btn_inference = widgets.Button(description="Run Inferencing", layout = btn_layout, button_style='info')

# Display
model_box = widgets.HBox([label_inference_model, ddl_inference_model])
inference_box = widgets.VBox([model_box, widgets.HBox([label_inference_video, selected_input_video])])
inference_output = widgets.Output()
display(inference_box)
display(btn_inference)
display(inference_output)

###### Generate and Save captions Function ###################
class Captioning:
    def __init__(self, annotation_file_path, truth_file_path, video_file):
        self.sf = pandas.read_csv(annotation_file_path, header=0, names=['action','start','end','accuracy'], usecols=['action','start','end','accuracy'])[['action','start','end','accuracy']]
        self.df = pandas.read_csv(truth_file_path)
        self.video_path = video_file
        self.time = 0
        self.nextTime = 0
        self.nextActionTime = 0
        self.eventCounter = 0
        self.actionCounter = 0
        self.prediction = 0.0
        self.ground_truth = ""
        self.caption = ""
        self.output_path = "./videosCaptionOutput/"

    def getGroundTruth(self):
        if self.time == 0:
            self.nextTime = self.df['start_frame'].iloc[1]

        if self.time >= self.nextTime:
            try:
                self.ground_truth = self.df['event'].iloc[self.eventCounter]
                self.ground_truth =  re.sub('[_,.]', ' ', self.ground_truth)
                self.nextTime = self.df['start_frame'].iloc[self.eventCounter+1]
                self.eventCounter += 1
            except:
                pass
            
            
    def getCaption(self):
        if self.actionCounter == 0:
            self.actionCounter, = self.sf.index[self.sf['end'] == 'end']
            
        if self.time >= self.nextActionTime:
            self.actionCounter += 1
            try: 
                self.nextActionTime = float(self.sf['end'].iloc[self.actionCounter])
                self.caption = self.sf['action'].iloc[self.actionCounter]
                self.caption =  re.sub('[_,.]', ' ', self.caption)
                self.prediction = float(self.sf['accuracy'].iloc[self.actionCounter])
            except:
                pass


    def saveVideo(self):
        if not os.path.exists('./videosCaptionOutput/'):
            os.makedirs('./videosCaptionOutput/')
        cap = cv2.VideoCapture(self.video_path)
        # and our buffer to write frames
#         fourcc = cv2.VideoWriter_fourcc(*'mp4v')
        fourcc = cv2.VideoWriter_fourcc(*'H264')
        out = cv2.VideoWriter(self.output_path+str(self.video_path[len(self.video_path)-13:len(self.video_path)-4])+'.mp4', fourcc, 25, (int(cap.get(3)),int(cap.get(4))))
        if (cap.isOpened() == False):
            print("Error opening video stream or file")
        counter = 0
        length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        progress_bar_instance = progress_bar(length)
        progress_bar_instance.display_bar()
        while (cap.isOpened()):
            ret, frame = cap.read()
            if not ret:
                print(f'Video annotation is process complete.')
                break

            height, width, channels = frame.shape

            self.time = int(cap.get(cv2.CAP_PROP_POS_MSEC)/40)
            self.getGroundTruth()
            self.getCaption()
    
#             print(self.time, self.caption)
            cv2.rectangle(frame, (int(width * 0.05), int(height * 0.8)), (int(width * 0.95), int(height*0.95)), (159,159,159), -1)
            cv2.putText(frame, "Ground truth", (int(width*0.1),int(height*0.85)), cv2.FONT_HERSHEY_PLAIN, 1, (0,0,0), 2)
            cv2.putText(frame, self.ground_truth, (int(width*0.3),int(height*0.85)), cv2.FONT_HERSHEY_PLAIN, 1, (0,0,0), 2)
            cv2.putText(frame, "Prediction", (int(width*0.1),int(height*0.9)), cv2.FONT_HERSHEY_PLAIN, 1, (0,0,0), 2)
            cv2.putText(frame, self.caption, (int(width*0.3),int(height*0.9)), cv2.FONT_HERSHEY_PLAIN, 1, (0,0,0), 2)
            cv2.putText(frame, "Accuracy", (int(width*0.65),int(height*0.9)), cv2.FONT_HERSHEY_PLAIN, 1, (0,0,0), 2)
            cv2.putText(frame, str(self.prediction), (int(width*0.8),int(height*0.9)), cv2.FONT_HERSHEY_PLAIN, 1, (0,0,0), 2)

            #write our frame
            out.write(frame)
            cv2.imshow('frame',frame)

            key = cv2.waitKey(1)
            # define the key to
            # close the window
            if key == 'q' or key == 27:
                break
            progress_bar_instance.update_bar()

        cap.release()
        out.release()
        cv2.destroyAllWindows()       
    
# Usage Captioning(annotation_file_path, truth_file_path, video_file) 
# annotation_file_path (e.g. "./data/generatedAnnotations/PDAN_TSU_RGB_P02T01C06.csv" )
# truth_file_path (e.g. "./data/annotations/P02/P02T01C06.csv" )
# video_file (e.g. "./data/rgbVideos/P02T01C06.mp4")

###### Generate CSV After running Inferencing ###########################
def run_inferencing(b):
    with inference_output:
        clear_output()
        video_data = selected_input_video.value.replace(".mp4","")
        inference_json(video_data, "TSU")
        if (ddl_inference_model.value == "PDAN_TSU_RGB"):
            sub_dir = "./TSU/models/"
        else:
            sub_dir = "./TSU/PDAN/"
        inf_dir = sub_dir + ddl_inference_model.value
        inference_dataset = "inference_smarthome_CS_51.json"
        dataset_type = "TSU"
        %run ./TSU/inferencing.py -load_model $inf_dir -video_name $selected_input_video.value -dataset $inference_dataset -dataset_type $dataset_type
        sub_folder = selected_input_video.value[:3]
        gt_annotation_directory = "./data/annotations/{folder}/".format(folder=sub_folder)
        generated_annotation_directory = "./data/generatedAnnotations/"
        annotation_file = selected_input_video.value.replace(".mp4",".csv")
        gen_annotation_file = "{loaded_model}_{vid_name}".format(loaded_model=ddl_inference_model.value, vid_name=annotation_file)
        video_directory = "./data/rgbVideos/{vid_name}".format(vid_name=selected_input_video.value)
        print("video is processing now..")
        test1 = Captioning(generated_annotation_directory+gen_annotation_file, gt_annotation_directory+annotation_file, video_directory)
        test1.saveVideo()
        html_video_code = '<video width="80%" height="80%" controls><source src="./videosCaptionOutput/{video}" type="video/mp4"></video>'.format(video=selected_input_video.value)

        inference_output.append_display_data(HTML(html_video_code))

btn_inference.on_click(run_inferencing)



## 5. Training
This is the section whereby users will train models based on desired criterion, of which there are 3 sections:

#### WandB Setup
  - It is recommended to set up Weights & Biases's Wandb.ai [https://wandb.ai/] as it is used for outputting charts

#### Dataset Selection
  - Users will be able to select or upload a dataset from the folder

#### Initialising Training
  - Input of custom training parameters

#### Model Training
  -  Initiating of model training

## WandB Setup
This section allows users to setup the WandB project used to output the display.
Additionally, the API key will be requested here.

Your API key can be found at https://wandb.ai/authorize

### Continuing an existing project
Should you require to continue logging to a previously made project, use the previously-indicated project name.

In [None]:
#Try to load .env file
basedir = os.path.abspath(os.getcwd())
load_dotenv(os.path.join(basedir, '.env'))

if os.environ.get('WANDB_API_KEY') is not None:
    API_key_wand = widgets.Text(
    value = os.environ.get('WANDB_API_KEY'),
    disabled=False
    )
else:
    API_key_wand = widgets.Text(
    placeholder='Type something..',
    disabled=False
    )

# Enter WandB API Key
API_label = widgets.Label('WandB API Key',
    layout=label_layout
)

# Initiate new Project
desired_label = widgets.Label('Project Name',
    layout=label_layout
)

if os.environ.get('WANDB_PROJECTNAME') is not None:
    desired_Project_name = widgets.Text(
    value = os.environ.get('WANDB_PROJECTNAME'),
    disabled=False
    )
else:
    desired_Project_name = widgets.Text(
    placeholder='Type something..',
    disabled=False
    )

       
# Create Project Button
initiate_WandB_Project = widgets.Button(
        description='Create Project',
    button_style='info',
    layout=btn_layout,
    disabled = False
)
wandbOutput = widgets.Output()
# Stored Project Name Output
apiOutput = widgets.Output()
if os.environ.get('WANDB_PROJECTNAME') is not None and os.environ.get('WANDB_PROJECTNAME') != "":
    with wandbOutput:
        print("Previous Project detected: " + os.environ.get('WANDB_PROJECTNAME'))

# Stored Api Output
apiOutput = widgets.Output()
if os.environ.get('WANDB_API_KEY') is not None and os.environ.get('WANDB_API_KEY') != "":
    with apiOutput:
        print("WANDB_API_KEY detected: " + os.environ.get('WANDB_API_KEY'))

storedApi = widgets.Label('Stored API Key',
    layout=label_layout
)

# Store Project API
def store_wandB_API(b):
    with apiOutput:
        apiOutput.clear_output()
        print("Your key has been stored at")
        %env WANDB_API_KEY=$API_key_wand.value
        envfile = open(".env", "w+")
        api_key = "WANDB_API_KEY=" + API_key_wand.value
        if os.environ.get('WANDB_PROJECTNAME') is not None:
            api_key = [api_key+"\n","WANDB_PROJECTNAME="+os.environ.get('WANDB_PROJECTNAME')]
            envfile.writelines(api_key)
        else:
            envfile.write(api_key)
        envfile.close()

# Create Project Functionality
def create_WandB(b):
    with wandbOutput:
        #Silence Weights&Biases info messages
        %env WANDB_SILENT=TRUE
        %env WANDB_PROJECTNAME=$desired_Project_name.value
        wandbOutput.clear_output()
        envfile = open(".env", "w+")
        project_name = "WANDB_PROJECTNAME=" + desired_Project_name.value
        if os.environ.get('WANDB_API_KEY') is not None:
            project_name = [project_name+"\n","WANDB_API_KEY="+os.environ.get('WANDB_API_KEY')]
            envfile.writelines(project_name)
        else:
            envfile.write(project_name)
        envfile.close()
        print("Your project is:"+desired_Project_name.value)

initiate_WandB_Project0 = widgets.Button(
    description='Store API Key',
    button_style='info',
    layout=btn_layout,
    disabled = False
)
#Link the buttons
initiate_WandB_Project.on_click(create_WandB)
initiate_WandB_Project0.on_click(store_wandB_API)

display(widgets.HBox([API_label,API_key_wand]))
display(widgets.HBox([storedApi,apiOutput]))
display(initiate_WandB_Project0)
display(widgets.HBox([desired_label,desired_Project_name]))
display(widgets.HBox([storedApi,wandbOutput]))
display(initiate_WandB_Project)

## Selecting Dataset Code

In [None]:
from util_modules.json_util.test_train_json import create_subset_json
from util_modules.json_util.test_train_json import count_train_test

# Training Dataset Upload Label
training_dataset_label = widgets.Label("Select dataset or upload your own")

# Create list for use with user dataset selection
datasetList = populateList("./TSU/tsu_data", ".json") # populate list with files from dataset directory

if not datasetList:
    print ("No existing datasets found")

# Allow user to select dataset
selectedDataset = widgets.Dropdown(
    options = datasetList,
    value= datasetList[0],
    disabled=False,
)


def on_change(change):
    with dataset_ddl_output:
        if change['type'] == 'change' and change['name'] == 'value':
            clear_output()
            train, test = count_train_test(change['new'], "TSU")
            print("Selected: %s " % change['new'])
            print("Total Number of train: %s" % train )
            print("Total Number of test: %s" % test )

selectedDataset.observe(on_change)

# Allow user to upload own dataset
training_upload_dataset = widgets.Button(
    description='Upload dataset',
    layout=btn_layout,
    button_style='info',
    disabled = False
)

# Allow user to cancel upload
training_upload_cancel = widgets.Button(
    description='Cancel',
    disabled = False
)
# Cancel upload functionality
def upload_dataset_cancel(b):
    with training_upload_output:
        clear_output()
        
# Link cancel button        
training_upload_cancel.on_click(upload_dataset_cancel)

# Output for dataset upload
training_upload_output = widgets.Output()
dataset_ddl_output = widgets.Output()

# Function to upload own dataset from computer
def upload_dataset(file_val):
    if not file_val:
        print ("No file uploaded")

    else:   
        clear_output(wait=True)
        uploaded_file = next(iter(file_val))
        uploaded_filename = uploaded_file["name"]
        content = uploaded_file["content"]
        print("File name: " + uploaded_filename)

        try:
            save_path = './TSU/tsu_data/'
            completeName = os.path.join(save_path, uploaded_filename) 
            with open(completeName, 'wb') as f: f.write(content)
            print (uploaded_filename + " uploaded to 'datasets' successfully!")
            
            # resets
            uploaded_file = ()
            uploaded_filename = ""
            content = ""
        except:
            print(sys.exc_info())
            display(training_upload_cancel)

# Upload functionality
def show_upload(b):
    with training_upload_output:
        clear_output()
        uploader = widgets.FileUpload(accept='.json',  # Currently only accepts .json
                                      multiple=False,
                                      description = "Browse",
                                      _counter = 0
        )
        upload_button = widgets.Button(
            description='Confirm Upload',
            disabled = False
        )

        def on_button_clicked(b):
            upload_dataset(uploader.value)

        upload_button.on_click(on_button_clicked)
        display(widgets.HBox([uploader, upload_button,training_upload_cancel]))

#Link the buttons
training_upload_dataset.on_click(show_upload)


style = {'description_width': 'initial'}

# Provide user with batch size
training_no = widgets.BoundedIntText(
    min=1,
    # max=1000, # to be defined and added as necessary
    step=1,
    style=style,
    disabled=False
)

# Provide user with batch size
testing_no = widgets.BoundedIntText(
    min=1,
    # max=1000, # to be defined and added as necessary
    step=1,
    style=style,
    disabled=False
)

# Confirm Button to provide user with virtual commit
dataset_Confirm_Button = widgets.Button(
    description='Confirm',
    layout=btn_layout,
    button_style='info',
    disabled = False
)
# Output to contain printout for ALL user input
datasetOutput = widgets.Output()


training_UserInput_label = widgets.Label("Current User Input")
# Functionality for user input
def on_button_click(b):
    with datasetOutput:
        json_output = create_subset_json(testing_no.value, training_no.value, selectedDataset.value, "./TSU/tsu_data/", "train_" +selectedDataset.value, "./TSU/tsu_data/")
        '''Print out the user input'''
        datasetOutput.clear_output() # Clear previous output
        print("Total train and test video data for training: ", json_output)

dataset_Confirm_Button.on_click(on_button_click)

# Labels
training_dataset_label = widgets.Label("Dataset", layout=label_layout)
training_no_label = widgets.Label("No. of Training", layout=label_layout)
training_testing_no_label = widgets.Label("No. of Testing", layout=label_layout)

# User Interface
uploadBox = widgets.HBox([training_upload_dataset,training_upload_output])
dataset_input_box = widgets.VBox([widgets.HBox([training_dataset_label,selectedDataset, dataset_ddl_output]),
                                  uploadBox,
                                  widgets.HBox([training_no_label,training_no]),
                                  widgets.HBox([training_testing_no_label,testing_no]),datasetOutput])
display(dataset_input_box)
display(dataset_Confirm_Button)

In [None]:
%%html
<!-- Left-align Tables -->
<style>
  table {margin-left: 0 !important;}
</style>

## Initialise Training

| Variable | Description |
| :- | :- |
| Video Type | Select from the various types (However only RGB is supported for now) |
|Model Name|Name of the model|
|Batch Size||
|Epoch|Number of training iterations|
|Kernel|Number of Concurrent Processes|

In [None]:
# Things I need; UI = UserInput
'''
-load_model = ./models/PDAN_TSU_RGB [ui]
-mode = RGB [hardcoded]
'''

class user_input:
    def __init__(self):
        self.model_name = None
        self.batch_size = None
        self.epoch = None
        self.kernelsize = None
        self.mode = 'RGB'
        self.comp_info = 'TSU_CS_RGB_PDAN'
        self.train = True
        self.dataset_type = 'TSU'
        self.dataset = None
        self.num_channel = 512
        self.APtype = 'map'
        self.model = 'PDAN'
        self.lr = 0.0002

        
modes = ['rgb','skeleton']
# Allow user to select type of TSU evaluation
training_video_mode = widgets.Dropdown(
    options = modes,
    value= modes[0],
    disabled=False,
)
        

# Specify a name for this new model using appropriate UI elements.
desired_model_name = widgets.Text(
    placeholder='Type something..',
    disabled=False
)

# Create list for use with user dataset selection
datasetList = populateList("./TSU/models","") # populate list with files from dataset directory

if not datasetList:
    print ("No existing datasets found")


# Provide user with batch size
user_batch_size = widgets.BoundedIntText(
    min=1,
    # max=1000, # to be defined and added as necessary
    step=1,
    disabled=False
)

# Provide user with epoch selection
epochs = widgets.BoundedIntText(
    min=1,
    # max=1000, # to be defined and added as necessary
    step=1,
    disabled=False
)

# Provide user with kernel input
kernel = widgets.BoundedIntText(
    min=2,
    # max=1000, # to be defined and added as necessary
    step=1,
    disabled=False
)

# Confirm Button to provide user with virtual commit
training_Confirm_Button = widgets.Button(
    description='Confirm Input',
    layout=btn_layout,
    button_style='info',
    disabled = False
)

# Output to contain printout for ALL user input
trainingParamOutput = widgets.Output()
with trainingParamOutput:
    print("Currently Empty")

user = user_input() #Instantiate class here
training_UserInput_label = widgets.Label("Currently Stored Input")
# Functionality for user input
def on_button_click(b):
    with trainingParamOutput:
        '''Store the user input into variables'''
        train_data = "train_" + selectedDataset.value
        user.dataset = train_data
        userModelName = desired_model_name.value
        userModelName = userModelName.replace(" ", "_") # strip spaces
        user.model_name = userModelName
        user.batch_size = user_batch_size.value
        user.epoch = epochs.value
        user.kernelsize = kernel.value
        user.mode = training_video_mode.value
        '''Print out the user input'''
        trainingParamOutput.clear_output() # Clear previous output
        print(train_data)
        for item in vars(user): # Clear previous output
            print("{}:{}".format(item,vars(user)[item]))

training_Confirm_Button.on_click(on_button_click)
# Labels
training_video_label = widgets.Label("Video Type", layout=label_layout)
training_modelName_label = widgets.Label("Model Name", layout=label_layout)
training_batchSize_label = widgets.Label("Batch Size", layout=label_layout)
training_epochs_label = widgets.Label("Epoch", layout=label_layout)
training_kernel_label = widgets.Label("Kernel", layout=label_layout)

inputBox = widgets.VBox([widgets.VBox([widgets.HBox([training_video_label, training_video_mode]),
widgets.HBox([training_modelName_label, desired_model_name]),
widgets.HBox([training_batchSize_label, user_batch_size]),
widgets.HBox([training_epochs_label, epochs]),
widgets.HBox([training_kernel_label, kernel])]),widgets.VBox([training_UserInput_label,trainingParamOutput])])
display(inputBox)

# Select Dataset
split_setting = ""
if "CS" in selectedDataset.value:
    split_setting = "CS"
elif "CV" in selectedDataset.value:
    split_setting = "CV"

bigRedButton = widgets.Button(
    description='Train Model',
    disabled = False,
    button_style='info',
    layout=btn_layout
)

# Output train
trainOutput = widgets.Output()

# Functionality for user input
def runTrain(b):
    with trainOutput:
        clear_output()
        if user.model_name == "" or user.model_name == None:
            print("Model name required.")
        elif desired_Project_name.value == "" or desired_Project_name.value == None:
            print("Project name required.")
        elif os.environ.get('WANDB_API_KEY') == None:
            print("WandB API Key required.")
        else:
            %run ./TSU/train.py -batch_size $user.batch_size \
            -model_name $user.model_name -epoch $user.epoch  \
            -kernelsize $user.kernelsize \
            -mode $user.mode -comp_info $user.comp_info -train $user.train \
            -dataset_type $user.dataset_type -num_channel $user.num_channel -APtype $user.APtype\
            -model $user.model -lr $user.lr -dataset $user.dataset -split_setting $split_setting -wandb_project $desired_Project_name.value

bigRedButton.on_click(runTrain)
display(widgets.VBox([training_Confirm_Button,bigRedButton,trainOutput]))

## 6. Testing
This testing section will take the trained models that was previously generated in the training section and evaluate their results to see how accurate each model is.

Your dataset has already been selected in section 5.

Parameters needed for running test:
1. -load_model, All the trained model that was generated in the previous training
2. -epoch
3. -dataset e.g. (train_smarthome_cs.json)

find out all the test videos from the json file

In [None]:
# Training Dataset Upload Label
training_dataset_label = widgets.Label("Select dataset or upload your own")

# Create list for use with user dataset selection
model_list = populateList("./TSU/PDAN/",) # populate list with files from dataset directory

if not model_list:
    print ("No existing datasets found")

# Allow user to select dataset
selected_model = widgets.Dropdown(
    options = model_list,
    value= model_list[0],
    disabled=False,
)


def on_change(change):
    with model_ddl_output:
        if change['type'] == 'change' and change['name'] == 'value':
            clear_output()
            print("Selected: %s " % change['new'])

selected_model.observe(on_change)

def run_test(b):
    with test_output:
        clear_output()
        model_w_path = "./TSU/PDAN/" + selected_model.value
        %run ./TSU/test.py -load_model $model_w_path -dataset $user.dataset 

# Widget
btn_testing = widgets.Button(description="Run Testing", layout = btn_layout, button_style='info')

btn_testing.on_click(run_test)

##output widget
test_output = widgets.Output()
model_ddl_output = widgets.Output()

# Display

display(widgets.VBox([selected_model,model_ddl_output,btn_testing,test_output]))

## View WandB Charts

To view the charts produced by within WandB, ensure the project's privacy settings are set to PUBLIC
<br> 
URL Format: https://wandb.ai/username/project/overview
<br>
Example: https://wandb.ai/2002133sit/TSU-project/overview

In [None]:
%wandb 2002133sit/TSU-project