# Setup

## Imports

In [1]:
from IPython.display import clear_output

!pip3 install pyprind

clear_output()

In [2]:
import cv2
from PIL import Image
import matplotlib
import matplotlib.pyplot as plt
from matplotlib.patches import Polygon as poly

import sklearn
from sklearn import preprocessing
from sklearn.model_selection import train_test_split
import numpy as np
import pylab as pl
import pandas as pd
import gzip

import os
import glob
import time
import random
import json
import copy
import pyprind
import tqdm
import itertools
import pickle as pkl
from dataclasses import dataclass, field
from collections import Counter
from typing import Union, List, Dict, Any, Optional, cast

import torch
import torchvision

from torchvision import transforms
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

## Utility

In [3]:
def one_hot(x, num_classes=2):
    return torch.nn.functional.one_hot(x, num_classes=num_classes).squeeze(1)

In [4]:
def binary(x, width, channel_size):
    bits = width*channel_size
    mask = 2**torch.arange(bits-1,-1,-1).to(x.device, x.dtype)
    return x.unsqueeze(-1).bitwise_and(mask).ne(0).byte().reshape(x.size(0), args.width, args.channel_size)

In [5]:
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.cuda.manual_seed_all(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)

In [6]:
class Histogram_Counter:
    def __init__(self, low, high, bins=10, experiment='counter'):
        self.experiment = experiment
        self.delta = (high-low)/bins
        self.round = int(np.ceil(np.log10(bins)))
        self.ranges = [round(low+i*self.delta, self.round) for i in range(bins+1)]

        if os.path.exists(f'{self.experiment}.pkl'):
            self.counts = self.load()
        else:
            self.counts = Counter(dict([str(key), 0] for key in self.ranges))

        self.last_save = 0

    def add(self, x):
        if torch.is_tensor(x):
            x = x.detach().cpu().numpy()
        x = self.delta * (x//self.delta)
        a, b = np.unique(x, return_counts=True)
        self.counts += Counter(dict(zip(a.round(self.round), b)))

        if sum(self.counts.values())-self.last_save >= 1000:
            self.save()
            self.last_save = sum(self.counts.values())
    
    def get_bins(self):
        values = []
        scale = 1
        for key in self.counts.keys():
            values.extend([float(key)]*(self.counts[key]//scale))
        return values

    def plot(self):
        plt.hist(self.get_bins(), bins=self.ranges, ec="k")
        plt.yticks([])
        plt.xlabel('Range')
        plt.title(self.experiment)
        plt.show()

    def save(self):
        with open(f'{self.experiment}.pkl', 'wb') as handle:
            pkl.dump(self.counts, handle, protocol=pkl.HIGHEST_PROTOCOL)

    def load(self):
        with open(f'{self.experiment}.pkl', 'rb') as handle:
            counts = pkl.load(handle)
        return counts

## Arguments

In [7]:
@dataclass
class TrainingArgs():

    seed: int = 1
    lr: float = 1e-4
    batch_size: int = 32
    num_workers: int = os.cpu_count()
    max_epochs: str = 200
    device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # torch.device("mps" if torch.backends.mps.is_available() else "cpu")

    size: int = int(1e4)
    bound: int = 10
    input_size: int = 1
    output_size: int = 1
    path_size: int = 2
    depth: int = 4
    channel_size: int = 128
    use_bias: bool = False
    k: int = 1
    architecture: str = 'DNN'
    mode: str = 'Random'

    data: tuple = None

    root_dir: str = '/content/'
    checkpoint: str = '/content/'
    experiments: tuple = ('overlap_0', 'overlap_1', 'intersection', 'all')

args = TrainingArgs()

# Models

## Shallow

In [8]:
class Shallow(torch.nn.Module):
    def __init__(self, args):
        super(Shallow, self).__init__()
        self.args = args

        self.act = []
        self.fc = []
        self.init = []
        self.masks = []
        for depth in range(self.args.depth):
            in_features = self.args.input_size if depth==0 else self.args.channel_size
            out_features = self.args.output_size if depth==self.args.depth-1 else self.args.channel_size
            act = torch.ones(self.args.path_size, out_features) if depth==self.args.depth-1 else torch.from_numpy(np.random.choice([0, 1], size=(self.args.path_size, out_features), p=[1./2, 1./2]))

            self.act.append(act.to(self.args.device))
            self.fc.append(nn.Linear(in_features, out_features, bias=self.args.use_bias))
            torch.nn.init.uniform_(self.fc[-1].weight, -1/out_features, 1/out_features)
            self.init.append(1/out_features)
            self.masks.append(torch.ones_like(self.fc[-1].weight).to(self.args.device))
            self.fc[-1].weight.data = torch.where(self.fc[-1].weight.data>0, 1/out_features, -1/out_features)
        self.fc = nn.Sequential(*self.fc)

        self.relu = nn.ReLU(inplace=True)

    def forward(self, x, gate):
        out = x
        for i in range(self.args.depth):
            preact = torch.stack([self.act[i][gate[j], :] for j in range(len(gate))])
            out = self.fc[i](out) * preact

        return out

In [9]:
shallow = Shallow(args).to(args.device)

In [10]:
x = torch.ones(2, 1).to(args.device)

z = shallow(x, [0, 1])

# Train

In [53]:
model = Shallow(args).to(args.device)
optimizer = torch.optim.SGD(model.parameters(), lr=args.lr)
criterion = nn.MSELoss().to(args.device)

input = torch.rand(1, 1).to(args.device)
x = input.repeat(2, 1).to(args.device)
y = torch.tensor([[-1.], [1.]]).to(args.device)

I = torch.eye(2)

z = input.squeeze(-1)

init_weights = [copy.deepcopy(model.fc[i].weight) for i in range(len(model.fc))]
gates_0 = [model.act[i][0] for i in range(len(model.act))]
gates_1 = [model.act[i][1] for i in range(len(model.act))]
gates_in = [model.act[i][0]*model.act[i][1] for i in range(len(model.act))]
gates_all = [torch.ones_like(model.act[i][1]) for i in range(len(model.act))]

In [55]:
train = 100000
sparse = 1
model.train()
loss = torch.tensor(100)

epochs = 0
while epochs<sparse:
    optimizer.zero_grad()
    
    y_pred = model(x, [0, 1])
    loss = criterion(y_pred, y)

    loss.backward()
    # optimizer.step()

    for i in range(model.args.depth):
        ind_pp = ((model.fc[i].weight>0)*(model.fc[i].weight.grad>0)).nonzero(as_tuple=True)
        ind_pn = ((model.fc[i].weight>0)*(model.fc[i].weight.grad<0)).nonzero(as_tuple=True)
        ind_nn = ((model.fc[i].weight<0)*(model.fc[i].weight.grad<0)).nonzero(as_tuple=True)
        ind_np = ((model.fc[i].weight<0)*(model.fc[i].weight.grad>0)).nonzero(as_tuple=True)

        # model.fc[i].weight.data[ind_pp] += model.init[i]
        # model.fc[i].weight.data[ind_nn] += model.init[i]
        model.fc[i].weight.data[ind_pn] = torch.maximum(model.fc[i].weight.data[ind_pn]-model.init[i], torch.zeros_like(model.fc[i].weight.data[ind_pn]).to(model.args.device))
        model.fc[i].weight.data[ind_np] = torch.minimum(model.fc[i].weight.data[ind_np]+model.init[i], torch.zeros_like(model.fc[i].weight.data[ind_np]).to(model.args.device))
        zeros = ((model.fc[i].weight==0)).nonzero(as_tuple=True)
        model.masks[i][zeros] = 0
        model.fc[i].weight.data *=  model.masks[i]

    epochs += 1
    print(loss.item())

print('\n#####\n')

epochs = 0
while epochs<train:
    optimizer.zero_grad()
    
    y_pred = model(x, [0, 1])
    loss = criterion(y_pred, y)

    loss.backward()
    optimizer.step()

    epochs += 1
    if (epochs+1)%(train//10)==0:
        print(loss.item())

model.eval()
y_pred = torch.where(model(x, [0, 1]).detach().cpu()>0, torch.ones(1,), -1*torch.ones(1,)).to(args.device)
accuracy = ((y_pred==y).sum()/y.size(0))*100

time.sleep(2)
print(f'Training | Loss = {round(loss.item(), 4)} Accuracy = {round(accuracy.item(), 4)}')

trained_weights = [copy.deepcopy(model.fc[i].weight) for i in range(len(model.fc))]

1.0002206563949585

#####

1.0000629425048828
0.9987169504165649
0.9968386888504028
0.9933593273162842
0.9851056933403015
0.9585400223731995
0.8218845129013062
0.17003804445266724
0.00027216499438509345
9.983367732502302e-08
Training | Loss = 0.0 Accuracy = 100.0


In [56]:
for i in range(len(init_weights)):
    t = ((init_weights[i]/model.init[i])==0).sum()/init_weights[i].flatten().shape[0]
    print(t)
print('\n######\n')
for i in range(len(init_weights)):
    t = ((trained_weights[i]/model.init[i])==0).sum()/trained_weights[i].flatten().shape[0]
    print(t)

tensor(0., device='cuda:0')
tensor(0., device='cuda:0')
tensor(0., device='cuda:0')
tensor(0., device='cuda:0')

######

tensor(0., device='cuda:0')
tensor(0., device='cuda:0')
tensor(0., device='cuda:0')
tensor(0., device='cuda:0')


In [57]:
model.eval()
y_pred = torch.where(model(x, [0, 1]).detach().cpu()>0, torch.ones(1,), -1*torch.ones(1,)).to(args.device)
accuracy = ((y_pred==y).sum()/y.size(0))*100

time.sleep(2)
print(f'Training | Loss = {round(loss.item(), 4)} Accuracy = {round(accuracy.item(), 4)}')

trained_weights = [copy.deepcopy(model.fc[i].weight) for i in range(len(model.fc))]

Training | Loss = 0.0 Accuracy = 100.0


In [58]:
model(x, [0, 1])

tensor([[-0.9996],
        [ 1.0001]], device='cuda:0', grad_fn=<MulBackward0>)

In [None]:
# !rm -f *.pkl
# metrics = dict([[exp, Histogram_Counter(-1, 1, 10000, exp)] for exp in args.experiments])

---

In [None]:
z = input.squeeze(-1)

w0, g0 = init_weights[0], gates_all[0]
w1, g1 = init_weights[1], gates_all[1]
w2, g2 = init_weights[2], gates_all[2]
w3, g3 = init_weights[3], gates_all[3]

In [None]:
def find_paths(z, weights, gates):
    w0, g0 = weights[0], gates[0]
    w1, g1 = weights[1], gates[1]
    w2, g2 = weights[2], gates[2]
    w3, g3 = weights[3], gates[3]

    paths = []
    for i0 in range(len(w0)):
        for i1 in range(len(w1)):
            for i2 in range(len(w2)):
                for i3 in range(len(w3)):
                    path = ( z * w0[i0][0]*g0[i0] * w1[i1][i0]*g1[i1] * w2[i2][i1]*g2[i2] * w3[i3][i2]*g3[0]).item()
                    if g0[i0]!=0 and g1[i1]!=0 and g2[i2]!=0 and g3[0]!=0:
                        paths.append(path)
    return paths

---

In [None]:
paths_0_init = find_paths(z, init_weights, gates_0)
paths_1_init = find_paths(z, init_weights, gates_1)
paths_in_init = find_paths(z, init_weights, gates_in)
paths_all_init = find_paths(z, init_weights, gates_all)

paths_0_train = find_paths(z, trained_weights, gates_0)
paths_1_train = find_paths(z, trained_weights, gates_1)
paths_in_train = find_paths(z, trained_weights, gates_in)
paths_all_train = find_paths(z, trained_weights, gates_all)

In [None]:
paths_0_init = np.array(paths_0_init)
paths_1_init = np.array(paths_1_init)
paths_in_init = np.array(paths_in_init)
paths_all_init = np.array(paths_all_init)

paths_0_train = np.array(paths_0_train)
paths_1_train = np.array(paths_1_train)
paths_in_train = np.array(paths_in_train)
paths_all_train = np.array(paths_all_train)

In [None]:
l = min(min(min(paths_0_init), min(paths_1_init)), min(min(paths_in_init), min(paths_all_init)))
r = max(max(max(paths_0_init), max(paths_1_init)), max(max(paths_in_init), max(paths_all_init)))

---

In [None]:
h0n, x0n = np.histogram(-paths_0_train[paths_0_train<0], bins=np.logspace(np.log10(1e-10), 0, 500))
h0p, x0p = np.histogram(paths_0_train[paths_0_train>0], bins=np.logspace(np.log10(1e-10), 0, 500))
h0p, x0p = h0p[::-1], x0p[::-1]

In [None]:
sum = 0
for i in range(len(h0n)):
    sum += -h0n[i]*x0n[i]
    print(h0n[i], '\t', np.format_float_scientific(-x0n[i], precision=3), '\t', np.format_float_scientific(-h0n[i]*x0n[i], precision=2), '\t', np.format_float_scientific(sum, precision=2))

print('\n######\n')
sum = 0
for i in range(len(h0p)):
    sum += h0p[i]*x0p[i]
    print(h0p[i], '\t', np.format_float_scientific(x0p[i], precision=3), '\t', np.format_float_scientific(h0p[i]*x0p[i], precision=2), '\t', np.format_float_scientific(sum, precision=2))

0 	 -1.e-10 	 0.e+00 	 0.e+00
0 	 -1.047e-10 	 0.e+00 	 0.e+00
0 	 -1.097e-10 	 0.e+00 	 0.e+00
0 	 -1.148e-10 	 0.e+00 	 0.e+00
0 	 -1.203e-10 	 0.e+00 	 0.e+00
0 	 -1.26e-10 	 0.e+00 	 0.e+00
0 	 -1.319e-10 	 0.e+00 	 0.e+00
0 	 -1.381e-10 	 0.e+00 	 0.e+00
11 	 -1.447e-10 	 -1.59e-09 	 -1.59e-09
1 	 -1.515e-10 	 -1.51e-10 	 -1.74e-09
0 	 -1.586e-10 	 0.e+00 	 -1.74e-09
0 	 -1.661e-10 	 0.e+00 	 -1.74e-09
0 	 -1.74e-10 	 0.e+00 	 -1.74e-09
0 	 -1.822e-10 	 0.e+00 	 -1.74e-09
0 	 -1.908e-10 	 0.e+00 	 -1.74e-09
0 	 -1.998e-10 	 0.e+00 	 -1.74e-09
0 	 -2.092e-10 	 0.e+00 	 -1.74e-09
0 	 -2.191e-10 	 0.e+00 	 -1.74e-09
0 	 -2.295e-10 	 0.e+00 	 -1.74e-09
0 	 -2.403e-10 	 0.e+00 	 -1.74e-09
0 	 -2.517e-10 	 0.e+00 	 -1.74e-09
0 	 -2.635e-10 	 0.e+00 	 -1.74e-09
2 	 -2.76e-10 	 -5.52e-10 	 -2.29e-09
8 	 -2.890e-10 	 -2.31e-09 	 -4.61e-09
14 	 -3.027e-10 	 -4.24e-09 	 -8.84e-09
1 	 -3.17e-10 	 -3.17e-10 	 -9.16e-09
0 	 -3.319e-10 	 0.e+00 	 -9.16e-09
0 	 -3.476e-10 	 0.e+00 	 -9.16e-09
0 	