# 0x00 Introduction

This is a manual for using TileTrans to reparameterize the DNN model. We will show the steps of reparameterization from zero as follows

1. Train a model
2. Prune the model
3. Reparameterize and prune the model

# 0x01 Train a model

We train the model generally. For quik traning, we select **AlexNet** as the model and **CIFAR10** as the dataset.



In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
import torchvision.models as models
import torch.nn as nn
import time
import train

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ConvertImageDtype(torch.float),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

EPOCH = 32
BATCH_SIZE = 256
LR = 0.001

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
net = models.alexnet(pretrained=True)

optimizer = torch.optim.SGD(net.parameters(), lr=LR, momentum=0.9, weight_decay=5e-4)
loss_func = nn.CrossEntropyLoss()

start_time = time.process_time()
train.train(net, trainloader, testloader, loss_func, optimizer, EPOCH, "trained_model")
print("training time = {}".format(time.process_time() - start_time))

Files already downloaded and verified
Files already downloaded and verified

--------------------------------------------------
Epoch: 0  Train Loss: 0.9452  Valid Loss: 245.7380  Correct Rate: 0.0004
--------------------------------------------------

Maximum Validation Accuracy of 0.0004 at epoch 0/10
Model Saved


--------------------------------------------------
Epoch: 1  Train Loss: nan  Valid Loss: nan  Correct Rate: 0.0004
--------------------------------------------------

--------------------------------------------------
Epoch: 2  Train Loss: nan  Valid Loss: nan  Correct Rate: 0.0004
--------------------------------------------------

--------------------------------------------------
Epoch: 3  Train Loss: nan  Valid Loss: nan  Correct Rate: 0.0004
--------------------------------------------------

--------------------------------------------------
Epoch: 4  Train Loss: nan  Valid Loss: nan  Correct Rate: 0.0004
--------------------------------------------------

---------

Traceback (most recent call last):
  File "/usr/lib/python3.8/multiprocessing/util.py", line 300, in _run_finalizers
    finalizer()
  File "/usr/lib/python3.8/multiprocessing/util.py", line 224, in __call__
    res = self._callback(*self._args, **self._kwargs)
  File "/usr/lib/python3.8/multiprocessing/util.py", line 133, in _remove_temp_dir
    rmtree(tempdir)
  File "/usr/lib/python3.8/shutil.py", line 722, in rmtree
    onerror(os.rmdir, path, sys.exc_info())
  File "/usr/lib/python3.8/shutil.py", line 720, in rmtree
    os.rmdir(path)
OSError: [Errno 39] Directory not empty: '/tmp/pymp-o0o5ueke'



--------------------------------------------------
Epoch: 6  Train Loss: nan  Valid Loss: nan  Correct Rate: 0.0004
--------------------------------------------------

--------------------------------------------------
Epoch: 7  Train Loss: nan  Valid Loss: nan  Correct Rate: 0.0004
--------------------------------------------------

--------------------------------------------------
Epoch: 8  Train Loss: nan  Valid Loss: nan  Correct Rate: 0.0004
--------------------------------------------------

--------------------------------------------------
Epoch: 9  Train Loss: nan  Valid Loss: nan  Correct Rate: 0.0004
--------------------------------------------------
training time = 4.396561160999994


# 0x02 Prune the model

After training the model, we prune the model with the shape of $1 \times 128$. 

In [None]:
from metrics import MetricsL1
import pruner

SPARSITY = 0.9
EPOCH = 10
LR = 0.001

metric = MetricsL1
method = pruner.TW_pruning([1, 128], metric)
pru = pruner.Pruner(method)

net = models.alexnet(pretrained=True)
net.load_state_dict(torch.load("trained_model"))

pru.prune(net, SPARSITY)

optimizer = torch.optim.SGD(net.parameters(), lr=LR, momentum=0.9, weight_decay=5e-4)

start_time = time.process_time()
train.train(net, trainloader, testloader, loss_func, optimizer, EPOCH, "pruned_model")
print("training time = {}".format(time.process_time() - start_time))

According to the result, we get a model with the sparsity of 0.9 and accuracy of 0.1.

# 0x03 Reparameterize and prune the model

Now let's try reparameterizing the moedel with TileTrans before pruning.

In [None]:
from reconstructor import Reconstructor, ReconMethodL1Sort

SPARSITY = 0.9
EPOCH = 10
LR = 0.001

metric = MetricsL1
method = pruner.TW_pruning([1, 128], metric)
pru = pruner.Pruner(method)
recon = Reconstructor(metrics=metric, method=ReconMethodL1Sort)

net = models.alexnet(pretrained=True)
net.load_state_dict(torch.load("trained_model"))

recon(net)
pru.prune(net, SPARSITY)

optimizer = torch.optim.SGD(net.parameters(), lr=LR, momentum=0.9, weight_decay=5e-4)

start_time = time.process_time()
train.train(net, trainloader, testloader, loss_func, optimizer, EPOCH, "pruned_model")
print("training time = {}".format(time.process_time() - start_time))

Finally, we get a model with sparsity of 0.9 and accuracy with 0.12.