# Setup

## Imports

In [1]:
from IPython.display import clear_output

!pip3 install pyprind

clear_output()

In [2]:
import cv2
from PIL import Image
import matplotlib
import matplotlib.pyplot as plt
from matplotlib.patches import Polygon as poly

import sklearn
from sklearn import preprocessing
from sklearn.model_selection import train_test_split
import numpy as np
import pandas as pd
import gzip

import os
import glob
import time
import random
import json
import copy
import pyprind
import tqdm
import itertools
import pickle as pkl
from dataclasses import dataclass, field
from collections import Counter
from typing import Union, List, Dict, Any, Optional, cast

import torch
import torchvision

from torchvision import transforms
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

## Utility

In [3]:
def one_hot(x, num_classes=2):
    return torch.nn.functional.one_hot(x, num_classes=num_classes).squeeze(1)

In [4]:
def binary(x, width, channel_size):
    bits = width*channel_size
    mask = 2**torch.arange(bits-1,-1,-1).to(x.device, x.dtype)
    return x.unsqueeze(-1).bitwise_and(mask).ne(0).byte().reshape(x.size(0), args.width, args.channel_size)

In [5]:
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.cuda.manual_seed_all(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)

In [6]:
class Histogram_Counter:
    def __init__(self, low, high, bins=10, experiment='counter'):
        self.experiment = experiment
        self.delta = (high-low)/bins
        self.round = int(np.ceil(np.log10(bins)))
        self.ranges = [round(low+i*self.delta, self.round) for i in range(bins+1)]

        if os.path.exists(f'{self.experiment}.pkl'):
            self.counts = self.load()
        else:
            self.counts = Counter(dict([str(key), 0] for key in self.ranges))

        self.last_save = 0

    def add(self, x):
        if torch.is_tensor(x):
            x = x.detach().cpu().numpy()
        x = self.delta * (x//self.delta)
        a, b = np.unique(x, return_counts=True)
        self.counts += Counter(dict(zip(a.round(self.round), b)))

        if sum(self.counts.values())-self.last_save >= 1000:
            self.save()
            self.last_save = sum(self.counts.values())
    
    def get_bins(self):
        values = []
        scale = 1000
        for key in self.counts.keys():
            values.extend([float(key)]*(self.counts[key]//scale))
        return values

    def plot(self):
        plt.hist(self.get_bins(), bins=self.ranges, ec="k")
        plt.yticks([])
        plt.xlabel('Range')
        plt.title(self.experiment)
        plt.show()

    def save(self):
        with open(f'{self.experiment}.pkl', 'wb') as handle:
            pkl.dump(self.counts, handle, protocol=pkl.HIGHEST_PROTOCOL)

    def load(self):
        with open(f'{self.experiment}.pkl', 'rb') as handle:
            counts = pkl.load(handle)
        return counts

## Arguments

In [7]:
@dataclass
class TrainingArgs():

    seed: int = 1
    lr: float = 1e-4
    batch_size: int = 16384
    num_workers: int = os.cpu_count()
    max_epochs: str = 200
    device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # torch.device("mps" if torch.backends.mps.is_available() else "cpu")

    size: int = int(1e4)
    bound: int = 10
    input_size: int = 1
    path_size: int = 2
    width: int = 4
    channel_size: int = 8
    channels: tuple = (channel_size, ) * width
    use_bias: bool = False
    k: int = 1
    architecture: str = 'DNN'
    mode: str = 'Random'

    data: tuple = None

    root_dir: str = '/content/'
    checkpoint: str = '/content/'
    experiments: tuple = ('overlap_0', 'overlap_1', 'intersection', 'all')

args = TrainingArgs()

# Models

## Shallow

In [8]:
class Shallow(torch.nn.Module):
    def __init__(self, args):
        super(Shallow, self).__init__()
        self.args = args
        
        self.act1f = torch.where(torch.rand((self.args.path_size, self.args.channels[0]))>0.5, torch.ones(1,), torch.zeros(1,)).to(self.args.device)
        self.fc1v = nn.Linear(in_features=self.args.input_size, out_features=self.args.channels[0], bias=self.args.use_bias)

        self.act2f = torch.where(torch.rand((self.args.path_size, self.args.channels[1]))>0.5, torch.ones(1,), torch.zeros(1,)).to(self.args.device)
        self.fc2v = nn.Linear(in_features=self.args.channels[0], out_features=self.args.channels[1], bias=self.args.use_bias)

        self.act3f = torch.where(torch.rand((self.args.path_size, self.args.channels[2]))>0.5, torch.ones(1,), torch.zeros(1,)).to(self.args.device)
        self.fc3v = nn.Linear(in_features=self.args.channels[1], out_features=self.args.channels[2], bias=self.args.use_bias)

        self.act4f = torch.where(torch.rand((self.args.path_size, self.args.channels[3]))>0.5, torch.ones(1,), torch.zeros(1,)).to(self.args.device)
        self.fc4v = nn.Linear(in_features=self.args.channels[2], out_features=self.args.channels[3], bias=self.args.use_bias)

        self.fc5v = nn.Linear(in_features=self.args.channels[3], out_features=self.args.input_size, bias=self.args.use_bias)

        self.relu = nn.ReLU(inplace=True)

    def forward(self, x, path, gate=None, mode='train'):
        if mode=='train':
            preact1 = torch.stack([self.act1f[path[i], :] for i in range(len(path))])
            preact2 = torch.stack([self.act2f[path[i], :] for i in range(len(path))])
            preact3 = torch.stack([self.act2f[path[i], :] for i in range(len(path))])
            preact4 = torch.stack([self.act4f[path[i], :] for i in range(len(path))])
        elif mode=='intersection':
            assert gate.size(0)==x.size(0)
            preact1 = torch.prod(self.act1f, dim=0).repeat(x.size(0), 1) * gate[:, 0, :]
            preact2 = torch.prod(self.act2f, dim=0).repeat(x.size(0), 1) * gate[:, 1, :]
            preact3 = torch.prod(self.act3f, dim=0).repeat(x.size(0), 1) * gate[:, 2, :]
            preact4 = torch.prod(self.act4f, dim=0).repeat(x.size(0), 1) * gate[:, 3, :]
        elif mode=='all':
            assert gate.size(0)==x.size(0)
            preact1 = gate[:, 0, :]
            preact2 = gate[:, 1, :]
            preact3 = gate[:, 2, :]
            preact4 = gate[:, 3, :]
        elif isinstance(mode, int) and mode<self.args.path_size:
            assert gate.size(0)==x.size(0)
            preact1 = self.act1f[mode, :].repeat(x.size(0), 1) * gate[:, 0, :]
            preact2 = self.act2f[mode, :].repeat(x.size(0), 1) * gate[:, 1, :]
            preact3 = self.act3f[mode, :].repeat(x.size(0), 1) * gate[:, 2, :]
            preact4 = self.act4f[mode, :].repeat(x.size(0), 1) * gate[:, 3, :]

        out = self.fc1v(x) * preact1
        out = self.fc2v(out) * preact1
        out = self.fc3v(out) * preact3
        out = self.fc4v(out) * preact4
        out = self.fc5v(out)

        return out

In [9]:
shallow = Shallow(args).to(args.device)

In [10]:
x = torch.ones(2, 1).to(args.device)
gate = torch.where(torch.rand((4, 16))>0.5, torch.ones(1,), torch.zeros(1,)).to(args.device)

z = shallow(x, [0, 1], gate, mode='train')

# Train

In [11]:
model = Shallow(args).to(args.device)
optimizer = torch.optim.SGD(model.parameters(), lr=args.lr)
criterion = nn.MSELoss().to(args.device)

input = torch.rand(1, 1).to(args.device)
x = input.repeat(2, 1).to(args.device)
y = torch.tensor([[-1.], [1.]]).to(args.device)

I = torch.eye(2)

In [12]:
# model.train()
# for _ in tqdm.trange(500000):
#     optimizer.zero_grad()
    
#     y_pred = model(x, [0, 1], gate=None, mode='train')
#     loss = criterion(y_pred, y)

#     loss.backward()
#     optimizer.step()

# model.eval()
# y_pred = torch.where(model(x, [0, 1], gate=None, mode='train').detach().cpu()>0, torch.ones(1,), -1*torch.ones(1,)).to(args.device)
# accuracy = ((y_pred==y).sum()/y.size(0))*100

# time.sleep(2)
# print(f'Training | Loss = {round(loss.item(), 4)} Accuracy = {round(accuracy.item(), 4)}')

In [None]:
metrics = dict([[exp, Histogram_Counter(-1, 1, 10000, exp)] for exp in args.experiments])

model.eval()
for i in tqdm.trange(0, 2**(args.width*args.channel_size), args.batch_size):
    x = input.repeat(args.batch_size, 1).to(args.device)
    path = [None]*x.size(0)
    gate = binary(torch.arange(i, i+args.batch_size, 1), args.width, args.channel_size).to(args.device)

    metrics['overlap_0'].add(model(x, path, gate, mode=0))
    metrics['overlap_1'].add(model(x, path, gate, mode=1))
    metrics['intersection'].add(model(x, path, gate, mode='intersection'))
    metrics['all'].add(model(x, path, gate, mode='all'))

 28%|██▊       | 74011/262144 [10:38<26:40, 117.55it/s]

In [None]:
metrics['overlap_0'].plot()

In [None]:
metrics['overlap_1'].plot()

In [None]:
metrics['intersection'].plot()

In [None]:
metrics['all'].plot()