In [1]:
import os
import torch
import torchvision
from torchvision.transforms import v2
from torchvision import transforms


def dino_model():
    os.environ['TORCH_HOME'] = '/vagrant/pytorch_models/'
    os.environ['TORCH_HUB'] = '/vagrant/pytorch_models/'
    # DINOv2 vit-s (14) with registers
    model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14_reg')
    # state = model.state_dict()
    # mymodel = vit_small(14, 4)
    # mymodel.load_state_dict(state)
    model.eval()

    return model.to('cpu')

def dino_transforms():
    return v2.Compose(
                    [
                        torchvision.transforms.ToTensor(),
                        transforms.Resize(size=(256, 256), antialias=True),
                        transforms.CenterCrop((224, 224)),
                        transforms.Normalize(
                                            mean=[0.485, 0.456, 0.406],
                                            std=[0.229, 0.224, 0.225]
                                            ),
                    ]
                    )

In [2]:
DINOv2 = dino_model()
DINOv2_transform = dino_transforms()

Using cache found in /vagrant/pytorch_models/hub\facebookresearch_dinov2_main


In [30]:
DINOv2.blocks

ModuleList(
  (0-23): 24 x NestedTensorBlock(
    (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
    (attn): MemEffAttention(
      (qkv): Linear(in_features=1024, out_features=3072, bias=True)
      (attn_drop): Dropout(p=0.0, inplace=False)
      (proj): Linear(in_features=1024, out_features=1024, bias=True)
      (proj_drop): Dropout(p=0.0, inplace=False)
    )
    (ls1): LayerScale()
    (drop_path1): Identity()
    (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
    (mlp): Mlp(
      (fc1): Linear(in_features=1024, out_features=4096, bias=True)
      (act): GELU(approximate='none')
      (fc2): Linear(in_features=4096, out_features=1024, bias=True)
      (drop): Dropout(p=0.0, inplace=False)
    )
    (ls2): LayerScale()
    (drop_path2): Identity()
  )
)

In [16]:
for ele in DINOv2.blocks:
    if isinstance(ele, type(DINOv2.blocks[0])):
        print(ele.mlp.fc1)
        print(ele.mlp.fc2)
        print(ele.attn.proj)
        print(ele.attn.qkv)


Linear(in_features=384, out_features=1536, bias=True)
Linear(in_features=1536, out_features=384, bias=True)
Linear(in_features=384, out_features=384, bias=True)
Linear(in_features=384, out_features=1152, bias=True)
Linear(in_features=384, out_features=1536, bias=True)
Linear(in_features=1536, out_features=384, bias=True)
Linear(in_features=384, out_features=384, bias=True)
Linear(in_features=384, out_features=1152, bias=True)
Linear(in_features=384, out_features=1536, bias=True)
Linear(in_features=1536, out_features=384, bias=True)
Linear(in_features=384, out_features=384, bias=True)
Linear(in_features=384, out_features=1152, bias=True)
Linear(in_features=384, out_features=1536, bias=True)
Linear(in_features=1536, out_features=384, bias=True)
Linear(in_features=384, out_features=384, bias=True)
Linear(in_features=384, out_features=1152, bias=True)
Linear(in_features=384, out_features=1536, bias=True)
Linear(in_features=1536, out_features=384, bias=True)
Linear(in_features=384, out_feat

In [57]:
from sklearn.cluster import AgglomerativeClustering
import numpy as np

def compute_centroids(weights, assignment):
    # we are going to mean the neurons into the first index in the weights occuring in the assingment
    first_indices = []
    for i in range(int(assignment.max()) + 1):
        indices = (assignment == i).nonzero()

        first_index = indices[0]
        
        try:
            first_indices.append(first_index.item())
            weights[first_index, :] = weights[indices].mean(0)
        except:
            first_indices.append(first_index[0].item())
            weights[first_index[0], :] = weights[indices[0]].mean(0)
    first_indices.sort()

    return weights[first_indices]


def reduce_neurons(weight, bias=None, clusters=None, threshold=0.1):
    if bias is None:
        bias = torch.zeros((weight.shape[0]))

    
    weight = torch.concat((weight, bias.unsqueeze(-1)), 1)

    normed = torch.nn.functional.normalize(weight)

    D = (1.0 - (normed @ normed.T)).relu()
    print(D.shape)
    C = AgglomerativeClustering(clusters, affinity='precomputed', linkage='complete', compute_full_tree=True, distance_threshold=threshold)
    assignment = C.fit_predict(D)

    centroids = compute_centroids(weight, assignment)

    bias, centroids = centroids[:, -1].squeeze(), centroids[:, :-1]

    return centroids, bias, assignment


def reduce_columns(weight, assignment):
    # we are going to sum the columns into the first index in the weights occuring in the assignment
    first_indices = []
    for i in range(int(assignment.max())):
        indices = (assignment == i).nonzero()

        first_index = indices[0]

        try:
            first_indices.append(first_index.item())
            weight[:, first_index] = weight[:, indices].sum(1)
        except:
            first_indices.append(first_index[0].item())
            weight[:, first_index[0]] = weight[:, indices[0]].sum(1)

    first_indices.sort()

    return weight[:, first_indices]


In [58]:
DINOv2.cpu()
print(DINOv2.blocks[0].mlp.fc1.weight.shape)
weights, bias, assignment = reduce_neurons(DINOv2.blocks[0].mlp.fc1.weight.detach().clone(), threshold=-1)
print(assignment.max())
weights.shape

torch.Size([1536, 384])
torch.Size([1536, 1536])
1535


torch.Size([1536, 384])

In [59]:
cols = reduce_columns(DINOv2.blocks[0].mlp.fc2.weight.detach().clone(), assignment)
cols.shape

torch.Size([384, 1535])

In [88]:
from tqdm import tqdm
import time

def train(model, optimizer, loader):
    model.train()
    loss = torch.nn.CrossEntropyLoss()

    for i, (X, y) in tqdm(enumerate(loader)):
        out = model(X.to(0))
        optimizer.zero_grad()
        l = loss(out, y.to(0))
        l.backward()
        optimizer.step()
        


def accuracy(output, target, topk=(1,)):
    output = output.to(torch.device('cpu'))
    target = target.to(torch.device('cpu'))
    maxk = max(topk)
    batch_size = target.shape[0]

    _, idx = output.sort(dim=1, descending=True)
    pred = idx.narrow(1, 0, maxk).t()
    correct = pred.eq(target.reshape(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].reshape(-1).float().sum(dim=0, keepdim=True)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res


def epoch_accuracy(loader_s, student):
    student.eval()

    out_epoch_s = [accuracy(student(L.to(0)), y)[0].detach().cpu().item() for L, y in loader_s]

    student.train()

    return sum(out_epoch_s) / len(out_epoch_s)

def test(network, test_loader):
    network.eval()
    test_loss = 0
    correct = 0
    test_losses=[]
    with torch.no_grad():
        for data, target in test_loader:
            output = network(data.to(0))
            test_loss += torch.nn.CrossEntropyLoss()(output, target.to(0)).item()
            pred = output.data.max(1, keepdim=True)[1].cpu()
            correct += pred.eq(target.data.view_as(pred)).sum()
        test_loss /= len(test_loader.dataset)
        test_losses.append(test_loss)
        print('\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))


def latency(f, x, trials = 100):
    total = 0.0
    for trial in range(trials):
        start = time.perf_counter()
        f(x)
        total += time.perf_counter() - start
    return total / trials


In [25]:
class LinearProbe(torch.nn.Module):
    def __init__(self, module):
        super().__init__()
        self.m = module
        self.linear = torch.nn.Linear(384, 102)
    
    def forward(self, x):
        x = self.m(x).detach()
        return self.linear(x) * 100

In [28]:
import gc

train_ds = torchvision.datasets.Flowers102('./Flowers102', split='train', transform=DINOv2_transform, download=True)
val_ds = torchvision.datasets.Flowers102('./Flowers102', split='val', transform=DINOv2_transform, download=True)

train_loader = torch.utils.data.DataLoader(train_ds, batch_size=32)
val_loader = torch.utils.data.DataLoader(val_ds, batch_size=48)

DINOv2 = dino_model()

model = LinearProbe(DINOv2)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

for epoch in range(10):
    train(model.to(0), optimizer, train_loader)
    gc.collect()
    test(model, val_loader)

Using cache found in /vagrant/pytorch_models/hub\facebookresearch_dinov2_main
32it [00:09,  3.51it/s]



Test set: Avg. loss: 1.7021, Accuracy: 285/1020 (28%)



32it [00:09,  3.44it/s]



Test set: Avg. loss: 0.3208, Accuracy: 738/1020 (72%)



32it [00:09,  3.27it/s]



Test set: Avg. loss: 0.1298, Accuracy: 870/1020 (85%)



32it [00:09,  3.29it/s]



Test set: Avg. loss: 0.1043, Accuracy: 898/1020 (88%)



32it [00:10,  3.12it/s]



Test set: Avg. loss: 0.0948, Accuracy: 907/1020 (89%)



32it [00:10,  3.10it/s]



Test set: Avg. loss: 0.0952, Accuracy: 906/1020 (89%)



32it [00:10,  3.05it/s]



Test set: Avg. loss: 0.0949, Accuracy: 906/1020 (89%)



32it [00:10,  3.07it/s]



Test set: Avg. loss: 0.0949, Accuracy: 906/1020 (89%)



32it [00:10,  3.05it/s]



Test set: Avg. loss: 0.0949, Accuracy: 906/1020 (89%)



32it [00:10,  3.06it/s]



Test set: Avg. loss: 0.0949, Accuracy: 906/1020 (89%)



In [101]:
batch_size_train = 128

train_loader = torch.utils.data.DataLoader(
  torchvision.datasets.CIFAR100('./cifar100/', train=True, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=batch_size_train, shuffle=True)

test_loader = torch.utils.data.DataLoader(
  torchvision.datasets.CIFAR100('./cifar100/', train=False, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                             ])),
  batch_size=1024, shuffle=True)

class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = torch.nn.Linear(3072, 2048)
        self.fc2 = torch.nn.Linear(2048, 1024)
        self.fc3 = torch.nn.Linear(1024, 512)
        self.fc4 = torch.nn.Linear(512, 100)

    def forward(self, x):
        x = x
        x = torch.nn.Flatten()(x)
        return torch.nn.functional.relu(self.fc4(torch.nn.functional.relu(self.fc3(torch.nn.functional.relu(self.fc2(torch.nn.functional.relu(self.fc1(x))))))))

Files already downloaded and verified
Files already downloaded and verified


In [91]:
import numpy as np

class LowRankLinear(torch.nn.Module):
    # takes in a linear layer and decomposes it into two low-rank linear layers
    def __init__(self, fc, rank):
        super(LowRankLinear, self).__init__()

        self.fc1 = torch.nn.Linear(fc.weight.shape[1], rank, bias = False)
        self.fc2 = torch.nn.Linear(rank, fc.weight.shape[0])
        
        weight1 = fc.weight

        self.fc2.bias = fc.bias

        W1 = weight1.cpu().detach().clone().numpy()

        U1, E1, V1 = np.linalg.svd(W1, False)

        rd1 = np.zeros((len(E1), len(E1)))

        for i, v in enumerate(E1):
            rd1[i, i] = v


        if fc.weight.shape[1] > fc.weight.shape[0]:
            # if the input dom of the fc is bigger than the output dim
            self.fc1.weight = torch.nn.parameter.Parameter(torch.tensor(rd1[:rank, :rank] @ V1[:rank, :]).to(fc.weight.device).type(fc.weight.dtype))
            self.fc2.weight = torch.nn.parameter.Parameter(torch.tensor(U1[:, :rank]).to(fc.weight.device).type(fc.weight.dtype))
        else:
            self.fc1.weight = torch.nn.parameter.Parameter(torch.tensor(V1[:rank, :]).to(fc.weight.device).type(fc.weight.dtype))
            self.fc2.weight = torch.nn.parameter.Parameter(torch.tensor(U1[:, :rank] @ rd1[:rank, :rank]).to(fc.weight.device).type(fc.weight.dtype))


    def forward(self, x):
        return self.fc2(self.fc1(x))

In [102]:
network = Net().to(0)
optimizer = torch.optim.SGD(network.parameters(), lr=1e-2,
                      momentum=0.5)

for epoch in range(5):
    train(network.to(0), optimizer, train_loader)
    gc.collect()
    test(network, test_loader)

391it [00:12, 32.34it/s]



Test set: Avg. loss: 0.0045, Accuracy: 511/10000 (5%)



391it [00:12, 32.09it/s]



Test set: Avg. loss: 0.0044, Accuracy: 824/10000 (8%)



391it [00:12, 31.64it/s]



Test set: Avg. loss: 0.0044, Accuracy: 1022/10000 (10%)



391it [00:12, 32.22it/s]



Test set: Avg. loss: 0.0043, Accuracy: 1180/10000 (12%)



391it [00:12, 32.11it/s]



Test set: Avg. loss: 0.0043, Accuracy: 1278/10000 (13%)



In [76]:
import time
from copy import deepcopy as copy

compressed_net = copy(network)


print(latency(compressed_net.eval(), torch.ones(1, 3, 32, 32).to(0)))

compressed_net.fc1 = LowRankLinear(network.fc1, 64)
compressed_net.fc2 = LowRankLinear(network.fc2, 64)


print(latency(compressed_net.eval(), torch.ones(1, 3, 32, 32)))

compressed_net.fc1 = LowRankLinear(network.fc1, 32)
compressed_net.fc2 = LowRankLinear(network.fc2, 32)


print(latency(compressed_net.eval(), torch.ones(1, 3, 32, 32)))

compressed_net.fc1 = LowRankLinear(network.fc1, 16)
compressed_net.fc2 = LowRankLinear(network.fc2, 16)


print(latency(compressed_net.eval(), torch.ones(1, 3, 32, 32)))

compressed_net.fc1 = LowRankLinear(network.fc1, 8)
compressed_net.fc2 = LowRankLinear(network.fc2, 8)

print(latency(compressed_net.eval(), torch.ones(1, 3, 32, 32)))

compressed_net.fc1 = LowRankLinear(network.fc1, 4)
compressed_net.fc2 = LowRankLinear(network.fc2, 4)


print(latency(compressed_net.eval(), torch.ones(1, 3, 32, 32)))

Net(
  (fc1): Linear(in_features=3072, out_features=1024, bias=True)
  (fc2): Linear(in_features=1024, out_features=100, bias=True)
)


RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument mat1 in method wrapper_CUDA_addmm)

In [96]:
import time
from copy import deepcopy as copy

network.cpu()
compressed_net = copy(network)

def lowranklatency(module, rank):
    module = copy(module)
    for mod in module.children():
        if isinstance(mod, torch.nn.Linear):
            mod = LowRankLinear(mod, rank)

    print(latency(module.eval(), torch.ones(64, 3, 32, 32)))

print(latency(compressed_net.eval(), torch.ones(64, 3, 32, 32)))

compressed_net.fc1 = LowRankLinear(network.fc1, 64)
compressed_net.fc2 = LowRankLinear(network.fc2, 64)

print(latency(compressed_net.eval(), torch.ones(64, 3, 32, 32)))

compressed_net.fc1 = LowRankLinear(network.fc1, 32)
compressed_net.fc2 = LowRankLinear(network.fc2, 32)


print(latency(compressed_net.eval(), torch.ones(64, 3, 32, 32)))

compressed_net.fc1 = LowRankLinear(network.fc1, 16)
compressed_net.fc2 = LowRankLinear(network.fc2, 16)


print(latency(compressed_net.eval(), torch.ones(64, 3, 32, 32)))

compressed_net.fc1 = LowRankLinear(network.fc1, 8)
compressed_net.fc2 = LowRankLinear(network.fc2, 8)


print(latency(compressed_net.eval(), torch.ones(64, 3, 32, 32)))

compressed_net.fc1 = LowRankLinear(network.fc1, 4)
compressed_net.fc2 = LowRankLinear(network.fc2, 4)

print(latency(compressed_net.eval(), torch.ones(64, 3, 32, 32)))

0.0020964660000026922
0.00046075300000666177
0.0003538460000891064
0.00036670999998023035
0.0002999140000065381
0.00028197599998748047


In [104]:
network.cpu()
compressed_net = copy(network)

def lowranklatency(module, rank):
    module = copy(module)
    for name, mod in module.named_modules():
        if isinstance(mod, torch.nn.Linear):
            module.add_module(name, LowRankLinear(mod, rank))
    print(module)
    print(latency(module.eval(), torch.ones(64, 3, 32, 32)))

for rank in range(2, 128, 2):
    lowranklatency(compressed_net, 2)