In [1]:
import torch
import torch.utils.data
from torch import nn
from torch.nn import functional as F

from ignite.engine import Events, Engine
from ignite.metrics import Accuracy, Loss

import numpy as np
import sklearn.datasets

import matplotlib.pyplot as plt
import seaborn as sns

sns.set()

In [2]:
noise = 0.1
X_train, y_train = sklearn.datasets.make_moons(n_samples=1500, noise=noise)
X_test, y_test = sklearn.datasets.make_moons(n_samples=200, noise=noise)
ds_train = torch.utils.data.TensorDataset(torch.from_numpy(X_train).float(), F.one_hot(torch.from_numpy(y_train)).float())
dl_train = torch.utils.data.DataLoader(ds_train, batch_size=64, shuffle=True, drop_last=True)

In [3]:
class Network(nn.Module):
    def __init__(self, features):
        super().__init__()
        
        self.fc1 = nn.Linear(2, features)
        self.fc2 = nn.Linear(features, features)
        self.fc3 = nn.Linear(features, features)
        
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        
Net = Network(20)
for name, param in Net.named_buffers():
    print(name)
print("------------")
state_dict = Net.state_dict()
for name, value in state_dict.items():
    print(name)

------------
fc1.weight
fc1.bias
fc2.weight
fc2.bias
fc3.weight
fc3.bias


In [4]:
import copy
from torch.quantization import quantize_fx

quantise = True
if quantise:
    dataiter = iter(dl_train)
    images, labels = next(dataiter)

    m = copy.deepcopy(Net)
    m.to("cpu")
    m.eval()
    qconfig_dict = {"": torch.quantization.get_default_qconfig("fbgemm")}
    model_prep = quantize_fx.prepare_fx(m, qconfig_dict, images)

    with torch.inference_mode():
        for _ in range(10):
            images, labels = next(dataiter)
            model_prep(images)
    model_quant = quantize_fx.convert_fx(model_prep)
    
for name, param in model_quant.named_buffers():
    print(name)
    
for name, param in model_quant.named_parameters():
    print(name)
print("-----------")
state_dict = model_quant.state_dict()
for name, value in state_dict.items():
    print(name)

-----------


