# ## Camera

In [None]:
from jetcam.csi_camera import CSICamera
# from jetcam.usb_camera import USBCamera

# Initialize the CSI camera with specified width and height
camera = CSICamera(width=224, height=224)
# Uncomment the line below to use a USB camera instead
# camera = USBCamera(width=224, height=224)

camera.running = True


# ## Task

In [None]:
import torchvision.transforms as transforms
from xy_dataset import XYDataset

TASK = 'road_following'

CATEGORIES = ['apex']

DATASETS = ['A', 'B', "automatic", "automatic_mini", "automatic_observer", "automatic_loop"]

# Define image transformations for data augmentation and normalization
TRANSFORMS = transforms.Compose([
    transforms.ColorJitter(0.2, 0.2, 0.2, 0.2),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Load datasets with the specified transformations
datasets = {}
for name in DATASETS:
    datasets[name] = XYDataset(TASK + '_' + name, CATEGORIES, TRANSFORMS, random_hflip=True)


# ## Controller

In [None]:
import ipywidgets.widgets as widgets

# Initialize the controller, replace index with your controller's index if different
controller = widgets.Controller(index=0)

# Uncomment the line below to display the controller widget
# display(controller)
print("Controller is made!")


# ## Teleoperations

In [None]:
from jetracer.nvidia_racecar import NvidiaRacecar
import traitlets

# Initialize the racecar and set throttle gain and steering offset
car = NvidiaRacecar()




In [None]:
# Import the JoystickController class from the joystick_control.py script
from joystick_control import JoystickController

# Create an instance of the JoystickController
controller = JoystickController()

# To update the car based on joystick input, call the update method
controller.start()

# ## Data Widgets

In [None]:
import ipywidgets.widgets as widgets
import cv2
import ipywidgets
import traitlets
from IPython.display import display
from jetcam.utils import bgr8_to_jpeg
from jupyter_clickable_image_widget import ClickableImageWidget

# Initialize active dataset
dataset = datasets[DATASETS[0]]

# Unobserve all callbacks from camera in case we are running this cell for the second time
camera.unobserve_all()

# Create image preview widget
camera_widget = ClickableImageWidget(width=camera.width, height=camera.height)
traitlets.dlink((camera, 'value'), (camera_widget, 'value'), transform=bgr8_to_jpeg)

# Create data collection widgets
dataset_widget = ipywidgets.Dropdown(options=DATASETS, description='Dataset')
category_widget = ipywidgets.Dropdown(options=dataset.categories, description='Category')
count_widget = ipywidgets.IntText(description='Count')
state_widget = ipywidgets.ToggleButtons(options=['stop', 'start'], description='Start Record Data', value='stop')

# Manually update counts at initialization
count_widget.value = dataset.get_count(category_widget.value)

# Sets the active dataset
def set_dataset(change):
    global dataset
    dataset = datasets[change['new']]
    count_widget.value = dataset.get_count(category_widget.value)

dataset_widget.observe(set_dataset, names='value')

# Update counts when we select a new category
def update_counts(change):
    count_widget.value = dataset.get_count(change['new'])

category_widget.observe(update_counts, names='value')

# Create sliders for joystick visualization
x_slider = widgets.FloatSlider(min=-1.0, max=1.0, step=0.001, description='x (left joystick)')
y_slider = widgets.FloatSlider(min=-1.0, max=1.0, step=0.001, description='y (right joystick)', orientation='vertical')

# Sets the active dataset
def set_dataset(change):
    global dataset
    dataset = datasets[change['new']]
    count_widget.value = dataset.get_count(category_widget.value)

dataset_widget.observe(set_dataset, names='value')

# Display the data collection widgets
display(
    ipywidgets.VBox([
        ipywidgets.HBox([camera_widget, x_slider, y_slider]),
        dataset_widget,
        category_widget,
        count_widget,
        state_widget
    ])
)


# ## Database for Training Data

In [None]:
import sqlite3, datetime

# Create a connection to the SQLite database
connection = sqlite3.connect('train_database.db')
cursor = connection.cursor()

# Create a table to store training data
cursor.execute('''
CREATE TABLE IF NOT EXISTS TrainData (
id INTEGER PRIMARY KEY,
task TEXT NOT NULL,
category TEXT NOT NULL,
dataset TEXT NOT NULL,
file_name INTEGER NOT NULL,
x TEXT NOT NULL,
y TEXT NOT NULL,
datetime TEXT NOT NULL
)
''')

# Save changes and close the connection
connection.commit()
connection.close()


# ## Automatic Data Collection using Observer Pattern

In [None]:
import datetime
import sqlite3

# Create a connection to the SQLite database
connection = sqlite3.connect('train_database.db')
cursor = connection.cursor()

# Callback function to handle data recording when controller axes change
def call_back_for_observer_axes(change):
    if change['new']:
        x = controller.axes[0].value
        y = controller.axes[3].value * -1
        x_slider.value = x
        y_slider.value = y

        # Save to disk
        filename = dataset.save_entry(category_widget.value, camera.value, x, y)

        # Insert a new row of data into the database
        cursor.execute('INSERT INTO TrainData (task, category, dataset, file_name, x, y, datetime) VALUES (?, ?, ?, ?, ?, ?, ?)',
                       (
                           TASK, # Task name
                           category_widget.value, # Category
                           dataset_widget.value, # Dataset name
                           filename, # Filename
                           str(x), # X-coordinate
                           str(y), # Y-coordinate 
                           datetime.datetime.now(), # Timestamp
                       )
                       )

        # Save changes to the SQLite database
        connection.commit()

        count_widget.value = dataset.get_count(category_widget.value)

# Observe changes in controller axes
controller.axes[0].observe(call_back_for_observer_axes, names='value')
controller.axes[3].observe(call_back_for_observer_axes, names='value')


# ## Automatic Data Collection in a Loop with a Dedicated Thread

In [None]:
import threading
import time
import datetime
import sqlite3

# Function to record data in a loop
def loop_func_record_data(state_widget, controller, dataset, x_slider, y_slider, count_widget, task, category_widget, dataset_widget):
    # Create a connection to the SQLite database
    connection = sqlite3.connect('train_database.db')
    cursor = connection.cursor()

    while state_widget.value == "start":
        time.sleep(0.1) # Delay to reduce frame rate
        x = controller.axes[0].value
        y = controller.axes[3].value * -1
        x_slider.value = x
        y_slider.value = y

        # Save to disk   
        filename = dataset.save_entry(category_widget.value, camera.value, x, y)

        # Insert a new row of data into the database
        cursor.execute('INSERT INTO TrainData (task, category, dataset, file_name, x, y, datetime) VALUES (?, ?, ?, ?, ?, ?, ?)',
                       (
                           task, # Task name
                           category_widget.value, # Category
                           dataset_widget.value, # Dataset name
                           filename, # Filename
                           str(x), # X-coordinate
                           str(y), # Y-coordinate 
                           datetime.datetime.now(), # Timestamp
                       )
                       )

        # Save changes to the SQLite database
        connection.commit()

        count_widget.value = dataset.get_count(category_widget.value)

# Function to start the data recording loop in a new thread
def start_live(change):
    if change['new'] == 'start':
        execute_thread = threading.Thread(target=loop_func_record_data,
                                          args=(state_widget, controller, dataset, x_slider, y_slider, count_widget, TASK, category_widget, dataset_widget))
        execute_thread.start()

# Observe changes in the state widget to start/stop data recording
state_widget.observe(start_live, names='value')
