In [None]:
from collections import OrderedDict
import sys
import os
from pathlib import Path
import torch
import torchvision
from torch import nn

import snntorch as snn
from snntorch import functional as SF


import nnt_cli.utils as nu
from pro_temp import _Project_Template



notebook_path = Path().resolve()
parent_path = os.path.join(notebook_path, "..")
sys.path.append(parent_path)

notebook_name=""
nu.settin.gen_settin.get_notebook_name(notebook_path)
notebook_name='L4_no_re'

[OK] D:\SoftProject\Python\SNN_demo\QSNN_NMNIST\L4_no_re.ipynb: Notebook name is correct.
[OK] D:\SoftProject\Python\SNN_demo\QSNN_NMNIST\L4_re.ipynb: Notebook name is correct.


In [None]:
# Parameters in this cell can be changed using script.
# But you have to pass these parameters to the model class. 
script_mode = False
# Use script_mode to use separate settings between using script and direct running.
debug_mode = True

batch_size = 256
num_epochs = 1
lr = 1e-3

beta = 0.95
bit_width=[8,4,2,1]
num_hiddens=1024

# num_steps=1000

# `match_name` and `remark` will be used to generate the file name of the saved model.
# When using script, the name set here will be used.
# It's better for `match_name` to fit variable setting in script.
match_name=f"{notebook_name}_bit_width_"
remark=f"4_linear_layers"

learn_beta_threshold=True

train_method = "one"
# "one": Train for one time.
# "iter": Train in one iteration, like sweep.
# "diter": Train in two nested iterations. 
# The settings of iteration are in the method `set_iter`.

reset_mode=[1,1,1,1]
# Reset mode for Leaky layer.
# 0 is "zero", 1 is "subtract"
# Index of number is the index of Leaky layer.

cp_fpath=""
# checkpoint path

random_seed=0

In [None]:
# Use class internal `set_name` method, when manually run the notebook.
# For debug
if script_mode is False:
    match_name=None
    remark=None

In [None]:
# papermill may inject list parameter as str.
# Here is a example of bit_width setting.

import ast
if isinstance(bit_width, str):
    try:
        bit_width = ast.literal_eval(bit_width)
    except (SyntaxError, ValueError) as e:
        raise ValueError("Wrong format, can't transform it to list.") from e

In [None]:
class _Project_Run(_Project_Template):
    def set_name(self):
        # Default name setting of results.
        self.match_name=f"{self.notebook_name}_qsnn_NMNIST_manual_"
        self.remark=f"4_linear_layers"

    def init_params(self):
        """
        **Essential**
        Set the global parameters for the class, including settings.
        Suggestion: Set init parameters here that are not necessary to be changed by script.
        """

        super().init_params()

        self.infer_size=0.01
        # Size of NMNIST test dataset is 10000
        # Because the quantized inference dataflow will be saved, a large inference dataset will significantly increase the time of training. 
        # In HPC 1 sample of data may costs 3.42 MB and 3 min to write.

        self.with_cache="False"

        self.quick_debug=False
        self.qb_dataset_size=0.1

        # self.sav_state = False
        # self.sav_checkpoint = False

        self.weight_decay=0.01
        # For optimizer

        self.dropout_prob=0.2
    
    def set_iter(self):
        """
        Set the iteration variables and values.
        The variables set here will be used for iteration training if you call the iteration training methods.
        `self.variable_name` will be used first when `iter` (sweep) mode is called.
        `self.variable_name` must be consistent with the variable names in the class defined.
        """


        # self.vary_list=[
        #     [8,8,8,8],
        #     [4,4,4,4],
        #     [2,2,2,2],
        #     [1,1,1,1],
        # ]
        # self.variable_name="bit_width"
        # # Must be consistent with the variable names in the class defined.

        # self.vary_list2=[1024,2048]
        # self.variable_name2="num_hiddens"
        # # The other variables, bit_width, are set by script, so the full performance of HPC could be used.  
    
    def set_model(self):
        """
        **Essential**
        Set the training model.
        `self.net`, `self.loss`, `self.optimizer` must be set.
        In subclass, do NOT use super() to inherit optimizer!!!
        """
        # super().set_model()
        # DO NOT USE super() to inherit optimizer!!!

        
        # num_inputs =2*34*34
        # num_hiddens=self.num_hiddens
        # num_outputs = 10

        # act_quant=nu.settin.qnn_settin.get_act_quant(self.bit_width[0])
        # weight_quant=[nu.settin.qnn_settin.get_weight_quant(bw) for bw in self.bit_width]

        # reset_mode_str=self.decode_reset_mode(self.reset_mode)
        # self.net = nn.Sequential(
        #         # unnamed layer, use layer1.parameters() or layer1.weight to access
        #         # named layer, use net.<name>.param to access
        #         OrderedDict([
        #             ('flaten', nn.Flatten()), # Already flatted in the dataset.
        #             # ('input_sav0', du.InputSaviorLayer(f"{self.sav_data_path}/{self.vary_title}","input_bf_quant",squeeze=True)),
        #             ('quant_ident', qnn.QuantIdentity(act_quant=act_quant,return_quant_tensor=True,
        #                                             )),
        #             ('input_sav1', du.InputSaviorLayer(f"{self.sav_data_path}/{self.vary_title}","input_af_quant",squeeze=True)),

        #             ('quant_linear_in', qnn.QuantLinear(num_inputs, num_hiddens,bias=False,device=self.device,
        #                                             weight_quant=weight_quant[0],)), 
        #             ('dropout1', nn.Dropout(self.dropout_prob)),                               
        #             ('batch_norm1', nn.BatchNorm1d(num_hiddens)),
        #             ('leaky1', snn.Leaky(beta=self.beta, init_hidden=True,reset_mechanism=reset_mode_str[0], 
        #                                  learn_beta=self.learn_beta_threshold, learn_threshold=self.learn_beta_threshold)),
                    
        #             ('quant_linear2', qnn.QuantLinear(num_hiddens, num_hiddens,bias=False,device=self.device,
        #                                             weight_quant=weight_quant[1],)),  
        #             ('dropout2', nn.Dropout(self.dropout_prob)),                              
        #             ('batch_norm2', nn.BatchNorm1d(num_hiddens)),
        #             ('leaky2', snn.Leaky(beta=self.beta, init_hidden=True,reset_mechanism=reset_mode_str[1], 
        #                                  learn_beta=self.learn_beta_threshold, learn_threshold=self.learn_beta_threshold)),

        #             ('quant_linear3', qnn.QuantLinear(num_hiddens, num_hiddens,bias=False,device=self.device,
        #                                             weight_quant=weight_quant[2],)),
        #             ('dropout3', nn.Dropout(self.dropout_prob)),  
        #             ('batch_norm3', nn.BatchNorm1d(num_hiddens)),
        #             ('leaky3', snn.Leaky(beta=self.beta, init_hidden=True,reset_mechanism=reset_mode_str[2], 
        #                                  learn_beta=self.learn_beta_threshold, learn_threshold=self.learn_beta_threshold)),

        #             ('quant_linear_out',qnn.QuantLinear(num_hiddens, num_outputs,bias=False,device=self.device,
        #                                             weight_quant=weight_quant[3],)),
        #             ('leaky_out', snn.Leaky(beta=self.beta, init_hidden=True, output=True, reset_mechanism=reset_mode_str[3], 
        #                                  learn_beta=self.learn_beta_threshold, learn_threshold=self.learn_beta_threshold))
        #         ])
        #     ).to(self.device)
        
        # self.loss=SF.mse_count_loss(correct_rate=0.8, incorrect_rate=0.2)

        # self.optimizer = torch.optim.AdamW(self.net.parameters(), lr=self.learning_rate, betas=(0.9, 0.999), weight_decay=self.weight_decay)

        # self.scheduler = torch.optim.lr_scheduler.OneCycleLR(
        #     self.optimizer,
        #     max_lr=self.learning_rate,
        #     total_steps=self.num_epochs,
        #     pct_start=0.3,
        #     anneal_strategy='cos'
        # )
    
    def init_net(self):
        """
        Init net if you need to.
        """
        # print("No init of the network!")
        # # No init, because I'm not sure how to init the quantized layer.
    

run_class_i=_Project_Run(notebook_name,script_mode,debug_mode,batch_size,num_epochs,lr,
                     beta,bit_width,num_hiddens,reset_mode,match_name=match_name,remark=remark,
                     learn_beta_threshold=learn_beta_threshold)

In [None]:
# For better representative, you can manually set seed:
run_class_i.set_seed(random_seed)

In [None]:
# You can use this block to check dataset

# run_class_i.init_params()
# run_class_i.set_path()
# run_class_i.set_dataset()
# run_class_i.set_dataloader()

# frame,_ = next(iter(run_class_i.train_loader))

# print(frame.shape )

In [None]:
# Check dataset by iterating all dataset to find out whether there is a store/cache error.
# You can also use this to set cache, if you have set up cache mode first.

# run_class_i.check_dataset()


# # If cache has error, use this to clear the cache.
# run_class_i.remove_cache()

In [None]:
# A interior method to analyze dataset

# run_class_i.dataset_analyze()

In [None]:
# A visualize method to call multi visual functions to analyze dataset

# run_class_i.get_visual_frame_distribution(quantile=0.98)

In [None]:
# Setting of the training method
# "one": Train for one time.
# "iter": Train in one iteration, like sweep. Only use run_class_i.variable_name and its values to do iteration training.
# "diter": Train in two nested iterations. 
# "cp": Load from checkpoint. And the model will start form that state.

match train_method:
    case "one":
        run_class_i.train_onetime(no_plot=True)
    case "iter":
        run_class_i.train_iter(no_plot=True)
    case "diter":
        run_class_i.train_double_iter(no_plot=True)
    case "cp":
        run_class_i.load_checkpoint(cp_fpath,no_plot=True)
    case _:
        raise ValueError("Invalid train method!")

In [None]:
run_class_i.plot_final()

In [None]:
run_class_i.plot_record()

QSNN example training with QuantTensor=True, 
Test different mix-precision for the QNN and QSNN 1,2,4
Floating point SNN


In [None]:
# Analyze after training

run_class_i.check_gradient_norm()