In [None]:
import torch
import torchvision.datasets as ds
import torchvision.transforms as transforms

In [None]:
# Load the MNIST training dataset.
# ToTensor converts PIL image to (CxHxW) in the range [0.0, 1.0].
train_set = ds.MNIST(".data", train=True, transform=transforms.ToTensor(), download=True)

In [None]:
# Just select examples with labels 0 or 1.
X_, y_ = zip(*[i for i in train_set if i[1] < 2])

In [None]:
import torchvision.utils as u
import matplotlib.pyplot as plt
import numpy as np

# Plot the first 100 examples of the dataset.
plt.imshow(np.transpose(u.make_grid(list(X_[:100]), 10).numpy(), (1,2,0)))

In [None]:
# Each image has a size of 28x28.
n = 28*28

# Convert lists of examples and labels to tensors.
X = torch.stack(X_).view((-1, n))
y = torch.tensor(y_).view(-1, 1).float()

In [None]:
from tqdm import tqdm_notebook

# Linear regression model.
model = torch.nn.Linear(n, 1, bias=True)

# Select a loss function.
loss = torch.nn.BCELoss()

# Use stochastic gradient descent as the optimizer.
opt = torch.optim.SGD(model.parameters(), lr=0.01)

costs = []
for i in tqdm_notebook(range(1000)):
    # Classify the training examples.
    pred_y = torch.sigmoid(model(X))
    # Compute the loss function.
    l = loss(pred_y, y)
    costs.append(l)
    # Compute gradient and update the parameters.
    opt.zero_grad()
    l.backward()
    opt.step()

In [None]:
# Plot learning curve, i.e. the error in each iteration.
plt.plot(costs[100:])

# Test the model

In [None]:
# Load test examples.
test_set = ds.MNIST(".data", train=False, transform=transforms.ToTensor(), download=True)

X_test_, y_test_ = zip(*[i for i in test_set if i[1] < 2])
X_test = torch.stack(X_test_).view(-1, n)
y_test = torch.tensor(y_test_).view(-1, 1).float()

In [None]:
# Use the classifier to predict the categories for the test examples.
pred_y = torch.sigmoid(model(X_test))
# Convert the probabilities (i.e. [0,1] into class labels {0, 1})
labels = torch.round(pred_y)

In [None]:
# Compute the accuracy of the classifier for the test examples.
torch.sum(labels == y_test).item() / y_test.size(0)

# Steal model parameters

In [None]:
# We have 28*28+1 unknowns (28*28 weights + 1 bias). Therefore, 
# we need 28*28+1 queries (i.e. equations).
k = n+1

# Create k random queries.
queries = torch.rand((k, n))

# Use the classifier to predict the categories for the queries.
output = model(queries)

In [None]:
# Add a column with ones for the bias to the queries. Shape (k, n) → (k, n+1).
q = torch.cat((queries, torch.ones((k, 1))), 1)

# Convert the queries with the added column into a numpy array.
a = q.data.numpy()

# Convert the output of the classifier into a numpy array.
b = output.data.squeeze().numpy()

In [None]:
# Solve for the parameters.
x = np.linalg.solve(a, b)

In [None]:
# Print the first 20 recovered parameter.
x[:20]

In [None]:
# Print the first 20 parameters of the model.
model.weight.squeeze().data.numpy()[:20]