In [None]:
import os
import torch
import torch.nn.functional as F
import pytorch_lightning as pl;
import importlib
import numpy as np
import data
import copy

from torch import nn
from torch.nn import functional as F
from torchvision import transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
from copy import deepcopy
from torch.special import logit
importlib.reload(data)

pl.seed_everything(42)


In [None]:

def Gumbel_Sigmoid(tens,T=1):
    log_U1=torch.log(torch.rand_like(tens))
    log_U2=torch.log(torch.rand_like(tens))
    t1=(tens-torch.log(log_U1/log_U2))/T
    t2=torch.sigmoid(t1)
    return t2

def indicator(tens,threshold=0.5,below=0,above=1):

    t1=-1*F.threshold(tens,threshold=threshold,value=below)
    t2=F.threshold(t1,threshold=-0.00001,value=above).int()
    return t2

def calculate_accuracy(pred_logits,true_idxs):

    probs=F.softmax(pred_logits,dim=1)
    pred=torch.argmax(probs,axis=1)
    acc=(len(torch.nonzero(pred==true_idxs))/len(pred))*100
    return acc
    

class Indicator(nn.Module):

    def __init__(self,threshold=0.5,below=0,above=1):

        super().__init__()
        self.l1=torch.nn.Threshold(threshold=threshold,value=below)
        self.l2=torch.nn.Threshold(threshold=-1e-10,value=above)

    def forward(self,x):

        x=self.l2(-1*self.l1(x))
        x=x.int()
        return x

class MaskedModel(nn.Module):

    def __init__(self,model):
        
        super().__init__()
        self.masked_model=deepcopy(model)
        self.indicator=Indicator()
        self.trained_weights=list(model.parameters())
        for w in self.trained_weights: w.requires_grad=False #model params are frozen
        self.logit_mask=[nn.Parameter(torch.rand_like(w,requires_grad=True)) for w in self.trained_weights]
        self.binarised_mask=[self.indicator(tens) for tens in self.logit_mask]



        for i,param in enumerate(self.masked_model.parameters()):
            param.data=self.trained_weights[i]*self.binarised_mask[i] #change weights of model


    def forward(self,x):

        logits=self.masked_model(x)
        return logits
        

    def logit_l2_loss(self,mode='mean'):
        l2=0
        for tens in self.logit_mask:
            if mode=='mean':
                l2+=(tens**2).mean()
            elif mode=='sum':
                l2+=(tens**2).sum()
            else:
                raise Exception(f'{mode} is an invalid l2 mode')

        return l2

    def mask_sparsity(self):
        binarised_mask=self.binarised_mask
        binarised_mask=[self.indicator(tens) for tens in self.logit_mask]
        numel=sum(torch.numel(btens) for btens in binarised_mask)
        num_ones=sum([torch.count_nonzero(btens).item() for btens in binarised_mask])
        return ((num_ones/numel)*100)



In [None]:
batch_size=64
epochs=2
subset_frac=0.1


data_path='datasets'

train_dataset = MNIST(data_path,train=True, transform=transforms.ToTensor())
idxs=np.random.choice(range(train_dataset.__len__()),int(train_dataset.__len__()*subset_frac),replace=False)
subset=torch.utils.data.Subset(train_dataset,idxs)
train_dataloader = DataLoader(subset,batch_size=batch_size,shuffle=True)

val_dataset=MNIST(data_path,train=False,transform=transforms.ToTensor())
val_dataloader = DataLoader(val_dataset,batch_size=batch_size,shuffle=True)

custom_dataset=data.MNISTCustomDataset(n=5,transform=transforms.ToTensor())
idxs=np.random.choice(range(custom_dataset.__len__()),int(custom_dataset.__len__()*subset_frac),replace=False)
custom_subset=torch.utils.data.Subset(custom_dataset,idxs)
custom_trainloader=DataLoader(custom_subset,batch_size=batch_size)

X,y=next(iter(train_dataloader))
X_c,y_c=next(iter(custom_trainloader))


In [None]:
class SimpleModel(nn.Module):

    def __init__(self):
        super().__init__()
        self.l1=nn.Linear(28*28,10)


    def forward(self,x):
        y=self.l1(x)
        return y

class MNISTModel(pl.LightningModule):

    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1,10,kernel_size=5)
        self.conv2 = nn.Conv2d(10,20,kernel_size=5)
        self.maxpool2 = nn.MaxPool2d(kernel_size=(2,2))
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320,50)
        self.fc2 = nn.Linear(50,10)

    def forward(self,x):
        x=F.relu(self.maxpool2(self.conv1(x)))
        x=F.relu(self.maxpool2(self.conv2_drop(self.conv2(x)))) #dropout in conv layers
        x=x.view(-1,320)
        x=F.relu(self.fc1(x))
        x=F.dropout(x,training=self.training) #dropout in FFN layers
        x=self.fc2(x)
        return x

    def training_step(self,batch,batch_idx):
        x,y=batch
        y_hat=self(x)
        loss=F.cross_entropy(y_hat,y)
        return loss

    def validation_step(self,batch,batch_idx,on_epoch=True):
        X,y=batch
        y_hat=self(X)
        loss=F.cross_entropy(y_hat,y)
        probs=F.softmax(y_hat,dim=1)
        pred=torch.argmax(probs,axis=1)
        acc=(len(torch.nonzero(pred==y))/len(pred))*100
        print(f"accuracy: {acc}")

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(),lr=0.02)

debug=True
if debug==True:
    limit_train_batches=0.05
    limit_val_batches=0.01
else:
    limit_train_batches=1.0
    limit_val_batches=1.0

epochs=3
model=MNISTModel()
trainer=pl.Trainer(max_epochs=3,
                    check_val_every_n_epoch=1,
                    limit_train_batches=limit_train_batches,limit_val_batches=limit_val_batches)
trainer.fit(model,
            train_dataloaders=train_dataloader,val_dataloaders=val_dataloader)

In [None]:

maskedmodel=MaskedModel(model)


epochs=3
criterion=torch.nn.CrossEntropyLoss()
optimiser=torch.optim.Adam(maskedmodel.parameters())
optimiser.add_param_group({'params':maskedmodel.logit_mask})

print('Training masked model')
for _ in range(epochs):
    running_loss=0
    for i,batch in enumerate(custom_trainloader):
        print(f'Step {i} of epoch {_}')
        optimiser.zero_grad()
        X,y=batch
        logits=maskedmodel(X)

        sum1=sum(t1.sum() for t1 in maskedmodel.logit_mask)

        l2_loss=maskedmodel.logit_l2_loss()
        CE_loss=criterion(logits,y)
        loss=CE_loss+l2_loss; running_loss+=loss
        running_loss+=loss

        
        loss.backward()
        optimiser.step()

        sum2=sum(t1.sum() for t1 in maskedmodel.logit_mask)

        #print(f"Delta logit mask {np.abs((sum1-sum2).item())}")
        
    running_loss/=i #mean loss over epoch
    print(f"\n Average epoch loss: {running_loss}")
    print(f'Non-zero proportion {maskedmodel.mask_sparsity()} \n')

print('\n Finished training')

In [None]:
#Ablation study
logit_mask=maskedmodel.logit_mask
binarised_mask=[indicator(tens) for tens in logit_mask]
model2=copy.deepcopy(model) #we need to check if masking has changed the original model weights

model2.eval()


for idx,param in enumerate(model2.parameters()):
    bool_idxs=binarised_mask[idx]==1 #the 'on' weights in mask
    param.data[bool_idxs]=0 #zero ablating 'on weights'



#calculate accuracy of model 2

accs=[]
for idx,batch in enumerate(custom_trainloader):
    X,y=batch
    pred_logits=model2(X)
    acc=calculate_accuracy(pred_logits=pred_logits,true_idxs=y)
    accs.append(acc)

avg_acc=sum(accs)/len(accs)