In [None]:
import torch
import torchvision
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.models import resnet18
from torch.optim.lr_scheduler import StepLR
import matplotlib.pyplot as plt
import numpy as np

# Set random seed for reproducibility
torch.manual_seed(42)

# Define transforms and datasets
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.RandomRotation(20),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

test_transform = transforms.Compose([
    transforms.Resize([224, 224]),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

train_dataset = torchvision.datasets.Flowers102(root='./dataset', split='train', transform=train_transform, download=True)
test_dataset = torchvision.datasets.Flowers102(root='./dataset', split='val', transform=test_transform, download=True)

train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True, drop_last=True)
test_dataloader = DataLoader(test_dataset, batch_size=64, shuffle=False, drop_last=False)

# Function to display images with labels
def show_images(dataset, num_images=10):
    # Set up a figure to plot images
    fig, axes = plt.subplots(1, num_images, figsize=(20, 4))
    for i in range(num_images):
        # Get a random index
        idx = np.random.randint(len(dataset))
        image, label = dataset[idx]
        # Denormalize the image
        image = image.numpy().transpose((1, 2, 0))
        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])
        image = std * image + mean
        image = np.clip(image, 0, 1)
        # Plot the image
        axes[i].imshow(image)
        axes[i].set_title(f"Label: {label}")
        axes[i].axis('off')
    plt.show()

# Show 10 random images with labels from the training dataset
show_images(train_dataset)

