## Train and Test

## TABLE VII + Table V
在三个不同的网络结构下的效果

In [None]:
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import tempfile

import crypten
import crypten.communicator as comm
import torch
import torch.nn as nn
import torch.nn.functional as F
from examples.util import NoopContextManager
from torchvision import datasets, transforms
import crypten.mpc as mpc
from crypten.lpgan.Secfunctions import track_network_traffic, get_network_bytes, calculate_network_traffic, timeit, print_execution_times, print_communication_costs

@mpc.run_multiprocess(world_size=2)
def run_mpc_autograd_cnn(
    context_manager=None,
    num_epochs=1,
    learning_rate=0.001,
    batch_size=128,
    print_freq=5,
    num_samples=60000,
):
    """
    Args:
        context_manager: used for setting proxy settings during download.
    """
    crypten.init()

    data_alice, data_bob, train_labels, test_data_alice, test_data_bob, test_labels = preprocess_mnist(context_manager)
    rank = comm.get().get_rank()

    # assumes at least two parties exist
    # broadcast dummy data with same shape to remaining parties
    if rank == 0:
        x_alice = data_alice
    else:
        x_alice = torch.empty(data_alice.size())

    if rank == 1:
        x_bob = data_bob
    else:
        x_bob = torch.empty(data_bob.size())

    # encrypt
    x_alice_enc = crypten.cryptensor(x_alice, src=0)
    x_bob_enc = crypten.cryptensor(x_bob, src=1)

    # combine feature sets
    x_combined_enc = crypten.cat([x_alice_enc, x_bob_enc], dim=2)
    x_combined_enc = x_combined_enc.unsqueeze(1)

    # reduce training set to num_samples
    x_reduced = x_combined_enc[:num_samples]
    y_reduced = train_labels[:num_samples]

    # encrypt plaintext model
    model_plaintext = CNN()
    dummy_input = torch.empty((1, 1, 28, 28))
    model = crypten.nn.from_pytorch(model_plaintext, dummy_input)
    model.train()
    model.encrypt()

    # encrypted training
    train_encrypted(
        x_reduced, y_reduced, model, num_epochs, learning_rate, batch_size, print_freq
    )
    
########################################################################
#===============================Test====================================
########################################################################
    # Testing on MNIST test set

    if rank == 0:
        x_test_alice = test_data_alice
    else:
        x_test_alice = torch.empty(test_data_alice.size())

    if rank == 1:
        x_test_bob = test_data_bob
    else:
        x_test_bob = torch.empty(test_data_bob.size())

    # encrypt test data
    x_test_alice_enc = crypten.cryptensor(x_test_alice, src=0)
    x_test_bob_enc = crypten.cryptensor(x_test_bob, src=1)

    x_test_combined_enc = crypten.cat([x_test_alice_enc, x_test_bob_enc], dim=2).unsqueeze(1)
    x_test = x_test_combined_enc[:10000]
    y_test = test_labels[:10000]

    model.eval()
    # run test
    test_encrypted(x_test, y_test, model, batch_size, print_freq)
    print_execution_times() # offline 的时间成本
    
@timeit
def train_encrypted(
    x_encrypted,
    y_encrypted,
    encrypted_model,
    num_epochs,
    learning_rate,
    batch_size,
    print_freq,
):
    rank = comm.get().get_rank()
    loss = crypten.nn.CrossEntropyLoss()

    num_samples = x_encrypted.size(0)
    # label_eye = torch.eye(2)
    label_eye = torch.eye(10)


    for epoch in range(num_epochs):
        last_progress_logged = 0
        # only print from rank 0 to avoid duplicates for readability
        if rank == 0:
            print(f"Epoch {epoch} in progress:")

        for j in range(0, num_samples, batch_size):

            # define the start and end of the training mini-batch
            start, end = j, min(j + batch_size, num_samples)

            # switch on autograd for training examples
            x_train = x_encrypted[start:end]
            x_train.requires_grad = True
            y_one_hot = label_eye[y_encrypted[start:end]]
            y_train = crypten.cryptensor(y_one_hot, requires_grad=True)

            # perform forward pass:
            output = encrypted_model(x_train)
            loss_value = loss(output, y_train)

            # backprop
            encrypted_model.zero_grad()
            loss_value.backward() # CrypTen implements encrypted SGD by implementing its own `backward` function
            encrypted_model.update_parameters(learning_rate)

            # log progress
            if j + batch_size - last_progress_logged >= print_freq:
                last_progress_logged += print_freq
                print(f"Loss {loss_value.get_plain_text().item():.4f}")

        # compute accuracy every epoch
        pred = output.get_plain_text().argmax(1)
        correct = pred.eq(y_encrypted[start:end])
        correct_count = correct.sum(0, keepdim=True).float()
        accuracy = correct_count.mul_(100.0 / output.size(0))

        loss_plaintext = loss_value.get_plain_text().item()
        print(
            f"Epoch {epoch} completed: "
            f"Loss {loss_plaintext:.4f} Accuracy {accuracy.item():.2f}"
        )
@timeit
@track_network_traffic
def test_encrypted(
    x_encrypted,
    y_encrypted,
    encrypted_model,
    batch_size,
    print_freq,
):
    rank = comm.get().get_rank()
    loss = crypten.nn.CrossEntropyLoss()
    

    num_samples = x_encrypted.size(0)
    label_eye = torch.eye(10)

    total_loss = 0.0
    total_correct = 0
    total_samples = 0

    for j in range(0, num_samples, batch_size):
        # define the start and end of the testing mini-batch
        start, end = j, min(j + batch_size, num_samples)

        # switch on autograd for testing examples
        x_test = x_encrypted[start:end]
        x_test.requires_grad = True
        y_one_hot = label_eye[y_encrypted[start:end]]
        y_test = crypten.cryptensor(y_one_hot, requires_grad=True)

        # perform forward pass:
        output = encrypted_model(x_test)
        loss_value = loss(output, y_test)

        # accumulate loss
        total_loss += loss_value.get_plain_text().item() * (end - start)

        # compute accuracy
        pred = output.get_plain_text().argmax(1)
        correct = pred.eq(y_encrypted[start:end])
        correct_count = correct.sum(0, keepdim=True).float()
        total_correct += correct_count.item()
        total_samples += (end - start)

        # log progress
        if (j + batch_size) % print_freq == 0:
            print(f"Processed {end}/{num_samples} samples. Loss: {loss_value.get_plain_text().item():.4f}")

    # compute average loss and accuracy
    avg_loss = total_loss / total_samples
    accuracy = total_correct / total_samples * 100.0

    if rank == 0:
        print(f"Test completed: Loss {avg_loss:.4f} Accuracy {accuracy:.2f}%")


def preprocess_mnist(context_manager):
    if context_manager is None:
        context_manager = NoopContextManager()

    # with context_manager:
    #     # each party gets a unique temp directory
    #     with tempfile.TemporaryDirectory() as data_dir:
    #         mnist_train = datasets.MNIST(data_dir, download=True, train=True)
    #         mnist_test = datasets.MNIST(data_dir, download=True, train=False)
    mnist_train = datasets.MNIST('../data', download=True, train=True)
    mnist_test = datasets.MNIST('../data', download=True, train=False)

    # modify labels so all non-zero digits have class label 1
    mnist_train.targets[mnist_train.targets != 0] = 1
    mnist_test.targets[mnist_test.targets != 0] = 1
    mnist_train.targets[mnist_train.targets == 0] = 0
    mnist_test.targets[mnist_test.targets == 0] = 0

    # compute normalization factors
    data_all = torch.cat([mnist_train.data, mnist_test.data]).float()
    data_mean, data_std = data_all.mean(), data_all.std()
    tensor_mean, tensor_std = data_mean.unsqueeze(0), data_std.unsqueeze(0)

    # normalize data
    data_train_norm = transforms.functional.normalize(
        mnist_train.data.float(), tensor_mean, tensor_std
    )
    data_test_norm = transforms.functional.normalize(
        mnist_test.data.float(), tensor_mean, tensor_std
    )

    # partition features between Alice and Bob
    data_alice = data_train_norm[:, :, :20]
    data_bob = data_train_norm[:, :, 20:]
    train_labels = mnist_train.targets

    test_data_alice = data_test_norm[:, :, :20]
    test_data_bob = data_test_norm[:, :, 20:]
    test_labels = mnist_test.targets

    return data_alice, data_bob, train_labels, test_data_alice, test_data_bob, test_labels

# 加载测试集
def load_test_data():
    mnist_test = datasets.MNIST('../data', download=True, train=True)

    # 正常化测试数据
    data_all = torch.cat([mnist_test.data]).float()
    data_mean, data_std = data_all.mean(), data_all.std()
    tensor_mean, tensor_std = data_mean.unsqueeze(0), data_std.unsqueeze(0)

    # 对测试数据进行标准化
    data_test_norm = transforms.functional.normalize(
        mnist_test.data.float(), tensor_mean, tensor_std
    )
    
    test_labels = mnist_test.targets
    return data_test_norm, test_labels

# class CNN(nn.Module):
#     """
#     C: 1conv, 2FC with ReLU
#     """
#     def __init__(self):
#         super(CNN, self).__init__()
#         self.conv1 = nn.Conv2d(1, 16, kernel_size=5, padding=0)
#         self.fc1 = nn.Linear(16 * 24 * 24, 100)  # 调整输入维度
#         self.fc2 = nn.Linear(100, 10)

#     def forward(self, x):
#         out = self.conv1(x)
#         out = F.relu(out)
#         out = out.view(-1, 9216)
#         out = self.fc1(out)
#         out = F.relu(out)
#         out = self.fc2(out)
#         return out
     
class CNN(nn.Module):
    """
    B: 1-Conv, 2-FC with square activation
    """
    def __init__(self):
        super(CNN, self).__init__()
        # 1个卷积层，输入通道1，输出通道16，卷积核大小为5，零填充0
        self.conv1 = nn.Conv2d(1, 16, kernel_size=5, padding=0)  
        
        # 两个全连接层，连接16个特征通道的 24x24 输出和128维度，然后64维度
        self.fc1 = nn.Linear(16 * 24 * 24, 128)
        self.fc2 = nn.Linear(128, 10)  # 输出维度调整为10类

    def square_activation(self, x):
        return x ** 2  # 方形激活函数

    def forward(self, x):
        # Conv 层
        print("before conv1 x'size:", x.size())  # 打印输入尺寸
        
        x = self.conv1(x)  # 经过卷积层
        print("after conv1 x'size:", x.size())  # 打印卷积后的尺寸
        
        x = self.square_activation(x)  # 应用方形激活函数
        x = x.view(-1, 16 * 24 * 24)  # 展平维度，变为(batch_size, 16*24*24)
        
        print("after view x'size:", x.size())  # 打印展平后的尺寸
        
        # 全连接层
        out = self.fc1(x)
        out = self.square_activation(out)  # 应用方形激活函数
        out = self.fc2(out)
        
        return out
    

# class CNN(nn.Module):
#     """
#     A: 3FC with square activation
#     """
#     def __init__(self):
#         super(CNN, self).__init__()
#         self.fc1 = nn.Linear(784, 128)
#         self.fc2 = nn.Linear(128, 64)
#         # self.fc3 = nn.Linear(64, 2)
#         self.fc3 = nn.Linear(64, 10)

#     def square_activation(self, x):
#         return x ** 2  # Square activation function

#     def forward(self, x):
#         print("before view x'size:", x.size())
#         x = x.view(-1, 784)
#         print("after view x'size:", x.size())
#         out = self.fc1(x)
#         out = self.square_activation(out)  # Apply square activation
#         out = self.fc2(out)
#         out = self.square_activation(out)  # Apply square activation
#         out = self.fc3(out)
#         return out
    
if __name__ == "__main__":
    run_mpc_autograd_cnn()



before conv1 x'size: 



before conv1 x'size: torch.Size([1, 1, 28, 28])




torch.Size([1, 1, 28, 28])
after conv1 x'size: 



torch.Size([1, 16, 24, 24])after conv1 x'size:
 



after view x'size:torch.Size([1, 16, 24, 24]) 




torch.Size([1, 9216])
after view x'size: 



torch.Size([1, 9216])
before conv1 x'size:before conv1 x'size: torch.Size([1, 1, 28, 28])
 torch.Size([1, 1, 28, 28])after conv1 x'size:
 torch.Size([1, 16, 24, 24])
after conv1 x'size:after view x'size:  torch.Size([1, 16, 24, 24])torch.Size([1, 9216])

after view x'size: torch.Size([1, 9216])


  param = torch.from_numpy(numpy_helper.to_array(node))
  param = torch.from_numpy(numpy_helper.to_array(node))


Epoch 0 in progress:

x_train shape after view: torch.Size([128, 1, 28, 28])

x_train shape after view: torch.Size([128, 1, 28, 28])



output's size:torch.Size([128, 10])
output's size:torch.Size([128, 10])


y_train's size:torch.Size([128, 10])
y_train's size:torch.Size([128, 10])

Loss 3.1206Loss 3.1206


x_train shape after view: torch.Size([128, 1, 28, 28])

x_train shape after view: torch.Size([128, 1, 28, 28])



output's size:torch.Size([128, 10])
output's size:torch.Size([128, 10])


y_train's size:torch.Size([128, 10])
y_train's size:torch.Size([128, 10])

Loss 2.2599Loss 2.2599


x_train shape after view: torch.Size([128, 1, 28, 28])

x_train shape after view: torch.Size([128, 1, 28, 28])



output's size:torch.Size([128, 10])
output's size:torch.Size([128, 10])


y_train's size:torch.Size([128, 10])
y_train's size:torch.Size([128, 10])

Loss 1.6512Loss 1.6512


x_train shape after view: torch.Size([128, 1, 28, 28])

x_train shape after view: torch.Size([128, 1, 28, 28])





## Table VI
SECURE TRAINING

NETWORK: A, BATCHSIZE: 128, EPOCH: 15

In [1]:
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import tempfile

import crypten
import crypten.communicator as comm
import torch
import torch.nn as nn
import torch.nn.functional as F
from examples.util import NoopContextManager
from torchvision import datasets, transforms
import crypten.mpc as mpc
from crypten.lpgan.Secfunctions import SecAnd, SecOr, get_network_bytes, calculate_network_traffic, timeit, print_execution_times, print_communication_costs

@mpc.run_multiprocess(world_size=2)
@timeit
def run_mpc_autograd_cnn(
    context_manager=None,
    num_epochs=2,
    learning_rate=0.001,
    batch_size=128,
    print_freq=5,
    num_samples=60000,
):
    """
    Args:
        context_manager: used for setting proxy settings during download.
    """
    crypten.init()

    data_alice, data_bob, train_labels = preprocess_mnist(context_manager)
    rank = comm.get().get_rank()

    # assumes at least two parties exist
    # broadcast dummy data with same shape to remaining parties
    if rank == 0:
        x_alice = data_alice
    else:
        x_alice = torch.empty(data_alice.size())

    if rank == 1:
        x_bob = data_bob
    else:
        x_bob = torch.empty(data_bob.size())

    # encrypt
    x_alice_enc = crypten.cryptensor(x_alice, src=0)
    x_bob_enc = crypten.cryptensor(x_bob, src=1)

    # combine feature sets
    x_combined_enc = crypten.cat([x_alice_enc, x_bob_enc], dim=2)
    x_combined_enc = x_combined_enc.unsqueeze(1)

    # reduce training set to num_samples
    x_reduced = x_combined_enc[:num_samples]
    y_reduced = train_labels[:num_samples]

    # encrypt plaintext model
    model_plaintext = CNN()
    dummy_input = torch.empty((1, 1, 28, 28))
    model = crypten.nn.from_pytorch(model_plaintext, dummy_input)
    model.train()
    model.encrypt()

    # encrypted training
    train_encrypted(
        x_reduced, y_reduced, model, num_epochs, learning_rate, batch_size, print_freq
    )
    print_execution_times() # offline 的时间成本

def train_encrypted(
    x_encrypted,
    y_encrypted,
    encrypted_model,
    num_epochs,
    learning_rate,
    batch_size,
    print_freq,
):
    rank = comm.get().get_rank()
    loss = crypten.nn.CrossEntropyLoss()

    num_samples = x_encrypted.size(0)
    # label_eye = torch.eye(2)
    label_eye = torch.eye(10)


    for epoch in range(num_epochs):
        last_progress_logged = 0
        # only print from rank 0 to avoid duplicates for readability
        if rank == 0:
            print(f"Epoch {epoch} in progress:")

        for j in range(0, num_samples, batch_size):

            # define the start and end of the training mini-batch
            start, end = j, min(j + batch_size, num_samples)

            # switch on autograd for training examples
            x_train = x_encrypted[start:end]
            x_train.requires_grad = True
            y_one_hot = label_eye[y_encrypted[start:end]]
            y_train = crypten.cryptensor(y_one_hot, requires_grad=True)

            # perform forward pass:
            output = encrypted_model(x_train)

            loss_value = loss(output, y_train)

            # backprop
            encrypted_model.zero_grad()
            loss_value.backward()
            encrypted_model.update_parameters(learning_rate)

            # log progress
            if j + batch_size - last_progress_logged >= print_freq:
                last_progress_logged += print_freq
                print(f"Loss {loss_value.get_plain_text().item():.4f}")

        # compute accuracy every epoch
        pred = output.get_plain_text().argmax(1)
        correct = pred.eq(y_encrypted[start:end])
        correct_count = correct.sum(0, keepdim=True).float()
        accuracy = correct_count.mul_(100.0 / output.size(0))

        loss_plaintext = loss_value.get_plain_text().item()
        print(
            f"Epoch {epoch} completed: "
            f"Loss {loss_plaintext:.4f} Accuracy {accuracy.item():.2f}"
        )


def preprocess_mnist(context_manager):
    if context_manager is None:
        context_manager = NoopContextManager()

    mnist_train = datasets.MNIST('../data', download=True, train=True)
    mnist_test = datasets.MNIST('../data', download=True, train=False)

    # modify labels so all non-zero digits have class label 1
    mnist_train.targets[mnist_train.targets != 0] = 1
    mnist_test.targets[mnist_test.targets != 0] = 1
    mnist_train.targets[mnist_train.targets == 0] = 0
    mnist_test.targets[mnist_test.targets == 0] = 0

    # compute normalization factors
    data_all = torch.cat([mnist_train.data, mnist_test.data]).float()
    data_mean, data_std = data_all.mean(), data_all.std()
    tensor_mean, tensor_std = data_mean.unsqueeze(0), data_std.unsqueeze(0)

    # normalize data
    data_train_norm = transforms.functional.normalize(
        mnist_train.data.float(), tensor_mean, tensor_std
    )
    data_test_norm = transforms.functional.normalize(
        mnist_test.data.float(), tensor_mean, tensor_std
    )

    # partition features between Alice and Bob
    data_alice = data_train_norm[:, :, :20]
    data_bob = data_train_norm[:, :, 20:]
    train_labels = mnist_train.targets

    test_data_alice = data_test_norm[:, :, :20]
    test_data_bob = data_test_norm[:, :, 20:]
    test_labels = mnist_test.targets

    return data_alice, data_bob, train_labels

# 加载测试集
def load_test_data():
    mnist_test = datasets.MNIST('../data', download=True, train=True)

    # 正常化测试数据
    data_all = torch.cat([mnist_test.data]).float()
    data_mean, data_std = data_all.mean(), data_all.std()
    tensor_mean, tensor_std = data_mean.unsqueeze(0), data_std.unsqueeze(0)

    # 对测试数据进行标准化
    data_test_norm = transforms.functional.normalize(
        mnist_test.data.float(), tensor_mean, tensor_std
    )
    
    test_labels = mnist_test.targets
    return data_test_norm, test_labels

class CNN(nn.Module):
    """
    A: 3FC with square activation
    """
    def __init__(self):
        super(CNN, self).__init__()
        self.fc1 = nn.Linear(784, 128)
        self.fc2 = nn.Linear(128, 64)
        # self.fc3 = nn.Linear(64, 2)
        self.fc3 = nn.Linear(64, 10)

    def square_activation(self, x):
        return x ** 2  # Square activation function

    def forward(self, x):
        x = x.view(-1, 784)
        out = self.fc1(x)
        out = self.square_activation(out)  # Apply square activation
        out = self.fc2(out)
        out = F.softmax(out)  # Apply square activation
        out = self.fc3(out)
        return out
    
if __name__ == "__main__":
    run_mpc_autograd_cnn()

  out = F.softmax(out)  # Apply square activation
  out = F.softmax(out)  # Apply square activation
  param = torch.from_numpy(numpy_helper.to_array(node))
  param = torch.from_numpy(numpy_helper.to_array(node))


Epoch 0 in progress:
Loss 2.6927Loss 2.6927

Loss 2.6995Loss 2.6995

Loss 2.7002Loss 2.7002

Loss 2.6963Loss 2.6963

Loss 2.6960Loss 2.6960

Loss 2.6950Loss 2.6950

Loss 2.6961Loss 2.6961

Loss 2.6945Loss 2.6945

Loss 2.6955Loss 2.6955

Loss 2.7007Loss 2.7007

Loss 2.6927Loss 2.6927

Loss 2.6934Loss 2.6934

Loss 2.6960Loss 2.6960

Loss 2.6929Loss 2.6929

Loss 2.6905Loss 2.6905

Loss 2.6941Loss 2.6941

Loss 2.6900Loss 2.6900

Loss 2.6929Loss 2.6929

Loss 2.6919Loss 2.6919

Loss 2.6865Loss 2.6865

Loss 2.6906Loss 2.6906

Loss 2.6896Loss 2.6896

Loss 2.6877Loss 2.6877

Loss 2.6880Loss 2.6880

Loss 2.6906Loss 2.6906

Loss 2.6901Loss 2.6901

Loss 2.6870Loss 2.6870

Loss 2.6879Loss 2.6879

Loss 2.6848Loss 2.6848

Loss 2.6861Loss 2.6861

Loss 2.6853Loss 2.6853

Loss 2.6877Loss 2.6877

Loss 2.6835Loss 2.6835

Loss 2.6806Loss 2.6806

Loss 2.6874Loss 2.6874

Loss 2.6794Loss 2.6794

Loss 2.6797Loss 2.6797

Loss 2.6829Loss 2.6829

Loss 2.6824Loss 2.6824

Loss 2.6790Loss 2.6790

Loss 2.6790Loss 2.6

In [4]:
import torch
import crypten

crypten.init() 

label = torch.Tensor([1, 2, 3, 4])
print(label)
# 输出: tensor([1., 2., 3., 4.])
label = crypten.cryptensor(label)
filled_tensor = crypten.cryptensor(torch.full(label.size(), 0))
crypten.print(f"filled_tensor_dec:{filled_tensor.get_plain_text()}")

tensor([1., 2., 3., 4.])
filled_tensor_dec:tensor([0., 0., 0., 0.])


