In [1]:
import sys
import logging
import os
from pathlib import Path
from pprint import pprint as pp
import pandas as pd

In [2]:
%pwd

'/mnt/d/imperial/second_term/adls/rs1923/mase_real'

In [3]:
# figure out the correct path
machop_path = Path(".").resolve().parent.parent /"new/mase/machop"
assert machop_path.exists(), "Failed to find machop at: {}".format(machop_path)
sys.path.append(str(machop_path))

In [4]:
# Turning you network to a graph

from chop.dataset import MaseDataModule, get_dataset_info
from chop.tools.logger import get_logger

from chop.passes.graph.analysis import (
    report_node_meta_param_analysis_pass,
    profile_statistics_analysis_pass,
)
from chop.passes.graph import (
    add_common_metadata_analysis_pass,
    init_metadata_analysis_pass,
    add_software_metadata_analysis_pass,
)
from chop.tools.get_input import InputGenerator
from chop.ir.graph.mase_graph import MaseGraph

from chop.models import get_model_info, get_model
from chop.tools.checkpoint_load import load_model

from chop.passes.graph.transforms.utils import metadata_value_type_cast_transform_pass


logger = get_logger("chop")
logger.setLevel(logging.INFO)

batch_size = 256
model_name = "jsc-tiny"
dataset_name = "jsc"


data_module = MaseDataModule(
    name=dataset_name,
    batch_size=batch_size,
    model_name=model_name,
    num_workers=0,
    # custom_dataset_cache_path="../../chop/dataset"
)
data_module.prepare_data()
data_module.setup()

CHECKPOINT_PATH = "/mnt/d/imperial/second_term/adls/rs1923/mase_real/mase_output/jsc-tiny_classification_jsc_2024-02-05/software/training_ckpts/best.ckpt"
model_info = get_model_info(model_name)
model = get_model(
    model_name,
    task="cls",
    dataset_info=data_module.dataset_info,
    pretrained=False,
    checkpoint = None)
model = load_model(load_name=CHECKPOINT_PATH, load_type="pl", model=model)

input_generator = InputGenerator(
    data_module=data_module,
    model_info=model_info,
    task="cls",
    which_dataloader="train",
)

dummy_in = next(iter(input_generator))
_ = model(**dummy_in)

  from .autonotebook import tqdm as notebook_tqdm


Total number of JSC_tiny parameters: 117


In [5]:
mg = MaseGraph(model=model)

In [6]:
# search space

pass_args = {
"by": "type",
"default": {"config": {"name": None}},

"linear": {
    "config": {
        "name": "integer",
        # data
        "data_in_width": 8,
        "data_in_frac_width": 4,
        # weight
        "weight_width": 8,
        "weight_frac_width": 4,
        # bias
        "bias_width": 8,
        "bias_frac_width": 4,
        }
},}

import copy
# build a search space
data_in_frac_widths = [(16, 8), (8, 6), (8, 4), (4, 2)]
w_in_frac_widths = [(16, 8), (8, 6), (8, 4), (4, 2)]
search_spaces = []
for d_config in data_in_frac_widths:
    for w_config in w_in_frac_widths:
        pass_args['linear']['config']['data_in_width'] = d_config[0]
        pass_args['linear']['config']['data_in_frac_width'] = d_config[1]
        pass_args['linear']['config']['weight_width'] = w_config[0]
        pass_args['linear']['config']['weight_frac_width'] = w_config[1]
        # dict.copy() and dict(dict) only perform shallow copies
        # in fact, only primitive data types in python are doing implicit copy when a = b happens
        search_spaces.append(copy.deepcopy(pass_args))

In [7]:
# Q1, Q2

# train for QAT(Quantization-Aware Training)

# This is not suitable for our case of PTQ

# But we could make comparisons on the performance between PTQ and QAT

'''
import torch
from torchmetrics.classification import MulticlassAccuracy
from ptflops import get_model_complexity_info
from torch import optim

from chop.passes.graph.transforms import (
    quantize_transform_pass,
    summarize_quantization_analysis_pass,
)

def init_mg(model):
    mg = MaseGraph(model=model)
    mg, _ = init_metadata_analysis_pass(mg, None)
    return mg

metric = MulticlassAccuracy(num_classes=5)

batch_size = 8
optimizer = optim.Adam(model.parameters(), lr=0.001)

max_epoch = 10
mg = init_mg(model)

for i, config in enumerate(search_spaces):
    mg, _ = quantize_transform_pass(mg, config)

    for epoch in range(max_epoch):
        for inputs in data_module.train_dataloader():
            xs, ys = inputs
            optimizer.zero_grad()
            preds = mg.model(xs)
            loss = torch.nn.functional.cross_entropy(preds, ys)  
            loss.backward()  
            optimizer.step()  
    
    torch.save({
        'model_state_dict': mg.model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'config': config,
    }, f'model_and_optimizer_{i}.pth')

    mg = init_mg(model)
    optimizer = optim.Adam(mg.model.parameters(), lr=0.001)
    '''

"\nimport torch\nfrom torchmetrics.classification import MulticlassAccuracy\nfrom ptflops import get_model_complexity_info\nfrom torch import optim\n\nfrom chop.passes.graph.transforms import (\n    quantize_transform_pass,\n    summarize_quantization_analysis_pass,\n)\n\ndef init_mg(model):\n    mg = MaseGraph(model=model)\n    mg, _ = init_metadata_analysis_pass(mg, None)\n    return mg\n\nmetric = MulticlassAccuracy(num_classes=5)\n\nbatch_size = 8\noptimizer = optim.Adam(model.parameters(), lr=0.001)\n\nmax_epoch = 10\nmg = init_mg(model)\n\nfor i, config in enumerate(search_spaces):\n    mg, _ = quantize_transform_pass(mg, config)\n\n    for epoch in range(max_epoch):\n        for inputs in data_module.train_dataloader():\n            xs, ys = inputs\n            optimizer.zero_grad()\n            preds = mg.model(xs)\n            loss = torch.nn.functional.cross_entropy(preds, ys)  \n            loss.backward()  \n            optimizer.step()  \n    \n    torch.save({\n        'm

In [7]:
# Q1, Q2

# search (grid search)

import torch
from torchmetrics.classification import MulticlassAccuracy
from ptflops import get_model_complexity_info
import time
import gc

from chop.passes.graph.transforms import (
    quantize_transform_pass,
    summarize_quantization_analysis_pass,
)

from chop.passes.graph.transforms.quantize.quantized_modules.linear import LinearInteger
from chop.passes.graph.utils import get_node_actual_target


def init_mg():
    mg = MaseGraph(model=model)
    mg, _ = init_metadata_analysis_pass(mg, None)
    mg, _ = add_common_metadata_analysis_pass(mg, {"dummy_in": dummy_in})
    return mg

metric = MulticlassAccuracy(num_classes=5)
batch_size = 256

# model storage size (unit: Byte)
def model_storage_size(model, weight_bit_width, bias_bit_width, data_bit_width):
    total_bits = 0 
    for name, param in model.named_parameters():
        if param.requires_grad and 'weight' in name:
            bits = param.numel() * weight_bit_width
            total_bits += bits

        elif param.requires_grad and 'bias' in name:
            bits = param.numel() * bias_bit_width
            total_bits += bits

    total_bits += data_bit_width*(1*16+1) # mean and variance

    total_bytes = total_bits / 8
    return total_bytes


# FLOP
def calculate_flop_for_linear(module, batch_size):
    in_features = module.in_features
    out_features = module.out_features
    return batch_size*(in_features * out_features)
def calculate_flop_for_batchnorm1d(module, batch_size):
    num_features = module.num_features
    # calculate the mean: num_features * batch_size  [for each element, (batch_size-1)add, 1division]
    # calculate the variance: (2*num_features+(num_features-1))*batch_size + (batch_size-1)  [for each element:2, for each sample: 2*num_features+(num_features-1)]
    # calculate the denominator (knowing variance): 2  [add bias & square root]
    # calculate for each sample xi: 4*num_features  [for each element, 4: 1*minus, 1*division, 1*multiply, 1*add]
    return num_features * batch_size + (2*num_features+(num_features-1))*batch_size + (batch_size-1) + 2 + batch_size*(4*num_features)
def calculate_flop_for_relu(module, input_features, batch_size):
    # per element comparison with 0 (in essence, a minus)
    input_features = input_features*batch_size
    return input_features
def add_flops_bitops_analysis_pass(graph):
    total_flops = 0
    for node in graph.fx_graph.nodes:
        if isinstance(get_node_actual_target(node), torch.nn.modules.Linear):
            flops = calculate_flop_for_linear(get_node_actual_target(node), batch_size)
            total_flops += flops
        elif isinstance(get_node_actual_target(node), torch.nn.modules.BatchNorm1d):
            flops = calculate_flop_for_batchnorm1d(get_node_actual_target(node), batch_size)
            total_flops += flops
    flops = calculate_flop_for_relu(get_node_actual_target(node), 16, batch_size)
    total_flops += flops
    flops = calculate_flop_for_relu(get_node_actual_target(node), 5, batch_size)
    total_flops += flops
    return total_flops



# bit-wise Operations (unit: number)
def bit_wise_op(model, input_res, weight_width, bias_width, data_width, batch_size):
    total_bitwise_ops = 0
    for name, module in model.named_modules():
        if isinstance(module, LinearInteger):
            bitwise_ops = calculate_bitwise_ops_for_linear(module, input_res, weight_width, bias_width, data_width, batch_size)
            total_bitwise_ops += bitwise_ops
    return total_bitwise_ops

def calculate_bitwise_ops_for_linear(module, input_res, weight_bit_width, bias_bit_width, data_bit_width, batch_size):
    in_features = module.in_features
    out_features = module.out_features
    bitwise_ops_per_multiplication = data_bit_width * weight_bit_width
    bitwise_ops_per_addition = data_bit_width * weight_bit_width
    bitwise_ops_per_output_feature = in_features * bitwise_ops_per_multiplication + (in_features - 1) * bitwise_ops_per_addition
    if module.bias is not None:
        bitwise_ops_per_output_feature += bias_bit_width
    total_bitwise_ops = out_features * bitwise_ops_per_output_feature
    return total_bitwise_ops*batch_size


# search (grid search)

mg = init_mg()
recorded_metrics = []
for i, config in enumerate(search_spaces):
    linear_config = config['linear']['config']
    data_width = linear_config['data_in_width']
    data_frac_width = linear_config['data_in_frac_width']
    weight_width = linear_config['weight_width']
    weight_frac_width = linear_config['weight_frac_width']
    bias_width = linear_config['bias_width']
    bias_frac_width = linear_config['bias_frac_width']

    mg, _ = quantize_transform_pass(mg, config)
    data_bit_width = config['linear']['config']['data_in_width']
    weight_bit_width = config['linear']['config']['weight_width']
    bias_bit_width = config['linear']['config']['bias_width']

    size = model_storage_size(mg.model, weight_bit_width, bias_bit_width, data_bit_width)  # model size after it has been quantized
    flop = add_flops_bitops_analysis_pass(mg)
    bit_op = bit_wise_op(mg.model, (16,), weight_bit_width, bias_bit_width, data_bit_width, batch_size)

    acc_avg, loss_avg = 0, 0
    accs, losses, latencies = [], [], []

    for inputs in data_module.train_dataloader():
        xs, ys = inputs
        start_time = time.time()
        preds = mg.model(xs)
        end_time = time.time()
        latency = end_time - start_time 
        latencies.append(latency*1000) # (unit: ms)

        acc = metric(preds, ys)
        accs.append(acc)
        loss = torch.nn.functional.cross_entropy(preds, ys)
        losses.append(loss)

    acc_avg = sum(accs) / len(accs)
    loss_avg = sum(losses) / len(losses)
    latency_avg = sum(latencies) / len(latencies) 

    recorded_metrics.append({
        "data_width": data_width,
        "data_frac_width": data_frac_width,
        "weight_width": weight_width,
        "weight_frac_width": weight_frac_width,
        "bias_width": bias_width,
        "bias_frac_width": bias_frac_width,
        
        "acc(%)": (acc_avg.item())*100,
        "loss": loss_avg.item(),
        "latency(ms)": latency_avg*1000,
        "model_size(Byte)": size,
        "flop(number)": flop,
        "bitwise_ops(number)": bit_op,
    }) 
    del accs, losses, latencies, acc_avg, loss_avg, latency_avg, size, flop, bit_op
    gc.collect()

    mg = init_mg()

gc.collect()

<chop.ir.graph.mase_graph.MaseGraph object at 0x7f7bf7a1b1f0>
<chop.ir.graph.mase_graph.MaseGraph object at 0x7f7d43cc7130>
<chop.ir.graph.mase_graph.MaseGraph object at 0x7f7bf7a688b0>
<chop.ir.graph.mase_graph.MaseGraph object at 0x7f7c3686ea70>
<chop.ir.graph.mase_graph.MaseGraph object at 0x7f7d43e3f490>
<chop.ir.graph.mase_graph.MaseGraph object at 0x7f7bf00d0b20>
<chop.ir.graph.mase_graph.MaseGraph object at 0x7f7bf7a1a320>
<chop.ir.graph.mase_graph.MaseGraph object at 0x7f7d43cc4dc0>
<chop.ir.graph.mase_graph.MaseGraph object at 0x7f7ce4e98970>
<chop.ir.graph.mase_graph.MaseGraph object at 0x7f7c2db65ed0>
<chop.ir.graph.mase_graph.MaseGraph object at 0x7f7bf00d9b40>
<chop.ir.graph.mase_graph.MaseGraph object at 0x7f7c3686ea70>
<chop.ir.graph.mase_graph.MaseGraph object at 0x7f7bf7e06680>
<chop.ir.graph.mase_graph.MaseGraph object at 0x7f7d43e3f490>
<chop.ir.graph.mase_graph.MaseGraph object at 0x7f7bf7e06410>
<chop.ir.graph.mase_graph.MaseGraph object at 0x7f7bee682b60>
<chop.ir

229

In [8]:
recorded_metrics

[{'data_width': 16,
  'data_frac_width': 8,
  'weight_width': 16,
  'weight_frac_width': 8,
  'bias_width': 8,
  'bias_frac_width': 4,
  'acc(%)': 51.60192251205444,
  'loss': 1.325373888015747,
  'latency(ms)': 349.6421604923392,
  'model_size(Byte)': 247.0,
  'flop(number)': 58625,
  'bitwise_ops(number)': 10168320},
 {'data_width': 16,
  'data_frac_width': 8,
  'weight_width': 8,
  'weight_frac_width': 6,
  'bias_width': 8,
  'bias_frac_width': 4,
  'acc(%)': 51.70328617095947,
  'loss': 1.3258105516433716,
  'latency(ms)': 298.0583218130465,
  'model_size(Byte)': 151.0,
  'flop(number)': 58625,
  'bitwise_ops(number)': 5089280},
 {'data_width': 16,
  'data_frac_width': 8,
  'weight_width': 8,
  'weight_frac_width': 4,
  'bias_width': 8,
  'bias_frac_width': 4,
  'acc(%)': 51.464223861694336,
  'loss': 1.3287193775177002,
  'latency(ms)': 307.19731413746007,
  'model_size(Byte)': 151.0,
  'flop(number)': 58625,
  'bitwise_ops(number)': 5089280},
 {'data_width': 16,
  'data_frac_widt

In [9]:
metrics = pd.DataFrame(recorded_metrics)

metrics

Unnamed: 0,data_width,data_frac_width,weight_width,weight_frac_width,bias_width,bias_frac_width,acc(%),loss,latency(ms),model_size(Byte),flop(number),bitwise_ops(number)
0,16,8,16,8,8,4,51.601923,1.325374,349.64216,247.0,58625,10168320
1,16,8,8,6,8,4,51.703286,1.325811,298.058322,151.0,58625,5089280
2,16,8,8,4,8,4,51.464224,1.328719,307.197314,151.0,58625,5089280
3,16,8,4,2,8,4,51.015449,1.331339,306.457938,103.0,58625,2549760
4,8,6,16,8,8,4,51.639044,1.342087,301.281766,230.0,58625,5089280
5,8,6,8,6,8,4,51.695049,1.34306,312.072526,134.0,58625,2549760
6,8,6,8,4,8,4,51.477736,1.344443,316.188802,134.0,58625,2549760
7,8,6,4,2,8,4,50.961071,1.344631,332.394076,86.0,58625,1280000
8,8,4,16,8,8,4,51.612431,1.32547,342.234177,230.0,58625,5089280
9,8,4,8,6,8,4,51.707715,1.32586,342.820482,134.0,58625,2549760


In [1]:
%cd machop

/mnt/d/imperial/second_term/adls/rs1923/mase_real/machop


In [2]:
!./ch search --config configs/examples/jsc_toy_by_type.toml --load /mnt/d/imperial/second_term/adls/new/mase/mase_output/jsc-tiny_classification_jsc_2024-02-05/software/training_ckpts/best.ckpt

Total number of JSC_1923 parameters: 3285
Total number of JSC_Tiny parameters: 117
Seed set to 0
+-------------------------+--------------------------+--------------------------+--------------------------+--------------------------+
| Name                    |         Default          |       Config. File       |     Manual Override      |        Effective         |
+-------------------------+--------------------------+--------------------------+--------------------------+--------------------------+
| task                    |      [38;5;8mclassification[0m      |           cls            |                          |           cls            |
| load_name               |           [38;5;8mNone[0m           | [38;5;8m../mase_output/jsc-tiny_[0m | /mnt/d/imperial/second_t | /mnt/d/imperial/second_t |
|                         |                          | [38;5;8mclassification_jsc_2024-[0m | erm/adls/new/mase/mase_o | erm/adls/new/mase/mase_o |
|                         |        

In [None]:
# sample efficiency:
# uses as few trials(samples) as possible to achieve a certain(optimal) hyperparameters