In [None]:
from udl_project.training import config

EPOCHS = 10

In [None]:
from udl_project.training.resnet_model_trainer import ResNetModelTrainer

res_net_mode_trainer = ResNetModelTrainer(epochs=EPOCHS, learning_rate=config.LEARNING_RATE)
res_net_mode_trainer.train()

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch

from udl_project.data_handling.data_loader_flowers import DataLoaderFlowers
from udl_project.data_handling.flower_dataset import FlowerDataset

# Create the dataset and dataloader
flower_dataset = FlowerDataset(train_test_spilt=0.8)
dataloader = DataLoaderFlowers.create_dataloader(flower_dataset)

# Get a batch of training data
train_loader = dataloader.get_train_dataloader()
data_iter = iter(train_loader)
images, labels = next(data_iter)


# Function to show images
def show_augmented_images(images, labels, class_names, num_images=8):
    """Display a grid of augmented images."""
    fig, axes = plt.subplots(2, 4, figsize=(16, 8))
    axes = axes.flatten()

    for i in range(min(num_images, len(images))):
        # Convert tensor to numpy and transpose for matplotlib
        img = images[i].clone()

        # Clamp values to [0, 1] range
        img = torch.clamp(img, 0, 1)

        # Convert to numpy and transpose
        img_np = img.permute(1, 2, 0).numpy()

        axes[i].imshow(img_np)
        axes[i].set_title(f"Class: {class_names[labels[i]]}")
        axes[i].axis("off")

    plt.tight_layout()
    plt.show()


# Get class names from the dataset
class_names = train_loader.classes
print(f"Classes: {class_names}")
print(f"Image batch shape: {images.shape}")
print(f"Labels batch shape: {labels.shape}")

# Show the augmented images
show_augmented_images(images, labels, class_names)

  from .autonotebook import tqdm as notebook_tqdm


Data directory: /home/jannes/.cache/kagglehub/datasets/lara311/flowers-five-classes/versions/1




AttributeError: 'DataLoader' object has no attribute 'classes'

In [None]:
# Show multiple augmentations of the same image
def show_multiple_augmentations(dataset, image_idx=0, num_augmentations=8):
    """Show multiple augmented versions of the same image."""
    # Get the original image and label (before any transforms)
    original_img, label = dataset.dataset[image_idx]

    # Apply the training transforms multiple times
    train_transform = dataloader.train_data.dataset.transform

    fig, axes = plt.subplots(2, 4, figsize=(16, 8))
    axes = axes.flatten()

    for i in range(num_augmentations):
        # Apply transform to the original image
        augmented_img = train_transform(original_img)

        # Convert to displayable format
        img_display = augmented_img.clone()

        img_display = torch.clamp(img_display, 0, 1)
        img_np = img_display.permute(1, 2, 0).numpy()

        axes[i].imshow(img_np)
        axes[i].set_title(f"Augmentation {i + 1}")
        axes[i].axis("off")

    plt.suptitle(f"Multiple augmentations of the same {class_names[label]} image", fontsize=16)
    plt.tight_layout()
    plt.show()


# Show multiple augmentations of the same image
print("Showing multiple augmentations of the same image:")
show_multiple_augmentations(flower_dataset.get_train_dataset())

In [None]:
# Compare original vs augmented images
def compare_original_vs_augmented(dataset, num_samples=4):
    """Compare original images with their augmented versions."""
    fig, axes = plt.subplots(2, num_samples, figsize=(16, 8))

    for i in range(num_samples):
        # Get original image
        original_img, label = dataset.dataset[i]

        # Apply augmentation
        augmented_img = dataloader.train_data.dataset.transform(original_img)

        # Show original image
        if isinstance(original_img, torch.Tensor):
            orig_np = original_img.permute(1, 2, 0).numpy()
        else:
            # If PIL image, convert to numpy
            orig_np = np.array(original_img)
            orig_np = orig_np / 255.0 if orig_np.max() > 1 else orig_np

        axes[0, i].imshow(orig_np)
        axes[0, i].set_title(f"Original: {class_names[label]}")
        axes[0, i].axis("off")

        # Show augmented image
        aug_display = augmented_img.clone()
        aug_display = torch.clamp(aug_display, 0, 1)
        aug_np = aug_display.permute(1, 2, 0).numpy()

        axes[1, i].imshow(aug_np)
        axes[1, i].set_title(f"Augmented: {class_names[label]}")
        axes[1, i].axis("off")

    plt.suptitle("Original vs Augmented Images", fontsize=16)
    plt.tight_layout()
    plt.show()


print("Comparing original vs augmented images:")
compare_original_vs_augmented(flower_dataset.get_train_dataset())