In [None]:
# Import necessary packages
import torch
import numpy as np
from torch import nn, optim
import torch.nn.functional as F
from torchvision import datasets, transforms, models
import json
from collections import OrderedDict
import time
from PIL import Image
import matplotlib.pyplot as plt

# Define the data directory
data_dir = 'flowers'
print(data_dir)

# Define directories for training, validation, and testing
train_dir = f"{data_dir}/train"
valid_dir = f"{data_dir}/valid"
test_dir = f"{data_dir}/test"
print(train_dir, valid_dir, test_dir)

# Define transformations for the training, validation, and testing sets
train_transforms = transforms.Compose([
    transforms.RandomRotation(30),
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

valid_transforms = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

test_transforms = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Load datasets using ImageFolder
train_data = datasets.ImageFolder(train_dir, transform=train_transforms)
test_data = datasets.ImageFolder(test_dir, transform=test_transforms)
valid_data = datasets.ImageFolder(valid_dir, transform=test_transforms)

# Print dataset information
print(train_data, test_data, valid_data)

# Define dataloaders for the datasets
trainloader = torch.utils.data.DataLoader(train_data, batch_size=64, shuffle=True)
testloader = torch.utils.data.DataLoader(test_data, batch_size=32)
validloader = torch.utils.data.DataLoader(valid_data, batch_size=32)

# Load category label mapping from JSON file
with open('cat_to_name.json', 'r') as f:
    cat_to_name = json.load(f)
    print(cat_to_name)

no_output_categories = len(cat_to_name)

# Build the network
hidden_units = 4096
model = models.vgg16_bn(weights='DEFAULT')

# Freeze the parameters of the pretrained model
for param in model.parameters():
    param.requires_grad = False

# Define the classifier
classifier = nn.Sequential(OrderedDict([
    ('fc1', nn.Linear(25088, hidden_units)),
    ('relu', nn.ReLU()),
    ('dropout1', nn.Dropout(0.05)),
    ('fc2', nn.Linear(hidden_units, no_output_categories)),
    ('output', nn.LogSoftmax(dim=1))
]))

# Replace the pretrained classifier with the new one
model.classifier = classifier

# Set device to GPU if available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)
print(f'The device in use is {device}.\n')

# Set training hyperparameters
epochs = 10
optimizer = optim.Adam(model.classifier.parameters(), lr=0.001)
criterion = nn.NLLLoss()
print_every = 20

# Initialize metrics
running_loss = 0
running_accuracy = 0
validation_losses, training_losses = [], []

# Training process
for e in range(epochs):
    batches = 0

    model.train()  # Set model to training mode

    for images, labels in trainloader:
        start = time.time()
        batches += 1

        images, labels = images.to(device), labels.to(device)

        # Forward pass
        log_ps = model(images)
        loss = criterion(log_ps, labels)
        loss.backward()
        optimizer.step()

        # Calculate metrics
        ps = torch.exp(log_ps)
        top_ps, top_class = ps.topk(1, dim=1)
        matches = (top_class == labels.view(*top_class.shape)).type(torch.FloatTensor)
        accuracy = matches.mean()

        optimizer.zero_grad()  # Reset gradients
        running_loss += loss.item()
        running_accuracy += accuracy.item()

        # Validation every print_every batches
        if batches % print_every == 0:
            end = time.time()
            training_time = end - start
            start = time.time()

            # Validation metrics
            validation_loss = 0
            validation_accuracy = 0

            model.eval() 