In [1]:
import torch
from tqdm import tqdm

# Image-related utilities
from torchvision.io import decode_image, read_image
from torchvision.transforms import ToTensor
from torchvision import transforms
from PIL import Image

# Import models
from torchvision.models import vgg19, VGG19_Weights
from torchvision.models import vgg16, VGG16_Weights
from torchvision.models import vgg11, VGG11_Weights

import torch.nn as nn
import torch.optim as optim

# Dataset
from torchvision.datasets import Imagenette, ImageFolder
from torch.utils.data import DataLoader

# Plotting utility
import matplotlib.pyplot as plt
import pandas as pd

In [2]:
# Define device
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Define transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize for VGG19
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # VGG preprocessing
])

# Read imagenette data into data loader
imagewoof_train = ImageFolder(root='/home/yi/Downloads/imagewoof2/train', transform=transform)
imagewoof_val = ImageFolder(root='/home/yi/Downloads/imagewoof2/val', transform=transform)

batch_size = 32
train_loader = DataLoader(imagewoof_train, batch_size=batch_size, shuffle=True, num_workers=4)
val_loader = DataLoader(imagewoof_val, batch_size=batch_size, shuffle=False, num_workers=4)

In [3]:
# Get number of classes
num_classes = len(imagewoof_train.classes)
print(f"Number of classes: {num_classes}")

Number of classes: 10


In [4]:
model_vgg19 = vgg19(weights=VGG19_Weights.DEFAULT).to(device)
# model_vgg11 = vgg11(weights=VGG11_Weights.DEFAULT).to(device)
# model_vgg16 = vgg16(weights=VGG16_Weights.DEFAULT).to(device)

# Limit the last output features to 10
model_vgg19.classifier[6] = nn.Linear(in_features=4096, out_features=num_classes)
# model_vgg11.classifier[6] = nn.Linear(in_features=4096, out_features=num_classes)
# model_vgg16.classifier[6] = nn.Linear(in_features=4096, out_features=num_classes)

# Move to device
model_vgg19 = model_vgg19.to(device)
# model_vgg11 = model_vgg11.to(device)
# model_vgg16 = model_vgg16.to(device)

In [5]:
# Loss function
criterion = nn.CrossEntropyLoss()

# Optimizer (fine-tuning the whole network)
# optimizer_vgg16 = optim.Adam(model_vgg16.parameters(), lr=1e-4)
# optimizer_vgg11 = optim.Adam(model_vgg11.parameters(), lr=1e-4)
optimizer_vgg19 = optim.Adam(model_vgg19.parameters(), lr=1e-4)

In [6]:
num_epochs = 10  # Adjust as needed

def train_model(model, train_loader, criterion, optimizer):
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        correct, total = 0, 0
    
        loop = tqdm(train_loader, leave=True)
        for images, labels in loop:
            images, labels = images.to(device), labels.to(device)
    
            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)
    
            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
    
            # Compute accuracy
            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    
            running_loss += loss.item()
            loop.set_description(f"Epoch [{epoch+1}/{num_epochs}]")
            loop.set_postfix(loss=running_loss/len(train_loader), acc=100 * correct / total)

In [7]:
def save_model(model, PATH):
    torch.save(model.state_dict(), PATH)

In [8]:
# train_model(model_vgg11, train_loader, criterion, optimizer_vgg11)
# save_model(model_vgg11, "vgg11_Imagewoof.pth")

In [9]:
# train_model(model_vgg16, train_loader, criterion, optimizer_vgg16)
# save_model(model_vgg16, "vgg16_Imagewoof.pth")

In [10]:
train_model(model_vgg19, train_loader, criterion, optimizer_vgg19)
# save_model(model_vgg19, "vgg19_Imagewoof.pth")

Epoch [1/10]: 100%|█████| 283/283 [02:37<00:00,  1.80it/s, acc=84.5, loss=0.487]
Epoch [2/10]: 100%|█████| 283/283 [02:38<00:00,  1.79it/s, acc=91.3, loss=0.273]
Epoch [3/10]: 100%|█████| 283/283 [02:40<00:00,  1.76it/s, acc=93.6, loss=0.193]
Epoch [4/10]: 100%|███████| 283/283 [02:41<00:00,  1.75it/s, acc=96, loss=0.129]
Epoch [5/10]: 100%|█████| 283/283 [02:40<00:00,  1.76it/s, acc=96.8, loss=0.102]
Epoch [6/10]: 100%|████| 283/283 [02:40<00:00,  1.76it/s, acc=96.8, loss=0.0957]
Epoch [7/10]: 100%|████| 283/283 [02:38<00:00,  1.78it/s, acc=97.5, loss=0.0755]
Epoch [8/10]: 100%|██████| 283/283 [02:41<00:00,  1.75it/s, acc=98, loss=0.0666]
Epoch [9/10]: 100%|████| 283/283 [02:39<00:00,  1.77it/s, acc=97.3, loss=0.0843]
Epoch [10/10]: 100%|███| 283/283 [02:40<00:00,  1.76it/s, acc=97.7, loss=0.0664]
