## Task dataset generation

### I-RAVEN Originals

In [None]:
import utils 

utils.view_matrices(source_dir='datasets/originals/',n_view=1)

### Task dataset

In [None]:
source_dir="datasets/project_demo_hexagons"
utils.view_matrices(source_dir=source_dir,n_view=1)

## Binary weight masking demonstration

### Setup

In [None]:
import masking, data, utils
from data import IRAVENDataModule
from models.SCL_model import SCL,SCLTrainingWrapper
import torch
import seaborn as sns
import matplotlib.pyplot as plt
device=torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

task_dataset_name='squares'
SCL_version_accuracy=80

#get a task dataset
dataset_path=f'datasets/{task_dataset_name}'
data_module=IRAVENDataModule(batch_size=8,split=(90,10,0))
data_module.prepare_data()
data_module.setup(root_dir=dataset_path)
train_dataloader_task,test_dataloader_task=data_module.train_dataloader(),data_module.test_dataloader()

#get IRAVEN originals
originals_path='datasets/originals_masking'
data_module_originals=IRAVENDataModule(batch_size=8,split=(90,10,0))
data_module_originals.prepare_data()
data_module_originals.setup(root_dir=originals_path)
test_dataloader_originals=data_module_originals.test_dataloader()


#get unmasked SCL
scl_kwargs={
        "image_size":160,                            # size of image
        "set_size": 9,                               # number of questions + 1 answer
        "conv_channels": [1, 16, 16, 32, 32, 32],    # convolutional channel progression, 1 for greyscale, 3 for rgb
        "conv_output_dim": 80,                       # model dimension, the output dimension of the vision net
        "attr_heads": 10,                            # number of attribute heads
        "attr_net_hidden_dims": [128],               # attribute scatter transform MLP hidden dimension(s)
        "rel_heads": 80,                             # number of relationship heads
        "rel_net_hidden_dims": [64, 32, 5] 
    }
model=SCL(**scl_kwargs)

#load
model_ckpt=f'model_ckpts/pretrain_SCL/SCL_pretrain_{SCL_version_accuracy}.ckpt'
state_dict=utils.get_SCL_state_dict(model_ckpt)
model.load_state_dict(state_dict)
model.eval() #for batch norm

init_kwargs={
    'model':model,
    'train_dataloader':train_dataloader_task,
    'test_dataloader1':test_dataloader_task,
    'test_dataloader2':test_dataloader_originals,
    'device':device,
    'savedir':None,
    'logit_init':2.5,
}

### What do we mean by binary weight masking?

In [None]:
def plot_tens(tens):
    tens=tens.detach().numpy()
    fig,ax=plt.subplots(figsize=(4,6))

    plot_kwargs={
        'data':tens,
        'xticklabels':False,
        'yticklabels':False,
        'cmap':'Greys_r',
        'annot':True,
        'ax':ax,
        'annot_kws':{'fontsize':20},
        'vmin':0,
        'vmax':1,
    }


    sns.heatmap(**plot_kwargs)   
  

In [None]:
tens=torch.rand(size=(5,1))
plot_tens(tens)

In [None]:
mask=torch.tensor([0,1,1,0,0]); mask=mask.unsqueeze(1)
masked_tens=tens*mask



### Initial binaries tensor

In [None]:
#aux functions

def plot_binary(binary_tensor,annot=False,title='Title'): 

    tens=binary_tensor.detach().numpy()
    fig,ax=plt.subplots(figsize=(10,6))
    fig.suptitle(title,fontsize=20)

    plot_kwargs={
        'data':tens,
        'xticklabels':False,
        'yticklabels':False,
        'cmap':'Greys_r',
        'annot':annot,
        'ax':ax,
    }


    sns.heatmap(**plot_kwargs)

    ax.patch.set_edgecolor('red')  

    ax.patch.set_linewidth(5) 


masked_scl=masking.MaskedSCLModel(init_kwargs)

layer_name='rel_net.net.2.weight'

logit_1=masked_scl.logit_tensors_dict[layer_name]
masked_scl.transform_logit_tensors() #map logits to binaries
binary_1=masked_scl.binaries[layer_name]

plot_binary(binary_1,title='initial binary mask')



### Train mask

In [None]:
##train train train

train_kwargs={
    'alpha':1e-6,
    'lr':1e-3,
    'n_epochs':1,
    'n_batches':10,
    'val_every_n_steps':10000,
    'eval_every_n_steps':1e10,
    'n_val_batches':10,
    'n_eval_batches':10,
    'save_freq_epoch':100000,
    'logging':False,
    }

masked_scl.train(**train_kwargs)

### Updated binary tensor

In [None]:

binary_2=masked_scl.binaries[layer_name]
plot_binary(binary_2,title='binary mask Epoch 1')

In [None]:
import utils
from utils import CPU_Unpickler

ckpt_file='masks/SCL_90/circles/alpha=2.1877616239495517e-06_checkpoint_step=15957_epoch=100'
device=torch.device('cpu')
with open(ckpt_file,'rb') as f:
    data=CPU_Unpickler(f).load()

binary_mask=utils.transform_logit_tensors(data['logit_tensors_dict'])
binary_tensor_100=binary_mask[layer_name]

plot_binary(binary_tensor_100,title='Binary mask epoch 100')



In [None]:
import seaborn as sns
import numpy as np

def plot_arr(arr,title):
    fig,ax=plt.subplots(figsize=(8,8))

    plot_kwargs={
        'data':arr,
        'xticklabels':False,
        'yticklabels':False,
        'cmap':'Greys_r',
        'annot':True,
        'ax':ax,
        'annot_kws':{'fontsize':15},
        'vmin':0,
        'vmax':1,
        'cbar':False,
        'linewidths':0.5,
        'linecolor':'black',
    }

    sns.heatmap(**plot_kwargs)   
    plt.title(title,fontsize=20,pad=20)
  


arr_0=np.random.random_sample(size=(8,8))
plot_arr(arr_0,title="Epoch 0")
    
bm_10=np.random.randint(0,2,size=(8,8))
arr_10 = arr_0*bm_10
plot_arr(arr_10,title="Epoch 10")

bm_50=(np.random.randint(0,2,size=(8,8)))*(np.random.randint(0,2,size=(8,8)))
arr_50=arr_10*bm_50
plot_arr(arr_50,title="Epoch 50")
