In [None]:
#| include: false
import timm
from fastai.vision.all import *
from fasterai.quantize.all import *

In [None]:
#| include: false
path = untar_data(URLs.PETS)
files = get_image_files(path/"images")

def label_func(f): return f[0].isupper()

dls = ImageDataLoaders.from_name_func(path, files, label_func, item_tfms=Resize(64))

In [None]:
pretrained_resnet_34 = timm.create_model('resnet34', pretrained=True)
learn = Learner(dls, pretrained_resnet_34, metrics=accuracy)
learn.model.fc = nn.Linear(512, 2)
learn.fit_one_cycle(3, 1e-3, cbs=QuantizeCallback())

In [None]:
from tqdm import tqdm

def get_model_size(model):
    torch.save(model.state_dict(), "temp.p")
    size = os.path.getsize("temp.p") / 1e6  # Size in MB
    os.remove("temp.p")
    return size
    
def compute_validation_accuracy(model, valid_dataloader, device=None):
    # Set the model to evaluation mode
    model.eval()
    
    # Use the model's device if no device is specified
    
    device = torch.device('cpu')
    
    # Move model to the specified device
    model = model.to(device)
    
    # Tracking correct predictions and total samples
    total_correct = 0
    total_samples = 0
    
    # Disable gradient computation for efficiency
    with torch.no_grad():
        for batch in tqdm(valid_dataloader):
            # Assuming batch is a tuple of (inputs, labels)
            # Adjust this if your dataloader returns a different format
            inputs, labels = batch
            
            # Move inputs and labels to the same device as the model
            inputs = torch.Tensor(inputs).to(device)
            labels = labels.to(device)
            
            # Forward pass
            outputs = model(inputs)
            
            # Get predictions (for classification tasks)
            # Use argmax along the class dimension
            _, predicted = torch.max(outputs, 1)
            
            # Update counters
            total_samples += labels.size(0)
            total_correct += (predicted == labels).sum().item()
    
    # Compute accuracy as a percentage
    accuracy = (total_correct / total_samples) * 100
    
    return accuracy

In [None]:
pretrained_resnet_34 = timm.create_model('resnet34', pretrained=True)
learn_original = Learner(dls, pretrained_resnet_34, metrics=accuracy)
learn_original.model.fc = nn.Linear(512, 2)

In [None]:
print(f'Size of the original model: {get_model_size(learn_original.model):.2f} MB')
print(f'Size of the quantized model: {get_model_size(learn.model):.2f} MB')

In [None]:
compute_validation_accuracy(learn.model, dls.valid)