In [None]:
%load_ext autoreload
%autoreload 2

import sys
import os

sys.path.insert(1, os.path.join(sys.path[0], '../..')) # Add parent directory to path

## Set GPUs

In [None]:
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

import torch
print(torch.cuda.is_available())

from utils import read_yaml
from parse_arguments import ConfigParser
from train import set_up_trainer

# Pre-training FP Models

## Load Config

In [None]:
config_fname = r'../configs/S4D_small.yaml' # Path to FP model config file
# resume_checkpoint = r"../log/S4D_small/fp32_16heads/checkpoint/ckpt.pth" # Path to checkpoint to resume training from
resume_checkpoint = None

config_yaml = read_yaml(config_fname)
# modification = {"model;type": 32} # Optional modification to the config
modification = None
config = ConfigParser(config_yaml, resume_checkpoint, modification=modification, save_log=False) # Set save_log to True to save trained model and logs

## Set up FP trainer

Hyperparams are set in config file.

In [None]:
trainer = set_up_trainer(config)

## Train and Eval

In [None]:
trainer.train()
trainer.eval()

# Quantisation

We call the set_up_quantizers function to perform quantisation. Quantization parameters are set in the qconfig file.

In [None]:
from quantize import set_up_quantizers
from pathlib import Path

# Set up quantizers for the model
q_config_fname = r"../configs/q_config.yaml"
q_config_yaml = read_yaml(q_config_fname)

# q_modification = {...} # Optional modification to the quantization config
q_modification = None 

q_config = ConfigParser(q_config_yaml, None, modification=q_modification, save_log=False)


q_trainer = set_up_trainer(config)
set_up_quantizers(q_trainer, q_config)

q_trainer.model

### Evaluating quantized models

The default is dynamically quantized.

In [None]:
# Dynamically quantised: default
q_trainer.eval()

For static quantization, run calibration and call set_static_quantization.

In [None]:
# Calibration for static quantization
q_trainer.calibrate(num_batches=1) # num_batches is the number of batches to use for calibration
q_trainer.set_static_quantization()
q_trainer.eval()

# You can return to dynamic quantization if needed
# q_trainer.set_dynamic_quantization()

For QAT, simply call train.

In [None]:
# QAT
train_log = q_trainer.train()
QAT_eval_log = q_trainer.eval()