In [None]:
import numpy as np
import torch
import os
import math
from torchvision import transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from huggingface_hub import interpreter_login, logout
from datasets import load_dataset
from transformers import ViTConfig, ViTImageProcessor, ViTForImageClassification
from transformers import image_utils as hf_image_utils
from PIL import Image, ImageOps
import scipy
import pywt
import pywt.data
from torch_cka import CKA
from torchinfo import summary
import pickle
from scipy.fftpack import dct, idct
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset,DataLoader
import torch.nn.functional as F
from torch.nn.parameter import Parameter
import pickle
plt.gray()
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
import os
import random
from PIL import Image

def class_sampler(class_id, n_samples=1):
    class_dir = f'./dataset_classes/{class_id}/'
    files = os.listdir(class_dir)
    img_files = random.sample([i for i in range(50)], min(n_samples, len(files)))
    img_paths = [f'{class_dir}{class_id}_{img_file}.jpg' for img_file in img_files]
    return [Image.open(img_path) for img_path in img_paths]


In [None]:
model_paths = {
    'base': 'google/vit-base-patch16-224',
    'dino': 'facebook/dino-vitb16',
    'mae': 'facebook/vit-mae-base',
    'large': 'google/vit-large-patch16-224'
}
from transformers import ViTForImageClassification
# vit_base = ViTForImageClassification.from_pretrained(model_paths['base']).to('cuda')
# vit_dino = ViTForImageClassification.from_pretrained(model_paths['dino']).to('cuda')
# vit_mae = ViTForImageClassification.from_pretrained(model_paths['mae']).to('cuda')
vit_base = ViTForImageClassification.from_pretrained(model_paths['base'])
vit_dino = ViTForImageClassification.from_pretrained(model_paths['dino'])
vit_mae = ViTForImageClassification.from_pretrained(model_paths['mae'])
vit_large = ViTForImageClassification.from_pretrained(model_paths['large'])

base_processor = ViTImageProcessor.from_pretrained(model_paths['base'])
dino_processor = ViTImageProcessor.from_pretrained(model_paths['dino'])
mae_processor = ViTImageProcessor.from_pretrained(model_paths['base'])
large_processor = ViTImageProcessor.from_pretrained(model_paths['large'])


In [None]:
def wt_decomposition(img,level):
    """
    Stores the basis images corresponding to the Discrete Wavelet Transform of img

    Args
      img: a PIL Image, either RGB or Grayscale (or) a 3-channel numpy array
      level: The level of decomposition

    Returns:
    A 1D list, containing all components
    """
    # if isinstance(img, Image.Image):
    #     np_img = np.array(img.convert('L'))
    # else:
    #     weights = np.array([0.299, 0.587, 0.114])
    #     np_img = (weights[0]*img[0] + weights[1]*img[1] + weights[2]*img[2]).astype('uint8')
    # if level == 1:
    #     scaled_bases = pywt.dwt2(np_img, 'haar')
    #     basis_imgs = []
    
    #     for i in scaled_bases:
        
    #         if isinstance(i, np.ndarray):
    #             zero_coeff = np.zeros_like(i)
    #             basis_img = pywt.idwt2(coeffs = (i, (None, None, None)), wavelet='haar')
    #             basis_imgs.append(basis_img)
    #         else:
    #             for j in i:
    #                 zero_coeff = np.zeros_like(j)
    #                 basis_img = pywt.idwt2(coeffs = (j, (None, None, None)), wavelet='haar')
    #                 basis_imgs.append(basis_img)
    #     return basis_imgs
    # else:
    scaled_bases = pywt.wavedec2(img,'haar',level=level)
    basis_imgs = []
    img_shapes = []
    for i in range(level):
        img_shapes.append(scaled_bases[i+1][0].shape)
    for i in range(len(scaled_bases)):
        if isinstance(scaled_bases[i],np.ndarray):
            coefs = [np.zeros_like(scaled_bases[0]),]  + [[np.zeros((img_shapes[k][0],img_shapes[k][1])),np.zeros((img_shapes[k][0],img_shapes[k][1])),np.zeros((img_shapes[k][0],img_shapes[k][1]))] for k in range(level)]
            coefs[0] = scaled_bases[i]
            basis_img = pywt.waverec2(coeffs=coefs,wavelet='haar')
            basis_imgs.append(basis_img)
        else:
            for j in range(len(scaled_bases[i])):
                # coefs = [np.zeros_like(scaled_bases[0]),] + [[np.zeros((scaled_bases[0].shape[0]*(2**k),scaled_bases[0].shape[1]*(2**k))),np.zeros((scaled_bases[0].shape[0]*(2**k),scaled_bases[0].shape[1]*(2**k))),np.zeros((scaled_bases[0].shape[0]*(2**k),scaled_bases[0].shape[1]*(2**k)))] for k in range(level)]
                coefs = [np.zeros_like(scaled_bases[0]),] + [[np.zeros((img_shapes[k][0],img_shapes[k][1])),np.zeros((img_shapes[k][0],img_shapes[k][1])),np.zeros((img_shapes[k][0],img_shapes[k][1]))] for k in range(level)]
                coefs[i][j] = scaled_bases[i][j]
                basis_img = pywt.waverec2(coeffs=coefs,wavelet='haar')
                basis_imgs.append(basis_img)
    return basis_imgs

In [None]:
def wav_decomposition_patched(img,level, patch_dim=16):
    """Docstring goes here"""
    np_img = np.array(img)
    # weights = np.array([0.229, 0.587, 0.114])
    # np_img = (weights[0]*img[0] + weights[1]*img[1] + weights[2]*img[2])
    patches = np.zeros(shape=(np_img.shape[0], np_img.shape[1],np_img.shape[2], (1 + 3*level)))
    for k in range(np_img.shape[0]):
        for i in range(0, np_img.shape[1], patch_dim):
            for j in range(0, np_img.shape[2], patch_dim):
                patch_wt = wt_decomposition(np_img[k,i:i+patch_dim, j:j+patch_dim],level)
                
                patch = np.zeros(shape=(patch_dim, patch_dim, 1+3*level))
                for l in range(patches.shape[3]):
                    patch[:, :, l] = patch_wt[l]
    
                patches[k,i : i + patch_dim, j : j + patch_dim, :] = patch
    return patches

def wav_channel_decomposition_patched(img, level, patch_dim=16):
    """Docstring goes here"""

    np_img = np.array(img)
    patches = np.zeros(shape=(224, 224, 1+3*level))

    for i in range(0, np_img.shape[0], patch_dim):
        for j in range(0, np_img.shape[1], patch_dim):
            patch_wt = wt_decomposition(np_img[i:i+patch_dim, j:j+patch_dim],level)            
            patch = np.zeros(shape=(patch_dim, patch_dim, patch_dim**2))
            for l in range(patch_dim):
                for m in range(patch_dim):
                    single_patch_dct = np.zeros_like(patch_dct)
                    single_patch_dct[l, m] = patch_dct[l, m]
                    patch_idct = idct2(single_patch_dct)
                    patch[:, :, l*patch_dim + m] = patch_idct

            patches[i : i + patch_dim, j : j + patch_dim, :] = patch
    return patches, patch_weights

In [None]:
## Import pickle files
with open('./Composition_fn_data/level_1_data_all_classes.pkl','rb') as f:
   data =  pickle.load(f)
with open('./Composition_fn_data/level_1_labels_all_classes.pkl','rb') as f:
    labels = pickle.load(f)

In [None]:
## Import pickle files
with open('./Composition_fn_data/level_2_data_all_classes.pkl','rb') as f:
   data_2 =  pickle.load(f)
with open('./Composition_fn_data/level_2_labels_all_classes.pkl','rb') as f:
    labels_2 = pickle.load(f)

In [None]:
#import pickle files, layer 6 data using dwt
with open('./Composition_fn_data/level_1_data_all_classes_layer_5.pkl','rb') as f:
   data_4 =  pickle.load(f)
with open('./Composition_fn_data/level_1_labels_all_classes_layer_5.pkl','rb') as f:
    labels_4 = pickle.load(f)

In [None]:
#impot pickle files, layer 2 using dwt
with open('./Composition_fn_data/level_1_data_all_classes_layer_1.pkl','rb') as f:
   data_3 =  pickle.load(f)
with open('./Composition_fn_data/level_1_labels_all_classes_layer_1.pkl','rb') as f:
    labels_3 = pickle.load(f)

In [None]:
#import pickle files, level 1 layer 12 db4
with open('./Composition_fn_data/level_1_data_all_classes_db4.pkl','rb') as f:
   data_5 =  pickle.load(f)
with open('./Composition_fn_data/level_1_labels_all_classes_db4.pkl','rb') as f:
    labels_5 = pickle.load(f)

In [None]:
#import pickle files, level 2 layer 12 db4
with open('./Composition_fn_data/level_2_data_all_classes_layer_-1_wav_db4.pkl','rb') as f:
   data_6 =  pickle.load(f)
with open('./Composition_fn_data/level_2_labels_all_classes_layer_-1_wavdb4.pkl','rb') as f:
    labels_6 = pickle.load(f)

In [None]:
with open('./Composition_fn_data/vit_large_level_2_data_all_classes.pkl','rb') as f:
   data_6 =  pickle.load(f)
with open('./Composition_fn_data/vit_large_level_2_labels_all_classes.pkl','rb') as f:
    labels_6 = pickle.load(f)

In [None]:
with open('./Composition_fn_data/vit_large_level_1_data_all_classes.pkl','rb') as f:
   data_7 =  pickle.load(f)
with open('./Composition_fn_data/vit_large_level_1_labels_all_classes.pkl','rb') as f:
    labels_7 = pickle.load(f)

In [None]:
with open('./Composition_fn_data/level_2_data_all_classes_layer_1_wav_haar.pkl','rb') as f:
   data_ =  pickle.load(f)
with open('./Composition_fn_data/level_2_labels_all_classes_layer_1_wavhaar.pkl','rb') as f:
    labels_ = pickle.load(f)

In [None]:
with open('./Composition_fn_data/level_2_data_all_classes_layer_5_wav_haar.pkl','rb') as f:
   data_9 =  pickle.load(f)
with open('./Composition_fn_data/level_2_labels_all_classes_layer_5_wavhaar.pkl','rb') as f:
    labels_9 = pickle.load(f)

In [None]:
class Data(Dataset):
    def __init__(self,data,labels,transform = True):
        self.imgs = data
        self.labels = labels
        self.transform=transform
    def __len__(self):
        return len(self.labels)

    def __getitem__(self,idx):
        # if self.transform:
        return self.imgs[idx],self.labels[idx]


In [None]:
class Approximator(torch.nn.Module):
    def __init__(self,level,model=vit_base):
        super().__init__()
        self.weight = Parameter(torch.rand(level*3 + 1))
        # if constraint:
        #     data = self.weight.data
        #     data = data.clamp(min=0)
        #     self.weight.data = data
        # self.weight.requires_grad=True
        self.model = model.classifier
        self.softmax = torch.nn.Softmax(dim=1)
        if model == vit_base:
            self.hidden_size = 768
        elif model == vit_large:
            self.hidden_size = 1024
    def forward(self,input: torch.Tensor)-> torch.Tensor:
        print(input.shape)
        if input.shape[-1] == self.hidden_size:
            input = input.permute(0,-1,1)
        print(torch.matmul(input,self.weight.data))
        out = F.linear(input,self.weight)
        return self.model(out)

In [None]:
class Approximator(torch.nn.Module):
    def __init__(self,level,model=vit_base):
        super().__init__()
        self.weight = Parameter(torch.rand(level*3 + 1))
        # if constraint:
        #     data = self.weight.data
        #     data = data.clamp(min=0)
        #     self.weight.data = data
        # self.weight.requires_grad=True
        self.model = model.classifier
        self.softmax = torch.nn.Softmax(dim=1)
        if model == vit_base:
            self.hidden_size = 768
        elif model == vit_large:
            self.hidden_size = 1024
    def forward(self,input: torch.Tensor)-> torch.Tensor:
        if input.shape[-1] == self.hidden_size:
            input = input.permute(0,-1,1)
        out = F.linear(input,self.weight)
        #out = input * self.weight.unsqueeze(0).unsqueeze(-1) 
        return self.model(out)


In [None]:
def train_one_epoch(epoch_index,model,train_loader,loss_fn,optimizer,softmax,constraint=False):
    running_loss = 0
    for i,data in enumerate(train_loader):
        inputs,labels = data
        optimizer.zero_grad()
        outputs = model(inputs.to('cpu'))
        loss = loss_fn(softmax(outputs),softmax(labels[:,0,:].to('cpu')))
        loss.backward()
        for i,param in enumerate(model.parameters()):
            if i != 0:
                param.requires_grad = False
        optimizer.step()
        if constraint == 'Non-negative':
            with torch.no_grad():
                model.weight.copy_ (model.weight.data.clamp(min=0))
        # elif constraint == 'convex':
        #     soft_fn = torch.nn.Softmax()
            # with torch.no_grad():
            #     model.weight.copy_ (soft_fn(model.weight.data))
        running_loss += loss.item()  
    return running_loss/len(train_loader)

In [None]:
def train(model,epochs,train_loader,val_loader,loss_fn,optimizer,softmax,constraint=False):
    epoch_number = 0
    
    EPOCHS = epochs
    
    best_vloss = 1_000_000.
    
    for epoch in range(EPOCHS):
        print('EPOCH {}:'.format(epoch_number + 1))
        model.train(True)
        avg_loss = train_one_epoch(epoch_number,model,train_loader,loss_fn,optimizer,softmax,constraint)
        # print(avg_loss)
        print(model.weight)
        running_vloss = 0.0
        model.eval()
        acc = 0
        # Disable gradient computation and reduce memory consumption.
        with torch.no_grad():
            for i, vdata in enumerate(val_loader):
                vinputs, vlabels = vdata
            # Make predictions for this batch
                com_cls = model(vinputs.to('cpu'))
    
                vloss = loss_fn(softmax(com_cls),softmax(vlabels[:,0,:].to('cpu')))
                running_vloss += vloss
                compose_cls = torch.argmax(com_cls).item()
                org_cls = torch.argmax(vlabels).item()
                if compose_cls == org_cls:
                    acc+=1
        avg_vloss = running_vloss / (i + 1)
        print('LOSS train {} valid {} acc {}'.format(avg_loss, avg_vloss,acc))
    
        if avg_vloss < best_vloss:
    
            best_vloss = avg_vloss
        epoch_number += 1
    if constraint == 'convex':
        soft_fn = torch.nn.Softmax()
        with torch.no_grad():
            model.weight.copy_ (soft_fn(model.weight.data))
    return model,acc,avg_loss,avg_vloss
    

In [None]:
def test(model,test_loader):
    with torch.no_grad():
        acc = 0
        for i, vdata in enumerate(test_loader):
            inputs,labels = vdata
            com_cls = model(inputs.to('cpu'))
            compose_cls = torch.argmax(com_cls).item()
            org_cls = torch.argmax(labels).item()
            if compose_cls == org_cls:
                acc+=1
        return acc
        

In [None]:
def train_model(data,labels,vit = vit_base,test_size = 0.3,level = 1,batch_size = 100,loss = 'CE',optim = 'SGD',epochs = 200,lr=0.001,constraint=False):
    #Split the dataset into train,test,val
    X_train,X_test,y_train,y_test = train_test_split(data,labels,test_size = test_size,random_state=42)
    X_val,X_test,y_val,y_test = train_test_split(X_test,y_test,test_size = 0.5,random_state=42)
    #The labels are all on gpu so we transfer them to cpu
    for i in range(len(y_train)):
        y_train[i] = y_train[i].cpu()

    for i in range(len(y_val)):
        y_val[i] = y_val[i].cpu()

    for i in range(len(y_test)):
        y_test[i] = y_test[i].cpu()
    
    #Wrap the datasets in pytorch Dataset Class
    train_data = Data(X_train,y_train)
    val_data = Data(X_val,y_val)
    test_data = Data(X_test,y_test)

    #Dataloaders for train,test,val
    train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_data, batch_size=1, shuffle=True)
    test_loader = DataLoader(test_data,batch_size=1,shuffle=True)

    #Instance of our model
    model = Approximator(level=level,model=vit)

    # Loss and Optimizer
    loss_fn = torch.nn.CrossEntropyLoss()
    if optim == 'SGD':
        optimizer = torch.optim.SGD(model.parameters(),lr = lr)
    elif optim == 'Adam':
        optimizer = torch.optim.Adam(model.parameters(),lr=lr)

    #Softmax for the output of model and labels
    soft = torch.nn.Softmax(dim=1)

    model,val_acc,avg_loss,avg_vloss = train(model,epochs,train_loader,val_loader,loss_fn,optimizer,soft,constraint)

    test_acc = test(model,test_loader)
    return model,val_acc/len(val_loader),test_acc/len(test_loader)

In [None]:
### To use constraints 
## 1) Non negative - pass constraint = 'Non-negative'
## 2) Convex - pass constraint = 'convex'

In [None]:
model,val_acc,test_acc = train_model(data_6, labels_6, level = 2, epochs = 100, batch_size= 100, optim='SGD',lr=0.001,constraint=False, vit = vit_large) 
print("Test_acc", test_acc,"Val_acc", val_acc)

In [None]:
torch.save(model.state_dict(), './Composition_fn_data/level_2_layer_2_SGD.pt')

In [None]:
model = Approximator(level=2, model = vit_large)
model.load_state_dict(torch.load('./Composition_fn_data/level_2_layer_2_SGD.pt'))

In [None]:
model.weight

In [None]:
model,val_acc,test_acc = train_model(data_, labels_, level = 1, epochs = 100, batch_size= 100, optim='SGD',lr=0.001,constraint='Non-negative') 
print("Test_acc", test_acc,"Val_acc", val_acc)

In [None]:
torch.save(model.state_dict(), './Composition_fn_data/level_2_layer_2_SGD_NonNeg.pt')

In [None]:
model_nocon = Approximator(level=2)
model_nocon.load_state_dict(torch.load('./Composition_fn_data/level_2_db4_Adam.pt'))

In [None]:
model_nocon.weight

In [None]:
model,val_acc,test_acc = train_model(data_6, labels_6, level = 2, epochs = 100, batch_size= 100, optim='Adam',lr=0.001,constraint='Non-negative') 
print("Test_acc", test_acc,"Val_acc", val_acc)

In [None]:
torch.save(model.state_dict(), './Composition_fn_data/level_2_db4_Adam_NonNeg.pt')

In [None]:
model_NonNeg = Approximator(level=2)
model_NonNeg.load_state_dict(torch.load('./Composition_fn_data/level_2_db4_Adam_NonNeg.pt'))

In [None]:
model_NonNeg.weight

In [None]:
model,val_acc,test_acc = train_model(data_6, labels_6, level = 2, epochs = 100, batch_size= 100, optim='Adam',lr=0.001,constraint='convex') 
print("Test_acc", test_acc,"Val_acc", val_acc)

In [None]:
torch.save(model.state_dict(), './Composition_fn_data/level_2_db4_Adam_con.pt')

In [None]:
model_con = Approximator(level=2)
model_con.load_state_dict(torch.load('./Composition_fn_data/level_2_db4_Adam_con.pt'))

In [None]:
model_con.weight