## 1 pytorch 训练线性模型

In [1]:
import torch
from torch import nn
from torch import optim

In [6]:

# A Toy Dataset
data = torch.tensor([[0,0],[0,1],[1,0],[1,1.]])
target = torch.tensor([[0],[0],[1],[1.]])

# A Toy Model
model = nn.Linear(2,1)
# print(model.parameters())
def train():
    # Training Logic
    opt = optim.SGD(params=model.parameters(),lr=0.1)
    for iter in range(20):

        # 1) erase previous gradients (if they exist)
        opt.zero_grad()

        # 2) make a prediction
        pred = model(data)

        # 3) calculate how much we missed
        loss = ((pred - target)**2).sum()

        # 4) figure out which weights caused us to miss
        loss.backward()

        # 5) change those weights
        opt.step()

        # 6) print our progress
        print(loss.data)

<generator object Module.parameters at 0x000002F6522F9580>


In [3]:
train()

tensor(0.2655)
tensor(0.1714)
tensor(0.1110)
tensor(0.0721)
tensor(0.0469)
tensor(0.0306)
tensor(0.0201)
tensor(0.0132)
tensor(0.0087)
tensor(0.0058)
tensor(0.0039)
tensor(0.0026)
tensor(0.0018)
tensor(0.0012)
tensor(0.0008)
tensor(0.0006)
tensor(0.0004)
tensor(0.0003)
tensor(0.0002)
tensor(0.0001)


## 2 syft训练

> 1. 注意这里的model.send和model.get会对model中每一个变量执行a = a.send()和b=b.get()
> 2. 这样，每次训练的时候，模型都会被完全转移到客户端。训练完成后，客户端的模型被完全转移到服务端。
> 3. 这样，表示两个端进行连续的训练，一端训练完成模型再由另一个客户端进行训练。

In [7]:
import syft as sy
hook = sy.TorchHook(torch)

In [8]:
# create a couple workers

bob = sy.VirtualWorker(hook, id="bob")
alice = sy.VirtualWorker(hook, id="alice")

In [9]:
# A Toy Dataset
data = torch.tensor([[0,0],[0,1],[1,0],[1,1.]], requires_grad=True)
target = torch.tensor([[0],[0],[1],[1.]], requires_grad=True)

# get pointers to training data on each worker by
# sending some training data to bob and alice
data_bob = data[0:2]
target_bob = target[0:2]

data_alice = data[2:]
target_alice = target[2:]

# Iniitalize A Toy Model
model = nn.Linear(2,1)

data_bob = data_bob.send(bob)
data_alice = data_alice.send(alice)
target_bob = target_bob.send(bob)
target_alice = target_alice.send(alice)

# organize pointers into a list
datasets = [(data_bob,target_bob),(data_alice,target_alice)]

In [10]:
from syft.federated.floptimizer import Optims
workers = ['bob', 'alice']
optims = Optims(workers, optim=optim.Adam(params=model.parameters(),lr=0.1))

In [11]:
def train():
    # Training Logic
    for iter in range(10):
        
        # NEW) iterate through each worker's dataset
        for data,target in datasets:
            
            # NEW) send model to correct worker
            model.send(data.location)
            
            #Call the optimizer for the worker using get_optim
            opt = optims.get_optim(data.location.id)
            #print(data.location.id)

            # 1) erase previous gradients (if they exist)
            opt.zero_grad()

            # 2) make a prediction
            pred = model(data)

            # 3) calculate how much we missed
            loss = ((pred - target)**2).sum()

            # 4) figure out which weights caused us to miss
            loss.backward()

            # 5) change those weights
            opt.step()
            
            # NEW) get model (with gradients)
            model.get()

            # 6) print our progress
            print(loss.get()) # NEW) slight edit... need to call .get() on loss\
    
# federated averaging

In [12]:
train()

tensor(0.2523, requires_grad=True)
tensor(1.1397, requires_grad=True)
tensor(0.0076, requires_grad=True)
tensor(0.3003, requires_grad=True)
tensor(0.1095, requires_grad=True)
tensor(0.0394, requires_grad=True)
tensor(0.2870, requires_grad=True)
tensor(0.0093, requires_grad=True)
tensor(0.3682, requires_grad=True)
tensor(0.0296, requires_grad=True)
tensor(0.3454, requires_grad=True)
tensor(0.0340, requires_grad=True)
tensor(0.2584, requires_grad=True)
tensor(0.0178, requires_grad=True)
tensor(0.1542, requires_grad=True)
tensor(0.0018, requires_grad=True)
tensor(0.0689, requires_grad=True)
tensor(0.0040, requires_grad=True)
tensor(0.0196, requires_grad=True)
tensor(0.0250, requires_grad=True)
