In [None]:
import numpy as np
import copy
from scipy.optimize import linear_sum_assignment

from keras.models import model_from_json
import keras
import pickle

## SaSD to compare (Convolutional) Neural Networks

In [None]:
def getWeightsArray(mod):
    weights = []
    for layer in mod.layers:
        if isinstance(layer, keras.layers.core.Dense) or isinstance(layer, keras.layers.convolutional.Conv2D):
            weights.append(np.array(layer.get_weights()[0])) 
            
    return weights

In [None]:
# gets two vectors with all weights (! and not just the connection indices) and gives back edit distance
def editDistanceSigns(a,b):
    nums = len(a) - len(np.intersect1d(np.where(b==0), np.where(a==0)))
    same = len(np.intersect1d(np.where(a<0), np.where(b<0))) + len(np.intersect1d(np.where(a>0), np.where(b>0)))
    if nums == 0:
        return 0
    return (nums-same)/nums

In [None]:
def compareLayers(layer1, layer2):
    
    assert layer1.shape == layer2.shape

    # k is nr of kernels/neurons
    k = layer1.shape[-1]

    #print(layer1.shape)
    bottomList1 = [] 
    bottomList2 = []
    
    # if conv layer, shape has length 4 (height, width, channels, kernels)
    if len(layer1.shape) == 4:
      
        for kernel in range(k):
            bottomList1.append(layer1[:,:,:, kernel].flatten())

        for kernel in range(k):
            bottomList2.append(layer2[:,:,:, kernel].flatten())

    # a dense layer, shape is (neurons last layer, neurons this layer)
    else:
        for j in range(layer1.shape[1]):
            bottomList1.append(layer1[:, j])

        for j in range(layer2.shape[1]):
            bottomList2.append(layer2[:, j])

        # if last layer (output layer has 10 neurons):
        if layer1.shape[1] == 10:
            # do just compute distance, without being able to change order of output neurons
            summed_dist = 0
            for j in range(10):
                summed_dist += editDistanceSigns(bottomList1[j], bottomList2[j])
            return summed_dist/10, range(10), range(10)     

    editMatrix = np.zeros((k, k))
    for j1 in range(k):
        for j2 in range(k):
            editMatrix[j1, j2] = editDistanceSigns(bottomList1[j1], bottomList2[j2])

    row_ind, col_ind = linear_sum_assignment(editMatrix)
    minCost = editMatrix[row_ind, col_ind].sum()

    return minCost / k, row_ind, col_ind

In [None]:
def compareModels(mod1, mod2):
    
    # get array of weights of conv and dense layers
    weightsNN1 = getWeightsArray(mod1)
    weightsNN2 = getWeightsArray(mod2)
    
    # boolean is set to True if we have a conv layer and remains as such until we reach first Dense layer
    # it recognizes this and is set to False thereafter
    firstDenseAfterConv = False
    # saves length of last conv layer before first dense
    lastConvLen = 0
            
    numLayers = len(weightsNN1)
    assert len(weightsNN2) == numLayers
    
    editDistance = np.zeros(numLayers)
    
    # for first layer: compare them and return new order of NN2-neurons/kernels 
    k = 0
    layerNN1 = weightsNN1[k].copy()
    layerNN2 = weightsNN2[k].copy()
    editDistance[k], hid_layerNN1, hid_layerNN2 = compareLayers(layerNN1, layerNN2)
    
    # if first layer is Conv, we have to make the transition when first dense layer is ahead
    if len(layerNN1.shape) == 4:
        firstDenseAfterConv = True
    
    # for all other layers:
    for k in range(1, numLayers):
        layerNN1 = weightsNN1[k].copy()
        layerNN2 = weightsNN2[k].copy()
        
        # 3 possibilities: conv layer is next, first dense layer, or other dense layers
        
        # nr 1: we are dealing with a conv layer
        if len(layerNN1.shape) == 4:
            # iterate through all channels in layer
            for j in range(weightsNN2[k].shape[-2]):
                # reorder channels in kernel
                layerNN2[:,:,j,:] = weightsNN2[k][:,:,hid_layerNN2[j],:].copy()
            # save number of kernels in case it is the last conv layer
            lastConvLen = weightsNN2[k].shape[-1]
            
        # nr 2: first dense layer after having had a conv layer
        elif firstDenseAfterConv:
            # change order of first dense layer according to hid_layerNN2
            block_size = int(layerNN2.shape[0]/lastConvLen)
            for i in range(lastConvLen):
                layerNN2[i*block_size:i*block_size+block_size-1, :] = weightsNN2[k][hid_layerNN2[i]*block_size:hid_layerNN2[i]*block_size+block_size-1, :].copy()   
            firstDenseAfterConv = False
            
        # nr 3: normal dense layer after dense
        else:
            for j in range(weightsNN2[k].shape[0]):
                layerNN2[j, :] = weightsNN2[k][hid_layerNN2[j], :].copy()
                
        editDistance[k], hid_layerNN1, hid_layerNN2 = compareLayers(layerNN1, layerNN2)
        
    return editDistance

## Extract saved Models

In [None]:
# function to get CNN from json and h5 files
def getModelFromFile(json_file, h5_file):
    # get model structure from json
    json = open(json_file, "r")
    loaded_json = json.read()
    json.close()
    model = model_from_json(loaded_json)
    
    # load weights in model
    model.load_weights(h5_file)
    return model

In [None]:
def is_WT(his_WT, his_orig):
    return (np.argmin(his_WT["val_loss"])<=np.argmin(his_orig["val_loss"])) and (np.min(his_WT["val_loss"])<1.02*np.min(his_orig["val_loss"]))

In [None]:
# array with WTs and with random tickets
WTs_CIFAR = []
WTs_CINIC = []
WTs_SVHN = []
randoms = []

# for each possible WT for CIFAR, add to array if it is one
for i in range(0, 20):
    # extract history
    his_orig = pickle.load(open('../tickets/WTs_CIFAR/his_orig_s0.1_nr' + str(i), "rb"))
    his_WT = pickle.load(open('../tickets/WTs_CIFAR/his_WT_s0.1_nr' + str(i), "rb"))
    # check if it is a WT (min epoch same or equal, min val_loss smaller or only 2%(?) higher)
    if is_WT(his_WT, his_orig):
        WTs_CIFAR.append(getModelFromFile("../tickets/conv2.json", "../tickets/WTs_CIFAR/WT_s0.1_nr" + str(i) + ".h5"))

#CINIC
for i in range(0, 20):
    # extract history
    his_orig = pickle.load(open('../tickets/WTs_CINIC/his_orig_s0.1_nr' + str(i), "rb"))
    his_WT = pickle.load(open('../tickets/WTs_CINIC/his_WT_s0.1_nr' + str(i), "rb"))
    # check if it is a WT (min epoch same or equal, min val_loss smaller or only 2%(?) higher)
    if is_WT(his_WT, his_orig):
        WTs_CINIC.append(getModelFromFile("../tickets/conv2.json", "../tickets/WTs_CINIC/WT_s0.1_nr" + str(i) + ".h5"))
        
#SVHN
for i in range(0, 20):
    # extract history
    his_orig = pickle.load(open('../tickets/WTs_SVHN/his_orig_s0.1_nr' + str(i), "rb"))
    his_WT = pickle.load(open('../tickets/WTs_SVHN/his_WT_s0.1_nr' + str(i), "rb"))
    # check if it is a WT (min epoch same or equal, min val_loss smaller or only 2%(?) higher)
    if is_WT(his_WT, his_orig):
        WTs_SVHN.append(getModelFromFile("../tickets/conv2.json", "../tickets/WTs_SVHN/WT_s0.1_nr" + str(i) + ".h5"))
        
        
# for each random ticket, add to array if it is not a WT
for i in range(0,20):
    # extract history
    his_orig = pickle.load(open('../tickets/random/his_orig_s0.1_nr' + str(i), "rb"))
    his_random = pickle.load(open('../tickets/random/his_random_s0.1_nr' + str(i), "rb"))
    # check if it is a WT (min epoch same or equal, min val_loss smaller or only 2%(?) higher)
    if not is_WT(his_random, his_orig):
        randoms.append(getModelFromFile("../tickets/conv2.json", "../tickets/random/random_s0.1_nr" + str(i) + ".h5"))
        
# take same amount of subnetworks for each condition (minimum of WTs of each type, in my case 14)  
min_len = min(len(randoms), len(WTs_SVHN), len(WTs_CIFAR), len(WTs_CINIC))

randoms = randoms[:min_len]
WTs_SVHN = WTs_SVHN[:min_len]
WTs_CIFAR = WTs_CIFAR[:min_len]
WTs_CINIC = WTs_CINIC[:min_len]

## Compare each group to itself and to all others (10 conditions in total)

In [None]:
# When comparing different tickets, possibility to choose from 14*14 = 196 combinations
# When comparing same tickets, possibility to choose from 13+12+...+1 = 91

### Each group to itself

In [None]:
# for each group of tickets, compare to itself (4 condititions in total)
for name, group in zip(["random", "WTs_SVHN", "WTs_CINIC", "WTs_CIFAR"],
                       [randoms, WTs_SVHN, WTs_CINIC, WTs_CIFAR]):
    
    print("starting group", name)
    
    # create array with 30 of 91 of the values True, other ones False and shuffle
    bool_arr = np.concatenate((np.ones(30, dtype = bool), np.zeros(61, dtype = bool)))
    np.random.shuffle(bool_arr)

    s = 0
    dists_all_tickets = []

    for i,ticket1 in enumerate(group):
        for j,ticket2 in enumerate(group):
            if i<j:
                if bool_arr[s]:
                    dists = compareModels(ticket1, ticket2)
                    dists_all_tickets.append(dists)
                    print("distance of ticket", i, "and ticket", j, "is:", dists)
                    print("mean distance:", np.mean(dists))
                s += 1
                
    np.savetxt("../dists/dists_" + name + ".csv", dists_all_tickets, delimiter = ",")

### Each WT-group to random group (mixed condition)

In [None]:
# for each group of tickets, compare to random (3 conditions in total)
for name, group in zip(["SVHN", "CINIC", "CIFAR"],
                       [WTs_SVHN, WTs_CINIC, WTs_CIFAR]):

    print("starting group", name, "and random")

    # create array with 30 of 196 of the values True, other ones False and shuffle
    bool_arr = np.concatenate((np.ones(30, dtype = bool), np.zeros(166, dtype = bool)))
    np.random.shuffle(bool_arr)

    s = 0
    dists_all_tickets = []

    for i,ticket1 in enumerate(group):
        for j,ticket2 in enumerate(randoms):
            if bool_arr[s]:
                dists = compareModels(ticket1, ticket2)
                dists_all_tickets.append(dists)
                print("distance of ticket", i, "and random", j, "is:", dists)
                print("mean distance:", np.mean(dists))
            s += 1

    np.savetxt("../dists/dists_mixed_" + name + ".csv", dists_all_tickets, delimiter = ",")

### Each WT-group with all other WT-groups

In [None]:
# for each group of WTs, compare to other groups of WTs (3 conditions in total)
for name, group in zip(["SVHN", "CINIC"],
                       [WTs_SVHN, WTs_CINIC]):
    for name2, group2 in zip(["CINIC", "CIFAR"],
                       [WTs_CINIC, WTs_CIFAR]):
        if (name != name2):
            print("starting group", name, "and", name2)

            # create array with 30 of 196 of the values True, other ones False and shuffle
            bool_arr = np.concatenate((np.ones(30, dtype = bool), np.zeros(166, dtype = bool)))
            np.random.shuffle(bool_arr)

            s = 0
            dists_all_tickets = []

            for i,ticket1 in enumerate(group):
                for j,ticket2 in enumerate(group2):
                    if bool_arr[s]:
                        dists = compareModels(ticket1, ticket2)
                        dists_all_tickets.append(dists)
                        print("distance of ticket", i, "and ticket", j, "is:", dists)
                        print("mean distance:", np.mean(dists))
                    s += 1

            np.savetxt("../dists/dists_between_" + name + "_" + name2 + ".csv", 
                       dists_all_tickets, delimiter = ",")