# Host model on a grid node

<h2>Import dependencies</h2>

In [1]:
import pickle
import torch
from torchvision import datasets, transforms
import torch.nn as nn
import grid as gr
import torch.nn.functional as F
import torch.optim as optim
import syft as sy
from torch.utils.data import TensorDataset, DataLoader
import time

<h2>Setup config</h2>
Define Config parameters, init hook, etc...

In [2]:
hook = sy.TorchHook(torch)
class Parser:
    """Parameters for training"""
    def __init__(self):
        self.epochs = 1
        self.lr = 0.001
        self.test_batch_size = 8
        self.batch_size = 8
        self.log_interval = 10
        self.seed = 1
    
args = Parser()

torch.manual_seed(args.seed)
kwargs = {}

<h2>Load dataset</h2>

In [3]:
test_loader  = torch.utils.data.DataLoader(
        datasets.MNIST(
            "./data",
            train=False,
            download=True,
            transform=transforms.Compose(
                [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
            ),
        ),
        batch_size=args.test_batch_size,
        shuffle=False,
        drop_last=False,
    )
(data, target) = test_loader.__iter__().next()

<h2>Define Model</h2>

In [4]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5, 1)
        self.conv2 = nn.Conv2d(20, 50, 5, 1)
        self.fc1 = nn.Linear(4 * 4 * 50, 500)
        self.fc2 = nn.Linear(500, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 4 * 4 * 50)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

<h2>Connect with remote workers</h2>
<strong><me>Before this step, it is necessary to initialize the workers separately<me></strong>

In [5]:
bob = gr.WebsocketGridClient(hook, "http://localhost:3000", id="Bob")
bob.connect()




In [6]:
model = Net()
traced_model = torch.jit.trace(model, data)

In [7]:
traced_model(data)

tensor([[-2.3152, -2.2881, -2.3457, -2.2915, -2.2961, -2.2715, -2.3289, -2.2913,
         -2.3501, -2.2519],
        [-2.2820, -2.3071, -2.3152, -2.3174, -2.2503, -2.1874, -2.3410, -2.3673,
         -2.3569, -2.3143],
        [-2.3563, -2.3302, -2.2750, -2.3400, -2.2548, -2.3052, -2.3479, -2.3081,
         -2.2579, -2.2576],
        [-2.3317, -2.2224, -2.3543, -2.3109, -2.3305, -2.2643, -2.3547, -2.3406,
         -2.2966, -2.2308],
        [-2.3190, -2.2563, -2.2857, -2.2515, -2.3096, -2.3297, -2.3566, -2.3890,
         -2.2868, -2.2513],
        [-2.3548, -2.3496, -2.2873, -2.3433, -2.2369, -2.2714, -2.3297, -2.3322,
         -2.2808, -2.2484],
        [-2.3192, -2.3576, -2.2496, -2.2803, -2.3042, -2.2563, -2.3850, -2.3073,
         -2.3136, -2.2614],
        [-2.2928, -2.3074, -2.3077, -2.3260, -2.2726, -2.2433, -2.4308, -2.3079,
         -2.2909, -2.2580]], grad_fn=<DifferentiableGraphBackward>)

In [8]:
class Net(sy.Plan):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5, 1)
        self.conv2 = nn.Conv2d(20, 50, 5, 1)
        self.fc1 = nn.Linear(4 * 4 * 50, 500)
        self.fc2 = nn.Linear(500, 10)
        
        self.add_to_state(["fc1", "fc2", "conv1", "conv2"])

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 4 * 4 * 50)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)
model = Net()

In [9]:
model.build(data)

In [10]:
# ptr_model = model.send(bob)
# ptr_data = data.send(bob)
# ptr_model(ptr_data).get()

In [11]:
data.shape

torch.Size([8, 1, 28, 28])

## Serve model

In [12]:
bob.models

[]

In [13]:
bob.serve_model(model, model_id="model")
bob.serve_model(traced_model, model_id="traced_model")

'{"success": true}'

In [14]:
bob.models

['model', 'traced_model']

## Query model

In [15]:
bob.run_inference(model_id="traced_model", data=torch.zeros(1, 1, 28, 28))

{'prediction': [[-2.322812080383301,
   -2.303769111633301,
   -2.293975830078125,
   -2.3327364921569824,
   -2.306673049926758,
   -2.274453639984131,
   -2.29142165184021,
   -2.3145298957824707,
   -2.3104403018951416,
   -2.2766435146331787]]}

In [16]:
bob.run_inference(model_id="model", data=torch.zeros(1, 1, 28, 28))

{'prediction': [[-2.2801671028137207,
   -2.3129661083221436,
   -2.2660160064697266,
   -2.313047170639038,
   -2.3106632232666016,
   -2.3073132038116455,
   -2.294689655303955,
   -2.2897844314575195,
   -2.339111089706421,
   -2.3140127658843994]]}