In [1]:
import pickle
import torch.nn as nn
from torch.utils import data
from tqdm import tqdm
import torch

In [2]:
import sys
sys.path.append("../")

In [3]:
from models import models as mdls
from datasets import datasets

In [4]:
training_set = pickle.load(open("../data/augmented_data/training_set.pkl", 'rb'))
testing_set = pickle.load(open("../data/augmented_data/testing_set.pkl", 'rb'))

In [5]:
if not isinstance(training_set, datasets.FaceRecognitionDataset):
    raise ValueError("Invalid training set type")

if not isinstance(testing_set, datasets.FaceRecognitionDataset):
    raise ValueError("Invalid testing set type")

Model Initialization

In [6]:
from torchvision import models
from torch import backends

loss = nn.CrossEntropyLoss
main_device = torch.device('cpu')

if backends.mps.is_available():
    main_device = torch.device('mps')

elif torch.cuda.is_available():
    main_device = torch.device('cuda')

model = mdls.FaceRecognitionNet(
    loss_function=loss,
    num_classes=3,
    max_epochs=50,
    learning_rate=3e-4,
    weights=models.ResNet50_Weights.DEFAULT,
    main_device=main_device,
    weight_decay=0.01,
    batch_size=32,
)

In [7]:
training_set.labels

array([1, 0, 1, ..., 1, 1, 0])

Quantizing Neural Network Model using Dynamic Quantization

In [8]:
from torch.quantization import quantize_dynamic

if torch.cuda.is_available():
    quantized_model = quantize_dynamic(
        model=model.model,
        dtype=torch.qint8,
        qconfig_spec={nn.Linear, nn.Conv2d}
    )
else:
    quantized_model = model

In [9]:
quantized_model.enable_gradient_trace()

Training Neural Network

In [None]:
mean_loss = quantized_model.train(image_dataset=training_set)

  0%|                                                                                           | 0/436 [00:00<?, ?it/s]

In [None]:
mean_loss

Testing model on given testing dataset

In [None]:
eval_loss = model.evaluate(image_dataset=testing_set)

In [None]:
eval_loss

Saving Neural Network to ONNX format

In [None]:
model.export(model_name='neural_net', model_path='../prod_models')