# Setup and imports

In [1]:
import os
from pathlib import Path
import xml.etree.ElementTree as ET

import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
import torchvision.models as models

In [3]:
DATA_DIR = Path("datasets")
IMAGE_DIR = DATA_DIR / "images"
ANN_DIR = DATA_DIR / "annotations"

IMG_SIZE = 128
BATCH_SIZE = 32
DEVICE = torch.device("cuda" if torch.cuda.is_available else "cpu")

# Parse XML : Get crops and labels
We'll first parse the XML files to extract the bounding box coordinates and labels for each image.

## Label mapping

In [None]:
label_map = {
    "with_mask": 1, # mask
    "without_mask": 0,  #  no mask
    "mask_weard_incorrect": 0   # treat incorrect as no mask
}

In [None]:
def parse_annotation(xml_path):
    """Parse a single XML annotation file."""
    tree = ET.parse(xml_path)
    root = tree.getroot()
    filename = root.find("filename").text # type: ignore
    objects = []
    for obj in root.findall("object"):
        name = obj.find("name").text # type: ignore
        bbox = obj.find("bndbox")
        xmin = int(bbox.find("xmin").text) # type: ignore
        ymin = int(bbox.find("ymin").text)   # type: ignore
        xmax = int(bbox.find("xmax").text)
        ymax = int(bbox.find("ymax").text)
        objects.append({"label": name, "bbox": (xmin, ymin, xmax, ymax)})
    
    return filename, objects

