# Setup on Google Colab




In [None]:
# !git clone https://github.com/JuliaXxj/ICLR2023.git
# !mv -v ./ICLR2023/* ./
# !rm -r ./ICLR2023/

Cloning into 'ICLR2023'...
remote: Enumerating objects: 13, done.[K
remote: Counting objects: 100% (13/13), done.[K
remote: Compressing objects: 100% (11/11), done.[K
remote: Total 13 (delta 3), reused 0 (delta 0), pack-reused 0[K
Unpacking objects: 100% (13/13), done.
renamed './ICLR2023/FFN19-44-07-ep25' -> './FFN19-44-07-ep25'
renamed './ICLR2023/models.py' -> './models.py'
renamed './ICLR2023/README.md' -> './README.md'
renamed './ICLR2023/utils.py' -> './utils.py'


# Code

In [None]:
!pip3 install imageio

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [None]:
from models import FeedforwardNeuralNetModel, TinyCNN, PatternClassifier, NewTinyCNN
from torchvision import datasets, transforms
from torch import optim
import torch
import torch.nn.functional as F
from torch.optim.lr_scheduler import StepLR
from tqdm import tqdm
import copy
import utils as CFI_utils
from matplotlib import pyplot as plt
import numpy as np
from collections import defaultdict
#from frozendict import frozendict
from datetime import datetime
import seaborn as sns
import pandas as pd
from absl import app, flags
from easydict import EasyDict
import torch.nn as nn
import json
import os


colors = sns.color_palette("tab10")
%matplotlib inline

#Logging stuffs
import logging
import sys
# Create logger
logger = logging.getLogger()
logger.setLevel(logging.INFO)

# Create STDERR handler
handler = logging.StreamHandler(sys.stderr)
# ch.setLevel(logging.DEBUG)

# Create formatter and add it to the handler
formatter = logging.Formatter('%(name)s - %(levelname)s - %(message)s')
handler.setFormatter(formatter)

# Set STDERR handler as the only handler 
logger.handlers = [handler]

#configs
epochs = 10
batch_size = 1000
test_batch_size = 10000
stable_batch_size = 60000
use_cuda = torch.cuda.is_available()
print("use_cuda: ",use_cuda)

lr = 0.01
log_interval = 100

#torch specific configs
torch.manual_seed(1)

device = torch.device("cuda" if use_cuda else "cpu")

train_kwargs = {'batch_size': batch_size}

test_kwargs = {'batch_size': test_batch_size}
stable_kwargs = {'batch_size': stable_batch_size}

if use_cuda:
    cuda_kwargs = {'num_workers': 1,
                   'pin_memory': True,
                   'shuffle': False}
    train_kwargs.update(cuda_kwargs)
    test_kwargs.update(cuda_kwargs)
    stable_kwargs.update(cuda_kwargs)

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])


use_cuda:  True


## Helper functions

In [None]:
class Shift:
    def __init__(self, shift = 0):
        print("alive")
        self.shift = shift

    def __call__(self, arr):
        print("running")
        #print(arr)
        return arr
    def __repr__(self) -> str:
        return f"{self.__class__.__name__}()"
    
class Rescale(object):
    """Rescale the image in a sample to a given size.

    Args:
        output_size (tuple or int): Desired output size. If tuple, output is
            matched to output_size. If int, smaller of image edges is matched
            to output_size keeping aspect ratio the same.
    """

    def __init__(self, output_size):
        assert isinstance(output_size, (int, tuple))
        self.output_size = output_size

    def __call__(self, sample):
        image, landmarks = sample['image'], sample['landmarks']

        h, w = image.shape[:2]
        if isinstance(self.output_size, int):
            if h > w:
                new_h, new_w = self.output_size * h / w, self.output_size
            else:
                new_h, new_w = self.output_size, self.output_size * w / h
        else:
            new_h, new_w = self.output_size

        new_h, new_w = int(new_h), int(new_w)

        img = transform.resize(image, (new_h, new_w))

        # h and w are swapped for landmarks because for images,
        # x and y axes are axis 1 and 0 respectively
        landmarks = landmarks * [new_w / w, new_h / h]

        return {'image': img, 'landmarks': landmarks}

def shift_and_roll(arr : torch.Tensor, x : int , y : int) -> torch.Tensor:
    return torch.roll(arr, shifts = (x,y), dims = (0,1)) #indexing
def shift(arr : torch.Tensor, x : int , y : int) -> torch.Tensor:

    pad = [0,0,0,0]
    x_start = 0
    x_end = 0
    y_start = 0
    y_end = 0
    if x >= 0:
        pad[0] = x
        x_start = 0
        x_end = 28
    else:
        pad[1] = abs(x)
        x_start = -28        
    
    if y >= 0:
        pad[2] = y
        y_start = 0
        y_end = 28
    else:
        pad[3] = abs(y)
        y_start= -28
        print("ys is {}".format(y_start))
    

    padder = torch.nn.ZeroPad2d(tuple(pad))
  
    result = padder(arr)

    if y < 0:
        y_end = result.shape[0]
    if x < 0:
        x_end = result.shape[1]
    
    return result[y_start:y_end, x_start:x_end]

def noisify(arr : torch.Tensor , distribution : torch.distributions.Distribution) -> torch.Tensor: #randomly add noise
   
    #print(distribution.sample(arr.size()).shape\\)
    #print(torch.reshape(distribution.sample(arr.size()), (28,28)).shape)
   
    noise = torch.reshape(distribution.sample(arr.size()), (arr.shape))
    return arr + noise

def transform_dataset(dataset):
    '''
    DO A STATE CHANGE
    '''
    shift_x = 1
    shift_y = 0
    mu = 1
    sigma = 1
    do_shift = False
    do_noisfy = True
    modification_string = "base"
    if do_shift:
        modification_string += " -shift {} {} - ".format(shift_x, shift_y)
    if do_noisfy:
        modification_string += " -noise added using gaussian using mean {} and stddev {}- ".format(mu, sigma)
    for i in range(dataset.data.shape[0]):
        if do_shift:
            dataset.data[i,:,:] = shift(dataset.data[i,:,:],shift_x, shift_y)
        
        if do_noisfy:
            gaussian = torch.distributions.Normal(loc = mu, scale = sigma)# loc = mu, scale = stddev

            dataset.data[i,:,:] = noisify(dataset.data[i,:,:], gaussian)
    return modification_string

def check_gradient(grad, label, last_sorted_grads, plot = False):
    logging.info("CHECKING GRADIENT FOR LABEL {}".format(label))
    sum_abs_grad = np.sum(abs(grad[label]), axis = 0)
    
    current_sorted_grad = (-sum_abs_grad).argsort()
    
#     if len(last_sorted_grads[label]) > 0:
#         for k in [100, 200, 300, 400]:
#             prev_top_k = set(last_sorted_grads[label][-1][:k])
#             current_top_k = set(current_sorted_grad[:k])
#             intersect = prev_top_k.intersection(current_top_k)
#             logging.info('k = {}. How many top Gradients are stable since last epoch?: {}'.format(k, len(intersect)))
        
    for k in [0, 9, 99, 199]:    
        logging.debug('{}th biggest gradient = {}'.format(k, np.sort(-sum_abs_grad)[k]))
    if plot:
        fig = plt.figure(figsize=(30, 1))
        plt.bar(range(sum_abs_grad.shape[0]), sum_abs_grad)
        plt.show()
        print(sum_abs_grad.max(), sum_abs_grad.argmax(), sum_abs_grad.min())

    return current_sorted_grad

class Patterns:
    def __init__(self, model: torch.nn.Module, dataloader: torch.utils.data.DataLoader, labels, layers):
        self._model = model
        self.label2patterns = {}
        self.label2idx = {}
        self._labels = labels
        self._layers = layers
        self._dataloader = dataloader
        self._populate()
        
    def _populate(self):
        
        label2patterns = {}
        label2idx = {}
        for label in self._labels:
            patterns = []
            filter_ids = []
            
            for data, target in self._dataloader:
                
                flter = np.where(target == label)
                filter_ids.append(flter)
                data = data[flter]
                logging.debug(data.shape[0])
                pattern = self._model.get_pattern(data, layers, device, flatten = True)
                logging.debug(pattern.shape)
                patterns.append(pattern)

            patterns = np.squeeze(np.concatenate(patterns, axis = 0))
            filter_ids = np.squeeze(np.concatenate(filter_ids, axis = 0))
            label2patterns[label] = patterns
            label2idx[label] = filter_ids
            
            logging.info(patterns.shape)
        
        #freeze
        self.label2patterns = dict(label2patterns)
        self.label2idx = dict(label2idx)
        
    def apply_filter(self, f):
        pass
    
    def unique():
        pass
    
    def query_pattern():
        pass

## Load Model & Datasets

In [None]:
dataset1 = datasets.MNIST('./data', train=True, download=True,
                          transform=transform)
dataset2 = datasets.MNIST('./data', train=False, download=True,
                          transform=transform)


#transform_dataset(dataset1)
#modification_string = transform_dataset(dataset2)
modification_string = ""
train_loader = torch.utils.data.DataLoader(dataset1, **train_kwargs)
test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)

In [None]:
LOADPATH = 'FFN19-44-07-ep25'

LAST_N_EPOCHS = 10

model = FeedforwardNeuralNetModel(28*28, 128, 10).to(device)
model.load_state_dict(torch.load(LOADPATH, map_location=device))

with torch.no_grad():
    total = 0
    correct = 0
    for data, target in dataset2:
        # data_extend = data[None, :]
        if int(torch.argmax(model.to(device)(data.to(device)), dim = 1)) == target:
        # if int(torch.argmax(model.cuda()(data_extend.cuda()), dim = 1)) == target:
            correct += 1
        total+=1
    "accuracy {}".format( correct/total)
    modification_string += " -accuracy {}-".format( correct/total)
    print(modification_string)

 -accuracy 0.898-


In [None]:
layers = ['fc1', 'fc2', 'fc3', 'fc4']
# layers = ['conv1', 'conv2','fc1', 'fc2']

labels = range(10)
K = 25
stable_loader = torch.utils.data.DataLoader(dataset1, **stable_kwargs)

all_patterns = Patterns(model = model,
                        dataloader = stable_loader,
                        labels = labels,
                        layers = layers)
all_test_patterns = Patterns(model = model,
                        dataloader = test_loader,
                        labels = labels,
                        layers = layers)

root - INFO - (5923, 458)
root - INFO - (6742, 458)
root - INFO - (5958, 458)
root - INFO - (6131, 458)
root - INFO - (5842, 458)
root - INFO - (5421, 458)
root - INFO - (5918, 458)
root - INFO - (6265, 458)
root - INFO - (5851, 458)
root - INFO - (5949, 458)
root - INFO - (980, 458)
root - INFO - (1135, 458)
root - INFO - (1032, 458)
root - INFO - (1010, 458)
root - INFO - (982, 458)
root - INFO - (892, 458)
root - INFO - (958, 458)
root - INFO - (1028, 458)
root - INFO - (974, 458)
root - INFO - (1009, 458)


In [None]:
log_dir = "log_file"

if not os.path.isdir(log_dir):
    os.mkdir(log_dir)

class NpEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, np.integer):
            return int(obj)
        if isinstance(obj, np.floating):
            return float(obj)
        if isinstance(obj, np.bool_):
            return int(obj)
        if isinstance(obj, np.ndarray):
            return obj.tolist()
        return super(NpEncoder, self).default(obj)
    


all_stable_relus = []

all_alpha_patterns = {"model": LOADPATH}

write_log = True
if write_log:
    ReLU_exp_log = open(os.path.join(log_dir,"relu_exp_log{}.csv".format(datetime.now().strftime("%H-%M-%S"))), "w")
    ReLU_exp_json = open(os.path.join(log_dir,"relu_exp_data{}.json".format(datetime.now().strftime("%H-%M-%S"))), "w")
    ReLU_exp_log.write("Epsilon,Label,NumStableReLU,NumUniqueAP,Alpha Pattern Cover\n")


epsilon_to_patterns = dict()
# epsilons = [0.0001, 0.0005, 0.001, 0.005, 0.01, 0.05]
epsilons = np.linspace(0, 0.5, num=50)

for epsilon in epsilons:
    alpha_patterns = {}
    label_to_most_common_pattern = dict()
    for label in all_patterns.label2patterns:
#     for label in [0,2,8]:
        patterns = all_patterns.label2patterns[label]
        print(patterns.shape)

        occuring_patterns = patterns.tolist() #make patterns into list of lists first dim is each example, second dim is the pattern
        most_common_pattern = max(occuring_patterns, key =  lambda x :occuring_patterns.count(x)) #argmax the count of each pattern

        #print("most common pattern: {}".format(most_common_pattern) )
        
        pattern_indices = list(filter( lambda x : most_common_pattern[x] , range(len(most_common_pattern))   ))#get true indices     
        #print("pattern_indices: {}".format(pattern_indices))
        label_to_most_common_pattern[label] = (pattern_indices,most_common_pattern)
        relu_sum = np.sum(patterns, axis = 0).squeeze()
        
        print("threshold: ", epsilon*patterns.shape[0], (1-epsilon)*patterns.shape[0])
        
#         print("relu_sum, layer-1 ", relu_sum[:256])
#         print("relu_sum, layer-2 ", relu_sum[256:384])
#         print("relu_sum, layer-3 ", relu_sum[384:448])
        print("relu_sum, prediction", relu_sum[448:458])
#         print("relu_sum", relu_sum[-10:])
        
        non_active_neurons = np.where(relu_sum<=epsilon*patterns.shape[0])
        active_neurons = np.where(relu_sum>=(1-epsilon)*patterns.shape[0])
        print("non active neurons: ", non_active_neurons)
        print("active neurons: ", active_neurons)

        stable_idx = np.concatenate([np.where(relu_sum<=epsilon*patterns.shape[0]), 
                                     np.where(relu_sum>=(1-epsilon)*patterns.shape[0])],
                                    axis = 1
                                    ).squeeze()
        neuro_idx = patterns.shape[1] - 10 + label
        if neuro_idx not in stable_idx:
            print(f"WARN: neuro_idx = {neuro_idx} for label {label} is not stable, let's include it anyway")
            stable_idx = np.append(stable_idx, neuro_idx)
        stable_idx = sorted(stable_idx) #sort the indices of the stable ReLUs. 
        unique_patterns, freq = np.unique(patterns[:, stable_idx ], axis = 0, return_counts=True)
        alpha_p = unique_patterns[np.argmax(freq)]
        print("unique patterns:", unique_patterns)
        print()
        print("Label is ", label, "-epsilon ", epsilon)
        print("Stable ReLUs", stable_idx)
        print("how many unique paths in the filtered pattern?", unique_patterns.shape)
        print("their freq\n", freq, freq.shape)
#         print("most prominent pattern", np.argmax(freq), alpha_p)
#         print("alpha_p is ", alpha_p)


        assert(len(stable_idx) == alpha_p.shape[-1])
        assert(freq.shape[0]==unique_patterns.shape[0])
#         alpha_patterns[label] = (stable_idx, tuple(alpha_p))
#         alpha_patterns[label] = {"stable_idx": stable_idx,
#                                 "alpha_pattern": alpha_p,
#                                 "alpha_pattern_coverage": freq.max()/freq.sum(),
#                                 "pattern_frequency": freq}
        alpha_patterns[label] = {"stable_idx": stable_idx,
                                 "active_neurons": active_neurons[0],
                                 "non_active_neurons": non_active_neurons[0],
                                "alpha_pattern": alpha_p,
                                "alpha_pattern_coverage": freq.max()/freq.sum(),
                                "pattern_frequency": freq}
    
#         print("pattern frequency:", freq)
        print("primary pattern coverage: ", freq.max()/freq.sum(),)

        if write_log:
            ReLU_exp_log.write("{},{},{},{},{}\n".format(epsilon, label, len(stable_idx), unique_patterns.shape[0], freq.max()/freq.sum()))
    all_alpha_patterns[epsilon] = alpha_patterns
    epsilon_to_patterns[epsilon] = label_to_most_common_pattern
json.dump(epsilon_to_patterns, open(os.path.join(log_dir,"most_common_patterns{}.json".format(datetime.now().strftime("%H-%M-%S"))), "w")) 
    
if write_log:
    ReLU_exp_log.close()
    json.dump(all_alpha_patterns, fp = ReLU_exp_json, indent=2, cls=NpEncoder)
    ReLU_exp_json.close()

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Stable ReLUs [0, 1, 2, 3, 4, 6, 7, 8, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 73, 75, 76, 77, 78, 79, 80, 81, 82, 83, 85, 86, 87, 89, 90, 91, 93, 94, 95, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 147, 148, 149, 150, 151, 152, 154, 155, 156, 157, 158, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 177, 178, 180, 181, 182, 183, 184, 185, 186, 187, 189, 190, 191, 192, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223

In [None]:
for train_data, train_target in stable_loader:
    print(train_data.size())
    print(train_target.size())
    
for test_data, test_target in test_loader:
    print(test_data.size())
    print(test_target.size())
    
train_by_cls = {}
test_by_cls = {}

for picked_cls in range(10):
    train_cls_idx = train_target==picked_cls    

    sub_train_dataset = train_data[train_cls_idx].to(device)
    sub_train_targets = train_target[train_cls_idx].to(device)

    test_cls_idx = test_target==picked_cls

    sub_test_dataset = test_data[test_cls_idx].to(device)
    sub_test_targets = test_target[test_cls_idx].to(device)
        
    train_by_cls[picked_cls] = {}
    train_by_cls[picked_cls]["data"] = sub_train_dataset
    train_by_cls[picked_cls]["targets"] = sub_train_targets
    
    test_by_cls[picked_cls] = {}
    test_by_cls[picked_cls]["data"] = sub_test_dataset
    test_by_cls[picked_cls]["targets"] = sub_test_targets
    
    print(f"Train: for class {picked_cls}, total = {len(sub_train_targets)}")
    print(f"Test: for class {picked_cls}, total = {len(sub_test_targets)}")
    print()




## For each $ϵ$, check how many test images of each class satisfy the alpha pattern

In [None]:
csv_filename = "test_alpha_pattern_count.csv"
df_row_index = []
for i in range(10):
    row_name = "label{} (total={})".format(i, len(test_by_cls[i]["targets"]))
    df_row_index.append(row_name)
print(df_row_index)

test_count_df = pd.DataFrame(index= df_row_index)
# test_count = []
epsilons = np.linspace(0, 0.5, num=50)

for epsilon in epsilons:
    print(f"For eps={epsilon} checking the number of test images of each class that satisfying the alpha pattern.")

    eps_test_count = []
    for picked_cls in range(10):
        cls_test_pattern = all_test_patterns.label2patterns[picked_cls]

        cls_eps_stable_idx = all_alpha_patterns[epsilon][picked_cls]["stable_idx"]
        cls_eps_alpha_pattern = all_alpha_patterns[epsilon][picked_cls]["alpha_pattern"]
        # cls_eps_alpha_pattern_active = all_alpha_patterns[epsilon][picked_cls]["active_neurons"]
        # cls_eps_alpha_pattern_nonactive = all_alpha_patterns[epsilon][picked_cls]["non_active_neurons"]


        # cls_test_data = test_by_cls[picked_cls]["data"]
        test_patterns_stable_neurons = cls_test_pattern[:, cls_eps_stable_idx]
        num_stable_neurons = len(cls_eps_stable_idx)

        check_patterns = test_patterns_stable_neurons == cls_eps_alpha_pattern
        # print(num_stable_neurons, check_patterns.sum(axis=-1))

        num_test_same_alpha_pattern = (check_patterns.sum(axis=-1) == num_stable_neurons).sum()
        # print(num_test_same_alpha_pattern)
        eps_test_count.append(num_test_same_alpha_pattern)
        print(f"class={picked_cls}, total = {num_test_same_alpha_pattern} / {len(cls_test_pattern)}")

    eps_test_count_df = pd.DataFrame(data=eps_test_count, columns=[epsilon], index=df_row_index)
    test_count_df = pd.concat((test_count_df, eps_test_count_df), axis=1)
    # test_count.append(eps_test_count)
# test_count = np.ndarray(eps_test_count)
# test_count_df = pd.DataFrame(data=test_count, columns=df_row_index, index=epsilons)
print(test_count_df.shape)
test_count_df.to_csv(os.path.join(log_dir, csv_filename))



## For each $ϵ$, check overlap of the alpha pattern of each pair of class

In [None]:
json_filename = "alpha_pattern_overlap.json"
epsilons = np.linspace(0, 0.5, num=50)

overlap = {}

for epsilon in epsilons:
    overlap[epsilon] = {}
    for c1 in range(10):
        c1_active = all_alpha_patterns[epsilon][c1]["active_neurons"]
        c1_nonactive = all_alpha_patterns[epsilon][c1]["non_active_neurons"]
        for c2 in range(c1+1, 10):
            pair_key = f"({c1}, {c2})"

            c2_active = all_alpha_patterns[epsilon][c2]["active_neurons"]
            c2_nonactive = all_alpha_patterns[epsilon][c2]["non_active_neurons"]

            overlap[epsilon][pair_key] = {}
            overlap[epsilon][pair_key]["active"] = np.intersect1d(c1_active, c2_active).tolist()
            overlap[epsilon][pair_key]["nonactive"] = np.intersect1d(c1_nonactive, c2_nonactive).tolist()
    

out_file = open(os.path.join(log_dir,json_filename), "w")
json.dump(overlap, out_file, indent = 4)
out_file.close()
            

In [None]:
extension_zip = ".zip"

zip_file = log_dir + extension_zip
!zip -r $zip_file $log_dir

# from google.colab import files
# files.download(zip_file)