<a href="https://colab.research.google.com/github/arkeodev/pytorch-tutorial/blob/main/Transfer_Learning/transfer_learning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Newton - Standing on the shoulders of giants

# Unlocking the Power of Pretrained Models with PyTorch Transfer Learning

In the machine learning and deep learning, building a model from scratch to solve complex problems can be a daunting task. It requires substantial data, immense computational resources, and considerable time. However, what if you could stand on the shoulders of giants and leverage what has already been learned? This is where transfer learning offering a shortcut to advanced deep learning capabilities.


## Introduction to Transfer Learning

Transfer learning is akin to learning a new skill with the knowledge you already possess. It involves taking a model trained on one task and repurposing it for another related task. This method is not only efficient but also reduces the need for a large dataset, which is often a bottleneck in machine learning projects.


### Why Transfer Learning?

- **Efficiency**: Jumpstart your project by leveraging existing models that have already learned robust features from large datasets.
- **Performance**: Pretrained models often yield surprisingly good results, even with a relatively small amount of data.

## Pre-trained Model Sources

Here're sources that the pretrained models can be found:

| Location | Link(s) |
|----------|---------|
| PyTorch domain libraries | [torchvision.models](https://pytorch.org/vision/stable/models.html), [torchtext.models](https://pytorch.org/text/stable/models.html), [torchaudio.models](https://pytorch.org/audio/stable/models.html), [torchrec.models](https://pytorch.org/torchrec/stable/models.html) |
| HuggingFace Hub | [https://huggingface.co/models](https://huggingface.co/models), [https://huggingface.co/datasets](https://huggingface.co/datasets) |
| timm (PyTorch Image Models) library | [https://github.com/rwightman/pytorch-image-models](https://github.com/rwightman/pytorch-image-models) |
| Paperswithcode | [https://paperswithcode.com/](https://paperswithcode.com/) |

## Transfer Learning with PyTorch's Updated API



PyTorch's TorchVision library provides a plethora of models for tasks including image classification, segmentation, and object detection. However, the conventional API posed several limitations.

### Limitations of the Current API

While TorchVision's `torchvision.models` offers a great starting point for transfer learning, it comes with limitations:

- **Limited Pre-trained Weight Options**: The binary nature of the `pretrained` parameter restricts models to a single set of weights, impeding the adoption of improved or alternative pre-trained weights.
- **Manual Inference Transform Definition**: Users must manually specify the preprocessing steps necessary for model inference, which can be error-prone and reduce model accuracy if done incorrectly.
- **Lack of Metadata**: Essential information regarding the weights, such as category labels and training recipes, is not readily accessible, complicating model utilization and experimentation.

### Introducing the Prototype API

TorchVision addresses these challenges with a new prototype API, enhancing usability and flexibility. Let's explore the improvements offered by this API.

### Step-by-Step Implementation with the New API

In [4]:
from PIL import Image

from torchvision import models as PM

# Load the image
img = Image.open("test/assets/encode_jpeg/grace_hopper_517x606.jpg")

# Step 1: Initialize model with pre-defined weights
weights = PM.ResNet50_Weights.IMAGENET1K_V1
model = PM.resnet50(weights=weights)
model.eval()

# Step 2: Initialize the inference transforms
preprocess = weights.transforms()

# Step 3: Apply inference preprocessing transforms
batch = preprocess(img).unsqueeze(0)

# Step 4: Use the model and print the predicted category
prediction = model(batch).squeeze(0).softmax(0)
class_id = prediction.argmax().item()
score = prediction[class_id].item()
category_name = weights.meta["categories"][class_id]
print(f"{category_name}: {100 * score}%")

FileNotFoundError: [Errno 2] No such file or directory: 'test/assets/encode_jpeg/grace_hopper_517x606.jpg'

This approach simplifies model usage while addressing previous limitations.

### Features of the New API

- **Multi-Weight Support**: By associating each model with an Enum class (e.g., `ResNet50_Weights`), the API now supports multiple sets of pre-trained weights.
  
- **Integrated Metadata and Preprocessing Transforms**: Each set of weights is linked with its corresponding metadata and preprocessing transforms, streamlining the model inference process.

### Utilizing Different Weights

Here's how to use the API to select from different available weights:

In [None]:
from torchvision.prototype.models import resnet50, ResNet50_Weights

# Using different sets of pre-trained weights
model_a = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)  # Old weights
model_b = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)  # Improved weights
model_c = resnet50(weights=ResNet50_Weights.DEFAULT)        # Best available weights
model_d = resnet50(weights=None)                            # No weights, random initialization

### Accessing Metadata and Preprocessing Transforms

The API makes it easy to access essential metadata and initialize the necessary preprocessing transforms for your data:

In [None]:
# Accessing metadata
size = ResNet50_Weights.IMAGENET1K_V2.meta["size"]

# Initializing and applying preprocessing transforms
preprocess = ResNet50_Weights.IMAGENET1K_V2.transforms()
img_preprocessed = preprocess(img)

### Retrieving Weights by Name

For situations where only the name of the weights is known, the API provides a convenient retrieval method:

In [None]:
from torchvision.prototype.models import get_weight

# Retrieving weights by name
weights = get_weight("ResNet50_Weights.IMAGENET1K_V1")

## Implementing Transfer Learning with PyTorch

## Conclusion



Transfer learning opens up a realm of possibilities in deep learning projects. By utilizing pretrained models, you can accelerate development, conserve resources, and achieve remarkable results, even with limited data. Explore the myriad of available models and find the perfect fit for your next project with PyTorch.