In [56]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split
import torch.quantization
import os

# Set device (use CUDA if available)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# MNIST dataset with normalization
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = datasets.MNIST(root="./data", train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root="./data", train=False, transform=transform, download=True)

# Data loaders
train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=64, shuffle=False)


In [57]:
N = 20

# Define the transformation to normalize MNIST data
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Load MNIST training dataset
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
# Load MNIST testing dataset
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

# Combine datasets
combined_dataset = torch.utils.data.ConcatDataset([train_dataset, test_dataset])

import random

def split_dataset(combined_dataset, N):
    # Total number of samples in the dataset
    M = len(combined_dataset)
    indices = list(range(M))
    split_size = M // N
    
    # Shuffle indices to ensure randomness in splitting
    random.shuffle(indices)
    
    # Split indices into N parts
    user_data = [indices[i * split_size:(i + 1) * split_size] for i in range(N)]
    
    # Create subsets for each user
    user_datasets = [torch.utils.data.Subset(combined_dataset, user_data[i]) for i in range(N)]
    
    return user_datasets


def split_train_test(user_dataset, test_ratio=0.2):
    # Total number of samples
    M = len(user_dataset)
    test_size = int(M * test_ratio)
    train_size = M - test_size
    
    # Split the dataset into training and testing sets
    train_subset, test_subset = random_split(user_dataset, [train_size, test_size])
    
    return train_subset, test_subset


user_datasets = split_dataset(combined_dataset, N)

batch_size = 64  # Adjust batch size as needed

# Split user-specific dataset into training and testing sets
user_train_loaders = []
user_test_loaders = []
for user_dataset in user_datasets:
    train_data, test_data = split_train_test(user_dataset)
    user_train_loaders.append(DataLoader(train_data, batch_size=batch_size, shuffle=True))
    user_test_loaders.append(DataLoader(test_data, batch_size=batch_size, shuffle=False))
    
def aggregate_updates(local_updates):
    # A naive method to aggregate model weights
    new_state_dict = {}
    for key in local_updates[0].keys():
        new_state_dict[key] = torch.mean(torch.stack([update[key] for update in local_updates]), dim=0)
    return new_state_dict

In [58]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.quant = torch.ao.quantization.QuantStub()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.relu1 = nn.ReLU()
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.relu2 = nn.ReLU()
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.relu3 = nn.ReLU()
        self.fc2 = nn.Linear(128, 10)
    
    def forward(self, x):
        x = self.quant(x)
        x = self.pool(self.relu1(self.conv1(x)))
        x = self.pool(self.relu2(self.conv2(x)))
        x = x.reshape(-1, 64 * 7 * 7)
        x = self.relu3(self.fc1(x))
        x = self.fc2(x)
        return x


In [65]:
global_model = Net().to(device)

client_models = [Net().to(device) for i in range(N)]

for i, model in enumerate(client_models + [global_model]):
    # Fuse layers (for better optimization during quantization)
    model.fuse_model = lambda: torch.quantization.fuse_modules(model, [["conv1", "relu1"], ["conv2", "relu2"]])
    model.fuse_model()
    
    # Specify quantization configuration
    model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
    
    # Prepare for QAT
    model = torch.quantization.prepare_qat(model)
    model = torch.quantization.convert(model)
    # Update the corresponding model
    if i < len(client_models):
        client_models[i] = model
    else:
        global_model = model





In [66]:
def print_model_size(mdl):
    torch.save(mdl.state_dict(), "tmp.pt")
    print("%.2f MB" %(os.path.getsize("tmp.pt")/1e6))
    os.remove('tmp.pt')

In [70]:
print_model_size(global_model)

0.43 MB


In [71]:
# Extract quantized parameters
for name, param in global_model.state_dict().items():
    print(f"{name}: {param}")
    # Quantized weights have scale and zero-point for reconstruction


quant.scale: tensor([1.])
quant.zero_point: tensor([0])
conv1.weight: tensor([[[[-0.1022, -0.0723,  0.2443],
          [ 0.0324, -0.3191,  0.2369],
          [-0.3166,  0.2568,  0.1870]]],


        [[[-0.1043,  0.2284,  0.0496],
          [-0.3078, -0.1390, -0.3178],
          [ 0.3078, -0.1167, -0.0025]]],


        [[[ 0.3239,  0.0612,  0.2117],
          [-0.0944, -0.2678,  0.0332],
          [ 0.0332, -0.2933, -0.2219]]],


        [[[-0.3064,  0.1293, -0.2059],
          [-0.2011, -0.1771,  0.1077],
          [-0.2059,  0.1532, -0.2202]]],


        [[[ 0.0491,  0.0801, -0.1550],
          [ 0.1860,  0.3151,  0.3280],
          [-0.0594,  0.0207,  0.1705]]],


        [[[-0.1298,  0.2621,  0.1578],
          [-0.1628,  0.1247, -0.3231],
          [-0.2239, -0.2799,  0.2722]]],


        [[[ 0.2907, -0.3153, -0.0788],
          [-0.2932,  0.2882,  0.0591],
          [ 0.1380,  0.2611,  0.3104]]],


        [[[-0.2488, -0.1009, -0.1367],
          [-0.0650,  0.1972,  0.2241],
     

In [75]:
# Initialize and quantize the model
model = Net()
model.qconfig = torch.quantization.get_default_qconfig('fbgemm')  # Quantization config
model = torch.quantization.prepare(model)  # Prepare for quantization
quantized_model = torch.quantization.convert(model)  # Convert to 8-bit



In [None]:
full_precision_model = Net()


In [83]:
print_model_size(quantized_model)
print_model_size(full_precision_model)

0.43 MB
1.69 MB


In [76]:
# Define a new full-precision model

# Dequantize each layer manually
for name, param in quantized_model.named_parameters():
    if hasattr(param, 'q_per_channel_scales'):
        # Handle per-channel quantization
        dequantized_weight = param.dequantize()
    elif hasattr(param, 'q_scale'):
        # Handle per-tensor quantization
        scale, zero_point = param.q_scale(), param.q_zero_point()
        dequantized_weight = (param.int_repr().float() - zero_point) * scale
    else:
        dequantized_weight = param.float()

    # Assign back to the full-precision model
    target_layer = dict(full_precision_model.named_parameters())[name]
    target_layer.data.copy_(dequantized_weight)

In [82]:
for j, (x, y) in zip(range(2), user_train_loaders[0]):
    # Check the outputs
    quantized_output = model(x)
    dequantized_output = full_precision_model(x)
    
    print("Quantized Output:", quantized_output.shape)
    print("Dequantized Output:", dequantized_output.shape)

Quantized Output: torch.Size([64, 10])
Dequantized Output: torch.Size([64, 10])
Quantized Output: torch.Size([64, 10])
Dequantized Output: torch.Size([64, 10])
