In [1]:
import dagshub
dagshub.init(repo_owner='leocus4', repo_name='TinyFFF', mlflow=True)

In [2]:
import torch
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Training on {DEVICE}")

Training on cuda


In [3]:
import torch
import torchvision
import numpy as np
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision.datasets import CIFAR10
from torchvision.transforms import ToTensor
from torch.utils.data.dataloader import DataLoader
from torch.utils.data import random_split
from sklearn.metrics import confusion_matrix, classification_report, ConfusionMatrixDisplay

In [4]:
def print_size_fc(model, list_of_fc_layers, list_of_fc_sparsity, verbose=False):
    '''
    model has to be sublass of nn.Module
        check the subclass with: issubclass(sub, sup), return true if sub is sublcass of sup
                                 isinstance(sub_instance, sup), return true if is sub_instance is subclass of sup
    list_of_fc_layers: list of fully connected layer OF THE MODEL (should be a pointer to layer of model)
    list_of_fc_sparsity: list of the sparsity for each fully connected layer
    '''
    assert isinstance(model, nn.Module), "The model is not a subclass of torch.nn.Module"
    assert len(list_of_fc_layers) == len(list_of_fc_sparsity), "The lists should be of the same length"
    kb = 1000
    verbose and print("-------------------------------------------------------------------------------------------")
    model_size_no_sparsity = 0
    for param in model.parameters():
        model_size_no_sparsity += param.nelement() * param.element_size()
    for buffer in model.buffers():
        model_size_no_sparsity += buffer.nelement() * buffer.element_size()
    
    total_size_no_sparsity = 0
    total_size_with_sparsity = 0
    total_size_with_sparsity_CSC = 0
    
    size_layer_list = []
    num = 0
    for fc_layer, sparsity in zip(list_of_fc_layers, list_of_fc_sparsity):
        num += 1
        # get size
        verbose and print("Layer " + str(num), fc_layer)
        weight = fc_layer.nelement() * fc_layer.element_size()
        
        # save in no sparsity
        total_size_no_sparsity += weight
        
        # set sparsity
        weight = min(1, 2 * sparsity) * weight
        
        # FROM Representation
        if (sparsity <= 0.5): # Dipende dall'analisi che vuoi fare
            if (len(list(fc_layer.shape)) == 3):
                verbose and print("Layer require additional", fc_layer.shape[0], "variables, total size with 4 bytes:", fc_layer.shape[0]*4 / kb)
                total_size_with_sparsity_CSC += (fc_layer.shape[0]*4) # number of filter
            elif (len(list(fc_layer.shape)) == 2):
                total_size_with_sparsity_CSC += (fc_layer.shape[1] + 1)*4 # number of column
                verbose and print("Layer require additional", fc_layer.shape[1]+1, "variables, total size with 4 bytes:", (fc_layer.shape[1]+1)*4 / kb)
            
        total_size_with_sparsity_CSC += weight
        
        # save in with sparsity
        total_size_with_sparsity += weight
        
        size_layer_list.append(weight)
        
        # print total - print weight - print bias
        verbose and print("Layer "+str(num)+":\t\t", (weight) / kb,
              "KB, \tweight:\t", weight / kb, "KB")
    
    # print total no sparisty
    verbose and print("Size FC Layer (no sparsity):\t", total_size_no_sparsity / kb,"KB")
    
    # print total with sparsity
    verbose and print("Size FC Layer (with sparsity):\t", total_size_with_sparsity / kb,"KB")
    
    # print model total - total no sparsity
    verbose and print("Total Size no sparsity:\t\t", model_size_no_sparsity / kb ,"KB")
    
    # print model total - total no sparisty + total with sparsity
    model_size_with_sparsity = model_size_no_sparsity - total_size_no_sparsity + total_size_with_sparsity
    verbose and print("Total Size with sparsity:\t", model_size_with_sparsity / kb,"KB")
    
    # print model total - total no sparisty + total with sparsity and CSC
    model_size_with_sparsity_CSC = model_size_no_sparsity - total_size_no_sparsity + total_size_with_sparsity_CSC
    verbose and print("Total Size with sparsity and CSC representation:\t", model_size_with_sparsity_CSC / kb,"KB")
    
    verbose and print("-------------------------------------------------------------------------------------------")
    
    return model_size_with_sparsity, model_size_with_sparsity_CSC, size_layer_list

In [5]:
def perform_compression(model, list_of_fc_layers, list_of_fc_sparsity, learning_rate, num_epochs, train_loader,
                        test_loader,model_device,val_loader=None, model_name=None, given_criterion=None,
                        calculate_inputs=None,calculate_outputs=None, history=False, regularizerParam = 0):
    '''
    model has to be sublass of nn.Module
        check the subclass with: issubclass(sub, sup), return true if sub is sublcass of sup
                                 isinstance(sub_instance, sup), return true if is sub_instance is subclass of sup
    list_of_fc_layers: list of fully connected layer OF THE MODEL (should be a pointer to layer of model)
    list_of_fc_sparsity: list of the sparsity for each fully connected layer
    NOTE - Sparsity applied only to weight of FC, not on bias
    NOTE - The list are modified during execution, so are copied with list.copy() to avoid changing the original list
    '''
    assert isinstance(model, nn.Module), "The model is not a subclass of torch.nn.Module"
    assert len(list_of_fc_layers) == len(list_of_fc_sparsity), "The lists should be of the same length"
    # asset sparsity between 0 and 1
    valid_sparsity = True
    for sparsity in list_of_fc_sparsity:
        if (sparsity > 1) or (sparsity < 0):
            valid_sparsity = False
    assert valid_sparsity, "The sparsity value must be between 0 and 1"
    list_of_fc_layers = list_of_fc_layers.copy()
    list_of_fc_sparsity = list_of_fc_sparsity.copy()
    # The idea is get the model, set all parameter to not require gradient, set fully connected layer to require gradient,
    # perform training
    
    # disabling parameters
    for name, param in model.named_parameters():
        print("Disabling:", name)
        param.requires_grad = False
    
    # activating fully connected layers only if its sparsity is > 0
    # if a layer has sparsity equal to zero we can override with 0
    # if all sparsity is set to 1, compression is not requested
    sparseTraining = False
    for fc_layer, sparsity in zip(list_of_fc_layers, list_of_fc_sparsity):
        if (sparsity == 1):
            #if (sparseTraining):
            print("Activating:", fc_layer.shape)
            fc_layer.requires_grad = True
        elif (sparsity > 0):
            print("Activating:", fc_layer.shape)
            fc_layer.requires_grad = True
        else:
            fc_layer.weight = torch.nn.Parameter(torch.zeros_like(fc_layer.weight), requires_grad=False)
            # delete from the list (since no need to update them)
            list_of_fc_layers.remove(fc_layer)
            list_of_fc_sparsity.remove(sparsity)
            
        if (sparsity < 1):
            sparseTraining = True
    
    acc = 0
    # TEST - compute accuracy
    accuracyHistory = []
    lastCorrect = 0
    totalPredictions = 0
    numberOfUpdates = len(test_loader)
        
    if not (sparseTraining):
        print("No need to perform compression, all layers's sparsity is set to 1")
    else: # PERFORM TRAINING - COMPRESSION
        
        # set up
        criterion = nn.NLLLoss()
        if given_criterion:
            criterion = given_criterion
        optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
        #optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9)
        n_total_steps = len(train_loader)
        
        # to save best results
        best_val_epoch, best_val_loss, best_val_acc, best_acc_epoch = 0, 1e6, 0, 0
        
        for epoch in range(num_epochs):
            
            model.train()
            for i, (inputs, labels) in enumerate(train_loader):
                
                inputs = inputs.to(model_device)
                labels = labels.to(model_device)
                                
                # Forward pass
                
                # preforward
                if calculate_inputs:
                    inputs = calculate_inputs(inputs)
                
                # forward
                if calculate_outputs:
                    outputs = calculate_outputs(inputs)
                else:
                    outputs = model.forward(inputs)
                
                # Regularization
                regularizer = 0
                if (regularizerParam != 0):
                    for layer in list_of_fc_layers:
                        regularizer += (torch.norm(layer.weight)**2)
                # Loss
                loss = criterion(outputs, labels) + (regularizer * regularizerParam)
                
                # Backward and optimize
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                
                # apply hardthreshold - in the list we have only layer with require_grad = True
                for fc_layer, sparsity in zip(list_of_fc_layers, list_of_fc_sparsity):
                    layer = fc_layer.data
                    new_layer = hardThreshold(layer, sparsity)
                    with torch.no_grad():
                        fc_layer.data = torch.FloatTensor(new_layer).to(model_device)
                
                # print Accuracy
                if (i+1) % 100 == 0:
                    print (f'Epoch [{epoch+1}/{num_epochs}], Step[{i+1}/{n_total_steps}], Loss: {loss.item():.4f}')
            
            print (f'Epoch [{epoch+1}/{num_epochs}], Step[{i+1}/{n_total_steps}], Loss: {loss.item():.4f}')
            
            # Use Validation Set at each epochs to pick the most 
            if (val_loader and model_name):
                model.eval()
                with torch.no_grad():
                    v_loss = 0
                    n_correct = 0
                    n_samples = 0
                    n_iterations = 0
                    for inputs, labels in test_loader:
                        inputs = inputs.to(model_device)
                        labels = labels.to(model_device)
                        # Forward pass
                
                        # preforward
                        if calculate_inputs:
                            inputs = calculate_inputs(inputs)
                        outputs = 0 
                        # forward
                        if calculate_outputs:
                            outputs = calculate_outputs(inputs)
                        else:
                            outputs = model.forward(inputs)
                        
                        # for calculating v_loss
                        loss = criterion(outputs, labels)                       
                        v_loss += loss.item()
                        n_iterations += 1
                        
                        # max returns (value, index)
                        _, predicted = torch.max(outputs.data, 1)
                        n_samples += labels.size(0)
                        n_correct += (predicted == labels).sum().item()
                    
                    # Val test completed, now checking the results
                    v_loss = v_loss/(n_iterations)
                    v_loss = round(v_loss, 5)
                    v_acc = round(100*(n_correct / n_samples), 5)
                    
                    if v_acc >= best_val_acc:
                        torch.save(model.state_dict(), model_name+"_acc.h5")
                        best_acc_epoch = epoch + 1
                        best_val_acc = v_acc
                    if v_loss <= best_val_loss:
                        torch.save(model.state_dict(), model_name+".h5")
                        best_val_epoch = epoch + 1
                        best_val_loss = v_loss
                    #print(f'Epoch[{epoch+1}]: t_loss: {t_loss} t_acc: {t_acc} v_loss: {v_loss} v_acc: {v_acc}')
                    print(f'Epoch[{epoch+1}]: v_loss: {v_loss} v_acc: {v_acc}')
        
        
        # Use Validation Set at each epochs to pick the most 
        if (val_loader and model_name):
            model.load_state_dict(torch.load(model_name+".h5", map_location='cpu'))
            print('Best model saved at epoch: ', best_val_epoch)
            print('Best acc model saved at epoch: ', best_acc_epoch)
        
        # USING TEST SET TO CHECK ACCURACY
        #model.eval()
        with torch.no_grad():
            n_correct = 0
            n_samples = 0
            for inputs, labels in test_loader:
                inputs = inputs.to(model_device)
                labels = labels.to(model_device)
                   # Forward pass
                
                # preforward
                if calculate_inputs:
                    inputs = calculate_inputs(inputs)
                outputs = 0 
                # forward
                if calculate_outputs:
                    outputs = calculate_outputs(inputs)
                else:
                    outputs = model.forward(inputs)
                # max returns (value, index)
                
                _, predicted = torch.max(outputs.data, 1)
                n_samples += labels.size(0)
                n_correct += (predicted == labels).sum().item()                
            acc = 100.0 * n_correct / n_samples
            totalPredictions = n_samples
            print(f'Accuracy of the network on the 10000 test images: {acc} %')

        
    result = {
        'correctPredictions': lastCorrect,
        'totalPredictions': totalPredictions,
        'accuracyThroughEpochs': accuracyHistory,
        'numberOfUpdate': numberOfUpdates,
    }
    
    return acc

In [6]:
import torch
import torchvision
import numpy as np
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision.datasets import CIFAR10
from torchvision.transforms import ToTensor
from torch.utils.data.dataloader import DataLoader
from torch.utils.data import random_split
from sklearn.metrics import confusion_matrix, classification_report, ConfusionMatrixDisplay


def print_full_model(model):
    assert isinstance(model, nn.Module), "The model is not a subclass of torch.nn.Module"
    kb = 1000
    model_size = 0
    for name, param in model.named_parameters():
        layer_size = param.nelement() * param.element_size()
        model_size += layer_size
        print(name,"\t", param.nelement(), "\t", param.element_size(),"\t", layer_size / kb, "KB")
        
    for name, buffer in model.named_buffers():
        layer_size = buffer.nelement() * buffer.element_size()
        model_size += layer_size
        print(name,"\t", layer_size / kb, "KB")
    print("Model Size:", model_size / kb, "KB")

def hardThreshold(A: torch.Tensor, sparsity):
    '''
    Given a Tensor A and the correponding sparsity, returns a copy in the
    format of numpy array with the constraint applied
    '''
    matrix_A = A.data.cpu().detach().numpy().ravel()    
    if len(matrix_A) > 0:
        threshold = np.percentile(np.abs(matrix_A), (1 - sparsity) * 100.0, method='higher')
        matrix_A[np.abs(matrix_A) < threshold] = 0.0
    matrix_A = matrix_A.reshape(A.shape)
    return matrix_A

def get_layers(model):
    """Recursively get all layers in a PyTorch model."""
    list_layers = []
    # for name, module in model.named_children():
    #     # check type of module
    #     is_conv1d = isinstance(module, torch.nn.Conv1d)
    #     is_conv2d = isinstance(module, torch.nn.Conv2d)
    #     is_linear = isinstance(module, torch.nn.Linear)
    #     is_sequential = isinstance(module, torch.nn.Sequential)
    #     if (is_conv1d or is_conv2d or is_linear):
    #         list_layers.append(module)
    #     if (is_sequential):
    #         for sub_name, sub_module in module.named_children():
    #             print(sub_name)
    #             # check type of module
    #             is_conv1d = isinstance(sub_module, torch.nn.Conv1d)
    #             is_conv2d = isinstance(sub_module, torch.nn.Conv2d)
    #             is_linear = isinstance(sub_module, torch.nn.Linear)
    #             if (is_conv1d or is_conv2d or is_linear):
    #                 list_layers.append(sub_module)
    for layer in model.children():
        if isinstance(layer, nn.Sequential):
            # If it's a sequential container, recursively get its layers
            list_layers.extend(get_layers(layer))
        else:
            # If it's a single layer, add it to the list
            if (isinstance(layer, torch.nn.Conv1d) or isinstance(layer, torch.nn.Conv2d) or isinstance(layer, torch.nn.Linear)):
                list_layers.append(layer)
    return list_layers

def apply_sparsity(model, list_of_fc_layers, list_of_fc_sparsity, model_device):
    assert isinstance(model, nn.Module), "The model is not a subclass of torch.nn.Module"
    assert len(list_of_fc_layers) == len(list_of_fc_sparsity), "The lists should be of the same length"
    # asset sparsity between 0 and 1
    valid_sparsity = True
    for sparsity in list_of_fc_sparsity:
        if (sparsity > 1) or (sparsity < 0):
            valid_sparsity = False
    assert valid_sparsity, "The sparsity value must be between 0 and 1"
    
    list_of_fc_layers = list_of_fc_layers.copy()
    list_of_fc_sparsity = list_of_fc_sparsity.copy()
    
    # apply hardthreshold - in the list we have only layer with require_grad = True
    for fc_layer, sparsity in zip(list_of_fc_layers, list_of_fc_sparsity):
        layer = fc_layer.weight.data
        new_layer = hardThreshold(layer, sparsity)
        with torch.no_grad():
            fc_layer.weight.data = torch.FloatTensor(new_layer).to(model_device)
    
def calculate_accuracy(model, train_loader, test_loader, model_device, calculate_inputs=None, calculate_outputs=None):
    assert isinstance(model, nn.Module), "The model is not a subclass of torch.nn.Module"

    acc = 0

    # TEST - compute accuracy
    with torch.no_grad():
        n_correct = 0
        n_samples = 0
        for inputs, labels in train_loader:
            inputs = inputs.to(model_device)
            labels = labels.to(model_device)
            # preforward
            if calculate_inputs:
                inputs = calculate_inputs(inputs)
            outputs = 0 
            # forward
            if calculate_outputs:
                outputs = calculate_outputs(inputs)
            else:
                outputs = model.forward(inputs)
                    
            # max returns (value, index)
            _, predicted = torch.max(outputs.data, 1)
            n_samples += labels.size(0)
            n_correct += (predicted == labels).sum().item()

        acc = 100.0 * n_correct / n_samples
        print(f'Accuracy of the network on the train images: {acc} %')

        n_correct = 0
        n_samples = 0
        for inputs, labels in test_loader:
            inputs = inputs.to(model_device)
            labels = labels.to(model_device)
            # preforward
            if calculate_inputs:
                inputs = calculate_inputs(inputs)
            outputs = 0 
            # forward
            if calculate_outputs:
                outputs = calculate_outputs(inputs)
            else:
                outputs = model.forward(inputs)
            
            # max returns (value, index)
            _, predicted = torch.max(outputs.data, 1)
            n_samples += labels.size(0)
            n_correct += (predicted == labels).sum().item()

        acc = 100.0 * n_correct / n_samples
        print(f'Accuracy of the network on the 10000 test images: {acc} %')

    return acc

In [7]:
def compute_sparsity_for_layers(layer_list):
    """Compute sparsity for each layer in a list of layers."""
    sparsity_info = []

    for layer in layer_list:
        weight = layer.data
        total_elements = weight.numel()
        zero_elements = (weight == 0).sum().item()
        sparsity = zero_elements / total_elements
        sparsity_info.append((layer.__class__.__name__, sparsity, total_elements, zero_elements))
    
    # Print the sparsity information for each layer
    for layer, sparsity, total_elements, zero_elements in sparsity_info:
        print(f'Layer: {layer}, Sparsity: {1-sparsity:.4f}, Total Elements: {total_elements}, Zero Elements: {zero_elements}')

    return sparsity_info

In [8]:
from fastfeedforward import FFF

def train(net, trainloader, epochs, norm_weight=0.0):
    """Train the network on the training set."""
    # Define loss and optimizer
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters(), lr=0.001)

    # Train the network for the given number of epochs
    for _ in range(epochs):
        # Iterate over data
        for images, labels in trainloader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            loss = criterion(net(images), labels)
            if norm_weight != 0:
                loss += norm_weight * net.fff.w1s.pow(2).sum()
                loss += norm_weight * net.fff.w2s.pow(2).sum()
            loss.backward()
            optimizer.step()


def test(net, testloader):
    """Validate the network on the entire test set."""
    # Define loss and optimizer
    criterion = torch.nn.CrossEntropyLoss()
    correct, total, loss = 0, 0, 0.0
    # Train the network for the given number of epochs
    with torch.no_grad():
        # Iterate over data
        for data in testloader:
            images, labels = data[0].to(DEVICE), data[1].to(DEVICE)
            outputs = net(images)
            loss += criterion(outputs, labels).item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = correct / total
    return loss, accuracy


class Net(torch.nn.Module):
    def __init__(self, input_width, leaf_width, output_width, depth, dropout, region_leak):
        super(Net, self).__init__()
        self.fff = FFF(input_width, leaf_width, output_width, depth, torch.nn.ReLU(), dropout, train_hardened=True, region_leak=region_leak)

    def forward(self, x):
        x = x.view(len(x), -1)
        x = self.fff(x)
        x = torch.nn.functional.softmax(x, -1)
        return x

    def parameters(self):
        return self.fff.parameters()


class FF(torch.nn.Module):
    def __init__(self, input_width, layer_width, output_width):
        super(FF, self).__init__()
        self.fc1 = torch.nn.Linear(input_width, layer_width)
        self.fc2 = torch.nn.Linear(layer_width, output_width)

    def forward(self, x):
        x = x.view(len(x), -1)
        x = torch.nn.functional.relu(self.fc1(x))
        x = torch.nn.functional.softmax(self.fc2(x), -1)
        return x

    def parameters(self):
        return [*self.fc1.parameters(), *self.fc2.parameters()]


def compute_n_params(input_width: int, l_w: int, depth: int, output_width: int):
    fff = Net(input_width, l_w, output_width, depth, 0, 0)
    ff = FF(input_width, l_w, output_width)

    n_ff = 0
    n_fff = 0
    for p in ff.parameters():
        n_ff += p.numel()
    for i, p in enumerate(fff.parameters()):
        print(f"[{i}-th layer]: {p.shape}")
        n_fff += p.numel()

    print(f"FFF: {n_fff}\nFF: {n_ff}")

In [9]:
import pickle
import mlflow
import numpy as np
import pandas as pd
from time import time
from matplotlib import pyplot as plt

def get_dist(net, testloader):
    """
    Returns the distribution of samples throughout the tree.
    """

    y = []
    l = []
    with torch.no_grad():
        # Iterate over data
        for data in testloader:
            images, labels = data[0].to(DEVICE), data[1].to(DEVICE)
            outputs, leaves = net.forward(images, return_nodes=True)
            y.append(labels)
            l.append(leaves)
    y = torch.concat(y, 0)
    l = torch.concat(l, 0)
    return y, l


class FFFWrapper(torch.nn.Module):
    def __init__(self, fff):
        super(FFFWrapper, self).__init__()
        self._fff = fff
        self._fastinference = [None for i in range(2 ** (self._fff.fff.depth.item()))]

    def forward(self, x, return_nodes=False):
        """
        Override the forward method in order to log the data distribution.
        """
        x = x.view(len(x), -1)
        original_shape = x.shape
        batch_size = x.shape[0]
        last_node = torch.zeros(len(x))

        current_nodes = torch.zeros((batch_size,), dtype=torch.long, device=x.device)
        for i in range(self._fff.fff.depth.item()):
            plane_coeffs = self._fff.fff.node_weights.index_select(dim=0, index=current_nodes)
            plane_offsets = self._fff.fff.node_biases.index_select(dim=0, index=current_nodes)
            plane_coeff_score = torch.bmm(x.unsqueeze(1), plane_coeffs.unsqueeze(-1))
            plane_score = plane_coeff_score.squeeze(-1) + plane_offsets
            plane_choices = (plane_score.squeeze(-1) >= 0).long()

            platform = torch.tensor(2 ** i - 1, dtype=torch.long, device=x.device)
            next_platform = torch.tensor(2 ** (i+1) - 1, dtype=torch.long, device=x.device)
            current_nodes = (current_nodes - platform) * 2 + plane_choices + next_platform

        leaves = current_nodes - next_platform
        new_logits = torch.empty((batch_size, self._fff.fff.output_width), dtype=torch.float, device=x.device)
        last_node = leaves

        for i in range(leaves.shape[0]):
            leaf_index = leaves[i]
            if self._fastinference[leaf_index] is not None:
                new_logits[i] = self._fastinference[leaf_index]
            else:
                logits = torch.matmul( x[i].unsqueeze(0), self._fff.fff.w1s[leaf_index])
                logits += self._fff.fff.b1s[leaf_index].unsqueeze(-2)
                activations = self._fff.fff.activation(logits)
                new_logits[i] = torch.matmul( activations, self._fff.fff.w2s[leaf_index]).squeeze(-2)

        if return_nodes:
            return new_logits.view(*original_shape[:-1], self._fff.fff.output_width), last_node
        return new_logits.view(*original_shape[:-1], self._fff.fff.output_width)


    def simplify_leaves(self, trainloader):
        y, leaves = (get_dist(self, trainloader))
        y = y.cpu().detach().numpy()
        outputs = y.max() + 1
        leaves = leaves.cpu().detach().numpy()

        n_simplifications = 0
        ratios = {}
        for l in np.unique(leaves):
            ratios[l] = torch.zeros(outputs)
            indices = leaves == l

            for i in range(outputs):
                ratios[l][i] = (np.sum(y[indices] == i) / np.sum(indices))

            argmax = np.argmax(ratios[l])
            if ratios[l][argmax] > 0.7:
                output = torch.zeros(outputs)
                output[argmax] = 1
                self._fastinference[l] = output
                n_simplifications += 1
                print(f"Leaf {l} has been replaced with {argmax}")
        print(self._fastinference)

In [10]:
from tqdm import trange
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
import torchvision.transforms as transforms

"""Load CIFAR-10 (training and test set)."""
transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,)),
    ]
)
trainset = MNIST("../data", train=True,  download=True, transform=transform)
testset = MNIST("../data",  train=False, download=True, transform=transform)

# Select class to keep 
trainloader = DataLoader(trainset, batch_size=1024, shuffle=True)
testloader = DataLoader(testset, batch_size=1024)

num_examples = {"trainset" : len(trainset), "testset" : len(testset)}

print(num_examples)

{'trainset': 60000, 'testset': 10000}


In [11]:
list_of_run = [
'543fe9acb34441a3a82b09ca2ef6046c',
'4ea2391b3d014e2fafff3accfb352d2c',
'612d1573468c4ffaa89bf54d60ce4508',
'24879e69bd174ef6a9973359ea1b9b6c',
'7c66090a514e4080b639b6f261d5a134',
'a258b245549043ef84a9beff50896872',
'57f3c5cd82dc480c94e516ab34620331',
'cef7d9cd3ea548b783a400984a7145fc',
'706ca101f6a2446db81d58abdf6e2815',
'2a116553f512425a9409bd968b9fe8ef',
'90b23d76307e425ea4ba7d647c7cb7f6',
'57e917cd181c456bbf3c365b5018a2d6',
'715fba635c6145b185c29e6aa6b6bcbf',
'652dc0e8fe0042309b990da5fc377d60',
'580862314acb40cbb6411391f6def1e1',
'0c4b701ba2c2434b99a83f0b771b3945',
'755af7caf9e74fbdbc2dee292dd8b3d1',
'e6c2200cd3a942b081e77a4fcbb21df6',
'9ac931f6215644bf9d22e6fcfc7f179f',
'69d271a25d744404ad63c43b575192a6',
'27f4eafb191340f592dfab6992d3700d',
'9d555e630619470c8e4a6d075ba0e65f',
'89e999f7bcce4b238cd3bca960ec27da',
'15d383cd1c044c8cb7cf5b5de6955b13',
]

# mlflow.artifacts.download_artifacts(run_id=run_id, dst_path=".")
# wrapped_model = pickle.load(open("./truncated_model.pkl", "rb"))
# wrapped_model._fff.fff.depth.item()
# wrapped_model._fff.fff.input_width
# wrapped_model._fff.fff.leaf_width
# wrapped_model._fff.fff.output_width

In [12]:
# Training variable
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
learning_rate = 0.001
num_epochs = 7
criterion = nn.CrossEntropyLoss()

# Dataset
batch_size = 1024
val_size = 5000
train_size = len(trainset) - val_size
train_ds, val_ds = random_split(trainset, [train_size, val_size])
train_loader = DataLoader(train_ds, batch_size, shuffle=True, num_workers=4)
val_loader = DataLoader(val_ds, batch_size, num_workers=4)
test_loader = DataLoader(testset, batch_size, num_workers=4)

In [13]:
result = []
for i in range (0,len(list_of_run)):
    run_id = list_of_run[i]
    mlflow.artifacts.download_artifacts(run_id=run_id, dst_path=".")
    wrapped_model = pickle.load(open("./truncated_model.pkl", "rb"))
    depth = wrapped_model._fff.fff.depth.item()
    input_width = wrapped_model._fff.fff.input_width
    leaf_width = wrapped_model._fff.fff.leaf_width
    output_width = wrapped_model._fff.fff.output_width
    buffer_size = 2*(leaf_width + output_width + 3)
    print("Run:\t", run_id)
    print("Depth:\t", depth)
    print("Input:\t", input_width)
    print("Output:\t", output_width)
    print("Leaf:\t", leaf_width)
    print("Buffer:\t", buffer_size)
    
    # to reduce the sparsity and train only below a certain tresholds
    list_of_sizes = [100, 90, 80, 70, 60, 50]
    checked_sizes = [False for x in list_of_sizes]
    current_size_index = 0
    
    start = 0.5
    a = start
    b = start
    sizes=[]
    before_trunc_sizes=[]
    trunc_sizes=[]

    model = wrapped_model.to(device)
    
    layers_list = []
    for i, (name, p) in enumerate(model.named_parameters()):
        if (len(list(p.shape)) > 1 and p.requires_grad):
            layers_list.append(p)
    un, comp, layers = print_size_fc(model, layers_list, [1,1,1,1,1,1])
    
    MODEL_NAME_COMPRESSED = "mnist_" + run_id + "_compressed_full"
    
    layers_sizes = layers.copy()
    # Creating a list of 1s with the same length as original_list using multiplication
    list_of_sparsity = [1] * len(layers_sizes)
    
    # Testing time and accuracy
    model.eval()
    t = time()
    train_loss, train_acc = test(model, train_loader)
    test_loss, test_acc = test(model, test_loader)
    t = time() - t
    un, comp, layers_sizes = print_size_fc(model, layers_list, list_of_sparsity)
    sizes.append((comp/1000, test_acc, test_acc, t))
    
    for i in range(1, 100):
        # get index of max
        index_of_max = np.argmax(layers_sizes)
        current_sparsity = list_of_sparsity[index_of_max]

        # reduce
        if (current_sparsity == 1):
            list_of_sparsity[index_of_max] = start - (start * 0.1)
        else:
            list_of_sparsity[index_of_max] = current_sparsity - (current_sparsity * 0.1)

        un, comp, layers_sizes = print_size_fc(model, layers_list, list_of_sparsity)
        
        if (current_size_index >= len(list_of_sizes)):
            break
        elif (comp / 1000 < list_of_sizes[current_size_index]):
            # compress and save result
            MODEL_NAME_COMPRESSED = "mnist_" + run_id + "_compressed_" + str(list_of_sizes[current_size_index])
            current_size_index += 1
            model.train()
            accuracy = perform_compression(model, layers_list, list_of_sparsity, learning_rate, num_epochs,
                                           train_loader, test_loader, device,
                                           val_loader=val_loader, model_name=MODEL_NAME_COMPRESSED, given_criterion=criterion)
            model.load_state_dict(torch.load(MODEL_NAME_COMPRESSED+".h5", map_location='cpu'))
            model.eval()
            t = time()
            train_loss, train_acc = test(model, train_loader)
            test_loss, test_acc = test(model, test_loader)
            t = time() - t
            sizes.append((comp / 1000, test_acc, accuracy, t))
        # continue reducing sparsity
        print(i, "iteration - ", "Size:", comp, list_of_sparsity)
    result.append({'run_id': run_id, 'depth': depth,'input_width':input_width, 
                   'output': output_width, 'leaf_width':leaf_width,
                   'buffer_size': buffer_size, 'sizes': sizes})

Downloading artifacts:   0%|          | 0/1 [00:00<?, ?it/s]

Training on cuda
Run:	 543fe9acb34441a3a82b09ca2ef6046c
Depth:	 4
Input:	 784
Output:	 10
Leaf:	 32
Buffer:	 90
1 iteration -  Size: 1515408.8 [1, 1, 0.45, 1, 1, 1]
2 iteration -  Size: 1370901.9200000002 [1, 1, 0.405, 1, 1, 1]
3 iteration -  Size: 1240845.7280000001 [1, 1, 0.36450000000000005, 1, 1, 1]
4 iteration -  Size: 1123795.1552000002 [1, 1, 0.32805000000000006, 1, 1, 1]
5 iteration -  Size: 1018449.6396800001 [1, 1, 0.29524500000000004, 1, 1, 1]
6 iteration -  Size: 923638.6757120001 [1, 1, 0.2657205, 1, 1, 1]
7 iteration -  Size: 838308.8081408 [1, 1, 0.23914845, 1, 1, 1]
8 iteration -  Size: 761511.92732672 [1, 1, 0.215233605, 1, 1, 1]
9 iteration -  Size: 692394.734594048 [1, 1, 0.1937102445, 1, 1, 1]
10 iteration -  Size: 630189.2611346432 [1, 1, 0.17433922005, 1, 1, 1]
11 iteration -  Size: 574204.3350211789 [1, 1, 0.156905298045, 1, 1, 1]
12 iteration -  Size: 523817.901519061 [1, 1, 0.1412147682405, 1, 1, 1]
13 iteration -  Size: 478470.1113671549 [1, 1, 0.1270932914164

Epoch [1/7], Step[54/54], Loss: 0.3888
Epoch[1]: v_loss: 0.37928 v_acc: 88.75
Epoch [2/7], Step[54/54], Loss: 0.3642
Epoch[2]: v_loss: 0.3628 v_acc: 89.15
Epoch [3/7], Step[54/54], Loss: 0.3520
Epoch[3]: v_loss: 0.35404 v_acc: 89.46
Epoch [4/7], Step[54/54], Loss: 0.3451
Epoch[4]: v_loss: 0.34731 v_acc: 89.72
Epoch [5/7], Step[54/54], Loss: 0.4075
Epoch[5]: v_loss: 0.34253 v_acc: 89.83
Epoch [6/7], Step[54/54], Loss: 0.2905
Epoch[6]: v_loss: 0.33877 v_acc: 90.08
Epoch [7/7], Step[54/54], Loss: 0.2909
Epoch[7]: v_loss: 0.33598 v_acc: 90.05
Best model saved at epoch:  7
Best acc model saved at epoch:  6
Accuracy of the network on the 10000 test images: 90.05 %
53 iteration -  Size: 59981.367950406675 [0.1937102445, 1, 0.005387631832152908, 1, 0.45, 1]
54 iteration -  Size: 58138.16795040668 [0.1937102445, 1, 0.005387631832152908, 1, 0.405, 1]
55 iteration -  Size: 56315.74197015068 [0.17433922005, 1, 0.005387631832152908, 1, 0.405, 1]
56 iteration -  Size: 54585.63115536601 [0.1743392200

Downloading artifacts:   0%|          | 0/1 [00:00<?, ?it/s]

Run:	 4ea2391b3d014e2fafff3accfb352d2c
Depth:	 4
Input:	 784
Output:	 10
Leaf:	 32
Buffer:	 90
1 iteration -  Size: 1515408.8 [1, 1, 0.45, 1, 1, 1]
2 iteration -  Size: 1370901.9200000002 [1, 1, 0.405, 1, 1, 1]
3 iteration -  Size: 1240845.7280000001 [1, 1, 0.36450000000000005, 1, 1, 1]
4 iteration -  Size: 1123795.1552000002 [1, 1, 0.32805000000000006, 1, 1, 1]
5 iteration -  Size: 1018449.6396800001 [1, 1, 0.29524500000000004, 1, 1, 1]
6 iteration -  Size: 923638.6757120001 [1, 1, 0.2657205, 1, 1, 1]
7 iteration -  Size: 838308.8081408 [1, 1, 0.23914845, 1, 1, 1]
8 iteration -  Size: 761511.92732672 [1, 1, 0.215233605, 1, 1, 1]
9 iteration -  Size: 692394.734594048 [1, 1, 0.1937102445, 1, 1, 1]
10 iteration -  Size: 630189.2611346432 [1, 1, 0.17433922005, 1, 1, 1]
11 iteration -  Size: 574204.3350211789 [1, 1, 0.156905298045, 1, 1, 1]
12 iteration -  Size: 523817.901519061 [1, 1, 0.1412147682405, 1, 1, 1]
13 iteration -  Size: 478470.1113671549 [1, 1, 0.12709329141645, 1, 1, 1]
14 it

Epoch [1/7], Step[54/54], Loss: 0.7905
Epoch[1]: v_loss: 0.79924 v_acc: 88.44
Epoch [2/7], Step[54/54], Loss: 0.8346
Epoch[2]: v_loss: 0.79902 v_acc: 88.44
Epoch [3/7], Step[54/54], Loss: 0.8017
Epoch[3]: v_loss: 0.7976 v_acc: 88.51
Epoch [4/7], Step[54/54], Loss: 0.7902
Epoch[4]: v_loss: 0.79778 v_acc: 88.52
Epoch [5/7], Step[54/54], Loss: 0.7801
Epoch[5]: v_loss: 0.79663 v_acc: 88.48
Epoch [6/7], Step[54/54], Loss: 0.7725
Epoch[6]: v_loss: 0.79588 v_acc: 88.57
Epoch [7/7], Step[54/54], Loss: 0.7838
Epoch[7]: v_loss: 0.79486 v_acc: 88.61
Best model saved at epoch:  7
Best acc model saved at epoch:  7
Accuracy of the network on the 10000 test images: 88.61 %
53 iteration -  Size: 59981.367950406675 [0.1937102445, 1, 0.005387631832152908, 1, 0.45, 1]
54 iteration -  Size: 58138.16795040668 [0.1937102445, 1, 0.005387631832152908, 1, 0.405, 1]
55 iteration -  Size: 56315.74197015068 [0.17433922005, 1, 0.005387631832152908, 1, 0.405, 1]
56 iteration -  Size: 54585.63115536601 [0.1743392200

Downloading artifacts:   0%|          | 0/1 [00:00<?, ?it/s]

Run:	 612d1573468c4ffaa89bf54d60ce4508
Depth:	 3
Input:	 784
Output:	 10
Leaf:	 32
Buffer:	 90
1 iteration -  Size: 756138.4 [1, 1, 0.45, 1, 1, 1]
2 iteration -  Size: 683884.9600000001 [1, 1, 0.405, 1, 1, 1]
3 iteration -  Size: 618856.8640000001 [1, 1, 0.36450000000000005, 1, 1, 1]
4 iteration -  Size: 560331.5776000001 [1, 1, 0.32805000000000006, 1, 1, 1]
5 iteration -  Size: 507658.81984000007 [1, 1, 0.29524500000000004, 1, 1, 1]
6 iteration -  Size: 460253.33785600006 [1, 1, 0.2657205, 1, 1, 1]
7 iteration -  Size: 417588.4040704 [1, 1, 0.23914845, 1, 1, 1]
8 iteration -  Size: 379189.96366336 [1, 1, 0.215233605, 1, 1, 1]
9 iteration -  Size: 344631.367297024 [1, 1, 0.1937102445, 1, 1, 1]
10 iteration -  Size: 313528.6305673216 [1, 1, 0.17433922005, 1, 1, 1]
11 iteration -  Size: 285536.16751058947 [1, 1, 0.156905298045, 1, 1, 1]
12 iteration -  Size: 260342.9507595305 [1, 1, 0.1412147682405, 1, 1, 1]
13 iteration -  Size: 237669.05568357746 [1, 1, 0.12709329141645, 1, 1, 1]
14 it

Epoch [1/7], Step[54/54], Loss: 0.2853
Epoch[1]: v_loss: 0.29964 v_acc: 91.34
Epoch [2/7], Step[54/54], Loss: 0.3000
Epoch[2]: v_loss: 0.28139 v_acc: 92.01
Epoch [3/7], Step[54/54], Loss: 0.2119
Epoch[3]: v_loss: 0.26964 v_acc: 92.36
Epoch [4/7], Step[54/54], Loss: 0.2580
Epoch[4]: v_loss: 0.26152 v_acc: 92.61
Epoch [5/7], Step[54/54], Loss: 0.2733
Epoch[5]: v_loss: 0.25582 v_acc: 92.73
Epoch [6/7], Step[54/54], Loss: 0.2480
Epoch[6]: v_loss: 0.25107 v_acc: 92.88
Epoch [7/7], Step[54/54], Loss: 0.2622
Epoch[7]: v_loss: 0.24909 v_acc: 92.88
Best model saved at epoch:  7
Best acc model saved at epoch:  7
Accuracy of the network on the 10000 test images: 92.88 %
39 iteration -  Size: 48850.663648162896 [0.405, 1, 0.010137779795222625, 1, 1, 1]


Downloading artifacts:   0%|          | 0/1 [00:00<?, ?it/s]

Run:	 24879e69bd174ef6a9973359ea1b9b6c
Depth:	 3
Input:	 784
Output:	 10
Leaf:	 32
Buffer:	 90
1 iteration -  Size: 756138.4 [1, 1, 0.45, 1, 1, 1]
2 iteration -  Size: 683884.9600000001 [1, 1, 0.405, 1, 1, 1]
3 iteration -  Size: 618856.8640000001 [1, 1, 0.36450000000000005, 1, 1, 1]
4 iteration -  Size: 560331.5776000001 [1, 1, 0.32805000000000006, 1, 1, 1]
5 iteration -  Size: 507658.81984000007 [1, 1, 0.29524500000000004, 1, 1, 1]
6 iteration -  Size: 460253.33785600006 [1, 1, 0.2657205, 1, 1, 1]
7 iteration -  Size: 417588.4040704 [1, 1, 0.23914845, 1, 1, 1]
8 iteration -  Size: 379189.96366336 [1, 1, 0.215233605, 1, 1, 1]
9 iteration -  Size: 344631.367297024 [1, 1, 0.1937102445, 1, 1, 1]
10 iteration -  Size: 313528.6305673216 [1, 1, 0.17433922005, 1, 1, 1]
11 iteration -  Size: 285536.16751058947 [1, 1, 0.156905298045, 1, 1, 1]
12 iteration -  Size: 260342.9507595305 [1, 1, 0.1412147682405, 1, 1, 1]
13 iteration -  Size: 237669.05568357746 [1, 1, 0.12709329141645, 1, 1, 1]
14 it

Epoch [1/7], Step[54/54], Loss: 0.9144
Epoch[1]: v_loss: 0.92371 v_acc: 85.76
Epoch [2/7], Step[54/54], Loss: 0.8656
Epoch[2]: v_loss: 0.92091 v_acc: 85.94
Epoch [3/7], Step[54/54], Loss: 0.8467
Epoch[3]: v_loss: 0.91947 v_acc: 85.93
Epoch [4/7], Step[54/54], Loss: 0.9735
Epoch[4]: v_loss: 0.91822 v_acc: 85.91
Epoch [5/7], Step[54/54], Loss: 0.8814
Epoch[5]: v_loss: 0.91731 v_acc: 85.98
Epoch [6/7], Step[54/54], Loss: 0.8783
Epoch[6]: v_loss: 0.91669 v_acc: 85.99
Epoch [7/7], Step[54/54], Loss: 0.8753
Epoch[7]: v_loss: 0.91665 v_acc: 86.01
Best model saved at epoch:  7
Best acc model saved at epoch:  7
Accuracy of the network on the 10000 test images: 86.01 %
39 iteration -  Size: 48850.663648162896 [0.405, 1, 0.010137779795222625, 1, 1, 1]


Downloading artifacts:   0%|          | 0/1 [00:00<?, ?it/s]

Run:	 7c66090a514e4080b639b6f261d5a134
Depth:	 2
Input:	 784
Output:	 10
Leaf:	 32
Buffer:	 90


KeyboardInterrupt: 

In [14]:
for r in result:
    print(r)

{'run_id': '543fe9acb34441a3a82b09ca2ef6046c', 'depth': 4, 'input_width': 784, 'output': 10, 'leaf_width': 32, 'buffer_size': 90, 'sizes': [(1675.908, 0.9274, 0.9274, 52.47797870635986), (96.90447921813978, 0.8963, 89.63, 48.06929326057434), (89.85803129632579, 0.9121, 91.21, 48.54480266571045), (77.8086053500239, 0.9095, 90.95, 49.80986571311951), (68.04857033351935, 0.9013, 90.13, 49.315375089645386), (59.98136795040668, 0.9005, 90.05, 47.807331800460815), (49.72946803982941, 0.8852, 88.52, 43.566579818725586)]}
{'run_id': '4ea2391b3d014e2fafff3accfb352d2c', 'depth': 4, 'input_width': 784, 'output': 10, 'leaf_width': 32, 'buffer_size': 90, 'sizes': [(1675.908, 0.8077, 0.8077, 34.018182039260864), (96.90447921813978, 0.8872, 88.72, 33.80422616004944), (89.85803129632579, 0.8957, 89.57, 34.04274368286133), (77.8086053500239, 0.8892, 88.92, 34.13646936416626), (68.04857033351935, 0.8897, 88.97, 34.019927740097046), (59.98136795040668, 0.8861, 88.61, 34.12063121795654), (49.7294680398294

In [None]:
for i in range (4,len(list_of_run)):
    run_id = list_of_run[i]
    mlflow.artifacts.download_artifacts(run_id=run_id, dst_path=".")
    wrapped_model = pickle.load(open("./truncated_model.pkl", "rb"))
    depth = wrapped_model._fff.fff.depth.item()
    input_width = wrapped_model._fff.fff.input_width
    leaf_width = wrapped_model._fff.fff.leaf_width
    output_width = wrapped_model._fff.fff.output_width
    buffer_size = 2*(leaf_width + output_width + 3)
    print("Run:\t", run_id)
    print("Depth:\t", depth)
    print("Input:\t", input_width)
    print("Output:\t", output_width)
    print("Leaf:\t", leaf_width)
    print("Buffer:\t", buffer_size)
    
    # to reduce the sparsity and train only below a certain tresholds
    list_of_sizes = [100, 90, 80, 70, 60, 50]
    checked_sizes = [False for x in list_of_sizes]
    current_size_index = 0
    
    start = 0.5
    a = start
    b = start
    sizes=[]
    before_trunc_sizes=[]
    trunc_sizes=[]

    model = wrapped_model.to(device)
    
    layers_list = []
    for i, (name, p) in enumerate(model.named_parameters()):
        if (len(list(p.shape)) > 1 and p.requires_grad):
            layers_list.append(p)
    un, comp, layers = print_size_fc(model, layers_list, [1,1,1,1,1,1])
    
    MODEL_NAME_COMPRESSED = "mnist_" + run_id + "_compressed_full"
    
    layers_sizes = layers.copy()
    # Creating a list of 1s with the same length as original_list using multiplication
    list_of_sparsity = [1] * len(layers_sizes)
    
    # Testing time and accuracy
    model.eval()
    t = time()
    train_loss, train_acc = test(model, train_loader)
    test_loss, test_acc = test(model, test_loader)
    t = time() - t
    un, comp, layers_sizes = print_size_fc(model, layers_list, list_of_sparsity)
    sizes.append((comp/1000, test_acc, test_acc, t))
    
    for i in range(1, 100):
        # get index of max
        index_of_max = np.argmax(layers_sizes)
        current_sparsity = list_of_sparsity[index_of_max]

        # reduce
        if (current_sparsity == 1):
            list_of_sparsity[index_of_max] = start - (start * 0.1)
        else:
            list_of_sparsity[index_of_max] = current_sparsity - (current_sparsity * 0.1)

        un, comp, layers_sizes = print_size_fc(model, layers_list, list_of_sparsity)
        
        if (current_size_index >= len(list_of_sizes)):
            break
        elif (comp / 1000 < list_of_sizes[current_size_index]):
            # compress and save result
            MODEL_NAME_COMPRESSED = "mnist_" + run_id + "_compressed_" + str(list_of_sizes[current_size_index])
            current_size_index += 1
            model.train()
            accuracy = perform_compression(model, layers_list, list_of_sparsity, learning_rate, num_epochs,
                                           train_loader, test_loader, device,
                                           val_loader=val_loader, model_name=MODEL_NAME_COMPRESSED, given_criterion=criterion)
            model.load_state_dict(torch.load(MODEL_NAME_COMPRESSED+".h5", map_location='cpu'))
            model.eval()
            t = time()
            train_loss, train_acc = test(model, train_loader)
            test_loss, test_acc = test(model, test_loader)
            t = time() - t
            sizes.append((comp / 1000, test_acc, accuracy, t))
        # continue reducing sparsity
        print(i, "iteration - ", "Size:", comp, list_of_sparsity)
    result.append({'run_id': run_id, 'depth': depth,'input_width':input_width, 
                   'output': output_width, 'leaf_width':leaf_width,
                   'buffer_size': buffer_size, 'sizes': sizes})

In [None]:
for r in result:
    print(r)

In [None]:
class FFFWrapper(torch.nn.Module):
    def __init__(self, fff):
        super(FFFWrapper, self).__init__()
        self._fff = fff
        self._fastinference = [None for i in range(2 ** (self._fff.fff.depth.item()))]

    def forward(self, x, return_nodes=False):
        """
        Override the forward method in order to log the data distribution.
        """
        x = x.view(len(x), -1)
        original_shape = x.shape
        batch_size = x.shape[0]
        last_node = torch.zeros(len(x))

        current_nodes = torch.zeros((batch_size,), dtype=torch.long, device=x.device)
        for i in range(self._fff.fff.depth.item()):
            plane_coeffs = self._fff.fff.node_weights.index_select(dim=0, index=current_nodes)
            plane_offsets = self._fff.fff.node_biases.index_select(dim=0, index=current_nodes)
            plane_coeff_score = torch.bmm(x.unsqueeze(1), plane_coeffs.unsqueeze(-1))
            plane_score = plane_coeff_score.squeeze(-1) + plane_offsets
            plane_choices = (plane_score.squeeze(-1) >= 0).long()

            platform = torch.tensor(2 ** i - 1, dtype=torch.long, device=x.device)
            next_platform = torch.tensor(2 ** (i+1) - 1, dtype=torch.long, device=x.device)
            current_nodes = (current_nodes - platform) * 2 + plane_choices + next_platform

        leaves = current_nodes - next_platform
        new_logits = torch.empty((batch_size, self._fff.fff.output_width), dtype=torch.float, device=x.device)
        last_node = leaves

        for i in range(leaves.shape[0]):
            leaf_index = leaves[i]
            if self._fastinference[leaf_index] is not None:
                new_logits[i] = self._fastinference[leaf_index]
            else:
                logits = torch.matmul( x[i].unsqueeze(0), self._fff.fff.w1s[leaf_index])
                logits += self._fff.fff.b1s[leaf_index].unsqueeze(-2)
                activations = self._fff.fff.activation(logits)
                new_logits[i] = torch.matmul( activations, self._fff.fff.w2s[leaf_index]).squeeze(-2)

        if return_nodes:
            return new_logits.view(*original_shape[:-1], self._fff.fff.output_width), last_node
        return new_logits.view(*original_shape[:-1], self._fff.fff.output_width)


    def simplify_leaves(self, trainloader):
        y, leaves = (get_dist(self, trainloader))
        y = y.cpu().detach().numpy()
        outputs = y.max() + 1
        leaves = leaves.cpu().detach().numpy()

        n_simplifications = 0
        ratios = {}
        for l in np.unique(leaves):
            ratios[l] = torch.zeros(outputs)
            indices = leaves == l

            for i in range(outputs):
                ratios[l][i] = (np.sum(y[indices] == i) / np.sum(indices))

            argmax = np.argmax(ratios[l])
            if ratios[l][argmax] >= 0:
                output = torch.zeros(outputs)
                output[argmax] = 1
                self._fastinference[l] = output
                n_simplifications += 1
                print(f"Leaf {l} has been replaced with {argmax}")
        print(self._fastinference)

In [None]:
import typer
import mlflow
import pickle
import numpy as np
import pandas as pd


def get_split_code(array, bias):
    code = """
    acc = """ + " + ".join(f"x[{i}] * {v}" for i, v in enumerate(array)) + """;
    acc += """ + str(bias[0]) + """;
    """
    return code


def get_output_code(w1, b1, w2, b2):
    code = """
    float hidden[""" + str(w1.shape[1]) + """];
    """
    for i in range(w1.shape[1]):
        code += f"hidden[{i}] = {b1[i]} + " + " + ".join(f"x[{j}] * {v}" for j, v in enumerate(w1[:, i])) + ";\n"
        code += f"hidden[{i}] = hidden[{i}] > 0 ? hidden[{i}] : 0;\n"
    code += """
    float logits[""" + str(w2.shape[1]) + """];
    """
    for j in range(w2.shape[1]):
        code += f"logits[{j}] = {b2[j]} + " + " + ".join(f"hidden[{i}] * {v}" for i, v in enumerate(w2[:, j])) + ";\n"
        code += f"logits[{j}] = logits[{j}] > 0 ? logits[{j}] : 0;\n"

    code += """
    max = 0.0;
    argmax = 0;
    for (int i = 0; i < """ + str(w2.shape[1]) + """; i++) {
        if (logits[i] > max) {
            max = logits[i];
            argmax = i;
        }
    }

    return argmax;
    """

    return code

def get_splits(weights, biases):
    code = """int perform_inference(float* x) {
    float acc;
    float max;
    int argmax;
        <replaceme>
}"""
    for index, (array, bias) in enumerate(zip(weights, biases)):
        code = code.replace("<replaceme>", get_split_code(array, bias), 1)
    return code


class Node:
    def __init__(self, array, bias, left, right):
        self._array = array
        self._bias = bias
        self._left = left
        self._right = right

    def __str__(self):
        code = get_split_code(self._array, self._bias)

        code += """
        if (acc >= 0) {
            """ + str(self._left) + """
        } else {
            """ + str(self._right) + """
        }
        """
        return code


class Leaf(Node):
    def __init__(self, w1, b1, w2, b2):
        self._w1 = w1
        self._b1 = b1
        self._w2 = w2
        self._b2 = b2

    def __str__(self):
        if self._w2 is None:
            return f"return {self._w1};\n"
        return get_output_code(self._w1, self._b1, self._w2, self._b2)

def print_parameters(params, key, lines, line, flit, skipTruncatedLeaves, sizes_print=False, index_print=True):
    # the first check on the sparsity on the weight also consider the non zero values of truncated leaves
    
    weight = params[key]
    dim = len(weight.shape)
    
    non_zero_values = np.count_nonzero(weight)
    first_dim = weight.shape[0]
    sparse = False
    
    if (dim == 2):
        weight_row = weight.shape[0]
        weight_col = weight.shape[1]
        
        sparse = non_zero_values < ((weight_row * (weight_col - 1) - 1) / 2)
    elif (dim == 3):
        weight_depth = weight.shape[0]
        weight_row = weight.shape[1]
        weight_col = weight.shape[2]
        
        size_requested = weight_depth * weight_row * weight_col
        sparse = non_zero_values < ((size_requested - weight_depth) / 2)
    if (not sparse):
        # print the parameters as usal
        lines.insert(
            line,
            "#define " + key + "_SPARSE " + str(0) + "\n"
        )
        line += 1
        line += 1
        
        param = weight
        if (key in ['LEAF_HIDDEN_WEIGHTS', 'LEAF_HIDDEN_BIASES', 'LEAF_OUTPUT_WEIGHTS', 'LEAF_OUTPUT_BIASES']):
            if (skipTruncatedLeaves):
                param = param[params['FASTINFERENCE'] == -1]
        param = param.flatten()
        tmp = ""
        if (flit and key not in ['FASTINFERENCE']):
            tmp = ", ".join(["F_LIT(" + str(x) + ")" for x in param])
        else:
            tmp = ", ".join([str(x) for x in param])
        lines.insert(
            line,
            tmp
        )
    else:
        # CSC or CSR format
        leaves_values = np.empty([0], dtype=float)
        leaves_offsets = np.empty([0], dtype=int)
        leaves_sizes = None
        printed_dim = key + "_DIM_1"
        if (sizes_print):
            leaves_sizes = np.empty([first_dim], dtype=int)
        elif(index_print):
            leaves_sizes = np.zeros([1], dtype=int)
            printed_dim += " + 1"
            
        value_position = 0
        actual_non_zero_values = 0
        
        for index, leaf_weight in enumerate(weight): # from 0 to first_dim
            
            if (key in ['LEAF_HIDDEN_WEIGHTS', 'LEAF_HIDDEN_BIASES', 'LEAF_OUTPUT_WEIGHTS', 'LEAF_OUTPUT_BIASES']):
                # insert filters non zero values into the fitler sizes
                if (params['FASTINFERENCE'][index] != -1 and skipTruncatedLeaves):
                    leaves_sizes[index] = 0
                    continue
            
            # insert filters non zero values into the fitler sizes
            non_zero_values_here = np.count_nonzero(leaf_weight)
            actual_non_zero_values += non_zero_values_here
            # flatten the filter
            flatten_leaf = leaf_weight.ravel()
            offset = 1
            
            if (sizes_print):
                # _sizes[DIM_1] = for each node/leaf we save the number of NNZ contained
                leaves_sizes[index] = non_zero_values_here
            elif (index_print):
                # _sizes[DIM_1 + 1] = for each node/leaf we save the index in _data array of the starting value
                # _sizes from {0 up to NNZ}
                leaves_sizes = np.append(leaves_sizes, actual_non_zero_values)
            
            for value in flatten_leaf: # from 0 to (n_depth * n_height * n_width)
                if (value == 0):
                    # increase offset
                    offset += 1
                else:
                    # save value, save index, reset offset, increase position
                    leaves_values = np.append(leaves_values, value)
                    leaves_offsets = np.append(leaves_offsets, offset)
                    leaves_values[value_position] = value
                    leaves_offsets[value_position] = offset
                    offset = 1
                    value_position += 1
                
            
        tmp = ""
        # substitute the definition
        lines[line] = "#define " + key + "_NNZ " + str(actual_non_zero_values) + "\n"
        line+=1
        lines.insert(
            line,
            "#define " + key + "_DIM " + str(dim) + "\n"
        )
        line+=1
        lines.insert(
            line,
            "#define " + key + "_SPARSE " + str(1) + "\n"
        )
        for d in range(0, dim):
            line+=1
            lines.insert(
                line,
                "#define " + key + "_DIM_" + str(d + 1) + " " + str(weight.shape[d]) + "\n"
            )
        line+=1
        lines.insert(
            line,
            "__hifram fixed " + key + "_data[" + key + "_NNZ] = {\n"
        )
        line+=1
        
        if flit:
            tmp = ", ".join(["F_LIT(" + str(x) + ")" for x in leaves_values])
        else:
            tmp = ", ".join([str(x) for x in leaves_values])
        lines.insert(
            line,
            tmp
        )
        line+=1
        line+=1
        line+=1
        lines.insert(
            line,
            "\n__hifram fixed " + key + "_offset[" + key + "_NNZ] = {\n"
        )
        line+=1
        if (flit and False):
            tmp = ", ".join(["F_LIT(" + str(x) + ")" for x in leaves_offsets])
        else:
            tmp = ", ".join([str(x) for x in leaves_offsets])
        lines.insert(
            line,
            tmp
        )
        line+=1
        lines.insert(
            line,
            "\n"
        )
        line+=1
        lines.insert(
            line,
            "};\n"
        )
        line+=1
        lines.insert(
            line,
            "\n__hifram fixed " + key + "_sizes[" + printed_dim + "] = {\n"
        )
        line+=1
        if (flit and False):
            tmp = ", ".join(["F_LIT(" + str(x) + ")" for x in leaves_sizes])
        else:
            tmp = ", ".join([str(x) for x in leaves_sizes])
        lines.insert(
            line,
            tmp
        )
        line+=1
        lines.insert(
            line,
            "\n"
        )
        line+=1
        lines.insert(
            line,
            "};\n\n"
        )
    
def make_program(wrapped_model, name, original_fastinference, flit=True, skipTruncatedLeaves=False):

    node_weights = wrapped_model._fff.fff.node_weights.cpu().detach().numpy()
    node_biases = wrapped_model._fff.fff.node_biases.cpu().detach().numpy()
    w1s = wrapped_model._fff.fff.w1s
    b1s = wrapped_model._fff.fff.b1s.cpu().detach().numpy()
    w2s = wrapped_model._fff.fff.w2s
    b2s = wrapped_model._fff.fff.b2s.cpu().detach().numpy()
    fastinference = wrapped_model._fastinference

    w1s = w1s.transpose(1, 2).cpu().detach().numpy()
    w2s = w2s.transpose(1, 2).cpu().detach().numpy()

    params = {}

    params['NODE_WEIGHTS'] = node_weights#.flatten()
    params['NODE_BIASES'] = node_biases#.flatten()
    params['FASTINFERENCE'] = np.array([-1 if x is None else int(x.argmax()) for x in fastinference])
    actual_leaves_weights = w1s #[params['FASTINFERENCE'] == -1]
    actual_leaves_biases = b1s #[params['FASTINFERENCE'] == -1]
    actual_leaves_out_weights = w2s #[params['FASTINFERENCE'] == -1]
    actual_leaves_out_biases = b2s #[params['FASTINFERENCE'] == -1]
    params['LEAF_HIDDEN_WEIGHTS'] = actual_leaves_weights#.flatten()
    params['LEAF_HIDDEN_BIASES'] = actual_leaves_biases#.flatten()
    params['LEAF_OUTPUT_WEIGHTS'] = actual_leaves_out_weights#.flatten()
    params['LEAF_OUTPUT_BIASES'] = actual_leaves_out_biases#.flatten()
    with open("weights_" + name + ".h", "w") as f:
        with open("weights_template.h") as in_f:
            lines = in_f.readlines()

            i = 0
            while i < len(lines):
                i += 1
                if "Add definitions here" in lines[i]:
                    break
            lines[i] = (f"""// name of the model  {name}

#define DEPTH {wrapped_model._fff.fff.depth.item()}
#define N_LEAVES {2 ** wrapped_model._fff.fff.depth.item()}
#define INPUT_SIZE {wrapped_model._fff.fff.input_width}
#define LEAF_WIDTH {wrapped_model._fff.fff.leaf_width}
#define OUTPUT_SIZE {wrapped_model._fff.fff.output_width}
#define SIMPLIFIED_LEAVES {sum([f is not None for f in fastinference])}
#define ORIGINAL_FASTINFERENCE {original_fastinference}""")
            i+=1
            lines.insert(i, "\n")

            # lines.insert(i+7, """
# float FASTINFERENCE[N_LEAVES] = {-1};
# float NODE_WEIGHTS[N_LEAVES-1][INPUT_SIZE];
# float NODE_BIASES[N_LEAVES-1];
# float LEAF_HIDDEN_WEIGHTS[N_LEAVES-SIMPLIFIED_LEAVES][LEAF_WIDTH][INPUT_SIZE];
# float LEAF_OUTPUT_WEIGHTS[N_LEAVES-SIMPLIFIED_LEAVES][OUTPUT_SIZE][LEAF_WIDTH];
# float LEAF_HIDDEN_BIASES[N_LEAVES-SIMPLIFIED_LEAVES][LEAF_WIDTH];
# float LEAF_OUTPUT_BIASES[N_LEAVES-SIMPLIFIED_LEAVES][OUTPUT_SIZE];
            # """)

            for key in params.keys():
                i = 0
                while i < len(lines):
                    if f"fixed {key}" in lines[i]:
                        break
                    i += 1
#                 i += 1
                # line of the definition of the parameter
                # calculate sparsity and choose representation
                # if dense keep the definition, i+=1, and print
                # if sparse
                print_parameters(params, key, lines, i, flit, skipTruncatedLeaves)
                # print Parameters
                
                # calculate sparsity and choose representation
                # if dense -> print as we already are doing
                # if not change the representation
                

            f.writelines(lines)
    return wrapped_model


def main(run_id, name, original_fastinference):
#     import torch
    net = make_program(run_id, name, original_fastinference)
#     net._fff.eval()
#     X = np.loadtxt('test.txt')
#     X = torch.Tensor(X)
#     with open('ref_outputs.txt', 'w') as f:
#         y = net(X).argmax(1)
#         y = [(str(x) + "\n") for x in y.detach().cpu().numpy()]
#         f.writelines(y)

In [None]:
for i in range (0, len(list_of_run)):
    run_id = list_of_run[i]
    
    mlflow.artifacts.download_artifacts(run_id=run_id, dst_path=".")
    wrapped_model = pickle.load(open("./truncated_model.pkl", "rb"))
    
    original_fastinference = str([-1 if x is None else int(x.argmax()) for x in wrapped_model._fastinference])
    
    # compress and save result
    MODEL_NAME_COMPRESSED = "mnist_" + run_id + "_compressed_" + str(list_of_sizes[0])
    wrapped_model.load_state_dict(torch.load(MODEL_NAME_COMPRESSED+".h5", map_location='cpu'))
    
    n_simplifications = wrapped_model.to(device).simplify_leaves(train_loader)
    main(wrapped_model, MODEL_NAME_COMPRESSED, original_fastinference)