In [1]:
from dotenv import load_dotenv
load_dotenv()

from argparse import ArgumentParser
import warnings
from collections import OrderedDict
import json
import io
import os
import sys
import pickle
import base64
from traceback import print_exc

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.utils.data.sampler as sampler
from torchvision.transforms import Compose, Normalize, ToTensor

from sklearn.cluster import KMeans

from tqdm import tqdm
from time import sleep, time

import logging

from non_iid_generator.customDataset import CustomDataset

DEVICE = os.environ["TORCH_DEVICE"]
DEVICE = "cuda"


In [2]:
def load_data(train_dataset_path, test_dataset_path):
    """Load CIFAR-10 (training and test set)."""

    batch_size = 128
    momentum = 0.9
    weight_decay = 1e-4
    finetune_lr = 0.001

    train_data = pickle.load(open(train_dataset_path, "rb"))
    train_loader = torch.utils.data.DataLoader(
        train_data,
        batch_size=batch_size,
        shuffle=True)

    test_data = pickle.load(open(test_dataset_path, "rb"))
    test_loader = torch.utils.data.DataLoader(
        test_data,
        batch_size=batch_size,
        shuffle=True)

    return train_loader, test_loader

def fine_tune(model, iterations, train_loader, print_frequency=100):
    '''
        short-term fine-tune a simplified model
        
        Input:
            `model`: model to be fine-tuned.
            `iterations`: (int) num of short-term fine-tune iterations.
            `print_frequency`: (int) how often to print fine-tune info.
        
        Output:
            `model`: fine-tuned model.
    '''

    # Data loaders for fine tuning and evaluation.
    batch_size = 128
    momentum = 0.9
    weight_decay = 1e-4
    finetune_lr = 0.001

    criterion = torch.nn.BCEWithLogitsLoss()
    
    _NUM_CLASSES = 10
    optimizer = torch.optim.SGD(
        model.parameters(),
        finetune_lr, 
        momentum=momentum,
        weight_decay=weight_decay)

    model = model.to(DEVICE)
    model.train()
    dataloader_iter = iter(train_loader)
    for i in range(iterations):
        try:
            (input, target) = next(dataloader_iter)
        except:
            dataloader_iter = iter(train_loader)
            (input, target) = next(dataloader_iter)
            
        if i % print_frequency == 0:
            print('Fine-tuning iteration {}'.format(i))
            sys.stdout.flush()
        
        # Ensure the target shape is sth like torch.Size([batch_size])
        if len(target.shape) > 1: target = target.reshape(len(target))

        target.unsqueeze_(1)
        target_onehot = torch.FloatTensor(target.shape[0], _NUM_CLASSES)
        target_onehot.zero_()
        target_onehot.scatter_(1, target, 1)
        target.squeeze_(1)
        input, target = input.to(DEVICE), target.to(DEVICE)
        target_onehot = target_onehot.to(DEVICE)

        pred = model(input)
        loss = criterion(pred, target_onehot)
        optimizer.zero_grad()
        loss.backward()  # compute gradient and do SGD step
        optimizer.step()

    return model

In [3]:
model = torch.load("models/alexnet/model_cpu.pth.tar")

NO_CLIENTS = 20
NO_CLASSES = 10

In [10]:
weights_list = []
for client_id in range(NO_CLIENTS):
    print(f">> Dataset No: {client_id}")
    train_dataset_path = f"./data/Cifar10/train/{client_id}.pkl"
    test_dataset_path = f"./data/Cifar10/test/{client_id}.pkl"
    trainloader, testloader = load_data(train_dataset_path, test_dataset_path)

    temp_model = fine_tune(model, 20, trainloader, print_frequency=10)
    weights_model = np.concatenate([param.data.cpu().numpy().flatten() for param in temp_model.parameters()])

    weights_list.append(weights_model)

>> Dataset No: 0
Fine-tuning iteration 0
Fine-tuning iteration 10
>> Dataset No: 1
Fine-tuning iteration 0
Fine-tuning iteration 10
>> Dataset No: 2
Fine-tuning iteration 0
Fine-tuning iteration 10
>> Dataset No: 3
Fine-tuning iteration 0
Fine-tuning iteration 10
>> Dataset No: 4
Fine-tuning iteration 0
Fine-tuning iteration 10
>> Dataset No: 5
Fine-tuning iteration 0
Fine-tuning iteration 10
>> Dataset No: 6
Fine-tuning iteration 0
Fine-tuning iteration 10
>> Dataset No: 7
Fine-tuning iteration 0
Fine-tuning iteration 10
>> Dataset No: 8
Fine-tuning iteration 0
Fine-tuning iteration 10
>> Dataset No: 9
Fine-tuning iteration 0
Fine-tuning iteration 10
>> Dataset No: 10
Fine-tuning iteration 0
Fine-tuning iteration 10
>> Dataset No: 11
Fine-tuning iteration 0
Fine-tuning iteration 10
>> Dataset No: 12
Fine-tuning iteration 0
Fine-tuning iteration 10
>> Dataset No: 13
Fine-tuning iteration 0
Fine-tuning iteration 10
>> Dataset No: 14
Fine-tuning iteration 0
Fine-tuning iteration 10
>> Da

In [11]:
origin_model_weights = weights_model = np.concatenate([param.data.cpu().numpy().flatten() for param in model.parameters()])

weight_diff = weights_list - origin_model_weights

all_weights = np.vstack(weights_list)

In [16]:
kmeans = KMeans(n_clusters=3, random_state=42)
cluster_assignments = kmeans.fit_predict(weight_diff)

  super()._check_params_vs_input(X, default_n_init=10)


In [17]:
cluster_assignments

array([0, 0, 0, 0, 0, 0, 0, 2, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2],
      dtype=int32)

In [18]:
client_id = 1
import pandas as pd
df = pd.DataFrame(columns=[str(i) for i in range(10)])

for client_id in range(NO_CLIENTS):
    row = np.zeros(NO_CLASSES, dtype=int)
    train_dataset_path = f"./data/Cifar10/train/{client_id}.pkl"
    train_data = pickle.load(open(train_dataset_path, "rb"))
    counts = np.asarray(np.unique(train_data.labels, return_counts=True))
    counts = dict(zip(counts[0].tolist(),counts[1].tolist()))
    for label, count in counts.items():
        row[label] = count
    df.loc[client_id] = row


In [19]:
df["labels"] = cluster_assignments
df

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,labels
0,0,0,0,8,0,1,5,122,3,0,0
1,0,1,0,33,0,0,194,19,2426,0,0
2,0,0,63,3,1266,309,0,0,0,0,0
3,45,8,2,1,35,3,1259,0,103,0,0
4,13,63,335,179,467,0,0,98,142,0,0
5,5,68,23,0,0,17,81,0,4,3,0
6,13,0,0,17,0,0,1,34,12,2848,0
7,0,214,457,216,10,0,0,185,0,0,2
8,1403,35,176,0,0,0,0,0,0,0,1
9,1157,229,1,1,1,77,29,764,0,0,1
