Skip to content

Thytu/ConvNet-CIFAR10

Repository files navigation

Code Interview

An old code interview reworked to test multiple torch concepts.
I was supposed to write a CIFAR10 classifier using a ConvNet, you can still find the model here.

Table of Contents

Setup

To use the projet please:

  • Install dvc
  • Do dvc repro

Usage

To reproduce the all dvc pipeline: dvc repro
To download the dataset: python3 src/data_handler.py
To find the best learning rate and optimizer according to Optuna: python3 src/hyperparameters.py
To rerun only the model training: python3 src/main.py

Architecture

This project can be read in mutliple parts:

  • The dvc pipeline to order every step
  • The logs handler to load the log level regarding var env
  • The yaml handler to load and write yaml
  • The data handler to download and load CIFAR10
  • The hyper-paramerters tuner using Optuna to find the best learning rate and optimizer
  • The model and its layers to classify CIFAR10
    The model include multiple torch concepts:
    • hooks to cast every inputs to torch.Tensor and send it to the selected device
    • Torchscript and jit to optimize the model for deployment by scripting the model.
      Note: Here, we do not use torch.trace
    • pruning to optimize the model for deployment by pruning the model.
    • quantization to optimize the model for deployment by casting tensors to uint8.
  • The main script to run the training loop and save resuling models

TODO

  • Add torch profiler to the project
  • Use CML w/ dvc
  • Use Optuna to find the best hyperparameters
    • Write module docstring
  • Use MLFlow

Contact

Created by @Thytu - feel free to contact me!