In [33]:
import crypten
import crypten.optim
import torch
from torchvision import datasets
from torchvision.transforms import ToTensor
from torch_nn_modules import ExampleNet, test
from torch.utils.data import DataLoader
import crypten.mpc as mpc
import time
import torch.nn as nn
import csv

In [34]:
crypten.init()
torch.set_num_threads(1)



In [35]:
enc_loss_fn = crypten.nn.CrossEntropyLoss()
plain_loss_fn = nn.CrossEntropyLoss()

In [36]:
# loss tests

# plain text
plain_rand = torch.randn(3,5) * 100
print("base rand", plain_rand)
plain_target = torch.randn(3, 5).softmax(dim=1)
print("target", plain_target)

# crypten
rand_enc = crypten.cryptensor(plain_rand)
target_enc = crypten.cryptensor(plain_target)

base rand tensor([[ 101.1814,  -46.6721,  -87.3501,  -11.3377, -127.6126],
        [ -83.2465, -133.9910, -169.2950,  -97.4745,  -75.6243],
        [-105.4851,   14.7071,   93.2591,  -73.9188,  -17.2318]])
target tensor([[0.0763, 0.2423, 0.2675, 0.2477, 0.1662],
        [0.1838, 0.1963, 0.2546, 0.1342, 0.2310],
        [0.1922, 0.1662, 0.3351, 0.1565, 0.1500]])


In [37]:
# crypten softmax
print("softmax plain", plain_rand.softmax(dim=1))
print("softmax enc", rand_enc.softmax(dim=1).get_plain_text())

softmax plain tensor([[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
        [4.8922e-04, 4.4818e-26, 2.0851e-41, 3.2388e-10, 9.9951e-01],
        [0.0000e+00, 7.6784e-35, 1.0000e+00, 0.0000e+00, 0.0000e+00]])
softmax enc tensor([[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
        [4.1199e-04, 0.0000e+00, 0.0000e+00, 0.0000e+00, 9.9959e-01],
        [0.0000e+00, 0.0000e+00, 1.0000e+00, 0.0000e+00, 0.0000e+00]])


In [38]:
plain_loss = plain_loss_fn(plain_rand, plain_target)

enc_loss = enc_loss_fn(rand_enc, target_enc)
dec_loss = enc_loss.get_plain_text()

print("plaintext loss", plain_loss)
crypten.print("crypten loss", dec_loss)

plaintext loss tensor(95.2656)
crypten loss tensor(10.9005)


In [39]:
rand_enc_softmax = rand_enc.softmax(dim=1)

loss_values = rand_enc_softmax.log(input_in_01=True).mul_(target_enc).neg_()
crypten.print("log, mul, neg:", loss_values.get_plain_text())
final_values = loss_values.sum().div_(target_enc.size(0))
crypten.print("sum, div:", final_values.get_plain_text())

log_softmax = rand_enc.log_softmax(dim=1)
print("log_softmax", log_softmax.get_plain_text())
plain_softmax = plain_rand.log_softmax(dim=1)
print("plain_softmax", plain_softmax)

log, mul, neg: tensor([[2.9449e-03, 3.4815e+00, 3.8447e+00, 3.5598e+00, 2.3890e+00],
        [1.4373e+00, 2.8215e+00, 3.6585e+00, 1.9289e+00, 8.9874e-03],
        [2.7627e+00, 2.3881e+00, 1.2939e-02, 2.2491e+00, 2.1555e+00]])
sum, div: tensor(10.9005)
log_softmax tensor([[-4.1199e-04, -1.4785e+02, -1.8853e+02, -1.1252e+02, -2.2879e+02],
        [-7.6233e+00, -5.8368e+01, -9.3672e+01, -2.1851e+01, -1.0986e-03],
        [-1.9874e+02, -7.8552e+01, -4.1199e-04, -1.6718e+02, -1.1049e+02]])
plain_softmax tensor([[ 0.0000e+00, -1.4785e+02, -1.8853e+02, -1.1252e+02, -2.2879e+02],
        [-7.6227e+00, -5.8367e+01, -9.3671e+01, -2.1851e+01, -4.8935e-04],
        [-1.9874e+02, -7.8552e+01,  0.0000e+00, -1.6718e+02, -1.1049e+02]])


In [40]:
log_enc = rand_enc_softmax.log(input_in_01=True)
print("log_enc", log_enc.get_plain_text())
log_plain = plain_softmax
print("log_plain", log_plain)

log_enc tensor([[ -0.0386, -14.3718, -14.3718, -14.3718, -14.3718],
        [ -7.8191, -14.3718, -14.3718, -14.3718,  -0.0389],
        [-14.3718, -14.3718,  -0.0386, -14.3718, -14.3718]])
log_plain tensor([[ 0.0000e+00, -1.4785e+02, -1.8853e+02, -1.1252e+02, -2.2879e+02],
        [-7.6227e+00, -5.8367e+01, -9.3671e+01, -2.1851e+01, -4.8935e-04],
        [-1.9874e+02, -7.8552e+01,  0.0000e+00, -1.6718e+02, -1.1049e+02]])


In [41]:

log_mul_enc = log_enc.mul(target_enc)
print("log_mul_enc", log_mul_enc.get_plain_text())
log_mul_plain = log_plain * plain_target
print("log_mul_plain", log_mul_plain)

log_mul_enc tensor([[-2.9449e-03, -3.4815e+00, -3.8447e+00, -3.5598e+00, -2.3890e+00],
        [-1.4373e+00, -2.8215e+00, -3.6585e+00, -1.9289e+00, -8.9874e-03],
        [-2.7627e+00, -2.3881e+00, -1.2939e-02, -2.2491e+00, -2.1555e+00]])
log_mul_plain tensor([[ 0.0000e+00, -3.5819e+01, -5.0436e+01, -2.7871e+01, -3.8034e+01],
        [-1.4013e+00, -1.1459e+01, -2.3846e+01, -2.9330e+00, -1.1306e-04],
        [-3.8206e+01, -1.3054e+01,  0.0000e+00, -2.6165e+01, -1.6572e+01]])


In [42]:
log_mul_enc_neg = log_mul_enc.neg()
print("log_mul_enc_neg", log_mul_enc_neg.get_plain_text())
log_mul_plain_neg = log_mul_plain.neg()
print("log_mul_plain_neg", log_mul_plain_neg)

log_mul_enc_neg tensor([[2.9449e-03, 3.4815e+00, 3.8447e+00, 3.5598e+00, 2.3890e+00],
        [1.4373e+00, 2.8215e+00, 3.6585e+00, 1.9289e+00, 8.9874e-03],
        [2.7627e+00, 2.3881e+00, 1.2939e-02, 2.2491e+00, 2.1555e+00]])
log_mul_plain_neg tensor([[-0.0000e+00, 3.5819e+01, 5.0436e+01, 2.7871e+01, 3.8034e+01],
        [1.4013e+00, 1.1459e+01, 2.3846e+01, 2.9330e+00, 1.1306e-04],
        [3.8206e+01, 1.3054e+01, -0.0000e+00, 2.6165e+01, 1.6572e+01]])


In [43]:

log_mul_enc_neg_sum = log_mul_enc_neg.sum()
print("log_mul_enc_neg_sum", log_mul_enc_neg_sum.get_plain_text())
log_mul_plain_neg_sum = log_mul_plain_neg.sum()
print("log_mul_plain_neg_sum", log_mul_plain_neg_sum)

log_mul_enc_neg_sum tensor(32.7015)
log_mul_plain_neg_sum tensor(285.7968)


In [44]:

final_enc = log_mul_enc_neg_sum.div(target_enc.size(0))
print("final_enc", final_enc.get_plain_text())
final_plain = log_mul_plain_neg_sum / plain_target.size(0)
print("final_plain", final_plain)

final_enc tensor(10.9005)
final_plain tensor(95.2656)
