In [2]:
import os
import json
import random
import numpy as np
from PIL import Image
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models

# ----------------------------------------------------------------------------
# 1) CONFIGURATION
# ----------------------------------------------------------------------------
DATA_JSON = "../data/processed/final_annotations_without_occluded.json"
IMG_DIR   = "../data/images"
DEVICE    = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SEED      = 42

# tile & presence crop sizes
TILE_SIZE     = 224
PRESENCE_SIZE = 64

BATCH_TILE      = 64
BATCH_OFFSET    = 32
BATCH_PRESENCE  = 128

# ----------------------------------------------------------------------------
# 2) UTILITIES
# ----------------------------------------------------------------------------
def set_seed(seed=SEED):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed()

def collate_fn(batch):
    return tuple(zip(*batch))

# ----------------------------------------------------------------------------
# 3) LOAD ANNOTATIONS & SPLIT
# ----------------------------------------------------------------------------
with open(DATA_JSON, "r") as f:
    ann = json.load(f)

images = list(ann["images"].keys())
random.shuffle(images)
n_test  = int(len(images)*0.2)
n_val   = int((len(images)-n_test)*0.1)

test_imgs  = images[:n_test]
val_imgs   = images[n_test:n_test+n_val]
train_imgs = images[n_test+n_val:]

def subset(subset_list):
    return {
        "all_parts": ann["all_parts"],
        "images": {fn: ann["images"][fn] for fn in subset_list}
    }

train_ann = subset(train_imgs)
val_ann   = subset(val_imgs)
test_ann  = subset(test_imgs)

# ----------------------------------------------------------------------------
# 4) PRECOMPUTE STATIC AVERAGE OFFSETS
# ----------------------------------------------------------------------------
# Build mapping part_name->idx and vice versa
parts       = ann["all_parts"]
part2idx    = {p:i for i,p in enumerate(parts)}
idx2part    = {i:p for p,i in part2idx.items()}
num_parts   = len(parts)

# collect all offsets
offsets = { seed_i: { tgt_j: [] for tgt_j in range(num_parts) if tgt_j!=seed_i }
            for seed_i in range(num_parts) }

for fn,info in train_ann["images"].items():
    # compute box centers
    centers = {}
    sizes   = {}
    for pi in info["available_parts"]:
        idx = part2idx[pi["part_name"]]
        bb  = pi["absolute_bounding_box"]
        cx  = bb["left"] + bb["width"]/2
        cy  = bb["top" ] + bb["height"]/2
        centers[idx] = (cx,cy)
        sizes[idx]   = (bb["width"], bb["height"])
    for s, (cx,cy) in centers.items():
        for t,(tx,ty) in centers.items():
            if t==s: continue
            dx = tx-cx
            dy = ty-cy
            tw,th = sizes[t]
            offsets[s][t].append((dx,dy,tw,th))

# average them
avg_offsets = {}
for s,d in offsets.items():
    avg_offsets[s] = {}
    for t,lst in d.items():
        arr = np.array(lst)
        mean = arr.mean(axis=0)
        avg_offsets[s][t] = mean  # dx,dy,w,h

# ----------------------------------------------------------------------------
# 5) DATASETS
# ----------------------------------------------------------------------------
class TileDataset(Dataset):
    """Produces random (or center) tiles with multi-label parts."""
    def __init__(self, annotations, image_dir, tile_size=TILE_SIZE, tiles_per_img=5):
        self.images     = list(annotations["images"].keys())
        self.image_dir  = image_dir
        self.tile_size  = tile_size
        self.tiles_pi   = tiles_per_img
        self.ann        = annotations["images"]
        self.tf         = transforms.Compose([
            transforms.Resize((tile_size,tile_size)),
            transforms.ToTensor()
        ])
    def __len__(self):
        return len(self.images)*self.tiles_pi
    def __getitem__(self, idx):
        im_idx = idx // self.tiles_pi
        fn     = self.images[im_idx]
        img    = Image.open(os.path.join(self.image_dir,fn)).convert("RGB")
        W,H    = img.size

        # sample center tile for first, then random
        if idx % self.tiles_pi == 0:
            x0,y0 = (W-self.tile_size)//2,(H-self.tile_size)//2
        else:
            x0 = random.randint(0, max(0,W-self.tile_size))
            y0 = random.randint(0, max(0,H-self.tile_size))

        crop = img.crop((x0,y0,x0+self.tile_size,y0+self.tile_size))
        label = torch.zeros(num_parts,dtype=torch.float32)
        for pi in self.ann[fn]["available_parts"]:
            pidx = part2idx[pi["part_name"]]
            bb   = pi["absolute_bounding_box"]
            cx   = bb["left"]+bb["width"]/2
            cy   = bb["top"] +bb["height"]/2
            if x0<=cx<=x0+self.tile_size and y0<=cy<=y0+self.tile_size:
                label[pidx] = 1.0
        return self.tf(crop), label

class PresenceDataset(Dataset):
    """Crops predicted boxes + binary present/missing label."""
    def __init__(self, annotations, image_dir, avg_offsets, tile_size=PRESENCE_SIZE):
        self.images     = list(annotations["images"].keys())
        self.ann        = annotations["images"]
        self.image_dir  = image_dir
        self.avg_off    = avg_offsets
        self.tile_size  = tile_size
        self.tf         = transforms.Compose([
            transforms.Resize((tile_size,tile_size)),
            transforms.ToTensor()
        ])
        self.items = []
        # prepare all (fn, seed, target, present_flag)
        for fn,info in self.ann.items():
            # collect seed centers
            seeds = []
            for pi in info["available_parts"]:
                si = part2idx[pi["part_name"]]
                bb = pi["absolute_bounding_box"]
                cx,cy = bb["left"]+bb["width"]/2, bb["top"]+bb["height"]/2
                seeds.append((si,cx,cy))
            # for each seed, each target part
            for si,cx,cy in seeds:
                for ti in range(num_parts):
                    dx,dy,tw,th = self.avg_off[si].get(ti,(0,0,0,0))
                    x0 = int(cx+dx - tw/2)
                    y0 = int(cy+dy - th/2)
                    self.items.append((fn,ti,x0,y0))
    def __len__(self):
        return len(self.items)
    def __getitem__(self, idx):
        fn,ti,x0,y0 = self.items[idx]
        img = Image.open(os.path.join(self.image_dir,fn)).convert("RGB")
        W,H = img.size
        # clamp
        x0 = max(0, min(x0, W-self.tile_size))
        y0 = max(0, min(y0, H-self.tile_size))
        crop = img.crop((x0,y0, x0+self.tile_size, y0+self.tile_size))
        label = 0
        # check if truly present in GT
        for pi in self.ann[fn]["available_parts"]:
            if part2idx[pi["part_name"]] == ti:
                label = 1
        return self.tf(crop), torch.tensor(label, dtype=torch.float32)

# ----------------------------------------------------------------------------
# 6) MODELS
# ----------------------------------------------------------------------------
class TileNet(nn.Module):
    def __init__(self, num_parts):
        super().__init__()
        self.backbone = models.resnet18(pretrained=True)
        self.backbone.fc = nn.Linear(self.backbone.fc.in_features, num_parts)
    def forward(self, x):
        return torch.sigmoid(self.backbone(x))

class PresenceNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(3,32,3,padding=1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(32,64,3,padding=1), nn.ReLU(), nn.MaxPool2d(2)
        )
        self.fc   = nn.Sequential(
            nn.Flatten(), nn.Linear(64*(PRESENCE_SIZE//4)**2,128),
            nn.ReLU(), nn.Linear(128,1)
        )
    def forward(self, x):
        return torch.sigmoid(self.fc(self.conv(x))).squeeze(1)

# ----------------------------------------------------------------------------
# 7) TRAINING FUNCTIONS
# ----------------------------------------------------------------------------
def train_tile_detector():
    ds = TileDataset(train_ann, IMG_DIR)
    dl = DataLoader(ds, batch_size=BATCH_TILE, shuffle=True, num_workers=4)
    model = TileNet(num_parts).to(DEVICE)
    opt   = torch.optim.Adam(model.parameters(), lr=1e-4)
    for epoch in range(5):
        model.train()
        tot,loss=0,0
        for xb,yb in tqdm(dl, desc=f"TileTrain {epoch}"):
            xb,yb = xb.to(DEVICE), yb.to(DEVICE)
            pred  = model(xb)
            l     = F.binary_cross_entropy(pred, yb)
            opt.zero_grad(); l.backward(); opt.step()
            tot+=1; loss+=l.item()
        print(" Epoch",epoch,"Loss",loss/tot)
    torch.save(model.state_dict(),"tile_net.pth")
    return model

def train_presence_verifier():
    ds = PresenceDataset(train_ann, IMG_DIR, avg_offsets)
    dl = DataLoader(ds, batch_size=BATCH_PRESENCE, shuffle=True, num_workers=4)
    model = PresenceNet().to(DEVICE)
    opt   = torch.optim.Adam(model.parameters(), lr=1e-4)
    for epoch in range(5):
        model.train()
        tot,loss=0,0
        for xb,yb in tqdm(dl, desc=f"PresTrain {epoch}"):
            xb,yb = xb.to(DEVICE), yb.to(DEVICE)
            pred  = model(xb)
            l     = F.binary_cross_entropy(pred, yb)
            opt.zero_grad(); l.backward(); opt.step()
            tot+=1; loss+=l.item()
        print(" Epoch",epoch,"Loss",loss/tot)
    torch.save(model.state_dict(),"pres_net.pth")
    return model

# ----------------------------------------------------------------------------
# 8) INFERENCE PIPELINE
# ----------------------------------------------------------------------------
def inference_on_split(split_ann):
    tile_ds = TileDataset(split_ann, IMG_DIR, tiles_per_img=1)
    pres_model = PresenceNet().to(DEVICE)
    pres_model.load_state_dict(torch.load("pres_net.pth"))
    pres_model.eval()

    tile_model = TileNet(num_parts).to(DEVICE)
    tile_model.load_state_dict(torch.load("tile_net.pth"))
    tile_model.eval()

    results = []
    for fn in tqdm(split_ann["images"].keys(), desc="Inf images"):
        img = Image.open(os.path.join(IMG_DIR,fn)).convert("RGB")
        W,H = img.size

        # 1) get central tile
        x0,y0 = (W-TILE_SIZE)//2,(H-TILE_SIZE)//2
        crop = img.crop((x0,y0,x0+TILE_SIZE,y0+TILE_SIZE))
        inp  = transforms.ToTensor()(transforms.Resize((TILE_SIZE,TILE_SIZE))(crop)).unsqueeze(0).to(DEVICE)

        # 2) detect seeds
        with torch.no_grad():
            preds = tile_model(inp)[0].cpu().numpy()
        seeds = [(i,p) for i,p in enumerate(preds) if p>0.5]

        # if no seeds, treat all missing
        if not seeds:
            missing = set(range(num_parts))
            results.append((fn,missing)); continue

        # 3) for every seed, estimate each part box, run presence net
        part_conf = {i:0.0 for i in range(num_parts)}
        for si,conf in seeds:
            # assume seed center is tile-center
            cx = x0+TILE_SIZE/2; cy=y0+TILE_SIZE/2
            for ti in range(num_parts):
                if ti==si: 
                    part_conf[ti] = max(part_conf[ti], conf)
                    continue
                dx,dy,w,h = avg_offsets[si].get(ti,(0,0,0,0))
                x1 = int(cx+dx - PRESENCE_SIZE/2)
                y1 = int(cy+dy - PRESENCE_SIZE/2)
                x1 = max(0,min(x1,W-PRESENCE_SIZE))
                y1 = max(0,min(y1,H-PRESENCE_SIZE))
                box_crop = img.crop((x1,y1,x1+PRESENCE_SIZE,y1+PRESENCE_SIZE))
                inp2 = transforms.ToTensor()(transforms.Resize((PRESENCE_SIZE,PRESENCE_SIZE))(box_crop)).unsqueeze(0).to(DEVICE)
                with torch.no_grad():
                    p = pres_model(inp2).item()
                part_conf[ti] = max(part_conf[ti], p)

        # 4) missing = those with conf < 0.5
        missing = {ti for ti,c in part_conf.items() if c < 0.5}
        results.append((fn, missing))

    return results

# ----------------------------------------------------------------------------
# 9) RUN EVERYTHING
# ----------------------------------------------------------------------------

    # 1) train submodules
tile_model = train_tile_detector()
pres_model = train_presence_verifier()

# 2) inference & evaluation
for split,ann_split in [("VAL",val_ann),("TEST",test_ann)]:
    res = inference_on_split(ann_split)
    # compute simple metrics:
    Yt, Yp = [], []
    for fn,miss in res:
        gt = set([part2idx[p] for p in ann_split["images"][fn].get("missing_parts",[])])
        Yt.append([1 if i in gt else 0 for i in range(num_parts)])
        Yp.append([1 if i in miss else 0 for i in range(num_parts)])
    Yt = np.array(Yt); Yp = np.array(Yp)
    from sklearn.metrics import f1_score, precision_score, recall_score
    print(f"{split} Micro-F1:", f1_score(Yt, Yp, average="micro", zero_division=0))
    print(f"{split} Macro-F1:", f1_score(Yt, Yp, average="macro", zero_division=0))
    print(f"{split} Precision:", precision_score(Yt, Yp, average="micro", zero_division=0))
    print(f"{split} Recall:", recall_score(Yt, Yp, average="micro", zero_division=0))


TileTrain 0:   2%|▏         | 11/449 [00:50<33:19,  4.57s/it]


KeyboardInterrupt: 