# TRANSTAILOR METHOD

## Import

In [None]:
import torch
import torchvision
import torchvision.transforms as transforms

## Load model and dataset
* Model: VGG16
* Dataset: CIFAR10

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("DEVICE: " + str(device))

# Load the VGG16 model
model = torchvision.models.vgg16(pretrained=True)
batch_size = 256

# Define the data transformation
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Load the CIFAR10 dataset
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)

kwargs = {'num_workers': 10, 'pin_memory': True} if device == 'cuda' else {}
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size, shuffle=True, **kwargs)

# train_loader = torch.utils.data.DataLoader(train_dataset, batch_size, shuffle=True)

# Freeze the parameters of the model
for param in model.parameters():
    param.requires_grad = False

# Replace the last layer of the model with a new layer that matches the number of classes in CIFAR10
num_classes = 10
model.classifier[6] = torch.nn.Linear(model.classifier[6].in_features, num_classes)

model = model.to(device)

DEVICE: cuda
Files already downloaded and verified


## Finetune model based on target data CIFAR10

In [None]:
num_epochs = 10;

# Fine-tune the pre-trained model to generate W_s*
print("\n===Fine-tune the pre-trained model to generate W_s*===")
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
criterion = torch.nn.CrossEntropyLoss()
for epoch in range(num_epochs):
    print("Epoch " + str(epoch) + "/" + str(num_epochs))
    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()


===Fine-tune the pre-trained model to generate W_s*===
Epoch 0/10
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10


## Define and traing scaling factor `α`

1.   Train the scaling factors using the target data (CIFAR10 in this case).
2.   Transform the scaling factors to the filter importance using the Taylor expansion method.
3.   Prune the filters based on the filter importance.
4.   Fine-tune the pruned model using the target data.



In [None]:
def generate_scaling_factors(num_layers, num_filters):
    scaling_factors = torch.zeros(num_layers, max(num_filters))
    for i in range(num_layers):
        scaling_factors[i, :num_filters[i]] = torch.rand(num_filters[i])
    return scaling_factors

num_layers = len(model.features)
num_filters = [0] * num_layers

for i in range(num_layers):
    layer = model.features[i]
    if isinstance(layer, torch.nn.Conv2d):
        num_filters[i] = layer.out_channels

alpha = generate_scaling_factors(num_layers, num_filters)


In [None]:
# Initialize a list to store the output of each filter
image, _ = next(iter(train_loader))

filter_outputs = []

for i in range(num_layers):
    if isinstance(model.features[i], torch.nn.Conv2d):
        filter_outputs.append(torch.zeros(batch_size, num_filters[i], image.shape[2], image.shape[3]))


In [None]:
for i in range(len(filter_outputs)):
    print(filter_outputs[i].shape)

In [None]:
num_epochs = 10;
# filter_outputs = []
criterion = torch.nn.CrossEntropyLoss()

# Train the factors alpha by optimizing the loss function
print("\n===Train the factors alpha by optimizing the loss function===")
optimizer_alpha = torch.optim.SGD([alpha], lr=0.001, momentum=0.9)
for epoch in range(num_epochs):
    print("Epoch " + str(epoch) + "/" + str(num_epochs))
    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        batch_size = inputs.shape[0]
        optimizer_alpha.zero_grad()
        outputs = inputs
        for i in range(num_layers):
            # print("LAYER " + str(i) + ":" + str((model.features[i])))
            # print("ALPHA " + str(alpha[i][:num_filters[i]].shape))

            if isinstance(model.features[i], torch.nn.Conv2d):
                outputs = model.features[i](outputs)
                # print("Outputs shape: " + str(outputs.shape))

                # Multiply the output of each filter by its corresponding scaling factor
                alpha_i = alpha[i][:num_filters[i]]
                for j in range(num_filters[i]):
                    outputs[:, j, :, :] = outputs[:, j, :, :] * alpha_i[j]

                # print("Output after multiplying alpha ["+str(i)+"]: " + str(outputs.shape))
            else:
                outputs = model.features[i](outputs)
        # outputs = model.classifier[6](model.features(inputs) * alpha)
        outputs = torch.flatten(outputs, 1)
        classification_output = model.classifier(outputs)
        loss = criterion(classification_output, labels)
        loss.backward()
        optimizer_alpha.step()

# Transform alpha to the filter importance beta using the Taylor expansion method
beta = torch.abs(torch.autograd.grad(outputs, alpha, grad_outputs=torch.ones_like(outputs), create_graph=True)[0])


===Train the factors alpha by optimizing the loss function===
Epoch 0/10
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10


RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn