In [None]:
import syft as sy
import torch
from torchvision import datasets
from torchvision import transforms
from collections import OrderedDict
import tenseal as ts
sy.load("tenseal")

In [None]:
duet = sy.duet(loopback=True)

♫♫♫ > DUET LIVE STATUS  -  Objects: 1  Requests: 0   Messages: 27  Request Handlers: 1                                

In [4]:
duet.requests.add_handler(action="accept")

### <img src="https://github.com/OpenMined/design-assets/raw/master/logos/OM/mark-primary-light.png" alt="he-black-box" width="100"/> Checkpoint 1 : Now STOP and run the Data Scientist notebook until the same checkpoint.

In [5]:
mnist_mean = 0.1307
mnist_std = 0.3081
batch_size = 64
epochs = 10
lr = 0.1
sigma = 1.0
max_per_sample_grad_norm = 1.0
delta = 1e-5
root = "."
weights_filename = "mnist_cnn_weights.pt"
device = torch.device("cpu")

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST(
        root,
        train=False,
        download=True,
        transform=transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize((mnist_mean,), (mnist_std,)),
            ]
        ),
    ),
    batch_size=100,
    shuffle=True,
    num_workers = 1,
    pin_memory = True
)

In [6]:
from torch import nn

# this shouldn't be here, but can't make it work
class ConvolutionalBase(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 16, 8, 2, padding=3)
        self.conv2 = nn.Conv2d(16, 32, 4, 2)
        
    def forward(self, x):
        x = nn.functional.relu(self.conv1(x))
        x = nn.functional.max_pool2d(x, 2, 1)
        x = nn.functional.relu(self.conv2(x))
        x = nn.functional.max_pool2d(x, 2, 1)
        x = x.view(-1, 32 * 4 * 4)
        return x

In [7]:
conv_base_names = duet.store["conv_base_names"].get()
conv_base_weights = duet.store["conv_base_weights"].get()

In [8]:
conv_base_state_dict = OrderedDict()

for conv_base_name, conv_base_weight in zip(conv_base_names, conv_base_weights):
    conv_base_state_dict[conv_base_name] = conv_base_weight

conv_base = ConvolutionalBase()
conv_base.load_state_dict(conv_base_state_dict)

<All keys matched successfully>

In [9]:
secret_data, secret_labels = next(iter(test_loader))
test_batch_size = 1
offset = 20
data_sample = secret_data[offset: offset + test_batch_size]
label_sample = secret_labels[offset: offset + test_batch_size]

batch_size_ptr = sy.lib.python.Int(test_batch_size).send(duet, searchable=True, tags=["batch_size"])

  return torch._C._cuda_getDeviceCount() > 0


In [10]:
intermediary_activation_data_sample = conv_base(data_sample)

In [11]:
def create_ctx():
    poly_mod_degree = 8192
    coeff_mod_bit_sizes = [40, 21, 21, 21, 21, 40]
    ctx = ts.context(ts.SCHEME_TYPE.CKKS, poly_mod_degree, -1, coeff_mod_bit_sizes)
    ctx.global_scale = 2 ** 21
    return ctx

def prepare_encrypted_activation(context, activation_sample):
    return ts.ckks_tensor(context, activation_sample.detach().numpy())

In [12]:
context = create_ctx()
ctx_ptr = context.send(duet, searchable=True, tags=["context"])

In [13]:
encrypted_activation = prepare_encrypted_activation(context, intermediary_activation_data_sample)
encrypted_activation_ptr = encrypted_activation.send(duet, searchable=True, tags=["encrypted_activation"])

In [14]:
duet.store.pandas

Unnamed: 0,ID,Tags,Description,object_type
0,<UID: 0bb0c69c82ba43eabdd8cf77159492f7>,[batch_size],,<class 'syft.lib.python.Int'>
1,<UID: db5832272cfd4dea8f41c29d474a1e78>,[context],,<class 'tenseal.enc_context.Context'>
2,<UID: de970c286e4342f49c90e75244817713>,[encrypted_activation],,<class 'tenseal.tensors.ckkstensor.CKKSTensor'>


### <img src="https://github.com/OpenMined/design-assets/raw/master/logos/OM/mark-primary-light.png" alt="he-black-box" width="100"/> Checkpoint 2 : Now STOP and run the Data Scientist notebook until the same checkpoint.

In [None]:
result = duet.store["result"].get(delete_obj=False)
result.link_context(context)

result = result.decrypt()

In [None]:
probs = torch.nn.functional.softmax(torch.tensor(result.tolist()), dim=1)
label_max = torch.argmax(probs, dim=1)

print(f"Maximum probability for label {label_max} with true_label {label_sample}")

### <img src="https://github.com/OpenMined/design-assets/raw/master/logos/OM/mark-primary-light.png" alt="he-black-box" width="100"/> Checkpoint 3 : Well done!

# Congratulations!!! - Time to Join the Community!

Congratulations on completing this notebook tutorial! If you enjoyed this and would like to join the movement toward privacy preserving, decentralized ownership of AI and the AI supply chain (data), you can do so in the following ways!

### Star PySyft on GitHub

The easiest way to help our community is just by starring the Repos! This helps raise awareness of the cool tools we're building.

- [Star PySyft](https://github.com/OpenMined/PySyft)

### Join our Slack!

The best way to keep up to date on the latest advancements is to join our community! You can do so by filling out the form at [http://slack.openmined.org](http://slack.openmined.org). #lib_tenseal and #code_tenseal are the main channels for the TenSEAL project.

### Donate

If you don't have time to contribute to our codebase, but would still like to lend support, you can also become a Backer on our Open Collective. All donations go toward our web hosting and other community expenses such as hackathons and meetups!

[OpenMined's Open Collective Page](https://opencollective.com/openmined)