<a href="https://colab.research.google.com/github/GenAIUnplugged/pytorch/blob/main/TransferLearning_Flowers102.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torchvision
import torch.nn as nn
from torchvision.transforms import transforms
import matplotlib.pyplot as plt
import numpy as np
from torchvision.models import resnet50,ResNet50_Weights
from tqdm import tqdm
import torch.optim.lr_scheduler as lr_scheduler # Import scheduler

# %%
# Define transformations
# Added Data Augmentation for training
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224), # Randomly crop and resize
    transforms.RandomHorizontalFlip(), # Randomly flip horizontally
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Normalize with ImageNet stats
])

# No augmentation for validation or testing, only resize and normalize
eval_transform = transforms.Compose([
    transforms.Resize(256), # Resize first
    transforms.CenterCrop(224), # Then crop the center
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])


# %%
# Load Datasets (using separate train, validation, and test sets for Flowers102)
train_ds = torchvision.datasets.Flowers102(
    root="./data",
    split="train",
    transform=train_transform, # Use train_transform
    download=True
)

val_ds = torchvision.datasets.Flowers102( # Load the validation set
    root="./data",
    split="val", # Specify validation split
    transform=eval_transform, # Use eval_transform
    download=True
)

test_ds = torchvision.datasets.Flowers102( # Load the test set
    root="./data",
    split="test", # Specify test split
    transform=eval_transform, # Use eval_transform
    download=True
)

# %%
# Data Loaders
train_loader = torch.utils.data.DataLoader(dataset=train_ds, shuffle=True, batch_size=32, num_workers=2) # Added num_workers
val_loader = torch.utils.data.DataLoader(dataset=val_ds, shuffle=False, batch_size=32, num_workers=2) # Added num_workers and shuffle=False for consistent evaluation
test_loader = torch.utils.data.DataLoader(dataset=test_ds, shuffle=False, batch_size=32, num_workers=2) # Added num_workers and shuffle=False for consistent evaluation


# %%
# Device configuration
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}") # Print device being used

# %%
# Load pre-trained ResNet50 model
model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1) # Specify weights version
model = model.to(device)

# %%
# Modify the final fully connected layer for 102 classes
num_classes = 102 # Set number of classes based on your dataset (Flowers102 has 102)
model.fc = nn.Linear(in_features=model.fc.in_features, out_features=num_classes)
model = model.to(device)

# %%
# Define Loss function and Optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001) # Adam optimizer
# Add a learning rate scheduler
scheduler = lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1) # Decrease learning rate by 10% every 30 epochs

# %%
# Training loop
epochs = 50 # Increased number of epochs

for epoch in range(epochs):
    model.train() # Set model to training mode
    total_loss = 0
    total_correct = 0
    total_samples = 0
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs} [Train]")

    for batch, (image, label) in enumerate(progress_bar):
        image, label = image.to(device), label.to(device)

        # Forward pass
        logits = model(image)
        loss = criterion(logits, label)

        # Backward pass and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Calculate accuracy
        pred_probs = torch.softmax(input=logits, dim=1).argmax(dim=1)
        correct = (pred_probs == label).sum().item()
        total_correct += correct
        total_loss += loss.item()
        total_samples += label.size(0)

        # Update progress bar
        progress_bar.set_postfix(loss=total_loss / (batch + 1), accuracy=total_correct / total_samples)

    # Step the scheduler
    scheduler.step()

    print(f"\nEpoch {epoch+1}/{epochs} [Train] - Loss: {total_loss / len(train_loader):.4f}, Accuracy: {total_correct / total_samples:.4f}")

    # Validation loop
    model.eval() # Set model to evaluation mode
    total_loss = 0
    total_correct = 0
    total_samples = 0
    progress_bar = tqdm(val_loader, desc=f"Epoch {epoch+1}/{epochs} [Validation]") # Use val_loader

    with torch.no_grad(): # Disable gradient calculation for evaluation
        for batch, (image, label) in enumerate(progress_bar):
            image, label = image.to(device), label.to(device)

            # Forward pass
            val_logits = model(image)
            val_loss = criterion(val_logits, label)

            # Calculate accuracy
            pred_probs = torch.softmax(input=val_logits, dim=1).argmax(dim=1)
            correct = (pred_probs == label).sum().item()
            total_correct += correct
            total_loss += val_loss.item()
            total_samples += label.size(0)

            # Update progress bar
            progress_bar.set_postfix(loss=total_loss / (batch + 1), accuracy=total_correct / total_samples)

    print(f"Epoch {epoch+1}/{epochs} [Validation] - Loss: {total_loss / len(val_loader):.4f}, Accuracy: {total_correct / total_samples:.4f}") # Use val_loader

print("\nTraining finished.")

# Final evaluation on the separate test set after all epochs
model.eval()
total_loss = 0
total_correct = 0
total_samples = 0
print("\nEvaluating on the full test set...")
with torch.no_grad():
    for batch, (image, label) in enumerate(test_loader): # Use test_loader for final evaluation
        image, label = image.to(device), label.to(device)
        test_logits = model(image)
        test_loss = criterion(test_logits, label)
        pred_probs = torch.softmax(input=test_logits, dim=1).argmax(dim=1)
        correct = (pred_probs == label).sum().item()
        total_correct += correct
        total_loss += test_loss.item()
        total_samples += label.size(0)

print(f"Final Test Loss: {total_loss / len(test_loader):.4f}, Final Test Accuracy: {total_correct / total_samples:.4f}") # Use test_loader

100%|██████████| 345M/345M [00:02<00:00, 124MB/s]
100%|██████████| 502/502 [00:00<00:00, 816kB/s]
100%|██████████| 15.0k/15.0k [00:00<00:00, 28.8MB/s]


Using device: cuda


Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100%|██████████| 97.8M/97.8M [00:00<00:00, 155MB/s]
Epoch 1/50 [Train]: 100%|██████████| 32/32 [00:12<00:00,  2.54it/s, accuracy=0.0402, loss=4.86]



Epoch 1/50 [Train] - Loss: 4.8564, Accuracy: 0.0402


Epoch 1/50 [Validation]: 100%|██████████| 32/32 [00:06<00:00,  4.73it/s, accuracy=0.0294, loss=13.7]


Epoch 1/50 [Validation] - Loss: 13.7219, Accuracy: 0.0294


Epoch 2/50 [Train]: 100%|██████████| 32/32 [00:10<00:00,  2.95it/s, accuracy=0.0412, loss=4.32]



Epoch 2/50 [Train] - Loss: 4.3218, Accuracy: 0.0412


Epoch 2/50 [Validation]: 100%|██████████| 32/32 [00:07<00:00,  4.16it/s, accuracy=0.0912, loss=3.97]


Epoch 2/50 [Validation] - Loss: 3.9745, Accuracy: 0.0912


Epoch 3/50 [Train]: 100%|██████████| 32/32 [00:10<00:00,  2.98it/s, accuracy=0.0873, loss=3.81]



Epoch 3/50 [Train] - Loss: 3.8052, Accuracy: 0.0873


Epoch 3/50 [Validation]: 100%|██████████| 32/32 [00:06<00:00,  4.57it/s, accuracy=0.106, loss=3.92]


Epoch 3/50 [Validation] - Loss: 3.9232, Accuracy: 0.1059


Epoch 4/50 [Train]: 100%|██████████| 32/32 [00:11<00:00,  2.88it/s, accuracy=0.13, loss=3.55]



Epoch 4/50 [Train] - Loss: 3.5457, Accuracy: 0.1304


Epoch 4/50 [Validation]: 100%|██████████| 32/32 [00:06<00:00,  4.75it/s, accuracy=0.146, loss=3.44]


Epoch 4/50 [Validation] - Loss: 3.4391, Accuracy: 0.1461


Epoch 5/50 [Train]: 100%|██████████| 32/32 [00:11<00:00,  2.87it/s, accuracy=0.152, loss=3.28]



Epoch 5/50 [Train] - Loss: 3.2842, Accuracy: 0.1520


Epoch 5/50 [Validation]: 100%|██████████| 32/32 [00:07<00:00,  4.18it/s, accuracy=0.182, loss=3.39]


Epoch 5/50 [Validation] - Loss: 3.3889, Accuracy: 0.1824


Epoch 6/50 [Train]: 100%|██████████| 32/32 [00:10<00:00,  2.95it/s, accuracy=0.211, loss=3.04]



Epoch 6/50 [Train] - Loss: 3.0420, Accuracy: 0.2108


Epoch 6/50 [Validation]: 100%|██████████| 32/32 [00:06<00:00,  4.85it/s, accuracy=0.25, loss=3.21]


Epoch 6/50 [Validation] - Loss: 3.2140, Accuracy: 0.2500


Epoch 7/50 [Train]: 100%|██████████| 32/32 [00:10<00:00,  2.92it/s, accuracy=0.262, loss=2.82]



Epoch 7/50 [Train] - Loss: 2.8207, Accuracy: 0.2618


Epoch 7/50 [Validation]: 100%|██████████| 32/32 [00:07<00:00,  4.15it/s, accuracy=0.232, loss=3.02]


Epoch 7/50 [Validation] - Loss: 3.0204, Accuracy: 0.2324


Epoch 8/50 [Train]:   0%|          | 0/32 [00:00<?, ?it/s]


KeyboardInterrupt: 

In [None]:
model.save()

In [None]:
# Define the path where you want to save the model
model_save_path = "resnet50_flowers102.pth" # Use a .pth or .pt extension

# Save the model's state dictionary
torch.save(model.state_dict(), model_save_path)

print(f"Model state dictionary saved to {model_save_path}")

In [None]:
# %%
# Import necessary libraries
import torch
import torchvision
from torchvision.transforms import transforms
from torchvision.models import resnet50, ResNet50_Weights
import torch.nn as nn
import matplotlib.pyplot as plt
from PIL import Image # To open images

# %%
# Device configuration (must be the same as or compatible with training)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# %%
# Define the model architecture (must match the saved model)
# Start with a standard ResNet50
loaded_model = resnet50(weights=None) # No pre-trained weights here, we'll load our own

# Modify the final fully connected layer to match the number of classes trained
num_classes = 102 # Must match the number of classes the saved model was trained on
loaded_model.fc = nn.Linear(in_features=loaded_model.fc.in_features, out_features=num_classes)

# Move the model to the device
loaded_model.to(device)

# %%
# Define the path to the saved model state dictionary
model_save_path = "resnet50_flowers102.pth" # Replace with the actual path to your saved file

# Load the saved state dictionary
try:
    loaded_model.load_state_dict(torch.load(model_save_path, map_location=device))
    print(f"Model state dictionary loaded successfully from {model_save_path}")
except FileNotFoundError:
    print(f"Error: Model file not found at {model_save_path}")
    # Handle this error, perhaps exit or skip testing
except Exception as e:
    print(f"Error loading model state dictionary: {e}")
    # Handle other loading errors

# %%
# Set the model to evaluation mode
# This is crucial for inference, it disables dropout and batch normalization tracking
loaded_model.eval()
print("Model set to evaluation mode.")

# %%
# Prepare a single image for testing

# Option 1: Use an image from your dataset (e.g., from the test_ds)
# This requires having the dataset loaded again
# Assuming test_ds is available or you reload it
# Example: Get the first image and label from the test dataset
# image, label = test_ds[0]
# print(f"Original label (index): {label}")

# Option 2: Load an external image file
# Replace 'path/to/your/image.jpg' with the actual path to an image file
try:
    image_path = "./data/flowers-102/test/1/image_06740.jpg" # Example path, replace with yours
    img = Image.open(image_path).convert("RGB") # Open and convert to RGB
    print(f"Loaded image from {image_path}")

    # Apply the same transformations used for testing data during training
    # Use the test_transform defined earlier (or redefine it here)
    test_transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    image_tensor = test_transform(img).unsqueeze(0) # Add a batch dimension (batch size of 1)
    image_tensor = image_tensor.to(device) # Move image to the device

except FileNotFoundError:
    print(f"Error: Image file not found at {image_path}")
    # Handle this error
except Exception as e:
    print(f"Error loading or transforming image: {e}")
    # Handle other image processing errors

# %%
# Perform inference
if 'image_tensor' in locals(): # Check if image was loaded successfully
    print("Performing inference...")
    with torch.no_grad(): # Disable gradient calculation for inference
        output = loaded_model(image_tensor)

    # Get the predicted class probabilities
    probabilities = torch.softmax(output, dim=1)

    # Get the predicted class index
    predicted_class_index = torch.argmax(probabilities, dim=1).item()

    # Optional: If you have a list of class names, you can get the class name
    # For Flowers102, the dataset object has .classes or similar
    # If you are using Flowers102 dataset object:
    # class_names = test_ds.classes
    # predicted_class_name = class_names[predicted_class_index]
    # print(f"Predicted class: {predicted_class_name} (Index: {predicted_class_index})")

    # If you don't have class names readily available:
    print(f"Predicted class index: {predicted_class_index}")

    # Optional: Print the probabilities for all classes
    # print("Class probabilities:")
    # print(probabilities)

else:
    print("Skipping inference due to image loading error.")

In [None]:
# %%
# Assuming the necessary imports, device setup, and loaded_model are already defined

# Ensure the model is in evaluation mode
loaded_model.eval()
print("Model set to evaluation mode.")

# %%
# Get a single batch from the test loader
try:
    # Iterate through the first batch of the test loader
    # You can stop after getting the first batch
    print("Getting a batch from the test loader...")
    for batch, (image, label) in enumerate(test_loader):
        # We only need the first batch for this example
        image_batch = image
        label_batch = label
        break # Exit the loop after the first batch

    print(f"Obtained a batch of size {image_batch.size(0)} from the test loader.")

    # Select a single image from the batch for inference
    # You can change the index (e.g., 0, 1, 2...) to pick a different image from the batch
    single_image_tensor = image_batch[0].unsqueeze(0) # Get the first image and add a batch dimension
    original_label_for_single_image = label_batch[0].item() # Get the label for that image

    # Move the single image tensor to the device
    single_image_tensor = single_image_tensor.to(device)

    print(f"Selected image at index 0 from the batch. Original label: {original_label_for_single_image}")

except NameError:
    print("Error: test_loader not found. Please run the Data Loaders cell first.")
    # Handle this error if test_loader is not available
except IndexError:
     print("Error: The test_loader is empty. Check your dataset and loader setup.")
except Exception as e:
    print(f"Error getting data from test loader: {e}")
    # Handle other errors

# %%
# Perform inference on the single image
if 'single_image_tensor' in locals() and single_image_tensor is not None: # Check if image was successfully obtained
    print("Performing inference on the single image...")
    with torch.no_grad(): # Disable gradient calculation for inference
        output = loaded_model(single_image_tensor)

    # Get the predicted class probabilities
    probabilities = torch.softmax(output, dim=1)

    # Get the predicted class index
    predicted_class_index = torch.argmax(probabilities, dim=1).item()

    # Optional: If you have a list of class names, you can get the class name
    # If you are using Flowers102 dataset object and have access to test_ds:
    # class_names = test_ds.classes # Assuming test_ds is available
    # predicted_class_name = class_names[predicted_class_index]
    # print(f"Predicted class: {predicted_class_name} (Index: {predicted_class_index})")

    # If you don't have class names readily available:
    print(f"Predicted class index: {predicted_class_index}")
    print(f"True class index: {original_label_for_single_image}")

    # Optional: Print the top N probabilities
    # top_prob, top_cat = torch.topk(probabilities[0], 5) # Get top 5
    # print("\nTop 5 Predicted Probabilities and Indices:")
    # for i in range(top_prob.size(0)):
    #     print(f"  Index: {top_cat[i].item()}, Probability: {top_prob[i].item():.4f}")

else:
    print("Skipping inference as a single image tensor could not be obtained.")