In [None]:
import os
import torch
import torch.nn.functional as F
import pytorch_lightning as pl;
import importlib
import matplotlib.pyplot as plt
import numpy as np
import data
import utils
import sys
import wandb
import copy

from torch import nn
from torch.nn import functional as F
from torchvision import transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader, Subset
from copy import deepcopy
from torch.special import logit
from pytorch_lightning import loggers as pl_loggers
from torch.utils.tensorboard import SummaryWriter

importlib.reload(data)
importlib.reload(utils)

pl.seed_everything(42)

debug=False

device=torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')


In [None]:

def Gumbel_Sigmoid(tens,T=1):
    log_U1=torch.log(torch.rand_like(tens))
    log_U2=torch.log(torch.rand_like(tens))
    t1=(tens-torch.log(log_U1/log_U2))/T
    t2=torch.sigmoid(t1)
    return t2

def calculate_accuracy(pred_logits,true_idxs):

    probs=F.softmax(pred_logits,dim=1)
    pred=torch.argmax(probs,axis=1)
    acc=(len(torch.nonzero(pred==true_idxs))/len(pred))*100
    return acc


def indicator(tens,threshold=0.5,below=0,above=1):
    #allow it, it works

    t1=F.threshold(-1*tens,threshold=-1*threshold,value=above)
    t2=F.threshold(t1,threshold=0,value=below)
    
    return t2


In [2]:
importlib.reload(data)

batch_size=64
epochs=2
subset_frac=0.1

custom_dataset_test_split=0.2

data_path='datasets'

train_dataset = MNIST(data_path,train=True, transform=transforms.ToTensor())
train_dataloader = DataLoader(train_dataset,batch_size=batch_size)

val_dataset=MNIST(data_path,train=False,transform=transforms.ToTensor())
val_dataloader = DataLoader(val_dataset,batch_size=batch_size)

if not debug:
    custom1_datamodule=data.CustomDataModule(n=5)
    custom1_datamodule.setup()
    custom1_train_dataloader=custom1_datamodule.train_dataloader()
    custom1_test_dataloader=custom1_datamodule.test_dataloader()

    custom2_datamodule=data.CustomDataModule(n=9)
    custom2_datamodule.setup()
    custom2_train_dataloader=custom2_datamodule.train_dataloader()
    custom2_test_dataloader=custom2_datamodule.test_dataloader()

if debug:
    custom1_datamodule=data.CustomDataModule(n=5,dataset_frac=subset_frac)
    custom1_datamodule.setup()
    custom1_train_dataloader=custom1_datamodule.train_dataloader()
    custom1_test_dataloader=custom1_datamodule.test_dataloader()

    custom2_datamodule=data.CustomDataModule(n=9,dataset_frac=subset_frac)
    custom2_datamodule.setup()
    custom2_train_dataloader=custom2_datamodule.train_dataloader()
    custom2_test_dataloader=custom2_datamodule.test_dataloader()



X,y=next(iter(train_dataloader))
X_c,y_c=next(iter(custom1_train_dataloader))
X_c2,y_c2=next(iter(custom2_train_dataloader))


NameError: name 'importlib' is not defined

In [1]:
#regular MNIST Model

class SimpleModel(nn.Module):

    def __init__(self):
        super().__init__()
        self.l1=nn.Linear(28*28,10)


    def forward(self,x):
        y=self.l1(x)
        return y

class MNISTModel(pl.LightningModule):

    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1,10,kernel_size=5)
        self.conv2 = nn.Conv2d(10,20,kernel_size=5)
        self.maxpool2 = nn.MaxPool2d(kernel_size=(2,2))
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320,50)
        self.fc2 = nn.Linear(50,10)


    def forward(self,x):
        x=F.relu(self.maxpool2(self.conv1(x)))
        x=F.relu(self.maxpool2(self.conv2_drop(self.conv2(x)))) #dropout in conv layers
        x=x.view(-1,320)
        x=F.relu(self.fc1(x))
        x=F.dropout(x,training=self.training) #dropout in FFN layers
        x=self.fc2(x)
        return x

    def training_step(self,batch,batch_idx):
        x,y=batch
        y_hat=self(x)
        loss=F.cross_entropy(y_hat,y)
        self.log('train loss',loss.item(),on_epoch=True)
        return loss

    def validation_step(self,batch,batch_idx,on_epoch=True):
        X,y=batch
        y_hat=self(X)
        loss=F.cross_entropy(y_hat,y)
        probs=F.softmax(y_hat,dim=1)
        pred=torch.argmax(probs,axis=1)
        acc=(len(torch.nonzero(pred==y))/len(pred))*100
        
        self.log('val loss',loss.item(),on_epoch=True)
        self.log('val acc',acc,on_epoch=True)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(),lr=0.02)

MNISTConvModel=MNISTModel()

class MNISTFFN(pl.LightningModule):

    def __init__(self):
        super().__init__()

        self.layers=nn.Sequential(
            nn.Flatten(start_dim=1,end_dim=-1),
            nn.Linear(28*28,256),
            nn.Linear(256,128),
            nn.Linear(128,64),
            nn.Linear(64,10),
        )
        


    def forward(self,x):
        x=self.layers(x)
        return x

    def training_step(self,batch,batch_idx):
        x,y=batch
        y_hat=self(x)
        loss=F.cross_entropy(y_hat,y)
        probs=F.softmax(y_hat,dim=1)
        pred=torch.argmax(probs,axis=1)
        acc=(len(torch.nonzero(pred==y))/len(pred))*100

        #logging
        self.logger.experiment.add_scalars('Pretrained loss',{'train':loss.item()},self.global_step)
        self.logger.experiment.add_scalars('Pretrained accuracy',{'train':acc},self.global_step)

        return loss

    def validation_step(self,batch,batch_idx,on_epoch=True):
        X,y=batch
        y_hat=self(X)
        loss=F.cross_entropy(y_hat,y)
        probs=F.softmax(y_hat,dim=1)
        pred=torch.argmax(probs,axis=1)
        acc=(len(torch.nonzero(pred==y))/len(pred))*100
        
        #logging
        self.logger.experiment.add_scalars('Pretrained loss',{'val':loss.item()},self.global_step)
        self.logger.experiment.add_scalars('Pretrained accuracy',{'val':acc},self.global_step)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(),lr=0.02)



NameError: name 'pl_loggers' is not defined

In [None]:
#Ablation study
logit_mask=maskedmodel.logit_mask
binarised_mask=[indicator(tens) for tens in logit_mask]
model2=copy.deepcopy(model) #we need to check if masking has changed the original model weights

model2.eval()

#ablate
for idx,param in enumerate(model2.parameters()):
    bool_idxs=binarised_mask[idx]==1 #the 'on' weights in mask
    param.data[bool_idxs]=0 #zero ablating 'on weights'



#calculate accuracy of model 2 on custom dataset
accs_c1=[]
for idx,batch in enumerate(custom1_test_dataloader):
    X,y=batch
    pred_logits=model2(X)
    acc=calculate_accuracy(pred_logits=pred_logits,true_idxs=y)
    accs_c1.append(acc)

avg_acc_c1=sum(accs_c1)/len(accs_c1)

accs_c2=[]
for idx,batch in enumerate(custom2_test_dataloader):
    X,y=batch
    pred_logits=model2(X)
    acc=calculate_accuracy(pred_logits=pred_logits,true_idxs=y)
    accs_c2.append(acc)

avg_acc_c2=sum(accs_c2)/len(accs_c2)

print(f'Ablated model accuracy on C1: {avg_acc_c1} \
    Ablated model accuracy on C2: {avg_acc_c2}')

In [None]:
#pretraining model - SHOULDN't BE RUNNING OFTEN

run_bool=input("Are you sure you want to retrain model? yes/no")

if run_bool=='yes':
    pass
elif run_bool=='no':
    sys.exit()
else:
    raise Exception('Please enter either yes or no')



tb_logger=pl_loggers.TensorBoardLogger(save_dir='logs') #for Lightning

if debug==True:
    limit_train_batches=0.05
    limit_val_batches=0.01
else:
    limit_train_batches=1.0
    limit_val_batches=1.0

check_val_every=1

epochs=5
model=MNISTFFN()
trainer=pl.Trainer(max_epochs=epochs,
                    check_val_every_n_epoch=check_val_every,
                    limit_train_batches=limit_train_batches,limit_val_batches=limit_val_batches,
                    logger=tb_logger)
trainer.fit(model,
            train_dataloaders=train_dataloader,val_dataloaders=val_dataloader)

In [None]:
#copying from https://discuss.pytorch.org/t/how-to-optimise-mask-on-a-fully-frozen-network/151438

model=MNISTFFN().load_from_checkpoint('logs/lightning_logs/version_7/checkpoints/epoch=4-step=4690.ckpt')

log_name='Mask alpha-10e-5'
log_dir=os.path.join('logs',log_name)
writer=SummaryWriter(log_dir=log_dir) #for PyTorch

for layer in model.layers:
    if len(list(layer.parameters()))!=0: #if has trainable parameter
        layer.weight.requires_grad=False


'''logit initilisation'''
logits=[]
for layer in model.layers:
    if isinstance(layer,nn.Linear):
        logits.append(torch.nn.Parameter(data=torch.full_like(layer.weight.clone(),0.9),requires_grad=True))
        logits.append(torch.nn.Parameter(data=torch.full_like(layer.bias.clone(),0.9),requires_grad=True))

criterion = torch.nn.CrossEntropyLoss()
optimiser = torch.optim.Adam(logits, lr=0.01)

'''Initialise hyper-parameters'''
NUM_EPOCHS = 30 # NB: check for number of training epochs in paper
n_batches=5
tau = 1  # temperature parameter, NB: check for value in paper
alpha = 0.005

step=0
for e in range(NUM_EPOCHS):
    print(f'Epoch: {e}')
    for batch_idx,batch in enumerate(custom1_train_dataloader):
        

        #limited training for debugging
        if batch_idx==30:
            break
        X,y=batch


        #mask transform (logits->bin)

        U1 = torch.rand(1, requires_grad=True)
        U2 = torch.rand(1, requires_grad=True)

        samples = []
        for layer in logits:       
            samples.append(torch.sigmoid((layer - torch.log(torch.log(U1) / torch.log(U2))) / tau))

        #Ensures that the binarisation step has no associated gradient
        binaries_stop = []        
        for layer in samples:
            with torch.no_grad():
                binaries_stop.append((layer > 0.5).float() - layer)

        binaries = []
        for idx, layer in enumerate(binaries_stop):
            binaries.append(layer+samples[idx])

        binaries_iter=iter(binaries)

        #manual inference
        for idx,layer in enumerate(model.layers):
            if isinstance(layer,nn.Linear):
                weight=layer.weight*next(binaries_iter) #manual application of mask
                bias=layer.bias*next(binaries_iter)
                X=F.linear(X,weight,bias)
            else:
                X=layer(X)
            pred_logits=X

        reg_loss = torch.sum(torch.stack([torch.sum(logit_tens) for logit_tens in logits]))
        loss=criterion(pred_logits,y)+alpha*reg_loss
        optimiser.zero_grad()
        loss.backward()
        optimiser.step()

        #logging
        writer.add_scalar('Mask loss',loss.item(),step)
        writer.add_scalars('A+S',
                           {'accuracy '+log_name:utils.calculate_accuracy(pred_logits,y),
                            'sparsity '+log_name:utils.sparsity(binaries)},
                           step)


        step+=1

    
print('Training complete')

In [None]:

modelA=MNISTFFN().load_from_checkpoint('logs/lightning_logs/version_7/checkpoints/epoch=4-step=4690.ckpt')
modelB=MNISTModel() #with conv enc layer

def get_children(model: torch.nn.Module):
    # get children form model!
    children = list(model.children())
    flatt_children = []
    if children == []:
        # if model has no children; model is last child! :O
        return model
    else:
       # look for children from children... to the last child!
       for child in children:
            try:
                flatt_children.extend(get_children(child))
            except TypeError:
                flatt_children.append(get_children(child))
    return flatt_children

def get_named_children(model):

    '''
    IMPORTANT: We assume that a leaf child is one that has 0 children
    This needs checking
    '''
    
    children_dict={}
    named_modules=dict(model.named_modules())
    for module_name,module in named_modules.items():
        if len(list(module.children()))==0:
            children_dict[module_name]=module

    return children_dict

class MaskedModel():

    def __init__(self,model,train_dataloader,test_dataloader1,test_dataloader2,alpha=0.4,tau=1,eval_freq=5):

        self.model=model
        self.train_dataloader=train_dataloader
        self.test_dataloader1=test_dataloader1
        self.test_dataloader2=test_dataloader2

        self.logit_tensors_dict={k:torch.nn.Parameter(data=torch.full_like(p,0.9)) for k,p in modelA.named_parameters()}
        self.alpha=alpha
        self.tau=tau
        self.logging=False

        self.eval_freq=eval_freq

        #freeze model parameters
        for p in model.parameters():
            p.requires_grad=False

        self.leaf_modules=get_named_children(self.model)


        self.optimiser=torch.optim.Adam(self.logit_tensors_dict.values())
        
    def forward(self,x):

        binaries=self.transform_logit_tensors()

        for layer_name,layer in self.leaf_modules.items():
            if isinstance(layer,nn.Linear):
                x=F.linear(x,weight=binaries[layer_name+'.weight'],bias=binaries[layer_name+'.bias']) #calling next rather than idx-ing ensures tensor not detached from computational graph (what does this mean)
            else:
                x=layer(x)

        return x
        #we must implement the forward layer ourself

        '''
        for layer_name,layer in self.layers:
            if linear:
                x=F.linear(x,weight=binaries[layer_name+'weight'],bias=binaries[layer_name+'bias'])
        '''
        
    
       

    def train(self,n_batches=10,n_epochs=5,logging=False,n_eval_batches=5):

        if logging:
            self.logging=True
        if self.logging:
            log_name=str(input('Enter log name'))
            wandb.init(project='AVR',name=log_name)

        for ep in range(n_epochs):
            for batch_idx,batch in enumerate(train_dataloader):
                if n_batches=='full':
                    continue
                if batch_idx==n_batches:
                    break
                x,y=batch
                y_hat=self.forward(x)

                crossent_loss=F.cross_entropy(y_hat,y)
                reg_loss=torch.sum(torch.stack([torch.sum(logit_tens) for logit_tens in list(self.logit_tensors_dict.values())]))
                loss=crossent_loss+self.alpha*reg_loss
                loss.backward()
                self.optimiser.step()

                if self.logging:
                    wandb.log({'epoch':ep,
                                'loss':loss,
                                'crossent_loss':crossent_loss,
                                'reg_loss':reg_loss,
                                })

            if (ep%self.eval_freq==0) and (ep!=0):
                self.eval(self.test_dataloader1,self.test_dataloader2,n_batches=n_eval_batches)
        
        if self.logging:
            wandb.finish()

    
    def transform_logit_tensors(self):

        tau=self.tau

        U1 = torch.rand(1, requires_grad=True)
        U2 = torch.rand(1, requires_grad=True)

        samples={}
        for k,v in self.logit_tensors_dict.items():
            samples[k]=torch.sigmoid((v - torch.log(torch.log(U1) / torch.log(U2))) / tau)
            

        binaries_stop={}
        for k,v in samples.items():
            with torch.no_grad():
                binaries_stop[k]=(v>0.5).float()-v
        
        binaries={}
        for k,v in binaries_stop.items():
            binaries[k]=v+samples[k]

        return binaries

    def eval(self,task_eval_dataloader,_task_eval_dataloader,n_batches='full'):
        #evaluate mask via ablation
        print('start eval')

        #create masked model
        masked_model=copy.deepcopy(self.model)
        with torch.no_grad():
            binaries=self.transform_logit_tensors()
            #only linear layer compatibility at the moment 
            for n,p in masked_model.named_parameters():
                masked_param=p*binaries[n]
                p.copy_(masked_param) #copy in masked params


        acc1s=[]
        acc2s=[]

        for batch_idx,(batch1,batch2) in enumerate(zip(task_eval_dataloader,_task_eval_dataloader)):
            if n_batches=='full':
                continue
            if batch_idx==n_batches:
                break
            x1,y1=batch1
            x2,y2=batch2

            pred_logits_1=masked_model(x1)
            pred_logits_2=masked_model(x2)

            acc1s.append(utils.calculate_accuracy(pred_logits_1,y1))
            acc2s.append(utils.calculate_accuracy(pred_logits_2,y2))

        acc1=round(np.mean(acc1s),2)
        acc2=round(np.mean(acc2s),2)

        if self.logging:
            wandb.log({"Acc task'":acc1,"Acc not task":acc2})

        print('end eval')



        


           


task_train_dataloader=custom1_train_dataloader
task_test_dataloader=custom1_test_dataloader
_task_test_dataloader=custom2_test_dataloader


mm=MaskedModel(modelA,train_dataloader=task_train_dataloader,
    test_dataloader1=task_test_dataloader,test_dataloader2=_task_test_dataloader)
mm.train(n_batches=10,n_epochs=10,logging=True)

    






In [None]:
#Ablation study

#transform trained mask logits

U1 = torch.rand(1, requires_grad=True)
U2 = torch.rand(1, requires_grad=True)

samples = []
for layer in logits:       
    samples.append(torch.sigmoid((layer - torch.log(torch.log(U1) / torch.log(U2))) / tau))




binaries_stop = []        
for layer in samples:
    with torch.no_grad():
        binaries_stop.append((layer > 0.5).float() - layer)

binaries = []
for idx, layer in enumerate(binaries_stop):
    binaries.append(layer+samples[idx])

binaries_iter=iter(binaries)

#replace tens
masked_model=copy.deepcopy(model)
with torch.no_grad():
    for p in masked_model.parameters():
        masked_tens=p*next(binaries_iter)
        p.copy_(masked_tens)



acc1s=[]
acc2s=[]
#eval on trained dataset

for batch1,batch2 in zip(custom1_test_dataloader,custom2_test_dataloader):
    x1,y1=batch1
    x2,y2=batch2

    pred_logits_1=masked_model(x1)
    pred_logits_2=masked_model(x2)

    acc1s.append(utils.calculate_accuracy(pred_logits_1,y1))
    acc2s.append(utils.calculate_accuracy(pred_logits_2,y2))

acc1=round(np.mean(acc1s),2)
acc2=round(np.mean(acc2s),2)

print(f'Acc 1 - {acc1}, \n Acc 2 - {acc2}')

        