# This is fine tuning on compressed data

In [9]:
#Import and basic definitions
import os
import torch
import torch.nn as nn
from torchvision import transforms, models
from torch.utils.data import Dataset, DataLoader
from torchvision.datasets import ImageFolder
import torch.optim as optim
from PIL import Image
import numpy as np
import pywt
from torchvision.models import mobilenet_v2

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_dir = r""
data_dir = r".\mini-imagenet"  # Adjust to your dataset path

In [10]:
#Definition of wavelet compression from part 2
def apply_wavelet_compression(images, retain):
    images_np = images.cpu().numpy()  # [N, 3, 84, 84]
    compressed = []
    total_coeffs = 0
    retained_coeffs = 0
    for img in images_np:
        img_recon = np.zeros_like(img)
        for c in range(3):
            coeffs = pywt.wavedec2(img[c], 'db1', level=2)
            coeff_arr, coeff_slices = pywt.coeffs_to_array(coeffs)
            total_coeffs += coeff_arr.size
            thresh = np.percentile(np.abs(coeff_arr), 100 * (1 - retain))
            coeff_arr[np.abs(coeff_arr) < thresh] = 0
            retained_coeffs += np.sum(coeff_arr != 0)
            coeffs_recon = pywt.array_to_coeffs(coeff_arr, coeff_slices, output_format='wavedec2')
            img_recon[c] = pywt.waverec2(coeffs_recon, 'db1')
        compressed.append(img_recon)
    compressed = np.stack(compressed)
    comp_ratio = total_coeffs / retained_coeffs if retained_coeffs > 0 else 1
    mse = np.mean((images_np - compressed) ** 2)
    psnr = 10 * np.log10(1 / mse) if mse > 0 else 100
    return torch.from_numpy(compressed).to(device).float(), psnr, comp_ratio


In [None]:
#Define and Modify the Model
from torchvision import models


mobilenet_v2 = models.mobilenet_v2(num_classes=100, pretrained=True)
mobilenet_v2.classifier[1] = torch.nn.Linear(mobilenet_v2.classifier[1].in_features, 1280)
# Move to appropriate device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = mobilenet_v2.to(device)

In [12]:
transform = transforms.Compose([
    transforms.Resize((96, 96)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

train_dataset = ImageFolder(root=os.path.join(data_dir, "train"), transform=transform)
val_dataset = ImageFolder(root=os.path.join(data_dir, "val"), transform=transform)
test_dataset = ImageFolder(root=os.path.join(data_dir, "test"), transform=transform)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=0)  # num_workers=0 for local
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=0)


In [15]:
import torch.optim as optim

#Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

#Training loop
for epoch in range(10): # loop over the dataset multiple times

 running_loss = 0.0
 for i, data in enumerate(train_loader, 0):
 
    # get the inputs; data is a list of [inputs, labels]
    inputs, labels = data
    images_comp, psnr, comp_ratio = apply_wavelet_compression(inputs, 0.25) #25% compression

    # zero the parameter gradients
    optimizer.zero_grad()

    # forward + backward + optimize
    outputs = model(images_comp)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()

    # print statistics
    running_loss += loss.item()
    if i % 2000 == 1999: # print every 2000 mini-batches
        print('[%d, %5d] loss: %.3f' %
        (epoch + 1, i + 1, running_loss / 2000))
    running_loss = 0.0

print('Finished Training')

#This code is adapted from the following guide https://toxigon.com/fine-tuning-pre-trained-models-with-pytorch

Finished Training


In [17]:
torch.save(model.state_dict(), "compressed_trained.pth")

In [18]:
def eval_model(compr):
    correct = 0
    total = 0

    with torch.no_grad():
        for data in val_loader:
            images, labels = data
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    print('Accuracy of the network on the 10000 uncompressed test images: %d %%' % (100 * correct / total))

    correct = 0
    total = 0

    with torch.no_grad():
        for data in val_loader:
            images, labels = data
            images_comp, psnr, comp_ratio = apply_wavelet_compression(inputs, compr) #25% compression
            outputs = model(images_comp)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    print('Accuracy of the network on the 10000 compressed test images: %d %%' % (100 * correct / total))
#This code is adapted from the following guide https://toxigon.com/fine-tuning-pre-trained-models-with-pytorch

eval_model(0.25)

Accuracy of the network on the 10000 uncompressed test images: 1 %
Accuracy of the network on the 10000 compressed test images: 1 %


### While the accuracy of mobilevnet2 trained on compressed images is extremely poor at 1%, it is better than the mobilenetv2 trained on uncompressed images (0.37%)