In [1]:
#!/usr/bin/env python
# coding: utf-8


# !pip install voila
# !jupyter serverextension enable --sys-prefix voila 




import ipywidgets as widgets
from ipywidgets import FileUpload
from PIL import Image
import numpy as np
import torch.nn as nn
import torch
from io import BytesIO



class BasicCnn(nn.Module) :
    def __init__(self, output_shape = 8) :
        super().__init__()
        self.cnnModel = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size = 3, stride = 1, padding = 1),
            nn.MaxPool2d(kernel_size = 3, stride = 2, padding = 1), #112
            nn.Conv2d(32, 64, kernel_size = 3, stride = 1, padding = 1),
            nn.MaxPool2d(kernel_size = 3, stride = 2, padding = 1), #56
            nn.Conv2d(64, 64, kernel_size = 3, stride = 1, padding = 1),
            nn.MaxPool2d(kernel_size = 3, stride = 2, padding = 1), #28
            nn.Conv2d(64, 128, kernel_size = 3, stride = 1, padding = 1),
            nn.MaxPool2d(kernel_size = 3, stride = 2, padding = 1), #14
            nn.Conv2d(128, 128, kernel_size = 3, stride = 1, padding = 1),
            nn.MaxPool2d(kernel_size = 3, stride = 2, padding = 1), #7
            nn.AdaptiveAvgPool2d((1,1)) #flatten
        )
        
        self.dnnModel = nn.Sequential(
            nn.Linear(128, 64),
            nn.Linear(64, 32),
            nn.Linear(32, output_shape)
        )
        
    def forward(self, x) :
        output = self.cnnModel(x)
        output = output.squeeze()
        output = self.dnnModel(output)
        return output
    
model = BasicCnn(2)
model = torch.load('model.pkl')
model.cpu()
model.eval()



classes=['lao song','not lao song']


upload = FileUpload(multiple=True)
out_pl = widgets.Output()
lbl_pred = widgets.Label()
btn_run = widgets.Button(description='Classify')



def on_click_classify(change):
    narray=[]
    d=BytesIO(list(upload.value.items())[-1][-1]['content'])
    img = Image.open(d)
    img=img.resize((224, 224)) 
    array=np.asarray(img)
    narray.append(array)
    data=torch.transpose(torch.transpose(torch.tensor(narray),3, 1),3,2)/255

        
    out_pl.clear_output()
    with out_pl: display(img)
    
    with torch.no_grad() :
        predict = model(data)
    probs=torch.softmax(predict,dim=0)[predict.argmax()].item()
    lbl_pred.value = f'Prediction: {classes[predict.argmax()]}; Probability: {probs:.04f}'

btn_run.on_click(on_click_classify)



#hide_output
vbox=widgets.VBox([widgets.Label('Select your lao song!'), 
      upload, btn_run, out_pl, lbl_pred])
vbox






VBox(children=(Label(value='Select your lao song!'), FileUpload(value={}, description='Upload', multiple=True)…