In [1]:
import torch
from tqdm import tqdm

# Image-related utilities
from torchvision.io import decode_image, read_image
from torchvision.transforms import ToTensor
from torchvision import transforms
from PIL import Image

# Import models
from torchvision.models import vgg19, VGG19_Weights
import torch.nn as nn
import torch.optim as optim

# Dataset
from torchvision.datasets import Imagenette, ImageFolder
from torch.utils.data import DataLoader

# Plotting utility
import matplotlib.pyplot as plt
import pandas as pd

In [2]:
# Define device
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Define transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize for VGG19
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # VGG preprocessing
])

# Read imagenette data into data loader
imagenette_train = ImageFolder(root='/home/yi/Downloads/imagenette2/train', transform=transform)
imagenette_val = ImageFolder(root='/home/yi/Downloads/imagenette2/val', transform=transform)

batch_size = 32
train_loader = DataLoader(imagenette_train, batch_size=batch_size, shuffle=True, num_workers=4)
val_loader = DataLoader(imagenette_val, batch_size=batch_size, shuffle=False, num_workers=4)

In [3]:
# Get number of classes
num_classes = len(imagenette_train.classes)
print(f"Number of classes: {num_classes}")

Number of classes: 10


In [4]:
model = vgg19(weights=VGG19_Weights.DEFAULT).to(device)

# Limit the last output features to 10
model.classifier[6] = nn.Linear(in_features=4096, out_features=num_classes)

# Move to device
model = model.to(device)

In [5]:
# Loss function
criterion = nn.CrossEntropyLoss()

# Optimizer (fine-tuning the whole network)
optimizer = optim.Adam(model.parameters(), lr=1e-4)

In [6]:
num_epochs = 10  # Adjust as needed

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    correct, total = 0, 0

    loop = tqdm(train_loader, leave=True)
    for images, labels in loop:
        images, labels = images.to(device), labels.to(device)

        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Compute accuracy
        _, preds = torch.max(outputs, 1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

        running_loss += loss.item()
        loop.set_description(f"Epoch [{epoch+1}/{num_epochs}]")
        loop.set_postfix(loss=running_loss/len(train_loader), acc=100 * correct / total)

Epoch [1/10]: 100%|█████| 296/296 [02:40<00:00,  1.84it/s, acc=91.6, loss=0.288]
Epoch [2/10]: 100%|█████| 296/296 [02:41<00:00,  1.83it/s, acc=96.1, loss=0.139]
Epoch [3/10]: 100%|████| 296/296 [02:42<00:00,  1.82it/s, acc=97.5, loss=0.0824]
Epoch [4/10]: 100%|█████| 296/296 [02:42<00:00,  1.82it/s, acc=97.2, loss=0.102]
Epoch [5/10]: 100%|████| 296/296 [02:43<00:00,  1.81it/s, acc=97.8, loss=0.0731]
Epoch [6/10]: 100%|████| 296/296 [02:43<00:00,  1.81it/s, acc=98.5, loss=0.0524]
Epoch [7/10]: 100%|█████| 296/296 [02:43<00:00,  1.81it/s, acc=98.5, loss=0.051]
Epoch [8/10]: 100%|████| 296/296 [02:44<00:00,  1.80it/s, acc=98.7, loss=0.0454]
Epoch [9/10]: 100%|████| 296/296 [02:46<00:00,  1.78it/s, acc=97.9, loss=0.0726]
Epoch [10/10]: 100%|███| 296/296 [02:48<00:00,  1.76it/s, acc=98.9, loss=0.0308]


In [8]:
PATH = 'vgg19_imagenette.pth'
torch.save(model.state_dict(), PATH)