In [None]:
%cd ..

In [2]:
import time
from typing import Union
from tqdm import tqdm
from collections import deque
from statistics import mean
from functools import partial

import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader

from torchvision.models import alexnet, AlexNet_Weights

from conv_cp.conv_cp import decompose_model
from conv_cp.imagenet.dataset import ImageNet

In [3]:
def acc_fn_1(y_pred: torch.Tensor, y_true: torch.Tensor) -> float:
    y_pred = y_pred.argmax(dim=1)
    correct = (y_pred == y_true).sum().item()
    return correct / y_true.size(0)


def acc_fn_5(y_pred: torch.Tensor, y_true: torch.Tensor) -> float:
    y_pred = y_pred.topk(5, dim=1).indices
    correct = (y_pred == y_true.unsqueeze(1)).sum().item()
    return correct / y_true.size(0)


def loss_fn(
    model: nn.Module,
    dataloader: DataLoader,
    device: Union[str, torch.device],
    verbose: bool = False,
    num_steps: int = 50,
) -> float:
    model.eval()
    model.to(device)

    total_loss = 0
    data_iter = iter(dataloader)
    loop = range(num_steps)
    if verbose:
        loop = tqdm(range(num_steps), desc="Validation")
    for step in loop:
        try:
            x, y = next(data_iter)
        except StopIteration:
            break
        x, y = x.to(device), y.to(device)
        with torch.no_grad():
            y_pred = model(x)
            loss = F.cross_entropy(y_pred, y)
            total_loss += loss.item()

        if verbose:
            loop.set_postfix(loss=total_loss / (step + 1))

    model.cpu()
    return total_loss / num_steps


def train_fn(
    model: nn.Module,
    dataloader: DataLoader,
    device: Union[str, torch.device],
    lr: float,
    num_steps: int = 1000,
    metric_len: int = 20,
    loss_tol: float = 1e-3,
    acc_tol: float = 0.95,
    verbose: bool = False,
):
    model.train()
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    running_loss = deque(maxlen=metric_len)
    running_acc = deque(maxlen=metric_len)
    runngin_acc_5 = deque(maxlen=metric_len)

    data_iter = iter(dataloader)
    loop = range(num_steps)
    if verbose:
        loop = tqdm(range(num_steps), desc="Training")
    for _ in loop:
        try:
            x, y = next(data_iter)
        except StopIteration:
            data_iter = iter(dataloader)
            x, y = next(data_iter)

        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        y_pred = model(x)
        loss = F.cross_entropy(y_pred, y)
        loss.backward()
        optimizer.step()

        running_loss.append(loss.item())
        running_acc.append(acc_fn_1(y_pred, y))
        runngin_acc_5.append(acc_fn_5(y_pred, y))

        if verbose:
            loop.set_postfix(
                loss=mean(running_loss),
                acc_1=mean(running_acc),
                acc_5=mean(runngin_acc_5),
            )

        if mean(running_loss) < loss_tol or mean(running_acc) > acc_tol:
            break

    optimizer.zero_grad()
    model.cpu()
    model.eval()

In [4]:
model = alexnet(weights=AlexNet_Weights.IMAGENET1K_V1)
transform = AlexNet_Weights.IMAGENET1K_V1.transforms()

dataset = ImageNet(root_dir="data/val-images", transform=transform)
train_dataset, val_dataset = dataset.split(0.9)

train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=256, shuffle=False, num_workers=4)

In [5]:
loss_fn = partial(
    loss_fn,
    dataloader=val_loader,
    device="cuda",
    num_steps=20,
    verbose=True,
)
train_fn = partial(
    train_fn,
    dataloader=train_loader,
    device="cuda",
    lr=1e-7,
    metric_len=10,
    num_steps=100,
    verbose=True,
)

In [6]:
cp_model = decompose_model(
    model,
    conv_rank=750,
    fc_rank=900,
    loss_fn=loss_fn,
    train_fn=train_fn,
    trial_rank=5,
    layer_size_regularization=0.0,
    linear_decomp_type="svd",
    freeze_decomposed=False,
    verbose=True,
)

Computing losses for conv layers
Processing module features.0


Validation: 100%|██████████| 20/20 [00:08<00:00,  2.43it/s, loss=8.34]


Processing module features.3


Validation: 100%|██████████| 20/20 [00:08<00:00,  2.42it/s, loss=7.46]


Processing module features.6


Validation: 100%|██████████| 20/20 [00:08<00:00,  2.41it/s, loss=8.66]


Processing module features.8


Validation: 100%|██████████| 20/20 [00:08<00:00,  2.39it/s, loss=7.16]


Processing module features.10


Validation: 100%|██████████| 20/20 [00:07<00:00,  2.55it/s, loss=6.66]


Computing losses for fc layers
Processing module classifier.1


Validation: 100%|██████████| 20/20 [00:08<00:00,  2.43it/s, loss=6.68]


Processing module classifier.4


Validation: 100%|██████████| 20/20 [00:08<00:00,  2.30it/s, loss=6.77]


Processing module classifier.6


Validation: 100%|██████████| 20/20 [00:08<00:00,  2.45it/s, loss=5.51]


Initializing CP model
Decomposing features.0 with rank 164
Training model with features.0 decomposed


Training: 100%|██████████| 100/100 [00:41<00:00,  2.40it/s, acc_1=0.507, acc_5=0.752, loss=2.28]


Decomposing features.3 with rank 147
Training model with features.3 decomposed


Training: 100%|██████████| 100/100 [00:41<00:00,  2.42it/s, acc_1=0.482, acc_5=0.725, loss=2.36]


Decomposing features.6 with rank 169
Training model with features.6 decomposed


Training: 100%|██████████| 100/100 [00:41<00:00,  2.43it/s, acc_1=0.464, acc_5=0.714, loss=2.42]


Decomposing features.8 with rank 140
Training model with features.8 decomposed


Training: 100%|██████████| 100/100 [00:41<00:00,  2.42it/s, acc_1=0.42, acc_5=0.681, loss=2.65]


Decomposing features.10 with rank 130
Training model with features.10 decomposed


Training: 100%|██████████| 100/100 [00:41<00:00,  2.42it/s, acc_1=0.418, acc_5=0.669, loss=2.66]


Decomposing classifier.1 with rank 318
Training model with classifier.1 decomposed


Training: 100%|██████████| 100/100 [00:41<00:00,  2.43it/s, acc_1=0.434, acc_5=0.687, loss=2.57]


Decomposing classifier.4 with rank 321
Training model with classifier.4 decomposed


Training: 100%|██████████| 100/100 [00:40<00:00,  2.45it/s, acc_1=0.421, acc_5=0.679, loss=2.64]


Decomposing classifier.6 with rank 261


In [7]:
def evaluate(
    model: nn.Module, dataloader: DataLoader, device: Union[str, torch.device]
):
    model.eval()
    model.to(device)
    accs = []
    accs_5 = []
    times = []
    loop = tqdm(dataloader, desc="Evaluation")
    for x, y in loop:
        x, y = x.to(device), y.to(device)
        with torch.no_grad():
            t_start = time.time()
            y_pred = model(x)
            times.append(time.time() - t_start)
            accs.append(acc_fn_1(y_pred, y))
            accs_5.append(acc_fn_5(y_pred, y))
            loop.set_postfix(acc=mean(accs), acc_5=mean(accs_5))
    return mean(accs), mean(accs_5), sum(times)

In [8]:
acc_1, acc_5, inference_time = evaluate(cp_model, val_loader, "cuda")
print(f"Top-1 Accuracy: {acc_1} | Top-5 Accuracy: {acc_5} | Inference time: {inference_time}")

Evaluation: 100%|██████████| 20/20 [00:08<00:00,  2.39it/s, acc=0.395, acc_5=0.679]

Top-1 Accuracy: 0.3946001838235294 | Top-5 Accuracy: 0.6787913602941177 | Inference time: 0.056990623474121094





In [9]:
orig_model = alexnet(weights=AlexNet_Weights.IMAGENET1K_V1)
cp_model.cpu()
orig_model.cuda()
ref_acc_1, ref_acc_5, ref_inference_time = evaluate(orig_model, val_loader, "cuda")
print(f"Top-1 Accuracy: {ref_acc_1} | Top-5 Accuracy: {ref_acc_5} | Inference time: {ref_inference_time}")

Evaluation: 100%|██████████| 20/20 [00:09<00:00,  2.19it/s, acc=0.566, acc_5=0.791]


Top-1 Accuracy: 0.5660960477941176 | Top-5 Accuracy: 0.7906135110294118 | Inference time: 0.025220155715942383


In [10]:
def get_size_ratio(model1: nn.Module, model2: nn.Module) -> float:
    size1 = sum(p.numel() for p in model1.parameters())
    size2 = sum(p.numel() for p in model2.parameters())
    return size1 / size2

def get_model_size(model: nn.Module) -> int:
    return sum(p.numel() for p in model.parameters())

size_ratio = get_size_ratio(cp_model, orig_model)
cp_size = get_model_size(cp_model)
orig_size = get_model_size(orig_model)
print(f"Orig model size: {orig_size} | CP model size: {cp_size} | Size Ratio: {size_ratio:.3f}" )

Orig model size: 61100840 | CP model size: 8513084 | Size Ratio: 0.139
