# Experimenting with low-rank model compression
## this time with vision transformers 

In [69]:
import os
import torch
import torchvision
from torchvision.transforms import v2
from torchvision import transforms
import time
from copy import deepcopy as copy

from sklearn.cluster import AgglomerativeClustering
import numpy as np

from tqdm import tqdm
import time 
import gc

from copy import deepcopy as copy
import time


def dino_model():
    os.environ['TORCH_HOME'] = './'
    os.environ['TORCH_HUB'] = './'
    # DINOv2 vit-s (14) with registers
    # model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14_reg')
    model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14_reg')
    # state = model.state_dict()
    # mymodel = vit_small(14, 4)
    # mymodel.load_state_dict(state)
    model.eval()

    return model.to('cpu')

def dino_transforms():
    return v2.Compose(
                    [
                        torchvision.transforms.ToTensor(),
                        transforms.Resize(size=(256, 256), antialias=True),
                        transforms.CenterCrop((224, 224)),
                        transforms.Normalize(
                                            mean=[0.485, 0.456, 0.406],
                                            std=[0.229, 0.224, 0.225]
                                            ),
                    ]
                    )

DINOv2 = dino_model()
DINOv2_transform = dino_transforms()

def compute_centroids(weights, assignment):
    # we are going to mean the neurons into the first index in the weights occuring in the assingment
    first_indices = []
    for i in range(int(assignment.max()) + 1):
        indices = (assignment == i).nonzero()

        first_index = indices[0]
        
        try:
            first_indices.append(first_index.item())
            weights[first_index, :] = weights[indices].mean(0)
        except:
            first_indices.append(first_index[0].item())
            weights[first_index[0], :] = weights[indices[0]].mean(0)
    first_indices.sort()

    return weights[first_indices]


def reduce_neurons(weight, bias=None, clusters=None, threshold=0.1):
    # function that does the neuron clustering - returns new weights and biases of reduced neurons layer
    if bias is None:
        bias = torch.zeros((weight.shape[0]))

    
    weight = torch.concat((weight, bias.unsqueeze(-1)), 1)

    normed = torch.nn.functional.normalize(weight)

    D = (1.0 - (normed @ normed.T)).relu()

    C = AgglomerativeClustering(clusters, metric='precomputed', linkage='complete', compute_full_tree=True, distance_threshold=threshold)
    assignment = C.fit_predict(D)

    centroids = compute_centroids(weight, assignment)

    bias, centroids = centroids[:, -1].squeeze(), centroids[:, :-1]

    return centroids, bias, assignment


def reduce_columns(weight, assignment):
    # function that compensates for neurons that were clustered in the previous layer by aggregating the input features
    # we are going to sum the columns into the first index in the weights occuring in the assignment
    first_indices = []
    for i in range(int(assignment.max())+1):
        indices = (assignment == i).nonzero()

        first_index = indices[0]

        try:
            first_indices.append(first_index.item())
            weight[:, first_index] = weight[:, indices].sum(1)
        except:
            first_indices.append(first_index[0].item())
            weight[:, first_index[0]] = weight[:, indices[0]].sum(1)

    first_indices.sort()

    return weight[:, first_indices]


def train(model, optimizer, loader):
    model.train()
    loss = torch.nn.CrossEntropyLoss()

    for i, (X, y) in tqdm(enumerate(loader)):
        out = model(X.to(0))
        optimizer.zero_grad()
        l = loss(out, y.to(0))
        l.backward()
        optimizer.step()
        


def accuracy(output, target, topk=(1,)):
    output = output.to(torch.device('cpu'))
    target = target.to(torch.device('cpu'))
    maxk = max(topk)
    batch_size = target.shape[0]

    _, idx = output.sort(dim=1, descending=True)
    pred = idx.narrow(1, 0, maxk).t()
    correct = pred.eq(target.reshape(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].reshape(-1).float().sum(dim=0, keepdim=True)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res


def epoch_accuracy(loader_s, student):
    student.eval()

    out_epoch_s = [accuracy(student(L.to(0)), y)[0].detach().cpu().item() for L, y in loader_s]

    student.train()

    return sum(out_epoch_s) / len(out_epoch_s)

def test(network, test_loader, dtype=torch.float32, silent=False):
    network.eval().to(0)
    test_loss = 0
    correct = 0
    test_losses=[]
    with torch.no_grad():
        for data, target in test_loader:
            output = network(data.to(0).type(dtype))
            test_loss += torch.nn.CrossEntropyLoss()(output, target.to(0)).item()
            pred = output.data.max(1, keepdim=True)[1].cpu()
            correct += pred.eq(target.data.view_as(pred)).sum()
        test_loss /= len(test_loader.dataset)
        test_losses.append(test_loss)
        if not silent:
            print('\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))
    if silent:
        return 100. * correct / len(test_loader.dataset) 



def latency(f, x, trials = 100):
    f.cpu()
    total = 0.0
    for trial in range(trials):
        start = time.perf_counter()
        f(x)
        total += time.perf_counter() - start
    return total / trials


class LowRankLinear(torch.nn.Module):
    # takes in a linear layer and decomposes it into two low-rank linear layers
    def __init__(self, fc, rank):
        super(LowRankLinear, self).__init__()

        self.fc1 = torch.nn.Linear(fc.weight.shape[1], rank, bias = False)
        self.fc2 = torch.nn.Linear(rank, fc.weight.shape[0])
        
        weight1 = fc.weight

        self.fc2.bias = fc.bias

        W1 = weight1.cpu().detach().clone().numpy()

        U1, E1, V1 = np.linalg.svd(W1, False)

        rd1 = np.zeros((len(E1), len(E1)))

        for i, v in enumerate(E1):
            rd1[i, i] = v


        if fc.weight.shape[1] > fc.weight.shape[0]:
            # if the input dom of the fc is bigger than the output dim
            self.fc1.weight = torch.nn.parameter.Parameter(torch.tensor(rd1[:rank, :rank] @ V1[:rank, :]).to(fc.weight.device).type(fc.weight.dtype))
            self.fc2.weight = torch.nn.parameter.Parameter(torch.tensor(U1[:, :rank]).to(fc.weight.device).type(fc.weight.dtype))
        else:
            self.fc1.weight = torch.nn.parameter.Parameter(torch.tensor(V1[:rank, :]).to(fc.weight.device).type(fc.weight.dtype))
            self.fc2.weight = torch.nn.parameter.Parameter(torch.tensor(U1[:, :rank] @ rd1[:rank, :rank]).to(fc.weight.device).type(fc.weight.dtype))


    def forward(self, x):
        return self.fc2(self.fc1(x))
    
def linearleaves(module):
    # returns a list of pairs of (parent, submodule_name) pairs for all submodule leaves of the current module
    if isinstance(module, torch.nn.Linear):
        return [(module, None)]

    linear_children = []
    for name, mod in module.named_modules():
        if isinstance(mod, torch.nn.Linear):
            linear_children.append((name, module))
    return linear_children
        

def getattrrecur(mod, s):
    s = s.split('.')
    for substr in s:
        mod = getattr(mod, substr)
    return mod


def setattrrecur(mod, s, value):
    s = s.split('.')
    for substr in s[:-1]:
        mod = getattr(mod, substr)
    setattr(mod, s[-1], value)

Using cache found in ./hub/facebookresearch_dinov2_main


## we will use this custom module to assess the performance of the embedding model on a simple task

In [41]:
class LinearProbe(torch.nn.Module):
    def __init__(self, module):
        super().__init__()
        self.m = module
        self.linear = torch.nn.Linear(384, 10)
    
    def forward(self, x):
        x = self.m(x).detach()
        return self.linear(x) * 100

## training the linear probe

In [42]:
train_ds = torchvision.datasets.CIFAR10('./cifar10', train=True, transform=DINOv2_transform, download=True)
val_ds = torchvision.datasets.CIFAR10('./cifar10', train=False, transform=DINOv2_transform, download=True)

train_loader = torch.utils.data.DataLoader(train_ds, batch_size=128)
val_loader = torch.utils.data.DataLoader(val_ds, batch_size=128)

DINOv2 = dino_model()

model = LinearProbe(DINOv2)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

for epoch in range(10):
    train(model.to(0), optimizer, train_loader)
    gc.collect()
    test(model, val_loader)

Files already downloaded and verified
Files already downloaded and verified


Using cache found in ./hub/facebookresearch_dinov2_main
391it [00:47,  8.17it/s]



Test set: Avg. loss: 0.0207, Accuracy: 9385/10000 (94%)



391it [00:48,  8.01it/s]



Test set: Avg. loss: 0.0179, Accuracy: 9425/10000 (94%)



391it [00:49,  7.93it/s]



Test set: Avg. loss: 0.0169, Accuracy: 9425/10000 (94%)



391it [00:49,  7.96it/s]



Test set: Avg. loss: 0.0162, Accuracy: 9435/10000 (94%)



391it [00:49,  7.93it/s]



Test set: Avg. loss: 0.0161, Accuracy: 9450/10000 (94%)



391it [00:49,  7.93it/s]



Test set: Avg. loss: 0.0174, Accuracy: 9416/10000 (94%)



391it [00:48,  8.03it/s]



Test set: Avg. loss: 0.0189, Accuracy: 9431/10000 (94%)



391it [00:48,  8.01it/s]



Test set: Avg. loss: 0.0178, Accuracy: 9448/10000 (94%)



391it [00:48,  7.99it/s]



Test set: Avg. loss: 0.0180, Accuracy: 9453/10000 (95%)



391it [00:49,  7.90it/s]



Test set: Avg. loss: 0.0191, Accuracy: 9446/10000 (94%)



## clustering the neurons only occurs in the fc layers of the transformer for now (these are responsible for the majority of the parameters)
## rank reduction only occurs when it would reduce model latency (when the rank is reduced by more than half)

see the previous notebook for more details on these two approaches


https://arxiv.org/pdf/2206.06072.pdf

In [43]:
def clusteredlowrank(module, rank, threshold):
    module = copy(module)

    for (name, mod), (name_next, mod_next) in zip(linearleaves(module), linearleaves(module)[1:]):
        if 'fc1' in name:
            weights, bias, assignment = reduce_neurons(getattrrecur(mod, name).weight.detach().clone(), getattrrecur(mod, name).bias.detach().clone(), threshold=threshold)
            cols = reduce_columns(getattrrecur(mod_next, name_next).weight.detach().clone(), assignment)

            mod = torch.nn.Linear(weights.shape[1], weights.shape[0])
            mod.weight = torch.nn.Parameter(weights.detach().clone())
            mod.bias = torch.nn.Parameter(bias.detach().clone())

            replacement_next = torch.nn.Linear(cols.shape[1], cols.shape[0])
            replacement_next.weight = torch.nn.Parameter(cols.detach().clone())
            replacement_next.bias = getattrrecur(mod_next, name_next).bias

            setattrrecur(module, name, mod)
            setattrrecur(module, name_next, replacement_next)


    for name, mod in linearleaves(module):
        setattrrecur(module, name, LowRankLinear(getattrrecur(module, name), rank))
    test(module, val_loader)
    print(latency(module.eval(), torch.ones(1, 3, 224, 224)))

def lowranklatency(module, rank):
    module = copy(module).cpu()
    start = time.time()
    layers_reduced = 0
    for i, (name, mod) in enumerate(linearleaves(module)):
        setattrrecur(module, name, LowRankLinear(getattrrecur(module, name), rank))
        layers_reduced += 1

    print(f'layers reduced: {layers_reduced} ({time.time() - start}s)')
    test(module, val_loader)

    print(latency(module.eval(), torch.ones(1, 3, 224, 224)))


def laterlowranklatency(module, rank, after=27):
    module = copy(module).cpu()
    start = time.time()
    layers_reduced = 0
    for i, (name, mod) in enumerate(linearleaves(module)):
        if i > after:
            setattrrecur(module, name, LowRankLinear(getattrrecur(module, name), rank))
            layers_reduced += 1

    print(f'layers reduced: {layers_reduced} ({time.time() - start}s)')
    test(module, val_loader)

    print(latency(module.eval(), torch.ones(1, 3, 224, 224)))


def earlierlowranklatency(module, rank, before=8):
    module = copy(module).cpu()
    start = time.time()
    layers_reduced = 0
    for i, (name, mod) in enumerate(linearleaves(module)):
        if i < before:
            setattrrecur(module, name, LowRankLinear(getattrrecur(module, name), rank))
            layers_reduced += 1

    print(f'layers reduced: {layers_reduced} ({time.time() - start}s)')
    test(module, val_loader)

    print(latency(module.eval(), torch.ones(1, 3, 224, 224)))


def clusteredlatency(module, threshold):
    module = copy(module)
    start = time.time()

    neurons_reduced = 0

    for (name, mod), (name_next, mod_next) in zip(linearleaves(module), linearleaves(module)[1:]):
        if 'fc1' in name:
            weights, bias, assignment = reduce_neurons(getattrrecur(mod, name).weight.detach().clone(), getattrrecur(mod, name).bias.detach().clone(), threshold=threshold)

            neurons_reduced += (torch.tensor(getattrrecur(mod, name).weight.shape) - torch.tensor(weights.shape)).sum()

            cols = reduce_columns(getattrrecur(mod_next, name_next).weight.detach().clone(), assignment)

            mod = torch.nn.Linear(weights.shape[1], weights.shape[0])
            mod.weight = torch.nn.Parameter(weights.detach().clone())
            mod.bias = torch.nn.Parameter(bias.detach().clone())

            replacement_next = torch.nn.Linear(cols.shape[1], cols.shape[0])
            replacement_next.weight = torch.nn.Parameter(cols.detach().clone())
            replacement_next.bias = getattrrecur(mod_next, name_next).bias

            setattrrecur(module, name, mod)
            setattrrecur(module, name_next, replacement_next)

    print(f'neurons reduced: {neurons_reduced} ({time.time() - start}s)')
    test(module, val_loader)

    print(latency(module.eval(), torch.ones(1, 3, 224, 224)))


def laterclusteredlatency(module, threshold, after=28):
    module = copy(module)
    start = time.time()

    neurons_reduced = 0

    for i, ((name, mod), (name_next, mod_next)) in enumerate(zip(linearleaves(module), linearleaves(module)[1:])):
        if 'fc1' in name and i > after:
            weights, bias, assignment = reduce_neurons(getattrrecur(mod, name).weight.detach().clone(), getattrrecur(mod, name).bias.detach().clone(), threshold=threshold)

            neurons_reduced += (torch.tensor(getattrrecur(mod, name).weight.shape) - torch.tensor(weights.shape)).sum()

            cols = reduce_columns(getattrrecur(mod_next, name_next).weight.detach().clone(), assignment)

            mod = torch.nn.Linear(weights.shape[1], weights.shape[0])
            mod.weight = torch.nn.Parameter(weights.detach().clone())
            mod.bias = torch.nn.Parameter(bias.detach().clone())

            replacement_next = torch.nn.Linear(cols.shape[1], cols.shape[0])
            replacement_next.weight = torch.nn.Parameter(cols.detach().clone())
            replacement_next.bias = getattrrecur(mod_next, name_next).bias

            setattrrecur(module, name, mod)
            setattrrecur(module, name_next, replacement_next)

    print(f'neurons reduced: {neurons_reduced} ({time.time() - start}s)')
    test(module, val_loader)

    print(latency(module.eval(), torch.ones(1, 3, 224, 224)))


def earlierclusteredlatency(module, threshold, before=28):
    module = copy(module)
    start = time.time()

    neurons_reduced = 0

    for i, ((name, mod), (name_next, mod_next)) in enumerate(zip(linearleaves(module), linearleaves(module)[1:])):
        if 'fc1' in name and i < before:
            weights, bias, assignment = reduce_neurons(getattrrecur(mod, name).weight.detach().clone(), getattrrecur(mod, name).bias.detach().clone(), threshold=threshold)

            neurons_reduced += (torch.tensor(getattrrecur(mod, name).weight.shape) - torch.tensor(weights.shape)).sum()

            cols = reduce_columns(getattrrecur(mod_next, name_next).weight.detach().clone(), assignment)

            mod = torch.nn.Linear(weights.shape[1], weights.shape[0])
            mod.weight = torch.nn.Parameter(weights.detach().clone())
            mod.bias = torch.nn.Parameter(bias.detach().clone())

            replacement_next = torch.nn.Linear(cols.shape[1], cols.shape[0])
            replacement_next.weight = torch.nn.Parameter(cols.detach().clone())
            replacement_next.bias = getattrrecur(mod_next, name_next).bias

            setattrrecur(module, name, mod)
            setattrrecur(module, name_next, replacement_next)

    print(f'neurons reduced: {neurons_reduced} ({time.time() - start}s)')
    test(module, val_loader)

    print(latency(module.eval(), torch.ones(1, 3, 224, 224)))

In [44]:
model.float().cpu()
print(latency(model.eval(), torch.ones(1, 3, 224, 224)))
test(model, val_loader)

0.022223653524415566

Test set: Avg. loss: 0.0191, Accuracy: 9446/10000 (94%)



In [None]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters())

bfloats are actually much (~100x) slower on the cpu than float32s

In [50]:
model.bfloat16()
print(latency(model.eval(), torch.ones(1, 3, 224, 224).bfloat16()))
test(model, val_loader, torch.bfloat16)

0.015075424829265103

Test set: Avg. loss: 0.0190, Accuracy: 9458/10000 (95%)



In [77]:
model.float().cpu()
evaluating = copy(model)
clusteredlatency(evaluating, 0.1)
train(evaluating.to(0), optimizer, train_loader)
test(model, val_loader)

neurons reduced: 4268 (0.6930887699127197s)

Test set: Avg. loss: 0.1799, Accuracy: 5494/10000 (55%)

0.020800426495261492


391it [00:51,  7.53it/s]



Test set: Avg. loss: 0.0190, Accuracy: 9446/10000 (94%)



In [82]:
model.float().cpu()
evaluating = copy(model.eval())
inital_params = count_parameters(evaluating)
lowranklatency(evaluating, 95)
train(evaluating.to(0), optimizer, train_loader)
test(evaluating, val_loader)
final_params = count_parameters(evaluating)
print(f'parameters: {inital_params} -> {final_params} (- {inital_params - final_params})')

layers reduced: 49 (1.2263009548187256s)

Test set: Avg. loss: 0.9868, Accuracy: 1078/10000 (11%)

0.020608259348664434


391it [00:53,  7.34it/s]



Test set: Avg. loss: 0.0190, Accuracy: 9446/10000 (94%)

parameters: 22061962 -> 22061962 (- 0)


In [81]:
model.float().cpu()
evaluating = copy(model.eval())
inital_params = count_parameters(evaluating)
lowranklatency(evaluating, 190)
train(evaluating.to(0), optimizer, train_loader)
test(model, val_loader)
lowranklatency(evaluating, 95)
train(evaluating.to(0), optimizer, train_loader)
test(model, val_loader)
final_params = count_parameters(evaluating)
print(f'parameters: {inital_params} -> {final_params} (- {inital_params - final_params})')

layers reduced: 49 (1.237255573272705s)

Test set: Avg. loss: 0.8855, Accuracy: 991/10000 (10%)

0.021984323041979222


391it [00:52,  7.46it/s]


layers reduced: 49 (1.186415195465088s)

Test set: Avg. loss: 0.9868, Accuracy: 1078/10000 (11%)

0.020019255420193074


391it [00:52,  7.41it/s]



Test set: Avg. loss: 0.0190, Accuracy: 9446/10000 (94%)



In [65]:
model.float().cpu()
evaluating = copy(model.eval())
laterlowranklatency(evaluating, 192, 44)

layers reduced: 4 (0.07844781875610352s)

Test set: Avg. loss: 0.0349, Accuracy: 9039/10000 (90%)

0.021354977586306632


In [66]:
model.float().cpu()
evaluating = copy(model.eval())
laterlowranklatency(evaluating, 192, 40)

layers reduced: 8 (0.18579506874084473s)

Test set: Avg. loss: 0.1111, Accuracy: 6845/10000 (68%)

0.02141873525455594


In [67]:
model.float().cpu()
evaluating = copy(model.eval())
laterlowranklatency(evaluating, 192, 36)

layers reduced: 12 (0.26773953437805176s)

Test set: Avg. loss: 0.2637, Accuracy: 4445/10000 (44%)

0.021434388312045485


In [56]:
model.float().cpu()
evaluating = copy(model.eval())
earlierlowranklatency(evaluating, 382, 4)

layers reduced: 4 (0.10966730117797852s)

Test set: Avg. loss: 0.0189, Accuracy: 9454/10000 (95%)

0.021378000401891768


In [57]:
model.float().cpu()
evaluating = copy(model.eval())
earlierlowranklatency(evaluating, 382, 8)

layers reduced: 8 (0.21886730194091797s)

Test set: Avg. loss: 0.0190, Accuracy: 9445/10000 (94%)

0.021889947098679842


In [58]:
model.float().cpu()
evaluating = copy(model.eval())
earlierlowranklatency(evaluating, 382, 12)

layers reduced: 12 (0.32404375076293945s)

Test set: Avg. loss: 0.0190, Accuracy: 9447/10000 (94%)

0.022367348104016856


In [59]:
model.float().cpu()
evaluating = copy(model)
earlierclusteredlatency(evaluating, 0.1, 4)

neurons reduced: 475 (0.052613258361816406s)

Test set: Avg. loss: 0.0236, Accuracy: 9358/10000 (94%)

0.021404069961281493


In [60]:
model.float().cpu()
evaluating = copy(model)
earlierclusteredlatency(evaluating, 0.1, 8)

neurons reduced: 661 (0.11128783226013184s)

Test set: Avg. loss: 0.0241, Accuracy: 9330/10000 (93%)

0.02182576248771511


In [61]:
model.float().cpu()
evaluating = copy(model)
earlierclusteredlatency(evaluating, 0.1, 12)

neurons reduced: 798 (0.17936372756958008s)

Test set: Avg. loss: 0.0280, Accuracy: 9253/10000 (93%)

0.021892040913226083


In [62]:
model.float().cpu()
evaluating = copy(model)
laterclusteredlatency(evaluating, 0.1, 44)

neurons reduced: 638 (0.05638718605041504s)

Test set: Avg. loss: 0.0202, Accuracy: 9392/10000 (94%)

0.022875919700600208


In [63]:
model.float().cpu()
evaluating = copy(model)
laterclusteredlatency(evaluating, 0.1, 40)

neurons reduced: 1666 (0.08997321128845215s)

Test set: Avg. loss: 0.0528, Accuracy: 8298/10000 (83%)

0.02122095245984383


In [64]:
evaluating = copy(model)
clusteredlowrank(evaluating, rank=384, threshold=0.05)


Test set: Avg. loss: 0.0201, Accuracy: 9425/10000 (94%)

0.02682997809140943


Now we will test our hypothesis by lattice descent on the compression graph and then a more exhaustive grid search.

In [107]:
def svdlayer(module, rank, name=None):
    module = copy(module).cpu()
    start = time.time()
    
    setattrrecur(module, name, LowRankLinear(getattrrecur(module, name), rank))

    return module


def svdmodel(module, rank):
    module = copy(module).cpu()
    layers_reduced = 0
    for i, (name, mod) in enumerate(linearleaves(module)):
        setattrrecur(module, name, LowRankLinear(getattrrecur(module, name), rank))
        layers_reduced += 1

    return module


def clusterlayer(module, threshold, layer=0):
    module = copy(module)
    start = time.time()

    neurons_reduced = 0

    for i, ((name, mod), (name_next, mod_next)) in enumerate(zip(linearleaves(module), linearleaves(module)[1:])):
        if 'fc1' in name and i == layer:
            weights, bias, assignment = reduce_neurons(getattrrecur(mod, name).weight.detach().clone(), getattrrecur(mod, name).bias.detach().clone(), threshold=threshold)

            neurons_reduced += (torch.tensor(getattrrecur(mod, name).weight.shape) - torch.tensor(weights.shape)).sum()

            cols = reduce_columns(getattrrecur(mod_next, name_next).weight.detach().clone(), assignment)

            mod = torch.nn.Linear(weights.shape[1], weights.shape[0])
            mod.weight = torch.nn.Parameter(weights.detach().clone())
            mod.bias = torch.nn.Parameter(bias.detach().clone())

            replacement_next = torch.nn.Linear(cols.shape[1], cols.shape[0])
            replacement_next.weight = torch.nn.Parameter(cols.detach().clone())
            replacement_next.bias = getattrrecur(mod_next, name_next).bias

            setattrrecur(module, name, mod)
            setattrrecur(module, name_next, replacement_next)

    return module, neurons_reduced



def clustermodel(module, threshold):
    module = copy(module)
    start = time.time()

    neurons_reduced = 0

    for i, ((name, mod), (name_next, mod_next)) in enumerate(zip(linearleaves(module), linearleaves(module)[1:])):
        if 'fc1' in name:
            weights, bias, assignment = reduce_neurons(getattrrecur(mod, name).weight.detach().clone(), getattrrecur(mod, name).bias.detach().clone(), threshold=threshold)

            neurons_reduced += (torch.tensor(getattrrecur(mod, name).weight.shape) - torch.tensor(weights.shape)).sum()

            cols = reduce_columns(getattrrecur(mod_next, name_next).weight.detach().clone(), assignment)

            mod = torch.nn.Linear(weights.shape[1], weights.shape[0])
            mod.weight = torch.nn.Parameter(weights.detach().clone())
            mod.bias = torch.nn.Parameter(bias.detach().clone())

            replacement_next = torch.nn.Linear(cols.shape[1], cols.shape[0])
            replacement_next.weight = torch.nn.Parameter(cols.detach().clone())
            replacement_next.bias = getattrrecur(mod_next, name_next).bias

            setattrrecur(module, name, mod)
            setattrrecur(module, name_next, replacement_next)

    return module, neurons_reduced


def getclusterablelayers(module):
    cluster_idxs = []
    lowrank_idxs = []

    for i, ((name, mod), (name_next, mod_next)) in enumerate(zip(linearleaves(module), linearleaves(module)[1:])):
        lowrank_idxs.append(i)
        if 'fc1' in name:
            cluster_idxs.append(i)

    return cluster_idxs

def latticedescenteval(module, tolerance=4):
    cluster_idxs = getclusterablelayers(module)
    module = copy(module)
    inital_acc = test(module, val_loader, silent=True)
    initial_params = count_parameters(module)

    lattice = [(module, inital_acc)]

    reduced = []
    neurons_reduced = []
    
    while max(lattice, key=lambda k: k[-1])[-1] > inital_acc - tolerance:
        module = lattice[0][0]
        print(latency(module.eval(), torch.ones(1, 3, 224, 224)))
        lattice.clear()
        for i in cluster_idxs:
            m, neurons = clusterlayer(module, 0.1, layer=i)
            distance = 0.1
            while neurons < 50:
                distance += 0.1
                m, neurons = clusterlayer(module, distance, layer=i)
            else:
                acc = test(m.bfloat16(), val_loader, torch.bfloat16, silent=True)
                lattice.append((m.float(), i, neurons, acc))
        
        if lattice:
            m, i, neurons, acc = max(lattice, key=lambda k: k[-1])
        else:
            break

        train(m.to(0), optimizer, train_loader)

        reduced.append(i)
        neurons_reduced.append(neurons)

        lattice = [(m, acc)]

        current_params = count_parameters(m)
        print(f'Current accuracy: {float(lattice[0][1])}% ({float(lattice[0][1]) - float(inital_acc)}), layers reduced: {reduced}; neurons reduced: {sum(neurons_reduced)} (- {initial_params - current_params} parameters)')
    
    return lattice[0][0]


In [109]:
module = latticedescenteval(model)

0.02326177352690138


391it [00:50,  7.67it/s]


Current accuracy: 94.91999816894531% (0.45999908447265625), layers reduced: [26]; neurons reduced: 142 (- 109198 parameters)
0.023186166152590885


391it [00:50,  7.78it/s]


Current accuracy: 94.75% (0.29000091552734375), layers reduced: [26, 18]; neurons reduced: 279 (- 214551 parameters)
0.023874963453272356


391it [00:50,  7.73it/s]


Current accuracy: 94.55000305175781% (0.09000396728515625), layers reduced: [26, 18, 10]; neurons reduced: 416 (- 319904 parameters)
0.026018363183829933


391it [00:48,  8.11it/s]


Current accuracy: 93.87000274658203% (-0.589996337890625), layers reduced: [26, 18, 10, 46]; neurons reduced: 1054 (- 810526 parameters)
0.02249063953291625


391it [00:47,  8.31it/s]


Current accuracy: 93.83000183105469% (-0.6299972534179688), layers reduced: [26, 18, 10, 46, 46]; neurons reduced: 1159 (- 891271 parameters)
0.021818426897516473


391it [00:47,  8.23it/s]


Current accuracy: 93.52999877929688% (-0.9300003051757812), layers reduced: [26, 18, 10, 46, 46, 6]; neurons reduced: 1345 (- 1034305 parameters)
0.023826734247850253


391it [00:49,  7.96it/s]


Current accuracy: 92.5199966430664% (-1.94000244140625), layers reduced: [26, 18, 10, 46, 46, 6, 22]; neurons reduced: 1456 (- 1119664 parameters)
0.022900281874462962


391it [00:49,  7.93it/s]


Current accuracy: 91.81999969482422% (-2.6399993896484375), layers reduced: [26, 18, 10, 46, 46, 6, 22, 22]; neurons reduced: 1507 (- 1158883 parameters)
0.022624140278203413


391it [00:48,  8.10it/s]


Current accuracy: 91.06999969482422% (-3.3899993896484375), layers reduced: [26, 18, 10, 46, 46, 6, 22, 22, 46]; neurons reduced: 2057 (- 1581833 parameters)
0.020954165830044075


391it [00:48,  8.11it/s]

Current accuracy: 89.16000366210938% (-5.299995422363281), layers reduced: [26, 18, 10, 46, 46, 6, 22, 22, 46, 26]; neurons reduced: 2118 (- 1628742 parameters)





In [110]:
def getcompressablelayers(module):
    lowrank_names = []

    for i, ((name, mod), (name_next, mod_next)) in enumerate(zip(linearleaves(module), linearleaves(module)[1:])):
        lowrank_names.append(name)

    return lowrank_names


def lowranklatticedescenteval(module, tolerance=4):
    lowrank_names = getcompressablelayers(module)
    module = copy(module)
    inital_acc = test(module, val_loader, silent=True)
    initial_params = count_parameters(module)

    lattice = [(module, inital_acc)]

    reduced = []
    
    while max(lattice, key=lambda k: k[-1])[-1] > inital_acc - tolerance:
        lowrank_names = getcompressablelayers(module)
        module = lattice[0][0]
        print(latency(module.eval(), torch.ones(1, 3, 224, 224)))
        lattice.clear()

        for name in lowrank_names:
            
            try:
                min_weight = torch.tensor(getattrrecur(module, name).weight.shape).min()
            except:
                continue
            
            m = svdlayer(module, min_weight // 4, name=name)
            acc = test(m.bfloat16(), val_loader, torch.bfloat16, silent=True)
            lattice.append((m, name, acc))
        
        if lattice:
            m, i, acc = max(lattice, key=lambda k: k[-1])
        else:
            break

        train(m.float().to(0), optimizer, train_loader)

        reduced.append(i)

        lattice = [(m, acc)]

        current_params = count_parameters(m)
        print(f'Current accuracy: {float(lattice[0][1])}% ({float(lattice[0][1]) - float(inital_acc)}), layers reduced: {reduced} (- {initial_params - current_params} parameters)')
    
    return lattice[0][0]


In [111]:
module = lowranklatticedescenteval(module)

AttributeError: 'NoneType' object has no attribute 'named_modules'

In [86]:
model.float().cpu()
evaluating = copy(model.eval())
inital_params = count_parameters(evaluating)
evaluating = svdmodel(evaluating, 190)
train(evaluating.to(0), optimizer, train_loader)
test(evaluating, val_loader)
final_params = count_parameters(evaluating)
print(f'parameters: {inital_params} -> {final_params} (- {inital_params - final_params})')

391it [00:47,  8.25it/s]



Test set: Avg. loss: 0.8855, Accuracy: 991/10000 (10%)

parameters: 22061962 -> 14836718 (- 7225244)


In [90]:
model.float().cpu()
evaluating = copy(model.eval())
inital_params = count_parameters(evaluating)
evaluating, neurons = clustermodel(evaluating, 0.1)
test(evaluating, val_loader)
train(evaluating.to(0), optimizer, train_loader)
test(evaluating, val_loader)
final_params = count_parameters(evaluating)
print(f'parameters: {inital_params} -> {final_params} (- {inital_params - final_params})')


Test set: Avg. loss: 0.1799, Accuracy: 5494/10000 (55%)



391it [00:49,  7.93it/s]



Test set: Avg. loss: 0.1799, Accuracy: 5494/10000 (55%)

parameters: 22061962 -> 18779870 (- 3282092)
