In [1]:
class Arguments:
    encrypted = True
    dtype = 'long'
    protocol = "fss"
    precision_fractional = 4

args = Arguments()

# ResNet-18

In [2]:
import time
import torch
from torchvision import transforms

import syft as sy

hook = sy.TorchHook(torch) 
bob = sy.VirtualWorker(hook, id="bob")
alice = sy.VirtualWorker(hook, id="alice")
crypto_provider = sy.VirtualWorker(hook, id="crypto_provider")

workers = [alice, bob]
sy.local_worker.clients = workers

kwargs = dict(crypto_provider=crypto_provider, protocol=args.protocol, requires_grad=False)
snapshots = []

In [4]:
from torchlib.models import resnet18
from torchlib.dataloader import PPPP
from sklearn.metrics import classification_report

model = resnet18(
    pretrained=True,
    num_classes=3,
    in_channels=3,
    pooling="max",
    adptpool=False,
    input_size=224,
)
state = torch.load("/Users/tryffel/code/4P/resnet_maxpool.pt", map_location='cpu')
model.load_state_dict(state["model_state_dict"])
model.pool, model.relu = model.relu, model.pool
dataloader = torch.utils.data.DataLoader(
    PPPP(
        train=False,
        transform=transforms.Compose(
            [
                transforms.Resize(224),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                transforms.Normalize((0.57282609,), (0.17427578,)),
                transforms.Lambda(lambda x: torch.repeat_interleave(x, 3, dim=0)),
            ]
        ),
    ),
    batch_size=1,
)
model.eval()
if args.encrypted:
    model.fix_precision(precision_fractional=args.precision_fractional, dtype=args.dtype).share(*workers, **kwargs)

predictions, targets = [], []
with torch.no_grad():
    for data, target in dataloader:
        if args.encrypted:
            t = time.time()
            data = data.fix_precision(precision_fractional=args.precision_fractional, dtype=args.dtype).share(*workers, **kwargs)
        prediction = model(data)
        if args.encrypted:
            prediction = prediction.get().float_prec()
            print(time.time() - t)
        print(prediction)
        prediction = prediction.argmax(dim=1)
        predictions.append(prediction.item())
        targets.append(target.item())
        print(round((torch.tensor(predictions) == torch.tensor(targets)).sum().item()/len(predictions)*100))


136.06766986846924
tensor([[-4.9330,  3.2527,  2.2385]])
100
137.87865114212036
tensor([[-4.5752,  4.6929,  0.4374]])
100
148.39177703857422
tensor([[ 8.4596, -5.5816, -1.2273]])
100
