In [None]:
import os
import torch
import torch.nn.functional as F
import pytorch_lightning as pl;
import importlib
import matplotlib.pyplot as plt
import numpy as np
import data
import utils
import sys
import wandb
import copy
import pickle

from torch import nn
from torch.nn import functional as F
from torchvision import transforms
from torchvision.datasets import MNIST
from abc import ABC,abstractmethod
from torch.utils.data import DataLoader, Subset
from copy import deepcopy
from torch.special import logit
from pytorch_lightning import loggers as pl_loggers
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks import ModelCheckpoint
from torch.utils.tensorboard import SummaryWriter
from pytorch_lightning.loggers import WandbLogger
from timeit import default_timer as timer

importlib.reload(data)
importlib.reload(utils)

pl.seed_everything(42)

debug=False

device=torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')


In [None]:
importlib.reload(data)

batch_size=128
epochs=2
subset_frac=0.1

custom_dataset_test_split=0.2

data_path='datasets'

train_dataset = MNIST(data_path,train=True, transform=transforms.ToTensor())
train_dataloader = DataLoader(train_dataset,batch_size=batch_size)

val_dataset=MNIST(data_path,train=False,transform=transforms.ToTensor())
val_dataloader = DataLoader(val_dataset,batch_size=batch_size)

n_task=5
n_task_=[1,2,3,4,6,7,8,9]

if not debug:
    custom1_datamodule=data.CustomDataModule(n=n_task)
    custom1_datamodule.setup()
    custom1_train_dataloader=custom1_datamodule.train_dataloader()
    custom1_test_dataloader=custom1_datamodule.test_dataloader()

    custom2_datamodule=data.CustomDataModule(n=n_task_)
    custom2_datamodule.setup()
    custom2_train_dataloader=custom2_datamodule.train_dataloader()
    custom2_test_dataloader=custom2_datamodule.test_dataloader()

if debug:
    custom1_datamodule=data.CustomDataModule(n=n_task,dataset_frac=subset_frac)
    custom1_datamodule.setup()
    custom1_train_dataloader=custom1_datamodule.train_dataloader()
    custom1_test_dataloader=custom1_datamodule.test_dataloader()

    custom2_datamodule=data.CustomDataModule(n=n_task_,dataset_frac=subset_frac)
    custom2_datamodule.setup()
    custom2_train_dataloader=custom2_datamodule.train_dataloader()
    custom2_test_dataloader=custom2_datamodule.test_dataloader()



X,y=next(iter(train_dataloader))
X_c,y_c=next(iter(custom1_train_dataloader))
X_c2,y_c2=next(iter(custom2_train_dataloader))


In [None]:
#regular MNIST Model

class SimpleModel(nn.Module):

    def __init__(self):
        super().__init__()
        self.l1=nn.Linear(28*28,10)


    def forward(self,x):
        y=self.l1(x)
        return y

class MNISTModel(pl.LightningModule):

    def __init__(self,img_size):
        super().__init__()
        H,W=img_size
        nf1=10
        nf2=20

        self.conv1 = nn.Conv2d(1,nf1,kernel_size=3,stride=1,padding=1)
        self.conv2 = nn.Conv2d(10,nf2,kernel_size=3,stride=1,padding=1)
        self.maxpool2 = nn.MaxPool2d(kernel_size=(2,2))
        self.conv2_drop = nn.Dropout2d()

        lin_size=int(0.25*H*0.25*W*nf2) #must be integer
        self.fc1 = nn.Linear(in_features=lin_size,out_features=50)
        self.fc2 = nn.Linear(50,10)

        self.save_hyperparameters()


    def forward(self,x):
        
        N=x.size()[0]

        x=F.relu(self.maxpool2(self.conv1(x)))
        x=F.relu(self.maxpool2(self.conv2_drop(self.conv2(x)))) #dropout in conv layers
        x=x.view(N,-1)
        x=F.relu(self.fc1(x))
        x=F.dropout(x,training=self.training) #dropout in FFN layers
        x=self.fc2(x)
        return x

    def training_step(self,batch,batch_idx):
        x,y=batch
        y_hat=self(x)
        loss=F.cross_entropy(y_hat,y)
        self.log('train-loss',loss.item(),on_epoch=True)
        #wandb.log('train/loss':loss)
        return loss

    def validation_step(self,batch,batch_idx,on_epoch=True):
        X,y=batch
        y_hat=self(X)
        loss=F.cross_entropy(y_hat,y)
        probs=F.softmax(y_hat,dim=1)
        pred=torch.argmax(probs,axis=1)
        acc=(len(torch.nonzero(pred==y))/len(pred))*100
        
        self.log('val-loss',loss.item(),on_epoch=True)
        self.log('val-acc',acc,on_epoch=True)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(),lr=0.02)


class MNISTFFN(pl.LightningModule):

    def __init__(self):
        super().__init__()

        self.layers=nn.Sequential(
            nn.Flatten(start_dim=1,end_dim=-1),
            nn.Linear(28*28,256),
            nn.Linear(256,128),
            nn.Linear(128,64),
            nn.Linear(64,10),
        )
        


    def forward(self,x):
        x=self.layers(x)
        return x

    def training_step(self,batch,batch_idx):
        x,y=batch
        y_hat=self(x)
        loss=F.cross_entropy(y_hat,y)
        probs=F.softmax(y_hat,dim=1)
        pred=torch.argmax(probs,axis=1)
        acc=(len(torch.nonzero(pred==y))/len(pred))*100

        #logging
        self.log('train-loss',loss.item())
        self.log('train-acc',acc)


        return loss

    def validation_step(self,batch,batch_idx,on_epoch=True):
        X,y=batch
        y_hat=self(X)
        loss=F.cross_entropy(y_hat,y)
        probs=F.softmax(y_hat,dim=1)
        pred=torch.argmax(probs,axis=1)
        acc=(len(torch.nonzero(pred==y))/len(pred))*100
        
        #logging
        self.log('val-loss',loss.item())
        self.log('val-acc',acc)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(),lr=0.02)



In [None]:
#pretraining model - SHOULDN't BE RUNNING OFTEN


checkpoint_dir='checkpoints'

name=str(input('Log name'))
wandb_logger=WandbLogger(project='AVR',name=name,version=name)

if debug==True:

    limit_train_batches=0.05
    limit_val_batches=0.01
    
else:
    limit_train_batches=1.0
    limit_val_batches=1.0

check_val_every=1

epochs=10
model=MNISTFFN()

#callbacks
ES_callback=EarlyStopping(monitor='val-loss',patience=100)
ckpt_callback=ModelCheckpoint(
    save_top_k=1, #save top 1 checkpoint,
    monitor='val-acc',
    mode='max',
    filename='max-val-acc-ckpt'
)


trainer=pl.Trainer(max_epochs=epochs,
                    check_val_every_n_epoch=check_val_every,
                    limit_train_batches=limit_train_batches,limit_val_batches=limit_val_batches,
                    logger=wandb_logger,callbacks=[ES_callback],
                    default_root_dir=checkpoint_dir)

trainer.fit(model,
            train_dataloaders=train_dataloader,val_dataloaders=val_dataloader)

wandb.finish()

In [None]:

modelA=MNISTFFN().load_from_checkpoint('logs/lightning_logs/version_7/checkpoints/epoch=4-step=4690.ckpt')
modelB=MNISTModel(img_size=(28,28)) #with conv enc layer

def get_children(model: torch.nn.Module):
    # get children form model!
    children = list(model.children())
    flatt_children = []
    if children == []:
        # if model has no children; model is last child! :O
        return model
    else:
       # look for children from children... to the last child!
       for child in children:
            try:
                flatt_children.extend(get_children(child))
            except TypeError:
                flatt_children.append(get_children(child))
    return flatt_children



def get_named_children(model):

    '''
    IMPORTANT: We assume that a leaf child is one that has 0 children
    This needs checking
    '''
    
    children_dict={}
    named_modules=dict(model.named_modules())
    for module_name,module in named_modules.items():
        if len(list(module.children()))==0:
            children_dict[module_name]=module

    return children_dict

class MaskedModel():

    def __init__(self,model,train_dataloader,test_dataloader1,test_dataloader2,tau=1):

        self.model=model
        self.train_dataloader=train_dataloader
        self.test_dataloader1=test_dataloader1
        self.test_dataloader2=test_dataloader2

        self.logit_tensors_dict={k:torch.nn.Parameter(data=torch.full_like(p,0.9)) for k,p in modelA.named_parameters()}
        self.alpha=None
        self.tau=tau
        self.logging=False

        self.train_epoch=0


        #freeze model parameters
        for p in model.parameters():
            p.requires_grad=False
        self.param_dict=dict(model.named_parameters())

        self.leaf_modules=get_named_children(self.model)


        self.optimiser=torch.optim.Adam(self.logit_tensors_dict.values())
        
    def forward(self,x):

        '''
        Forward through masked model 
        Masked model - Frozen pretrained model * binaries
        '''

        binaries=self.transform_logit_tensors()

        #apply mask to (frozen) tensors
        for layer_name,layer in self.leaf_modules.items():
            if isinstance(layer,nn.Linear):
                weight_=self.param_dict[layer_name+'.weight']*binaries[layer_name+'.weight']
                bias_=self.param_dict[layer_name+'.bias']*binaries[layer_name+'.bias']
                x=F.linear(x,weight=weight_,bias=bias_) #calling next rather than idx-ing ensures tensor not detached from computational graph (what does this mean)
            else:
                x=layer(x)

        return x
        #we must implement the forward layer ourself

        '''
        for layer_name,layer in self.layers:
            if linear:
                x=F.linear(x,weight=binaries[layer_name+'weight'],bias=binaries[layer_name+'bias'])
        '''
        
    
       

    def train(self,alpha,n_batches=10,n_epochs=5,logging=False,
            val_every_n_steps=10,
            eval_every=10,n_eval_batches=5,norm_freq=5,set_log_name=False):

        if logging:
            self.logging=True
        if self.logging:
            if set_log_name:
                log_name=str(input('Enter log name'))
                if log_name=='':
                    sys.exit()
            else:
                log_name=None
            run=wandb.init(project='AVR',name=log_name)

        for ep in range(n_epochs):

            #set class attributes to be used in other methods called in 'train()'
            self.train_epoch+=1
            self.alpha=alpha


            for batch_idx,batch in enumerate(train_dataloader):
                if n_batches=='full':
                    pass
                if batch_idx==n_batches:
                    break
                x,y=batch
                y_hat=self.forward(x)

                crossent_loss=F.cross_entropy(y_hat,y)
                reg_loss=alpha*torch.sum(torch.stack([torch.sum(logit_tens) for logit_tens in list(self.logit_tensors_dict.values())]))
                loss=crossent_loss+reg_loss
                acc=utils.calculate_accuracy(y_hat,y)
                loss.backward()
                self.optimiser.step()

                if self.logging:
                    wandb.log({'epoch':ep,
                                'train_loss':loss.item(),
                                'train_crossent_loss':crossent_loss.item(),
                                'train_reg_loss':reg_loss.item(),
                                'train_accuracy':acc
                                })
                                
                    if ep%norm_freq==0 and ep!=0:
                        data=self.param_grad_norms()
                        table=wandb.Table(data=data,columns=['names','key','values'])
                        wandb.log({'Norm data':table})

                if (run.step%val_every_n_steps==0) and (run.step!=0):
                    self.validation()


            #evaluate via ablation and comparison to other tasks
            if (ep%eval_every==0):
                self.eval(self.test_dataloader1,self.test_dataloader2,n_batches=n_eval_batches)
            
            print(f'Epoch: {ep}, Loss:{loss.item()}')


        
        if self.logging:
            wandb.finish()

    
    def transform_logit_tensors(self):

        tau=self.tau

        U1 = torch.rand(1, requires_grad=True)
        U2 = torch.rand(1, requires_grad=True)

        samples={}
        for k,v in self.logit_tensors_dict.items():
            samples[k]=torch.sigmoid((v - torch.log(torch.log(U1) / torch.log(U2))) / tau)
            

        binaries_stop={}
        for k,v in samples.items():
            with torch.no_grad():
                binaries_stop[k]=(v>0.5).float()-v
        
        binaries={}
        for k,v in binaries_stop.items():
            binaries[k]=v+samples[k]

        return binaries


    def validation(self):
        '''
        Calculate accuracy of mask on validation set
        For now, we use test_dataset=val_dataset
        '''
        
        batch=next(iter(self.test_dataloader1))
        x,y=batch
        with torch.no_grad():
            y_hat=self.forward(x)

        crossent_loss=F.cross_entropy(y_hat,y)
        reg_loss=self.alpha*torch.sum(torch.stack([torch.sum(logit_tens) for logit_tens in list(self.logit_tensors_dict.values())]))
        loss=crossent_loss+reg_loss
        acc=utils.calculate_accuracy(y_hat,y)



        if self.logging:
                wandb.log({
                            'validation_loss':loss.item(),
                            'validation_crossent_loss':crossent_loss.item(),
                            'validation_reg_loss':reg_loss.item(),
                            'validation_accuracy':acc
                            })

        





    def eval(self,task_eval_dataloader,_task_eval_dataloader,n_batches):
        '''
        Evaluated mask via ablation

        Ablation - frozen_parameters * ~binaries (inverted mask)
        '''
        print('start eval')
        #create masked model
        masked_model=copy.deepcopy(self.model)
        with torch.no_grad():
            binaries=self.transform_logit_tensors()

            #only linear layer compatibility at the moment 
            for n,p in masked_model.named_parameters():
                invert_mask=(~(binaries[n].bool())).int() #to ablate TASK weights, we invert mask
                masked_param=p*invert_mask
                p.copy_(masked_param) #copy in masked params

        acc1s=[]
        acc2s=[]

        for batch_idx,(batch1,batch2) in enumerate(zip(task_eval_dataloader,_task_eval_dataloader)):
            if n_batches=='full':
                pass
            if batch_idx==n_batches:
                break
            x1,y1=batch1
            x2,y2=batch2

            pred_logits_1=masked_model(x1)
            pred_logits_2=masked_model(x2)

            acc1s.append(utils.calculate_accuracy(pred_logits_1,y1))
            acc2s.append(utils.calculate_accuracy(pred_logits_2,y2))

        acc1=round(np.mean(acc1s),2)
        acc2=round(np.mean(acc2s),2)

        if self.logging:
            wandb.define_metric("Eval accuracies",step_metric='epoch')
            wandb.log({'Eval accuracies':{"Task":acc1,"NOT task":acc2}})
        else:
            print({"Acc task'":acc1,"Acc not task":acc2})

        print('end eval')

    def param_grad_norms(self):
        data=[]
        names=list(self.logit_tensors_dict.keys())
        for name in names:
            data.append([name,'param norm',self.logit_tensors_dict[name].mean().item()])
            data.append([name,'grad norm',self.logit_tensors_dict[name].grad.mean().item()])

        return data



           


task_train_dataloader=custom1_train_dataloader
task_test_dataloader=custom1_test_dataloader
_task_test_dataloader=custom2_test_dataloader


kwargs={
    'alpha':1e-5,
    'n_epochs':5,
    'n_batches':10,
    'val_every_n_steps':10,
    'eval_every':2,
    'n_eval_batches':1,
    'norm_freq':100,
    'logging':False,
    'set_log_name':True
}



if 0:
    mm=MaskedModel(modelA,train_dataloader=task_train_dataloader,
        test_dataloader1=task_test_dataloader,test_dataloader2=_task_test_dataloader)
    mm.train(**kwargs)

    






In [None]:
class AbstractMaskedModel(ABC):

    def __init__(self,model,train_dataloader,test_dataloader1,test_dataloader2,savedir=None):
        
        self.model=model
        self.train_dataloader=train_dataloader
        self.test_dataloader1=test_dataloader1
        self.test_dataloader2=test_dataloader2
        #freeze model parameters
        for p in model.parameters():
            p.requires_grad=False
        self.param_dict=dict(model.named_parameters())
        self.leaf_modules=utils.get_named_children(self.model)
        self.savedir=savedir


        self.logit_tensors_dict={k:torch.nn.Parameter(data=torch.full_like(p,0.9)) for k,p in model.named_parameters()}
        self.alpha=None
        self.logging=False #this attribute and below are set during training/loading
        self.logger=None
        self.run_id=None
        self.optimiser=torch.optim.Adam(self.logit_tensors_dict.values())

        self.global_step=0
        self.train_epoch=0

    @abstractmethod
    def forward(self,x,invert_mask=False):
        pass

    def calculate_loss(self,y_hat,y):
        crossent_loss=F.cross_entropy(y_hat,y)
        reg_loss=self.alpha*torch.sum(torch.stack([torch.sum(logit_tens) for logit_tens in list(self.logit_tensors_dict.values())]))
        loss=crossent_loss+reg_loss
        acc=utils.calculate_accuracy(y_hat,y)

        return crossent_loss,reg_loss,loss,acc

    def train(self,alpha,tau=1,n_epochs=5,n_batches=5,batch_split=4,
                    val_every_n_steps=10,n_val_batches=100,
                    eval_every=10,n_eval_batches=5,
                    logging=False,set_log_name=False,save_freq=10):


            #set class attributes for use in rest of class
            self.alpha=alpha
            self.tau=tau
            
            if logging:
                self.logging=True
            if self.logging:
                if set_log_name:
                    log_name=str(input('Enter log name'))
                    if log_name=='':
                        sys.exit()
                else:
                    log_name=None
                
                if self.run_id is not None:
                    self.logger=wandb.init(id=self.run_id,project='AVR',resume='must')
                else:
                    self.logger=wandb.init(project='AVR',name=log_name)
                wandb.define_metric('global step')

            for epoch in range(self.train_epoch,n_epochs):
                start_time=timer()
                for batch_idx,batch in enumerate(self.train_dataloader):
                    if n_batches=='full':
                        pass
                    if batch_idx==n_batches:
                        break


                    train_loss=0
                    split_X,split_y=torch.chunk(batch[0],batch_split),torch.chunk(batch[1],batch_split)
                    for x,y in zip(split_X,split_y):
                        y_hat=self.forward(x)
                        crossent_loss,reg_loss,loss,acc=self.calculate_loss(y_hat,y)
                        train_loss+=loss.item()
                        loss.backward()

                    self.optimiser.step()

                    if self.logging:
                        '''
                        wandb.log({'epoch':epoch,
                                    'train_loss':train_loss,
                                    },step=self.global_step)
                        '''
                        wandb.define_metric('train_loss',step_metric='global_step')
                        wandb.log({'train_loss':train_loss,'global_step':self.global_step})

                    if (self.global_step%val_every_n_steps==0) and (self.global_step!=0):
                        self.validation(n_batches=n_val_batches)

                    self.global_step+=1
                    

                end_train_time=timer()



                #run ablation every n_ablation epochs
                if (epoch%eval_every==0) and (epoch!=0):
                    self.eval(self.test_dataloader1,self.test_dataloader2,n_batches=n_eval_batches)
                    end_eval_time=timer()

                    train_time=end_train_time-start_time
                    eval_time=end_eval_time-end_train_time
                    print(f'Train time: {train_time} \n Eval time:{eval_time}')

                
                #save every n_save epochs
                if (self.savedir is not None) and (epoch%save_freq==0) and (epoch!=0):
                    self.save()

                    
                print(f'Epoch: {epoch}, Loss:{loss.item()}')
                self.train_epoch+=1
                
            
            wandb.finish()
            print('Training finished')
                
    def validation(self,n_batches):
        batch=next(iter(self.test_dataloader1))

        losses=[]
        val_accs=[]

        for batch_idx,batch in enumerate(self.test_dataloader1):
            if n_batches=='full':
                pass
            if batch_idx==n_batches:
                break

            x,y=batch
            with torch.no_grad():
                y_hat=self.forward(x)
            crossent_loss,reg_loss,loss,acc=self.calculate_loss(y_hat,y)
            losses.append((crossent_loss.item(),reg_loss.item(),loss.item()))
            val_accs.append(acc)

        val_crossent_loss=np.mean([_[0] for _ in losses])
        val_reg_loss=np.mean([_[1] for _ in losses])
        val_loss=np.mean([_[2] for _ in losses])
        val_accuracy=np.mean(val_accs)

        if self.logging:
                wandb.log({
                            'validation_loss':val_loss,
                            'validation_crossent_loss':val_crossent_loss,
                            'validation_reg_loss':val_reg_loss,
                            'validation_accuracy':val_accuracy
                            })
        else:
            #print(f'\n Validation accuracy: {acc}')
            pass

        

    def eval(self,task_eval_dataloader,_task_eval_dataloader,n_batches):
        '''
        Evaluated mask via ablation

        Ablation - frozen_parameters * ~binaries (inverted mask)
        '''
        #create masked model

        acc1s=[]
        acc2s=[]


        for batch_idx,(batch1,batch2) in enumerate(zip(task_eval_dataloader,_task_eval_dataloader)):
            if n_batches=='full':
                pass
            if batch_idx==n_batches:
                break
            x1,y1=batch1
            x2,y2=batch2


            pred_logits_1=self.forward(x1,invert_mask=True)
            pred_logits_2=self.forward(x2,invert_mask=True)

            acc1s.append(utils.calculate_accuracy(pred_logits_1,y1))
            acc2s.append(utils.calculate_accuracy(pred_logits_2,y2))



        acc1=round(np.mean(acc1s),2)
        acc2=round(np.mean(acc2s),2)

        if self.logging:
            wandb.define_metric("Eval accuracies",step_metric='epoch')
            wandb.log({'Eval accuracies':{"Task":acc1,"NOT task":acc2}},step=self.global_step)
        else:
            print({"Acc task'":acc1,"Acc not task":acc2})


    def MaskedLinear(self,x,name,invert=False):

        '''
        Think invert detaches tensor from comp graph, so should only be used during val
        '''
        binaries=self.transform_logit_tensors() #we could just update binaries every training step
        binary_weight,binary_bias=binaries[name+'.weight'],binaries[name+'.bias']
        if invert:
            binary_weight=(~(binary_weight.bool())).int()
            binary_bias=(~(binary_bias.bool())).int()

            '''
            idxs0_w,idxs1_w=binary_weight==0.0,binary_weight==1.0
            idxs0_b,idxs1_b=binary_bias==0.0,binary_bias==0.0
            binary_weight[idxs0_w]+=1.0
            binary_weight[idxs1_w]-=-1.0
            binary_bias[idxs0_b]=+1.0
            binary_bias[idxs1_b]-=1.0
            '''

        masked_weight,masked_bias=self.param_dict[name+'.weight']*binary_weight,self.param_dict[name+'.bias']*binary_bias
        out=F.linear(x,weight=masked_weight,bias=masked_bias)
        return out

    def MaskedConv2d(self,x,name,bias=False,invert=False):

        '''
        invert detaches tensor from comp graph, so should only be used during val
        '''

        stride,padding=self.leaf_modules[name].stride,self.leaf_modules[name].padding

        binaries=self.transform_logit_tensors()
        binary_weight=binaries[name+'.weight']

        if bias:
            binary_bias=binaries[name+'.bias']
        else:
            masked_bias=None

        if invert:
            binary_weight=(~(binary_weight.bool())).int()
            if bias:
                binary_bias=(~(binary_bias.bool())).int()

        masked_weight=self.param_dict[name+'.weight']*binary_weight
        masked_bias=self.param_dict[name+'.bias']*binary_bias
        out=F.conv2d(x,weight=masked_weight,bias=masked_bias,stride=stride,padding=padding)
        return out

    def MaskedBatchNorm2d(self,x,name,invert=False):
        
        #these are approximations to feature mean + variance over whole dataset, calculated during training
        running_mean=self.leaf_modules[name].running_mean
        running_var=self.leaf_modules[name].running_var 

        binaries=self.transform_logit_tensors()
        binary_weight=binaries[name+'.weight']
        binary_bias=binaries[name+'.bias']

        if invert:
            binary_weight=(~(binary_weight.bool())).int()
            binary_bias=(~(binary_bias.bool())).int()
        
        masked_weight=self.param_dict[name+'.weight']*binary_weight
        masked_bias=self.param_dict[name+'.bias']*binary_bias
        return F.batch_norm(x,running_mean=running_mean,running_var=running_var,weight=masked_weight,bias=binary_bias)

    def MaskedLayerNorm(self,x,name,invert=False):

        normalized_shape=self.leaf_modules[name].normalized_shape

        binaries=self.transform_logit_tensors()
        binary_weight=binaries[name+'.weight']
        binary_bias=binaries[name+'.bias']

        if invert:
            binary_weight=(~(binary_weight.bool())).int()
            binary_bias=(~(binary_bias.bool())).int()
        
        masked_weight=self.param_dict[name+'.weight']*binary_weight
        masked_bias=self.param_dict[name+'.bias']*binary_bias

        return F.layer_norm(normalized_shape=normalized_shape,weight=masked_weight,bias=masked_bias)




    def transform_logit_tensors(self):

        tau=self.tau

        U1 = torch.rand(1, requires_grad=True)
        U2 = torch.rand(1, requires_grad=True)

        samples={}
        for k,v in self.logit_tensors_dict.items():
            samples[k]=torch.sigmoid((v - torch.log(torch.log(U1) / torch.log(U2))) / tau)
            

        binaries_stop={}
        for k,v in samples.items():
            with torch.no_grad():
                binaries_stop[k]=(v>0.5).float()-v
        
        binaries={}
        for k,v in binaries_stop.items():
            binaries[k]=v+samples[k]

        return binaries

    def save(self):

        if not os.path.isdir(self.savedir):
            os.mkdir(self.savedir)

        save_dict={}
        save_dict['alpha']=self.alpha
        save_dict['tau']=self.tau
        save_dict['global_step']=self.global_step
        save_dict['train_epoch']=self.train_epoch
        save_dict['logit_tensors_dict']=self.logit_tensors_dict
        save_dict['optimiser']=self.optimiser
        if self.logging:
            save_dict['run_id']=self.logger.id
        else:
            save_dict['run_id']=None

        fname=os.path.join(self.savedir,f'checkpoint_step={self.global_step}_epoch={self.train_epoch}')
        with open(fname,'wb') as f:
            pickle.dump(save_dict,f)
            print(f'Checkpoint step={self.global_step}, epoch={self.train_epoch} saved')
 
    def load(self,path):

        with open (path,'rb') as f:
            load_dict=pickle.load(f)

        self.alpha=load_dict.get('alpha')
        self.tau=load_dict.get('tau')
        self.global_step=load_dict.get('global_step')
        self.train_epoch=load_dict.get('train_epoch')
        self.logit_tensors_dict=load_dict.get('logit_tensors_dict')
        self.optimiser=load_dict.get('optimiser')
        self.run_id=load_dict.get('run_id')
    



class MaskedMNISTFFN(AbstractMaskedModel):

    def __init__(self,kwargs):
        super().__init__(**kwargs)

        #none mask trainable layers
        self.layer0=nn.Flatten(start_dim=1,end_dim=-1)

    def forward(self, x, invert_mask=False):
        
        
        x0=self.layer0(x)
        x1=self.MaskedLinear(x0,name='layers.1',invert=invert_mask)
        x2=self.MaskedLinear(x1,name='layers.2',invert=invert_mask)
        x3=self.MaskedLinear(x2,name='layers.3',invert=invert_mask)
        x4=self.MaskedLinear(x3,name='layers.4',invert=invert_mask)

        return x4



class MaskedMNISTConv(AbstractMaskedModel):
    
    def __init__(self,kwargs):
        super().__init__(**kwargs)


        #initialise layers that mask not trained on
        #should implement method to check if we've done this right
        self.maxpool_2=nn.MaxPool2d(kernel_size=(2,2))
        self.conv2_drop=nn.Dropout()

    def forward(self,x,invert_mask=False):

        N=x.size()[0]
        
        x=F.relu(self.maxpool_2(self.MaskedConv2d(x,name='conv1',invert=invert_mask)))
        x=F.relu(self.maxpool_2(self.conv2_drop(self.MaskedConv2d(x,name='conv2',invert=invert_mask))))
        x=x.view(N,-1)
        x=F.relu(self.MaskedLinear(x,name='fc1',invert=invert_mask))
        x=F.dropout(x)
        x=self.MaskedLinear(x,name='fc2',invert=invert_mask)
        return x


class MaskedSCLModel(AbstractMaskedModel):

    def __init__(self,kwargs):
        super().__init__(**kwargs)

        self.flatten_layer=nn.Flatten(1)
        self.relu=nn.ReLU(inplace=True)

    def MaskedVisionNet(self,x,invert_mask=False):

        vision_module_names=[_ for _ in self.leaf_modules.keys() if 'vision' in _]
        vision_modules={k:self.leaf_modules[k] for k in vision_module_names}

        for name,module in vision_modules.items():
            if isinstance(module,nn.Conv2d):
                x=self.MaskedConv2d(x,name=name,bias=True,invert=invert_mask)
            elif isinstance(module,nn.Linear):
                x=self.MaskedLinear(x,name=name,invert=invert_mask)
            elif isinstance(module,nn.BatchNorm2d):
                x=self.MaskedConv2d(x,name=name,invert=invert_mask)
            elif isinstance(module,nn.LayerNorm):
                pass
            elif isinstance(module,nn.Flatten):
                pass
            elif isinstance(module,nn.ReLU):
                pass
            else:
                raise Exception('Unrecognised module')



        x=self.MaskedConv2d(x,name='vision.net.0',bias=True)
        x=self.MaskedBatchNorm2d(x,name='vision.net.1')

        x=self.MaskedConv2d(x,name='vision.net.2',bias=True)
        x=self.MaskedBatchNorm2d(x,name='vision.net.3',bias=True)

        x=self.MaskedConv2d(x,name='vision.net.4',bias=True)
        x=self.MaskedBatchNorm2d(x,name='vision.net.5',bias=True)

        x=self.MaskedConv2d(x,name='vision.net.6',bias=True)
        x=self.MaskedBatchNorm2d(x,name='vision.net.7',bias=True)

        x_conv_out=self.MaskedConv2d(x,name='vision.net.8',bias=True)

        x1=self.flatten_layer(x_conv_out)
        x1=self.MaskedLinear(x1,name='vision.net.10')
        x1=self.relu(x1)

        #feedforward residual layer
        x2=self.MaskedLinear(x1,name='vision.net.12.net.0')
        x2=self.MaskedLayerNorm(x2,name='vision.net.12.net.1')
        x2=self.relu(x2)
        x2=self.MaskedLinear(x2,name='vision.net.12.net.3')
        out=x2+x1
        
        return out

    def MaskedAttrNet(self,x,invert_mask=False):
        pass





    def forward(self, x, invert_mask=False):

        b,m,n,c,h,w=x.shape
        images=x.view(-1,c,h,w)


        features=self.MaskedVision(x,invert_mask=invert_mask)
        attrs=self.MaskedAttrNet(x,invert_mask=invert_mask)
        attrs=self.MaskedFFResidual(x,invert_mask=invert_mask)

        rels=self.MaskedRelNet(attrs,invert_mask=invert_mask)
        rels=rels.flatten(2)

        logits=self.MaskedToLogit(rels,invert_mask=invert_mask).flatten(1)

        return logits

from scattering_transform import SCLTrainingWrapper,SCL

scl_kwargs={
    "image_size":160,                            # size of image
    "set_size": 9,                               # number of questions + 1 answer
    "conv_channels": [1, 16, 16, 32, 32, 32],    # convolutional channel progression, 1 for greyscale, 3 for rgb
    "conv_output_dim": 80,                       # model dimension, the output dimension of the vision net
    "attr_heads": 10,                            # number of attribute heads
    "attr_net_hidden_dims": [128],               # attribute scatter transform MLP hidden dimension(s)
    "rel_heads": 80,                             # number of relationship heads
    "rel_net_hidden_dims": [64, 23, 5] 
}


task_train_dataloader=custom1_train_dataloader
task_test_dataloader=custom1_test_dataloader
_task_test_dataloader=custom2_test_dataloader
model=SCLTrainingWrapper(SCL(**scl_kwargs))

kwargs={
    'model':model,
    'train_dataloader':custom1_train_dataloader,
    'test_dataloader1':custom1_test_dataloader,
    'test_dataloader2':custom2_test_dataloader,
    'savedir':'model_ckpts/FFN'
}



from scattering_transform import SCLTrainingWrapper, SCL
SCL_model=SCLTrainingWrapper(SCL(**scl_kwargs))
masked_scl=MaskedSCLModel(kwargs)

In [None]:

#REMEMBER TO RELOAD CELL ABOVE IF CHANGING AMM class

task_train_dataloader=custom1_train_dataloader
task_test_dataloader=custom1_test_dataloader
_task_test_dataloader=custom2_test_dataloader
model=MNISTFFN.load_from_checkpoint('/Users/iyngkarrankumar/Documents/AI/AVR-functional-modularity/model_ckpts/MNISTFFNepoch=9-step=9380.ckpt')

kwargs={
    'model':model,
    'train_dataloader':custom1_train_dataloader,
    'test_dataloader1':custom1_test_dataloader,
    'test_dataloader2':custom2_test_dataloader,
    'savedir':'model_ckpts/FFN'
}

train_kwargs={
    'alpha':1e-5,
    'n_epochs':10,
    'n_batches':1,
    'val_every_n_steps':100,
    'n_val_batches':1,
    'eval_every':1,
    'n_eval_batches':2,
    'logging':True,
    'set_log_name':True,
    'batch_split':10,
    'save_freq':4

}

mm1=MaskedMNISTFFN(kwargs)
mm1.load('/Users/iyngkarrankumar/Documents/AI/AVR-functional-modularity/model_ckpts/FFN/checkpoint_step=5_epoch=4')

if 0:
    mm1.train(**train_kwargs)

In [None]:

test=False

def sweep_function(test=test,model_type='FFN'):

    run=wandb.init(project='AVR')

    
    model=MNISTFFN().load_from_checkpoint('/Users/iyngkarrankumar/Documents/AI/AVR-functional-modularity/model_ckpts/MNISTFFNepoch=9-step=9380.ckpt')
    task_train_dataloader=custom1_train_dataloader
    task_test_dataloader=custom1_test_dataloader
    _task_test_dataloader=custom2_test_dataloader

    alpha=wandb.config.alpha
    n_epochs=wandb.config.n_epochs

    mm_kwargs={
        'model':model,
        'train_dataloader':task_train_dataloader,
        'test_dataloader1':task_test_dataloader,
        'test_dataloader2':_task_test_dataloader
    }

    train_kwargs={
    'n_batches':5 if test else 'full',
    'val_every_n_steps':10,
    'eval_every':2,
    'n_eval_batches':100,
    'norm_freq':100,
    'logging':True,
    }
    
    if model_type=='FFN':
        print(f'Model type:{model_type}')
        masked_model=MaskedMNISTFFN(mm_kwargs)
    elif model_type=='Conv':
        print(f'Model type:{model_type}')
        masked_model=MaskedMNISTConv(mm_kwargs)
    else:
        raise Exception('Enter valid model type')
    masked_model.train(alpha=alpha,n_epochs=n_epochs,**train_kwargs)





In [None]:
sweep_configuration={
    'method':'grid',
    
    'name':str(input('Enter sweep name')),
    'metric':{
        'goal':'maximize',
        'name':'validation_accuracy',
        },
    'parameters':{
        'alpha':{'values':[1e-6,1e-5,1e-4,1e-3,1e-2,1e-1,1e0]},
        'n_epochs':{'value':2 if test else 10}
        }
    }

sweep_id=wandb.sweep(sweep=sweep_configuration,project='AVR')
wandb.agent(sweep_id,function=sweep_function)
wandb.finish()

In [None]:
#scl
import scattering_transform
from scattering_transform import SCLTrainingWrapper
import utils
import torch


scl_kwargs={
    "image_size":160,                            # size of image
    "set_size": 9,                               # number of questions + 1 answer
    "conv_channels": [1, 16, 16, 32, 32, 32],    # convolutional channel progression, 1 for greyscale, 3 for rgb
    "conv_output_dim": 80,                       # model dimension, the output dimension of the vision net
    "attr_heads": 10,                            # number of attribute heads
    "attr_net_hidden_dims": [128],               # attribute scatter transform MLP hidden dimension(s)
    "rel_heads": 80,                             # number of relationship heads
    "rel_net_hidden_dims": [64, 23, 5] 
}

SCL_model=scattering_transform.SCL(**scl_kwargs)

#REMEMBER TO RELOAD CELL ABOVE IF CHANGING AMM class

task_train_dataloader=custom1_train_dataloader
task_test_dataloader=custom1_test_dataloader
_task_test_dataloader=custom2_test_dataloader

kwargs={
    'model':SCL_model,
    'train_dataloader':custom1_train_dataloader,
    'test_dataloader1':custom1_test_dataloader,
    'test_dataloader2':custom2_test_dataloader,
    'savedir':'model_ckpts/FFN'
}

train_kwargs={
    'alpha':1e-5,
    'n_epochs':10,
    'n_batches':1,
    'val_every_n_steps':100,
    'n_val_batches':1,
    'eval_every':1,
    'n_eval_batches':2,
    'logging':True,
    'set_log_name':True,
    'batch_split':10,
    'save_freq':4

}

mm1=MaskedSCLModel(kwargs)



if 0:
    mm1.train(**train_kwargs)