We have seen that complex network require significant resources, such as GPU, for training, and also for fast inference. However, it turns out that a model with significantly smaller number of parameters in most cases can still be trained to perform reasonably well. In other worlds, increase in the model complexity typically results in small(non-proportional) increase in the model performance.

According the previously notebooks, we can see that the accuracy of simple dense model was not significantly worse than that of a poweful CNN. **Increasing the number of CNN layer and/or number of neurons in the classifier allowed us to gain a few percents of accuracy at most**.

This leads us to the idea that we can experiment with `Lightweight network architectures` in order to train faster models. This is especially important if we want to be able to execute our models on mobile devices.

This module will rely on the Cats and Dogs dataset. First we will make sure that the dataset is available.


In [None]:
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

In [None]:
import torch
import torch.nn as nn
import torchvision
import matplotlib.pyplot as plt
from torchinfo import summary
import os, glob, zipfile

In [None]:
# check the paltform, Apple Silicon or Linux
import os, platform

torch_device="cpu"

if 'kaggle' in os.environ.get('KAGGLE_URL_BASE','localhost'):
    torch_device = 'cuda'
else:
    torch_device = 'mps' if platform.system() == 'Darwin' else 'cpu'

In [None]:
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'

In [None]:
torch_device

In [None]:
if not os.path.exists('data/kagglecatsanddogs_5340.zip'):
    !wget -P data https://download.microsoft.com/download/3/E/1/3E1C3F21-ECDB-4869-8368-6DEBA77B919F/kagglecatsanddogs_5340.zip

In [None]:
from PIL import Image

def check_image(fn):
    try:
        im = Image.open(fn)
        im.verify()
        return True
    except:
        return False

def check_image_dir(path):
    for fn in glob.glob(path):
        if not check_image(fn):
            print("Corrupt image: {}".format(fn))
            os.remove(fn)

def common_transform():
    # torchvision.transforms.Normalize is used to normalize a tensor image with mean and standard deviation.
    std_normalize = torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                     std=[0.229, 0.224, 0.225])
    # torchvision.transforms.Compose is used to compose several transforms together in order to do data augmentation.
    trans = torchvision.transforms.Compose([
        torchvision.transforms.Resize(256), # resize the image to 256x256
        torchvision.transforms.CenterCrop(224), # crop the image to 224x224 about the center
        torchvision.transforms.ToTensor(), # convert the image to a tensor with pixel values in the range [0, 1]
        std_normalize])
    return trans

def load_cats_dogs_dataset():
    if not os.path.exists('data/PetImages'):
        with zipfile.ZipFile('data/kagglecatsanddogs_5340.zip', 'r') as zip_ref:
            zip_ref.extractall('data')
    
    check_image_dir('data/PetImages/Cat/*.jpg')
    check_image_dir('data/PetImages/Dog/*.jpg')

    dataset = torchvision.datasets.ImageFolder('data/PetImages', transform=common_transform())
    trainset, testset = torch.utils.data.random_split(dataset, [20000, len(dataset) - 20000])
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True, num_workers=2) # num_workers: how many subprocesses to use for data loading
    testloader = torch.utils.data.DataLoader(testset, batch_size=32, shuffle=False, num_workers=2)
    return dataset, trainloader, testloader 
    
    

In [None]:
dataset, trainloader, testloader = load_cats_dogs_dataset()

## MobileNet

In the previous notebook, we habve seen [**ResNet** architecture](https://www.kaggle.com/code/aisuko/pre-trained-models-and-transfer-learning) for image classification. More lightweight analog of ResNet is **MobileNet**, which uses so-called *Inverted Residual Blocks*. Let's load pre-trained mobilenet and see how it works:

In [None]:
print(torch.__version__)
# https://pytorch.org/hub/pytorch_vision_mobilenet_v2/
model = torch.hub.load('pytorch/vision:v0.10.0', 'mobilenet_v2', pretrained=True)
model.eval()
print(model)

### Apply the model to the dataset and visualize the results

In [None]:
sample_image = dataset[0][0].unsqueeze(0) # unsqueeze(0): add a dimension of size 1 at the 0th position
res = model(sample_image) # apply the model to the sample image
print(res[0].argmax()) # get the index of the highest probability

## Using MobileNet for transfer learning

Now let's perform the same transfer learning process as in previous notebook, but using MobileNet as a base model.

### Freeze all parameters of the model

In [None]:
for x in model.parameters():
    x.requires_grad = False

### Replace the final classifier

We also transfer the model to our default training device (GPU or CPU).

In [None]:
# check the paltform, Apple Silicon or Linux
import os, platform

torch_device="cpu"

if 'kaggle' in os.environ.get('KAGGLE_URL_BASE','localhost'):
    torch_device = 'cuda'
else:
    torch_device = 'mps' if platform.system() == 'Darwin' else 'cpu'

In [None]:
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'

In [None]:
torch_device

In [None]:
model.classifier = nn.Linear(1280,2)  # change the last layer to a linear layer with 2 outputs
model = model.to(torch_device)
summary(model, input_size=(1, 3, 224, 224))

### Doing the actual training

In [None]:
def validate(net, dataloader, loss_fn=nn.NLLLoss()):
    net.eval() # put the network into evaluation mode to deactivate the dropout layers
    count,acc,loss =0,0,0
    with torch.no_grad(): # deactivate autograd to save memory and speed up computations
        for features, labels in dataloader:
            features,labels = features.to(torch_device), labels.to(torch_device)
            out=net(features) # forward pass of the mini-batch through the network to obtain the outputs
            loss += loss_fn(out,labels) # compute the loss
            preds=torch.max(out,dim=1)[1] # compute the predictions to obtain the accuracy
            acc+=(preds==labels).sum() # accumulate the correct predictions
            count+=len(labels) # accumulate the total number of examples
    return loss.item()/count, acc.item()/count # return the loss and accuracy

def train_long(net, train_loader, test_loader, epochs=5, lr=0.01, optimizer=None, loss_fn=nn.NLLLoss(), print_freq=10):
    optimizer = optimizer or torch.optim.Adam(net.parameters(), lr=lr) # use Adam optimizer if not provided
    for epoch in range(epochs):
        net.train() # put the network into training mode make sure the parameters are trainable
        total_loss,acc,count =0,0,0
        for i, (features, labels) in enumerate(train_loader):
            lbls = labels.to(torch_device)
            optimizer.zero_grad() # reset the gradients to zero before each batch to avoid accumulation
            out=net(features.to(torch_device)) # forward pass of the mini-batch through the network to obtain the outputs
            loss = loss_fn(out, lbls) # compute the loss
            loss.backward() # compute the gradients of the loss with respect to all the parameters of the network
            optimizer.step() # update the parameters of the network using the gradients to minimize the loss
            total_loss+=loss # accumulate the loss for inspection
            _,preds=torch.max(out,dim=1) # compute the predictions to obtain the accuracy
            acc+=(preds==lbls).sum() # accumulate the correct predictions
            count+=len(lbls) # accumulate the total number of examples
            if i%print_freq==0:
                print(f'Epoch {epoch}, iter {i}, loss={total_loss.item()/count:.3f}, acc={acc.item()/count:.3f}')
        vl, va = validate(net, test_loader, loss_fn=loss_fn)
        print(f'Epoch {epoch}, val_loss={vl:.3f}, val_acc={va:.3f}')

train_long(model, trainloader, testloader, loss_fn=torch.nn.CrossEntropyLoss(),epochs=1, print_freq=90)

## Summary

Notice that MobileNet results in almost the same accuracy as VGG-16, and just slightly lower than full-scale ResNet.

The main advantage of small models, such as MobileNet or ResNet-18 is that they can be used on mobile devices, 