In [None]:
%cd ..

In [2]:
from typing import Union
from tqdm import tqdm
from collections import deque
from statistics import mean
from functools import partial

import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader

from torchvision.models import alexnet, AlexNet_Weights

from conv_cp.conv_cp import decompose_model
from conv_cp.imagenet.dataset import ImageNet

In [3]:
def acc_fn_1(y_pred: torch.Tensor, y_true: torch.Tensor) -> float:
    y_pred = y_pred.argmax(dim=1)
    correct = (y_pred == y_true).sum().item()
    return correct / y_true.size(0)


def acc_fn_5(y_pred: torch.Tensor, y_true: torch.Tensor) -> float:
    y_pred = y_pred.topk(5, dim=1).indices
    correct = (y_pred == y_true.unsqueeze(1)).sum().item()
    return correct / y_true.size(0)


def loss_fn(
    model: nn.Module,
    dataloader: DataLoader,
    device: Union[str, torch.device],
    verbose: bool = False,
    num_steps: int = 50,
) -> float:
    model.eval()
    model.to(device)

    total_loss = 0
    data_iter = iter(dataloader)
    loop = range(num_steps)
    if verbose:
        loop = tqdm(range(num_steps), desc="Validation")
    for step in loop:
        try:
            x, y = next(data_iter)
        except StopIteration:
            break
        x, y = x.to(device), y.to(device)
        with torch.no_grad():
            y_pred = model(x)
            loss = F.cross_entropy(y_pred, y)
            total_loss += loss.item()

        if verbose:
            loop.set_postfix(loss=total_loss / (step + 1))

    model.cpu()
    return total_loss / num_steps


def train_fn(
    model: nn.Module,
    dataloader: DataLoader,
    device: Union[str, torch.device],
    lr: float,
    num_steps: int = 1000,
    metric_len: int = 20,
    loss_tol: float = 1e-3,
    acc_tol: float = 0.95,
    verbose: bool = False,
):
    model.train()
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    running_loss = deque(maxlen=metric_len)
    running_acc = deque(maxlen=metric_len)
    runngin_acc_5 = deque(maxlen=metric_len)

    data_iter = iter(dataloader)
    loop = range(num_steps)
    if verbose:
        loop = tqdm(range(num_steps), desc="Training")
    for _ in loop:
        try:
            x, y = next(data_iter)
        except StopIteration:
            data_iter = iter(dataloader)
            x, y = next(data_iter)

        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        y_pred = model(x)
        loss = F.cross_entropy(y_pred, y)
        loss.backward()
        optimizer.step()

        running_loss.append(loss.item())
        running_acc.append(acc_fn_1(y_pred, y))
        runngin_acc_5.append(acc_fn_5(y_pred, y))

        if verbose:
            loop.set_postfix(
                loss=mean(running_loss),
                acc_1=mean(running_acc),
                acc_5=mean(runngin_acc_5),
            )

        if mean(running_loss) < loss_tol or mean(running_acc) > acc_tol:
            break

    optimizer.zero_grad()
    model.cpu()
    model.eval()

In [4]:
model = alexnet(weights=AlexNet_Weights.IMAGENET1K_V1)
model = model.cuda()
transform = AlexNet_Weights.IMAGENET1K_V1.transforms()

dataset = ImageNet(root_dir="data/val-images", transform=transform)
train_dataset, val_dataset = dataset.split(0.9)

train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=256, shuffle=False)

In [5]:
loss_fn = partial(
    loss_fn,
    dataloader=val_loader,
    device="cuda",
    num_steps=20,
    verbose=True,
)
train_fn = partial(
    train_fn,
    dataloader=train_loader,
    device="cuda",
    lr=1e-7,
    metric_len=10,
    num_steps=100,
    verbose=True,
)

In [6]:
cp_model = decompose_model(
    model,
    conv_rank=2000,
    fc_rank=1500,
    loss_fn=loss_fn,
    train_fn=train_fn,
    trial_rank=5,
    layer_size_regularization=0.8,
    linear_decomp_type="svd",
    freeze_decomposed=False,
    verbose=True,
)
cp_model = cp_model.cpu()
torch.save(cp_model, "checkpoints/model@2000@1500.pt")
del cp_model

Computing losses for conv layers
Processing module features.0


Validation: 100%|██████████| 20/20 [00:30<00:00,  1.51s/it, loss=8.32]


Processing module features.3


Validation: 100%|██████████| 20/20 [00:28<00:00,  1.45s/it, loss=7.47]


Processing module features.6


Validation: 100%|██████████| 20/20 [00:29<00:00,  1.45s/it, loss=9.19]


Processing module features.8


Validation: 100%|██████████| 20/20 [00:29<00:00,  1.50s/it, loss=7.13]


Processing module features.10


Validation: 100%|██████████| 20/20 [00:31<00:00,  1.59s/it, loss=6.64]


Computing losses for fc layers
Processing module classifier.1


Validation: 100%|██████████| 20/20 [00:30<00:00,  1.51s/it, loss=6.79]


Processing module classifier.4


Validation: 100%|██████████| 20/20 [00:28<00:00,  1.43s/it, loss=6.77]


Processing module classifier.6


Validation: 100%|██████████| 20/20 [00:28<00:00,  1.40s/it, loss=5.47]


Initializing CP model
Decomposing features.0 with rank 156
Training model with features.0 decomposed


Training: 100%|██████████| 100/100 [02:57<00:00,  1.77s/it, acc_1=0.508, acc_5=0.756, loss=2.21]


Decomposing features.3 with rank 286
Training model with features.3 decomposed


Training: 100%|██████████| 100/100 [02:58<00:00,  1.78s/it, acc_1=0.499, acc_5=0.752, loss=2.22]


Decomposing features.6 with rank 547
Training model with features.6 decomposed


Training: 100%|██████████| 100/100 [03:01<00:00,  1.82s/it, acc_1=0.509, acc_5=0.748, loss=2.18]


Decomposing features.8 with rank 552
Training model with features.8 decomposed


Training: 100%|██████████| 100/100 [02:54<00:00,  1.74s/it, acc_1=0.497, acc_5=0.743, loss=2.19]


Decomposing features.10 with rank 459
Training model with features.10 decomposed


Training: 100%|██████████| 100/100 [02:54<00:00,  1.74s/it, acc_1=0.516, acc_5=0.766, loss=2.11]


Decomposing classifier.1 with rank 704
Training model with classifier.1 decomposed


Training: 100%|██████████| 100/100 [02:48<00:00,  1.69s/it, acc_1=0.526, acc_5=0.755, loss=2.13]


Decomposing classifier.4 with rank 487
Training model with classifier.4 decomposed


Training: 100%|██████████| 100/100 [02:55<00:00,  1.76s/it, acc_1=0.52, acc_5=0.754, loss=2.16]


Decomposing classifier.6 with rank 309


In [8]:
model = alexnet(weights=AlexNet_Weights.IMAGENET1K_V1)
model = model.cuda()
cp_model = decompose_model(
    model,
    conv_rank=1800,
    fc_rank=1400,
    loss_fn=loss_fn,
    train_fn=train_fn,
    trial_rank=5,
    layer_size_regularization=0.8,
    linear_decomp_type="svd",
    freeze_decomposed=False,
    verbose=True,
)
cp_model = cp_model.cpu()
torch.save(cp_model, "checkpoints/model@1800@1400.pt")
del cp_model

Computing losses for conv layers
Processing module features.0


Validation: 100%|██████████| 20/20 [00:31<00:00,  1.58s/it, loss=8.32]


Processing module features.3


Validation: 100%|██████████| 20/20 [00:27<00:00,  1.36s/it, loss=7.47]


Processing module features.6


Validation: 100%|██████████| 20/20 [00:29<00:00,  1.49s/it, loss=8.94]


Processing module features.8


Validation: 100%|██████████| 20/20 [00:29<00:00,  1.47s/it, loss=7.13]


Processing module features.10


Validation: 100%|██████████| 20/20 [00:31<00:00,  1.58s/it, loss=6.65]


Computing losses for fc layers
Processing module classifier.1


Validation: 100%|██████████| 20/20 [00:30<00:00,  1.54s/it, loss=6.78]


Processing module classifier.4


Validation: 100%|██████████| 20/20 [00:28<00:00,  1.44s/it, loss=6.79]


Processing module classifier.6


Validation: 100%|██████████| 20/20 [00:30<00:00,  1.53s/it, loss=5.47]


Initializing CP model
Decomposing features.0 with rank 141
Training model with features.0 decomposed


Training: 100%|██████████| 100/100 [02:44<00:00,  1.65s/it, acc_1=0.521, acc_5=0.752, loss=2.19]


Decomposing features.3 with rank 258
Training model with features.3 decomposed


Training: 100%|██████████| 100/100 [02:50<00:00,  1.71s/it, acc_1=0.506, acc_5=0.744, loss=2.22]


Decomposing features.6 with rank 491
Training model with features.6 decomposed


Training: 100%|██████████| 100/100 [02:46<00:00,  1.66s/it, acc_1=0.512, acc_5=0.75, loss=2.18]


Decomposing features.8 with rank 497
Training model with features.8 decomposed


Training: 100%|██████████| 100/100 [02:44<00:00,  1.64s/it, acc_1=0.504, acc_5=0.75, loss=2.22]


Decomposing features.10 with rank 413
Training model with features.10 decomposed


Training: 100%|██████████| 100/100 [02:57<00:00,  1.77s/it, acc_1=0.516, acc_5=0.751, loss=2.14]


Decomposing classifier.1 with rank 657
Training model with classifier.1 decomposed


Training: 100%|██████████| 100/100 [02:48<00:00,  1.69s/it, acc_1=0.507, acc_5=0.739, loss=2.2]


Decomposing classifier.4 with rank 454
Training model with classifier.4 decomposed


Training: 100%|██████████| 100/100 [02:57<00:00,  1.77s/it, acc_1=0.486, acc_5=0.741, loss=2.25]


Decomposing classifier.6 with rank 289


In [9]:
model = alexnet(weights=AlexNet_Weights.IMAGENET1K_V1)
model = model.cuda()
cp_model = decompose_model(
    model,
    conv_rank=1600,
    fc_rank=1350,
    loss_fn=loss_fn,
    train_fn=train_fn,
    trial_rank=5,
    layer_size_regularization=0.8,
    linear_decomp_type="svd",
    freeze_decomposed=False,
    verbose=True,
)
cp_model = cp_model.cpu()
torch.save(cp_model, "checkpoints/model@1600@1350.pt")
del cp_model

Computing losses for conv layers
Processing module features.0


Validation: 100%|██████████| 20/20 [00:31<00:00,  1.57s/it, loss=8.32]


Processing module features.3


Validation: 100%|██████████| 20/20 [00:27<00:00,  1.36s/it, loss=7.47]


Processing module features.6


Validation: 100%|██████████| 20/20 [00:28<00:00,  1.42s/it, loss=9.3] 


Processing module features.8


Validation: 100%|██████████| 20/20 [00:26<00:00,  1.33s/it, loss=7.17]


Processing module features.10


Validation: 100%|██████████| 20/20 [00:29<00:00,  1.45s/it, loss=6.65]


Computing losses for fc layers
Processing module classifier.1


Validation: 100%|██████████| 20/20 [00:28<00:00,  1.44s/it, loss=6.79]


Processing module classifier.4


Validation: 100%|██████████| 20/20 [00:28<00:00,  1.43s/it, loss=6.77]


Processing module classifier.6


Validation: 100%|██████████| 20/20 [00:29<00:00,  1.48s/it, loss=5.47]


Initializing CP model
Decomposing features.0 with rank 125
Training model with features.0 decomposed


Training: 100%|██████████| 100/100 [02:45<00:00,  1.66s/it, acc_1=0.509, acc_5=0.757, loss=2.21]


Decomposing features.3 with rank 229
Training model with features.3 decomposed


Training: 100%|██████████| 100/100 [02:42<00:00,  1.63s/it, acc_1=0.508, acc_5=0.743, loss=2.22]


Decomposing features.6 with rank 438
Training model with features.6 decomposed


Training: 100%|██████████| 100/100 [02:52<00:00,  1.73s/it, acc_1=0.511, acc_5=0.76, loss=2.13]


Decomposing features.8 with rank 441
Training model with features.8 decomposed


Training: 100%|██████████| 100/100 [02:57<00:00,  1.77s/it, acc_1=0.496, acc_5=0.734, loss=2.3]


Decomposing features.10 with rank 367
Training model with features.10 decomposed


Training: 100%|██████████| 100/100 [02:45<00:00,  1.66s/it, acc_1=0.496, acc_5=0.742, loss=2.26]


Decomposing classifier.1 with rank 633
Training model with classifier.1 decomposed


Training: 100%|██████████| 100/100 [02:53<00:00,  1.73s/it, acc_1=0.5, acc_5=0.743, loss=2.22] 


Decomposing classifier.4 with rank 439
Training model with classifier.4 decomposed


Training: 100%|██████████| 100/100 [02:50<00:00,  1.71s/it, acc_1=0.508, acc_5=0.747, loss=2.22]


Decomposing classifier.6 with rank 278


In [10]:
model = alexnet(weights=AlexNet_Weights.IMAGENET1K_V1)
model = model.cuda()
cp_model = decompose_model(
    model,
    conv_rank=1400,
    fc_rank=1250,
    loss_fn=loss_fn,
    train_fn=train_fn,
    trial_rank=5,
    layer_size_regularization=0.8,
    linear_decomp_type="svd",
    freeze_decomposed=False,
    verbose=True,
)
cp_model = cp_model.cpu()
torch.save(cp_model, "checkpoints/model@1400@1250.pt")
del cp_model

Computing losses for conv layers
Processing module features.0


Validation: 100%|██████████| 20/20 [00:32<00:00,  1.64s/it, loss=9.24]


Processing module features.3


Validation: 100%|██████████| 20/20 [00:27<00:00,  1.40s/it, loss=7.48]


Processing module features.6


Validation: 100%|██████████| 20/20 [00:28<00:00,  1.43s/it, loss=8.83]


Processing module features.8


Validation: 100%|██████████| 20/20 [00:29<00:00,  1.48s/it, loss=7.13]


Processing module features.10


Validation: 100%|██████████| 20/20 [00:30<00:00,  1.51s/it, loss=6.64]


Computing losses for fc layers
Processing module classifier.1


Validation: 100%|██████████| 20/20 [00:29<00:00,  1.47s/it, loss=6.77]


Processing module classifier.4


Validation: 100%|██████████| 20/20 [00:29<00:00,  1.49s/it, loss=6.77]


Processing module classifier.6


Validation: 100%|██████████| 20/20 [00:26<00:00,  1.35s/it, loss=5.47]


Initializing CP model
Decomposing features.0 with rank 115
Training model with features.0 decomposed


Training: 100%|██████████| 100/100 [02:42<00:00,  1.62s/it, acc_1=0.511, acc_5=0.755, loss=2.22]


Decomposing features.3 with rank 200
Training model with features.3 decomposed


Training: 100%|██████████| 100/100 [02:41<00:00,  1.61s/it, acc_1=0.51, acc_5=0.753, loss=2.21]


Decomposing features.6 with rank 380
Training model with features.6 decomposed


Training: 100%|██████████| 100/100 [02:44<00:00,  1.64s/it, acc_1=0.506, acc_5=0.772, loss=2.13]


Decomposing features.8 with rank 385
Training model with features.8 decomposed


Training: 100%|██████████| 100/100 [02:46<00:00,  1.66s/it, acc_1=0.506, acc_5=0.748, loss=2.24]


Decomposing features.10 with rank 320
Training model with features.10 decomposed


Training: 100%|██████████| 100/100 [02:51<00:00,  1.72s/it, acc_1=0.51, acc_5=0.74, loss=2.24] 


Decomposing classifier.1 with rank 586
Training model with classifier.1 decomposed


Training: 100%|██████████| 100/100 [02:50<00:00,  1.70s/it, acc_1=0.507, acc_5=0.762, loss=2.12]


Decomposing classifier.4 with rank 406
Training model with classifier.4 decomposed


Training: 100%|██████████| 100/100 [02:46<00:00,  1.66s/it, acc_1=0.527, acc_5=0.761, loss=2.12]


Decomposing classifier.6 with rank 258


In [11]:
model = alexnet(weights=AlexNet_Weights.IMAGENET1K_V1)
model = model.cuda()
cp_model = decompose_model(
    model,
    conv_rank=1200,
    fc_rank=1150,
    loss_fn=loss_fn,
    train_fn=train_fn,
    trial_rank=5,
    layer_size_regularization=0.8,
    linear_decomp_type="svd",
    freeze_decomposed=False,
    verbose=True,
)
cp_model = cp_model.cpu()
torch.save(cp_model, "checkpoints/model@1200@1150.pt")
del cp_model

Computing losses for conv layers
Processing module features.0


Validation: 100%|██████████| 20/20 [00:30<00:00,  1.52s/it, loss=9.25]


Processing module features.3


Validation: 100%|██████████| 20/20 [00:26<00:00,  1.33s/it, loss=7.47]


Processing module features.6


Validation: 100%|██████████| 20/20 [00:30<00:00,  1.52s/it, loss=9.3] 


Processing module features.8


Validation: 100%|██████████| 20/20 [00:29<00:00,  1.46s/it, loss=7.11]


Processing module features.10


Validation: 100%|██████████| 20/20 [00:30<00:00,  1.53s/it, loss=6.65]


Computing losses for fc layers
Processing module classifier.1


Validation: 100%|██████████| 20/20 [00:28<00:00,  1.45s/it, loss=6.78]


Processing module classifier.4


Validation: 100%|██████████| 20/20 [00:28<00:00,  1.45s/it, loss=6.78]


Processing module classifier.6


Validation: 100%|██████████| 20/20 [00:30<00:00,  1.54s/it, loss=5.47]


Initializing CP model
Decomposing features.0 with rank 98
Training model with features.0 decomposed


Training: 100%|██████████| 100/100 [02:42<00:00,  1.62s/it, acc_1=0.501, acc_5=0.75, loss=2.21]


Decomposing features.3 with rank 171
Training model with features.3 decomposed


Training: 100%|██████████| 100/100 [02:41<00:00,  1.61s/it, acc_1=0.509, acc_5=0.743, loss=2.24]


Decomposing features.6 with rank 328
Training model with features.6 decomposed


Training: 100%|██████████| 100/100 [02:48<00:00,  1.69s/it, acc_1=0.501, acc_5=0.743, loss=2.25]


Decomposing features.8 with rank 329
Training model with features.8 decomposed


Training: 100%|██████████| 100/100 [02:51<00:00,  1.72s/it, acc_1=0.486, acc_5=0.739, loss=2.27]


Decomposing features.10 with rank 274
Training model with features.10 decomposed


Training: 100%|██████████| 100/100 [02:39<00:00,  1.59s/it, acc_1=0.487, acc_5=0.727, loss=2.33]


Decomposing classifier.1 with rank 540
Training model with classifier.1 decomposed


Training: 100%|██████████| 100/100 [02:43<00:00,  1.64s/it, acc_1=0.483, acc_5=0.735, loss=2.33]


Decomposing classifier.4 with rank 373
Training model with classifier.4 decomposed


Training: 100%|██████████| 100/100 [02:40<00:00,  1.60s/it, acc_1=0.489, acc_5=0.739, loss=2.26]


Decomposing classifier.6 with rank 237


In [None]:
cp_model = decompose_model(
    model,
    conv_rank=1000,  
    fc_rank=1050,
    loss_fn=loss_fn,
    train_fn=train_fn,
    trial_rank=5,
    layer_size_regularization=0.8,
    linear_decomp_type="svd",
    freeze_decomposed=False,
    verbose=True,
)
cp_model = cp_model.cpu()
torch.save(cp_model, "checkpoints/model@1000@1050.pt")
del cp_model

In [12]:
model = alexnet(weights=AlexNet_Weights.IMAGENET1K_V1)
model = model.cuda()
cp_model = decompose_model(
    model,
    conv_rank=900,
    fc_rank=950,
    loss_fn=loss_fn,
    train_fn=train_fn,
    trial_rank=5,
    layer_size_regularization=0.8,
    linear_decomp_type="svd",
    freeze_decomposed=False,
    verbose=True,
)
cp_model = cp_model.cpu()
torch.save(cp_model, "checkpoints/model@900@950.pt")
del cp_model

Computing losses for conv layers
Processing module features.0


Validation: 100%|██████████| 20/20 [00:28<00:00,  1.44s/it, loss=8.32]


Processing module features.3


Validation: 100%|██████████| 20/20 [00:30<00:00,  1.54s/it, loss=7.48]


Processing module features.6


Validation: 100%|██████████| 20/20 [00:31<00:00,  1.55s/it, loss=9.04]


Processing module features.8


Validation: 100%|██████████| 20/20 [00:31<00:00,  1.59s/it, loss=7.13]


Processing module features.10


Validation: 100%|██████████| 20/20 [00:30<00:00,  1.54s/it, loss=6.64]


Computing losses for fc layers
Processing module classifier.1


Validation: 100%|██████████| 20/20 [00:28<00:00,  1.42s/it, loss=6.79]


Processing module classifier.4


Validation: 100%|██████████| 20/20 [00:32<00:00,  1.61s/it, loss=6.78]


Processing module classifier.6


Validation: 100%|██████████| 20/20 [00:30<00:00,  1.53s/it, loss=5.47]


Initializing CP model
Decomposing features.0 with rank 71
Training model with features.0 decomposed


Training: 100%|██████████| 100/100 [02:34<00:00,  1.55s/it, acc_1=0.498, acc_5=0.751, loss=2.21]


Decomposing features.3 with rank 129
Training model with features.3 decomposed


Training: 100%|██████████| 100/100 [02:39<00:00,  1.59s/it, acc_1=0.486, acc_5=0.735, loss=2.31]


Decomposing features.6 with rank 246
Training model with features.6 decomposed


Training: 100%|██████████| 100/100 [02:41<00:00,  1.61s/it, acc_1=0.471, acc_5=0.729, loss=2.37]


Decomposing features.8 with rank 248
Training model with features.8 decomposed


Training: 100%|██████████| 100/100 [02:45<00:00,  1.66s/it, acc_1=0.468, acc_5=0.712, loss=2.4]


Decomposing features.10 with rank 206
Training model with features.10 decomposed


Training: 100%|██████████| 100/100 [02:37<00:00,  1.57s/it, acc_1=0.479, acc_5=0.721, loss=2.34]


Decomposing classifier.1 with rank 446
Training model with classifier.1 decomposed


Training: 100%|██████████| 100/100 [02:46<00:00,  1.66s/it, acc_1=0.477, acc_5=0.73, loss=2.31]


Decomposing classifier.4 with rank 308
Training model with classifier.4 decomposed


Training: 100%|██████████| 100/100 [02:39<00:00,  1.60s/it, acc_1=0.478, acc_5=0.737, loss=2.31]


Decomposing classifier.6 with rank 196
