# Create Gunpla Classifier App

## Create Dataloader

In [1]:
import dataset
from torchvision import transforms

# Create training transform with TrivialAugment
train_transform = transforms.Compose([
    dataset.SquarePad(fill=255),
    transforms.Resize((224,224)),
    transforms.TrivialAugmentWide(num_magnitude_bins=31),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

test_validate_transform = transforms.Compose([
    dataset.SquarePad(fill=255),
    transforms.Resize((224,224)),
    #transforms.TrivialAugmentWide(num_magnitude_bins=31),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])


train_dataloader, test_dataloader, validate_dataloader, names = dataset.create_dataloaders(
        train_transform=train_transform,
        test_transform=test_validate_transform,
        validate_transform=test_validate_transform,
    )

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
print(names)

['RG01 RX-78-2 Gundam', "RG02 MS-06S Char's Zaku II", 'RG03 GAT-X-105 Aile Strike Gundam', 'RG04 MS-06F Zaku II', 'RG05 ZGMF-X10A Freedom Gundam', 'RG06 FX-550 Sky Grasper', 'RG07 RX-178 Gundam Mk-II Titans', 'RG08 RX-178 Gundam Mk-II A.E.U.G.', 'RG09 ZGMF-X09A Justice Gundam', 'RG10 MSZ-006 Zeta Gundam', 'RG11 ZGMF-X42S Destiny Gundam', 'RG12 RX-78GP01 Zephyranthes', 'RG13 RX-78GP01fb Full Burnern', 'RG14 ZGMF-X20A Strike Freedom Gundam', 'RG15 GN-001 Gundam Exia', "RG16 MSM-07S Char's Z'gok", 'RG17 XXXG-00W0 Wing Gundam Zero EW', 'RG18 GN-0000-GNR-010 OO Raiser', 'RG19 MBF-P02 Gundam Astray Red Frame', 'RG20 XXXG-01W Wing Gundam EW', 'RG21 GNT-0000 OO Qan[T]', 'RG22 MSN-06S Sinanju', 'RG23 Build Strike Gundam Full Package', 'RG24 Gundam Astray Gold Frame Amatsu Mina', 'RG25 RX-0 Unicorn Gundam', "RG26 MS-06R-2 Johnny Ridden's Zaku II", 'RG27 RX-0[N] Unicorn Gundam 02 Banshee Norn', 'RG28 OZ-00MS Tallgeese EW', 'RG29 MSN-04 Sazabi', 'RG30 RX-0 Full Armor Unicorn Gundam', 'RG31 XM-X1 C

## Create ViT model


In [3]:
import torch
import torchvision
import dataset

from torch import nn
from torchvision import transforms

def create_vit_model(model_path: str,
                     num_classes:int=3, 
                     seed:int=42):
    model = torchvision.models.vit_b_16()
    model.heads = nn.Linear(in_features=768, out_features=num_classes) 

    model.load_state_dict(torch.load(f=model_path))
    
    transform = transforms.Compose([
        dataset.SquarePad(fill=255),
        transforms.Resize((224,224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    
    return model, transform

In [4]:
vit, transform = create_vit_model("./models/pretrained_vit2.pt", len(names))
device = "cuda" if torch.cuda.is_available() else "cpu"
_=vit.to(device)

## predict

In [5]:
from timeit import default_timer as timer 
from typing import Tuple, Dict

def predict(model, transform, names, device) -> Tuple[Dict, float]:
    def inner_func(img):
    	"""Transforms and performs a prediction on img and returns prediction and time taken.
    	"""
    	# Start the timer
    	start_time = timer()
    	
    	# Transform the target image and add a batch dimension
    	img = transform(img).unsqueeze(0).to(device)
    	
    	# Put model into evaluation mode and turn on inference mode
    	model.eval()
    	with torch.inference_mode():
    	    # Pass the transformed image through the model and turn the prediction logits into prediction probabilities
    	    pred_probs = torch.softmax(model(img), dim=1)
    	
    	# Create a prediction label and prediction probability dictionary for each prediction class (this is the required format for Gradio's output parameter)
    	pred_labels_and_probs = {names[i]: float(pred_probs[0][i]) for i in range(len(names))}
    	
    	# Calculate the prediction time
    	pred_time = round(timer() - start_time, 5)
    	
    	# Return the prediction dictionary and prediction time 
    	return pred_labels_and_probs, pred_time

    return inner_func

In [6]:
import data_utils
from PIL import Image

imgs=data_utils.get_random_images('./data/gunpla', k=3)

vit_predict = predict(vit, test_validate_transform, names, device)

for p in imgs:
    img = Image.open(p)
    pred_dict, pred_time = vit_predict(img)
    print(f"Prediction label and probability dictionary: \n{pred_dict}")
    print(f"Prediction time: {pred_time} seconds")

Prediction label and probability dictionary: 
{'RG01 RX-78-2 Gundam': 0.03296208009123802, "RG02 MS-06S Char's Zaku II": 0.00015252716548275203, 'RG03 GAT-X-105 Aile Strike Gundam': 0.03440404683351517, 'RG04 MS-06F Zaku II': 0.00029010840808041394, 'RG05 ZGMF-X10A Freedom Gundam': 0.0002843623806256801, 'RG06 FX-550 Sky Grasper': 0.00213455967605114, 'RG07 RX-178 Gundam Mk-II Titans': 0.00015522667672485113, 'RG08 RX-178 Gundam Mk-II A.E.U.G.': 0.006870567332953215, 'RG09 ZGMF-X09A Justice Gundam': 0.00010574361658655107, 'RG10 MSZ-006 Zeta Gundam': 0.0002595948171801865, 'RG11 ZGMF-X42S Destiny Gundam': 0.00013282686995808035, 'RG12 RX-78GP01 Zephyranthes': 0.8586981296539307, 'RG13 RX-78GP01fb Full Burnern': 0.053710177540779114, 'RG14 ZGMF-X20A Strike Freedom Gundam': 1.4190071624398115e-06, 'RG15 GN-001 Gundam Exia': 0.004067098721861839, "RG16 MSM-07S Char's Z'gok": 1.914706444949843e-05, 'RG17 XXXG-00W0 Wing Gundam Zero EW': 0.0001819964381866157, 'RG18 GN-0000-GNR-010 OO Raiser

## Gradio

In [7]:
# Import/install Gradio 
try:
    import gradio as gr
except: 
    !pip -q install gradio
    import gradio as gr
    
print(f"Gradio version: {gr.__version__}")

Gradio version: 3.16.2


In [8]:
import os

# Get example filepaths in a list of lists
example_list = [["./data/playground/" + example] for example in os.listdir("./data/playground/")]
example_list

[['./data/playground/rg01_02.jpg'],
 ['./data/playground/rg02_02.jpg'],
 ['./data/playground/rg04_02.jpg'],
 ['./data/playground/rg03_02.jpg']]

In [9]:
import gradio as gr

# Create title, description and article strings
title = "Gunpla classifier"
description = "Which gunpla is this?"
article = "Created at [09. PyTorch Model Deployment](https://www.learnpytorch.io/09_pytorch_model_deployment/)."

# Create the Gradio demo
demo = gr.Interface(fn=vit_predict, # mapping function from input to output
                    inputs=gr.Image(type="pil"), # what are the inputs?
                    outputs=[gr.Label(num_top_classes=len(names), label="Predictions"), # what are the outputs?
                             gr.Number(label="Prediction time (s)")], # our fn has two outputs, therefore we have two outputs
                    examples=example_list, 
                    title=title,
                    description=description,
                    article=article)

# Launch the demo!
#demo.launch(debug=True, share=True)

## Model.py

In [10]:
%%writefile model.py

import torch
import torchvision
import dataset

from torch import nn
from torchvision import transforms

def create_vit_model(model_path: str,
                     device: str,
                     num_classes:int=3
                    ):
    model = torchvision.models.vit_b_16()
    model.heads = nn.Linear(in_features=768, out_features=num_classes) 

    model.load_state_dict(torch.load(f=model_path, map_location=device))
    
    transform = transforms.Compose([
        dataset.SquarePad(fill=255),
        transforms.Resize((224,224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    
    return model, transform


def create_efficientnet_model(model_path: str,
                     device: str,
                     num_classes:int=3
                    ):
    model = torchvision.models.efficientnet_b2()
    model.classifier = torch.nn.Sequential(
        torch.nn.Dropout(p=0.2, inplace=True), 
        torch.nn.Linear(in_features=1408,         # note: this is different from b0
                    out_features=num_classes, # same number of output units as our number of classes
                    bias=True))

    model.load_state_dict(torch.load(f=model_path, map_location=device))
    
    transform = transforms.Compose([
        dataset.SquarePad(fill=255),
        transforms.Resize((224,224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    
    return model, transform

Overwriting model.py


## App.py

In [11]:
%%writefile app.py

import gradio as gr
import os
import torch

import model

from timeit import default_timer as timer
from typing import Tuple, Dict

class_names=['RG01 RX-78-2 Gundam', "RG02 MS-06S Char's Zaku II", 'RG03 GAT-X-105 Aile Strike Gundam', 'RG04 MS-06F Zaku II', 'RG05 ZGMF-X10A Freedom Gundam', 'RG06 FX-550 Sky Grasper', 'RG07 RX-178 Gundam Mk-II Titans', 'RG08 RX-178 Gundam Mk-II A.E.U.G.', 'RG09 ZGMF-X09A Justice Gundam', 'RG10 MSZ-006 Zeta Gundam', 'RG11 ZGMF-X42S Destiny Gundam', 'RG12 RX-78GP01 Zephyranthes', 'RG13 RX-78GP01fb Full Burnern', 'RG14 ZGMF-X20A Strike Freedom Gundam', 'RG15 GN-001 Gundam Exia', "RG16 MSM-07S Char's Z'gok", 'RG17 XXXG-00W0 Wing Gundam Zero EW', 'RG18 GN-0000-GNR-010 OO Raiser', 'RG19 MBF-P02 Gundam Astray Red Frame', 'RG20 XXXG-01W Wing Gundam EW', 'RG21 GNT-0000 OO Qan[T]', 'RG22 MSN-06S Sinanju', 'RG23 Build Strike Gundam Full Package', 'RG24 Gundam Astray Gold Frame Amatsu Mina', 'RG25 RX-0 Unicorn Gundam', "RG26 MS-06R-2 Johnny Ridden's Zaku II", 'RG27 RX-0[N] Unicorn Gundam 02 Banshee Norn', 'RG28 OZ-00MS Tallgeese EW', 'RG29 MSN-04 Sazabi', 'RG30 RX-0 Full Armor Unicorn Gundam', 'RG31 XM-X1 Crossbone Gundam X1', 'RG32 RX-93 Nu Gundam', 'RG33 ZGMF-X56S_α Force Impulse Gundam', 'RG34 MSN-02 Zeong', 'RG35 XXXG-01W Wing Gundam', 'RG36 RX-93-υ2 Hi-Nu Gundam', 'RG37 GF13-017NJII God Gundam']

device = "cuda" if torch.cuda.is_available() else "cpu"

vit, vit_transform = model.create_vit_model("./models/pretrained_vit2.pt", device, len(class_names))
vit=vit.to(device)

efficientnet, efficientnet_transform = model.create_efficientnet_model("./models/efficentnet_b2_argument.pt", device, len(class_names))
efficientnet=efficientnet.to(device)


def predict_func(model, transform, names, device) -> Tuple[Dict, float]:
    def inner_func(img):

    	"""Transforms and performs a prediction on img and returns prediction and time taken.
    	"""
    	# Start the timer
    	start_time = timer()
    	
    	# Transform the target image and add a batch dimension
    	img = transform(img).unsqueeze(0).to(device)
    	
    	# Put model into evaluation mode and turn on inference mode
    	model.eval()
    	with torch.inference_mode():
    	    # Pass the transformed image through the model and turn the prediction logits into prediction probabilities
    	    pred_probs = torch.softmax(model(img), dim=1)
    	
    	# Create a prediction label and prediction probability dictionary for each prediction class (this is the required format for Gradio's output parameter)
    	pred_labels_and_probs = {names[i]: float(pred_probs[0][i]) for i in range(len(names))}
    	
    	# Calculate the prediction time
    	pred_time = round(timer() - start_time, 5)
    	
    	# Return the prediction dictionary and prediction time 
    	return pred_labels_and_probs, pred_time

    return inner_func

vit_predict=predict_func(vit, vit_transform, class_names, device)
efficientnet_predict=predict_func(efficientnet, efficientnet_transform, class_names, device)


def predict(img, model="EfficientNet"):
    pf = vit_predict if model == "ViT" else efficientnet_predict
    
    return pf(img)

# Gradio app

title="Gunpla Classifier"
description="Which gunpla is this?"
example_list = [["./data/playground/" + example] for example in os.listdir("./data/playground/")]

demo = gr.Interface(
            fn=predict, 
            inputs=[
                gr.Image(type='pil', label="Upload Image"),
                gr.inputs.Dropdown(["EfficientNet", "ViT"], default="EfficientNet", label="Select Model"),
            ],
            outputs=[
                gr.Label(num_top_classes=3, label="Predictions"),
                gr.Number(label="Prediction time (s)"),
            ],
            examples=example_list,
            title=title,
            description=description,
        )


if __name__ == "__main__":
    demo.launch()

Overwriting app.py


In [12]:
import app
import importlib
importlib.reload(app)

app.demo.launch(share=True)



Running on local URL:  http://127.0.0.1:7860
Running on public URL: https://c83dd364-e482-45f1.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades (NEW!), check out Spaces: https://huggingface.co/spaces




## requirements.txt

In [13]:
%%writefile requirements.txt
torch==1.13.0
torchvision==0.14.0
gradio==3.10.1

Overwriting requirements.txt


## Pack the files for deployment

In [15]:
%%writefile pack.sh
#!/bin/bash

tar zcvf gunpla_classifier.tgz app.py model.py dataset.py requirements.txt models/pretrained_vit2.pt models/efficentnet_b2_argument.pt data/playground data/gunpla.tgz

Overwriting pack.sh


In [16]:
!bash pack.sh

app.py
model.py
dataset.py
requirements.txt
models/pretrained_vit2.pt
models/efficentnet_b2_argument.pt
data/playground/
data/playground/rg01_02.jpg
data/playground/rg02_02.jpg
data/playground/rg04_02.jpg
data/playground/rg03_02.jpg
data/gunpla.tgz


The file gunpla_classifier.tgz is ready to deploy!

# Upload the app to huggingface

The app were already created [here](https://huggingface.co/spaces/4179e1/gunpla_classifier), clone it via
```
git clone https://huggingface.co/spaces/4179e1/gunpla_classifier
```

Move the file gunpla_classifier.tgz under gunpla_classifier and decompress it via
```
tar zxvf gunpla_classifier.tgz
```

Add the file to git

```
git add *
git lfs track data/gunpla.tgz
git lfs track models/efficentnet_b2_argument.pt
git lfs track models/pretrained_vit2.pt
git commit -a
```

Upload it to huggingface via

```
git push
```

Wait for a few minutes and checkout the results in https://huggingface.co/spaces/4179e1/gunpla_classifier