In [None]:
import sys
import logging
import os
from pathlib import Path
from pprint import pprint as pp
import pandas as pd

In [None]:
# figure out the correct path
machop_path = Path(".").resolve().parent.parent /"rs1923/mase_real/machop"
assert machop_path.exists(), "Failed to find machop at: {}".format(machop_path)
sys.path.append(str(machop_path))

from chop.dataset import MaseDataModule, get_dataset_info
from chop.tools.logger import set_logging_verbosity, get_logger

from chop.passes.graph.analysis import (
    report_node_meta_param_analysis_pass,
    profile_statistics_analysis_pass,
)
from chop.passes.graph import (
    add_common_metadata_analysis_pass,
    init_metadata_analysis_pass,
    add_software_metadata_analysis_pass,
)
from chop.tools.get_input import InputGenerator
from chop.ir.graph.mase_graph import MaseGraph

from chop.models import get_model_info, get_model

from chop.tools.checkpoint_load import load_model

set_logging_verbosity("debug")

logger = get_logger("chop")
logger.setLevel(logging.INFO)

batch_size = 512
model_name = "vgg7"
dataset_name = "cifar10"

data_module = MaseDataModule(
    name=dataset_name,
    batch_size=batch_size,
    model_name=model_name,
    num_workers=0,
)
data_module.prepare_data()
data_module.setup()

model_info = get_model_info(model_name)
ori_model = get_model(
    model_name,
    task="cls",
    dataset_info=data_module.dataset_info,
    pretrained=False,
    checkpoint = None)

input_generator = InputGenerator(
    data_module=data_module,
    model_info=model_info,
    task="cls",
    which_dataloader="train",
)

dummy_in = {"x": next(iter(data_module.train_dataloader()))[0]}

In [None]:
from copy import deepcopy
from torch import nn
from chop.actions.search.search_space.base import SearchSpaceBase
from chop.passes.graph.transforms.quantize import (
    QUANTIZEABLE_OP,
    quantize_transform_pass,
)
from chop.ir.graph.mase_graph import MaseGraph
from chop.passes.graph import (
    init_metadata_analysis_pass,
    add_common_metadata_analysis_pass,
)

from collections import defaultdict
from chop.passes.graph.utils import get_parent_name

In [None]:
# prepare for the network transform (i.e, channel)

DEFAULT_NETWORK_CONFIG = { 
    "config": {
        "by":"name",
        "default": {"config": {"name": None}},
        "seq_blocks_0": {
            "config": {
                "name": "output_only",
                "channel_output": 128,
                }
            },
        "seq_blocks_1": {
            "config": {
                "name": "bn",
                "channel": 128,
                }
            },
        "seq_blocks_2": {
            "config": {
                "name": "relu",
                }
            },

        "seq_blocks_3": {
            "config": {
                "name": "both",
                "channel_input": 128,
                "channel_output": 128,
                }
            },
        "seq_blocks_4": {
            "config": {
                "name": "bn",
                "channel": 128,
                }
            },
        "seq_blocks_5": {
            "config": {
                "name": "relu",
                }
            },
        "seq_blocks_6": {
            "config": {
                "name": "maxpool",
                }
            },

        "seq_blocks_7": {
            "config": {
                "name": "both",
                "channel_input": 128,
                "channel_output": 256,
                }
            },
        "seq_blocks_8": {
            "config": {
                "name": "bn",
                "channel": 256,
                }
            },
        "seq_blocks_9": {
            "config": {
                "name": "relu",
                }
            },

        "seq_blocks_10": {
            "config": {
                "name": "both",
                "channel_input": 256,
                "channel_output": 256,
                }
            },
        "seq_blocks_11": {
            "config": {
                "name": "bn",
                "channel": 256,
                }
            },
        "seq_blocks_12": {
            "config": {
                "name": "relu",
                }
            },
        "seq_blocks_13": {
            "config": {
                "name": "maxpool",
                }
            },

        "seq_blocks_14": {
            "config": {
                "name": "both",
                "channel_input": 256,
                "channel_output": 512,
                }
            },
        "seq_blocks_15": {
            "config": {
                "name": "bn",
                "channel": 512,
                }
            },
        "seq_blocks_16": {
            "config": {
                "name": "relu",
                }
            },

        "seq_blocks_17": {
            "config": {
                "name": "both",
                "channel_input": 512,
                "channel_output": 512,
                }
            },
        "seq_blocks_18": {
            "config": {
                "name": "bn",
                "channel": 512,
                }
            },
        "seq_blocks_19": {
            "config": {
                "name": "relu",
                }
            },
        "seq_blocks_20": {
            "config": {
                "name": "maxpool",
                }
            },

        "seq_blocks_22": {
            "config": {
                "name": "linear",
                "channel_input": 8192,
            }
        },
    }   
}


   
def instantiate_conv2d(in_channels, out_channels, kernel_size=3, padding=1):
    return nn.Conv2d(
        in_channels=in_channels,
        out_channels=out_channels,
        kernel_size=kernel_size,
        padding=padding,
    )

def redefine_conv2d_transform_pass(graph, pass_args=None):
    pass_args_copy = deepcopy(pass_args)
    main_config = pass_args_copy.pop('config')
    default = main_config.pop('default', None)
    if default is None:
        raise ValueError("default value must be provided.")
    for node in graph.fx_graph.nodes:
        config = main_config.get(node.name, default)['config']
        name = config.get("name", None)
        if name == "output_only" or name == "both":
            ori_module = graph.modules[node.target]
            if isinstance(ori_module, nn.Conv2d): # Ensure the node is a Conv2d layer 
                in_channels = ori_module.in_channels
                out_channels = ori_module.out_channels
                if name == "output_only":
                    out_channels = config["channel_output"]
                elif name == "both":
                    in_channels = config["channel_input"]
                    out_channels = config["channel_output"]      
                new_module = instantiate_conv2d(
                    in_channels, out_channels
                )
                parent_name, name = get_parent_name(node.target)
                setattr(graph.modules[parent_name], name, new_module)
    return graph, {}
    

def instantiate_bn(num_features):
    return nn.BatchNorm2d(
        num_features=num_features,
    )

def redefine_bn_transform_pass(graph, pass_args=None):
    pass_args_copy = deepcopy(pass_args)
    main_config = pass_args_copy.pop('config')
    default = main_config.pop('default', None)
    if default is None:
        raise ValueError("default value must be provided.")
    for node in graph.fx_graph.nodes:
        config = main_config.get(node.name, default)['config']
        name = config.get("name", None)
        if name == "bn":
            ori_module = graph.modules[node.target]
            if isinstance(ori_module, nn.BatchNorm2d): # Ensure the node is a BatchNorm2d layer 
                #num_features = ori_module.num_features
                if 'channel' in config:
                    num_features = config['channel']
                new_module = instantiate_bn(
                    num_features
                )
                parent_name, name = get_parent_name(node.target)
                setattr(graph.modules[parent_name], name, new_module)
    return graph, {}


def instantiate_linear(in_features, out_features, bias):
    if bias is not None:
        bias = True
    return nn.Linear(
        in_features=in_features, 
        out_features=1024,
        bias=bias
    )

def redefine_linear_transform_pass(graph, pass_args=None):
    # graph = self.mg
    pass_args_copy = deepcopy(pass_args)
    main_config = pass_args_copy.pop('config')
    default = main_config.pop('default', None)
    if default is None:
        raise ValueError(f"default value must be provided.")
    for node in graph.fx_graph.nodes: 
        config = main_config.get(node.name, default)['config']
        name = config.get("name", None)
        if name == "linear":
            ori_module = graph.modules[node.target]
            if isinstance(ori_module, nn.Linear): # Ensure the node is a linear layer 
                in_features = ori_module.in_features
                if "channel_input" in config:
                    in_features = config['channel_input'] # this is "="
                    new_module = instantiate_linear(in_features, 1024, True)
                    parent_name, name = get_parent_name(node.target)
                    setattr(graph.modules[parent_name], name, new_module)
    return graph, {}


def instantiate_relu(boolean):
    return nn.ReLU(inplace=boolean)

def redefine_relu_pass(graph, pass_args=None):
    # graph = self.mg
    pass_args_copy = deepcopy(pass_args)
    main_config = pass_args_copy.pop('config')
    default = main_config.pop('default', None)
    if default is None:
        raise ValueError(f"default value must be provided.")
    for node in graph.fx_graph.nodes:
        config = main_config.get(node.name, default)['config']
        name = config.get("name", None)
        if name == "relu":
            new_module = instantiate_relu(True)
            parent_name, name = get_parent_name(node.target)
            setattr(graph.modules[parent_name], name, new_module)
    return graph, {}


def instantiate_maxpool():
    return nn.MaxPool2d((2,2))

def redefine_pooling_transform_pass(graph, pass_args=None):
    pass_args_copy = deepcopy(pass_args)
    main_config = pass_args_copy.pop('config')
    default = main_config.pop('default', None)
    if default is None:
        raise ValueError("default value must be provided.")
    
    for node in graph.fx_graph.nodes:
        config = main_config.get(node.name, default)['config']
        name = config.get("name", None)
        
        if name == "maxpool":
            ori_module = graph.modules[node.target]
            if isinstance(ori_module, nn.MaxPool2d):
                new_module = instantiate_maxpool()
                parent_name, name = get_parent_name(node.target)
                setattr(graph.modules[parent_name], name, new_module)
                
    return graph, {}

pass_config_conv2d = {
"by": "name",
"default": {"config": {"name": None}},
"seq_blocks_0": {
    "config": {
        "name": "output_only",
        "channel_output": 128,
        }
    },
"seq_blocks_3": {
    "config": {
        "name": "both",
        "channel_output": 128,
        }
    },
"seq_blocks_7": {
    "config": {
        "name": "both",
        "channel_output": 256,
    }
},
"seq_blocks_10": {
    "config": {
        "name": "both",
        "channel_output": 256,
    }
},
"seq_blocks_14": {
    "config": {
        "name": "both",
        "channel_output": 512,
    }
},
"seq_blocks_17": {
    "config": {
        "name": "both",
        "channel_output": 512,
    }
},
}

pass_config_linear = {
"by": "name",
"default": {"config": {"name": None}},
"seq_blocks_22": {
    "config": {
        "name": "linear",
        "channel_input": 8192,
    }
}
}

pass_config_bn = {
"by": "name",
"default": {"config": {"name": None}},
"seq_blocks_1": {
"config": {
    "name": "bn",
    "channel": 128,
    }
},
"seq_blocks_4": {
"config": {
    "name": "bn",
    "channel": 128,
    }
},
"seq_blocks_8": {
"config": {
    "name": "bn",
    "channel": 256,
    }
},
"seq_blocks_11": {
"config": {
    "name": "bn",
    "channel": 256,
    }
},
"seq_blocks_15": {
"config": {
    "name": "bn",
    "channel": 512,
    }
},
"seq_blocks_18": {
"config": {
    "name": "bn",
    "channel": 512,
    }
},
}

pass_config_relu = {
"by": "name",
"default": {"config": {"name": None}},
"seq_blocks_2": {
    "config": {
        "name": "relu",
        }
    },
"seq_blocks_5": {
    "config": {
        "name": "relu",
        }
    },
"seq_blocks_9": {
    "config": {
        "name": "relu",
        }
    },
"seq_blocks_12": {
    "config": {
        "name": "relu",
        }
    },
"seq_blocks_16": {
    "config": {
        "name": "relu",
        }
    },
"seq_blocks_19": {
    "config": {
        "name": "relu",
        }
    },
"seq_blocks_23": {
    "config": {
        "name": "relu",
        }
    },
"seq_blocks_25": {
    "config": {
        "name": "relu",
        }
    },
}

pass_config_maxpool = {
"by": "name",
"default": {"config": {"name": None}},
"seq_blocks_6": {
    "config": {
        "name": "maxpool",
        }
    },
"seq_blocks_13": {
    "config": {
        "name": "maxpool",
        }
    },
"seq_blocks_20": {
    "config": {
        "name": "maxpool",
        }
    },

}


def design_pass_config_conv2d(a,b,c,d,e,f):
    config = {
        "by": "name",
        "default": {"config": {"name": None}},

        "seq_blocks_0": {
            "config": {
                "name": "output_only",
                "channel_output": a,
                }
            },

        "seq_blocks_3": {
            "config": {
                "name": "both",
                "channel_input": a,
                "channel_output": b,
                }
            },

        "seq_blocks_7": {
            "config": {
                "name": "both",
                "channel_input": b,
                "channel_output": c,
            }
        },

        "seq_blocks_10": {
            "config": {
                "name": "both",
                "channel_input": c,
                "channel_output": d,
            }
        },
        "seq_blocks_14": {
            "config": {
                "name": "both",
                "channel_input": d,
                "channel_output": e,
            }
        },
        "seq_blocks_17": {
            "config": {
                "name": "both",
                "channel_input": e,
                "channel_output": f,
            }
        },
    }
    return config

def design_pass_config_linear(f):
    config = {
        "by": "name",
        "default": {"config": {"name": None}},
        "seq_blocks_22": {
            "config": {
                "name": "linear",
                "channel_input": f*4*4,
            }
        }
    }
    return config

def design_pass_config_bn(a,b,c,d,e,f):
    config = {
        "by": "name",
        "default": {"config": {"name": None}},
        "seq_blocks_1": {
        "config": {
            "name": "bn",
            "channel": a,
            }
        },
        "seq_blocks_4": {
        "config": {
            "name": "bn",
            "channel": b,
            }
        },
        "seq_blocks_8": {
        "config": {
            "name": "bn",
            "channel": c,
            }
        },
        "seq_blocks_11": {
        "config": {
            "name": "bn",
            "channel": d,
            }
        },
        "seq_blocks_15": {
        "config": {
            "name": "bn",
            "channel": e,
            }
        },
        "seq_blocks_18": {
        "config": {
            "name": "bn",
            "channel": f,
            }
        },
    }
    return config



In [None]:
# train

import torch
import torch.optim as optim

multipliers = [64, 128, 256, 512]
layer_sequence = ["seq_blocks_0", "seq_blocks_3", "seq_blocks_7", "seq_blocks_10", "seq_blocks_14", "seq_blocks_17"]

from torchmetrics.classification import MulticlassAccuracy
metric = MulticlassAccuracy(num_classes=10)

max_epoch = 10
batch_size = 512
optimizer = optim.Adam(ori_model.parameters(), lr=1e-5)

def init_mg():
    model = ori_model
    mg = MaseGraph(model=model)
    mg, _ = init_metadata_analysis_pass(mg, None)
    return mg

for a in multipliers:
    for b in multipliers:
            for c in multipliers:
                for d in multipliers:
                        for e in multipliers:
                            for f in multipliers:
                                if a<=b and b<=c and c<=d and d<=e and e<=f:
                                    pass_config_conv2d = design_pass_config_conv2d(a, b, c, d, e, f)
                                    pass_config_bn = design_pass_config_bn(a, b, c, d, e, f)
                                    pass_config_linear = design_pass_config_linear(f)

                                    print(pass_config_conv2d)

                                    mg = init_mg()
                                    #print(mg.model)
                                    mg, _ = redefine_conv2d_transform_pass(mg, pass_args={"config": pass_config_conv2d})
                                    mg, _ = redefine_bn_transform_pass(mg, pass_args={"config": pass_config_bn})
                                    mg, _ = redefine_relu_pass(mg, pass_args={"config": pass_config_relu})
                                    mg, _ = redefine_pooling_transform_pass(mg, pass_args={"config": pass_config_maxpool})
                                    mg, _ = redefine_linear_transform_pass(mg, pass_args={"config": pass_config_linear})

                                    for epoch in range(max_epoch):
                                        data_iterator = iter(data_module.train_dataloader())
                                        inputs = next(data_iterator)
                                        xs, ys = inputs
                                        optimizer.zero_grad()
                                        preds = mg.model(xs)
                                        loss = torch.nn.functional.cross_entropy(preds, ys)  
                                        loss.backward()  
                                        optimizer.step()


                                    with torch.no_grad():
                                        for inputs in data_module.train_dataloader():
                                            xs, ys = inputs
                                            preds = mg.model(xs)
                                            acc = metric(preds, ys)
                                            loss = torch.nn.functional.cross_entropy(preds, ys)

                                    print(f"acc_{a}_{b}_{c}_{d}_{e}_{f}",acc)

                                    mg = init_mg()
                                    
                                    optimizer = optim.Adam(mg.model.parameters(), lr=1e-5)