In [1]:
from pathlib import Path

import math
import numpy as np
import torch
from torch import nn
from matplotlib import pyplot as plt
from pprint import pprint
import time

import syft as sy
from syft import VirtualMachine
from syft.core.plan.plan_builder import PLAN_BUILDER_VM, make_plan, build_plan_inputs, ROOT_CLIENT
from syft.lib.python.collections.ordered_dict import OrderedDict
from syft.lib.python.list import List
from syft import logger
from syft import SyModule, SySequential

# transformers imports, not needed in AST
from transformers.models.distilbert.modeling_distilbert import DistilBertConfig, create_sinusoidal_embeddings

# Add in AST
from transformers.activations import gelu

logger.remove()

In [2]:
# Create client
alice = sy.VirtualMachine(name="alice")
alice_client = alice.get_client()
remote_torch = ROOT_CLIENT.torch

In [3]:
class SyLinear(SyModule):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.layer = nn.Linear(10, 10)
        
    def forward(self, x):
        return self.layer(x)

class SyMLP(SyModule):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.layers = nn.ModuleList([
            nn.Linear(10, 10) for _ in range (5)
        ])
        
    def forward(self, x):
        out = x
        for i in range(len(self.layers)):
            layer = self.layers[i]
            out = layer(x=out)[0]
        return out
    
dummy_x = torch.randn(2, 10)
model = SyMLP(inputs={'x': dummy_x})

RECOMPILING SyLinear
RECOMPILING SyLinear
RECOMPILING SyLinear
RECOMPILING SyLinear
RECOMPILING SyLinear
RECOMPILING SyLinear


In [4]:
"""
in allowlist.py

allowlist["torch.nn.ModuleList.__iter__"] = "syft.lib.python.Iterator"
allowlist["torch.nn.ModuleList.__len__"] = "syft.lib.python.Int"
allowlist["torch.nn.ModuleList.__getitem__"] = "torch.nn.Module"
"""

model = nn.ModuleList([nn.Linear(10, 10) for _ in range(2)])
model_ptr = model.send(alice_client)

>>> print(model_ptr[0]) # Works
<syft.proxy.torch.nn.ModulePointer object at 0x7f7bf151beb0>

>>> print(next(iter(model_ptr))) # Needs to be ModulePointer
<syft.proxy.syft.lib.misc.union.FloatIntStringTensorParameterUnionPointer object at 0x7f7beef96ee0>

Linear(in_features=1, out_features=2, bias=True)
Linear(in_features=1, out_features=2, bias=True)
<syft.proxy.torch.nn.ModulePointer object at 0x7f7bf151beb0>
<syft.proxy.syft.lib.misc.union.FloatIntStringTensorParameterUnionPointer object at 0x7f7beef96ee0>


In [7]:
type(len(m))

int