In [None]:
import os
import torch
import torch.nn.functional as F
import pytorch_lightning as pl;
import importlib
import matplotlib.pyplot as plt
import numpy as np
import data
import utils
import sys
import importlib
import wandb
import copy
import pickle

from torch import nn
from torch.nn import functional as F
from torchvision import transforms
from torchvision.datasets import MNIST
from abc import ABC,abstractmethod
from torch.utils.data import DataLoader, Subset
from copy import deepcopy
from torch.special import logit
from pytorch_lightning import loggers as pl_loggers
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks import ModelCheckpoint
from torch.utils.tensorboard import SummaryWriter
from pytorch_lightning.loggers import WandbLogger
from timeit import default_timer as timer

importlib.reload(data)
importlib.reload(utils)

pl.seed_everything(42)

debug=False

device=torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')


In [None]:
#testing masked scl

import masking,data,utils
from data import IRAVENDataModule
importlib.reload(masking)
importlib.reload(data)
importlib.reload(utils)

from models.SCL_model import SCLTrainingWrapper,SCL

load_path='/Users/iyngkarrankumar/Documents/AI/AVR-functional-modularity/SCL_pretrain_80.ckpt'


#I-RAVEN dataset
batch_size=8
data=IRAVENDataModule(batch_size=batch_size)
data.prepare_data()
data.setup()
train_dataloader,test_dataloader,val_dataloader=data.train_dataloader(),data.test_dataloader(),data.val_dataloader()


#setup model
scl_kwargs={
    "image_size":160,                            # size of image
    "set_size": 9,                               # number of questions + 1 answer
    "conv_channels": [1, 16, 16, 32, 32, 32],    # convolutional channel progression, 1 for greyscale, 3 for rgb
    "conv_output_dim": 80,                       # model dimension, the output dimension of the vision net
    "attr_heads": 10,                            # number of attribute heads
    "attr_net_hidden_dims": [128],               # attribute scatter transform MLP hidden dimension(s)
    "rel_heads": 80,                             # number of relationship heads
    "rel_net_hidden_dims": [64, 32, 5] 
}
model=SCL(**scl_kwargs)

#load
state_dict=utils.get_SCL_state_dict(load_path)
model.load_state_dict(state_dict)
wrapped_model=SCLTrainingWrapper(model)



kwargs={
    'model':model,
    'train_dataloader':train_dataloader,
    'test_dataloader1':val_dataloader,
    'test_dataloader2':test_dataloader,
    'savedir':'model_ckpts/FFN',
}

X,y,*rest=next(iter(train_dataloader))

masked_scl=masking.MaskedSCLModel(kwargs)

## $\alpha$ sweep

In [None]:

test=True

import masking,data,models.SCL_model as SCL_model
from data import IRAVENDataModule
from models.SCL_model import SCL
importlib.reload(masking)
importlib.reload(data)
importlib.reload(SCL_model)

#setup
model_ckpt='/Users/iyngkarrankumar/Documents/AI/AVR-functional-modularity/SCL_pretrain_80.ckpt'
task_path='datasets/squares'
save_freq= 5 if test else 10000
batch_size=8

#dataset setup
if 1:
    #task dataset
    path=task_path
    data_module=IRAVENDataModule(batch_size=batch_size)
    data_module.prepare_data()
    data_module.setup(root_dir=path)
    train_dataloader_task,test_dataloader_task=data_module.train_dataloader(),data_module.test_dataloader()

    #NOT task dataset
    path_='datasets/originals_masking'
    data_module_=IRAVENDataModule(batch_size=batch_size)
    data_module_.prepare_data()
    data_module_.setup(root_dir=path_)
    test_dataloader_not_task=data_module_.test_dataloader()

#model setup
if 1:
    scl_kwargs={
        "image_size":160,                            # size of image
        "set_size": 9,                               # number of questions + 1 answer
        "conv_channels": [1, 16, 16, 32, 32, 32],    # convolutional channel progression, 1 for greyscale, 3 for rgb
        "conv_output_dim": 80,                       # model dimension, the output dimension of the vision net
        "attr_heads": 10,                            # number of attribute heads
        "attr_net_hidden_dims": [128],               # attribute scatter transform MLP hidden dimension(s)
        "rel_heads": 80,                             # number of relationship heads
        "rel_net_hidden_dims": [64,32,5] 
    }
    model=SCL(**scl_kwargs)
    state_dict=utils.get_SCL_state_dict(model_ckpt)
    model.load_state_dict(state_dict)



def sweep_function(test=test,debug=False):

    if not debug:
        run=wandb.init(project='AVR')

    task_train_dataloader=train_dataloader_task
    task_test_dataloader=test_dataloader_task
    _task_test_dataloader=test_dataloader_not_task

    alpha=1e-5 if debug else wandb.config.alpha
    

    task_name=os.path.basename(task_path)
    savedir=os.path.join('masks/SCL',task_name)

    init_kwargs={
        'model':model,
        'train_dataloader':train_dataloader_task,
        'test_dataloader1':test_dataloader_task,
        'test_dataloader2':test_dataloader_not_task,
        'savedir':savedir,
    }

    #train kwargs setup
    train_kwargs={
        'n_epochs':5 if test else 10,
        'n_batches':2 if test else 'full',
        'val_every_n_steps':10,
        'eval_every':2,
        'n_val_batches':2 if test else 100,
        'n_eval_batches':2 if test else 100,
        'save_freq':save_freq,
        'logging':False if debug else True,
        }
    
    
    masked_scl=masking.MaskedSCLModel(init_kwargs)
    masked_scl.train(alpha=alpha,**train_kwargs)




In [None]:
alpha_values=[1e-10,1e-5] if test else [1e-10,1e-6,1e-5,1e-4,1e-3,1e-2]

sweep_configuration={
    'method':'grid',
    'name':str(input('Enter sweep name')),
    'metric':{
        'goal':'maximize',
        'name':'validation_accuracy',
        },
    'parameters':{
        'alpha':{'values':alpha_values},
        }
    }

sweep_id=wandb.sweep(sweep=sweep_configuration,project='AVR')
wandb.agent(sweep_id,function=sweep_function)
wandb.finish()

## Dataset sweep

In [None]:
import masking,data,SCL_model
from data import IRAVENDataModule
importlib.reload(masking)
importlib.reload(data)
importlib.reload(SCL_model)

test=True



#setup model
if 1:
    scl_kwargs={
    "image_size":160,                            # size of image
    "set_size": 9,                               # number of questions + 1 answer
    "conv_channels": [1, 16, 16, 32, 32, 32],    # convolutional channel progression, 1 for greyscale, 3 for rgb
    "conv_output_dim": 80,                       # model dimension, the output dimension of the vision net
    "attr_heads": 10,                            # number of attribute heads
    "attr_net_hidden_dims": [128],               # attribute scatter transform MLP hidden dimension(s)
    "rel_heads": 80,                             # number of relationship heads
    "rel_net_hidden_dims": [64,32,5] 
    }   
    model=SCL(**scl_kwargs)

#NOT task dataloaders
if 1:
    nt_rootdir='datasets/originals'

    data_module_nt=IRAVENDataModule()
    data_module_nt.prepare_data()
    data_module_nt.setup(root_dir=nt_rootdir)
    val_dataloader_not_task=data_module_nt.val_dataloader()


#sweep function - this runs for each of the parameters set in sweep config
def sweep_function(test=test,model_type='FFN'):

    run=wandb.init(project='AVR')
    
    #dataset setup
    dataset_name=wandb.config.dataset_name
    t_rootdir=os.path.join('datasets',dataset_name)
    if 1:
        data_module_t=IRAVENDataModule()
        data_module_t.prepare_data()
        data_module_t.setup(root_dir=t_rootdir)
        train_dataloader_task,val_dataloader_task=data_module_t.train_dataloader(),data_module_t.val_dataloader()

    savedir=os.path.join('mask_ckpts',dataset_name)
    #init kwargs
    init_kwargs={
        'model':model,
        'train_dataloader':train_dataloader_task,
        'test_dataloader1':val_dataloader_task,
        'test_dataloader2':val_dataloader_not_task,
        'savedir':savedir,
    }

    train_kwargs={
    'alpha':1e-5, #input alpha from alpha sweep
    'n_epochs':2 if test else 50,
    'n_batches':5 if test else 'full',
    'val_every_n_steps':10,
    'eval_every':2,
    'n_val_batches':100,
    'n_eval_batches':100,
    'save_freq':1 if test else 10, 
    'logging':True,
    }
    
    masked_scl=masking.MaskedSCLModel(init_kwargs)
    masked_scl.train(**train_kwargs)




In [None]:
#we can also get dataset names by reading path where datasets are stored
dataset_names=['squares','circles','triangles','max rotation'] 

sweep_configuration={
    'method':'grid',
    
    'name':str(input('Enter sweep name')),
    'metric':{
        'goal':'maximize',
        'name':'validation_accuracy',
        },
    'parameters':{
        'dataset_name':dataset_names
        }
    }

sweep_id=wandb.sweep(sweep=sweep_configuration,project='AVR')
wandb.agent(sweep_id,function=sweep_function)
wandb.finish()