# Train and Hyperparameter Tuning

In [None]:
import os
import ast
import gc
import json
import time
import random
from datetime import datetime
from collections import defaultdict, Counter

import cv2
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision
from torchvision.models.detection import maskrcnn_resnet50_fpn, MaskRCNN_ResNet50_FPN_Weights

import albumentations as A
from albumentations.pytorch import ToTensorV2


BASE_PATH = "/kaggle/input/deepfashion2-original-with-dataframes/DeepFashion2"
IMAGE_PATHS = {
    "train": f"{BASE_PATH}/deepfashion2_original_images/train/image",
    "validation": f"{BASE_PATH}/deepfashion2_original_images/validation/image",
    "test": f"{BASE_PATH}/deepfashion2_original_images/test/test/image",
}
CSV_PATHS = {
    "train": f"{BASE_PATH}/img_info_dataframes/train.csv",
    "validation": f"{BASE_PATH}/img_info_dataframes/validation.csv",
    "test": f"{BASE_PATH}/img_info_dataframes/test.csv",
}

IMAGE_SIZE = (224, 224)

CLASSES = [
    "background",
    "short sleeve top",
    "trousers",
    "shorts",
    "long sleeve top",
    "skirt",
    "vest dress",
    "short sleeve dress",
    "vest",
    "long sleeve outwear",
    "long sleeve dress",
    "sling dress",
    "sling",
    "short sleeve outwear",
]
NUM_CLASSES = len(CLASSES)

CLASS_WEIGHTS_DICT = {
    "short sleeve top": 0.335,
    "trousers": 0.434,
    "shorts": 0.656,
    "long sleeve top": 0.666,
    "skirt": 0.779,
    "vest dress": 1.338,
    "short sleeve dress": 1.395,
    "vest": 1.492,
    "long sleeve outwear": 1.785,
    "long sleeve dress": 3.037,
    "sling dress": 3.699,
    "sling": 12.098,
    "short sleeve outwear": 44.225,
}
CLASS_WEIGHTS = torch.tensor([1.0] + [CLASS_WEIGHTS_DICT[c] for c in CLASSES[1:]], dtype=torch.float32)

def set_seed(seed: int = 42) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


set_seed(42)


def safe_read_csv(path: str) -> pd.DataFrame:
    try:
        df = pd.read_csv(path)
        print(f"✅ Loaded dataset: {len(df)} samples | columns: {list(df.columns)}")
        return df
    except FileNotFoundError:
        print(f"❌ Could not find CSV file at {path}")
        return pd.DataFrame()


def load_dataset(split: str) -> pd.DataFrame:
    return safe_read_csv(CSV_PATHS[split])


def load_image(filename: str, split: str) -> np.ndarray:

        full_path = os.path.join(IMAGE_PATHS[split], filename)
    if not os.path.exists(full_path):
        print(f"❌ Image not found: {full_path}")
        return None
    try:
        img = cv2.imread(full_path)
        return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    except Exception as exc:
        print(f"❌ Error loading image {full_path}: {exc}")
        return None



def parse_segmentation(seg_text) -> list:

    if seg_text is None:
        return []
    if isinstance(seg_text, float) and pd.isna(seg_text):
        return []

    if isinstance(seg_text, list):
        return [np.array(poly, dtype=np.int32).reshape(-1, 2) for poly in seg_text if len(poly) >= 6]

    if isinstance(seg_text, np.ndarray):
        arr = seg_text.astype(np.int32)
        return [arr.reshape(-1, 2)] if arr.size >= 6 else []

    if isinstance(seg_text, str):
        s = seg_text.strip()
        if s in {"", "[]", "nan", "none", "None"}:
            return []
        try:
            parsed = ast.literal_eval(s)
            if isinstance(parsed, list):
                return [np.array(poly, dtype=np.int32).reshape(-1, 2) for poly in parsed if len(poly) >= 6]
        except Exception:
            return []

    return []


def create_mask(polygons: list, image_shape) -> np.ndarray:
    mask = np.zeros(image_shape[:2], dtype=np.uint8)
    for poly in polygons:
        poly_arr = np.asarray(poly, dtype=np.int32)
        if poly_arr.ndim == 2 and poly_arr.shape[0] >= 3 and poly_arr.shape[1] == 2:
            cv2.fillPoly(mask, [poly_arr], 255)
    return mask


def clip_and_validate_bbox(bbox, img_w, img_h):
    x1, y1, x2, y2 = bbox
    x1 = max(0, min(x1, img_w - 1))
    x2 = max(0, min(x2, img_w - 1))
    y1 = max(0, min(y1, img_h - 1))
    y2 = max(0, min(y2, img_h - 1))
    return None if (x2 <= x1 or y2 <= y1) else [x1, y1, x2, y2]



train_transform = A.Compose(
    [
        A.Resize(*IMAGE_SIZE),
        A.HorizontalFlip(p=0.5),
        A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1, p=0.5),
        A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ToTensorV2(),
    ],
    bbox_params=A.BboxParams(format="pascal_voc", min_visibility=0.1, label_fields=["labels"]),
)

eval_transform = A.Compose(
    [
        A.Resize(*IMAGE_SIZE),
        A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ToTensorV2(),
    ],
    bbox_params=A.BboxParams(format="pascal_voc", min_visibility=0.0, label_fields=["labels"], clip=True),
)



class DeepFashionMaskRCNNDataset(Dataset):
    def __init__(self, df: pd.DataFrame, split: str = "train"):
        self.df = df.reset_index(drop=True)
        self.split = split
        if split != "test" and "category_name" in df.columns:
            cats = sorted(df["category_name"].dropna().unique().tolist())
            self.category_to_id = {c: i + 1 for i, c in enumerate(cats)}
        else:
            self.category_to_id = {}

    def __len__(self):
        return len(self.df)

    def _parse_bbox(self, bbox_str, img_w, img_h):

        try:
            bbox = ast.literal_eval(bbox_str) if isinstance(bbox_str, str) else bbox_str
            if not bbox or len(bbox) != 4:
                return []
            x, y, w, h = bbox
            x1 = max(0, x)
            y1 = max(0, y)
            x2 = min(img_w - 1, x + w)
            y2 = min(img_h - 1, y + h)
            if x2 <= x1 or y2 <= y1:
                return []
            return [[x1, y1, x2, y2]]
        except Exception as exc:
            print(f"Error parsing bbox: {exc}")
            return []

    def _bbox_is_valid(self, bbox):
        x1, y1, x2, y2 = bbox
        if not all(isinstance(v, (int, float)) for v in bbox):
            return False
        if x2 <= x1 + 2.0 or y2 <= y1 + 2.0:
            return False
        if x1 < 0 or y1 < 0 or x2 > IMAGE_SIZE[1] or y2 > IMAGE_SIZE[0]:
            return False
        return True

    def __getitem__(self, idx):
        row = self.df.iloc[idx]

        image = load_image(row["path"], self.split)
        if image is None:
            image = np.zeros((IMAGE_SIZE[0], IMAGE_SIZE[1], 3), dtype=np.uint8)
        img_h, img_w = image.shape[:2]

        boxes = self._parse_bbox(row.get("b_box"), img_w, img_h)

        labels = []
        if boxes and self.split != "test":
            category_name = row.get("category_name")
            if category_name and not pd.isna(category_name):
                cid = self.category_to_id.get(category_name)
                if cid is not None:
                    labels = [int(cid)] * len(boxes)

        masks = []
        if self.split != "test":
            polys = parse_segmentation(row.get("segmentation"))
            if polys:
                mask = create_mask(polys, image.shape)
                if mask is not None and mask.sum() > 0:
                    masks = [mask.astype(np.uint8)]

        transform = train_transform if self.split == "train" else eval_transform
        try:
            transformed = transform(image=image, masks=masks, bboxes=boxes, labels=labels)
        except Exception as exc:
            print(f"Transform error at idx {idx} ({row.get('path')}): {exc}")
            return self.__getitem__((idx + 1) % len(self.df))

        image_t = transformed["image"]
        boxes_t = (
            torch.as_tensor(transformed["bboxes"], dtype=torch.float32)
            if transformed["bboxes"]
            else torch.zeros((0, 4), dtype=torch.float32)
        )
        labels_t = (
            torch.as_tensor(transformed["labels"], dtype=torch.int64)
            if transformed["labels"]
            else torch.zeros((0,), dtype=torch.int64)
        )

        if transformed.get("masks"):
            masks_t = torch.stack([torch.as_tensor(m, dtype=torch.uint8) for m in transformed["masks"]])
        else:
            masks_t = torch.zeros((0, IMAGE_SIZE[0], IMAGE_SIZE[1]), dtype=torch.uint8)

        if len(transformed["bboxes"]) == 0 and len(boxes) > 0:
            print(f"Dropped all boxes at idx {idx}, path={row.get('path')}")

        area = (boxes_t[:, 2] - boxes_t[:, 0]) * (boxes_t[:, 3] - boxes_t[:, 1]) if len(boxes_t) > 0 else torch.zeros((0,), dtype=torch.float32)

        target = {
            "boxes": boxes_t,
            "labels": labels_t,
            "masks": masks_t,
            "image_id": torch.tensor([idx]),
            "area": area,
            "iscrowd": torch.zeros((len(boxes_t),), dtype=torch.int64),
        }

        if self.split == "test":
            return image_t, {}
        return image_t, target


def collate_fn(batch):
    batch = [b for b in batch if b is not None]
    if not batch:
        dummy_img = torch.zeros((3, IMAGE_SIZE[0], IMAGE_SIZE[1]), dtype=torch.float32)
        dummy_tgt = {
            "boxes": torch.zeros((0, 4), dtype=torch.float32),
            "labels": torch.zeros((0,), dtype=torch.int64),
            "masks": torch.zeros((0, IMAGE_SIZE[0], IMAGE_SIZE[1]), dtype=torch.uint8),
            "image_id": torch.tensor([0]),
            "area": torch.zeros((0,), dtype=torch.float32),
            "iscrowd": torch.zeros((0,), dtype=torch.int64),
        }
        return [dummy_img], [dummy_tgt]
    images, targets = zip(*batch)
    return list(images), list(targets)



def build_weighted_maskrcnn(num_classes: int, class_weights: torch.Tensor = None, device: str = "cuda"):
    weights = MaskRCNN_ResNet50_FPN_Weights.DEFAULT
    model = maskrcnn_resnet50_fpn(weights=weights)

    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = torchvision.models.detection.faster_rcnn.FastRCNNPredictor(in_features, num_classes)

    in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
    hidden_layer = 256
    model.roi_heads.mask_predictor = torchvision.models.detection.mask_rcnn.MaskRCNNPredictor(in_features_mask, hidden_layer, num_classes)

    if class_weights is not None:
        w = class_weights.detach().float().to(device)

        def weighted_classifier_loss(class_logits, labels):
            return F.cross_entropy(class_logits, labels, weight=w)

        model.roi_heads.fastrcnn_loss = lambda class_logits, labels, *args, **kwargs: weighted_classifier_loss(class_logits, labels)

    return model


def train_maskrcnn_epoch(model, loader, optimizer, device):
    model.train()
    total_loss = 0.0
    component_sums = defaultdict(float)

    for i, (images, targets) in enumerate(loader):
        images = [img.to(device) for img in images]
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        optimizer.zero_grad()
        loss_dict = model(images, targets)
        loss = sum(loss_dict.values())
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        for k, v in loss_dict.items():
            component_sums[k] += v.item()

        if i % 1000 == 0:
            comps = ", ".join([f"{k}:{v.item():.4f}" for k, v in loss_dict.items()])
            print(f"Batch {i}/{len(loader)} | Loss {loss.item():.4f} | {comps}")

    avg_components = {k: v / max(1, len(loader)) for k, v in component_sums.items()}
    return total_loss / max(1, len(loader)), avg_components


def validate_maskrcnn(model, loader, device):
    model.train()
    total_loss = 0.0
    component_sums = defaultdict(float)

    with torch.no_grad():
        for images, targets in loader:
            images = [img.to(device) for img in images]
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
            loss_dict = model(images, targets)
            loss = sum(loss_dict.values())
            total_loss += loss.item()
            for k, v in loss_dict.items():
                component_sums[k] += v.item()

    avg_components = {k: v / max(1, len(loader)) for k, v in component_sums.items()}
    return total_loss / max(1, len(loader)), avg_components


def train_model(model, train_loader, optimizer, device, num_epochs=5, save_dir="training_results", val_loader=None):
    os.makedirs(save_dir, exist_ok=True)
    print("=" * 80)
    print(f"Starting training at {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
    print(f"Device: {device} | Epochs: {num_epochs} | Save dir: {save_dir}")
    print("=" * 80)

    train_losses, train_components_hist, epoch_times = [], [], []

    for epoch in range(num_epochs):
        t0 = time.time()
        print(f"\nEpoch {epoch+1}/{num_epochs}\n" + "-" * 50)

        tr_loss, tr_comps = train_maskrcnn_epoch(model, train_loader, optimizer, device)
        train_losses.append(tr_loss)
        train_components_hist.append(tr_comps)

        if val_loader is not None:
            val_loss, val_comps = validate_maskrcnn(model, val_loader, device)
            print(f"Val Loss: {val_loss:.4f}")

        epoch_time = time.time() - t0
        epoch_times.append(epoch_time)
        print(f"Train Loss: {tr_loss:.4f} | Epoch time: {epoch_time:.1f}s")

        ckpt = {
            "epoch": epoch + 1,
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "train_loss": tr_loss,
        }
        torch.save(ckpt, os.path.join(save_dir, f"model_epoch_{epoch+1}.pth"))
        print(f"✓ Saved checkpoint: epoch {epoch+1}")

    total_time = sum(epoch_times)
    print("\n" + "=" * 80)
    print("TRAINING COMPLETE")
    print(f"Total time: {total_time:.1f}s | Avg/epoch: {np.mean(epoch_times):.1f}s")
    print("=" * 80)

    return {
        "train_losses": train_losses,
        "train_loss_components": train_components_hist,
        "training_time": total_time,
        "avg_epoch_time": float(np.mean(epoch_times) if epoch_times else 0.0),
    }


def mask_iou(m1: np.ndarray, m2: np.ndarray) -> float:
    inter = np.logical_and(m1, m2).sum()
    union = np.logical_or(m1, m2).sum()
    return float(inter / union) if union > 0 else 0.0


def evaluate_model(model, data_loader, device, iou_thresh=0.5, score_thresh=0.5, num_classes=NUM_CLASSES):
    model.eval()

    detection_records = []
    per_class_ious = defaultdict(list)

    with torch.no_grad():
        for images, targets in data_loader:
            images = [img.to(device) for img in images]
            outputs = model(images)

            for out, tgt in zip(outputs, targets):
                p_boxes = out["boxes"].cpu().numpy()
                p_scores = out["scores"].cpu().numpy()
                p_labels = out["labels"].cpu().numpy()
                p_masks = out["masks"].cpu().numpy()[:, 0] > 0.5 if "masks" in out else []

                t_boxes = tgt["boxes"].cpu().numpy()
                t_labels = tgt["labels"].cpu().numpy()
                t_masks = tgt["masks"].cpu().numpy() if "masks" in tgt else []

                keep = p_scores >= score_thresh
                p_boxes, p_scores, p_labels = p_boxes[keep], p_scores[keep], p_labels[keep]
                p_masks = p_masks[keep] if len(p_masks) else []

                used = set()
                for pb, ps, pl, pm in zip(p_boxes, p_scores, p_labels, p_masks if len(p_masks) else [None] * len(p_boxes)):
                    matched = False
                    for j, (tb, tl) in enumerate(zip(t_boxes, t_labels)):
                        if j in used or tl != pl:
                            continue
                        if len(t_masks):
                            iou = mask_iou(pm, t_masks[j] > 0)
                        else:
                            xA = max(pb[0], tb[0]); yA = max(pb[1], tb[1])
                            xB = min(pb[2], tb[2]); yB = min(pb[3], tb[3])
                            inter = max(0, xB - xA) * max(0, yB - yA)
                            area_p = (pb[2] - pb[0]) * (pb[3] - pb[1])
                            area_t = (tb[2] - tb[0]) * (tb[3] - tb[1])
                            union = area_p + area_t - inter
                            iou = float(inter / union) if union > 0 else 0.0
                        if iou >= iou_thresh:
                            matched = True
                            used.add(j)
                            break
                    detection_records.append({"label": int(pl), "score": float(ps), "correct": bool(matched)})

                if len(p_masks) and len(t_masks):
                    for tl in np.unique(t_labels):
                        t_class_idxs = [k for k, lab in enumerate(t_labels) if lab == tl]
                        if not t_class_idxs:
                            continue
                        for k in t_class_idxs:
                            best = 0.0
                            for pl2, pm2 in zip(p_labels, p_masks):
                                if pl2 != tl:
                                    continue
                                best = max(best, mask_iou(pm2, t_masks[k] > 0))
                            per_class_ious[int(tl)].append(best)

    try:
        from sklearn.metrics import average_precision_score

        per_class_ap = {}
        for c in range(1, num_classes):
            recs = [r for r in detection_records if r["label"] == c]
            if not recs:
                per_class_ap[c] = 0.0
                continue
            y_true = [1 if r["correct"] else 0 for r in recs]
            y_score = [r["score"] for r in recs]
            if len(set(y_true)) > 1:
                per_class_ap[c] = float(average_precision_score(y_true, y_score))
            else:
                per_class_ap[c] = 0.0
        mAP = float(np.mean(list(per_class_ap.values()))) if per_class_ap else 0.0
    except Exception:
        per_class_ap, mAP = {}, 0.0

    per_class_mIoU = {c: (float(np.mean(v)) if len(v) else 0.0) for c, v in per_class_ious.items()}
    return {"mAP": mAP, "per_class_AP": per_class_ap, "per_class_mIoU": per_class_mIoU}


def clean_dataframe(df: pd.DataFrame) -> pd.DataFrame:
    valid_rows = []
    for _, row in df.iterrows():
        try:
            if pd.isna(row["b_box"]):
                continue
            x, y, w, h = (ast.literal_eval(row["b_box"]) if isinstance(row["b_box"], str) else row["b_box"])[:4]
            if w >= 3 and h >= 3:
                valid_rows.append(row)
        except Exception:
            continue
    return pd.DataFrame(valid_rows).reset_index(drop=True)


from sklearn.model_selection import train_test_split


def stratified_subsample(df, n_samples, label_col="category_name"):
    if n_samples >= len(df):
        return df
    subset, _ = train_test_split(df, train_size=n_samples, stratify=df[label_col], random_state=42)
    return subset


def plot_category_distribution(df: pd.DataFrame, split_name: str = "train"):
    if "category_name" not in df.columns:
        print(f"❌ No category information available for {split_name} split")
        return
    counts = df["category_name"].value_counts()
    plt.figure(figsize=(12, 6))
    bars = plt.bar(range(len(counts)), counts.values, alpha=0.7)
    plt.xlabel("Clothing Categories")
    plt.ylabel("Number of Samples")
    plt.title(f"Category Distribution - {split_name.upper()}")
    plt.xticks(range(len(counts)), counts.index, rotation=45, ha="right")
    for bar, val in zip(bars, counts.values):
        plt.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 10, str(val), ha="center", va="bottom")
    plt.tight_layout()
    plt.show()


def show_image_grid(df, split='train', start_index=0, rows=2, cols=3):

    fig, axes = plt.subplots(rows, cols, figsize=(15, 10))
    
    if rows == 1 and cols == 1:
        axes = [axes]
    elif rows == 1 or cols == 1:
        axes = axes.flatten()
    else:
        axes = axes.flatten()
    
    for i, ax in enumerate(axes):
        index = start_index + i
        
        if index >= len(df):
            ax.text(0.5, 0.5, 'No more images', ha='center', va='center', 
                   transform=ax.transAxes, fontsize=12)
            ax.set_xticks([])
            ax.set_yticks([])
            continue
        
        row = df.iloc[index]
        
        img = load_image(row['path'], split)
        if img is None:
            ax.text(0.5, 0.5, 'Image not found', ha='center', va='center', 
                   transform=ax.transAxes, fontsize=12)
            ax.set_xticks([])
            ax.set_yticks([])
            continue
        
        ax.imshow(img)
        
        if split != 'test':
            if 'segmentation' in row and not pd.isna(row['segmentation']):
                polygons = parse_segmentation(row['segmentation'])
              
                if polygons:
                    mask = create_mask(polygons, img.shape)
                    colored_mask = np.zeros_like(img)
                    colored_mask[mask > 0] = [255, 100, 100]  
                    ax.imshow(colored_mask, alpha=0.4)
            
            title_parts = []
            if 'category_name' in row and not pd.isna(row['category_name']):
                title_parts.append(f"category_name: {row['category_name']}")
            
            if 'scale' in row and not pd.isna(row['scale']):
                title_parts.append(f"Scale: {row['scale']}")
            
            if 'occlusion' in row and not pd.isna(row['occlusion']):
                title_parts.append(f"Occl: {row['occlusion']}")

            if 'zoom_in' in row and not pd.isna(row['zoom_in']):
                title_parts.append(f"zoom_in: {row['zoom_in']}")

            if 'viewpoint' in row and not pd.isna(row['viewpoint']):
                title_parts.append(f"viewpoint: {row['viewpoint']}")
            
            ax.set_title('\n'.join(title_parts), fontsize=8)

        ax.set_xticks([])
        ax.set_yticks([])
    
    plt.tight_layout()
    plt.suptitle(f'{split.upper()} Dataset', fontsize=14, y=1.02)
    plt.show()


if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    print("Loading datasets.")
    train_df = load_dataset("train")
    val_df = load_dataset("validation")

    print("Training samples grid:")
    show_image_grid(train_df, 'train', start_index=0, rows=5, cols=5)

    print("Validation samples grid:")
    show_image_grid(val_df, 'validation', start_index=0, rows=5, cols=5)



    

    train_size = 300
    val_size = 200

    if train_size is not None:
        print(f"Subsampling train -> {train_size} samples")
        train_df = stratified_subsample(train_df, train_size)

    if val_size is not None:
        print(f"Subsampling validation -> {val_size} samples")
        val_df = stratified_subsample(val_df, val_size)

    if not train_df.empty:
        train_df = clean_dataframe(train_df)

    show_stats = False
    if show_stats and not train_df.empty:
        plot_category_distribution(train_df, "train")

    train_set = DeepFashionMaskRCNNDataset(train_df, split="train")
    val_set = DeepFashionMaskRCNNDataset(val_df, split="validation") if not val_df.empty else None

    train_loader = DataLoader(train_set, batch_size=4, shuffle=True, num_workers=2, collate_fn=collate_fn)
    val_loader = DataLoader(val_set, batch_size=4, shuffle=False, num_workers=2, collate_fn=collate_fn) if val_set else None


    model = build_weighted_maskrcnn(NUM_CLASSES, CLASS_WEIGHTS, device=str(device))
    model.to(device)

    optimizer = torch.optim.SGD([p for p in model.parameters() if p.requires_grad], lr=0.005, momentum=0.9, weight_decay=0.0005)

    gc.collect()
    torch.cuda.empty_cache()

    train_history = train_model(
        model,
        train_loader=train_loader,
        optimizer=optimizer,
        device=device,
        num_epochs=5,
        save_dir="/kaggle/working/maskrcnn_training_results",
        val_loader=val_loader,
    )

    if val_loader is not None:
        results = evaluate_model(model, val_loader, device, num_classes=NUM_CLASSES)
        print("Evaluation:", results)

    print("🎉 Done.")


# Evaluation

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import precision_score, recall_score, f1_score, average_precision_score, confusion_matrix
from collections import defaultdict
import torch
import time
from datetime import datetime
import os


def validate_model(
    model,
    val_loader,
    device,
    save_dir="training_results"
):
 
    print("=" * 80)
    print(f"Starting validation at {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
    print("=" * 80)
    
    validation_start_time = time.time()
    
    
    val_loss, val_components = validate_maskrcnn(model, val_loader, device)
    
    validation_time = time.time() - validation_start_time
    
    print(f"Validation Loss: {val_loss:.4f}")
    print(f"Validation time: {format_time(validation_time)}")
    print("=" * 80)
    
    return {
        'val_loss': val_loss,
        'val_loss_components': val_components,
        'validation_time': validation_time
    }

In [None]:
classes = [
    "background",  
    "short sleeve top",
    "trousers",
    "shorts",
    "long sleeve top",
    "skirt",
    "vest dress",
    "short sleeve dress",
    "vest",
    "long sleeve outwear",
    "long sleeve dress",
    "sling dress",
    "sling",
    "short sleeve outwear",
]
def evaluate_model_comprehensive(
    model,
    data_loader,
    device,
    num_classes=14,
    class_names=classes,
    save_dir="training_results",
    dataset_name="validation"
):
    
    print("=" * 80)
    print(f"Starting comprehensive evaluation at {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
    print(f"Dataset: {dataset_name}")
    print("=" * 80)
    
    evaluation_start_time = time.time()
    
    results = evaluate_model(model, data_loader, device, num_classes=num_classes)
    
    print("Generating confusion matrix...")
    y_true, y_pred = get_predictions_for_confusion_matrix(model, data_loader, device)
    
    if len(y_true) > 0 and len(y_pred) > 0:
        plot_confusion_matrix(y_true, y_pred, class_names, save_dir)
    
    plot_metrics_summary(results, class_names, save_dir)
    
    evaluation_time = time.time() - evaluation_start_time
    
    print(f"\nEvaluation completed in: {format_time(evaluation_time)}")
    
    results_summary = {
        'evaluation_time': evaluation_time,
        'metrics': results
    }
    
    with open(f'{save_dir}/evaluation_summary_{dataset_name}.txt', 'w') as f:
        f.write(f"MASK R-CNN EVALUATION SUMMARY - {dataset_name.upper()}\n")
        f.write("=" * 50 + "\n\n")
        f.write(f"Evaluation completed at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
        f.write(f"Evaluation time: {format_time(evaluation_time)}\n")
        f.write(f"Number of classes: {num_classes}\n\n")
        
        f.write("METRICS:\n")
        f.write("-" * 30 + "\n")
        for metric, value in results.items():
            if isinstance(value, dict):
                f.write(f"{metric}:\n")
                for k, v in value.items():
                    if class_names and k < len(class_names):
                        f.write(f"  {class_names[k]}: {v:.4f}\n")
                    else:
                        f.write(f"  {k}: {v:.4f}\n")
            else:
                f.write(f"{metric}: {value:.4f}\n")
    
    print(f"\nEvaluation results saved to: {save_dir}")
    print("Files generated:")
    print("- confusion_matrix.png")
    print("- metrics_summary.png")
    print(f"- evaluation_summary_{dataset_name}.txt")
    
    return results_summary

def get_predictions_for_confusion_matrix(model, data_loader, device, score_thresh=0.5, iou_thresh=0.5):
    model.eval()
    all_pred_labels = []
    all_true_labels = []
    
    with torch.no_grad():
        for images, targets in data_loader:
            images = [img.to(device) for img in images]
            outputs = model(images)
            
            for output, target in zip(outputs, targets):
                pred_scores = output["scores"].cpu().numpy()
                pred_labels = output["labels"].cpu().numpy()
                pred_boxes = output["boxes"].cpu().numpy()
                
                keep = pred_scores >= score_thresh
                if np.any(keep):
                    pred_labels = pred_labels[keep]
                    pred_boxes = pred_boxes[keep]
                else:
                    pred_labels = np.array([])
                    pred_boxes = np.array([]).reshape(0, 4)
                
                gt_labels = target["labels"].cpu().numpy()
                gt_boxes = target["boxes"].cpu().numpy()
                
                matched_preds, matched_gts = match_predictions_to_gt(
                    pred_boxes, pred_labels, gt_boxes, gt_labels, iou_thresh
                )
                
                all_pred_labels.extend(matched_preds)
                all_true_labels.extend(matched_gts)
    
    return np.array(all_true_labels), np.array(all_pred_labels)



def match_predictions_to_gt(pred_boxes, pred_labels, gt_boxes, gt_labels, iou_thresh):
    matched_preds = []
    matched_gts = []
    used_gt = set()
    
    for pred_box, pred_label in zip(pred_boxes, pred_labels):
        best_iou = 0
        best_gt_idx = -1
        
        for gt_idx, (gt_box, gt_label) in enumerate(zip(gt_boxes, gt_labels)):
            if gt_idx in used_gt:
                continue
            
            if pred_label == gt_label:
                iou = calculate_box_iou(pred_box, gt_box)
                if iou > best_iou and iou >= iou_thresh:
                    best_iou = iou
                    best_gt_idx = gt_idx
        
        if best_gt_idx != -1:
            matched_preds.append(pred_label)
            matched_gts.append(gt_labels[best_gt_idx])
            used_gt.add(best_gt_idx)
        else:
            matched_preds.append(pred_label)
            matched_gts.append(0)  
    
    for gt_idx, gt_label in enumerate(gt_labels):
        if gt_idx not in used_gt:
            matched_preds.append(0)  
            matched_gts.append(gt_label)


    return matched_preds, matched_gts

def plot_training_curves(train_losses, val_losses, train_maps, val_maps, train_loss_components, val_loss_components, save_dir="plots"):
    os.makedirs(save_dir, exist_ok=True)
    
    epochs = range(1, len(train_losses) + 1)
    
    fig, axes = plt.subplots(2, 3, figsize=(18, 12))
    fig.suptitle('Training Progress', fontsize=16, fontweight='bold')
    
    axes[0, 0].plot(epochs, train_losses, 'b-', label='Train Loss', linewidth=2)
    axes[0, 0].plot(epochs, val_losses, 'r-', label='Val Loss', linewidth=2)
    axes[0, 0].set_title('Total Loss', fontweight='bold')
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].set_ylabel('Loss')
    axes[0, 0].legend()
    axes[0, 0].grid(True, alpha=0.3)
    
    if train_maps and val_maps:
        axes[0, 1].plot(epochs, train_maps, 'b-', label='Train mAP', linewidth=2)
        axes[0, 1].plot(epochs, val_maps, 'r-', label='Val mAP', linewidth=2)
        axes[0, 1].set_title('Mean Average Precision (mAP)', fontweight='bold')
        axes[0, 1].set_xlabel('Epoch')
        axes[0, 1].set_ylabel('mAP')
        axes[0, 1].legend()
        axes[0, 1].grid(True, alpha=0.3)
    else:
        axes[0, 1].text(0.5, 0.5, 'mAP data not available', 
                       ha='center', va='center', transform=axes[0, 1].transAxes)
        axes[0, 1].set_title('Mean Average Precision (mAP)', fontweight='bold')
    
    train_cls = [comp['classifier'] for comp in train_loss_components]
    val_cls = [comp['classifier'] for comp in val_loss_components]
    axes[0, 2].plot(epochs, train_cls, 'b-', label='Train', linewidth=2)
    axes[0, 2].plot(epochs, val_cls, 'r-', label='Val', linewidth=2)
    axes[0, 2].set_title('Classification Loss', fontweight='bold')
    axes[0, 2].set_xlabel('Epoch')
    axes[0, 2].set_ylabel('Loss')
    axes[0, 2].legend()
    axes[0, 2].grid(True, alpha=0.3)
    
    train_box = [comp['box_reg'] for comp in train_loss_components]
    val_box = [comp['box_reg'] for comp in val_loss_components]
    axes[1, 0].plot(epochs, train_box, 'b-', label='Train', linewidth=2)
    axes[1, 0].plot(epochs, val_box, 'r-', label='Val', linewidth=2)
    axes[1, 0].set_title('Box Regression Loss', fontweight='bold')
    axes[1, 0].set_xlabel('Epoch')
    axes[1, 0].set_ylabel('Loss')
    axes[1, 0].legend()
    axes[1, 0].grid(True, alpha=0.3)
    
    train_mask = [comp['mask'] for comp in train_loss_components]
    val_mask = [comp['mask'] for comp in val_loss_components]
    axes[1, 1].plot(epochs, train_mask, 'b-', label='Train', linewidth=2)
    axes[1, 1].plot(epochs, val_mask, 'r-', label='Val', linewidth=2)
    axes[1, 1].set_title('Mask Loss', fontweight='bold')
    axes[1, 1].set_xlabel('Epoch')
    axes[1, 1].set_ylabel('Loss')
    axes[1, 1].legend()
    axes[1, 1].grid(True, alpha=0.3)
    
    train_obj = [comp['objectness'] for comp in train_loss_components]
    val_obj = [comp['objectness'] for comp in val_loss_components]
    train_rpn = [comp['rpn_box_reg'] for comp in train_loss_components]
    val_rpn = [comp['rpn_box_reg'] for comp in val_loss_components]
    
    axes[1, 2].plot(epochs, train_obj, 'b-', label='Train Objectness', linewidth=2)
    axes[1, 2].plot(epochs, val_obj, 'r-', label='Val Objectness', linewidth=2)
    axes[1, 2].plot(epochs, train_rpn, 'b--', label='Train RPN Box', linewidth=2)
    axes[1, 2].plot(epochs, val_rpn, 'r--', label='Val RPN Box', linewidth=2)
    axes[1, 2].set_title('RPN Losses', fontweight='bold')
    axes[1, 2].set_xlabel('Epoch')
    axes[1, 2].set_ylabel('Loss')
    axes[1, 2].legend()
    axes[1, 2].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(f'{save_dir}/training_curves.png', dpi=300, bbox_inches='tight')
    plt.show()

def plot_confusion_matrix(y_true, y_pred, class_names=None, save_dir="plots"):
    os.makedirs(save_dir, exist_ok=True)
    
    cm = confusion_matrix(y_true, y_pred)
    
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                xticklabels=class_names, yticklabels=class_names,
                cbar_kws={'label': 'Count'})
    plt.title('Confusion Matrix', fontsize=16, fontweight='bold')
    plt.xlabel('Predicted Label', fontsize=12)
    plt.ylabel('True Label', fontsize=12)
    plt.tight_layout()
    plt.savefig(f'{save_dir}/confusion_matrix.png', dpi=300, bbox_inches='tight')
    plt.show()

def plot_metrics_summary(results, class_names=None, save_dir="plots"):
    os.makedirs(save_dir, exist_ok=True)
    
    fig, axes = plt.subplots(2, 2, figsize=(15, 12))
    fig.suptitle('Model Performance Metrics', fontsize=16, fontweight='bold')
    
    if 'detection_per_class_AP' in results and results['detection_per_class_AP']:
        classes = list(results['detection_per_class_AP'].keys())
        ap_values = list(results['detection_per_class_AP'].values())
        
        if class_names is not None:
            class_labels = [class_names[c] if c < len(class_names) else f'Class {c}' for c in classes]
        else:
            class_labels = [f'Class {c}' for c in classes]
        
        axes[0, 0].bar(range(len(classes)), ap_values, alpha=0.7)
        axes[0, 0].set_title('Per-Class AP (Detection)', fontweight='bold')
        axes[0, 0].set_xlabel('Class')
        axes[0, 0].set_ylabel('Average Precision')
        axes[0, 0].set_xticks(range(len(classes)))
        axes[0, 0].set_xticklabels(class_labels, rotation=45, ha='right')
        axes[0, 0].grid(True, alpha=0.3)
    
    if 'per_class_IoU' in results and results['per_class_IoU']:
        classes = list(results['per_class_IoU'].keys())
        iou_values = list(results['per_class_IoU'].values())
        
        if class_names is not None:
            class_labels = [class_names[c] if c < len(class_names) else f'Class {c}' for c in classes]
        else:
            class_labels = [f'Class {c}' for c in classes]
        
        axes[0, 1].bar(range(len(classes)), iou_values, alpha=0.7, color='orange')
        axes[0, 1].set_title('Per-Class IoU', fontweight='bold')
        axes[0, 1].set_xlabel('Class')
        axes[0, 1].set_ylabel('IoU')
        axes[0, 1].set_xticks(range(len(classes)))
        axes[0, 1].set_xticklabels(class_labels, rotation=45, ha='right')
        axes[0, 1].grid(True, alpha=0.3)
    
    metrics = ['precision', 'recall', 'f1', 'detection_mAP_050', 'segmentation_mAP_050', 'mIoU', 'pixel_accuracy']
    metric_values = [results.get(m, 0) for m in metrics]
    
    axes[1, 0].bar(range(len(metrics)), metric_values, alpha=0.7, color='green')
    axes[1, 0].set_title('Overall Performance Metrics', fontweight='bold')
    axes[1, 0].set_xlabel('Metric')
    axes[1, 0].set_ylabel('Score')
    axes[1, 0].set_xticks(range(len(metrics)))
    axes[1, 0].set_xticklabels(metrics, rotation=45, ha='right')
    axes[1, 0].grid(True, alpha=0.3)
    
    if 'segmentation_per_class_AP' in results and results['segmentation_per_class_AP']:
        det_classes = list(results['detection_per_class_AP'].keys())
        seg_classes = list(results['segmentation_per_class_AP'].keys())
        
        common_classes = list(set(det_classes) & set(seg_classes))
        
        if common_classes:
            det_ap = [results['detection_per_class_AP'][c] for c in common_classes]
            seg_ap = [results['segmentation_per_class_AP'][c] for c in common_classes]
            
            if class_names is not None:
                class_labels = [class_names[c] if c < len(class_names) else f'Class {c}' for c in common_classes]
            else:
                class_labels = [f'Class {c}' for c in common_classes]
            
            x = np.arange(len(common_classes))
            width = 0.35
            
            axes[1, 1].bar(x - width/2, det_ap, width, label='Detection AP', alpha=0.7)
            axes[1, 1].bar(x + width/2, seg_ap, width, label='Segmentation AP', alpha=0.7)
            axes[1, 1].set_title('Detection vs Segmentation AP', fontweight='bold')
            axes[1, 1].set_xlabel('Class')
            axes[1, 1].set_ylabel('Average Precision')
            axes[1, 1].set_xticks(x)
            axes[1, 1].set_xticklabels(class_labels, rotation=45, ha='right')
            axes[1, 1].legend()
            axes[1, 1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(f'{save_dir}/metrics_summary.png', dpi=300, bbox_inches='tight')
    plt.show()

def format_time(seconds):
    hours = int(seconds // 3600)
    minutes = int((seconds % 3600) // 60)
    seconds = int(seconds % 60)
    return f"{hours:02d}:{minutes:02d}:{seconds:02d}"


In [None]:


checkpoint_path = "/kaggle/input/mask-rcnn-model/RCNN_model (1).pth"
model.load_state_dict(torch.load(checkpoint_path, map_location="cpu"))

model.eval()

val_results = validate_model(
    model=model,
    val_loader=val_loader,
    device=device,
    save_dir="/kaggle/working/maskrcnn_training_results"
)

In [None]:

evaluation_results = evaluate_model_comprehensive(
    model=model,
    data_loader=val_loader,
    device=device,
    num_classes=14,
    class_names=classes,  
    save_dir="/kaggle/working/maskrcnn_training_results",
    dataset_name="validation"
)
