# Model

In [None]:
import sys
import configparser
sys.path.append("../../../../deps/brevitas/src/brevitas_examples") #bnn_pynq
from bnn_pynq.models.CNV import*

# Model configuration
config = configparser.ConfigParser()
config['QUANT'] = {'WEIGHT_BIT_WIDTH': '1',
                     'ACT_BIT_WIDTH': '1',
                     'IN_BIT_WIDTH': '8'}
config['MODEL'] = {'NUM_CLASSES':'2',
                  'IN_CHANNELS':'3'}
model = cnv(config)

# Dataset

In [None]:
# Dependencies
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.transforms import ToTensor, Lambda, Compose
from torchvision.io import read_image
from IPython.display import display
from ipywidgets import IntProgress
from torchvision.models import *
import random
import time
import gc 

# configure device
if (torch.cuda.is_available()):
    dev = 'cuda'
else:
    dev = 'cpu'
device = dev
if (torch.cuda.is_available()):
    torch.cuda.device(device)
    gc.collect()
    torch.cuda.empty_cache()
print("using " + device)

# dataset
path_dataset = "../../../../notebooks/dataset/SugarWeed"
batch_size = 16
train_set = 0.8
test_set = 0.2
img_size = 32
rand_rotation = 30

# Create data loaders.
data_dir = path_dataset

# Applying Transformation
train_transforms = transforms.Compose([
                                transforms.RandomRotation(rand_rotation),
                                transforms.RandomResizedCrop(img_size),
                                transforms.RandomHorizontalFlip(),
                                transforms.ToTensor()])

# Data Loading
train_dataset = datasets.ImageFolder(f"{data_dir}/train", transform=train_transforms)
test_dataset = datasets.ImageFolder(f"{data_dir}/val", transform=train_transforms)

# Dataset partition
trainloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size)
testloader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size)

# Train

In [None]:
# import EarlyStopping
from pytorchtools import EarlyStopping
import numpy as np

In [None]:
# train configuration
epochs = 50
patience = 10

# loss function
loss_fn = nn.CrossEntropyLoss()

# training optimizer
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

In [None]:
def train_model(model, batch_size, patience, n_epochs):
    
    train_losses = []
    valid_losses = []
    avg_train_losses = []
    avg_valid_losses = [] 
    avg_dur =[]
    dur =[]
    total_epoch = 0
    correct = 0
    
    early_stopping = EarlyStopping(patience=patience, verbose=True)
    
    for epoch in range(1, n_epochs + 1):
        # train the model 
        print(f"Epoch {epoch}")
        train_epoch_bar = IntProgress(min=0, max=len(trainloader), description='Train:') 
        display(train_epoch_bar) 
        
        model.train() 
        for batch, (data, target) in enumerate(trainloader, 1):
            train_epoch_bar.value += 1 
            optimizer.zero_grad()
            output = model(data)
            loss = loss_fn(output, target)
            loss.backward()
            optimizer.step()
            train_losses.append(loss.item())
 
        # validate the model 
        validate_epoch_bar = IntProgress(min=0, max=len(testloader), description='Validate:') 
        display(validate_epoch_bar) 
        

        size = len(testloader.dataset)
        model.eval() 
        for data, target in testloader:
            validate_epoch_bar.value += 1 
            startTime = time.process_time()
            output = model(data)
            endtime = time.process_time()
            loss = loss_fn(output, target)
            # record validation loss
            valid_losses.append(loss.item())
            correct += (output.argmax(1) == target).type(torch.float).sum().item()
            dur.append(endtime-startTime)

        correct /= size
        # print training/validation statistics 
        # calculate average loss over an epoch
        train_loss = np.average(train_losses)
        valid_loss = np.average(valid_losses)
        avg_train_losses.append(train_loss)
        avg_valid_losses.append(valid_loss)
        avg_dur.append(np.average(dur))
        
        epoch_len = len(str(n_epochs))
        
        print_msg = (f'[{epoch:>{epoch_len}}/{n_epochs:>{epoch_len}}] ' +
                     f'train_loss: {train_loss:.5f} ' +
                     f'valid_loss: {valid_loss:.5f}')
        print(print_msg)
        print(f"Accuracy: {(100*correct):>0.1f}%, AVg dur: {np.average(dur)} \n")
        

        train_losses = []
        valid_losses = []
        dur =[]
        correct = 0
        
        early_stopping(valid_loss, model)

        total_epoch = epoch
        
        if early_stopping.early_stop:
            print("Early stopping")
            break
        
    model.load_state_dict(torch.load('checkpoint.pt'))

    return  model, avg_train_losses, avg_valid_losses, total_epoch

In [None]:
model, avg_train_losses, avg_valid_losses, total_epoch = train_model(model, batch_size, patience, epochs)

# Export

In [None]:
from brevitas.export import export_qonnx
export_model_path = f"../../../../notebooks/GitHub/M_Project/Model/tmp/cnv_e{total_epoch}_1bit_trained.onnx"
input_tensor = torch.randn(1, 3, 32, 32)
input_tensor = input_tensor.to(device)
export_qonnx(module=model, input_t=input_tensor, export_path=export_model_path );