In [None]:
import numpy as np
import tarfile
import os
import torch
from torch import nn
import torch.nn.functional as F
from torchvision.datasets.utils import download_url
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
import torchvision.transforms as tt
from torchvision.models import resnet18
torch.manual_seed(100)
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' 

In [None]:
dataset_url = "https://s3.amazonaws.com/fast-ai-imageclas/mnist_png.tgz"
download_url(dataset_url, '.')

# Extract from archive
with tarfile.open('./mnist_png.tgz', 'r:gz') as tar:
    tar.extractall(path='./mnist_png')
    

In [None]:
# Look into the data directory
data_dir = './mnist_global/'
print(os.listdir(data_dir))
classes = os.listdir(data_dir + "/training")
print(classes)

In [None]:
train_ds = ImageFolder(data_dir+'/training', transform_train)
valid_ds = ImageFolder(data_dir+'/testing', transform_test)
batch_size = 256
train_dl = DataLoader(train_ds, batch_size, shuffle=True, num_workers=3, pin_memory=True)
valid_dl = DataLoader(valid_ds, batch_size*2, num_workers=3, pin_memory=True)

In [None]:
device = "cuda:0"
model = resnet18(num_classes = 10).to(device = device)

epochs = 1
max_lr = 0.01
grad_clip = 0.1
weight_decay = 1e-4
opt_func = torch.optim.Adam

In [None]:
%%time
history = fit_one_cycle(epochs, max_lr, model, train_dl, valid_dl, 
                             grad_clip=grad_clip, 
                             weight_decay=weight_decay, 
                             opt_func=opt_func)

torch.save(model.state_dict(), "30dec-og.pt")

In [None]:
model=resnet18(num_classes = 10).to(device = device)
torch.save(model.state_dict(), "30dec_dp.pt")  


def eval(gb):
    gb.eval()

    # Define transformations for the dataset
    transform = tt.Compose([
       tt.ToTensor(),
    tt.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    # Load the CIFAR-10 test dataset
    test_ds = ImageFolder('mnist_global/testing/', transform)
    test_loader = DataLoader(test_ds, batch_size=64, shuffle=False)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    gb.to(device)

    # Evaluate the model on the test dataset and calculate predictions
    class_correct = torch.zeros(10)  # Assuming there are 10 classes in CIFAR-10
    class_total = torch.zeros(10)

    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = gb(images)
            _, predicted = torch.max(outputs, 1)
            correct = (predicted == labels).squeeze()

            for i in range(len(labels)):
                label = labels[i]
                class_correct[label] += correct[i].item()
                class_total[label] += 1

    # Calculate accuracy per class
    class_accuracies = class_correct / class_total
    for i, acc in enumerate(class_accuracies):
        print(f'Accuracy for class {i}: {acc.item():.4f}')


In [None]:
import copy
def fedavg(local_models):

    global_model = copy.deepcopy(local_models[0])
    avg_state_dict = global_model.state_dict()
    
    local_state_dicts = list()
    for model in local_models:
        local_state_dicts.append(model.state_dict())
    
    #local_state_dicts=local_models
    
    for layer in avg_state_dict.keys():
        avg_state_dict[layer] *= 0 
        for client_idx in range(len(local_models)):
            avg_state_dict[layer] += local_state_dicts[client_idx][layer]
        avg_state_dict[layer] = avg_state_dict[layer]/4
    
    global_model.load_state_dict(avg_state_dict)
    return global_model 


In [None]:
clients=['client_min_1','client_min_2','client_min_3']
ll=5

noise_multiplier=0.1

for i in range(5):
    print(f'Round {i+1}')
    print('Global Model')
    print('------------')
    model = resnet18(num_classes = 10).to(device = device)
    #model.load_state_dict(torch.load("30dec.pt"))

    if i!=0:
        model.load_state_dict(torch.load("30dec_dp.pt"))

    for client in clients:
        other_samples = ImageFolder(client+'/training', transform_train)


        heal_loader = torch.utils.data.DataLoader(other_samples, batch_size=256, shuffle = True)

        optimizer = torch.optim.Adam(model.parameters(), lr = 0.01)

        
        for epoch in range(ll):  
            model.train()
            for i, data in enumerate(heal_loader):
                inputs, labels = data
                inputs, labels = inputs.cuda(),torch.tensor(labels).cuda()

                optimizer.zero_grad()
                outputs = model(inputs)
                loss = F.cross_entropy(outputs, labels)
                loss.backward()
                optimizer.step()

        print(client)
        eval(model)        
        torch.save(model.state_dict(), client+".pt")    
        print('------------')
        print('------------')
    
    model=resnet18(num_classes = 10).to(device = device)
    model1=resnet18(num_classes = 10).to(device = device)
    model2=resnet18(num_classes = 10).to(device = device)
    model3=resnet18(num_classes = 10).to(device = device)
    
    model.load_state_dict(torch.load("30dec_dp.pt"))
    model1.load_state_dict(torch.load("client_min_1.pt"))
    model2.load_state_dict(torch.load("client_min_2.pt"))
    model3.load_state_dict(torch.load("client_min_3.pt"))
    gb= fedavg([model,model1,model2,model3])
    print('Global Model')
    print('------------')

    history = [evaluate(gb, valid_dl)]
    print(history)
    eval(gb)
    print('------------')
    torch.save(gb.state_dict(), "30dec_dp.pt")  






    


In [None]:
eval(model)

In [None]:
eval(model2)

In [None]:
clients=['client_min_1','client_min_2','client_min_3']
ll=5

noise_multiplier=0.1

for i in range(5):
    print(f'Round {i+1}')
    print('Global Model')
    print('------------')
    model = resnet18(num_classes = 10).to(device = device)
    #model.load_state_dict(torch.load("30dec.pt"))

    if i!=0:
        model.load_state_dict(torch.load("30dec_dp.pt"))

    for client in clients:
        other_samples = ImageFolder(client+'/training', transform_train)


        heal_loader = torch.utils.data.DataLoader(other_samples, batch_size=256, shuffle = True)

        optimizer = torch.optim.Adam(model.parameters(), lr = 0.01)

        
        for epoch in range(ll):  
            model.train()
            for i, data in enumerate(heal_loader):
                inputs, labels = data
                inputs, labels = inputs.cuda(),torch.tensor(labels).cuda()

                optimizer.zero_grad()
                outputs = model(inputs)
                loss = F.cross_entropy(outputs, labels)
                loss.backward()
                optimizer.step()

        print(client)
        eval(model)        
        torch.save(model.state_dict(), client+".pt")    
        print('------------')
        print('------------')
    
    model=resnet18(num_classes = 10).to(device = device)
    model1=resnet18(num_classes = 10).to(device = device)
    model2=resnet18(num_classes = 10).to(device = device)
    model3=resnet18(num_classes = 10).to(device = device)
    
    model.load_state_dict(torch.load("30dec_dp.pt"))
    model1.load_state_dict(torch.load("client_min_1.pt"))
    model2.load_state_dict(torch.load("client_min_2.pt"))
    model3.load_state_dict(torch.load("client_min_3.pt"))
    gb= fedavg([model,model1,model2,model3])
    print('Global Model')
    print('------------')

    history = [evaluate(gb, valid_dl)]
    print(history)
    eval(gb)
    print('------------')
    torch.save(gb.state_dict(), "30dec_dp.pt")  






    


In [None]:
eval(model)