In [1]:
import ipywidgets
#from jinja2 import Template
from IPython.display import display, Image, HTML

ModeSelectPageTitle_widget = ipywidgets.HTML(value="<b>Home</b>")
Manage_widget = ipywidgets.Button(description='Shop Manage')
Shopping_widget = ipywidgets.Button(description='Start Shopping')

ModeSelect_widget = ipywidgets.VBox([
    ModeSelectPageTitle_widget,
    ipywidgets.HBox([Manage_widget, Shopping_widget])
])

AddProductPageTitle_widget = ipywidgets.HTML(value="<b>Add Products</b>")
ProductName_widget = ipywidgets.Text(description='Prod.Name')
ProductPrice_widget = ipywidgets.Text(description='Price')
SaveProductTitle_widget = ipywidgets.HTML(value="Save Product")
Save_button_widget = ipywidgets.Button(description='Save')
Finish_button_widget = ipywidgets.Button(description='Start Training')
HomeButton_widget = ipywidgets.Button(description='Return Home')
UploadImageTitle_widget = ipywidgets.HTML(value="Upload Image")
PictureUpload_widget = ipywidgets.FileUpload(
    accept='image/*',
    multiple=False     # Single file upload = False
)

# Set up product image directory location if it is not there already
DSPIMG_DIR = '/nvdli-nano/data/shopping_cart/display_image/'
!mkdir -p {DSPIMG_DIR}

file = open("NoImage.jpg", "rb")
Image = file.read()
ImagePreview_widget = ipywidgets.Image(value=Image, format='png', width=225, height=225)

AddProduct_widget = ipywidgets.VBox([
    AddProductPageTitle_widget,
    ProductName_widget,
    ProductPrice_widget,
    ipywidgets.HBox([UploadImageTitle_widget,PictureUpload_widget]),
    ImagePreview_widget,
    ipywidgets.HBox([SaveProductTitle_widget, Save_button_widget]),
    ipywidgets.HBox([Finish_button_widget, HomeButton_widget])
])

out = ipywidgets.Output()

def retunToHome(x):
    out.clear_output()
    with out:
        display(ModeSelect_widget)
        
HomeButton_widget.on_click(retunToHome)

def handleUpload(change):
    Image = PictureUpload_widget.data[0]
    ImagePreview_widget.value = Image
    
PictureUpload_widget.observe(handleUpload, names='data')
        
# Update product master with product name and price
# Also save display image to display_image directory
def saveToFile(x):
    with open('ProductMaster.csv', 'a+', newline='') as write_obj:
        write_obj.write(ProductName_widget.value + ',' + ProductPrice_widget.value +'\n')
    with open(DSPIMG_DIR + ProductName_widget.value + '.jpg', "w+b") as f:
        f.write(PictureUpload_widget.data[0])

Save_button_widget.on_click(saveToFile)

with out:
    display(ModeSelect_widget)

def switchMode(x):
    out.clear_output()
    with out:
        display(AddProduct_widget)

Manage_widget.on_click(switchMode)
#out


In [2]:
import torchvision.transforms as transforms
from dataset import ImageClassificationDataset

TASK = 'products'
        
DATASETS = ['Primary', 'Secondary']

CATEGORIES = []

# Number of products in ProductMaster.csv is equal to Categories
# Update CATEGORIES every time when training page is loaded
def setCategories():
    CATEGORIES.clear()
    try :
        with open("ProductMaster.csv", "r") as table:
            for row in table:
                elements = row.split(",",1)
                CATEGORIES.append(elements[0])
    except Exception:
        print("File open error!")
        CATEGORIES.append('Dummy')

setCategories()
#print(CATEGORIES)
        
TRANSFORMS = transforms.Compose([
    transforms.ColorJitter(0.2, 0.2, 0.2, 0.2),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

datasets = {}
def setdatasets():
    for name in DATASETS:
        datasets[name] = ImageClassificationDataset('data/' + TASK + '_' + name, CATEGORIES, TRANSFORMS)

setdatasets()
#print("{} task with {} categories defined".format(TASK, CATEGORIES))

In [3]:
# Check device number
#!ls -ltrh /dev/video*
from jetcam.usb_camera import USBCamera

# for USB Camera (Logitech C270 webcam)
camera = USBCamera(width=225, height=225, capture_device=0) # If error? check & confirm the capture_device number

camera.running = True
#print("camera created")

In [4]:
import traitlets
from jetcam.utils import bgr8_to_jpeg
import torch.nn.functional as F

# initialize active dataset
dataset = datasets[DATASETS[0]]

# unobserve all callbacks from camera in case we are running this cell for second time
camera.unobserve_all()

# create image preview
camera_widget = ipywidgets.Image()
traitlets.dlink((camera, 'value'), (camera_widget, 'value'), transform=bgr8_to_jpeg)

dataset_widget = ipywidgets.Dropdown(options=DATASETS, description='Dataset')
product_widget = ipywidgets.Dropdown(options=dataset.categories, description='Product')
count_widget = ipywidgets.IntText(description='Count')
save_widget = ipywidgets.Button(description='Add')
modelPath = '/nvdli-nano/data/shopping_cart/data/my_model.pth' # Update this path if the project is running on a different folder

# manually update counts at initialization
count_widget.value = dataset.get_count(product_widget.value)

# sets the active dataset
def set_dataset(change):
    global dataset
    dataset = datasets[change['new']]
    count_widget.value = dataset.get_count(product_widget.value)
    
dataset_widget.observe(set_dataset, names='value')

# update counts when we select a new category
def update_counts(change):
    count_widget.value = dataset.get_count(change['new'])
    
product_widget.observe(update_counts, names='value')

# save image for category and update counts
def save(c):
    dataset.save_entry(camera.value, product_widget.value)
    count_widget.value = dataset.get_count(product_widget.value)
    
save_widget.on_click(save)

#Model - Begin -
import torch
import torchvision

device = torch.device('cuda')

# ALEXNET
# model = torchvision.models.alexnet(pretrained=True)
# model.classifier[-1] = torch.nn.Linear(4096, len(dataset.categories))

# SQUEEZENET 
# model = torchvision.models.squeezenet1_1(pretrained=True)
# model.classifier[1] = torch.nn.Conv2d(512, len(dataset.categories), kernel_size=1)
# model.num_classes = len(dataset.categories)

# RESNET 18
model = torchvision.models.resnet18(pretrained=True)
model.fc = torch.nn.Linear(512, len(dataset.categories))

# RESNET 34
# model = torchvision.models.resnet34(pretrained=True)
# model.fc = torch.nn.Linear(512, len(dataset.categories))
    
model = model.to(device)

SaveModelHeader_widget = ipywidgets.HTML(value="<b><em>Save Model</em></b>")
model_save_button = ipywidgets.Button(description='save model')
model_load_button = ipywidgets.Button(description='load model')
model_path_widget = ipywidgets.Text(description='model path', value=modelPath)

def load_model():
    model.load_state_dict(torch.load(modelPath))
    
model_load_button.on_click(load_model)
    
def save_model(c):
    torch.save(model.state_dict(), model_path_widget.value)
    modelPath = model_path_widget.value

model_save_button.on_click(save_model)

model_widget = ipywidgets.VBox([
    SaveModelHeader_widget,
    model_path_widget,
    ipywidgets.HBox([model_load_button, model_save_button])
])
#Model - End -

In [5]:
#Training - Begin -
BATCH_SIZE = 8

optimizer = torch.optim.Adam(model.parameters())
# optimizer = torch.optim.SGD(model.parameters(), lr=1e-3, momentum=0.9)

TrainingPageTitle_widget = ipywidgets.HTML(value="<b>Training</b>")
AddTrainingDataHeader_widget = ipywidgets.HTML(value="<b><em>Add Training Data</em></b>")
TrainHeader_widget = ipywidgets.HTML(value="<b><em>Train Model</em></b>")
epochs_widget = ipywidgets.IntText(description='Epochs', value=1)
eval_button = ipywidgets.Button(description='Evaluate')
train_button = ipywidgets.Button(description='Train')
loss_widget = ipywidgets.FloatText(description='Loss')
accuracy_widget = ipywidgets.FloatText(description='Accuracy')
progress_widget = ipywidgets.FloatProgress(min=0.0, max=1.0, description='Progress')
PageSeperator_widget = ipywidgets.HTML(value="<b>---</b>")

def train_eval(is_training):
    global BATCH_SIZE, LEARNING_RATE, MOMENTUM, model, dataset, optimizer, eval_button, train_button, accuracy_widget, loss_widget, progress_widget, state_widget
    
    try:
        train_loader = torch.utils.data.DataLoader(
            dataset,
            batch_size=BATCH_SIZE,
            shuffle=True
        )

        state_widget.value = 'stop'
        train_button.disabled = True
        eval_button.disabled = True
        time.sleep(1)

        if is_training:
            model = model.train()
        else:
            model = model.eval()
        while epochs_widget.value > 0:
            i = 0
            sum_loss = 0.0
            error_count = 0.0
            for images, labels in iter(train_loader):
                # send data to device
                images = images.to(device)
                labels = labels.to(device)

                if is_training:
                    # zero gradients of parameters
                    optimizer.zero_grad()

                # execute model to get outputs
                outputs = model(images)

                # compute loss
                loss = F.cross_entropy(outputs, labels)

                if is_training:
                    # run backpropogation to accumulate gradients
                    loss.backward()

                    # step optimizer to adjust parameters
                    optimizer.step()

                # increment progress
                error_count += len(torch.nonzero(outputs.argmax(1) - labels).flatten())
                count = len(labels.flatten())
                i += count
                sum_loss += float(loss)
                progress_widget.value = i / len(dataset)
                loss_widget.value = sum_loss / i
                accuracy_widget.value = 1.0 - error_count / i
                
            if is_training:
                epochs_widget.value = epochs_widget.value - 1
            else:
                break
    except e:
        print("Model Train Error : {}".format(e))
        pass
    model = model.eval()

    train_button.disabled = False
    eval_button.disabled = False
    state_widget.value = 'live'
    
train_button.on_click(lambda c: train_eval(is_training=True))
eval_button.on_click(lambda c: train_eval(is_training=False))
    
train_eval_widget = ipywidgets.VBox([
    TrainHeader_widget,
    epochs_widget,
    progress_widget,
    loss_widget,
    accuracy_widget,
    ipywidgets.HBox([train_button])#, eval_button])  # Don't use Eval for the time being
])
#Training - End -

In [6]:
dataAdd_Widget = ipywidgets.VBox([
    dataset_widget, 
    product_widget, 
    count_widget, 
    save_widget
])

trainingPage_widget = ipywidgets.VBox([
    TrainingPageTitle_widget,
    AddTrainingDataHeader_widget,
    ipywidgets.HBox([dataAdd_Widget, camera_widget]), 
    train_eval_widget,
    model_widget,
    PageSeperator_widget,
    ipywidgets.HBox([Shopping_widget, HomeButton_widget])
])

def trainingPage(x):
    out.clear_output()
    with out:
        setCategories()
        setdatasets()
        product_widget.options = dataset.categories
        display(trainingPage_widget)
    
Finish_button_widget.on_click(trainingPage)
#out

In [7]:
import threading
import time
from utils import preprocess
import torch.nn.functional as F

#state_widget = ipywidgets.ToggleButtons(options=['stop', 'live'], description='state', value='stop') # Stop button is not required
state_widget = ipywidgets.ToggleButtons(options=['stop', 'live'], description='state', value='live')
prediction_widget = ipywidgets.Text(description='', layout = ipywidgets.Layout(width='225px'))
price_widget = ipywidgets.Text(description='', layout = ipywidgets.Layout(width='225px'))
TempTable_widget = ipywidgets.HTML(value="<b>Product Price</b>", layout = ipywidgets.Layout(width='450px'))
debug_widget = ipywidgets.Text(description='', layout = ipywidgets.Layout(width='450px'))

# Not used
score_widgets = []
for category in dataset.categories:
    score_widget = ipywidgets.FloatSlider(min=0.0, max=1.0, description=category, orientation='vertical')
    score_widgets.append(score_widget)
    
ProdList_dict = {}  #Only the products purchased by shopper
ProdTable_dict = {} #All products in the shop (read from CSV)

# Live execution
def live(state_widget, model, camera, prediction_widget, score_widget):
    global dataset
    while state_widget.value == 'live':
        #print("Test running")
        image = camera.value
        preprocessed = preprocess(image)
        output = model(preprocessed)
        output = F.softmax(output, dim=1).detach().cpu().numpy().flatten()
        
        if output.max() <= 0.97:
            print("Prediction inaccurate, skip & continue")
            file = open("NoImage.jpg", "rb")
            Image = file.read()
            ImagePreview_widget.value = Image
            continue
        
        #debug_widget.value = output
        category_index = output.argmax()
        #print("Category Index : {}".format(category_index))
        prediction_widget.value = dataset.categories[category_index]
        productName = prediction_widget.value
        
        # Display preview image (not the acutal image)
        file = open("/nvdli-nano/data/shopping_cart/display_image/" + productName + ".jpg", "rb")
        Image = file.read()
        ImagePreview_widget.value = Image
        
        #ProductName1_widget.value = productName
        ProdList_dict[productName] = ProdTable_dict[productName]
        #print("Product List : {}".format(ProdList_dict))
        print_dict_as_html_table(ProdList_dict)
        
        price_widget.value = ProdTable_dict[productName]
        time.sleep(1)
                
        # Individual scores (products) are not displayed
        #for i, score in enumerate(list(output)):
            #score_widgets[i].value = score

def start_live(change):
    execute_thread = threading.Thread(target=live, args=(state_widget, model, camera, prediction_widget, score_widget))
    execute_thread.start()
    
def updateProductTable():
    with open("ProductMaster.csv", "r") as table:
        for row in table:
            elements = row.split(",",1)
            ProdTable_dict[elements[0]] = elements[1]

def print_dict_as_html_table(prod_dict):
    sum = 0.00
    # create a list that will hold the html content  
    # Define table styling
    html_list = ["<style> table, th, td { border: 1px solid black; border-collapse: collapse; border-style: dotted; padding-left: 15px;}  tr:nth-child(even) {  background-color: rgba(150, 212, 212, 0.4); } tfoot td { background-color: rgba(250, 210, 110, 0.4); font-weight:bold;}</style>"]

    # initialise with the <table> tag
    #html_list = ["<table>"]
    html_list.append("<table style=\"width:450px\">")
    html_list.append("<tr>")
    html_list.append("<th style=\"width:70%\">Product</th>")
    html_list.append("<th>Price</th>")
    html_list.append("</tr>")

    #print("Dictionary List {}".format(some_dict))

    #iterate through the dictionary, appending row and element tags to the list
    for key in prod_dict.keys():
        html_list.append("<tr>")
        html_list.append("<td>{0}</td>".format(key))
        price = format(float(prod_dict[key]), ".2f")
        html_list.append("<td>{0}</td>".format(price))
        html_list.append("</tr>")
        sum += float(price) #format(float(price), ".2f")

    Total = format(float(sum), ".2f")
    #print("Total : {}".format(Total))
    html_list.append("<tfoot> <tr> <td>Total</td> <td>{}</td> </tr> </tfoot>".format(Total))  
    
    # add the final </table> tag to the list
    html_list.append("</table>")

    # create a string from the list
    html_string = ' '.join([str(elem) for elem in html_list])
    #print("HTML String {}".format(html_string))

    # display the html 
    NewTempTable_widget = ipywidgets.HTML(value=html_string)#, layout = ipywidgets.Layout(width='450px', border='solid'))
    remove = shoppingPage_widget.children[-1] # Remove old table
    shoppingPage_widget.children = shoppingPage_widget.children[:-1]
    remove.close() # Close old table
    shoppingPage_widget.children += (NewTempTable_widget,) # Add new table
        
#state_widget.observe(start_live, names='value')
        
shoppingPage_widget = ipywidgets.VBox([
    ipywidgets.HBox([camera_widget, ImagePreview_widget], layout = ipywidgets.Layout(height='225px')),
    #ipywidgets.HBox(score_widgets),
    ipywidgets.HBox([prediction_widget, price_widget]),
    #state_widget
    HomeButton_widget,
    PageSeperator_widget,
    TempTable_widget
    #PageSeperator_widget,
    #debug_widget
])

def shoppingPage(x):
    out.clear_output()
    with out:
        load_model()
        updateProductTable()
        start_live('live')
        display(shoppingPage_widget)
        
Shopping_widget.on_click(shoppingPage)
out

Output()

In [8]:
#List = [0.045, 0.98, 0.85]
#print("List : ", List)
#print("Max : ", List.max())