In [1]:
import os
# Tu si treba vybrat jedno GPU (cize bud 0 alebo 1)
# Je to cislovane naopak ako v nvidia-smi, because ..., that's why
os.environ["CUDA_VISIBLE_DEVICES"] = "0" 
os.environ["OMP_NUM_THREADS"] = "1"

In [2]:
IMAGENET_PATH = "/data/imagenet/imagenet/"

In [3]:
import torch
import torch.nn as nn
import os
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import time
import copy
import sys

import random
import numpy as np
import torch
import torch
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F
from torch.autograd import Variable

import sys
import numpy as np
import timm
from functools import partial
from tqdm.notebook import tqdm

In [4]:
def random_seed(seed=42, rank=0):
    torch.manual_seed(seed + rank)
    np.random.seed(seed + rank)
    random.seed(seed + rank)

random_seed(47)

In [5]:
device = torch.device("cuda")
amp_autocast = partial(torch.autocast, device_type=device.type, dtype=torch.float16)

In [6]:
# Very small models
model_name = "test_vit3.r160_in1k"
#model_name = "test_vit2.r160_in1k"
#model_name = "test_vit.r160_in1k"

# Bigger models
#model_name = "vit_wee_patch16_reg1_gap_256.sbb_in1k" 
#model_name = "vit_medium_patch16_reg4_gap_256.sbb_in1k"

# Even bigger
# See bigger than tiny from: https://huggingface.co/timm/convnext_tiny.fb_in1k
# Or others at https://huggingface.co/timm/vit_wee_patch16_reg1_gap_256.sbb_in1k

model = timm.create_model(model_name, pretrained=True)
model.cuda() 
print(model)

sum(p.numel() for p in model.parameters())

VisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 96, kernel_size=(16, 16), stride=(16, 16))
    (norm): Identity()
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (patch_drop): Identity()
  (norm_pre): Identity()
  (blocks): Sequential(
    (0): Block(
      (norm1): LayerNorm((96,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=96, out_features=288, bias=True)
        (q_norm): Identity()
        (k_norm): Identity()
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=96, out_features=96, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (ls1): LayerScale()
      (drop_path1): Identity()
      (norm2): LayerNorm((96,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=96, out_features=192, bias=True)
        (act): GELU(approximate='none')
        (drop1): Dropout(p=0.0, inplace=False)
        (norm): Identity()
     

930280

In [7]:
def validate(model, val_loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for input, target in tqdm(val_loader):
            with amp_autocast():
                output = model(input)
                correct += (target == output.argmax(dim=1)).sum().item()
                total += target.numel()
    print("val acc", correct / total)

In [8]:
data_config = timm.data.resolve_data_config(model=model)

print(data_config)

val_dataset = timm.data.create_dataset(
    name="imagenet",
    split="validation",
    root=IMAGENET_PATH
)

val_loader = timm.data.create_loader(
    val_dataset,
    input_size=data_config['input_size'],                                                 
    batch_size=128,
    use_prefetcher=True,                                                       
    interpolation=data_config['interpolation'],                                           
    mean=data_config['mean'],
    std=data_config['std'],
    num_workers=8,                                                             
    crop_pct=data_config["crop_pct"],
    crop_mode=data_config['crop_mode'],                                                   
    crop_border_pixels=False,  
    pin_memory=True,
    device=device,
)

for bx, by in val_loader:
    print(bx.shape, by.shape)
    print(by)
    break

{'input_size': (3, 160, 160), 'interpolation': 'bicubic', 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5), 'crop_pct': 0.95, 'crop_mode': 'center'}
torch.Size([128, 3, 160, 160]) torch.Size([128])
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2], device='cuda:0')


In [9]:
train_dataset = timm.data.create_dataset(
    name="imagenet",
    split="train",
    root=IMAGENET_PATH
)

# TODO: augmentations
train_loader = timm.data.create_loader(
    train_dataset,
    input_size=data_config['input_size'],                                                 
    batch_size=128,
    use_prefetcher=True,                                                       
    interpolation="random",  
    mean=data_config['mean'],
    std=data_config['std'],
    num_workers=12,                                                             
    crop_pct=data_config["crop_pct"],
    crop_mode=data_config['crop_mode'],                                                   
    crop_border_pixels=False,  
    pin_memory=True,
    device=device,
    is_training=True,
)



In [10]:
# Run validation on full model (reports number in range 0-1)

validate(model, val_loader)

  0%|          | 0/391 [00:00<?, ?it/s]

val acc 0.56908


In [11]:
# Function for gathering input statistics for a layer
def update_cov(m, i, o):
    x = i[0].detach().flatten(0, -2).float()
    with torch.autocast(device_type="cuda", enabled=False):
        m.XX.data += x.square().sum(dim=0)

# Selection functions for picking which layers to compress
# In general, we do not compress first and last layer
# Name of the layer is n.n2 (where there are no dots in n2)
# Actual layer is m
select_all_linear = lambda n, n2, m: type(m) == nn.Linear and "head" not in n and "head" not in n2
select_only_attn_proj = lambda n, n2, m: type(m) == nn.Linear and "head" not in n and "head" not in n2 and "proj" in n2


def process_model(model_name=model_name, replace_filter=select_only_attn_proj, 
                  replace_fn=lambda x: x, finetune_epochs=2, finetune_maxlr=1e-5,
                  wanda=True):
    """
    :model_name: Name of model
    :replace_filter: Which layers to replace
    :replace_fn: Takes a layer, outputs new layer (or sequence of layers)
    :finetune_epochs: How many epochs to use for finetuning
    :finetune_maxlr: Maximal learning rate for finetuning
    :wanda: Whether to collect wanda statistics (if False, norms of inputs will be one)
    """
    model = timm.create_model(model_name, pretrained=True)
    model.cuda()
    
    print("pars before", sum(p.numel() for p in model.parameters()))
    
    # need to dump, because changing during iteration would make a mess
    to_change = []
    for n, m in model.named_modules():
        for n2, m2 in m.named_modules():
            if "." not in n2 and replace_filter(n, n2, m2) and len(n2) > 0:
                to_change.append((m, n2, m2))
                if wanda:
                    m2.XX = torch.zeros((m2.weight.shape[1]), device=m2.weight.device)
                    m2.register_forward_hook(update_cov)
                else:
                    m2.XX = torch.ones((m2.weight.shape[1]), device=m2.weight.device)
                
    
    # Gather input stats
    if wanda:
        step = 0
        for input, _ in train_loader:
            with amp_autocast():
                with torch.no_grad():
                    model(input)
            step += 1
            if step == 50:
                break
    
    for m, n2, m2 in to_change:
        setattr(m, n2, replace_fn(m2))
    
    print("pars after", sum(p.numel() for p in model.parameters()))
    print("first validation")
    validate(model, val_loader)
    
    if finetune_epochs == 0:
        return
    
    optimizer = torch.optim.AdamW(model.parameters(), lr=finetune_maxlr, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.PolynomialLR(optimizer, total_iters=len(train_loader)*finetune_epochs)
    scaler = torch.GradScaler("cuda")

    for ep in range(finetune_epochs):
        model.train()
        loss_sum = 0
        loss_cc = 0
        for input, target in (pbar := tqdm(train_loader)):
            with amp_autocast():
                output = model(input)
                loss = F.cross_entropy(output, target)
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
                scheduler.step()
                optimizer.zero_grad()
                loss_sum += loss.item()
                loss_cc += 1
                pbar.set_description("loss: %.3f %.3f" % (loss.item(), loss_sum / loss_cc))

        validate(model, val_loader)

In [12]:
class BlockLinear2(nn.Module):
    def __init__(self, nb=4, bs=(128,512), bias=False, dbg=False):
        super().__init__()
        
        self.weight = nn.Parameter(torch.zeros(nb, bs[0], bs[1]))
        self.bias = None
        if bias:
            self.bias = nn.Parameter(torch.zeros(bs[1]))
    
    def extra_repr(self):
        return "w_shape %s" % str(self.weight.shape)
    
    def forward(self, x):
        orig_shape = x.shape
        x = x.flatten(0, -2)
        x = x.reshape(x.shape[0], self.weight.shape[0], -1).permute(1, 0, 2)
        x = x.matmul(self.weight)
        x = x.permute(1, 0, 2)
        x = x.reshape(tuple(list(orig_shape[0:-1]) + [-1]))
        if self.bias is not None:
            x = x + self.bias
        return x

class MonarchPerm2(nn.Module):
    def __init__(self, nb=4, size=128):
        super().__init__()
        out = []
        perm_base = torch.arange(size)
        parts = perm_base.chunk(nb*nb)
        for i in range(nb):
            out += parts[i::nb]
        self.perm = torch.cat(out).cuda()
        
    def forward(self, x):
        orig_shape = x.shape
        x = x.flatten(0, -2)
        out = x.T[self.perm].T
        return out.reshape(orig_shape)
    
class LambdaLayer(nn.Module):
    def __init__(self, f):
        super().__init__()
        self.f = f
        
    def forward(self, x):
        return self.f(x)

def replace_with_monarch(m, n_blocks=4, rank=None):
    assert m.weight.shape[0] % n_blocks == 0
    assert m.weight.shape[1] % n_blocks == 0
    if rank == None:
        # ~50% of original weights
        bs = m.weight.shape[0] // n_blocks, m.weight.shape[1] // n_blocks
        rank = int((bs[0] * bs[1]) / (bs[0] + bs[1]) / 2)
        print("auto rank", rank)
    
    norm = m.XX.sqrt() + 1e-8
    norm = norm / norm.mean()
    print(m, norm.amin(), norm.amax())
    W = m.weight.detach() * norm
    mid = rank * n_blocks
        
    layer2 = nn.Sequential(
        BlockLinear2(n_blocks, (W.shape[1]//n_blocks, mid)), 
        MonarchPerm2(n_blocks, mid*n_blocks),
        BlockLinear2(n_blocks, (mid, W.shape[0]//n_blocks), bias=(m.bias is not None))
    ).cuda()
    
    print("density", sum(p.numel() for p in layer2.parameters()) / W.numel())

    s0 = W.shape[0] // n_blocks
    s1 = W.shape[1] // n_blocks
    svd_err = 0
    for i in range(n_blocks):
        for j in range(n_blocks):
            part = W[i*s0:i*s0+s0, j*s1:j*s1+s1]

            U, s, Vh = torch.linalg.svd(part, full_matrices=False)
            s = s[:rank]
            U = U[:,:rank] * s.sqrt()
            Vh = Vh[:rank] * s.sqrt().unsqueeze(1)
            svd_err += (part - (U @ Vh)).square().sum().item()

            assert layer2[0].weight.data[j, :,rank*i:rank*i+rank].numel() > 0
            assert layer2[2].weight.data[i, rank*j:rank*j+rank,:].numel() > 0
            layer2[0].weight.data[j, :, rank*i:rank*i+rank] = (Vh / norm[j*s1:j*s1+s1]).T
            layer2[2].weight.data[i, rank*j:rank*j+rank] = U.T

    
    
    if m.bias is not None:
        layer2[2].bias.data = m.bias
    
    test = torch.eye(m.weight.shape[1], device="cuda")
    if m.bias is not None:
        print("err", (layer2(test) - m(test)).square().mean().item(), "triv err", (m(test) - m.bias.detach()).square().mean().item())
    else:
        print("err", (layer2(test) - m(test)).square().mean().item(), "triv err", (m(test)).square().mean().item())
    test = torch.eye(m.weight.shape[1], device="cuda") * norm
    if m.bias is not None:
        print("err norm", (layer2(test) - m(test)).square().mean().item(), "triv err", (m(test) - m.bias.detach()).square().mean().item())
    else:
        print("err norm", (layer2(test) - m(test)).square().mean().item(), "triv err", (m(test)).square().mean().item())
    
    return layer2

process_model(replace_fn=replace_with_monarch, finetune_epochs=0, finetune_maxlr=1e-4)

pars before 930280
auto rank 6
Linear(in_features=96, out_features=96, bias=True) tensor(0.3823, device='cuda:0') tensor(3.6373, device='cuda:0')
density 0.5026041666666666
err 0.0008873476181179285 triv err 0.00307566300034523
err norm 0.0007022766512818635 triv err 0.0034382150042802095
auto rank 6
Linear(in_features=96, out_features=96, bias=True) tensor(0.3820, device='cuda:0') tensor(2.1747, device='cuda:0')
density 0.5026041666666666
err 0.0007720351568423212 triv err 0.002187454840168357
err norm 0.0005869496962986887 triv err 0.00246794824488461
auto rank 6
Linear(in_features=96, out_features=96, bias=True) tensor(0.3452, device='cuda:0') tensor(2.2589, device='cuda:0')
density 0.5026041666666666
err 0.0009386584861204028 triv err 0.0021895181853324175
err norm 0.0006904646870680153 triv err 0.0023023795802146196
auto rank 6
Linear(in_features=96, out_features=96, bias=True) tensor(0.5048, device='cuda:0') tensor(2.6526, device='cuda:0')
density 0.5026041666666666
err 0.0009713

  0%|          | 0/391 [00:00<?, ?it/s]

val acc 0.10512


In [13]:
process_model(replace_fn=replace_with_monarch, finetune_epochs=0, finetune_maxlr=1e-4, wanda=False)

pars before 930280
auto rank 6
Linear(in_features=96, out_features=96, bias=True) tensor(1., device='cuda:0') tensor(1., device='cuda:0')
density 0.5026041666666666
err 0.0008055444923229516 triv err 0.00307566300034523
err norm 0.0008055444923229516 triv err 0.00307566300034523
auto rank 6
Linear(in_features=96, out_features=96, bias=True) tensor(1., device='cuda:0') tensor(1., device='cuda:0')
density 0.5026041666666666
err 0.0006608118419535458 triv err 0.002187454840168357
err norm 0.0006608118419535458 triv err 0.002187454840168357
auto rank 6
Linear(in_features=96, out_features=96, bias=True) tensor(1., device='cuda:0') tensor(1., device='cuda:0')
density 0.5026041666666666
err 0.0007813608972355723 triv err 0.0021895181853324175
err norm 0.0007813608972355723 triv err 0.0021895181853324175
auto rank 6
Linear(in_features=96, out_features=96, bias=True) tensor(1., device='cuda:0') tensor(1., device='cuda:0')
density 0.5026041666666666
err 0.0008158438722603023 triv err 0.002281213

  0%|          | 0/391 [00:00<?, ?it/s]

val acc 0.03074


In [15]:
# Toto trva dlho
# process_model(replace_fn=replace_with_monarch, finetune_epochs=1, finetune_maxlr=1e-4)

pars before 930280
auto rank 6
Linear(in_features=96, out_features=96, bias=True) tensor(0.3776, device='cuda:0') tensor(3.6518, device='cuda:0')
density 0.5026041666666666
err 0.0008888047304935753 triv err 0.00307566300034523
err norm 0.0007026286912150681 triv err 0.0034289523027837276
auto rank 6
Linear(in_features=96, out_features=96, bias=True) tensor(0.3807, device='cuda:0') tensor(2.1671, device='cuda:0')
density 0.5026041666666666
err 0.0007724976167082787 triv err 0.002187454840168357
err norm 0.0005873910849913955 triv err 0.0024656588211655617
auto rank 6
Linear(in_features=96, out_features=96, bias=True) tensor(0.3455, device='cuda:0') tensor(2.2605, device='cuda:0')
density 0.5026041666666666
err 0.0009388193138875067 triv err 0.0021895181853324175
err norm 0.0006902138120494783 triv err 0.002301169093698263
auto rank 6
Linear(in_features=96, out_features=96, bias=True) tensor(0.5165, device='cuda:0') tensor(2.6154, device='cuda:0')
density 0.5026041666666666
err 0.000969

  0%|          | 0/391 [00:00<?, ?it/s]

val acc 0.10654


  0%|          | 0/10009 [00:00<?, ?it/s]



  0%|          | 0/391 [00:00<?, ?it/s]

val acc 0.49554
