# Session 6: Advanced PyTorch: DataLoaders and Pre-trained Models

**Objective:** To learn the standard workflow for handling datasets and leverage the power of pre-trained models for advanced tasks.

## Part 1: Concepts

### 1. The Problem with Manual Batching
In our last session, we fed the entire dataset to the model at once. This is not feasible for large datasets that don't fit in memory. The solution is to process data in small batches. PyTorch provides elegant tools for this: `Dataset` and `DataLoader`.

### 2. `Dataset` and `DataLoader`
- **`torch.utils.data.Dataset`:** An abstract class that represents a dataset. You typically create a custom class that inherits from it and implements two key methods:
    - `__len__(self)`: Returns the total number of samples in the dataset.
    - `__getitem__(self, idx)`: Returns the sample (e.g., an image and its label) at a given index `idx`.
- **`torch.utils.data.DataLoader`:** A data loader that wraps a `Dataset` and provides an iterable over it. It handles batching, shuffling, and even parallel data loading.

In [None]:
# What does "Inherit" mean?
# Parent class
class Animal:
    def __init__(self, name):
        self.name = name

    def eat(self):
        return f"{self.name} is eating."

# Child class inherits from Animal
class Dog(Animal):
    def bark(self):
        return f"{self.name} says Woof!"

# Create an object of the Dog class
my_dog = Dog("Rex")

# Call method from the parent class (Animal)
print(my_dog.eat())  # Output: Rex is eating.

# Call method from the child class (Dog)
print(my_dog.bark()) # Output: Rex says Woof!

#### What is typically inherited by the Dataset class?

For one thing, the Dataset class is expected to have some methods which are indeed `__len__` and `__getitem__`. Secondly, there are utility functions such as `__add__` which concatenates different datasets and similar methods.

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader

# Create a simple custom dataset
class MyCustomDataset(Dataset):
    def __init__(self, features, labels):
        self.features = features
        self.labels = labels

    def __len__(self):
        return len(self.features)

    def __getitem__(self, idx):
        return self.features[idx], self.labels[idx]

# Sample data
X = torch.randn(100, 10) # 100 samples, 10 features each
y = torch.randint(0, 2, (100,)) # 100 labels (0 or 1)

# Instantiate the dataset and dataloader
dataset = MyCustomDataset(X, y)

data_loader = DataLoader(dataset=dataset, batch_size=16, shuffle=True)

# Iterate over the data
first_batch_features, first_batch_labels = next(iter(data_loader))
print(f"Feature batch shape: {first_batch_features.shape}") 
print(f"Label batch shape: {first_batch_labels.shape}")

### 3. GPU Acceleration with `.to(device)`
To speed up training, you can move your model and data to a GPU (if available).

In [None]:
# 1. Check for GPU and set the device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

# 2. Move your model to the device (example)
# model = YourModelClass()
# model.to(device)

# 3. Inside your training loop, you would move each batch of data to the device
# for features, labels in data_loader:
#     features = features.to(device)
#     labels = labels.to(device)
#     ...

### 4. Loading Pre-trained Models (Transfer Learning)
Why train a huge model from scratch when you can use one trained by experts on massive datasets? This is called **transfer learning**. `torchvision.models` provides many famous, pre-trained models.

In [None]:
import torchvision.models as models

# Load a pre-trained ResNet-18 model
# pretrained=True downloads the weights trained on the ImageNet dataset
resnet18 = models.resnet18(pretrained=True)

# Set the model to evaluation mode
resnet18.eval()

# print(resnet18) # You can inspect all the layers!

# To use it, you need to provide an input tensor with the correct shape
# and apply the same transformations the model was trained on.
# For now, let's create a dummy image tensor.
# (Batch size, Channels, Height, Width)
dummy_image = torch.randn(1, 3, 224, 224)
output = resnet18(dummy_image)

# The output is a tensor of shape [1, 1000] because ImageNet has 1000 classes.
print(f"Output shape: {output.shape}")

## Part 2: Exercises & Debugging (90 mins)

### Lab 6.1: Load and Inspect a Pre-trained Model
* **Task:** Load the pre-trained `alexnet` model from `torchvision.models`. Print the model architecture to the screen. Create a dummy input tensor of the correct size for a single image and pass it through the model to get the output.
* **Hint:** AlexNet, like ResNet, expects a 224x224 RGB image.

In [None]:
import torch
import torchvision.models as models

# Load pre-trained AlexNet

# Print the model

# Create a dummy input tensor (1 image, 3 channels, 224x224)

# Get the output

# The output has 1000 classes for ImageNet
print(f"\nOutput shape: {output.shape}")
print(f"Predicted class index: {torch.argmax(output, 1).item()}")