In [None]:
from customDataset import CustomDataset
from torchinfo import summary
from torch.utils.data import DataLoader
import matplotlib.patches as patches
from tqdm import tqdm
import random
import torch.optim as optim
import torch
import torch.nn as nn

from utils import *
from torchvision.models import vgg16, VGG16_Weights

manualSeed = 999 
print("Random Seed: ", manualSeed)
random.seed(manualSeed)
torch.manual_seed(manualSeed)
torch.use_deterministic_algorithms(True)  
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"runnning on {device}")

In [None]:
batch_size = 32
image_size = 224
train_dataset = CustomDataset(
    image_folders=[
        r"..\data\black-and-white-rectangle\train",
    ],
    image_extension=".png",
    image_size=image_size,
)
test_dataset = CustomDataset(
    image_folders=[
        r"..\data\black-and-white-rectangle\val",
    ],
    image_extension=".png",
    image_size=image_size,
)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
len(train_dataset), len(test_dataset)

In [None]:
def show_images_with_bboxes(dataset, num_images=12, num_cols=4):
    """
    Affiche des exemples d'images avec leurs boîtes englobantes associées.
    Args:
        dataset (Dataset): Instance de CustomDataset.
        num_images (int, optional): Nombre d'images à afficher. Par défaut 12.
        num_cols (int, optional): Nombre de colonnes dans l'affichage. Par défaut 4.
    """
    num_rows = (num_images + num_cols - 1) // num_cols  # Calcul du nombre de lignes
    fig, axes = plt.subplots(num_rows, num_cols, figsize=(num_cols * 5, num_rows * 5))
    axes = axes.flatten()  # Aplatit la grille d'axes pour un accès plus facile

    for i in range(num_images):
        sample = dataset[i]
        image = sample['input_img'].numpy().transpose((1, 2, 0))  # Convertir en HWC pour plt
        label = sample['label'].numpy().reshape(-1, 5)

        axes[i].imshow(image)
        axes[i].axis("off")

        # Tracer les boîtes englobantes
        for bbox in label:
            class_id, x_center, y_center, width, height = bbox
            x_center *= image.shape[1]
            y_center *= image.shape[0]
            width *= image.shape[1]
            height *= image.shape[0]

            x1 = x_center - width / 2
            y1 = y_center - height / 2

            rect = patches.Rectangle(
                (x1, y1), width, height, linewidth=2, edgecolor='g', facecolor='none'
            )
            axes[i].add_patch(rect)

    # Supprime les axes inutilisés
    for j in range(i + 1, len(axes)):
        axes[j].axis("off")

    plt.tight_layout()
    plt.show()

# Exemple d'utilisation
show_images_with_bboxes(train_dataset, num_images=12, num_cols=4)

In [None]:
input_size = (batch_size, 3, 224, 224)
model = vgg16(weights=VGG16_Weights.DEFAULT)
# Freeze training for all layers
for param in model.parameters():
    param.require_grad = False
# print(model)
# summary(
#     model,
#     input_size=input_size,
#     col_names=["input_size", "output_size", "num_params", "trainable"],
#     col_width=20,
#     row_settings=["var_names"],
# )

In [None]:
n_inputs = model.classifier[6].in_features
model.classifier[6] =  nn.Sequential(
    nn.Linear(n_inputs, 4),
    nn.Sigmoid()
)
model = model.to(device)
summary(
    model,
    input_size=input_size,
    col_names=["input_size", "output_size", "num_params", "trainable"],
    col_width=20,
    row_settings=["var_names"],
)

In [None]:
def calculate_iou(bbox_pred, bbox_target):
    """
    Calculate Intersection over Union (IoU) between two bounding boxes.
    Args:
        bbox_pred (torch.Tensor): Predicted bounding box of shape (4,).
        bbox_target (torch.Tensor): Target bounding box of shape (4,).
    Returns:
        float: Intersection over Union (IoU) score.
    """
    # Extract coordinates
    x1_pred, y1_pred, w_pred, h_pred = bbox_pred
    x1_target, y1_target, w_target, h_target = bbox_target
    
    # Calculate coordinates of intersection rectangle
    x1_inter = max(x1_pred, x1_target)
    y1_inter = max(y1_pred, y1_target)
    x2_inter = min(x1_pred + w_pred, x1_target + w_target)
    y2_inter = min(y1_pred + h_pred, y1_target + h_target)
    
    # Calculate area of intersection rectangle
    inter_area = max(0, x2_inter - x1_inter) * max(0, y2_inter - y1_inter)
    
    # Calculate area of union of both rectangles
    pred_area = w_pred * h_pred
    target_area = w_target * h_target
    union_area = pred_area + target_area - inter_area
    
    # Calculate IoU
    iou = inter_area / union_area if union_area > 0 else 0.0
    
    return iou
def train_bbox_regression_model(
    model,
    optimizer,
    criterion,
    train_dataloader,
    val_dataloader,
    num_epochs,
    device,
    scheduler=None,
    plot_figs=True,
    nb_batches_to_display=10,
    early_stopper=None,
    save_filepath=None
):
    """
    Train a bounding box regression model.
    Args:
        model (torch.nn.Module): The model to train.
        optimizer (torch.optim.Optimizer): The optimizer used for training.
        criterion (torch.nn.Module): The loss function.
        train_dataloader (torch.utils.data.DataLoader): DataLoader for the training set.
        val_dataloader (torch.utils.data.DataLoader): DataLoader for the validation set.
        num_epochs (int): Number of epochs for training.
        device (torch.device): Device to use for training (e.g., 'cpu', 'cuda').
        scheduler (optional): Learning rate scheduler.
        plot_figs (bool): Whether to plot training and validation curves.
        nb_batches_to_display (int): Number of batches to display during training.
        early_stopper (optional): Early stopping mechanism.
        save_filepath(optional): Path to save the model
    Returns:
        tuple: Tuple containing the trained model and training history.
    """
    print(f"Train: model={type(model).__name__}, opt={type(optimizer).__name__}(lr={optimizer.param_groups[0]['lr']}), num_epochs={num_epochs}, device={device}\n")
    history = {"loss": [], "val_loss": [], "iou": [], "val_iou": []}
    start_time_sec = time.time()
    for epoch in range(1, num_epochs + 1):
        print(f"Epoch {epoch}/{num_epochs}")
        print("=" * 60)
        # TRAINING
        model.train()
        train_loss = 0.0
        total_iou = 0.0
        num_train_examples = 0
        for batch_index, data in enumerate(train_dataloader):
            optimizer.zero_grad()
            inputs, targets = data["input_img"], data["label"].squeeze(1)[:, :4]
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            if scheduler is not None:
                scheduler.step()
            train_loss += loss.item() * inputs.size(0)
            num_train_examples += inputs.size(0)
            # Calculate IoU
            with torch.no_grad():
                for pred, target in zip(outputs, targets):
                    iou = calculate_iou(pred, target)
                    total_iou += iou.item()
            if (nb_batches_to_display > 0 and (batch_index + 1) % nb_batches_to_display == 0):
                print(f"\tBatch {batch_index+1}/{len(train_dataloader)}, loss: {train_loss / num_train_examples:.4f}")
        
        train_loss = train_loss / len(train_dataloader.dataset)
        train_iou = total_iou / num_train_examples if num_train_examples > 0 else 0.0
        
        # VALIDATION
        model.eval()
        val_loss = 0.0
        total_val_iou = 0.0
        num_val_batches = 0
        with torch.no_grad():
            for data in val_dataloader:
                inputs, targets = data["input_img"], data["label"].squeeze(1)[:, :4]
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                val_loss += loss.item() * inputs.size(0)
                
                # Calculate IoU
                for pred, target in zip(outputs, targets):
                    iou = calculate_iou(pred, target)
                    total_val_iou += iou.item()
                    num_val_batches += 1
        
        val_loss = val_loss / len(val_dataloader.dataset)
        val_iou = total_val_iou / num_val_batches if num_val_batches > 0 else 0.0
        
        print(f"Epoch {epoch}/{num_epochs}, train loss: {train_loss:.4f}, train IoU: {train_iou:.4f}, val loss: {val_loss:.4f}, val IoU: {val_iou:.4f}\n")
        history["loss"].append(train_loss)
        history["val_loss"].append(val_loss)
        history["iou"].append(train_iou)
        history["val_iou"].append(val_iou)
        
        # EARLY STOPPING
        if early_stopper is not None and early_stopper.should_stop(val_loss):
            print("Early stopping triggered.")
            break
    
    # END OF TRAINING
    end_time_sec = time.time()
    total_time_sec = end_time_sec - start_time_sec
    time_per_epoch_sec = total_time_sec / num_epochs
    print(f"Time total:     {total_time_sec:.2f} sec")
    print(f"Time per epoch: {time_per_epoch_sec:.2f} sec\n")
    
    # PLOT CURVES
    if plot_figs:
        _, axes = plt.subplots(1, 2, figsize=(12, 5))
        axes[0].plot(history["loss"], label="Train Loss", color="blue", marker="o")
        axes[0].plot(history["val_loss"], label="Validation Loss", color="orange", marker="x")
        axes[0].set_xlabel("Number of Epochs", fontsize=12)
        axes[0].set_ylabel("Loss", fontsize=12)
        axes[0].set_title("Training and Validation Loss Over Epochs", fontsize=14)
        axes[0].legend(loc="best", fontsize=12)
        axes[0].grid(True, linestyle="--", alpha=0.7)
        
        axes[1].plot(history["iou"], label="Train IoU", color="blue", marker="o")
        axes[1].plot(history["val_iou"], label="Validation IoU", color="orange", marker="x")
        axes[1].set_xlabel("Number of Epochs", fontsize=12)
        axes[1].set_ylabel("IoU", fontsize=12)
        axes[1].set_title("Training and Validation IoU Over Epochs", fontsize=14)
        axes[1].legend(loc="best", fontsize=12)
        axes[1].grid(True, linestyle="--", alpha=0.7)
        
        plt.tight_layout()
        plt.show()
    
    if save_filepath is not None:
        torch.save(model, save_filepath)
    
    return model, history

num_epochs = 1
criterion = nn.MSELoss()
optimizer = optim.Adam(model.classifier[6].parameters(), lr=0.001)
early_stopper = EarlyStopper(patience=3, delta=0.001)
model, history = train_bbox_regression_model(
    model=model,
    optimizer=optimizer,
    criterion=criterion,
    train_dataloader=train_dataloader,
    val_dataloader=test_dataloader,
    num_epochs=num_epochs,
    device=device,
    scheduler=None,
    plot_figs=True,
    nb_batches_to_display=2,
    early_stopper=early_stopper,
    save_filepath=None
)