### Example is based on CIFAR-100 dataset.

In [1]:
import os
os.environ["IS_DECOMPOSED"] = "0" # not decomposed model
from copy import deepcopy

import torch
import torch.nn as nn
from torchvision import datasets, transforms

import warnings
warnings.filterwarnings("ignore")

In [2]:
from td import Conv2dTD
from train_utils import ConvBnAct, train_model, eval_model

In [3]:
device = torch.device("mps")

In [4]:
tt = transforms.Compose([transforms.ToTensor()])

In [5]:
train_loader = torch.utils.data.DataLoader(datasets.CIFAR100('data_train', train=True, download=True, transform=tt), batch_size=128, shuffle=True)
test_loader = torch.utils.data.DataLoader(datasets.CIFAR100('data_test', train=False, download=True, transform=tt), batch_size=1000, shuffle=True)

Files already downloaded and verified
Files already downloaded and verified


In [6]:
dataloaders = {'train': train_loader, 'test': test_loader}

In [7]:
# simple model with different kernel_size
class AModel(nn.Module):
    def __init__(self, cls=100, conv_layer=Conv2dTD):
        super().__init__()
        self.conv1 = ConvBnAct(3, 64, kernel_size=(5, 3), stride=2, padding=(2, 1),
                               conv_layer=conv_layer)
        self.conv2 = ConvBnAct(64, 128, kernel_size=(3, 4), stride=2, padding=(1, 1), bias=True, 
                               conv_layer=conv_layer)
        self.conv3 = ConvBnAct(128, 256, kernel_size=(2, 3), stride=2, padding=(0, 1), 
                               conv_layer=conv_layer)
        # conv with defined ranks
        self.conv4 = ConvBnAct(256, 512, kernel_size=(5, 2), stride=2, padding=(2, 0), bias=True, 
                               conv_layer=conv_layer, core_ranks=[32, 32], stick_rank=32)
        self.conv5 = ConvBnAct(512, 256, kernel_size=(2, 2), padding=(1, 1), stride=1, 
                               conv_layer=conv_layer)
        self.linear = nn.Sequential(nn.Flatten(), nn.Linear(2304, cls))

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.conv5(x)
        x = self.linear(x)
        return x

In [8]:
model = AModel().to(device)
optimizer = torch.optim.Adam(model.parameters(), weight_decay=0.0001, lr=0.0004)

In [9]:
# _ = train_model(model, device, optimizer, dataloaders, num_epochs=7)

In [10]:
# torch.save(model.state_dict(), 'model_original.pth')
model.load_state_dict(torch.load('model_original.pth'))
# model.load_state_dict(torch.load('model_dec.pth'))

<All keys matched successfully>

In [11]:
eval_model(model, device, dataloaders["test"])

Model test accuracy: 0.4147


In [12]:
def decompose_model(model, mode=1):
    """Recursion to call .decompose method in conv layers"""
    for name, layers in model.named_children():
        try:
            if not layers.is_decomposed:
                layers.decompose(mode=mode)
                # # Example of how you might set core ranks and stick ranks
                # in_ch = layers.in_channels
                # out_ch = layers.out_channels
                # core_ranks = [in_ch // 2, out_ch // 4]
                # stick_rank = min(in_ch, out_ch)
                # layers.decompose(core_ranks=core_ranks, stick_rank=stick_rank)
        except AttributeError:
            pass

        if layers is not None:
            decompose_model(layers, mode)

In [13]:
def check_decomp(mode=1, num_epochs=1):
    model = AModel().to(device)
    model.load_state_dict(torch.load('model_original.pth'))
    print("Model parameters: ", sum(tensor.numel() for tensor in model.parameters()))
    print("Model Eval:")
    eval_model(model, device, dataloaders['test'])
    # decomposition
    decompose_model(model, mode=mode)
    print("\nDecomposed model parameters: ", sum(tensor.numel() for tensor in model.parameters()))
    print("Decomposed model Eval:")
    eval_model(model, device, dataloaders['test'])
    # finetune
    print(f"\nFinetune {num_epochs} epoch:")
    _ = train_model(model, device, optimizer, dataloaders, num_epochs=num_epochs)
    torch.save(model.state_dict(), 'model_dec.pth')

In [14]:
check_decomp(mode=1)

Model parameters:  2366372
Model Eval:
Model test accuracy: 0.4147

Decomposed model parameters:  863140
Decomposed model Eval:
Model test accuracy: 0.2223

Finetune 1 epoch:


Train process:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch-1 train loss: 2.860                   <-> accuracy: 0.358


                                                            

Epoch-1 test loss: 3.604                   <-> accuracy: 0.295




In [15]:
check_decomp(mode=2)

Model parameters:  2366372
Model Eval:
Model test accuracy: 0.4147

Decomposed model parameters:  1010596
Decomposed model Eval:
Model test accuracy: 0.2955

Finetune 1 epoch:


Train process:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch-1 train loss: 1.696                   <-> accuracy: 0.474


                                                            

Epoch-1 test loss: 3.246                   <-> accuracy: 0.346




In [16]:
check_decomp(mode=3)

Model parameters:  2366372
Model Eval:
Model test accuracy: 0.4147

Decomposed model parameters:  878500
Decomposed model Eval:
Model test accuracy: 0.2078

Finetune 1 epoch:


Train process:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch-1 train loss: 2.892                   <-> accuracy: 0.334


                                                            

Epoch-1 test loss: 3.672                   <-> accuracy: 0.285




In [17]:
"""
To upload decomposed weights from checkpoint, you need to export 
the environment variable with the previously chosen mode.

export IS_DECOMPOSED=2 
or
import os
os.environ["IS_DECOMPOSED"] = 2 # mode

Then, all decomposed weights will be loaded correctly from the checkpoint.
model.load_state_dict(torch.load('model_dec.pth'))
"""

'\nTo upload decomposed weights from checkpoint, you need to export \nthe environment variable with the previously chosen mode.\n\nexport IS_DECOMPOSED=2 \nor\nimport os\nos.environ["IS_DECOMPOSED"] = 2 # mode\n\nThen, all decomposed weights will be loaded correctly from the checkpoint.\nmodel.load_state_dict(torch.load(\'model_dec.pth\'))\n'