# Objective

This experiment checks the following for a simple 2 layer FC network on MNIST.
1. Verify Agop and NFM relations for the conv layers
2. Run RFM to construct similar matrices as the above.(TBD)

The model is taken from MNIST/model3

# Setup

In [1]:
import sys
parent_dir='C:\\Users\\garav\\AGOP\\DLR'
model_dir= 'C:\\Users\\garav\\AGOP\\DLR\\trained_models\\MNIST\\model3\\nn_models\\'
#parent_dir = os.path.abspath(os.path.join(os.getcwd(), os.pardir))
sys.path.append(parent_dir)

In [3]:
import torch
import torchvision
import torchvision.transforms as transforms
from utils import trainer as t
from utils import agop_fc as af
from torch.utils.data import Dataset
import random
import torch.backends.cudnn as cudnn
import rfm
import numpy as np
from trained_models.MNIST.model3 import model3
import numpy as np
from sklearn.model_selection import train_test_split
from torch.linalg import norm
from torchvision import models
import torch.nn as nn
from copy import deepcopy

Setting up a new session...
Without the incoming socket you cannot receive events from the server or register event handlers to your Visdom client.


In [5]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
#device='cpu'
print(f"Using device: {device}")

Using device: cuda:0


In [42]:
pwd

'C:\\Users\\garav\\AGOP\\DLR\\experiments\\rfm_mnist'

In [7]:
def one_hot_data(dataset, num_samples=-1):
    labelset = {}
    for i in range(10):
        one_hot = torch.zeros(10)
        one_hot[i] = 1
        labelset[i] = one_hot

    subset = [(ex.flatten(), labelset[label]) for \
              idx, (ex, label) in enumerate(dataset) if idx < num_samples]
    return subset


def group_by_class(dataset):
    labelset = {}
    for i in range(10):
        labelset[i] = []
    for i, batch in enumerate(dataset):
        img, label = batch
        labelset[label].append(img.view(1, 3, 32, 32))
    return labelset


def split(trainset, p=.8):
    train, val = train_test_split(trainset, train_size=p)
    return train, val

def merge_data(mnist, n):
    #cifar_by_label = group_by_class(cifar)

    mnist_by_label = group_by_class(mnist)

    data = []
    labels = []

    labelset = {}

    for i in range(10):
        one_hot = torch.zeros(1, 10)
        one_hot[0, i] = 1
        labelset[i] = one_hot

    for l in mnist_by_label:

        #cifar_data = torch.cat(cifar_by_label[l])
        mnist_data = torch.cat(mnist_by_label[l])
        min_len = len(mnist_data)
        m = min(n, min_len)
        #cifar_data = cifar_data[:m]
        mnist_data = mnist_data[:m]

        merged = torch.cat([mnist_data], axis=-1)
        #for i in range(3):
           # vis.image(merged[i])
        data.append(merged.reshape(m, -1))
        print(merged.shape)
        labels.append(np.repeat(labelset[l], m, axis=0))
    data = torch.cat(data, axis=0)

    labels = np.concatenate(labels, axis=0)
    merged_labels = torch.from_numpy(labels)

    return list(zip(data, labels))



In [9]:
torch.cuda.empty_cache()

In [9]:

SEED = 5700
torch.manual_seed(SEED)
random.seed(SEED)
np.random.seed(SEED)
torch.cuda.manual_seed(SEED)
#cudnn.benchmark = False

transform = transforms.Compose(
        [transforms.ToTensor()
        ])

def repeat_channel(x):
    return x.repeat(3, 1, 1)

mnist_transform = transforms.Compose(
    [transforms.Resize([32, 32]),
     transforms.ToTensor(),
     transforms.Lambda(repeat_channel)]
)

path= './data'  
    
mnist_trainset = torchvision.datasets.MNIST(root=path,
                                                train=True,
                                                transform=mnist_transform,
                                                download=True)

#trainset = group_by_class(mnist_trainset)
trainset = merge_data(mnist_trainset, 5000)
trainset, valset = split(trainset, p=.8)
print("Train Size: ", len(trainset), "Val Size: ", len(valset))

trainloader = torch.utils.data.DataLoader(trainset, batch_size=100,
                                              shuffle=True, num_workers=2)
valloader = torch.utils.data.DataLoader(valset, batch_size=100,
                                            shuffle=False, num_workers=1)


mnist_testset = torchvision.datasets.MNIST(root=path,
                                               train=False,
                                               transform=mnist_transform,
                                               download=True)

print("Test Size: ", len(mnist_testset))
testset = merge_data(mnist_testset, 900)
testloader = torch.utils.data.DataLoader(testset, batch_size=128,
                                             shuffle=False, num_workers=2)

name = 'mnist_fc'


torch.Size([5000, 3, 32, 32])
torch.Size([5000, 3, 32, 32])
torch.Size([5000, 3, 32, 32])
torch.Size([5000, 3, 32, 32])
torch.Size([5000, 3, 32, 32])
torch.Size([5000, 3, 32, 32])
torch.Size([5000, 3, 32, 32])
torch.Size([5000, 3, 32, 32])
torch.Size([5000, 3, 32, 32])
torch.Size([5000, 3, 32, 32])
Train Size:  40000 Val Size:  10000
Test Size:  10000
torch.Size([900, 3, 32, 32])
torch.Size([900, 3, 32, 32])
torch.Size([900, 3, 32, 32])
torch.Size([900, 3, 32, 32])
torch.Size([900, 3, 32, 32])
torch.Size([892, 3, 32, 32])
torch.Size([900, 3, 32, 32])
torch.Size([900, 3, 32, 32])
torch.Size([900, 3, 32, 32])
torch.Size([900, 3, 32, 32])


In [12]:
for idx, batch in enumerate(trainloader):
        inputs, labels = batch
        _, dim = inputs.shape
        break
print(dim)

3072


In [15]:
net = model3.Net(3072, num_classes=10)
init_net=deepcopy(net)

In [19]:
import os
if os.path.exists(model_dir+'mnist_fc_trained_nn.pth'):
    checkpoint = torch.load(model_dir+'mnist_fc_trained_nn.pth', map_location=torch.device(device))
    net.load_state_dict(checkpoint['state_dict'])  # Access the 'state_dict' within the loaded dictionary
    print("Model weights loaded successfully.")
else:
    t.train_network(trainloader, valloader, testloader,
                num_classes=10, root_path= model_dir, 
                optimizer=torch.optim.SGD(net.parameters(), lr=.1),
                lfn=  nn.MSELoss(), 
                num_epochs = 10,
                name=name, net=net)

Model weights loaded successfully.


# AGOP_FC.py

In [77]:
''' This module does the following
1. Scan the network for conv layers
2. For each FC layer compute W^TW of eq 3
3. For each FC layer compute the AGOP(AJOP in case of multiple outputs)
4. For each conv layer print the pearson correlation between 2 and 3
'''

import torch
import torch.nn as nn
import random
import numpy as np
#from functorch import jacrev, vmap
from torch.func import jacrev
from torch.nn.functional import pad
#import dataset
from numpy.linalg import eig
from copy import deepcopy
from torch.linalg import norm, svd
from torchvision import models
import visdom
from torch.linalg import norm, eig


SEED = 2323

torch.manual_seed(SEED)
random.seed(SEED)
np.random.seed(SEED)
torch.cuda.manual_seed(SEED)

vis = visdom.Visdom('http://127.0.0.1', use_incoming_socket=False)
vis.close(env='main')

def get_jacobian(net, data, c_idx=0, chunk=100):
    with torch.no_grad():
        def single_net(x):
            # x is (s)
            return net(x.unsqueeze(0))[:,c_idx*chunk:(c_idx+1)*chunk].squeeze(0)
        # Parallelize across the images.
        return torch.vmap(jacrev(single_net))(data) #(n,chunk,s)

def min_max(M):
    return (M - M.min()) / (M.max() - M.min())

def sqrt(G):
    U, s, Vt = svd(G)
    s = torch.pow(s, 1./2)
    G = U @ torch.diag(s) @ Vt
    return G


def correlation(M1, M2):
    M1 -= M1.mean()
    M2 -= M2.mean()

    norm1 = norm(M1.flatten())
    norm2 = norm(M2.flatten())

    return torch.sum(M1.cuda() * M2.cuda()) / (norm1 * norm2)

def egop(model, z, c=10, chunk_idxs=1):
    ajop = 0
    #Chunking is done to compute jacobian as chunks. This saves memory
    chunk = c // chunk_idxs
    for i in range(chunk_idxs):
        grads = get_jacobian(model, z, c_idx=i, chunk=chunk) #(n,chunk,s)
        grads_t = grads.transpose(1, 2) 
        ajop_matmul= torch.matmul(grads_t, grads) #(n,s,s)
        #Clarify: mean and sum are making no difference here. Check if trainloader has grouped images
        ajop += torch.mean(ajop_matmul, dim=0) #(s,s)
    return ajop



def get_grads(net, patchnet, trainloader, max_batch, classes, chunk_idx,
              kernel=(3,3), padding=(1,1),
              stride=(1,1), layer_idx=0):
    net.eval()
    net.cuda()
    patchnet.eval()
    patchnet.cuda()
    M = 0
    #M.cuda()
    
    # Num images for taking AGOP (Can be small for early layers)
    MAX_NUM_IMGS = max_batch

    for idx, batch in enumerate(trainloader):
        print("Computing GOP for sample " + str(idx) + \
              " out of " + str(MAX_NUM_IMGS))
        imgs, _ = batch
        #imgs=imgs[:]
        with torch.no_grad():
            imgs = imgs.cuda()        
            # Run the first half of the network wrt to the current layer 
            ip = net.features[:layer_idx](imgs).cpu() #(n,s)
            
        #print(patches.shape)
        M += egop(patchnet,ip.cuda(), classes, chunk_idx).cuda()
        del imgs
        torch.cuda.empty_cache()
        if idx >= MAX_NUM_IMGS:
            break
    net.cpu()
    patchnet.cpu()
    return M

def load_nn(net, init_net, layer_idx=0):
   
    count = 0
    
    # Get the layer_idx+1 th conv layer
    #TODO: Add functionality to access classifier layers too.
    for idx, m in enumerate(net.features):
        if isinstance(m, nn.Linear):
            count += 1
        if count-1 == layer_idx:
            l_idx = idx
            break
    
    patchnet = deepcopy(net)
    
    # Truncate all layers before l_idx.
    patchnet.features = net.features[l_idx:]
    
    M = net.features[l_idx].weight.data
    # Compute WW which is (s,s) matrix
    M =torch.matmul(M.T, M)
    M0 = init_net.features[l_idx].weight.data
    # Compute W0tW0 which is (s,s) matrix
    M0 =torch.matmul(M0.T, M0)
    return net, patchnet, M, M0, l_idx


def verify_NFA(net, init_net, trainloader, layer_idx=0, max_batch=10, classes=10, chunk_idx=1):


    net, patchnet, M, M0, l_idx = load_nn(net, init_net, layer_idx=layer_idx)

    i_val = correlation(M0.cuda(), M.cuda())
    print("Correlation between Initial and Trained CNFM: ", i_val)

    G = get_grads(net, patchnet, trainloader,  max_batch, classes, chunk_idx,
                  layer_idx=l_idx)
    print("Shape of grad matrix",G.shape)
    G = sqrt(G.cuda())
    Gop = G.clone()
    r_val = correlation(M.cuda(), G.cuda())
    print("Correlation between Trained CNFM and AGOP: ", r_val)
    print("Final: ", i_val, r_val)
    return Gop

def vis_transform_image(net, img, G, layer_idx=0):
   #TODO: What to visualise for the FC layers?
    count = -1
    
    # Computes WtW for the weights(ignoring its bias) of layer_idx+1 the conv layer
    for idx, p in enumerate(net.parameters()):
        if len(p.shape) > 1:
            count += 1
        if count == layer_idx:
            M = p.data
            _, ki, q, s = M.shape

            M = M.reshape(-1, ki*q*s)
            M = torch.einsum('nd, nD -> dD', M, M)
            break

    count = 0
    l_idx = None
    
    # Get the layer_idx+1 conv layer 
    for idx, m in enumerate(net.features):
        if isinstance(m, nn.Conv2d):
            print(m, count)
            count += 1

        if count-1 == layer_idx:
            l_idx = idx
            break

    net.eval()
    net.cuda()
    img = img.cuda()
    img = net.features[:l_idx](img).cpu()
    net.cpu()
    
    # If G is given which is expected to be the AGOP of layer_idx+1 conv layer then that is used.
    if G is not None:
        M = G

    patches = patchify(img, (q, s), (1, 1))
    
    print(patches.shape)
    # Patches should will be of the shape (n,w,h,c,q,s) not (n,w,h,q,s,c)
    n, w, h, q, s, c = patches.shape
    # Vectorize each patch
    patches = patches.reshape(n, w, h, q*s*c)
    # Apply either WtW or AGOP of the layer_idx+1 conv to each patch. D is c*q*s vector
    M_patch = torch.einsum('nwhd, dD -> nwhD', patches, M) #(n,w,h,c*q*s)
    
    M_patch = norm(M_patch, dim=-1) #(n,w,h)

    vis.image(min_max(M_patch[0])) #(w,h) image.




Setting up a new session...
Without the incoming socket you cannot receive events from the server or register event handlers to your Visdom client.


# Verify NFA for FC layers:

In [81]:

print(net)
verify_NFA(net, init_net, trainloader, max_batch= 20, classes=10, chunk_idx=1, layer_idx=0)

Net(
  (features): Sequential(
    (0): Linear(in_features=3072, out_features=1024, bias=True)
    (1): Nonlinearity()
    (2): Linear(in_features=1024, out_features=1024, bias=True)
    (3): Nonlinearity()
  )
  (classifier): Sequential(
    (0): Linear(in_features=1024, out_features=10, bias=True)
  )
)
Correlation between Initial and Trained CNFM:  tensor(0.2370, device='cuda:0')
Computing GOP for sample 0 out of 20
Computing GOP for sample 1 out of 20
Computing GOP for sample 2 out of 20
Computing GOP for sample 3 out of 20
Computing GOP for sample 4 out of 20
Computing GOP for sample 5 out of 20
Computing GOP for sample 6 out of 20
Computing GOP for sample 7 out of 20
Computing GOP for sample 8 out of 20
Computing GOP for sample 9 out of 20
Computing GOP for sample 10 out of 20
Computing GOP for sample 11 out of 20
Computing GOP for sample 12 out of 20
Computing GOP for sample 13 out of 20
Computing GOP for sample 14 out of 20
Computing GOP for sample 15 out of 20
Computing GOP fo

tensor([[ 2.2016e-02, -3.6660e-04,  6.4364e-04,  ..., -1.7765e-03,
          8.2688e-04,  7.2708e-06],
        [-3.6529e-04,  2.4448e-02, -1.4750e-03,  ...,  1.6157e-03,
         -6.3870e-04, -2.2715e-04],
        [ 6.4223e-04, -1.4761e-03,  2.2138e-02,  ..., -9.7591e-04,
          8.7749e-04, -5.8127e-04],
        ...,
        [-1.7785e-03,  1.6131e-03, -9.7762e-04,  ...,  2.5178e-02,
          4.6532e-04,  1.2917e-04],
        [ 8.2640e-04, -6.3988e-04,  8.7811e-04,  ...,  4.6324e-04,
          2.3380e-02,  4.3505e-04],
        [ 5.6838e-06, -2.2632e-04, -5.8024e-04,  ...,  1.3076e-04,
          4.3773e-04,  2.3445e-02]], device='cuda:0')

In [None]:
#TODO: How to meaningfully visualise? 

# RFM 

In [None]:
'''Warning: This is an extremely cpu intensive process since it uses solve function from linalg 
The rfm.py from utils is equipped with more memory efficient solvers. 
'''

rfm.rfm(trainloader, valloader, testloader, name=name,
            batch_size=10, iters=1, reg=1e-3)