<a href="https://colab.research.google.com/github/ashwindasr/Federated-Learning/blob/master/federated-cats-vs-dogs-classifier.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torchvision import datasets, transforms, models

In [0]:
# Uncomment this cell and run it only once.

# !git clone https://github.com/OpenMined/PySyft.git
# !python ./PySyft/ setup.py test
# !pip install syft

In [0]:
import syft as sy                           # Import the Pysyft library                          
hook = sy.TorchHook(torch)                  # Hook PyTorch to PySyft
bob = sy.VirtualWorker(hook, id="bob")      # Define remote worker bob 
alice = sy.VirtualWorker(hook, id="alice")  # Define remote worker alice

In [0]:
#Initializing values for arguments
class Arguments():
    def __init__(self):
        self.batch_size = 64
        self.test_batch_size = 1000
        self.epochs = 20
        self.lr = 0.01
        self.momentum = 0.5
        self.no_cuda = False
        self.seed = 1
        self.log_interval = 10
        self.save_model = False

args = Arguments()

use_cuda = not args.no_cuda and torch.cuda.is_available()       # Check if GPU is available
torch.manual_seed(args.seed)
device = torch.device("cuda" if use_cuda else "cpu")            # Use set device to CPU or GPU(if available)
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}

In [0]:
# Run this cell only once

!unzip /content/data.zip

# Make sure your data folder has the following  folder hierarchy:
# 
#        data
#          |- train
#          |     |- cats
#          |     |- dogs
#          |- test
#               |- cats
#               |- dogs

In [0]:
# Set the path to the extracted data folder
data_dir = '/content/data/data'

# Define transforms for the training data and testing data
train_transforms = transforms.Compose([transforms.RandomRotation(30),
                                       transforms.RandomResizedCrop(224),
                                       transforms.RandomHorizontalFlip(),
                                       transforms.ToTensor(),
                                       transforms.Normalize([0.485, 0.456, 0.406],
                                                            [0.229, 0.224, 0.225])])

test_transforms = transforms.Compose([transforms.Resize(255),
                                      transforms.CenterCrop(224),
                                      transforms.ToTensor(),
                                      transforms.Normalize([0.485, 0.456, 0.406],
                                                           [0.229, 0.224, 0.225])])

# Load data using Pytorch's ImageFolder dataloader. 
train_data = datasets.ImageFolder(data_dir + '/train', transform=train_transforms)
test_data = datasets.ImageFolder(data_dir + '/test', transform=test_transforms)

# Use PySyft's Federated Data Loader to load the federated data.
federated_train_loader = sy.FederatedDataLoader( 
    train_data.federate((bob,alice)), # .federate(()) function distributes the data to the virtual workers
    batch_size=64, shuffle=True)

# If you encounter an error in the above command, comment out the cell which contains the
# installaion files of PySyft, restart runtime and run again.

# We use the normal non-federated Dataloader for testing.
test_loader = torch.utils.data.DataLoader(test_data, batch_size=64)

In [0]:
def train(args, model, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(federated_train_loader):
        model.send(data.location)         # Send the model to where the data is. (Bob or Alice)
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        model.get()                       # Get the updated weights from the virtual workers
        if batch_idx % args.log_interval == 0:
            loss = loss.get()             # Get the new loss from the virtual workers
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * args.batch_size, len(train_loader) * args.batch_size, 
                100. * batch_idx / len(train_loader), loss.item()))

In [0]:
# Test the data in an un-federated way as we normally do. 
def test(args, model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()
            pred = output.argmax(1, keepdim=True)  
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

In [0]:
# Load and modify the model of your convenience, here we use Resnet50
model = torchvision.models.resnet50(pretrained=True)

model.fc = nn.Sequential(nn.Linear(2048, 512),
                                 nn.ReLU(),
                                 nn.Dropout(0.2),
                                 nn.Linear(512, 10),
                                 nn.LogSoftmax(dim=1))  
optimizer = optim.SGD(model.parameters(), lr=args.lr)
model = model.to(device)

In [0]:
# Train the model and save the weights
for epoch in range(1, args.epochs ):
    train(args, model, device, federated_train_loader, optimizer, epoch)
    test(args, model, device, test_loader)

if (args.save_model):
    torch.save(model.state_dict(), "model.pt")