In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import os
import random

In [2]:
#Target model parameters
target_model_batch_size_train = 64
target_model_batch_size_test = 1000
target_model_learning_rate = 0.001
target_model_epochs = 3
target_model_log_interval = 100

In [4]:
#Load digit data
train_set = torchvision.datasets.MNIST('/files/', train=True, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ]))

indices_01 = (train_set.targets == 0) | (train_set.targets == 1)
train_set.data, train_set.targets = train_set.data[indices_01], train_set.targets[indices_01]

train_set2 = torchvision.datasets.MNIST('/files/', train=True, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ]))
indices_n01 = (train_set2.targets != 0) & (train_set2.targets != 1)
train_set2.data, train_set2.targets = train_set2.data[indices_n01], train_set2.targets[indices_n01]

train_loader = torch.utils.data.DataLoader(train_set,batch_size=target_model_batch_size_train, shuffle=True)

test_set = torchvision.datasets.MNIST('/files/', train=False, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ]))

indices_n01 = (test_set.targets == 0) | (test_set.targets == 1)
test_set.data, test_set.targets = test_set.data[indices_n01], test_set.targets[indices_n01]

test_loader = torch.utils.data.DataLoader(test_set,batch_size=target_model_batch_size_test, shuffle=True)



In [5]:
#Create synthetic data set probs
synth_probs = torch.zeros(784)
for t,l in train_set:
  synth_probs += torch.flatten(t,0,2)
synth_probs /= len(train_set)

In [6]:
#Target model definition
class TargetCNN(nn.Module):
    def __init__(self):
        super(TargetCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 2)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x)

In [7]:
#Create the target model and optimizer
target_model = TargetCNN()
target_model_optimizer = optim.Adam(target_model.parameters(), lr=target_model_learning_rate)

In [8]:
#Load target model if it exists
if os.path.exists("target_model.pt"):
    checkpoint = torch.load("target_model.pt")
    target_model.load_state_dict(checkpoint['state_dict'])
    target_model_optimizer.load_state_dict(checkpoint['optimizer'])

In [9]:
#Train the target model
target_model.train()
total_loss = 0
num_examples = 0
for e in range(target_model_epochs):
  for b, (data, target) in enumerate(train_loader):
    target_model_optimizer.zero_grad()
    output = target_model(data)
    loss = F.nll_loss(output, target)
    total_loss += loss.item()
    num_examples += len(data)
    loss.backward()
    target_model_optimizer.step()
    if b % target_model_log_interval == 0:
      print("Epoch {}/{} {:.0f}% with avg loss {:.6f}".format(e+1,target_model_epochs,b/len(train_loader)*100,total_loss/num_examples))
      checkpoint = {   
                'state_dict': target_model.state_dict(),
                'optimizer': target_model_optimizer.state_dict(),
      }
      torch.save(checkpoint, "target_model.pt")

print("Finished training with avg loss {:.6f}".format(total_loss/num_examples))
checkpoint = {   
          'state_dict': target_model.state_dict(),
          'optimizer': target_model_optimizer.state_dict(),
}
torch.save(checkpoint, "target_model.pt")



Epoch 1/3 0% with avg loss 0.012038
Epoch 1/3 51% with avg loss 0.001221
Epoch 2/3 0% with avg loss 0.000697
Epoch 2/3 51% with avg loss 0.000486
Epoch 3/3 0% with avg loss 0.000389
Epoch 3/3 51% with avg loss 0.000322
Finished training with avg loss 0.000281


In [11]:
#Test the target model
target_model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
  for data, target in test_loader:
    output = target_model(data)
    test_loss += F.nll_loss(output, target, size_average=False).item()
    pred = output.data.max(1, keepdim=True)[1]
    correct += pred.eq(target.data.view_as(pred)).sum()
  test_loss /= len(test_loader.dataset)
  print(test_loss)



0.0011639842523556933


In [12]:
#Adaptive training parameters
adaptive_retraining_model_learning_rate = 0.01
adaptive_retraining_num_rounds = 64
adaptive_retraining_queries_per_round = 16
adaptive_retraining_uncertainty_threshold = 0.25

In [13]:
#Adaptive retraining model defintion
class AdaptiveRetrainingModel(nn.Module):
  def __init__(self):
      super(AdaptiveRetrainingModel, self).__init__()
      self.fc1 = nn.Linear(784,2)

  def forward(self, x):
      x = x.flatten(1,3)
      x = F.log_softmax(self.fc1(x))
      return x

In [14]:
#Create the adaptive retraining model
adaptive_retraining_model = AdaptiveRetrainingModel()
adaptive_retraining_model_optimizer = optim.Adam(adaptive_retraining_model.parameters(), lr=adaptive_retraining_model_learning_rate)

In [15]:
#Adaptive retraining query genereration defintion
def AdaptiveRetraining():
  adaptive_retraining_model.train()

  #Store sequence of queries
  queries = []

  #Train the copy model on uniform random points
  first_query_indecies = random.sample(range(0, len(train_set)-1), adaptive_retraining_queries_per_round)
  for q_idx in first_query_indecies:
    #Get query
    q, a = train_set[q_idx]
    q = q.unsqueeze(0)
    #Get output of target model
    r = target_model(q)
    #Train the adaptive model
    v = adaptive_retraining_model(q)
    loss = F.cross_entropy(v, r)
    loss.backward()
    target_model_optimizer.step()
    #Record the query
    q = torch.flatten(q.squeeze(0),-2,-1)
    queries.append(q)

  #Train the copy model on new points along the decision boundary for several rounds
  for r_num in range(adaptive_retraining_num_rounds-1):
    for q_num in range(adaptive_retraining_queries_per_round):

      #Find a query that the adaptive retraining model is uncertain about
      q_idx = random.randint(0, len(train_set)-1)
      num_attempts = 0
      while(True):
        #Get query
        q, a = train_set[q_idx]
        q = q.unsqueeze(0)
        num_attempts += 1
        #Get output of adaptive model
        v = adaptive_retraining_model(q)
        #Compute certainty heuristic
        c = torch.max(v)-torch.mean(v)
        #Keep query if unceratin or keep looking
        if(c < adaptive_retraining_uncertainty_threshold):
          break
        elif(num_attempts >= 50):
          break
        else:
          q_idx = random.randint(0, len(train_set)-1)

      #Get output of target model
      q, a = train_set[q_idx]
      q = q.unsqueeze(1)
      r = target_model(q)
      #Train the adaptive model
      v = adaptive_retraining_model(q)
      loss = F.cross_entropy(v, r)
      loss.backward()
      target_model_optimizer.step()
      #Record the query
      q = torch.flatten(q.squeeze(0),-2,-1)
      queries.append(q)
      
  
  #Record the sequence of queries
  queries = torch.stack(queries)
  return torch.transpose(queries,0,1)

In [16]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
#Benign queries parameters

In [17]:
#Benign queries definition
def BenignQueries(num_benign_queries):
  num_pd = num_benign_queries//2
  queries = []
  query_indecies = random.sample(range(0, len(train_set)), num_benign_queries-num_pd)
  for q_idx in query_indecies:
    q, a = train_set[q_idx]
    q = torch.flatten(q,-2,-1).squeeze(0)
    queries.append(q)
  query_indecies = random.sample(range(0, len(train_set)), num_pd)
  for q_idx in query_indecies:
    q, a = train_set2[q_idx]
    q = torch.flatten(q,-2,-1).squeeze(0)
    queries.append(q)
  queries = torch.stack(queries)
  return queries

In [18]:
#Synthetic queries definition
def SynthQueries(num_queries):
  queries = []
  for q_idx in range(num_queries):
    q = torch.clamp(synth_probs + torch.rand(784), min=0, max=1)
    queries.append(q)
  queries = torch.stack(queries)
  return queries

In [19]:
#Detection model definition
class DetectionModel(nn.Module):
  def __init__(self, in_dim, lstm_hidden_dim, mlp_hidden_dim, target_parameter_dim,out_dim):
    super(DetectionModel, self).__init__()
    # LSTM
    self.lstm = nn.LSTM(in_dim, lstm_hidden_dim, 1, batch_first = True)
    
    # MLP
    self.mlp1 = nn.Linear(lstm_hidden_dim,mlp_hidden_dim)
    self.mlp2 = nn.Linear(target_parameter_dim,mlp_hidden_dim)
    self.mlp3 = nn.Linear(mlp_hidden_dim*2,out_dim)
    
    # activation functions
    self.relu = nn.ReLU()
    self.dropout = nn.Dropout(0.5)

  def forward(self, x, p): #num_queries

    # compute encoding
    o, _ = self.lstm(x)

    # keep last layer
    o = o[:,-1]

    # pass through mlp
    o = self.relu(self.dropout(self.mlp1(o)))
    p = self.mlp2(p).unsqueeze(0)

    o = torch.cat([o,p],-1)
    o = self.mlp3(o)
    return o
      

In [20]:
ps = []
for param in target_model.parameters():
  ps.append(torch.flatten(param))
ps = torch.cat(ps).to(device)
print(ps.shape)

torch.Size([21432])


In [21]:
detection_model_learning_rate = 0.00000000001

In [23]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

cuda


In [24]:
#Create the detection model
detection_model = DetectionModel(784,2048,1024,21432,2).to(device)


In [25]:
detection_model_optimizer = optim.Adam(detection_model.parameters(), lr=detection_model_learning_rate)

In [26]:
#Load the detection model if it exists
if os.path.exists("detection_model.pt"):
    checkpoint = torch.load("detection_model.pt")
    detection_model.load_state_dict(checkpoint['state_dict'])
    detection_model_optimizer.load_state_dict(checkpoint['optimizer'])

In [27]:
#Train the detection model to distinguish between adaptive retraining and benign queries
for e in range(15):
  total_loss = 0
  #Train on benign
  detection_model_optimizer.zero_grad()
  Q = BenignQueries(1024).to(device).unsqueeze(0)
  A = torch.tensor([0]).to(device)
  O = detection_model(Q, ps)
  loss = F.cross_entropy(O, A)
  total_loss += loss.item()
  loss.backward()
  detection_model_optimizer.step()
  
  #Train on adaptive retraining
  detection_model_optimizer.zero_grad()
  Q = AdaptiveRetraining().to(device)
  A = torch.tensor([1]).to(device)
  O = detection_model(Q, ps)
  loss = F.cross_entropy(O, A)
  total_loss += loss.item()
  loss.backward()
  detection_model_optimizer.step()
  
  print(total_loss)
  checkpoint = {   
                'state_dict': detection_model.state_dict(),
                'optimizer': detection_model_optimizer.state_dict(),
      }
  torch.save(checkpoint, "detection_model.pt")


  if __name__ == '__main__':


1.3663349151611328
1.4260262846946716
1.3735675811767578
1.3972859382629395
1.3676543235778809
1.3922427296638489
1.3440635204315186
1.3561939001083374
1.3583173155784607
1.3806659579277039
1.3916305899620056
1.3516534566879272
1.3552793860435486
1.369852602481842
1.387883186340332


In [28]:
#Evaluate the detection model to distinguish between adaptive retraining and benign queries
total_accuracy = 0
false_positves = 0
false_negatives = 0
num_tests = 1000
for e in range(num_tests):

  #Pick random query tpye
  r = random.randint(0, 2)
  Q,A = 0,0
  if r == 0:
    Q = BenignQueries(1024).to(device).unsqueeze(0)
    A = torch.tensor([0]).to(device)
  elif r == 1:
    Q = AdaptiveRetraining().to(device)
    A = torch.tensor([1]).to(device)
  else:
    Q = SynthQueries(1024).to(device).unsqueeze(0)
    A = torch.tensor([1]).to(device)

  #Compute the output
  O = detection_model(Q, ps).squeeze(0)

  #Record accuracy
  r = 0
  if O[1] > O[0]:
    r = 1
    if A == 0:
      false_positves += 1
  else:
    if A == 1:
      false_negatives += 1
  if A == r:
    total_accuracy += 1

total_accuracy /= num_tests
total_accuracy *= 100
print(total_accuracy)

false_negatives /= num_tests
false_negatives *= 100
print(false_negatives)

false_positves /= num_tests
false_positves *= 100
print(false_positves)

 


  if __name__ == '__main__':


38.5
57.8
3.6999999999999997
