# Experimenting with low-rank model compression
## this time with vision transformers 

In [1]:
import os
import torch
import torchvision
from torchvision.transforms import v2
from torchvision import transforms
import time
from copy import deepcopy as copy

from sklearn.cluster import AgglomerativeClustering
import numpy as np

from tqdm import tqdm
import time 
import gc

from copy import deepcopy as copy
import time


def dino_model():
    os.environ['TORCH_HOME'] = './'
    os.environ['TORCH_HUB'] = './'
    # DINOv2 vit-s (14) with registers
    model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14_reg')
    # state = model.state_dict()
    # mymodel = vit_small(14, 4)
    # mymodel.load_state_dict(state)
    model.eval()

    return model.to('cpu')

def dino_transforms():
    return v2.Compose(
                    [
                        torchvision.transforms.ToTensor(),
                        transforms.Resize(size=(256, 256), antialias=True),
                        transforms.CenterCrop((224, 224)),
                        transforms.Normalize(
                                            mean=[0.485, 0.456, 0.406],
                                            std=[0.229, 0.224, 0.225]
                                            ),
                    ]
                    )

DINOv2 = dino_model()
DINOv2_transform = dino_transforms()

def compute_centroids(weights, assignment):
    # we are going to mean the neurons into the first index in the weights occuring in the assingment
    first_indices = []
    for i in range(int(assignment.max()) + 1):
        indices = (assignment == i).nonzero()

        first_index = indices[0]
        
        try:
            first_indices.append(first_index.item())
            weights[first_index, :] = weights[indices].mean(0)
        except:
            first_indices.append(first_index[0].item())
            weights[first_index[0], :] = weights[indices[0]].mean(0)
    first_indices.sort()

    return weights[first_indices]


def reduce_neurons(weight, bias=None, clusters=None, threshold=0.1):
    # function that does the neuron clustering - returns new weights and biases of reduced neurons layer
    if bias is None:
        bias = torch.zeros((weight.shape[0]))

    
    weight = torch.concat((weight, bias.unsqueeze(-1)), 1)

    normed = torch.nn.functional.normalize(weight)

    D = (1.0 - (normed @ normed.T)).relu()

    C = AgglomerativeClustering(clusters, affinity='precomputed', linkage='complete', compute_full_tree=True, distance_threshold=threshold)
    assignment = C.fit_predict(D)

    centroids = compute_centroids(weight, assignment)

    bias, centroids = centroids[:, -1].squeeze(), centroids[:, :-1]

    return centroids, bias, assignment


def reduce_columns(weight, assignment):
    # function that compensates for neurons that were clustered in the previous layer by aggregating the input features
    # we are going to sum the columns into the first index in the weights occuring in the assignment
    first_indices = []
    for i in range(int(assignment.max())+1):
        indices = (assignment == i).nonzero()

        first_index = indices[0]

        try:
            first_indices.append(first_index.item())
            weight[:, first_index] = weight[:, indices].sum(1)
        except:
            first_indices.append(first_index[0].item())
            weight[:, first_index[0]] = weight[:, indices[0]].sum(1)

    first_indices.sort()

    return weight[:, first_indices]


def train(model, optimizer, loader):
    model.train()
    loss = torch.nn.CrossEntropyLoss()

    for i, (X, y) in tqdm(enumerate(loader)):
        out = model(X.to(0))
        optimizer.zero_grad()
        l = loss(out, y.to(0))
        l.backward()
        optimizer.step()
        


def accuracy(output, target, topk=(1,)):
    output = output.to(torch.device('cpu'))
    target = target.to(torch.device('cpu'))
    maxk = max(topk)
    batch_size = target.shape[0]

    _, idx = output.sort(dim=1, descending=True)
    pred = idx.narrow(1, 0, maxk).t()
    correct = pred.eq(target.reshape(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].reshape(-1).float().sum(dim=0, keepdim=True)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res


def epoch_accuracy(loader_s, student):
    student.eval()

    out_epoch_s = [accuracy(student(L.to(0)), y)[0].detach().cpu().item() for L, y in loader_s]

    student.train()

    return sum(out_epoch_s) / len(out_epoch_s)

def test(network, test_loader):
    network.eval().to(0)
    test_loss = 0
    correct = 0
    test_losses=[]
    with torch.no_grad():
        for data, target in test_loader:
            output = network(data.to(0))
            test_loss += torch.nn.CrossEntropyLoss()(output, target.to(0)).item()
            pred = output.data.max(1, keepdim=True)[1].cpu()
            correct += pred.eq(target.data.view_as(pred)).sum()
        test_loss /= len(test_loader.dataset)
        test_losses.append(test_loss)
        print('\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))


def latency(f, x, trials = 100):
    f.cpu()
    total = 0.0
    for trial in range(trials):
        start = time.perf_counter()
        f(x)
        total += time.perf_counter() - start
    return total / trials


class LowRankLinear(torch.nn.Module):
    # takes in a linear layer and decomposes it into two low-rank linear layers
    def __init__(self, fc, rank):
        super(LowRankLinear, self).__init__()

        self.fc1 = torch.nn.Linear(fc.weight.shape[1], rank, bias = False)
        self.fc2 = torch.nn.Linear(rank, fc.weight.shape[0])
        
        weight1 = fc.weight

        self.fc2.bias = fc.bias

        W1 = weight1.cpu().detach().clone().numpy()

        U1, E1, V1 = np.linalg.svd(W1, False)

        rd1 = np.zeros((len(E1), len(E1)))

        for i, v in enumerate(E1):
            rd1[i, i] = v


        if fc.weight.shape[1] > fc.weight.shape[0]:
            # if the input dom of the fc is bigger than the output dim
            self.fc1.weight = torch.nn.parameter.Parameter(torch.tensor(rd1[:rank, :rank] @ V1[:rank, :]).to(fc.weight.device).type(fc.weight.dtype))
            self.fc2.weight = torch.nn.parameter.Parameter(torch.tensor(U1[:, :rank]).to(fc.weight.device).type(fc.weight.dtype))
        else:
            self.fc1.weight = torch.nn.parameter.Parameter(torch.tensor(V1[:rank, :]).to(fc.weight.device).type(fc.weight.dtype))
            self.fc2.weight = torch.nn.parameter.Parameter(torch.tensor(U1[:, :rank] @ rd1[:rank, :rank]).to(fc.weight.device).type(fc.weight.dtype))


    def forward(self, x):
        return self.fc2(self.fc1(x))
    
def linearleaves(module):
    # returns a list of pairs of (parent, submodule_name) pairs for all submodule leaves of the current module
    if isinstance(module, torch.nn.Linear):
        return [(module, None)]

    linear_children = []
    for name, mod in module.named_modules():
        if isinstance(mod, torch.nn.Linear):
            linear_children.append((name, module))
    return linear_children
        

def getattrrecur(mod, s):
    s = s.split('.')
    for substr in s:
        mod = getattr(mod, substr)
    return mod


def setattrrecur(mod, s, value):
    s = s.split('.')
    for substr in s[:-1]:
        mod = getattr(mod, substr)
    setattr(mod, s[-1], value)

Using cache found in ./hub\facebookresearch_dinov2_main


## we will use this custom module to assess the performance of the embedding model on a simple task

In [2]:
class LinearProbe(torch.nn.Module):
    def __init__(self, module):
        super().__init__()
        self.m = module
        self.linear = torch.nn.Linear(768, 102)
    
    def forward(self, x):
        x = self.m(x).detach()
        return self.linear(x) * 100

## training the linear probe

In [3]:
train_ds = torchvision.datasets.Flowers102('./Flowers102', split='train', transform=DINOv2_transform, download=True)
val_ds = torchvision.datasets.Flowers102('./Flowers102', split='val', transform=DINOv2_transform, download=True)

train_loader = torch.utils.data.DataLoader(train_ds, batch_size=16)
val_loader = torch.utils.data.DataLoader(val_ds, batch_size=48)

DINOv2 = dino_model()

model = LinearProbe(DINOv2)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

for epoch in range(5):
    train(model.to(0), optimizer, train_loader)
    gc.collect()
test(model, val_loader)

Using cache found in ./hub\facebookresearch_dinov2_main
64it [00:34,  1.83it/s]
64it [00:25,  2.55it/s]
64it [00:25,  2.54it/s]
64it [00:25,  2.52it/s]
64it [00:25,  2.48it/s]



Test set: Avg. loss: 0.0176, Accuracy: 1011/1020 (99%)



## clustering the neurons only occurs in the fc layers of the transformer for now (these are responsible for the majority of the parameters)
## rank reduction only occurs when it would reduce model latency (when the rank is reduced by more than half)

see the previous notebook for more details on these two approaches


https://arxiv.org/pdf/2206.06072.pdf

In [19]:
def clusteredlowrank(module, rank, threshold):
    module = copy(module)

    for (name, mod), (name_next, mod_next) in zip(linearleaves(module), linearleaves(module)[1:]):
        if 'fc1' in name:
            weights, bias, assignment = reduce_neurons(getattrrecur(mod, name).weight.detach().clone(), getattrrecur(mod, name).bias.detach().clone(), threshold=threshold)
            cols = reduce_columns(getattrrecur(mod_next, name_next).weight.detach().clone(), assignment)

            mod = torch.nn.Linear(weights.shape[1], weights.shape[0])
            mod.weight = torch.nn.Parameter(weights.detach().clone())
            mod.bias = torch.nn.Parameter(bias.detach().clone())

            replacement_next = torch.nn.Linear(cols.shape[1], cols.shape[0])
            replacement_next.weight = torch.nn.Parameter(cols.detach().clone())
            replacement_next.bias = getattrrecur(mod_next, name_next).bias

            setattrrecur(module, name, mod)
            setattrrecur(module, name_next, replacement_next)


    for name, mod in linearleaves(module):
        setattrrecur(module, name, LowRankLinear(getattrrecur(module, name), rank))
    test(module, val_loader)
    print(latency(module.eval(), torch.ones(1, 3, 224, 224)))

def lowranklatency(module, rank):
    module = copy(module).cpu()
    start = time.time()
    layers_reduced = 0
    for i, (name, mod) in enumerate(linearleaves(module)):
        setattrrecur(module, name, LowRankLinear(getattrrecur(module, name), rank))
        layers_reduced += 1

    print(f'layers reduced: {layers_reduced} ({time.time() - start}s)')
    test(module, val_loader)

    print(latency(module.eval(), torch.ones(1, 3, 224, 224)))


def laterlowranklatency(module, rank, after=27):
    module = copy(module).cpu()
    start = time.time()
    layers_reduced = 0
    for i, (name, mod) in enumerate(linearleaves(module)):
        if i > after:
            setattrrecur(module, name, LowRankLinear(getattrrecur(module, name), rank))
            layers_reduced += 1

    print(f'layers reduced: {layers_reduced} ({time.time() - start}s)')
    test(module, val_loader)

    print(latency(module.eval(), torch.ones(1, 3, 224, 224)))


def earlierlowranklatency(module, rank, before=8):
    module = copy(module).cpu()
    start = time.time()
    layers_reduced = 0
    for i, (name, mod) in enumerate(linearleaves(module)):
        if i < before:
            setattrrecur(module, name, LowRankLinear(getattrrecur(module, name), rank))
            layers_reduced += 1

    print(f'layers reduced: {layers_reduced} ({time.time() - start}s)')
    test(module, val_loader)

    print(latency(module.eval(), torch.ones(1, 3, 224, 224)))


def clusteredlatency(module, threshold):
    module = copy(module)
    start = time.time()

    neurons_reduced = 0

    for (name, mod), (name_next, mod_next) in zip(linearleaves(module), linearleaves(module)[1:]):
        if 'fc1' in name:
            weights, bias, assignment = reduce_neurons(getattrrecur(mod, name).weight.detach().clone(), getattrrecur(mod, name).bias.detach().clone(), threshold=threshold)

            neurons_reduced += (torch.tensor(getattrrecur(mod, name).weight.shape) - torch.tensor(weights.shape)).sum()

            cols = reduce_columns(getattrrecur(mod_next, name_next).weight.detach().clone(), assignment)

            mod = torch.nn.Linear(weights.shape[1], weights.shape[0])
            mod.weight = torch.nn.Parameter(weights.detach().clone())
            mod.bias = torch.nn.Parameter(bias.detach().clone())

            replacement_next = torch.nn.Linear(cols.shape[1], cols.shape[0])
            replacement_next.weight = torch.nn.Parameter(cols.detach().clone())
            replacement_next.bias = getattrrecur(mod_next, name_next).bias

            setattrrecur(module, name, mod)
            setattrrecur(module, name_next, replacement_next)

    print(f'neurons reduced: {neurons_reduced} ({time.time() - start}s)')
    test(module, val_loader)

    print(latency(module.eval(), torch.ones(1, 3, 224, 224)))


def laterclusteredlatency(module, threshold, after=28):
    module = copy(module)
    start = time.time()

    neurons_reduced = 0

    for i, ((name, mod), (name_next, mod_next)) in enumerate(zip(linearleaves(module), linearleaves(module)[1:])):
        if 'fc1' in name and i > after:
            weights, bias, assignment = reduce_neurons(getattrrecur(mod, name).weight.detach().clone(), getattrrecur(mod, name).bias.detach().clone(), threshold=threshold)

            neurons_reduced += (torch.tensor(getattrrecur(mod, name).weight.shape) - torch.tensor(weights.shape)).sum()

            cols = reduce_columns(getattrrecur(mod_next, name_next).weight.detach().clone(), assignment)

            mod = torch.nn.Linear(weights.shape[1], weights.shape[0])
            mod.weight = torch.nn.Parameter(weights.detach().clone())
            mod.bias = torch.nn.Parameter(bias.detach().clone())

            replacement_next = torch.nn.Linear(cols.shape[1], cols.shape[0])
            replacement_next.weight = torch.nn.Parameter(cols.detach().clone())
            replacement_next.bias = getattrrecur(mod_next, name_next).bias

            setattrrecur(module, name, mod)
            setattrrecur(module, name_next, replacement_next)

    print(f'neurons reduced: {neurons_reduced} ({time.time() - start}s)')
    test(module, val_loader)

    print(latency(module.eval(), torch.ones(1, 3, 224, 224)))


def earlierclusteredlatency(module, threshold, before=28):
    module = copy(module)
    start = time.time()

    neurons_reduced = 0

    for i, ((name, mod), (name_next, mod_next)) in enumerate(zip(linearleaves(module), linearleaves(module)[1:])):
        if 'fc1' in name and i < before:
            weights, bias, assignment = reduce_neurons(getattrrecur(mod, name).weight.detach().clone(), getattrrecur(mod, name).bias.detach().clone(), threshold=threshold)

            neurons_reduced += (torch.tensor(getattrrecur(mod, name).weight.shape) - torch.tensor(weights.shape)).sum()

            cols = reduce_columns(getattrrecur(mod_next, name_next).weight.detach().clone(), assignment)

            mod = torch.nn.Linear(weights.shape[1], weights.shape[0])
            mod.weight = torch.nn.Parameter(weights.detach().clone())
            mod.bias = torch.nn.Parameter(bias.detach().clone())

            replacement_next = torch.nn.Linear(cols.shape[1], cols.shape[0])
            replacement_next.weight = torch.nn.Parameter(cols.detach().clone())
            replacement_next.bias = getattrrecur(mod_next, name_next).bias

            setattrrecur(module, name, mod)
            setattrrecur(module, name_next, replacement_next)

    print(f'neurons reduced: {neurons_reduced} ({time.time() - start}s)')
    test(module, val_loader)

    print(latency(module.eval(), torch.ones(1, 3, 224, 224)))

In [5]:
model.float().cpu()
print(latency(model.eval(), torch.ones(1, 3, 224, 224)))
test(model, val_loader)

0.258001792999998

Test set: Avg. loss: 0.0176, Accuracy: 1011/1020 (99%)



bfloats are actually much (~100x) slower on the cpu than float32s

In [6]:
model.bfloat16()
print(latency(model.eval(), torch.ones(1, 3, 224, 224).bfloat16()))

27.134958201000014


In [7]:
model.float().cpu()
evaluating = copy(model)
clusteredlatency(evaluating, 0.1)

  weight[:, first_index] = weight[:, indices].sum(1)


neurons reduced: 3927 (12.91650128364563s)

Test set: Avg. loss: 0.2549, Accuracy: 774/1020 (76%)

0.24229154299999664


In [10]:
model.float().cpu()
evaluating = copy(model.eval())
lowranklatency(evaluating, 382)

layers reduced: 49 (199.4038803577423s)

Test set: Avg. loss: 2.9148, Accuracy: 11/1020 (1%)

0.17873593699999787


In [11]:
model.float().cpu()
evaluating = copy(model.eval())
laterlowranklatency(evaluating, 382, 48)

layers reduced: 1 (0.032996177673339844s)

Test set: Avg. loss: 0.0176, Accuracy: 1011/1020 (99%)

0.23260513299999502


In [12]:
model.float().cpu()
evaluating = copy(model.eval())
laterlowranklatency(evaluating, 382, 44)

layers reduced: 5 (17.37299418449402s)

Test set: Avg. loss: 0.3358, Accuracy: 728/1020 (71%)

0.24638663100000713


In [13]:
model.float().cpu()
evaluating = copy(model.eval())
laterlowranklatency(evaluating, 382, 40)

layers reduced: 9 (34.76936674118042s)

Test set: Avg. loss: 0.9883, Accuracy: 387/1020 (38%)

0.24058260899999367


In [15]:
model.float().cpu()
evaluating = copy(model.eval())
earlierlowranklatency(evaluating, 382, 4)

layers reduced: 4 (17.93900990486145s)

Test set: Avg. loss: 0.0198, Accuracy: 1007/1020 (99%)

0.24273931499998752


In [16]:
model.float().cpu()
evaluating = copy(model.eval())
earlierlowranklatency(evaluating, 382, 8)

layers reduced: 8 (34.327433824539185s)

Test set: Avg. loss: 0.0213, Accuracy: 1000/1020 (98%)

0.2465916739999966


In [17]:
model.float().cpu()
evaluating = copy(model.eval())
earlierlowranklatency(evaluating, 382, 12)

layers reduced: 12 (57.95983099937439s)

Test set: Avg. loss: 0.3101, Accuracy: 794/1020 (78%)

0.2322501379999949


In [20]:
model.float().cpu()
evaluating = copy(model)
earlierclusteredlatency(evaluating, 0.1, 4)

  weight[:, first_index] = weight[:, indices].sum(1)


neurons reduced: 819 (0.8727188110351562s)

Test set: Avg. loss: 0.0173, Accuracy: 1012/1020 (99%)

0.2387627739999948


In [22]:
model.float().cpu()
evaluating = copy(model)
earlierclusteredlatency(evaluating, 0.1, 8)

neurons reduced: 1094 (1.9740779399871826s)

Test set: Avg. loss: 0.0171, Accuracy: 1011/1020 (99%)

0.2560333930000161


In [23]:
model.float().cpu()
evaluating = copy(model)
earlierclusteredlatency(evaluating, 0.1, 12)

neurons reduced: 1290 (2.8597729206085205s)

Test set: Avg. loss: 0.0175, Accuracy: 1010/1020 (99%)

0.24357506200003173


In [24]:
model.float().cpu()
evaluating = copy(model)
laterclusteredlatency(evaluating, 0.1, 44)

neurons reduced: 1103 (0.8631227016448975s)

Test set: Avg. loss: 0.1698, Accuracy: 840/1020 (82%)

0.26188703099998745


In [25]:
model.float().cpu()
evaluating = copy(model)
laterclusteredlatency(evaluating, 0.1, 40)

neurons reduced: 1709 (1.8653233051300049s)

Test set: Avg. loss: 0.2302, Accuracy: 809/1020 (79%)

0.2436918040000046


In [9]:
evaluating = copy(model)
clusteredlowrank(evaluating, rank=384, threshold=0.05)


Test set: Avg. loss: 3.3223, Accuracy: 10/1020 (1%)

0.2702627630000006
