In [1]:
import copy
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
import torch
import torch.nn as nn
import torch.optim as optim

In [2]:
# device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# hyper-parameters about DNN model
input_size = 30
hidden_size = 180
num_classes = 2
# hyper-parameters about optimizer
learning_rate = 0.001
momentum = 0.9
# Hyper-parameters about training control
batch_size = 32
num_iters = 300
iters_retrain = 25
num_retrains = num_iters // iters_retrain
lambda_punish = 0.1

In [3]:
class NeuralNet(nn.Module):
    
    '''Fully connected neural network with one hidden layer
    '''
    
    def __init__(self, input_size, hidden_size, num_classes):
        super(NeuralNet, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size) 
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, num_classes)  
    
    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        return out

In [4]:
class SurrogateModel(nn.Module):
    
    '''Fully connected neural network with one hidden layer
    '''
    
    def __init__(self, input_size):
        super(SurrogateModel, self).__init__()
        self.fc1 = nn.Linear(input_size, 7500) 
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(7500, 1)  
    
    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        out = out.squeeze(-1) # if not squeeze, out has dim (n_samples, 1). After squeeze, out has dim (n_samples) same to the y_true
        return out

In [5]:
def get_jth_minibatach(j, batch_size, X_train, y_train):
    '''返回数据集中的第j个minibatch
       
       @param j: 第j次iters_retrain
       @param batch_size: int
       @param X_train: torch.tensor
       @param y_train: torch.tensor
    '''
    num_data = X_train.size(0)
    num_minibatches = num_data // batch_size + ((num_data % batch_size) > 0)
    j = j % num_minibatches
    start = j * batch_size
    stop = start + batch_size
    return X_train[start:stop], y_train[start:stop]

In [6]:
def get_num_weights(model):
    '''返回模型的weights参数个数
    '''
    num_weights = 0
    for key, value in model.state_dict().items():
        if key.endswith('weight'):
            num_weights += torch.prod(torch.tensor(value.size()))
    return num_weights.item()

In [7]:
def get_row_weights(model_state_dict):
    row_weights = []
    for key, value in model_state_dict.items():
        if key.endswith('weight'):
            row_weights.append(value.view(-1))
    return torch.cat(row_weights)

In [8]:
def get_APL_dataset(saved_model_state_dict, X_train):
    tmp_model = NeuralNet(input_size, hidden_size, num_classes)
    X_APL_train = torch.zeros(iters_retrain, get_num_weights(tmp_model))
    y_APL_train = torch.zeros(iters_retrain)
    for i in range(len(saved_model_state_dict)):
        tmp_model.load_state_dict(saved_model_state_dict[i])
        X_train.to(device)
        outputs = tmp_model(X_train)
        _, y_pred = torch.max(outputs.data, 1)
        tree = DecisionTreeClassifier()
        X_train.to(torch.device('cpu'))
        y_pred.to(torch.device('cpu'))
        tree.fit(X_train.numpy(), y_pred.numpy())
        decision_path_matrix = tree.decision_path(X_train.numpy())
        apl = decision_path_matrix.sum() / X_train.size(0)
        y_APL_train[i] = apl
        X_APL_train[i] = get_row_weights(saved_model_state_dict[i])
    return X_APL_train, y_APL_train

In [9]:
# dataset
data = load_breast_cancer()
X_train, X_test, y_train, y_test = train_test_split(data.data, data.target, test_size=0.25, random_state=2020)
X_train, X_test = torch.tensor(X_train, dtype=torch.float), torch.tensor(X_test, dtype=torch.float)
y_train, y_test = torch.tensor(y_train, dtype=torch.long), torch.tensor(y_test, dtype=torch.long)

In [50]:
# train
model = NeuralNet(input_size, hidden_size, num_classes)
# initialize model weights
# for m in model.modules():
#    if isinstance(m, (nn.Linear)):
#        nn.init.xavier_uniform_(m.weight)
model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum)

# surrogate model
surrogate_model = SurrogateModel(get_num_weights(model))
# initialize model weights
# for m in surrogate_model.modules():
#    if isinstance(m, (nn.Linear)):
#        nn.init.xavier_uniform_(m.weight)
surrogate_model.to(device)
criterion_surrogate = nn.MSELoss()
optimizer_surrogate = optim.SGD(surrogate_model.parameters(), lr=learning_rate, momentum=momentum)
for i in range(num_retrains):
    saved_model_state_dict = [] # save the model state dict in each iters_retrain
    # train DNN model
    print('Training DNN model......')
    for j in range(iters_retrain):
        trn_x, trn_y = get_jth_minibatach(j, batch_size, X_train, y_train)
        trn_x.to(device)
        trn_y.to(device)
        output = model(trn_x)
        path_length = surrogate_model(get_row_weights(model.state_dict()))
        loss = criterion(output, trn_y) + lambda_punish * path_length
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        saved_model_state_dict.append(copy.deepcopy(model.state_dict()))
        if (i*iters_retrain + j + 1) % 10 == 0:
            print('DNN iters: [{0}]/[{1}] loss: {2:.2f} APL: {3:.2f}'.format((i*iters_retrain + j + 1), num_iters, loss, path_length))
    # train Decision Tree to get {weights, APL} dataset
    X_APL_train, y_APL_train = get_APL_dataset(saved_model_state_dict, X_train)
    print('Train surrogate model......')
    # train surrogate model
    for j in range(50):
        trn_x, trn_y = get_jth_minibatach(j, batch_size, X_APL_train, y_APL_train)
        trn_x.to(device)
        trn_y.to(device)
        output = surrogate_model(trn_x)
        loss = criterion_surrogate(output, trn_y)
        optimizer_surrogate.zero_grad()
        loss.backward()
        optimizer_surrogate.step()
        if (j+1) % 10 == 0:
            print('Surrogate iters: [{0}]/[50] loss: {1:.2f}'.format(j+1, loss))

Training DNN model......
DNN iters: [10]/[300] loss: 0.58 APL: 0.03
DNN iters: [20]/[300] loss: 0.54 APL: 0.03
Train surrogate model......
Surrogate iters: [10]/[50] loss: 8.41
Surrogate iters: [20]/[50] loss: 4.95
Surrogate iters: [30]/[50] loss: 3.38
Surrogate iters: [40]/[50] loss: 3.44
Surrogate iters: [50]/[50] loss: 3.05
Training DNN model......
DNN iters: [30]/[300] loss: 1.39 APL: 7.22
DNN iters: [40]/[300] loss: 1.33 APL: 7.40
DNN iters: [50]/[300] loss: 1.29 APL: 7.47
Train surrogate model......
Surrogate iters: [10]/[50] loss: 3.27
Surrogate iters: [20]/[50] loss: 2.90
Surrogate iters: [30]/[50] loss: 2.36
Surrogate iters: [40]/[50] loss: 1.59
Surrogate iters: [50]/[50] loss: 1.63
Training DNN model......
DNN iters: [60]/[300] loss: 0.73 APL: 1.81
DNN iters: [70]/[300] loss: 0.89 APL: 1.81
Train surrogate model......
Surrogate iters: [10]/[50] loss: 3.81
Surrogate iters: [20]/[50] loss: 3.69
Surrogate iters: [30]/[50] loss: 3.61
Surrogate iters: [40]/[50] loss: 3.59
Surrogat

KeyboardInterrupt: 

In [51]:
# test
with torch.no_grad():
    correct = 0
    total = 0
    X_test.to(device)
    y_test.to(device)
    outputs = model(X_test)
    _, predicted = torch.max(outputs.data, 1)
    total += y_test.size(0)
    correct += (predicted == y_test).sum().item()

    print('Accuracy of the network on the Breast Cancer dataset: {0} %'.format(100 * correct / total))

Accuracy of the network on the Breast Cancer dataset: 57.34265734265734 %


In [103]:
torch.save(model.state_dict(), 'dnn_model.pth')