In [None]:
import importlib
import models.cnn_mlp
import data
import utils
import torch
import torch.nn.functional as F
import wandb
import numpy as np
import os
import pickle

In [None]:
importlib.reload(data)
importlib.reload(models.cnn_mlp)
importlib.reload(utils)

from data import IRAVENDataModule

load=False
train=True
logging=False
debug=True

#load,save,log hyperparams
load_path=None
save_freq=1e10
savedir='model_ckpts/pretrain_SCL'
logging=False
watch_freq=1e10
device='cpu'

#train hyperparams
n_epochs=2 if debug else 10
n_batches=10 if debug else 'full'
n_val_batches=2 if debug else 'full'
grad_clip_value=0.5
lr=1e-4

#data hyperparams
batch_size=16
split=(90,10,0)

#----------------------------------------


device=torch.device(device)
print(f'Using device {device}')


#data
data_module=IRAVENDataModule(batch_size=batch_size,split=split)
data_module.prepare_data()
data_module.setup()
train_dataloader,test_dataloader,val_dataloader=data_module.train_dataloader(),data_module.test_dataloader(),data_module.val_dataloader()
x,y,*rest=next(iter(train_dataloader))



#setup model
if load:
    print(f'Loading model from checkpoint {load_path}')
    with open (load_path,'rb') as f:
        load_dict=pickle.load(f)
    kwargs=load_dict['kwargs']

    CNN_MLP_model=models.cnn_mlp.CNN_MLP(**kwargs)
    optimiser=torch.optim.Adam(CNN_MLP_model.parameters(),lr=lr)

    CNN_MLP_model.load_state_dict(load_dict['model state dict'])
    optimiser.load_state_dict(load_dict['Optimiser state dict'])

    start_epoch=load_dict['Epoch']
    run_id=load_dict['Run ID']

else:
    kwargs={

    }

    CNN_MLP_model=models.cnn_mlp.CNN_MLP(**kwargs)
    optimiser=torch.optim.Adam(CNN_MLP_model.parameters(),lr=lr)

    start_epoch=0

CNN_MLP_model.to(device)


#train loop
if train:
    if logging:
        if load:
            run=wandb.init(id=run_id,project='AVR',resume='must')

        else:
            name=str(input('Log name: '))
            run=wandb.init(project='AVR',name=name)
        wandb.watch(CNN_MLP_model,log='all',log_freq=watch_freq)
    
    for epoch in range(start_epoch,n_epochs):
        print(f'Starting epoch {epoch}')

        #train
        CNN_MLP_model.train()
        for batch_idx,batch in enumerate(train_dataloader):
            #print(f'Train step {batch_idx}')
            optimiser.zero_grad()


            if batch_idx==n_batches:
                break
            if n_batches=='full':
                pass

            x,y,*rest=batch
            x,y=x.to(device),y.to(device)
            logits=CNN_MLP_model(x)
            
            train_loss=F.cross_entropy(logits,y)
            train_accuracy=utils.calculate_accuracy(logits,y)
            train_loss.backward()
            torch.nn.utils.clip_grad_norm_(CNN_MLP_model.parameters(),grad_clip_value)
            optimiser.step()


        #val every epoch
        CNN_MLP_model.eval()
        with torch.no_grad():
            val_losses=[] #accumulate losses over all validation batches
            val_accuracies=[] #accumulate accuracies over all validation batches
            for batch_idx,batch in enumerate(test_dataloader):
                #print(f'Validation step {batch_idx}')

                if batch_idx==n_val_batches:
                    break
                if n_val_batches=='full':
                    pass

                x,y,*rest=batch
                x,y=x.to(device),y.to(device)
                logits=CNN_MLP_model(x)
                val_loss=F.cross_entropy(logits,y);val_losses.append(val_loss.item())
                val_accuracy=utils.calculate_accuracy(logits,y);val_accuracies.append(val_accuracy)

            

        #logging
        if logging:
            wandb.log({
                'epoch':epoch,
                'Loss/train':train_loss,
                'Accuracy/train':train_accuracy,
                'Loss/val':np.mean(val_losses),
                'Accuracy/val':np.mean(val_accuracies)
                })



        #save
        if (epoch%save_freq==0) and (epoch!=0):
            if not os.path.isdir(savedir):
                os.mkdir(savedir)

            save_dict={}
            save_dict['kwargs']=kwargs
            save_dict['model state dict']=CNN_MLP_model.state_dict()
            save_dict['Optimiser state dict']=optimiser.state_dict()
            save_dict['Epoch']=epoch

            if logging:
                save_dict['Run ID']=run.id
            else:
                save_dict['Run ID']=None

            fname=os.path.join(savedir,f'epoch={epoch}_accuracy={val_accuracy}.ckpt')
            with open(fname,'wb') as f:
                pickle.dump(save_dict,f)
                print('Checkpoint saved')


        print(f'Finished epoch {epoch} - validation loss {np.mean(val_losses)}')





In [None]:
CNN_MLP_model=models.cnn_mlp.CNN_MLP()