In [1]:
from difflogic import LogicLayer, GroupSum
import torch
from chop import MaseGraph

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# import os
# import sys

# "Add dot (pydot) to path if needed:"

# new_path = os.path.expanduser("~/miniforge3/envs/adls-project/bin/")
# if new_path not in sys.path:
#     sys.path.append(new_path)

# # Add to environment PATH as well
# os.environ["PATH"] = new_path + os.pathsep + os.environ["PATH"]

In [3]:
model = torch.nn.Sequential(
    torch.nn.Flatten(),
    LogicLayer(400, 8_000),
    LogicLayer(8_000, 8_000),
    LogicLayer(8_000, 8_000),
    LogicLayer(8_000, 8_000),
    LogicLayer(8_000, 8_000),
    LogicLayer(8_000, 8_000),
    GroupSum(k=10, tau=20)
)

In [4]:
mg = MaseGraph(model)

# mg.draw("DLG.svg")

In [5]:
def eval(model, loader, mode):
    orig_mode = model.training
    with torch.no_grad():
        model.train(mode=mode)
        res = np.mean(
            [
                (model(x.to('cpu').round()).argmax(-1) == y.to('cpu')).to(torch.float32).mean().item()
                for x, y in loader
            ]
        )
        model.train(mode=orig_mode)
    return res.item()

In [6]:
import mnist_dataset
import numpy as np

train_set = mnist_dataset.MNIST('./data-mnist', train=True, download=True, remove_border=True)
test_set = mnist_dataset.MNIST('./data-mnist', train=False, remove_border=True)

train_loader = torch.utils.data.DataLoader(train_set, batch_size=100, shuffle=True, pin_memory=True, drop_last=True, num_workers=4)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=100, shuffle=False, pin_memory=True, drop_last=True)

state_dict = torch.load("best_model_full.pth", map_location=torch.device('cpu'))  # Load the dictionary
mg.model.load_state_dict(state_dict)  # Apply the weights

mg.model.eval()

print("Accuracy: ", eval(mg.model, test_loader, mode=False))

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz


100%|██████████| 9.91M/9.91M [00:00<00:00, 14.3MB/s]



Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz


100%|██████████| 28.9k/28.9k [00:00<00:00, 382kB/s]



Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz


100%|██████████| 1.65M/1.65M [00:00<00:00, 3.58MB/s]



Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz


100%|██████████| 4.54k/4.54k [00:00<00:00, 6.59MB/s]



Accuracy:  0.9751000040769577


In [11]:
import chop.passes as passes

mg, _ = passes.init_metadata_analysis_pass(mg)
# mg, _ = passes.add_common_metadata_analysis_pass(
#     mg,
#     pass_args={
#         "dummy_in": [ x[0] for x in iter(test_loader)][0],
#         "add_value": False,
#     },
# )

In [12]:
mg.draw("DLG_test.svg")

In [46]:
import types

@torch.fx.wrap
def indices_placeholder(inp, indices_a, indices_b):
    return inp[..., indicies_a], inp[..., indicies_b]

@torch.fx.wrap
def neuron_placeholder(a, b, op):
    # Define a dictionary for vectorized operations
    return a

@torch.fx.wrap
def groupsum_placeholder(inp):
    # Define a dictionary for vectorized operations
    return "Just placeholder"


def test_pass(
    graph,
    pass_args={"model":"None", "state_dict":{}},
):
    model = pass_args["model"]
    
    new_graph = torch.fx.Graph()
    given_params = model.state_dict()
    if (pass_args["state_dict"]):
        given_params = pass_args["state_dict"]
    
    last_node = None
    
    # Step 1: Copy input placeholder nodes
    for node in graph.nodes:
        if node.op == "placeholder":  # Detect input nodes
            last_node = new_graph.placeholder(node.name)  # Copy placeholder
            break
    
    for name, module in model.named_modules():
        
        if isinstance(module, torch.nn.Sequential):  
            continue  # Skip Sequential containers
        
        elif isinstance(module, torch.nn.Flatten):
            continue
        
        elif isinstance(module, LogicLayer):
            name_param = "indices"
            param_id = name + "." + name_param
            indices_a = given_params[param_id][0]
            indices_b = given_params[param_id][1]
            last_node = new_graph.call_function(indices_placeholder, args=(last_node,), kwargs={"indices_a":indices_a, "indices_b":indices_b})
            
            name_param = "weights"
            param_id = name + "." + name_param
            w = given_params[param_id]
            last_node = new_graph.call_function(neuron_placeholder, args=(last_node,), kwargs={"op":w})
        elif isinstance(module, GroupSum):
            last_node = new_graph.call_function(groupsum_placeholder, args=(last_node,))
        
        else:
            raise NotImplementedError(name, module)
        
    
    for node in graph.nodes:
        if node.op == "output":  # Detect input nodes
            last_node = new_graph.output(last_node)  # Copy placeholder
            break 
    
    mg.fx_graph = new_graph
    
    return mg, None
                    
                    
                    
#     nodes = set({})
#     for node in mg.fx_graph.nodes:
#         if (node.op == "call_module"):
#             pass
#         elif (node.op == "call_method" and long not in node.target ):
#             if len(node.args) > 1:
#                 raise "What?"
#             node.replace_all_uses_with(node.args[0])
#         elif (isinstance(node.target, types.BuiltinFunctionType):
#              node.replace_all_uses_with(node.args[0]) 
            
            
#     return graph, None

In [47]:
mg_test, test = test_pass(mg, pass_args={"model":model, "state_dict":state_dict})
mg_test.draw("DLG_test.svg")