# Transfer Learning Example
Here is some code demonstrating how to do transfer learning on the cifar dataset using a vision transformer and a CNN (Resnet). The code takes a bit over an hour to run on a GPU in google colab so don't feel like you need to run this, but this is a fun little example of how you would apply transfer learning to some more powerful models.

In [1]:
!pip install torch torchvision transformers

Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch)
  Downloading nvidia_curand_cu12-10.3.5

# Pretrained ViT

In [2]:
# Import necessary libraries
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from transformers import ViTForImageClassification, ViTImageProcessor, Trainer, TrainingArguments

def count_trainable_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

# Check for GPU availability
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Define transformations for the training and testing datasets
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

# Load CIFAR-10 dataset
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

# Initialize the feature extractor
feature_extractor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224')

# Define a custom dataset class to apply the feature extractor
class CustomCIFAR10Dataset(torch.utils.data.Dataset):
    def __init__(self, dataset, feature_extractor):
        self.dataset = dataset
        self.feature_extractor = feature_extractor

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        image, label = self.dataset[idx]
        # Ensure the image is in the correct format (e.g., PIL Image)
        if isinstance(image, torch.Tensor):
            # Convert the tensor to a PIL Image
            image = transforms.ToPILImage()(image)
        encoding = self.feature_extractor(images=image, return_tensors="pt")
        pixel_values = encoding['pixel_values'].squeeze()
        return {'pixel_values': pixel_values, 'labels': label}

num_labels = 10
id2label = {i: f"label_{i}" for i in range(num_labels)}
label2id = {f"label_{i}": i for i in range(num_labels)}
model = ViTForImageClassification.from_pretrained(
    'google/vit-base-patch16-224',
    num_labels=num_labels,
    id2label=id2label,
    label2id=label2id,
    ignore_mismatched_sizes=True
)

initial_trainable_params = count_trainable_parameters(model)
print(f"Trainable parameters before freezing: {initial_trainable_params}")

# Modify the classifier to match the number of classes in CIFAR-10
in_features = model.classifier.in_features
model.classifier = torch.nn.Linear(in_features, 10)

# Freeze all layers except the classifier
for name, param in model.named_parameters():
    if 'classifier' not in name:
        param.requires_grad = False

model.to(device)

final_trainable_params = count_trainable_parameters(model)
print(f"Trainable parameters after freezing: {final_trainable_params}")

# Create DataLoader instances
train_loader = DataLoader(CustomCIFAR10Dataset(train_dataset, feature_extractor), batch_size=32, shuffle=True)
test_loader = DataLoader(CustomCIFAR10Dataset(test_dataset, feature_extractor), batch_size=32, shuffle=False)

# Define training arguments
training_args = TrainingArguments(
    output_dir='./results',
    num_train_epochs=3,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    warmup_steps=500,
    weight_decay=0.01,
    logging_dir='./logs',
    logging_steps=10,
    eval_strategy="epoch",
    run_name='vit-cifar-transfer-learning',
    report_to="none"
)

# Define a function to compute metrics
def compute_metrics(p):
    preds = p.predictions.argmax(-1)
    return {'accuracy': (preds == p.label_ids).astype(float).mean().item()}

# Initialize the Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_loader.dataset,
    eval_dataset=test_loader.dataset,
    compute_metrics=compute_metrics
)

# Train the model
trainer.train()

# Evaluate the model
eval_results = trainer.evaluate()
print(f"Test Accuracy: {eval_results['eval_accuracy']:.4f}")

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170M/170M [00:24<00:00, 6.83MB/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


preprocessor_config.json:   0%|          | 0.00/160 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/69.7k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([10]) in the model instantiated
- classifier.weight: found shape torch.Size([1000, 768]) in the checkpoint and torch.Size([10, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Trainable parameters before freezing: 85806346
Trainable parameters after freezing: 7690


Epoch,Training Loss,Validation Loss,Accuracy
1,1.1965,1.207386,0.6102
2,0.9754,1.060054,0.6472
3,1.0442,1.031363,0.6556


Test Accuracy: 0.6556


# Pretrained CNN

In [None]:
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from transformers import ResNetForImageClassification, AutoImageProcessor, Trainer, TrainingArguments

def count_trainable_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

# Check for GPU availability
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Define transformations for the training and testing datasets
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

# Load CIFAR-10 dataset
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

# Initialize the feature extractor
feature_extractor = AutoImageProcessor.from_pretrained('microsoft/resnet-50')

# Define a custom dataset class to apply the feature extractor
class CustomCIFAR10Dataset(torch.utils.data.Dataset):
    def __init__(self, dataset, feature_extractor):
        self.dataset = dataset
        self.feature_extractor = feature_extractor

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        image, label = self.dataset[idx]
        # Ensure the image is in the correct format (e.g., PIL Image)
        if isinstance(image, torch.Tensor):
            # Convert the tensor to a PIL Image
            image = transforms.ToPILImage()(image)
        encoding = self.feature_extractor(images=image, return_tensors="pt")
        pixel_values = encoding['pixel_values'].squeeze()
        return {'pixel_values': pixel_values, 'labels': label}

num_labels = 10
id2label = {i: f"label_{i}" for i in range(num_labels)}
label2id = {f"label_{i}": i for i in range(num_labels)}
model = ResNetForImageClassification.from_pretrained(
    'microsoft/resnet-50',
    num_labels=num_labels,
    id2label=id2label,
    label2id=label2id,
    ignore_mismatched_sizes=True
)

initial_trainable_params = count_trainable_parameters(model)
print(f"Trainable parameters before freezing: {initial_trainable_params}")

# Access the classifier's input features
in_features = model.classifier[1].in_features
# in_features = 65536

# Modify the classifier to match the number of classes in CIFAR-10
model.classifier = torch.nn.Sequential(
    torch.nn.Flatten(start_dim=1, end_dim=-1),
    torch.nn.Linear(in_features, 10)
)

# Freeze all layers except the classifier
for name, param in model.named_parameters():
    if 'classifier' not in name:
        param.requires_grad = False

model.to(device)

final_trainable_params = count_trainable_parameters(model)
print(f"Trainable parameters after freezing: {final_trainable_params}")

# Create DataLoader instances
train_loader = DataLoader(CustomCIFAR10Dataset(train_dataset, feature_extractor), batch_size=32, shuffle=True)
test_loader = DataLoader(CustomCIFAR10Dataset(test_dataset, feature_extractor), batch_size=32, shuffle=False)

# Define training arguments
training_args = TrainingArguments(
    output_dir='./results',
    num_train_epochs=3,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    warmup_steps=500,
    weight_decay=0.01,
    logging_dir='./logs',
    logging_steps=10,
    eval_strategy="epoch",
    run_name='resnet-cifar-transfer-learning',
    report_to="none"
)

# Define a function to compute metrics
def compute_metrics(p):
    preds = p.predictions.argmax(-1)
    return {'accuracy': (preds == p.label_ids).astype(float).mean().item()}

# Initialize the Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_loader.dataset,
    eval_dataset=test_loader.dataset,
    compute_metrics=compute_metrics
)

# Train the model
trainer.train()

# Evaluate the model
eval_results = trainer.evaluate()
print(f"Test Accuracy: {eval_results['eval_accuracy']:.4f}")


Files already downloaded and verified
Files already downloaded and verified


preprocessor_config.json:   0%|          | 0.00/266 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/69.6k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/102M [00:00<?, ?B/s]

Some weights of ResNetForImageClassification were not initialized from the model checkpoint at microsoft/resnet-50 and are newly initialized because the shapes did not match:
- classifier.1.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([10]) in the model instantiated
- classifier.1.weight: found shape torch.Size([1000, 2048]) in the checkpoint and torch.Size([10, 2048]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Trainable parameters before freezing: 23528522
Trainable parameters after freezing: 20490


Epoch,Training Loss,Validation Loss,Accuracy
1,1.9854,1.988141,0.5168
2,1.8418,1.839359,0.5361


# Subset data to go faster

In [None]:
import torch
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms
from transformers import ResNetForImageClassification, AutoImageProcessor, Trainer, TrainingArguments

def count_trainable_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

# Check for GPU availability
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Define transformations for the training and testing datasets
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

# Load CIFAR-10 dataset
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

# Create subsets containing only the first 3200 samples
train_subset = Subset(train_dataset, indices=range(320))
test_subset = Subset(test_dataset, indices=range(64))

# Initialize the feature extractor
feature_extractor = AutoImageProcessor.from_pretrained('microsoft/resnet-50')

# Define a custom dataset class to apply the feature extractor
class CustomCIFAR10Dataset(torch.utils.data.Dataset):
    def __init__(self, dataset, feature_extractor):
        self.dataset = dataset
        self.feature_extractor = feature_extractor

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        image, label = self.dataset[idx]
        # Ensure the image is in the correct format (e.g., PIL Image)
        if isinstance(image, torch.Tensor):
            # Convert the tensor to a PIL Image
            image = transforms.ToPILImage()(image)
        encoding = self.feature_extractor(images=image, return_tensors="pt")
        pixel_values = encoding['pixel_values'].squeeze()
        return {'pixel_values': pixel_values, 'labels': label}

num_labels = 10
id2label = {i: f"label_{i}" for i in range(num_labels)}
label2id = {f"label_{i}": i for i in range(num_labels)}
model = ResNetForImageClassification.from_pretrained(
    'microsoft/resnet-50',
    num_labels=num_labels,
    id2label=id2label,
    label2id=label2id,
    ignore_mismatched_sizes=True
)

initial_trainable_params = count_trainable_parameters(model)
print(f"Trainable parameters before freezing: {initial_trainable_params}")

# Access the classifier's input features
in_features = model.classifier[1].in_features
# in_features = 65536

# Modify the classifier to match the number of classes in CIFAR-10
model.classifier = torch.nn.Sequential(
    torch.nn.Flatten(start_dim=1, end_dim=-1),
    torch.nn.Linear(in_features, 10)
)

# Freeze all layers except the classifier
for name, param in model.named_parameters():
    if 'classifier' not in name:
        param.requires_grad = False

model.to(device)

final_trainable_params = count_trainable_parameters(model)
print(f"Trainable parameters after freezing: {final_trainable_params}")

# Create DataLoader instances for the subsets
train_loader = DataLoader(CustomCIFAR10Dataset(train_subset, feature_extractor), batch_size=32, shuffle=True)
test_loader = DataLoader(CustomCIFAR10Dataset(test_subset, feature_extractor), batch_size=32, shuffle=False)

# Define training arguments
training_args = TrainingArguments(
    output_dir='./results',
    num_train_epochs=3,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    warmup_steps=500,
    weight_decay=0.01,
    logging_dir='./logs',
    logging_steps=10,
    eval_strategy="epoch",
    run_name='resnet-cifar-transfer-learning',
    report_to="none"
)

# Define a function to compute metrics
def compute_metrics(p):
    preds = p.predictions.argmax(-1)
    return {'accuracy': (preds == p.label_ids).astype(float).mean().item()}

# Initialize the Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_loader.dataset,
    eval_dataset=test_loader.dataset,
    compute_metrics=compute_metrics
)

# Train the model
trainer.train()

# Evaluate the model
eval_results = trainer.evaluate()
print(f"Test Accuracy: {eval_results['eval_accuracy']:.4f}")
