In [1]:
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6


from tqdm import tqdm
import matplotlib.pyplot as plt
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

import torch
from torch import nn
from torch.utils.data import DataLoader

from utils import get_dataset
# from options import args_parser
from update import test_inference
from models import MLP, CNNMnist, CNNFashion_Mnist, CNNCifar

# python src/baseline_main.py --model=mlp --dataset=mnist --epochs=10


In [2]:
class Args(object):
    
    # federated parameters (default values are set)
    epochs = 10
    num_users = 10
    frac = 1 # fraction of clients
    local_ep = 5 # num of local epoch
    local_bs = 128 # batch size
    lr = 0.001
    momentum = 0.9
    
    # model arguments
    model = 'cnn'
    kernel_num = 9 # num of each kind of kernel
    kernel_sizes = '3,4,5' # comma-separated kernel size to use for convolution
    num_channels = 1 # num of channels of imgs
    norm = 'batch_norm' # batch_norm, layer_norm, None
    num_filters = 32 # num of filters for conv nets -- 32 for mini-imagenet, 64 for omiglot
    max_pool = 'True' # whether use max pooling rather than strided convolutions
    
    # other arguments
    dataset = "fmnist"
    num_classes = 10 
    gpu = 0
    optimizer = 'sgd'
    iid = 1 # 0 for non-iid
    unequal = 0 # whether to use unequal data splits for non-iid settings (0 for equal splits)
    stopping_rounds = 10 # rounds of early stopping
    verbose = 1
    seed = 1

    # malicious arguments
    n_attackers = [0]
    attack_type = None #'untargeted_mkrum' # untargeted_med, untargeted_mkrum

In [3]:
# bsz : batch size (number of positive pairs)
# d   : latent dim
# x   : Tensor, shape=[bsz, d]
#       latents for one side of positive pairs
# y   : Tensor, shape=[bsz, d]
#       latents for the other side of positive pairs

def align_loss(x, y, alpha=2):
    return (x - y).norm(p=2, dim=1).pow(alpha).mean()

def uniform_loss(x, t=2):
    return torch.pdist(x, p=2).pow(2).mul(-t).exp().mean().log()

In [4]:

if __name__ == '__main__':
    args = Args()
    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

    # load datasets
    train_dataset, test_dataset, _ = get_dataset(args)

    # BUILD MODEL
    if args.model == 'cnn':
        # Convolutional neural netork
        if args.dataset == 'mnist':
            global_model = CNNMnist(args=args)
        elif args.dataset == 'fmnist':
            global_model = CNNFashion_Mnist(args=args)
        elif args.dataset == 'cifar':
            global_model = CNNCifar(args=args)
    elif args.model == 'mlp':
        # Multi-layer preceptron
        img_size = train_dataset[0][0].shape
        len_in = 1
        for x in img_size:
            len_in *= x
            global_model = MLP(dim_in=len_in, dim_hidden=64,
                               dim_out=args.num_classes)
    else:
        exit('Error: unrecognized model')

    # Set the model to train and send it to device.
    global_model.to(device)
    global_model.train()
    print(global_model)

    # Training
    # Set optimizer and criterion
    if args.optimizer == 'sgd':
        optimizer = torch.optim.SGD(global_model.parameters(), lr=args.lr,
                                    momentum=0.5)
    elif args.optimizer == 'adam':
        optimizer = torch.optim.Adam(global_model.parameters(), lr=args.lr,
                                     weight_decay=1e-4)

    trainloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
    criterion = nn.CrossEntropyLoss().to(device)
    epoch_loss = []

    for epoch in tqdm(range(args.epochs)):
        batch_loss = []

        for batch_idx, (images, labels) in enumerate(trainloader):
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = global_model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            if batch_idx % 50 == 0:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch+1, batch_idx * len(images), len(trainloader.dataset),
                    100. * batch_idx / len(trainloader), loss.item()))
            batch_loss.append(loss.item())

        loss_avg = sum(batch_loss)/len(batch_loss)
        print('\nTrain loss:', loss_avg)
        epoch_loss.append(loss_avg)

    # Plot loss
    plt.figure()
    plt.plot(range(len(epoch_loss)), epoch_loss)
    plt.xlabel('epochs')
    plt.ylabel('Train loss')
    plt.savefig('../save/nn_{}_{}_{}.png'.format(args.dataset, args.model,
                                                 args.epochs))

    # testing
    test_acc, test_loss = test_inference(args, global_model, test_dataset)
    print('Test on', len(test_dataset), 'samples')
    print("Test Accuracy: {:.2f}%".format(100*test_acc))

CNNFashion_Mnist(
  (conv1): Conv2d(1, 64, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(64, 64, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=25600, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=10, bias=True)
)


  0%|                                                                                           | 0/10 [00:00<?, ?it/s]



 10%|████████▎                                                                          | 1/10 [00:15<02:17, 15.32s/it]


Train loss: 1.003701526663705


 20%|████████████████▌                                                                  | 2/10 [00:27<01:50, 13.75s/it]


Train loss: 0.6446382510128306


 20%|████████████████▌                                                                  | 2/10 [00:29<01:56, 14.61s/it]


KeyboardInterrupt: 