In [58]:
import torch
from torchvision.models import resnet50, ResNet50_Weights
from torchvision.datasets import MNIST, CIFAR10 
import torchvision.transforms as transforms
from tqdm import tqdm

In [59]:
torch.manual_seed(0)

<torch._C.Generator at 0x1115a3d30>

In [60]:
queries = list(range(2000,10001,2000))
print(queries)

[2000, 4000, 6000, 8000, 10000]


In [61]:
pretrained_model = resnet50(weights=ResNet50_Weights.DEFAULT)

knockoff_model = resnet50(weights=None)

In [62]:
transform = transforms.Compose(
    [transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))])

mnist_train = MNIST('./data', train=True, transform=transform, download=True)
mnist_test = MNIST('./data', train=False, transform=transform, download=True)
cifar_train= CIFAR10('./data', train=True, transform=transform, download=True)
cifar_test= CIFAR10('./data', train=False, transform=transform, download=True)


Files already downloaded and verified
Files already downloaded and verified


In [63]:
mnist_train_loader = torch.utils.data.DataLoader(mnist_train, batch_size=4, shuffle=True)
mnist_test_loader = torch.utils.data.DataLoader(mnist_test, batch_size=4, shuffle=True)
cifar_train_loader = torch.utils.data.DataLoader(cifar_train, batch_size=4, shuffle=True)
cifar_test_loader = torch.utils.data.DataLoader(cifar_test, batch_size=4, shuffle=True)

In [64]:
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(knockoff_model.parameters())

In [65]:
def training_knockoff(pretrained_model, knockoff_model, num_queries, epoch, data_loader):
    for i in range(1, epoch+1):
        print("epoch", i)
        for _ in tqdm(range(num_queries)):
            # Every data instance is an input + label pair
            inputs, _ = next(iter(data_loader))

            # Zero your gradients for every batch!
            optimizer.zero_grad()
            labels = pretrained_model(inputs)

            # Make predictions for this batch
            outputs = knockoff_model(inputs)

            # Compute the loss and its gradients
            loss = loss_fn(outputs, labels)
            loss.backward()

            # Adjust learning weights
            optimizer.step()
    return loss


In [66]:
def testing_knockoff(pretrained_model, knockoff_model, data_loader):
    correct = 0
    count = 0
    for data in data_loader:
        inputs, _ = data
        pretrained_output = pretrained_model(inputs)
        knockoff_output = knockoff_model(inputs)
        if pretrained_output == knockoff_output:
            correct += 1
        count += 1
    return correct/count

In [68]:
accuracy = []

for num_query in queries:
    print("Queries:", num_query)
    training_knockoff(pretrained_model, knockoff_model, num_query, 10, cifar_train_loader)
    acc = testing_knockoff(pretrained_model, knockoff_model, cifar_test_loader)
    print(acc)
    accuracy.append(acc)
    # Reset knockoff model to new resnet
    knockoff_model = resnet50(weights=None)


Queries: 2000
epoch 1


100%|██████████| 2000/2000 [11:55<00:00,  2.79it/s]


epoch 2


 35%|███▌      | 706/2000 [04:09<07:37,  2.83it/s]


KeyboardInterrupt: 