In [1]:
import json
import random
import numpy as np
import pandas as pd
from pathlib import Path
from PIL import Image, ImageDraw, ImageFont

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

import torchvision.transforms as T
import torchvision.models as models

from sklearn.model_selection import GroupShuffleSplit
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix, classification_report
from tqdm import tqdm


JSON_PATH = Path("LS_export_26.12.json")
LABELS_CSV = Path("label_type_gender.csv")


SEED = 42
TEST_SIZE = 0.2
BATCH_SIZE = 16
EPOCHS = 30
LR = 1e-4
PATIENCE = 5
MIN_DELTA = 1e-4

PAD = 0.0
MIN_CROP_PX = 16

OUT_DIR = Path("gender_run_out_perspective_aug_A")
OUT_DIR.mkdir(parents=True, exist_ok=True)
BEST_PATH = OUT_DIR / "best_model.pt"

MIS_DIR = OUT_DIR / "mistakes"
LOW_DIR = OUT_DIR / "low_confidence"
MIS_DIR.mkdir(parents=True, exist_ok=True)
LOW_DIR.mkdir(parents=True, exist_ok=True)

SAVE_TOP_MISTAKES = 60
SAVE_TOP_LOWCONF = 60
LOWCONF_THRESHOLD = 0.55

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)


def seed_everything(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    try:
        torch.use_deterministic_algorithms(True)
    except Exception:
        pass

def seed_worker(worker_id: int):
    worker_seed = (torch.initial_seed() + worker_id) % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

seed_everything(SEED)
g = torch.Generator()
g.manual_seed(SEED)

# 

with open(JSON_PATH, "r", encoding="utf-8") as f:
    tasks = json.load(f)

label_df = pd.read_csv(LABELS_CSV, encoding="utf-8")
label2gender = dict(zip(label_df["label"], label_df["gender"]))

rows = []
missing_images = 0

for task in tasks:
    task_id = task.get("id")
    data = task.get("data", {}) or {}

    image_path = data.get("image_local_path")
    museum_number = data.get("Museum number")

    if not image_path:
        missing_images += 1
        continue

    image_path = Path(image_path)
    if not image_path.exists():
        missing_images += 1
        continue

    for ann in task.get("annotations", []):
        for r in ann.get("result", []):
            if r.get("type") != "rectanglelabels":
                continue

            v = r.get("value", {}) or {}
            labels = v.get("rectanglelabels", [])
            if not labels:
                continue

            label = labels[0]
            gender = label2gender.get(label)
            if gender not in {"male", "female"}:
                continue

            rows.append({
                "task_id": task_id,
                "museum_number": museum_number,
                "image_path": str(image_path),
                "label": label,
                "gender": gender,
                "x": float(v["x"]),
                "y": float(v["y"]),
                "w": float(v["width"]),
                "h": float(v["height"]),
            })

df = pd.DataFrame(rows)

print("Missing images:", missing_images)
print("Total bboxes (male/female):", len(df))
print("Gender counts:\n", df["gender"].value_counts(dropna=False))

df.to_csv(OUT_DIR / "bboxes_gender_df.csv", index=False, encoding="utf-8")

# 

df["museum_number"] = df["museum_number"].fillna(df["task_id"].astype(str))

gss = GroupShuffleSplit(n_splits=1, test_size=TEST_SIZE, random_state=SEED)
train_idx, val_idx = next(gss.split(df, groups=df["museum_number"]))

train_df = df.iloc[train_idx].reset_index(drop=True)
val_df = df.iloc[val_idx].reset_index(drop=True)

print("\nTrain:", len(train_df), "Val:", len(val_df))
print("Train gender:\n", train_df["gender"].value_counts())
print("Val gender:\n", val_df["gender"].value_counts())

train_df.to_csv(OUT_DIR / "train_df.csv", index=False, encoding="utf-8")
val_df.to_csv(OUT_DIR / "val_df.csv", index=False, encoding="utf-8")

#

class GenderBboxDataset(Dataset):
    def __init__(self, df: pd.DataFrame, transform=None, pad: float = 0.0):
        self.df = df.reset_index(drop=True)
        self.transform = transform
        self.pad = pad

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img = Image.open(row.image_path).convert("RGB")
        W, H = img.size

        x1 = row.x / 100.0 * W
        y1 = row.y / 100.0 * H
        x2 = (row.x + row.w) / 100.0 * W
        y2 = (row.y + row.h) / 100.0 * H

        if (x2 - x1) < MIN_CROP_PX or (y2 - y1) < MIN_CROP_PX:
            raise ValueError(
                f"Too-small crop at idx={idx}: "
                f"w={(x2-x1):.2f}px h={(y2-y1):.2f}px "
                f"task_id={row.task_id} image={row.image_path}"
            )

        if self.pad > 0:
            pad_x = self.pad * (x2 - x1)
            pad_y = self.pad * (y2 - y1)
            x1 = max(0, x1 - pad_x)
            y1 = max(0, y1 - pad_y)
            x2 = min(W, x2 + pad_x)
            y2 = min(H, y2 + pad_y)

        crop = img.crop((x1, y1, x2, y2))

        if self.transform:
            crop = self.transform(crop)

        y = 1 if row.gender == "male" else 0

        meta = {
            "task_id": row.task_id,
            "museum_number": row.museum_number,
            "image_path": row.image_path,
            "label": row.label,
            "gender": row.gender,
            "bbox_px": (float(x1), float(y1), float(x2), float(y2)),
        }
        return crop, y, meta

def collate_keep_meta(batch):
    xs = torch.stack([b[0] for b in batch], dim=0)
    ys = torch.tensor([b[1] for b in batch], dtype=torch.long)
    metas = [b[2] for b in batch]
    return xs, ys, metas

# 

transform_train = T.Compose([
    T.Resize((224, 224)),

    T.RandomPerspective(distortion_scale=0.10, p=0.25),

    T.RandomAffine(
        degrees=4,
        translate=(0.02, 0.02),
        scale=(0.98, 1.02),
        shear=None,
    ),

    T.RandomHorizontalFlip(p=0.5),

    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),

    T.RandomErasing(p=0.05, scale=(0.01, 0.03), ratio=(0.3, 3.3), value="random"),
])

transform_val = T.Compose([
    T.Resize((224, 224)),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

#

train_ds = GenderBboxDataset(train_df, transform=transform_train, pad=PAD)
val_ds = GenderBboxDataset(val_df, transform=transform_val, pad=PAD)

train_loader = DataLoader(
    train_ds,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=0,
    collate_fn=collate_keep_meta,
    worker_init_fn=seed_worker,
    generator=g,
)
val_loader = DataLoader(
    val_ds,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=0,
    collate_fn=collate_keep_meta,
    worker_init_fn=seed_worker,
    generator=g,
)

print("\nReady loaders:", len(train_ds), len(val_ds))

#

n_female = int((train_df["gender"] == "female").sum())
n_male = int((train_df["gender"] == "male").sum())
print("Train counts:", {"female": n_female, "male": n_male})

w_female = n_male / (n_female + n_male)  # class 0
w_male = n_female / (n_female + n_male)  # class 1
class_weights = torch.tensor([w_female, w_male], dtype=torch.float32).to(device)
print("Class weights:", class_weights.detach().cpu().numpy())

criterion = nn.CrossEntropyLoss(weight=class_weights)

#

model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
model.fc = nn.Linear(model.fc.in_features, 2)
model = model.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=LR)

#

def train_one_epoch():
    model.train()
    total_loss = 0.0
    for x, y, _metas in tqdm(train_loader, desc="train", leave=False):
        x = x.to(device)
        y = y.to(device)

        optimizer.zero_grad(set_to_none=True)
        logits = model(x)
        loss = criterion(logits, y)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * y.size(0)
    return total_loss / len(train_loader.dataset)

def _save_crop_with_caption(meta, pred_label, p_male, out_path: Path):
    img = Image.open(meta["image_path"]).convert("RGB")
    x1, y1, x2, y2 = meta["bbox_px"]
    crop = img.crop((x1, y1, x2, y2))

    conf = max(p_male, 1.0 - p_male)
    caption = f"true={meta['gender']} | pred={pred_label} | P(male)={p_male:.3f} | conf={conf:.3f} | label={meta['label']}"

    W, H = crop.size
    canvas = Image.new("RGB", (W, H + 40), (255, 255, 255))
    canvas.paste(crop, (0, 0))
    draw = ImageDraw.Draw(canvas)

    try:
        font = ImageFont.truetype("arial.ttf", 16)
    except Exception:
        font = ImageFont.load_default()

    draw.text((6, H + 10), caption, fill=(0, 0, 0), font=font)
    canvas.save(out_path)

def save_diagnostics(y_true, y_pred, p_male_list, metas):
    mistakes = []
    lowconf = []

    for yt, yp, pm, meta in zip(y_true, y_pred, p_male_list, metas):
        true_label = "male" if yt == 1 else "female"
        pred_label = "male" if yp == 1 else "female"
        conf = max(pm, 1.0 - pm)

        if true_label != pred_label:
            mistakes.append((conf, meta, pred_label, pm))

        if conf < LOWCONF_THRESHOLD:
            lowconf.append((conf, meta, pred_label, pm))

    mistakes.sort(key=lambda x: x[0], reverse=True)
    lowconf.sort(key=lambda x: x[0])

    for i, (_conf, meta, pred_label, pm) in enumerate(mistakes[:SAVE_TOP_MISTAKES], start=1):
        out = MIS_DIR / f"mistake_{i:03d}_task{meta['task_id']}.jpg"
        _save_crop_with_caption(meta, pred_label, pm, out)

    for i, (_conf, meta, pred_label, pm) in enumerate(lowconf[:SAVE_TOP_LOWCONF], start=1):
        out = LOW_DIR / f"lowconf_{i:03d}_task{meta['task_id']}.jpg"
        _save_crop_with_caption(meta, pred_label, pm, out)

@torch.no_grad()
def eval_epoch(save_examples: bool = False):
    model.eval()
    y_true, y_pred = [], []
    p_male_list = []
    metas_all = []

    for x, y, metas in tqdm(val_loader, desc="val", leave=False):
        x = x.to(device)
        logits = model(x)
        p = torch.softmax(logits, dim=1).cpu()

        pred = p.argmax(dim=1).numpy().tolist()
        p_male = p[:, 1].numpy().tolist()

        y_true += y.numpy().tolist()
        y_pred += pred
        p_male_list += p_male
        metas_all += metas

    acc = accuracy_score(y_true, y_pred)
    f1m = f1_score(y_true, y_pred, average="macro")

    if save_examples:
        save_diagnostics(y_true, y_pred, p_male_list, metas_all)

    return acc, f1m, y_true, y_pred

#

best_f1 = -1.0
best_epoch = 0
pat_left = PATIENCE
history = []

for epoch in range(1, EPOCHS + 1):
    train_loss = train_one_epoch()
    acc, f1m, y_true, y_pred = eval_epoch(save_examples=False)

    history.append({
        "epoch": epoch,
        "train_loss": float(train_loss),
        "val_acc": float(acc),
        "val_f1_macro": float(f1m),
    })

    print(f"Epoch {epoch:02d} | loss={train_loss:.4f} | val_acc={acc:.3f} | val_f1_macro={f1m:.3f}")

    if f1m > best_f1 + MIN_DELTA:
        best_f1 = f1m
        best_epoch = epoch
        pat_left = PATIENCE
        torch.save(model.state_dict(), BEST_PATH)
        print(f"  New best macro-F1={best_f1:.3f} (epoch {best_epoch}) saved to {BEST_PATH}")
    else:
        pat_left -= 1
        print(f"  no improvement, patience left: {pat_left}/{PATIENCE}")
        if pat_left <= 0:
            print(f"\nEarly stopping at epoch {epoch}. Best macro-F1={best_f1:.3f} at epoch {best_epoch}.")
            break

pd.DataFrame(history).to_csv(OUT_DIR / "history.csv", index=False, encoding="utf-8")

#

model.load_state_dict(torch.load(BEST_PATH, map_location=device))
acc, f1m, y_true, y_pred = eval_epoch(save_examples=True)

cm = confusion_matrix(y_true, y_pred, labels=[0, 1])
report = classification_report(y_true, y_pred, target_names=["female", "male"], digits=3, zero_division=0)

print("\nCONFUSION MATRIX (rows=true, cols=pred):\n", cm)
print("\nREPORT:\n", report)

with open(OUT_DIR / "report.txt", "w", encoding="utf-8") as f:
    f.write("CONFUSION MATRIX (rows=true, cols=pred)\n")
    f.write(str(cm) + "\n\n")
    f.write(report + "\n")

print(f"\nBest model: {BEST_PATH} | best val_f1_macro={best_f1:.3f} at epoch {best_epoch}")
print(f"All outputs in: {OUT_DIR.resolve()}")
print(f"Mistake crops: {MIS_DIR.resolve()}")
print(f"Low-confidence crops: {LOW_DIR.resolve()}")


Device: cpu
Missing images: 0
Total bboxes (male/female): 520
Gender counts:
 gender
male      359
female    161
Name: count, dtype: int64

Train: 425 Val: 95
Train gender:
 gender
male      296
female    129
Name: count, dtype: int64
Val gender:
 gender
male      63
female    32
Name: count, dtype: int64

Ready loaders: 425 95
Train counts: {'female': 129, 'male': 296}
Class weights: [0.69647056 0.3035294 ]


                                                      

Epoch 01 | loss=0.6127 | val_acc=0.747 | val_f1_macro=0.721
  New best macro-F1=0.721 (epoch 1) saved to gender_run_out_perspective_aug_A\best_model.pt


                                                      

Epoch 02 | loss=0.2815 | val_acc=0.695 | val_f1_macro=0.681
  no improvement, patience left: 4/5


                                                      

Epoch 03 | loss=0.2019 | val_acc=0.768 | val_f1_macro=0.664
  no improvement, patience left: 3/5


                                                      

Epoch 04 | loss=0.0917 | val_acc=0.747 | val_f1_macro=0.721
  no improvement, patience left: 2/5


                                                      

Epoch 05 | loss=0.0746 | val_acc=0.779 | val_f1_macro=0.746
  New best macro-F1=0.746 (epoch 5) saved to gender_run_out_perspective_aug_A\best_model.pt


                                                      

Epoch 06 | loss=0.0942 | val_acc=0.768 | val_f1_macro=0.709
  no improvement, patience left: 4/5


                                                      

Epoch 07 | loss=0.1064 | val_acc=0.758 | val_f1_macro=0.727
  no improvement, patience left: 3/5


                                                      

Epoch 08 | loss=0.0873 | val_acc=0.768 | val_f1_macro=0.741
  no improvement, patience left: 2/5


                                                      

Epoch 09 | loss=0.1113 | val_acc=0.789 | val_f1_macro=0.756
  New best macro-F1=0.756 (epoch 9) saved to gender_run_out_perspective_aug_A\best_model.pt


                                                      

Epoch 10 | loss=0.0967 | val_acc=0.726 | val_f1_macro=0.712
  no improvement, patience left: 4/5


                                                      

Epoch 11 | loss=0.0538 | val_acc=0.832 | val_f1_macro=0.783
  New best macro-F1=0.783 (epoch 11) saved to gender_run_out_perspective_aug_A\best_model.pt


                                                      

Epoch 12 | loss=0.0602 | val_acc=0.832 | val_f1_macro=0.793
  New best macro-F1=0.793 (epoch 12) saved to gender_run_out_perspective_aug_A\best_model.pt


                                                      

Epoch 13 | loss=0.0199 | val_acc=0.789 | val_f1_macro=0.764
  no improvement, patience left: 4/5


                                                      

Epoch 14 | loss=0.0426 | val_acc=0.779 | val_f1_macro=0.758
  no improvement, patience left: 3/5


                                                      

Epoch 15 | loss=0.0322 | val_acc=0.811 | val_f1_macro=0.777
  no improvement, patience left: 2/5


                                                      

Epoch 16 | loss=0.0295 | val_acc=0.821 | val_f1_macro=0.791
  no improvement, patience left: 1/5


                                                      

Epoch 17 | loss=0.0273 | val_acc=0.779 | val_f1_macro=0.754
  no improvement, patience left: 0/5

Early stopping at epoch 17. Best macro-F1=0.793 at epoch 12.


                                                  


CONFUSION MATRIX (rows=true, cols=pred):
 [[19 13]
 [ 3 60]]

REPORT:
               precision    recall  f1-score   support

      female      0.864     0.594     0.704        32
        male      0.822     0.952     0.882        63

    accuracy                          0.832        95
   macro avg      0.843     0.773     0.793        95
weighted avg      0.836     0.832     0.822        95


Best model: gender_run_out_perspective_aug_A\best_model.pt | best val_f1_macro=0.793 at epoch 12
All outputs in: C:\Users\Katya\mag_vase\ML\gender_run_out_perspective_aug_A
Mistake crops: C:\Users\Katya\mag_vase\ML\gender_run_out_perspective_aug_A\mistakes
Low-confidence crops: C:\Users\Katya\mag_vase\ML\gender_run_out_perspective_aug_A\low_confidence
