# 2.- Implement the training process for a DL model using an off-the-shelf PyTorch model

## Load an off-the-shelf model

PyTorch provides several state-of-the-art models with their corresponding pre-trained paramaeters through the **torchvision** package

In [None]:
import torchvision

In [None]:
# Take an off-the-shelf and load its pre-trained parameters
model =
model.eval()

To train a model from scratch, lets download some data and use it as our *training set*

In [None]:
#@title Install FiftyOne package to use open source image datasets
!pip install fiftyone 
import fiftyone as fo
import fiftyone.zoo as foz

We'll train a DL model to identify cats in images. For that we gather some images with cats, and other images without cats.
That will make the model more robust when inferring if the image contains or not a cat 

In [None]:
#@title Download some images from the Fiftyone dataset
dataset = foz.load_zoo_dataset(
    "open-images-v7",
    split="train",
    label_types=["classifications"],
    classes = ["Cat"],
    max_samples=1000,
    dataset_dir="sample_data",
    download_if_necessary=True
)

To load the images that we downloaded into python, we use a **Dataset**. This object will load each image and its corresponding label so we can use it to feed our model.

In [None]:
#@title This is a wrapper that allows us to use FiftyOne datasets with PyTorch
import matplotlib.pyplot as plt
import torch
from PIL import Image


class FiftyOneTorchDataset(torch.utils.data.Dataset):
    """A class to construct a PyTorch dataset from a FiftyOne dataset.
    
    Args:
        fiftyone_dataset: a FiftyOne dataset or view that will be used for training or testing
        transforms (None): a list of PyTorch transforms to apply to images and targets when loading
        gt_field ("ground_truth"): the name of the field in fiftyone_dataset that contains the 
            desired labels to load
        classes (None): a list of class strings that are used to define the mapping between
            class names and indices. If None, it will use all classes present in the given fiftyone_dataset.
    """

    def __init__(
        self,
        fiftyone_dataset,
        transforms=None,
        classes=None,
    ):
        self.samples = fiftyone_dataset
        self.transforms = transforms
        self.img_paths = self.samples.values("filepath")

        self.classes = classes

    def __getitem__(self, idx):
        img_path = self.img_paths[idx]
        sample = self.samples[img_path]
        metadata = sample.metadata
        img = Image.open(img_path).convert("RGB")

        label = any(lab["label"] in self.classes
                    for lab in sample["positive_labels"]["classifications"])
        target = torch.as_tensor(label, dtype=torch.float32)

        if self.transforms is not None:
            img = self.transforms(img)

        return img, target

    def __len__(self):
        return len(self.img_paths)

    def get_classes(self):
        return self.classes

## Explore the dataset

Each image can have an arbitrary shape, so we resize all of them to have a standard shape of 299x299 pixels. This is the shape *Inception V3* expects as inputs.

In [None]:
preprocess_fun = torchvision.transforms.Compose([
    torchvision.transforms.Resize((299, 299)),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])


In [None]:
pt_dataset = FiftyOneTorchDataset(dataset,
                                  classes=["Cat"],
                                  transforms=
                                  )

In [None]:
x, l = pt_dataset[100]
x.shape

In [None]:
#@title Show one example image
im = (x - x.min()) / (x.max() - x.min())
plt.imshow(im.permute(1, 2, 0))

Check the label assigned by Inception to this image here https://deeplearning.cms.waikato.ac.nz/user-guide/class-maps/IMAGENET/



---



## Train an off-the-shelf model from scratch

In this exercise we'll train the Inception V3 model to idenfity cats.
This is a classification tasks, and we only need that the model returns 1 if a cat is present in an image, and 0 otherwise.

In [None]:
# Load an off-the-shelf model (Inception V3), without pre-trained parameters
model =

This time we use the *Binary Cross Entropy (BCE)* loss function because there are only two possible outcomes [0, 1].

In PyTorch this function is implemented as **nn.BCEWithLogitsLoss**. The *WithLogits* part of the name means that PyTorch applies the appropriate transformations to the output of the model to map them into a [0, 1] response.

In [None]:
import torch.nn as nn
import torch.optim as optim

from torch.utils.data import DataLoader

In [None]:
criterion = nn.BCEWithLogitsLoss()

optimizer = optim.Adam(model.parameters(), lr=0.001)

trn_queue = DataLoader(pt_dataset, batch_size=16, shuffle=True, pin_memory=True)

In [None]:
# Move the model to the GPU memory
model.train()
model.cuda()

for e in range(10):
  for i, (x, y) in enumerate(trn_queue):
    # Empty the accumulated gradients from any previous iteration
    optimizer.zero_grad()

    # Move the input images and their respective classes to the GPU
    x = x.cuda()
    y = y.cuda()

    y_hat = model(x)

    # Compute the error/loss function
    loss = criterion(y_hat.logits, y.view(-1, 1))

    # Perform the backward pass to generate the gradients of the loss function with respect to the inputs
    loss.backward()

    # Update the model parameters
    optimizer.step()

    # Log the progress of the model
    if i % 10 == 0:
      acc = torch.sum(y == y_hat.logits.detach().argmax(dim=1)) / x.shape[0]

      print(f"Epoch {e}, step {i}: loss={loss.item()}, acc={acc}")