In [5]:
from ultralytics import YOLO
import numpy as np
import cv2
import torch
from pathlib import Path
import albumentations as A
from albumentations.pytorch import ToTensorV2
from torch import nn
import torch
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
import torch.nn.functional as F
import time
import matplotlib.pyplot as plt

In [20]:
yolo_path = "../yolo/yolo_h5.pt"
test_image_id = 1075
training_data_dir = "../data/training_data"
test_data_dir = "../data/test_data"
imgdir = "../data/miami_fall_24_jpgs"

def get_img(img_id):
    return cv2.imread(f"../data/miami_fall_24_jpgs/{img_id}.jpg")

def get_tps_coords(img_id, img):
    cls = ["finger", "toe"]
    ret = {}
    h, w = img.shape[:2]
    for c in cls:
        fp = f"../data/tps_files/{img_id}_{c}.TPS"
        coordinates = []
        skip = 2
        with open(fp, "r") as f:
            for line in f:
                line = line.strip()
                if not line or "=" in line:
                    continue
                parts = line.split()
                if len(parts) == 2:
                    if skip > 0:
                        skip -= 1
                        continue
                    try:
                        x , y = map(float, parts)
                        
                        coordinates.append((x, h - 1 - y))
                        #print((x, y))
                    except ValueError:
                        continue
        ret[c] = coordinates
    return ret

In [7]:
def crop_toe_boxes(r, image, g_coords, show=False, output_name="output"):
    target_classes = [2, 3] #["bot_finger", "bot_toe"]
    test_classes = [0, 1] #["up_finger", "up_toe"]
    classmap = {2: "finger", 3: "toe"}
    crops = []  # store cropped images
    coords_list = []  # store corresponding coordinates (for later heatmaps)
    tps = []
    result = r[0]  # first image
    boxes = result.boxes
    
    # Convert to numpy arrays for convenience
    xyxy = boxes.xyxy.cpu().numpy()   # shape (N, 4)
    cls_ids = boxes.cls.cpu().numpy() # shape (N,)
    conf = boxes.conf.cpu().numpy()   # optional if you want confidence filtering
    
    # Loop and filter
    for (x1, y1, x2, y2), cls_id in zip(xyxy, cls_ids):
        if int(cls_id) in target_classes:
            # Crop the image
            x1i, y1i, x2i, y2i = map(int, [x1, y1, x2, y2])
            crop = image[y1i:y2i, x1i:x2i].copy()  # copy to avoid referencing original image
            coords_list.append([x1i, y1i, x2i, y2i])  # store original coordinates
            
            l_coords = []
            valid = True
            for (x, y) in g_coords[classmap[int(cls_id)]]:
                x_local = x - x1
                y_local = y - y1
                # Check valid
                if not (0 <= x_local < (x2i - x1i) and 0 <= y_local < (y2i - y1i)):
                    valid = False
                    break

                l_coords.append((x_local, y_local))

            if not valid:
                continue
            
            crops.append(crop)
            tps.append(l_coords)

            if show:
                copy = crop.copy()
                for (lx, ly) in l_coords:
                    cv2.circle(copy, (int(round(lx)), int(round(ly))), 5, (0, 0, 255), -1)
                cv2.imshow(f"Crop Class {int(cls_id)}", copy)
                cv2.waitKey(0)   # waits for a key press
                cv2.destroyWindow(f"Crop Class {int(cls_id)}")
        elif int(cls_id) in test_classes:
            x1i, y1i, x2i, y2i = map(int, [x1, y1, x2, y2])
            crop = image[y1i:y2i, x1i:x2i].copy()
            outpath = f"{test_data_dir}/{output_name}_{classmap[int(cls_id)+2]}.jpg"
            #print(outpath)
            cv2.imwrite(outpath, crop)
            
    return crops, coords_list, tps



In [38]:
IMAGENORMALIZE = A.Compose(
        [
            A.LongestMaxSize(max_size=224),
            A.PadIfNeeded(
                min_height=224,
                min_width=224,
                border_mode=0,
                value=0,
            ),
            A.Normalize(
                mean=(0.485, 0.456, 0.406),
                std=(0.229, 0.224, 0.225)
            ),
            ToTensorV2(),
        ],
        keypoint_params=A.KeypointParams(
            format="xy",
            remove_invisible=False
        )
    )
IMAGENET_MEAN = np.array([0.485, 0.456, 0.406])
IMAGENET_STD  = np.array([0.229, 0.224, 0.225])

def preprocess_image(crop, tps):
    aug = IMAGENORMALIZE(image=crop, keypoints=tps)
    return aug["image"], aug["keypoints"]

def produce_training(r, img, tps_coords, name="sample"):
    crops, box_coords, local_tps_coords = crop_toe_boxes(r, img, tps_coords, False)
    a_rescaled = []
    a_coords = []
    for i in range(len(crops)):
        if len(local_tps_coords[i]) != 9:
            continue
        rescaled, r_coords = preprocess_image(crops[i], local_tps_coords[i])
        data = {
            "image": rescaled,
            "keypoints": torch.tensor(r_coords, dtype=torch.float32)
        }
        torch.save(data, f"{training_data_dir}/vit/{name}_{i}.pt")
        a_rescaled.append(rescaled)
        a_coords.append(r_coords)

def inspect_data(img_tensor, tps):
    img = img_tensor.permute(1,2,0).cpu().numpy()
    img = img * IMAGENET_STD + IMAGENET_MEAN   # denormalize
    img = (img * 255).clip(0,255).astype(np.uint8)
    img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
    for i, (x, y) in enumerate(tps):
        x, y = int(x), int(y)
        cv2.circle(img, (x,y), 4, (0,0,255), -1)
        cv2.putText(
            img,
            str(i),
            (x+5, y-5),
            cv2.FONT_HERSHEY_SIMPLEX,
            0.5,
            (0,255,0),
            1,
            cv2.LINE_AA
        )
    cv2.imshow("Debug", img)

    #print("Press any key for next image, or 'q' to quit")
    key = cv2.waitKey(0)

    if key == ord('q'):
        cv2.destroyAllWindows()
        return False  # stop loop
    return True

  A.PadIfNeeded(


In [39]:
def process_images():
    model = YOLO(yolo_path)
    dir_path = Path(imgdir)
    count = 0
    for file in dir_path.iterdir():
        print(f"Processing file {count}", end="\r", flush=True)
        try:
            if ".jpg" in file.name:
                imgid = file.name.replace(".jpg", "")
                if int(imgid) > 1000:
                    process_image(imgid, model)
            count += 1
        except Exception as e:
            count += 1
            print()
            print(f"Failed to process file {file}: {e}")
            #continue
            break

def load_pts():
    dir_path = Path(f"{training_data_dir}/vit")
    pt_files = sorted(dir_path.glob("*.pt"))
    for pt_file in pt_files:
        data = torch.load(pt_file)
        image_tensor = data["image"]
        keypoints = data["keypoints"]
        cont = inspect_data(image_tensor, keypoints)
        if not cont:
            break

def process_image(imgid, model):
    img = get_img(imgid)
    tps = get_tps_coords(imgid, img)
    r = model(img, verbose=False)
    produce_training(r, img, tps, imgid)

In [41]:
process_images()

Processing file 849
Failed to process file ..\data\miami_fall_24_jpgs\2.99.jpg: invalid literal for int() with base 10: '2.99'


In [42]:
import torch, numpy as np, glob
from pathlib import Path

d = Path("../data/training_data/vit")
files = sorted(d.glob("*.pt"))
print("Found files:", len(files))

bad = []
stats = {"img_min":[], "img_max":[], "img_mean":[], "kp_min":[], "kp_max":[], "kp_len":[]}

for p in files:
    data = torch.load(p)
    img = data["image"]
    kps = data["keypoints"]
    # image checks
    stats["img_min"].append(float(img.min()))
    stats["img_max"].append(float(img.max()))
    stats["img_mean"].append(float(img.mean()))
    # keypoint checks
    stats["kp_min"].append(float(kps.min()))
    stats["kp_max"].append(float(kps.max()))
    stats["kp_len"].append(int(kps.numel()))
    # sanity
    if kps.shape not in [(9,2),(18,)]:
        bad.append((p, "bad shape", kps.shape))
    if len(kps) == 0:
        bad.append((p, "zero keypoints"))
    # any kp outside 0..224?
    if (kps < -10).any() or (kps > 1000).any():
        bad.append((p, "kps extreme values"))
        
print("Image min/max/mean (examples):", np.percentile(stats["img_min"], [0,50,100]), np.percentile(stats["img_max"], [0,50,100]), np.mean(stats["img_mean"]))
print("Kp min/max examples:", np.percentile(stats["kp_min"], [0,50,100]), np.percentile(stats["kp_max"], [0,50,100]))
print("Keypoint tensor sizes (numel): unique:", set(stats["kp_len"]))
if bad:
    print("Bad files (first 10):", bad[:10])
else:
    print("No obvious bad files found.")

Found files: 1623
Image min/max/mean (examples): [    -2.1179     -2.1179     -1.1932] [    0.91451       1.786      2.6226] -0.5137831589248799
Kp min/max examples: [    0.27038      18.723      39.233] [     185.77      208.28      223.97]
Keypoint tensor sizes (numel): unique: {18}
No obvious bad files found.
