In [None]:
import torch
import torchvision

from tqdm import tqdm
from torchvision import transforms
from torch.utils.data import DataLoader
from torchvision.models import vgg16, VGG16_Weights
from torch.quantization import quantize_dynamic, convert
from sklearn.metrics import precision_score, recall_score, f1_score

In [None]:
# Set seed for reproducibility.
torch.manual_seed(1234)

# Define transformation.
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),  # Adjusted size to match MobileNetV2 input.
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Load CIFAR-10.
train_dataset = torchvision.datasets.CIFAR10(root='../data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
test_dataset = torchvision.datasets.CIFAR10(root='../data', train=False, download=True, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)

# Check if GPU is available.
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load MobileNetV2 model.
model = vgg16(weights=VGG16_Weights.IMAGENET1K_V1).to(device)

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

# Training function.
def train(model, train_loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    for images, labels in tqdm(train_loader, desc="Training"):
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0)

    return running_loss / len(train_loader.dataset)

# Evaluation function.
def evaluate(model, data_loader, device):
    model.eval()
    correct = 0
    total = 0
    all_predicted = []
    all_labels = []
    with torch.no_grad():
        for images, labels in tqdm(data_loader, desc="Evaluating"):
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)

            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            all_predicted.extend(predicted.tolist())
            all_labels.extend(labels.tolist())
            
    accuracy = 100 * correct / total
    precision = precision_score(all_labels, all_predicted, average='macro')
    recall = recall_score(all_labels, all_predicted, average='macro')
    f1 = f1_score(all_labels, all_predicted, average='macro')

    return accuracy, precision, recall, f1

In [None]:
# Fine-tune the model.
num_epochs = 1
for epoch in range(num_epochs):
    print(f'Epoch {epoch+1}/{num_epochs}')
    train_loss = train(model, train_loader, criterion, optimizer, device)
    print(f'Train Loss: {train_loss:.4f}')

    acc, pre, rec, f1 = evaluate(model, test_loader, device)
    print(f'Accuracy after epoch {epoch+1}: {acc:.2f}%')
    print(f'Precision after epoch {epoch+1}: {pre:.2f}')
    print(f'Recall after epoch {epoch+1}: {rec:.2f}')
    print(f'F1 after epoch {epoch+1}: {f1:.2f}')

In [None]:
# The function aggressive_quantization_and_pruning() is designed to aggressively
# reduce the computational complexity of a pre-trained VGG16 model by simultaneously 
# applying quantization and pruning techniques to its weights. 
# Quantization involves reducing the precision of weight values to a smaller set
# of discrete levels, thereby reducing memory and computational requirements. 
# In this function, weights are quantized by rounding them to the nearest multiple 
# of a specified quantization level. Additionally, pruning involves removing 
# connections with negligible weights, effectively reducing the model's parameter 
# count and computational cost. In this function, weights are pruned by setting 
# to zero those weights whose absolute values fall below a certain percentage of 
# the maximum absolute weight value. By combining these two aggressive techniques, 
# the function aims to produce a more efficient model without significantly sacrificing 
# performance, ultimately facilitating deployment on resource-constrained platforms.

# Define a function for aggressive quantization and pruning.
def aggressive_quantization_and_pruning(model, q_level=0.9, prune_level=0.05):
    # Quantize weights aggressively.
    for param_tensor in model.state_dict():
        param = model.state_dict()[param_tensor]
        model.state_dict()[param_tensor] = torch.round(param / q_level) * q_level

    # Prune weights aggressively.
    for param_tensor in model.state_dict():
        param = model.state_dict()[param_tensor]
        prune_mask = torch.abs(param) < (torch.max(torch.abs(param)) * prune_level)
        param[prune_mask] = 0.0

In [None]:
# Apply aggressive quantization and pruning.
aggressive_quantization_and_pruning(model)

# Evaluate the pruned and quantized model
print('Quantized and Pruned Model Evaluation:')
acc, pre, rec, f1 = evaluate(model, test_loader, device)
print(f'Accuracy after epoch {epoch+1}: {acc:.2f}%')
print(f'Precision after epoch {epoch+1}: {pre:.2f}')
print(f'Recall after epoch {epoch+1}: {rec:.2f}')
print(f'F1 after epoch {epoch+1}: {f1:.2f}')

In [None]:
# # Dynamic Quantization.
# quantized_model = quantize_dynamic(model, {torch.nn.Conv2d, torch.nn.Linear}, dtype=torch.qint8)

# # Evaluate accuracy after quantization.
# acc, pre, rec, f1 = evaluate(model, test_loader, device)
# print(f'Accuracy after epoch {epoch+1}: {acc:.2f}%')
# print(f'Precision after epoch {epoch+1}: {pre:.2f}')
# print(f'Recall after epoch {epoch+1}: {rec:.2f}')
# print(f'F1 after epoch {epoch+1}: {f1:.2f}')