# # Task

In [None]:
import torchvision.transforms as transforms
from xy_dataset import XYDataset

TASK = 'road_following'

CATEGORIES = ['apex']

DATASETS = ['A', 'B', "automatic", "automatic_mini", "automatic_loop", "automatic_observer"]

TRANSFORMS = transforms.Compose([
    transforms.ColorJitter(0.2, 0.2, 0.2, 0.2),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

datasets = {}
for name in DATASETS:
    datasets[name] = XYDataset(TASK + '_' + name, CATEGORIES, TRANSFORMS, random_hflip=True)

# Initialize active dataset
dataset = datasets[DATASETS[0]]

# # Dataset widget

In [None]:
import ipywidgets

# Create widgets
dataset_widget = ipywidgets.Dropdown(options=DATASETS, description='Dataset')
category_widget = ipywidgets.Dropdown(options=dataset.categories, description='Category')
count_widget = ipywidgets.IntText(description='Count')

# Manually update counts at initialization
count_widget.value = dataset.get_count(category_widget.value)

# Sets the active dataset
def set_dataset(change):
    global dataset
    dataset = datasets[change['new']]
    count_widget.value = dataset.get_count(category_widget.value)
dataset_widget.observe(set_dataset, names='value')

# Update counts when we select a new category
def update_counts(change):
    count_widget.value = dataset.get_count(change['new'])
category_widget.observe(update_counts, names='value')

data_collection_widget = ipywidgets.VBox([
    dataset_widget,
    category_widget,
    count_widget
])

display(data_collection_widget)

# # Model

In [None]:
import torch
import torchvision
import ipywidgets

device = torch.device('cuda')
output_dim = 2 * len(dataset.categories)  # x, y coordinate for each category

# Uncomment the model you want to use

# ALEXNET
# model = torchvision.models.alexnet(pretrained=True)
# model.classifier[-1] = torch.nn.Linear(4096, output_dim)

# SQUEEZENET 
# model = torchvision.models.squeezenet1_1(pretrained=True)
# model.classifier[1] = torch.nn.Conv2d(512, output_dim, kernel_size=1)
# model.num_classes = len(dataset.categories)

# RESNET 18
model = torchvision.models.resnet18(pretrained=True)
model.fc = torch.nn.Linear(512, output_dim)

# RESNET 34
# model = torchvision.models.resnet34(pretrained=True)
# model.fc = torch.nn.Linear(512, output_dim)

# DENSENET 121
# model = torchvision.models.densenet121(pretrained=True)
# model.classifier = torch.nn.Linear(model.num_features, output_dim)

model = model.to(device)

model_save_button = ipywidgets.Button(description='Save Model')
model_load_button = ipywidgets.Button(description='Load Model')
model_path_widget = ipywidgets.Text(description='Model Path', value='road_following_model.pth')

def load_model(c):
    model.load_state_dict(torch.load(model_path_widget.value))
model_load_button.on_click(load_model)

def save_model(c):
    torch.save(model.state_dict(), model_path_widget.value)
model_save_button.on_click(save_model)

model_widget = ipywidgets.VBox([
    model_path_widget,
    ipywidgets.HBox([model_load_button, model_save_button])
])

display(model_widget)

# # Creating a Database for Storing Training Process Information

In [None]:
import sqlite3, datetime

connection = sqlite3.connect('train_database.db')

cursor = connection.cursor()

# Create table TrainDataProcess to save data during model training
cursor.execute('''
CREATE TABLE IF NOT EXISTS TrainDataProcess (
id INTEGER PRIMARY KEY,
name_train TEXT NOT NULL,
model TEXT NOT NULL,
optimizer TEXT NOT NULL,
count_imgs INTEGER NOT NULL,
batch_size INTEGER NOT NULL,
epoch INTEGER NOT NULL,
epoch_all INTEGER NOT NULL,
datetime TEXT NOT NULL,
loss_epoch REAL NOT NULL
)
''')

# Save changes and close connection
connection.commit()
connection.close()

# # Training 

In [None]:
import time
import sqlite3, datetime

# Create an object to work with the local SQLite database
connection = sqlite3.connect('train_database.db')

BATCH_SIZE = 8

optimizer = torch.optim.Adam(model.parameters())
# optimizer = torch.optim.SGD(model.parameters(), lr=1e-3, momentum=0.9)

epochs_widget = ipywidgets.IntText(description='Epochs', value=1)
eval_button = ipywidgets.Button(description='Evaluate')
train_button = ipywidgets.Button(description='Train')
loss_widget = ipywidgets.FloatText(description='Loss')
progress_widget = ipywidgets.FloatProgress(min=0.0, max=1.0, description='Progress')

# Add a widget to take the training name value
name_train_model_widget = ipywidgets.Text(description='Test Name', value='Enter Test Name')

def train_eval(is_training):
    global BATCH_SIZE, LEARNING_RATE, MOMENTUM, model, dataset, optimizer, eval_button, train_button, accuracy_widget, loss_widget, progress_widget

    # Remember the initial number of epochs
    initial_epoch = epochs_widget.value

    try:
        # Establish connection with the local database
        cursor = connection.cursor()

        train_loader = torch.utils.data.DataLoader(
            dataset,
            batch_size=BATCH_SIZE,
            shuffle=True
        )

        train_button.disabled = True
        eval_button.disabled = True
        time.sleep(1)

        if is_training:
            model = model.train()
        else:
            model = model.eval()

        while epochs_widget.value > 0:
            i = 0
            sum_loss = 0.0
            error_count = 0.0
            for images, category_idx, xy in iter(train_loader):
                # Send data to device
                images = images.to(device)
                xy = xy.to(device)

                if is_training:
                    # Zero gradients of parameters
                    optimizer.zero_grad()

                # Execute model to get outputs
                outputs = model(images)

                # Compute MSE loss over x, y coordinates for associated categories
                loss = 0.0
                for batch_idx, cat_idx in enumerate(list(category_idx.flatten())):
                    loss += torch.mean((outputs[batch_idx][2 * cat_idx:2 * cat_idx+2] - xy[batch_idx])**2)
                loss /= len(category_idx)

                if is_training:
                    # Run backpropagation to accumulate gradients
                    loss.backward()

                    # Step optimizer to adjust parameters
                    optimizer.step()

                # Increment progress
                count = len(category_idx.flatten())
                i += count
                sum_loss += float(loss)
                progress_widget.value = i / len(dataset)
                loss_widget.value = sum_loss / i

                # Add a new row of data about the training process
                cursor.execute('INSERT INTO TrainDataProcess (name_train, model, optimizer, count_imgs, batch_size, epoch, epoch_all, datetime, loss_epoch) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)',
                               (name_train_model_widget.value,
                                type(model).__name__,  # Model name
                                type(optimizer).__name__,  # Optimizer name
                                count_widget.value,  # Number of images
                                BATCH_SIZE,  # Batch size
                                epochs_widget.value,  # Current epoch
                                initial_epoch,  # Total epochs
                                datetime.datetime.now(),  # Time of the end of the epoch
                                loss_widget.value)  # Loss value
                               )

                # Save changes to the local SQLite database
                connection.commit()

            if is_training:
                epochs_widget.value = epochs_widget.value - 1
            else:
                break

        # Close connection with the local SQLite database
        connection.close()

    except Exception as e:
        print(e)

    model = model.eval()

    train_button.disabled = False
    eval_button.disabled = False

train_button.on_click(lambda c: train_eval(is_training=True))
eval_button.on_click(lambda c: train_eval(is_training=False))

train_eval_widget = ipywidgets.VBox([
    epochs_widget,
    progress_widget,
    loss_widget,
    ipywidgets.HBox([train_button, eval_button]),
    name_train_model_widget
])

display(data_collection_widget, train_eval_widget, model_widget)