In [None]:
from data import IRAVENDataModule


batch_size=16
data_module=IRAVENDataModule(batch_size=batch_size)
data_module.prepare_data()
data_module.setup()
train_dataloader,test_dataloader,val_dataloader=data_module.train_dataloader(),data_module.test_dataloader(),data_module.val_dataloader()




In [None]:
import scattering_transform
from scattering_transform import SCLTrainingWrapper
import utils
import torch
import torch.nn.functional as F
import wandb
import numpy as np
import os
import pickle

#
load=False
load_path='/Users/iyngkarrankumar/Documents/AI/AVR-functional-modularity/model_ckpts/pretrain_SCL/epoch=5_accuracy=12.5.ckpt'

#save 
save_freq=10000
savedir='model_ckpts/pretrain_SCL'

#logging
logging=False
watch_freq=1000

#training loop
train=True
n_epochs=10
n_batches=100
n_val_batches=100
val_every_n_steps=10
grad_clip_value=0.5
lr=1e-4

device=torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(f'Using device {device}')


#setup model
if load:
    print(f'Loading model from checkpoint {load_path}')


    with open (load_path,'rb') as f:
        load_dict=pickle.load(f)

    kwargs=load_dict['kwargs']
    SCL_model=SCLTrainingWrapper(scattering_transform.SCL(**kwargs))
    optimiser=torch.optim.Adam(SCL_model.parameters(),lr=lr)
    
    SCL_model.load_state_dict(load_dict['model state dict'])
    optimiser.load_state_dict(load_dict['Optimiser state dict'])

    start_epoch=load_dict['Epoch']
    run_id=load_dict['Run ID']

else:
    #setup
    kwargs={
        "image_size":160,                            # size of image
        "set_size": 9,                               # number of questions + 1 answer
        "conv_channels": [1, 16, 16, 32, 32, 32],    # convolutional channel progression, 1 for greyscale, 3 for rgb
        "conv_output_dim": 80,                       # model dimension, the output dimension of the vision net
        "attr_heads": 10,                            # number of attribute heads
        "attr_net_hidden_dims": [128],               # attribute scatter transform MLP hidden dimension(s)
        "rel_heads": 80,                             # number of relationship heads
        "rel_net_hidden_dims": [64, 23, 5] 
    }

    SCL_model=SCLTrainingWrapper(scattering_transform.SCL(**kwargs))
    optimiser=torch.optim.Adam(SCL_model.parameters(),lr=lr)

    start_epoch=0

SCL_model.to(device)




if train:
    if logging:
        if load:
            run=wandb.init(id=run_id,project='AVR',resume='must')

        else:
            name=str(input('Log name: '))
            run=wandb.init(project='AVR',name=name)
        wandb.watch(SCL_model,log='all',log_freq=watch_freq)
    
    SCL_model.train()
    for epoch in range(start_epoch,n_epochs):
        print(f'Starting epoch {epoch}')

        #train
        for batch_idx,batch in enumerate(train_dataloader):
            print(f'Train step {batch_idx}')
            if batch_idx==n_batches:
                break
            if n_batches=='full':
                pass

            matrix,targets,_,_=batch
            matrix,targets=matrix.to(device),targets.to(device)
            matrix=matrix.unsqueeze(2)
            questions,answers=matrix[:,0:8,:,:,],matrix[:,8:,:,:,]
            logits=SCL_model(questions,answers)
            train_loss=F.cross_entropy(logits,targets)
            train_accuracy=utils.calculate_accuracy(logits,targets)


            train_loss.backward()
            torch.nn.utils.clip_grad_norm_(SCL_model.parameters(),grad_clip_value)
            optimiser.step()


        #val
        if (batch_idx%val_every_n_steps)==0:
            SCL_model.eval()
            with torch.no_grad():
                losses=[]
                accuracies=[]
                for batch_idx,batch in enumerate(test_dataloader):
                    print(f'Validation step {batch_idx}')

                    if batch_idx==n_val_batches:
                        break
                    if n_val_batches=='full':
                        pass

                    matrix,targets,_,_=batch
                    matrix,targets=matrix.to(device),targets.to(device)
                    matrix=matrix.unsqueeze(2)
                    questions,answers=matrix[:,0:8,:,:,],matrix[:,8:,:,:,]
                    logits=SCL_model(questions,answers)
                    val_loss=F.cross_entropy(logits,targets);losses.append(val_loss.item())
                    val_accuracy=utils.calculate_accuracy(logits,targets);accuracies.append(val_accuracy)

            

        if logging:
            wandb.log({'epoch':epoch,
                        'Loss/train':train_loss,
                        'Accuracy/train':train_accuracy,
                        'Loss/val':np.mean(losses),
                        'Accuracy/val':np.mean(accuracies)})


        #save
        if (epoch%save_freq==0) and (epoch!=0):
            if not os.path.isdir(savedir):
                os.mkdir(savedir)

            save_dict={}
            save_dict['kwargs']=kwargs
            save_dict['model state dict']=SCL_model.state_dict()
            save_dict['Optimiser state dict']=optimiser.state_dict()
            save_dict['Epoch']=epoch

            if logging:
                save_dict['Run ID']=run.id
            else:
                save_dict['Run ID']=None

            fname=os.path.join(savedir,f'epoch={epoch}_accuracy={val_accuracy}.ckpt')
            with open(fname,'wb') as f:
                pickle.dump(save_dict,f)
                print('Checkpoint saved')


        print(f'Finished epoch {epoch}')


## Old code

In [None]:
from data import IRAVENDataModule


batch_size=16
data=IRAVENDataModule(batch_size=batch_size)
data.prepare_data()
data.setup()
train_dataloader,test_dataloader,val_dataloader=data.train_dataloader(),data.test_dataloader(),data.val_dataloader()




import scattering_transform
import importlib
importlib.reload(scattering_transform)


kwargs={
        "image_size":160,                            # size of image
        "set_size": 9,                               # number of questions + 1 answer
        "conv_channels": [1, 16, 16, 32, 32, 32],    # convolutional channel progression, 1 for greyscale, 3 for rgb
        "conv_output_dim": 80,                       # model dimension, the output dimension of the vision net
        "attr_heads": 10,                            # number of attribute heads
        "attr_net_hidden_dims": [128],               # attribute scatter transform MLP hidden dimension(s)
        "rel_heads": 80,                             # number of relationship heads
        "rel_net_hidden_dims": [64, 23, 5] 
    }

model=scattering_transform.SCL(**kwargs)
X,y,_,_=next(iter(train_dataloader))

questions,answers=X[:,0:8:,:,:,].unsqueeze(2),X[:,8:,:,:,].unsqueeze(2)
wrapped_model=scattering_transform.SCLTrainingWrapper(model)
out=wrapped_model(questions,answers)
