## Import Packages

In [None]:
from fastbook import *
from fastai.vision.widgets import *
import torch
from fastai.tabular.all import *

## Constants

In [None]:
# Set all constants
training_data_path = "Images"
model_path = "Model/"
img_size = 128

## Data Loading & Augmentation

In [None]:
model_data = DataBlock(
    blocks=(ImageBlock, CategoryBlock),
    get_items=get_image_files,
    splitter=RandomSplitter(valid_pct=0.2, seed=42),
    get_y=parent_label,
    item_tfms=Resize(img_size)
)

In [None]:
dls = model_data.dataloaders(training_data_path)

In [None]:
model_aug = model_data.new(item_tfms=Resize(img_size), batch_tfms=aug_transforms(mult=2))
dls = model_aug.dataloaders(training_data_path)
dls.train.show_batch(max_n=4, nrows=2, unique=True)

## Model Training

In [2]:
learn = vision_learner(dls, resnet34, metrics=[accuracy, error_rate])

In [None]:
learning_rate = learn.lr_find()
learning_rate

In [None]:
learn.fit(3, lr=learning_rate)

## Results

In [None]:
interp = ClassificationInterpretation.from_learner(learn)
interp.plot_confusion_matrix()

In [None]:
interp.plot_top_losses(5, nrows=5)

In [None]:
interp.print_classification_report()

In [None]:
learn.export(model_path + "Model.pkl")
torch.save(learn, model_path + "Pt_Model.pt")

## Convert the Model to ONNX Model

In [None]:
import torch
import torch.onnx

learn = load_learner(model_path + 'Model2.pkl')
model = learn.model.eval()
dummy_data = torch.randn(1, 3, img_size, img_size)
dummy_input = torch.autograd.Variable(dummy_data)

output_model_name = 'onnx_greenscape2.onnx'
input_name = 'image'
output_name = 'warn'

torch.onnx.export(model, dummy_input, model_path + output_model_name, input_names = [input_name], output_names = [output_name])

In [None]:
import onnx
onnx_model = onnx.load(model_path + output_model_name)
onnx.checker.check_model(onnx_model)
model