In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

In [2]:
# import syft as sy

In [3]:
# hook = sy.TorchHook(torch)  
# bob = sy.VirtualWorker(hook, id="bob") 
# alice = sy.VirtualWorker(hook, id="alice")  

In [4]:
class Arguments():
    def __init__(self):
        self.batch_size = 256
        self.test_batch_size = 1000
        self.epochs = 10
        self.lr = 0.01
        self.momentum = 0.5
        self.no_cuda = False
        self.seed = 1
        self.log_interval = 30
        self.save_model = False

args = Arguments()

use_cuda = not args.no_cuda and torch.cuda.is_available()

torch.manual_seed(args.seed)

device = torch.device("cuda" if use_cuda else "cpu")

In [5]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])

# federated_train_loader = sy.FederatedDataLoader(
federated_train_loader = DataLoader(
        datasets.CIFAR10(root='./dataset', 
                     train=True,
                     download=True, 
                     transform=transform), 
    batch_size=args.batch_size, 
    shuffle=True)

test_loader = DataLoader(
    datasets.CIFAR10(root='./dataset', 
                     train=False,
                     download=True, 
                     transform=transform), 
    batch_size=args.test_batch_size,
    shuffle=False)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Files already downloaded and verified
Files already downloaded and verified


In [6]:
import torchvision.models as models
googlenet = models.googlenet(pretrained=True)


use_cuda = torch.cuda.is_available()

# move tensors to GPU if CUDA is available
# if use_cuda:
#     googlenet.cuda()
    
print(googlenet)

GoogLeNet(
  (conv1): BasicConv2d(
    (conv): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (maxpool1): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=True)
  (conv2): BasicConv2d(
    (conv): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (conv3): BasicConv2d(
    (conv): Conv2d(64, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (maxpool2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=True)
  (inception3a): Inception(
    (branch1): BasicConv2d(
      (conv): Conv2d(192, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track

In [7]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(googlenet.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)

In [8]:
# Training
def train(model, device, federated_train_loader, optimizer, epoch):
    print('\nEpoch: %d' % epoch)
    model.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(federated_train_loader):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        if batch_idx % args.log_interval == 0:
#             loss = loss.get() # <-- NEW: get the loss back
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * args.batch_size, len(federated_train_loader) * args.batch_size,
                100. * batch_idx / len(federated_train_loader), loss.item()))

In [9]:
def test(model, device, test_loader):
    global best_acc
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(test_loader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format(
                test_loss, correct, len(test_loader.dataset),100. * correct / len(test_loader.dataset)))

In [10]:
%%time
model = googlenet.to(device)
optimizer = optim.SGD(model.parameters(), lr=args.lr) 

for epoch in range(1, args.epochs + 1):
    train(model, device, federated_train_loader, optimizer, epoch)
    test(model, device, test_loader)

if (args.save_model):
    torch.save(model.state_dict(), "./models/CIFAR10_cnn.pt")


Epoch: 1
Test set: Average loss: 1.3669, Accuracy: 534/10000 (5%)
Test set: Average loss: 2.8142, Accuracy: 1060/10000 (11%)
Test set: Average loss: 4.1787, Accuracy: 1586/10000 (16%)
Test set: Average loss: 5.5149, Accuracy: 2137/10000 (21%)
Test set: Average loss: 6.8553, Accuracy: 2695/10000 (27%)
Test set: Average loss: 8.2457, Accuracy: 3219/10000 (32%)
Test set: Average loss: 9.6673, Accuracy: 3743/10000 (37%)
Test set: Average loss: 11.0824, Accuracy: 4273/10000 (43%)
Test set: Average loss: 12.4203, Accuracy: 4829/10000 (48%)
Test set: Average loss: 13.7572, Accuracy: 5363/10000 (54%)

Epoch: 2
Test set: Average loss: 1.1929, Accuracy: 603/10000 (6%)
Test set: Average loss: 2.4900, Accuracy: 1210/10000 (12%)
Test set: Average loss: 3.7119, Accuracy: 1801/10000 (18%)
Test set: Average loss: 4.8601, Accuracy: 2427/10000 (24%)
Test set: Average loss: 6.0280, Accuracy: 3045/10000 (30%)
Test set: Average loss: 7.2496, Accuracy: 3644/10000 (36%)
Test set: Average loss: 8.5321, Accur

Test set: Average loss: 7.0975, Accuracy: 4922/10000 (49%)
Test set: Average loss: 8.1298, Accuracy: 5596/10000 (56%)
Test set: Average loss: 9.1301, Accuracy: 6289/10000 (63%)
Test set: Average loss: 10.1788, Accuracy: 6982/10000 (70%)

Epoch: 10
Test set: Average loss: 1.0289, Accuracy: 706/10000 (7%)
Test set: Average loss: 2.0888, Accuracy: 1401/10000 (14%)
Test set: Average loss: 3.1669, Accuracy: 2095/10000 (21%)
Test set: Average loss: 4.1480, Accuracy: 2814/10000 (28%)
Test set: Average loss: 5.1715, Accuracy: 3539/10000 (35%)
Test set: Average loss: 6.0992, Accuracy: 4257/10000 (43%)
Test set: Average loss: 7.1038, Accuracy: 4954/10000 (50%)
Test set: Average loss: 8.1454, Accuracy: 5648/10000 (56%)
Test set: Average loss: 9.1564, Accuracy: 6348/10000 (63%)
Test set: Average loss: 10.2269, Accuracy: 7043/10000 (70%)
Wall time: 5min 48s
