In [30]:
from argparse import ArgumentParser, Namespace
import os
import sys
import time
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.backends.cudnn as cudnn
import pickle
import numpy as np
import json
import random
from datetime import datetime

from copy import deepcopy
import shutil

_NUM_CLASSES = 100
DEVICE = torch.device("mps")
DT_FORMAT = "%Y-%m-%d %H:%M:%S"

In [28]:

def update_progress(index, length, **kwargs):
    '''
        display progress
        
        Input:
            `index`: (int) shows the index of current progress
            `length`: (int) total length of the progress
            `**kwargs`: info to display (e.g. accuracy)
    '''
    barLength = 10 # Modify this to change the length of the progress bar
    progress = float(index/length)
    if isinstance(progress, int):
        progress = float(progress)
    if not isinstance(progress, float):
        progress = 0
    if progress >= 1:
        progress = 1
    block = int(round(barLength*progress))
    text = "\rPercent: [{0}] {1:.2f}% ({2}/{3}) ".format( 
            "#"*block + "-"*(barLength-block), round(progress*100, 3), \
            index, length)
    for key, value in kwargs.items():
        text = text + str(key) + ': ' + str(value) + ', '
    if len(kwargs) != 0:
        text = text[:-2:]
    sys.stdout.write(text)
    sys.stdout.flush()


def adjust_learning_rate(optimizer, epoch, args):
    """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
    lr = args.lr * (0.1 ** (epoch // 50))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def get_avg(self):
        return self.avg
    
    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count
        
def compute_accuracy(output, target):
    output = output.argmax(dim=1)
    acc = 0.0
    acc = torch.sum(target == output).item()
    acc = acc/output.size(0)*100
    return acc
    
def train(train_loader, model, criterion, optimizer, epoch, args):
    batch_time = AverageMeter()
    losses = AverageMeter()
    acc = AverageMeter()

    if args.fedprox: global_model = deepcopy(model)

    # switch to train mode
    model.train()
    
    print('===================================================================')
    end = time.time()
    
    for i, (images, target) in enumerate(train_loader):
        
        # # Ensure the target shape is sth like torch.Size([batch_size])
        if len(target.shape) > 1: target = target.reshape(len(target))

        target.unsqueeze_(1)
        target_onehot = torch.FloatTensor(target.shape[0], _NUM_CLASSES)
        target_onehot.zero_()
        target_onehot.scatter_(1, target, 1)
        target.squeeze_(1)
        
        images = images.to(DEVICE)
        target_onehot = target_onehot.to(DEVICE)
        target = target.to(DEVICE)

        output = model(images)
        
        if args.fedprox:
            proximal_term = 0
            for local_weights, global_weights in zip(model.parameters(), global_model.parameters()):
                proximal_term += (local_weights - global_weights).norm(2)
            loss = criterion(output, target_onehot) + (args.mu / 2) * proximal_term
        
        else:
            loss = criterion(output, target_onehot)
        
        # measure accuracy and record loss
        batch_acc = compute_accuracy(output, target)
        
        losses.update(loss.item(), images.size(0))
        acc.update(batch_acc, images.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        
        # Update statistics
        estimated_time_remained = batch_time.get_avg()*(len(train_loader)-i-1)
        update_progress(i, len(train_loader), 
            ESA='{:8.2f}'.format(estimated_time_remained)+'s',
            loss='{:4.2f}'.format(loss.item()),
            acc='{:4.2f}%'.format(float(batch_acc))
            )

    print()
    print('Finish epoch {}: time = {:8.2f}s, loss = {:4.2f}, acc = {:4.2f}%'.format(
            epoch+1, batch_time.get_avg()*len(train_loader), 
            float(losses.get_avg()), float(acc.get_avg())))
    print('===================================================================')
    return


def eval(test_loader, model, args):
    batch_time = AverageMeter()
    acc = AverageMeter()

    # switch to eval mode
    model.eval()

    end = time.time()
    for i, (images, target) in enumerate(test_loader):

        if len(target.shape) > 1: target = target.reshape(len(target))

        images = images.to(DEVICE)
        target = target.to(DEVICE)
        
        output = model(images)
        batch_acc = compute_accuracy(output, target)
        acc.update(batch_acc, images.size(0))
        batch_time.update(time.time() - end)
        end = time.time()

        # Update statistics
        estimated_time_remained = batch_time.get_avg()*(len(test_loader)-i-1)
        update_progress(i, len(test_loader), 
            ESA='{:8.2f}'.format(estimated_time_remained)+'s',
            acc='{:4.2f}'.format(float(batch_acc))
            )
    print()
    print('Test accuracy: {:4.2f}% (time = {:8.2f}s)'.format(
            float(acc.get_avg()), batch_time.get_avg()*len(test_loader)))
    print('===================================================================')
    return float(acc.get_avg())
      

In [3]:
class AlexNet_reduced(nn.Module):

    def __init__(self):
        super(AlexNet_reduced, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(64, 192, kernel_size=3, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(192, 384, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
        )
        self.avgpool = nn.AdaptiveAvgPool2d((6, 6))
        self.classifier = nn.Sequential(
            nn.Dropout(),
            nn.Linear(256 * 6 * 6, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, 10),
        )

    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x


def alexnet_reduced(reducing_rate=1, pretrained=False, progress=True, num_classes=1000):
    r"""AlexNet model architecture from the
    `"One weird trick..." <https://arxiv.org/abs/1404.5997>`_ paper.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    model = AlexNet_reduced()
    if pretrained:
        state_dict = model_zoo.load_url(model_urls['alexnet'], progress=progress)
        model.load_state_dict(state_dict)
    if num_classes != 1000:
        num_in_feature = model.classifier[6].in_features
        model.classifier[6] = nn.Linear(num_in_feature, num_classes)
    return model

In [4]:
pre_args = {
  "project_folder": "./projects/test/test_other_datasets/",
  "model_name": "alexnet.pth.tar",
  "no_clients": 56,
  "no_rounds": 300,
  "fine_tuning_epochs": 50,
  "use_server_data": False,
  "fedprox": False,
  "mu": 0.01,
  "client_selection": True,
  "pretrained": False,
  "generate_dataset": False,
  "niid": None,
  "b": None,
  "p": None,
  "alpha": None,
  "dataset_path": "./data/32_Cifar10_NIID_56c_a03",
  "epochs": 200,
  "arch": "alexnet_reduced",
  "workers": 4,
  "batch_size": 128,
  "lr": 0.001,
  "momentum": 0.9,
  "weight_decay": 0.0005
}
args = Namespace(**pre_args)
args

Namespace(project_folder='./projects/test/test_other_datasets/', model_name='alexnet.pth.tar', no_clients=56, no_rounds=300, fine_tuning_epochs=50, use_server_data=False, fedprox=False, mu=0.01, client_selection=True, pretrained=False, generate_dataset=False, niid=None, b=None, p=None, alpha=None, dataset_path='./data/32_Cifar10_NIID_56c_a03', epochs=200, arch='alexnet_reduced', workers=4, batch_size=128, lr=0.001, momentum=0.9, weight_decay=0.0005)

In [6]:
alexmodel = alexnet_reduced(num_classes=100)

transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    ])

train_dataset = datasets.CIFAR100(root="../data", train=True, download=True,
        transform=transform)
train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=args.batch_size, shuffle=True,
    num_workers=args.workers, pin_memory=True)
test_dataset = datasets.CIFAR100(root="../data", train=False, download=True,
    transform=transform)
test_loader = torch.utils.data.DataLoader(
    test_dataset, batch_size=args.batch_size, shuffle=True,
    num_workers=args.workers, pin_memory=True)

Files already downloaded and verified
Files already downloaded and verified


In [31]:
# Network
cudnn.benchmark = True
num_classes = _NUM_CLASSES
criterion = nn.BCEWithLogitsLoss()
# criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(alexmodel.parameters(), args.lr,
                            momentum=args.momentum,
                            weight_decay=args.weight_decay)
# model = nn.DataParallel(model)
alexmodel = alexmodel.to(DEVICE)
criterion = criterion.to(DEVICE)
# Train & evaluation
best_acc = 0
for epoch in range(args.epochs):
    print('Epoch [{}/{}]'.format(epoch+1, args.epochs))
    adjust_learning_rate(optimizer, epoch, args)
    # train for one epoch
    train(train_loader, alexmodel, criterion, optimizer, epoch, args)
    acc = eval(test_loader, alexmodel, args)
    
    if acc > best_acc:
        best_acc = acc
    print(' ')
print('Best accuracy:', best_acc)
    
best_acc = eval(test_loader, alexmodel, args)
print('Best accuracy:', best_acc)

Epoch [1/200]
Percent: [##########] 99.74% (390/391) ESA:     0.00s, loss: 0.54, acc: 0.00%
Finish epoch 1: time =    63.35s, loss = 0.67, acc = 0.93%
Percent: [##########] 98.73% (78/79) ESA:     0.00s, acc: 0.00
Test accuracy: 1.00% (time =     7.48s)
 
Epoch [2/200]
Percent: [##########] 99.74% (390/391) ESA:     0.00s, loss: 0.06, acc: 0.00%
Finish epoch 2: time =    62.68s, loss = 0.09, acc = 0.93%
Percent: [##########] 98.73% (78/79) ESA:     0.00s, acc: 0.00
Test accuracy: 0.98% (time =     7.87s)
 
Epoch [3/200]
Percent: [##########] 99.74% (390/391) ESA:     0.00s, loss: 0.06, acc: 2.50%
Finish epoch 3: time =    63.49s, loss = 0.06, acc = 1.02%
Percent: [##########] 98.73% (78/79) ESA:     0.00s, acc: 0.00
Test accuracy: 1.04% (time =     7.73s)
 
Epoch [4/200]
Percent: [##########] 99.74% (390/391) ESA:     0.00s, loss: 0.06, acc: 1.25%
Finish epoch 4: time =    63.55s, loss = 0.06, acc = 1.05%
Percent: [##########] 98.73% (78/79) ESA:     0.00s, acc: 0.00
Test accuracy: 1.0