In [13]:
import numpy as np
import pandas as pd
import json

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torch.quantization

In [30]:
# Define the ANN class
class ANN(nn.Module):
    def __init__(self):
        super(ANN, self).__init__()
        self.fc1 = nn.Linear(2, 4)
        self.fc2 = nn.Linear(4, 1)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Function to compute global min and max
def get_global_min_max(tensors):
    min_val = float('inf')
    max_val = float('-inf')
    for tensor in tensors:
        min_val = min(min_val, tensor.min().item())
        max_val = max(max_val, tensor.max().item())
    return min_val, max_val

# Function to quantize tensors using common bins
def quantize_tensor_to_int(tensor, bin_edges):
    """
    Quantize the tensor into integer bins.
    
    Args:
    tensor (torch.Tensor): The tensor to be quantized.
    bin_edges (np.ndarray): The edges of the bins.
    
    Returns:
    torch.Tensor: The quantized tensor represented by bin indices.
    """
    # Flatten the tensor to a 1D array
    flattened_tensor = tensor.flatten()
    
    # Digitize the tensor into bins
    quantized_indices = np.digitize(flattened_tensor.numpy(), bin_edges) - 1
    quantized_indices = np.clip(quantized_indices, 0, len(bin_edges) - 2)  # Ensure indices are within valid range
    
    # Reshape the quantized tensor to the original shape
    quantized_tensor = torch.tensor(quantized_indices, dtype=torch.int32).reshape(tensor.shape)
    
    return quantized_tensor

# Function to capture intermediate outputs and prepare for quantization
def get_intermediate_outputs(model, input_data):
    intermediate_outputs = []

    def hook_fn(module, input, output):
        intermediate_outputs.append(input[0])
        intermediate_outputs.append(output)

    hooks = []
    for layer in model.children():
        hook = layer.register_forward_hook(hook_fn)
        hooks.append(hook)

    # Forward pass
    with torch.no_grad():
        model(input_data)

    # Remove hooks
    for hook in hooks:
        hook.remove()

    return intermediate_outputs

# Quantize and save all tensors
def save_tensor_as_list(tensor, filename):
    tensor_list = tensor.cpu().numpy().tolist()  # Convert tensor to list
    with open(filename, 'w') as f:
        json.dump(tensor_list, f)

In [31]:
# Instantiate the model
model = ANN()

# Define an input tensor with appropriate shape
input_tensor = torch.randint(0, 10, (3, 2), dtype=torch.float)

# Capture intermediate outputs
intermediate_outputs = get_intermediate_outputs(model, input_tensor)

# Include input_tensor in the tensors to be quantized
all_tensors = [input_tensor] + intermediate_outputs

# Compute global min and max
global_min, global_max = get_global_min_max(all_tensors)

# Number of bins for quantization
num_bins = 16

# Create bin edges
bin_edges = np.linspace(global_min, global_max, num_bins + 1)

# quantize tensors
quantized_input_tensor = quantize_tensor_to_int(input_tensor, bin_edges)
save_tensor_as_list(quantized_input_tensor, 'input_tensor.json')

for i, tensor in enumerate(intermediate_outputs):
    quantized_tensor = quantize_tensor_to_int(tensor, bin_edges)
    save_tensor_as_list(quantized_tensor, f"intermediate_output_{i}.json")

print("Tensors have been saved as JSON files.")

# Print verification
for i, tensor in enumerate(intermediate_outputs):
    quantized_tensor = quantize_tensor_to_int(tensor, bin_edges)
    tensor_list = quantized_tensor.cpu().numpy().tolist()
    
    print(f"Intermediate Output {i}: {tensor_list}")

Tensors have been saved as JSON files.
Intermediate Output 0: [[9, 7], [6, 6], [4, 15]]
Intermediate Output 1: [[3, 3, 3, 4], [4, 3, 4, 4], [8, 3, 0, 1]]
Intermediate Output 2: [[4, 4, 4, 4], [4, 4, 4, 4], [8, 4, 4, 4]]
Intermediate Output 3: [[5], [5], [5]]
