<center><a href="https://www.nvidia.com/dli"> <img src="images/DLI_Header.png" alt="Header" style="width: 400px;"/> </a></center>

In [None]:
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms.v2 as transforms
import torchvision.io as tv_io

import glob
from PIL import Image

import utils

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.cuda.is_available()

## The Dataset

<img src="./images/fruits.png" style="width: 600px;">

## Load ImageNet Base Model

In [None]:
from torchvision.models import vgg16
from torchvision.models import VGG16_Weights

weights = VGG16_Weights.FIXME
vgg_model = vgg16(weights=weights)

## Freeze Base Model

In [None]:
# Freeze base model
vgg_model.requires_grad_(FIXME)
next(iter(vgg_model.parameters())).requires_grad

##  Add Layers to Model

In [None]:
vgg_model.classifier[0:3]

Once we've taken what we've wanted from VGG16, we can then add our own modifications. No matter what additional modules we add, we still need to end with one value for each output.

In [None]:
N_CLASSES = FIXME

my_model = nn.Sequential(
    vgg_model.features,
    vgg_model.avgpool,
    nn.Flatten(),
    vgg_model.classifier[0:3],
    nn.Linear(4096, 500),
    nn.ReLU(),
    nn.Linear(500, N_CLASSES)
)
my_model

## Compile Model

In [None]:
loss_function = nn.FIXME()
optimizer = Adam(my_model.parameters())
my_model = torch.compile(my_model.to(device))

##  Data Transforms

To preprocess our input images, we will use the transforms included with the VGG16 weights.

In [None]:
pre_trans = weights.transforms()

In [None]:
IMG_WIDTH, IMG_HEIGHT = (224, 224)

random_trans = transforms.Compose([
    FIXME
])

##  Load Dataset

Now it's time to load the train and validation datasets.

In [None]:
DATA_LABELS = ["freshapples", "freshbanana", "freshoranges", "rottenapples", "rottenbanana", "rottenoranges"]

class MyDataset(Dataset):
    def __init__(self, data_dir):
        self.imgs = []
        self.labels = []

        for l_idx, label in enumerate(DATA_LABELS):
            data_paths = glob.glob(data_dir + label + '/*.png', recursive=True)
            for path in data_paths:
                img = tv_io.read_image(path, tv_io.ImageReadMode.RGB)
                self.imgs.append(pre_trans(img).to(device))
                self.labels.append(torch.tensor(l_idx).to(device))


    def __getitem__(self, idx):
        img = self.imgs[idx]
        label = self.labels[idx]
        return img, label

    def __len__(self):
        return len(self.imgs)

Select the batch size `n` and set `shuffle` either to `True` or `False` depending on if we are `train`ing or `valid`ating.

In [None]:
n = FIXME

train_path = "data/fruits/train/"
train_data = MyDataset(train_path)
train_loader = DataLoader(train_data, batch_size=n, shuffle=FIXME)
train_N = len(train_loader.dataset)

valid_path = "data/fruits/valid/"
valid_data = MyDataset(valid_path)
valid_loader = DataLoader(valid_data, batch_size=n, shuffle=FIXME)
valid_N = len(valid_loader.dataset)

## Train the Model

Time to train the model! We've moved the `train` and `validate` functions to our [utils.py](./utils.py) file. Before running the below, make sure all your variables are correctly defined.

It may help to rerun this cell or change the number of `epochs`.

In [None]:
epochs = 10

for epoch in range(epochs):
    print('Epoch: {}'.format(epoch))
    utils.train(my_model, train_loader, train_N, random_trans, optimizer, loss_function)
    utils.validate(my_model, valid_loader, valid_N, loss_function)

## Unfreeze Model for Fine Tuning

If you have reached 92% validation accuracy already, this next step is optional. If not, we suggest fine tuning the model with a very low learning rate.

In [None]:
# Unfreeze the base model
vgg_model.requires_grad_(FIXME)
optimizer = Adam(my_model.parameters(), lr=.0001)

In [None]:
epochs = 1

for epoch in range(epochs):
    print('Epoch: {}'.format(epoch))
    utils.train(my_model, train_loader, train_N, random_trans, optimizer, loss_function)
    utils.validate(my_model, valid_loader, valid_N, loss_function)

##  Evaluate the Model

In [None]:
utils.validate(my_model, valid_loader, valid_N, loss_function)

<img src="./images/assess_task.png" style="width: 800px;">

<center><a href="https://www.nvidia.com/dli"> <img src="images/DLI_Header.png" alt="Header" style="width: 400px;"/> </a></center>