In [None]:
!pip install pyyaml tqdm

In [None]:
import os
from pathlib import Path

from PIL import Image
import yaml
from tqdm import tqdm

import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms


In [None]:
# Root of the YOLO-format AnimalClue footprint dataset
# The folder that contains 'images', 'labels', 'data.yaml'
YOLO_ROOT = Path("/path/to/footprint_yolo")

# Where cropped patches are saved as a classification dataset
PATCH_ROOT = Path("/path/to/footprint_patches")

# YOLO-style splits
SPLITS = ["train", "val", "test"]

YOLO_IMAGES_DIR = YOLO_ROOT / "images"
YOLO_LABELS_DIR = YOLO_ROOT / "labels"
DATA_YAML_PATH = YOLO_ROOT / "data.yaml"


In [None]:
def load_class_names(yaml_path):
    """
    Expects a YOLO data.yaml with a 'names' field, e.g.:
      names: [dog, cat, ...]
    Returns list: index -> class_name
    """
    with open(yaml_path, "r") as f:
        data = yaml.safe_load(f)
    names = data.get("names", None)
    if isinstance(names, dict):
        # Convert to list in index order
        names = [names[k] for k in sorted(names.keys())]
    return names

class_names = load_class_names(DATA_YAML_PATH)
print(f"Loaded {len(class_names)} class names.")
print(class_names[:10])  


In [None]:
def yolo_to_xyxy(bbox, img_width, img_height, margin_factor=0.1):
    """
    Convert YOLO normalized bbox (cx, cy, w, h) to pixel coords (x_min, y_min, x_max, y_max).
    margin_factor adds a bit of extra context around the footprint (e.g., 0.1 = 10%).

    bbox: (cx, cy, w, h) all in [0,1]
    """
    cx, cy, bw, bh = bbox

    # Convert normalized center/size to pixel center/size
    cx_pix = cx * img_width
    cy_pix = cy * img_height
    bw_pix = bw * img_width
    bh_pix = bh * img_height

    # Add margin
    bw_pix *= (1.0 + margin_factor)
    bh_pix *= (1.0 + margin_factor)

    x_min = cx_pix - bw_pix / 2.0
    x_max = cx_pix + bw_pix / 2.0
    y_min = cy_pix - bh_pix / 2.0
    y_max = cy_pix + bh_pix / 2.0

    # Clip to image bounds
    x_min = max(0, int(round(x_min)))
    y_min = max(0, int(round(y_min)))
    x_max = min(img_width - 1, int(round(x_max)))
    y_max = min(img_height - 1, int(round(y_max)))

    # Ensure valid box
    if x_max <= x_min or y_max <= y_min:
        return None

    return x_min, y_min, x_max, y_max


In [None]:
def find_image_for_label(label_path, images_dir, img_exts=(".jpg", ".jpeg", ".png")):
    """
    Given a label file path '.../xxx.txt', look for 'xxx.jpg' (or .jpeg/.png)
    in images_dir.
    """
    stem = label_path.stem
    for ext in img_exts:
        candidate = images_dir / f"{stem}{ext}"
        if candidate.exists():
            return candidate
    return None


def create_classification_patches(
    yolo_images_dir,
    yolo_labels_dir,
    patch_root,
    splits,
    class_names,
    margin_factor=0.1,
    min_size=5,
):
    """
    Main preprocessing pipeline:
      - Read YOLO labels
      - Crop patches around footprints
      - Save to PATCH_ROOT / split / class_name
    """
    patch_root = Path(patch_root)
    patch_root.mkdir(parents=True, exist_ok=True)

    for split in splits:
        print(f"\n=== Processing split: {split} ===")
        split_images_dir = yolo_images_dir / split
        split_labels_dir = yolo_labels_dir / split

        if not split_labels_dir.exists():
            print(f"Warning: labels dir for split '{split}' not found: {split_labels_dir}")
            continue

        # Count patches 
        patch_count = 0

        # All label files in this split
        label_files = sorted(split_labels_dir.glob("*.txt"))

        for label_path in tqdm(label_files, desc=f"{split} labels"):
            img_path = find_image_for_label(label_path, split_images_dir)
            if img_path is None:
                print(f"Warning: no image found for label: {label_path}")
                continue

            with Image.open(img_path) as img:
                img = img.convert("RGB")
                w, h = img.size

                with label_path.open("r") as f:
                    for idx_line, line in enumerate(f):
                        line = line.strip()
                        if not line:
                            continue
                        parts = line.split()
                        if len(parts) != 5:
                            print(f"Bad label format in {label_path}: {line}")
                            continue

                        class_id = int(parts[0])
                        cx, cy, bw, bh = map(float, parts[1:])

                        # Convert to pixel coords
                        box = yolo_to_xyxy((cx, cy, bw, bh), w, h,
                                           margin_factor=margin_factor)
                        if box is None:
                            continue
                        x_min, y_min, x_max, y_max = box

                        box_w = x_max - x_min
                        box_h = y_max - y_min
                        # Skip tiny patches (noise)
                        if box_w < min_size or box_h < min_size:
                            continue

                        patch = img.crop((x_min, y_min, x_max, y_max))

                        # Determine output dir: PATCH_ROOT/split/class_name
                        class_name = class_names[class_id]
                        out_dir = patch_root / split / class_name
                        out_dir.mkdir(parents=True, exist_ok=True)

                        out_name = f"{img_path.stem}_{idx_line}.jpg"
                        out_path = out_dir / out_name
                        patch.save(out_path)

                        patch_count += 1

        print(f"Finished split '{split}'. Total patches saved: {patch_count}")


In [None]:
create_classification_patches(
    yolo_images_dir=YOLO_IMAGES_DIR,
    yolo_labels_dir=YOLO_LABELS_DIR,
    patch_root=PATCH_ROOT,
    splits=SPLITS,
    class_names=class_names,
    margin_factor=0.1,   # tune this
    min_size=10,         # skip super tiny patches
)


In [None]:
class FootprintPatchDataset(Dataset):
    """
    Classification-style dataset for cropped footprint patches.
    Assumes directory structure:
      root_dir/
        class_name0/
          *.jpg
        class_name1/
          *.jpg
        ...
    """
    def __init__(self, root_dir, transform=None):
        self.root_dir = Path(root_dir)
        self.transform = transform

        self.image_paths = []
        self.labels = []
        self.class_to_idx = {}
        self.idx_to_class = []

        # Walk subdirectories
        for class_idx, class_name in enumerate(sorted(os.listdir(self.root_dir))):
            class_path = self.root_dir / class_name
            if not class_path.is_dir():
                continue

            self.class_to_idx[class_name] = class_idx
            self.idx_to_class.append(class_name)

            for fname in os.listdir(class_path):
                if fname.lower().endswith((".jpg", ".jpeg", ".png")):
                    self.image_paths.append(class_path / fname)
                    self.labels.append(class_idx)

        print(f"Loaded {len(self.image_paths)} images "
              f"from {self.root_dir}, {len(self.idx_to_class)} classes.")

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        label = self.labels[idx]
        image = Image.open(img_path).convert("RGB")

        if self.transform is not None:
            image = self.transform(image)

        return image, label


In [None]:
class ContrastiveTransform:
    """
    Wraps a base transform and returns TWO augmented views.
    """
    def __init__(self, base_transform):
        self.base_transform = base_transform

    def __call__(self, x):
        xi = self.base_transform(x)
        xj = self.base_transform(x)
        return xi, xj

image_size = 128  # must match your encoder

base_transform = transforms.Compose([
    transforms.RandomResizedCrop(image_size, scale=(0.6, 1.0)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomApply([
        transforms.ColorJitter(brightness=0.4, contrast=0.4,
                               saturation=0.4, hue=0.1)
    ], p=0.8),
    transforms.RandomGrayscale(p=0.2),
    transforms.GaussianBlur(kernel_size=9, sigma=(0.1, 2.0)),
    transforms.ToTensor(),
    # You can add normalization here if you like
])

contrastive_transform = ContrastiveTransform(base_transform)


In [None]:
batch_size = 64

train_dataset_contrastive = FootprintPatchDataset(
    PATCH_ROOT / "train",
    transform=contrastive_transform
)

train_loader_contrastive = DataLoader(
    train_dataset_contrastive,
    batch_size=batch_size,
    shuffle=True,
    num_workers=4,
    drop_last=True
)

# For supervised training / evaluation, you might want a simpler transform
eval_transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
])

val_dataset = FootprintPatchDataset(
    PATCH_ROOT / "val",
    transform=eval_transform
)
val_loader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=4
)
