In [2]:
import torch
import crypten

crypten.init()

In [3]:
a = crypten.cryptensor([1,2,3])
b = crypten.cryptensor([1,2,3])

a

MPCTensor(
	_tensor=tensor([ 65536, 131072, 196608])
	plain_text=HIDDEN
	ptype=ptype.arithmetic
)

In [4]:
import crypten.mpc as mpc


a = mpc.MPCTensor([1,2,3])
a

MPCTensor(
	_tensor=tensor([ 65536, 131072, 196608])
	plain_text=HIDDEN
	ptype=ptype.arithmetic
)

In [6]:
import crypten.nn as c_nn

class BasicLinearModel(c_nn.Module):
    def __init__(self):
        super().__init__()

        self.fc1 = c_nn.Linear(4 , 4)
        self.fc2 = c_nn.Linear(4 , 2)
    
    def forward(self, x):
        out = self.fc1(x)
        out = self.fc2(out)

        return out

basic_lm = BasicLinearModel()
print(basic_lm.state_dict())

basic_lm.encrypt()

print(basic_lm.state_dict())


OrderedDict([('fc1.weight', tensor([[-0.1309,  0.0450,  0.4528, -0.0454],
        [ 0.2319,  0.3383, -0.3561,  0.0178],
        [ 0.1768, -0.4243,  0.3830, -0.4094],
        [ 0.3077,  0.0947,  0.1158, -0.3072]])), ('fc1.bias', tensor([ 0.3522, -0.0236,  0.4168,  0.0478])), ('fc2.weight', tensor([[-0.0438, -0.1931,  0.1598,  0.2168],
        [ 0.0042,  0.2861,  0.3185, -0.2077]])), ('fc2.bias', tensor([-0.1493,  0.4654]))])
OrderedDict([('fc1.weight', MPCTensor(
	_tensor=tensor([[ -8577,   2946,  29674,  -2976],
        [ 15199,  22169, -23334,   1167],
        [ 11588, -27807,  25101, -26831],
        [ 20165,   6208,   7585, -20131]])
	plain_text=HIDDEN
	ptype=ptype.arithmetic
)), ('fc1.bias', MPCTensor(
	_tensor=tensor([23081, -1546, 27317,  3130])
	plain_text=HIDDEN
	ptype=ptype.arithmetic
)), ('fc2.weight', MPCTensor(
	_tensor=tensor([[ -2870, -12655,  10472,  14208],
        [   277,  18749,  20871, -13610]])
	plain_text=HIDDEN
	ptype=ptype.arithmetic
)), ('fc2.bias', MPCTensor(


In [11]:
import crypten.communicator as comm
# print(basic_lm.state_dict())

data = crypten.rand(1, 4)

lm = BasicLinearModel()
@mpc.run_multiprocess(world_size=2)
def inference():
    
    lm.encrypt()
    lm(data)

    rank = comm.get().get_rank()
    # print(f"Rank {rank}:\n sum: {x_enc + x2_enc}\nx1: {x_enc}\nx2: {x2_enc}\n\n")
    print(f"\nRank {rank}:\n {lm.state_dict()}")


inference()

# lm.encrypt()
# print(lm.state_dict())


Rank 1:
 OrderedDict([('fc1.weight', MPCTensor(
	_tensor=tensor([[ 9193766128126929858,  7435998648562898045,  2978194639409428068,
          3091264160307227314],
        [ 3637922814226346526,  4550751558590518337,  6253665929886393792,
          5507971438917322666],
        [-3629080603663076951, -6688369296969302426, -3611403768636257806,
         -6818580741022606956],
        [ 6914344991524009335, -9140518894752618360,  1928034181795118528,
          6135844120986623104]])
	plain_text=HIDDEN
	ptype=ptype.arithmetic
)), ('fc1.bias', MPCTensor(
	_tensor=tensor([-4177940190696220180, -6375031699338444806, -2915258432312533610,
         -539072888349598628])
	plain_text=HIDDEN
	ptype=ptype.arithmetic
)), ('fc2.weight', MPCTensor(
	_tensor=tensor([[ 4689715033362257508, -5431779375541915471,  -617007408401599662,
         -8069885339622820551],
        [-4618366722576853634, -3695145305524209521, -5526165037780514576,
         -6894400471154215559]])
	plain_text=HIDDEN
	ptype=ptype

[None, None]

In [37]:
import crypten.communicator as comm


@mpc.run_multiprocess(world_size=2)
def examine_arithmetic_shares():
    x_enc = crypten.cryptensor([1, 2, 3], ptype=mpc.arithmetic)
    # x2_enc = crypten.cryptensor([2, 3, 4], ptype=crypten.mpc.arithmetic)


    rank = comm.get().get_rank()
    # print(f"Rank {rank}:\n sum: {x_enc + x2_enc}\nx1: {x_enc}\nx2: {x2_enc}\n\n")
    print(f"Rank {rank}:\n {x_enc}")
    
x = examine_arithmetic_shares()

Rank 0:
 MPCTensor(
	_tensor=tensor([ 2696400424640252575, -6461003361434373356,  6599021308001053197])
	plain_text=HIDDEN
	ptype=ptype.arithmetic
)Rank 1:
 MPCTensor(
	_tensor=tensor([-2696400424640187039,  6461003361434504428, -6599021308000856589])
	plain_text=HIDDEN
	ptype=ptype.arithmetic
)



In [13]:
import torch.nn as nn
import torch.nn.functional as F
import crypten
import crypten.mpc as mpc
import crypten.communicator as comm


#Define an example network
class ExampleNet(nn.Module):
    def __init__(self):
        super(ExampleNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=5, padding=0)
        self.fc1 = nn.Linear(16 * 12 * 12, 100)
        self.fc2 = nn.Linear(100, 2) # For binary classification, final layer needs only 2 outputs
 
    def forward(self, x):
        out = self.conv1(x)
        out = F.relu(out)
        out = F.max_pool2d(out, 2)
        out = out.view(-1, 16 * 12 * 12)
        out = self.fc1(out)
        out = F.relu(out)
        out = self.fc2(out)
        return out
    
crypten.common.serial.register_safe_class(ExampleNet)

In [9]:
crypten.init()

x_small = torch.rand(100, 1, 28, 28)
y_small = torch.randint(1, (100,))

# Transform labels into one-hot encoding
label_eye = torch.eye(2)
y_one_hot = label_eye[y_small]

# Transform all data to CrypTensors
x_train = crypten.cryptensor(x_small, src=0)
y_train = crypten.cryptensor(y_one_hot)

# Instantiate and encrypt a CrypTen model
model_plaintext = ExampleNet()
dummy_input = torch.empty(1, 1, 28, 28)
model = crypten.nn.from_pytorch(model_plaintext, dummy_input)
model.encrypt()

  param = torch.from_numpy(numpy_helper.to_array(node))


Graph encrypted module

In [15]:

@mpc.run_multiprocess(world_size=2)
def main():
    model.train() # Change to training mode
    loss = crypten.nn.MSELoss() # Choose loss functions

    # Set parameters: learning rate, num_epochs
    learning_rate = 0.001
    num_epochs = 10

    # Train the model: SGD on encrypted data
    for i in range(num_epochs):

        # forward pass
        output = model(x_train)
        loss_value = loss(output, y_train)
        
        # set gradients to zero
        model.zero_grad()

        # perform backward pass
        loss_value.backward()

        # update parameters
        model.update_parameters(learning_rate) 
        
        # examine the loss after each epoch
        print("Epoch: {0:d} Loss: {1:.4f}".format(i, loss_value.get_plain_text()))


main()

KeyboardInterrupt: 