# Generate Bounding Boxes and Pose Keypoints

In [None]:
import os
import cv2
import xml.etree.ElementTree as ET
import mediapipe as mp
import numpy as np
from tqdm import tqdm

In [None]:
# Path to JAAD dataset and output directory
jaad_path = r'D:\mgr\PedestrianIntentionEstimation\JAAD_dataset\JAAD_clips'
annotation_dir = r'D:\mgr\PedestrianIntentionEstimation\JAAD_dataset\annotations'
output_dir = os.path.join(jaad_path, 'frames_with_bboxes')
pose_output_dir = os.path.join(jaad_path, 'pose_keypoints')
os.makedirs(output_dir, exist_ok=True)
os.makedirs(pose_output_dir, exist_ok=True)

In [None]:
# Initialize Mediapipe Pose
mp_pose = mp.solutions.pose
pose = mp_pose.Pose(static_image_mode=False, model_complexity=2, enable_segmentation=False, min_detection_confidence=0.4)

def draw_keypoints(image, keypoints):
    for keypoint in keypoints:
        x = int(keypoint[0] * image.shape[1])
        y = int(keypoint[1] * image.shape[0])
        cv2.circle(image, (x, y), 5, (0, 0, 255), -1)
    return image


def extract_and_save_frames_with_bboxes_and_pose_keypoints(
        video_path, annotation_path, video_output_dir, pose_output_dir,
        pad=0.15, min_box=40, min_det_conf=0.7, min_vis=0.5):
    cap = cv2.VideoCapture(video_path)
    video_name = os.path.basename(video_path).split('.')[0]
    os.makedirs(video_output_dir, exist_ok=True)
    os.makedirs(pose_output_dir, exist_ok=True)

    tree = ET.parse(annotation_path); root = tree.getroot()
    W = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    H = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    eps = 1e-6

    mp_pose = mp.solutions.pose
    pose = mp_pose.Pose(static_image_mode=False, model_complexity=1,
                        enable_segmentation=False,
                        min_detection_confidence=min_det_conf,
                        min_tracking_confidence=min_det_conf)

    def boxes_for_frame(fid):
        out = []
        for track in root.findall('.//track'):
            for box in track.findall('.//box'):
                if int(box.get('frame')) == fid:
                    xtl = int(float(box.get('xtl'))); ytl = int(float(box.get('ytl')))
                    xbr = int(float(box.get('xbr'))); ybr = int(float(box.get('ybr')))
                    out.append((xtl, ytl, xbr, ybr))
        return out

    frame_id = 0
    with tqdm(total=total_frames, desc=f"Processing {video_name}", unit="frame") as pbar:
        while cap.isOpened():
            ret, frame_full = cap.read()
            if not ret:
                break

            boxes = boxes_for_frame(frame_id)
            for (xtl, ytl, xbr, ybr) in boxes:
                cv2.rectangle(frame_full, (xtl, ytl), (xbr, ybr), (0, 255, 0), 2)

            keypoints = np.zeros((33, 3), dtype=np.float32)

            if boxes:
                # choose biggest box
                areas = [ (xbr-xtl)*(ybr-ytl) for (xtl, ytl, xbr, ybr) in boxes ]
                xtl, ytl, xbr, ybr = boxes[int(np.argmax(areas))]
                bw, bh = xbr-xtl, ybr-ytl

                if bw >= min_box and bh >= min_box:
                    px = int(pad * bw); py = int(pad * bh)
                    cx1 = max(0, xtl - px); cy1 = max(0, ytl - py)
                    cx2 = min(W, xbr + px); cy2 = min(H, ybr + py)

                    if cx2 > cx1 and cy2 > cy1:
                        crop = frame_full[cy1:cy2, cx1:cx2]
                        r = pose.process(cv2.cvtColor(crop, cv2.COLOR_BGR2RGB))

                        if r.pose_landmarks:
                            lm = r.pose_landmarks.landmark
                            cw = max(1, cx2 - cx1); ch = max(1, cy2 - cy1)
                            out_pts = []
                            for p in lm:
                                vis = float(p.visibility) if (p.visibility is not None and np.isfinite(p.visibility)) else 0.0
                                zval = float(p.z) if (p.z is not None and np.isfinite(p.z)) else 0.0

                                if vis < min_vis:
                                    out_pts.append([0.0, 0.0, 0.0])
                                    continue

                                # remap to full frame pixels
                                x_pix = int(max(0, min(cw-1, p.x * cw))) + cx1
                                y_pix = int(max(0, min(ch-1, p.y * ch))) + cy1

                                x_pix = int(max(0, min(W-1, x_pix)))
                                y_pix = int(max(0, min(H-1, y_pix)))

                                x_n = x_pix / (W - 1 + eps)
                                y_n = y_pix / (H - 1 + eps)
                                out_pts.append([x_n, y_n, zval])

                                cv2.circle(frame_full, (x_pix, y_pix), 4, (0, 0, 255), -1)

                            keypoints = np.array(out_pts, dtype=np.float32)

            keypoints = np.nan_to_num(keypoints, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32)

            preview = cv2.resize(frame_full, (640, 480), interpolation=cv2.INTER_AREA)
            frame_filename = os.path.join(video_output_dir, f"{video_name}_frame_{frame_id:05d}.jpg")
            keypoints_filename = os.path.join(pose_output_dir,  f"{video_name}_frame_{frame_id:05d}.npy")
            cv2.imwrite(frame_filename, preview)
            np.save(keypoints_filename, keypoints)

            frame_id += 1
            pbar.update(1)

    cap.release()


#One-batch-per-run processing
BATCH_SIZE = 50

def is_already_processed(video_file: str) -> bool:
    vid_name = os.path.splitext(video_file)[0]
    frames_dir_vid = os.path.join(output_dir, vid_name)
    kpts_dir_vid   = os.path.join(pose_output_dir, vid_name)

    frames_ok = os.path.isdir(frames_dir_vid) and any(
        f.lower().endswith('.jpg') for f in os.listdir(frames_dir_vid)
    )
    kpts_ok = os.path.isdir(kpts_dir_vid) and any(
        f.lower().endswith('.npy') for f in os.listdir(kpts_dir_vid)
    )
    return frames_ok or kpts_ok

all_videos = sorted([f for f in os.listdir(jaad_path) if f.lower().endswith('.mp4')])
all_videos = [vf for vf in all_videos
              if os.path.exists(os.path.join(annotation_dir, vf.replace('.mp4', '.xml')))]

# choose unprocessed
remaining = [vf for vf in all_videos if not is_already_processed(vf)]

print(f"[INFO] Total videos: {len(all_videos)} | Already processed: {len(all_videos) - len(remaining)} | Remaining: {len(remaining)}")

if not remaining:
    print("[INFO] Nothing to do. All videos are already processed.")
else:
    current_batch = remaining[:BATCH_SIZE]
    print(f"[INFO] Processing one batch: {len(current_batch)} videos")

    for video_file in current_batch:
        video_path = os.path.join(jaad_path, video_file)
        annotation_file = video_file.replace('.mp4', '.xml')
        annotation_path = os.path.join(annotation_dir, annotation_file)

        video_output_dir = os.path.join(output_dir, os.path.splitext(video_file)[0])
        video_pose_output_dir = os.path.join(pose_output_dir, os.path.splitext(video_file)[0])

        os.makedirs(video_output_dir, exist_ok=True)
        os.makedirs(video_pose_output_dir, exist_ok=True)

        extract_and_save_frames_with_bboxes_and_pose_keypoints(
            video_path, annotation_path, video_output_dir, video_pose_output_dir
        )

    left_after = len(remaining) - len(current_batch)
    print(f"\n[INFO] Batch done. Remaining videos to process next runs: {left_after}")

In [None]:
def count_generated_files(frames_root, kpts_root):
    def count_ext(root, ext):
        total = 0
        per_video = {}
        if not os.path.isdir(root):
            return 0, {}
        for vd in sorted(os.listdir(root)):
            vpath = os.path.join(root, vd)
            if not os.path.isdir(vpath):
                continue
            n = sum(1 for f in os.listdir(vpath) if f.lower().endswith(ext))
            per_video[vd] = n
            total += n
        return total, per_video

    total_jpg, jpg_per_video = count_ext(frames_root, ".jpg")
    total_npy, npy_per_video = count_ext(kpts_root, ".npy")

    all_vids = sorted(set(list(jpg_per_video.keys()) + list(npy_per_video.keys())))
    mismatches = []
    missing_frames = []
    missing_kpts = []

    for vid in all_vids:
        n_jpg = jpg_per_video.get(vid, 0)
        n_npy = npy_per_video.get(vid, 0)
        if n_jpg == 0 and n_npy > 0:
            missing_frames.append((vid, n_jpg, n_npy))
        elif n_npy == 0 and n_jpg > 0:
            missing_kpts.append((vid, n_jpg, n_npy))
        elif n_jpg != n_npy:
            mismatches.append((vid, n_jpg, n_npy))

    return total_jpg, total_npy, mismatches, missing_frames, missing_kpts

total_jpg, total_npy, mismatches, missing_frames, missing_kpts = count_generated_files(
    output_dir, pose_output_dir
)

print("\n PODSUMOWANIE WYGENEROWANYCH PLIKÓW ")
print(f"Łącznie klatek (JPG):      {total_jpg}")
print(f"Łącznie keypointów (NPY):  {total_npy}")
print(f"Wideo z brakującymi KLATKAMI (jpg, ale są npy): {len(missing_frames)}")
print(f"Wideo z brakującymi KEYPOINTAMI (npy, ale są jpg): {len(missing_kpts)}")
print(f"Wideo z NIESPÓJNYMI LICZBAMI jpg != npy: {len(mismatches)}")

def preview(rows, title, k=5):
    if not rows:
        return
    print(f"\n{title} (pokazuję do {k}):")
    for vid, n_jpg, n_npy in rows[:k]:
        print(f"  {vid}: jpg={n_jpg}, npy={n_npy}")

preview(missing_frames, "Brakuje KLATEK (jpg, a są npy)")
preview(missing_kpts,   "Brakuje KEYPOINTÓW (npy, a są jpg)")
preview(mismatches,     "Niespójne liczby plików (jpg != npy)")


# Preprocess Annotations

In [None]:
import os
import xml.etree.ElementTree as ET
import numpy as np
import pickle

In [None]:
def preprocess_annotations(annotations_base_dir, cache_dir, video_names):
    os.makedirs(cache_dir, exist_ok=True)
    total_crossing = 0
    total_not_crossing = 0


    annotations_dir = os.path.join(annotations_base_dir, 'annotations')
    annotations_dirs = {
        'attributes': os.path.join(annotations_base_dir, 'annotations_attributes'),
        'appearance': os.path.join(annotations_base_dir, 'annotations_appearance'),
        'traffic': os.path.join(annotations_base_dir, 'annotations_traffic'),
        'vehicle': os.path.join(annotations_base_dir, 'annotations_vehicle')
    }

    for video in video_names:
        video_id = video.split('_')[1].split('.')[0]

        annotations_paths = {
            'annotations': os.path.join(annotations_dir, f"video_{video_id}.xml"),
            'attributes': os.path.join(annotations_dirs['attributes'], f"video_{video_id}_attributes.xml"),
            'appearance': os.path.join(annotations_dirs['appearance'], f"video_{video_id}_appearance.xml"),
            'traffic': os.path.join(annotations_dirs['traffic'], f"video_{video_id}_traffic.xml"),
            'vehicle': os.path.join(annotations_dirs['vehicle'], f"video_{video_id}_vehicle.xml")
        }

        annotations = {key: ET.parse(path).getroot() for key, path in annotations_paths.items() if os.path.exists(path)}

        preprocessed_data = []

        for track in annotations['annotations'].findall('.//track'):
            for box in track.findall('.//box'):
                frame_id = int(box.get('frame'))
                
                # find biggest pedestrian in frame
                boxes_in_frame = []
                for tr in annotations['annotations'].findall('.//track'):
                    for bx in tr.findall('.//box'):
                        if int(bx.get('frame')) == frame_id:
                            xtl = float(bx.get('xtl'))
                            ytl = float(bx.get('ytl'))
                            xbr = float(bx.get('xbr'))
                            ybr = float(bx.get('ybr'))
                            area = (xbr - xtl) * (ybr - ytl)
                            boxes_in_frame.append((bx, area))
                if not boxes_in_frame:
                    continue
                # choose biggest box
                largest_box, _ = max(boxes_in_frame, key=lambda x: x[1])

                if box != largest_box:
                    continue

                # label crossing / not crossing
                label = 0
                for attr in box.findall('attribute'):
                    if attr.get('name') == 'cross' and attr.text == 'crossing':
                        label = 1
                        break

                if label == 1:
                    total_crossing += 1
                else:
                    total_not_crossing += 1

                # additional info
                traffic_info = get_traffic_info(annotations.get('traffic', None), frame_id)
                vehicle_info = get_vehicle_info(annotations.get('vehicle', None), frame_id)
                appearance_info = get_appearance_info(annotations.get('appearance', None), frame_id)
                attributes_info = get_attributes_info(annotations.get('attributes', None), frame_id)

                preprocessed_data.append((frame_id, label, traffic_info, vehicle_info, appearance_info, attributes_info))


        with open(os.path.join(cache_dir, f"video_{video_id}.pkl"), 'wb') as f:
            pickle.dump(preprocessed_data, f)

    print(f"\nTotal 'crossing' samples: {total_crossing}")
    print(f"Total 'not crossing' samples: {total_not_crossing}")

def get_traffic_info(root, frame_id):
    traffic_info = {'ped_crossing': 0, 'ped_sign': 0, 'stop_sign': 0, 'traffic_light': 0}
    if root is not None:
        for frame in root.findall('.//frame'):
            if int(frame.get('id')) == frame_id:
                traffic_info = {
                    'ped_crossing': int(frame.get('ped_crossing')),
                    'ped_sign': int(frame.get('ped_sign')),
                    'stop_sign': int(frame.get('stop_sign')),
                    'traffic_light': 1 if frame.get('traffic_light') != 'n/a' else 0
                }
    return traffic_info

def get_vehicle_info(root, frame_id):
    vehicle_info = {'action': 0}
    if root is not None:
        for frame in root.findall('.//frame'):
            if int(frame.get('id')) == frame_id:
                action = frame.get('action')
                vehicle_info = {
                    'action': 1 if action == 'moving_slow' else 2 if action == 'decelerating' else 3 if action == 'stopped' else 4 if action is not None and action == 'accelerating' else 0
                }
    return vehicle_info

def get_appearance_info(root, frame_id):
    appearance_info = {'pose': 0, 'clothing': 0, 'objects': 0}
    if root is not None:
        for track in root.findall('.//track'):
            for box in track.findall('.//box'):
                if int(box.get('frame')) == frame_id:
                    appearance_info = {
                        'pose': float(box.find('pose').text) if box.find('pose') is not None else 0,
                        'clothing': float(box.find('clothing').text) if box.find('clothing') is not None else 0,
                        'objects': float(box.find('objects').text) if box.find('objects') is not None else 0
                    }
    return appearance_info

def get_attributes_info(root, frame_id):
    attributes_info = {'age': 0, 'gender': 0, 'crossing_point': 0}
    if root is not None:
        for track in root.findall('.//track'):
            for box in track.findall('.//box'):
                if int(box.get('frame')) == frame_id:
                    attributes_info = {
                        'age': float(box.find('age').text) if box.find('age') is not None else 0,
                        'gender': 1 if box.find('gender') is not None and box.find('gender').text == 'male' else 0 if box.find('gender') is not None and box.find('gender').text == 'female' else 0,
                        'crossing_point': float(box.find('crossing_point').text) if box.find('crossing_point') is not None else 0
                    }
    return attributes_info

In [None]:
annotations_base_dir = r'D:\mgr\PedestrianIntentionEstimation\JAAD_dataset'
cache_dir = r'D:\mgr\PedestrianIntentionEstimation\JAAD_dataset\cache'
video_names = sorted(os.listdir(os.path.join(annotations_base_dir, 'annotations')))
preprocess_annotations(annotations_base_dir, cache_dir, video_names)

print("Annotations preprocessed")

# Create Dataset

In [None]:
import os
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader, Subset
from sklearn.model_selection import train_test_split
from torchvision import transforms
import pickle
from tqdm import tqdm
from PIL import Image
from multiprocessing import Value, cpu_count
import sys

In [None]:
counter = Value('i', 0)

class JAADDataset(Dataset):
    def __init__(self, frames_dir, keypoints_dir, cache_dir, transform=None):
        self.frames_dir = frames_dir
        self.keypoints_dir = keypoints_dir
        self.cache_dir = cache_dir
        self.transform = transform
        self.video_names = sorted(os.listdir(frames_dir))
        self.data = []

        for video in self.video_names:
            video_id = video.split('_')[1]
            cache_file = os.path.join(cache_dir, f"video_{video_id}.pkl")
            if not os.path.exists(cache_file):
                continue
            with open(cache_file, 'rb') as f:
                video_data = pickle.load(f)
                for frame_id, label, traffic_info, vehicle_info, appearance_info, attributes_info in video_data:
                    frame_path = os.path.join(frames_dir, video, f"{video}_frame_{frame_id:05d}.jpg")
                    keypoint_file = os.path.join(keypoints_dir, video, f"{video}_frame_{frame_id:05d}.npy")
                    if os.path.exists(frame_path) and os.path.exists(keypoint_file):
                        self.data.append((frame_path, keypoint_file, label, traffic_info, vehicle_info, appearance_info, attributes_info))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        with counter.get_lock():
            counter.value += 1
            if counter.value % 1000 == 0:
                print(f"[INFO] __getitem__ called {counter.value} times", flush=True)

        frame_path, keypoint_file, label, traffic_info, vehicle_info, appearance_info, attributes_info = self.data[idx]
        frame = Image.open(frame_path).convert('RGB')

        keypoints = np.load(keypoint_file)
        if keypoints.size == 0:  # Handle empty keypoints
            keypoints = np.zeros((33, 3), dtype=np.float32)

        keypoints = torch.tensor(keypoints, dtype=torch.float32)
        traffic_info = torch.tensor(list(traffic_info.values()), dtype=torch.float32)
        vehicle_info = torch.tensor(list(vehicle_info.values()), dtype=torch.float32)
        appearance_info = torch.tensor(list(appearance_info.values()), dtype=torch.float32)
        attributes_info = torch.tensor(list(attributes_info.values()), dtype=torch.float32)

        if self.transform:
            frame = self.transform(frame)

        return frame, keypoints, label, traffic_info, vehicle_info, appearance_info, attributes_info



# Data augmentation transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

frames_dir = r'D:\mgr\PedestrianIntentionEstimation\JAAD_dataset\JAAD_clips\frames_with_bboxes'
keypoints_dir = r'D:\mgr\PedestrianIntentionEstimation\JAAD_dataset\JAAD_clips\pose_keypoints'
cache_dir = r'D:\mgr\PedestrianIntentionEstimation\JAAD_dataset\cache'

base_dataset = JAADDataset(frames_dir, keypoints_dir, cache_dir, transform)

all_labels = []
for i in range(len(base_dataset)):
    _, _, label, _, _, _, _ = base_dataset[i]
    all_labels.append(int(label))

all_labels = np.array(all_labels)
total_frames = len(all_labels)
n_crossing = (all_labels == 1).sum()
n_not_crossing = (all_labels == 0).sum()

print("\n STATYSTYKI ZBIORU ")
print(f"Łącznie klatek:      {total_frames}")
print(f"Crossing (1):        {n_crossing}  ({n_crossing / total_frames:.2%})")
print(f"Not crossing (0):    {n_not_crossing}  ({n_not_crossing / total_frames:.2%})")


train_indices, test_indices = train_test_split(list(range(len(base_dataset))), test_size=0.2, random_state=42)

train_set = Subset(base_dataset, train_indices)
test_set = Subset(base_dataset, test_indices)


train_save_dir = r'D:\mgr\PedestrianIntentionEstimation\JAAD_dataset\training_data'
test_save_dir = r'D:\mgr\PedestrianIntentionEstimation\JAAD_dataset\test_data'
os.makedirs(train_save_dir, exist_ok=True)
os.makedirs(test_save_dir, exist_ok=True)


if not os.listdir(train_save_dir) or not os.listdir(test_save_dir):

    def save_preprocessed_data(dataset, save_dir):
        for idx in tqdm(range(len(dataset)), desc=f"Saving to {os.path.basename(save_dir)}", unit="frame"):
            frame, keypoints, label, traffic_info, vehicle_info, appearance_info, attributes_info = dataset[idx]
            save_path = os.path.join(save_dir, f'data_{idx}.pt')
            torch.save({
                'frame': frame,
                'keypoints': keypoints,
                'label': label,
                'traffic_info': traffic_info,
                'vehicle_info': vehicle_info,
                'appearance_info': appearance_info,
                'attributes_info': attributes_info
            }, save_path)


    save_preprocessed_data(train_set, train_save_dir)
    save_preprocessed_data(test_set, test_save_dir)
else:
  print("Training data and testing data already exist. Skipping preprocessing.")


class PreprocessedDataset(Dataset):
    def __init__(self, data_dir, transform=None):
        self.data_dir = data_dir
        self.data_files = sorted([os.path.join(data_dir, f) for f in os.listdir(data_dir) if f.endswith('.pt')])
        self.transform = transform

    def __len__(self):
        return len(self.data_files)

    def __getitem__(self, idx):
        data = torch.load(self.data_files[idx])
        frame = data['frame']
        keypoints = data['keypoints']
        label = data['label']
        traffic_info = data['traffic_info']
        vehicle_info = data['vehicle_info']
        appearance_info = data['appearance_info']
        attributes_info = data['attributes_info']

        if self.transform:
            frame = self.transform(frame)

        return frame, keypoints, label, traffic_info, vehicle_info, appearance_info, attributes_info

train_dataset = PreprocessedDataset(train_save_dir, transform=None)
test_dataset = PreprocessedDataset(test_save_dir, transform=None)


def collate_fn(batch):
    frames, keypoints, labels, traffic_infos, vehicle_infos, appearance_infos, attributes_infos = zip(*batch)

    frames = torch.stack(frames)
    keypoints = torch.stack(keypoints)
    labels = torch.tensor(labels)
    traffic_infos = torch.stack(traffic_infos)
    vehicle_infos = torch.stack(vehicle_infos)
    appearance_infos = torch.stack(appearance_infos)
    attributes_infos = torch.stack(attributes_infos)

    return frames, keypoints, labels, traffic_infos, vehicle_infos, appearance_infos, attributes_infos


print("Datasets created.")

# Model definition

In [None]:
import torch
import torch.nn as nn
import torchvision.models as models

class PedestrianCrossingPredictor(nn.Module):
    def __init__(self):
        super(PedestrianCrossingPredictor, self).__init__()
        
        # VGG19 backbone
        vgg19 = models.vgg19(pretrained=True)
        self.vgg19_features = vgg19.features
        self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
        self.vgg19_classifier = nn.Sequential(*list(vgg19.classifier.children())[:-1])  # do 4096-D
        
        for param in self.vgg19_features[:36].parameters():
            param.requires_grad = False
        
        # dimensionality rediction
        self.fc_img = nn.Sequential(
            nn.Linear(4096, 512),
            nn.ReLU(),
            nn.BatchNorm1d(512),
            nn.Dropout(0.5),
            nn.Linear(512, 256),
            nn.ReLU()
        )
        
        # goal module
        self.goal_module = nn.Sequential(
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.BatchNorm1d(128),
            nn.Dropout(0.5),
            nn.Linear(128, 256),
            nn.ReLU()
        )
        
        # keypoints (99) + traffic (4) + vehicle (1) + appearance (3) + attributes (3) = 110
        self.com_fc1 = nn.Linear(256 + 256 + 110, 256)
        self.com_bn1 = nn.BatchNorm1d(256)
        self.com_d1 = nn.Dropout(0.5)
        self.com_fc2 = nn.Linear(256, 128)
        self.com_bn2 = nn.BatchNorm1d(128)
        self.com_d2 = nn.Dropout(0.5)
        self.com_fc3 = nn.Linear(128, 1)

    def forward(self, x, keypoints, traffic_info, vehicle_info, appearance_info, attributes_info):
        batch_size = x.size(0)
        
        # features
        c_out = self.vgg19_features(x)
        c_out = self.avgpool(c_out)
        c_out = torch.flatten(c_out, 1)
        c_out = self.vgg19_classifier(c_out)  # [batch, 4096]
        c_out = self.fc_img(c_out)            # [batch, 256]
        
        # additional data
        keypoints = keypoints.view(batch_size, -1)  # 33x3=99
        additional_info = torch.cat([keypoints, traffic_info, vehicle_info, appearance_info, attributes_info], dim=1)
        
        # goal module
        goal_out = self.goal_module(c_out)
        
        combined = torch.cat((c_out, goal_out, additional_info), dim=1)
        combined = self.com_fc1(combined)
        combined = torch.relu(combined)
        combined = self.com_bn1(combined)
        combined = self.com_d1(combined)
        combined = self.com_fc2(combined)
        combined = torch.relu(combined)
        combined = self.com_bn2(combined)
        combined = self.com_d2(combined)
        combined = self.com_fc3(combined)
        
        return combined


In [None]:
# weights initialization
def init_w(m):
    if isinstance(m, nn.Linear):
        nn.init.xavier_normal_(m.weight)
        nn.init.constant_(m.bias, 0)

model = PedestrianCrossingPredictor()
model.apply(init_w)

criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

torch.save(model.state_dict(), r'D:\mgr\PedestrianIntentionEstimation\model\modelrn.pth')
torch.save(optimizer.state_dict(), r'D:\mgr\PedestrianIntentionEstimation\model\optimizerrn.pth')

print("Model saved")

# Training and testing the model

In [None]:
import matplotlib.pyplot as plt

def plot_metrics(history):
    epochs = range(1, len(history['loss']) + 1)

    plt.figure(figsize=(14, 8))

    # Loss
    plt.subplot(2, 2, 1)
    plt.plot(epochs, history['loss'], marker='o')
    plt.title('Training Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.grid(True)

    # Accuracy
    plt.subplot(2, 2, 2)
    plt.plot(epochs, history['accuracy'], marker='o')
    plt.title('Training Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.grid(True)

    # Recall
    plt.subplot(2, 2, 3)
    plt.plot(epochs, history['recall'], marker='o')
    plt.title('Training Recall')
    plt.xlabel('Epoch')
    plt.ylabel('Recall')
    plt.grid(True)

    # F1 Score
    plt.subplot(2, 2, 4)
    plt.plot(epochs, history['f1'], marker='o')
    plt.title('Training F1 Score')
    plt.xlabel('Epoch')
    plt.ylabel('F1 Score')
    plt.grid(True)

    plt.tight_layout()
    plt.show()


In [None]:
import pickle
import numpy as np
from tqdm import tqdm
import os
import torch
import torch.nn as nn
from sklearn.metrics import accuracy_score, recall_score, f1_score, confusion_matrix, ConfusionMatrixDisplay
from torch.optim import AdamW
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader, Subset, Dataset
from multiprocessing import cpu_count
from torch.cuda.amp import autocast, GradScaler
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split

In [None]:
class PreprocessedDataset(Dataset):
    def __init__(self, data_dir, transform=None):
        self.data_dir = data_dir
        self.data_files = sorted([os.path.join(data_dir, f) for f in os.listdir(data_dir) if f.endswith('.pt')])
        self.transform = transform

    def __len__(self):
        return len(self.data_files)

    def __getitem__(self, idx):
        data = torch.load(self.data_files[idx])
        frame = data['frame']
        keypoints = data['keypoints']
        label = data['label']
        traffic_info = data['traffic_info']
        vehicle_info = data['vehicle_info']
        appearance_info = data['appearance_info']
        attributes_info = data['attributes_info']

        if self.transform:
            frame = self.transform(frame)

        return frame, keypoints, label, traffic_info, vehicle_info, appearance_info, attributes_info

In [None]:
def find_best_threshold(y_true, y_probs):
    ths = np.linspace(0.0, 1.0, 101)
    accs = [accuracy_score(y_true, (y_probs >= t).astype(int)) for t in ths]
    i = int(np.argmax(accs))
    return float(ths[i]), float(accs[i])

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

train_dir = r'D:\mgr\PedestrianIntentionEstimation\JAAD_dataset\training_data'
train_pt = PreprocessedDataset(train_dir, transform=None)

labels = []
for i in tqdm(range(len(train_pt)), desc="Reading labels"):
    item = train_pt[i]
    labels.append(int(item[2]))  # [2] = label
labels = np.array(labels, dtype=np.int64)

# save labels to npy file
out_path = r'D:\mgr\PedestrianIntentionEstimation\train_labels_all.npy'
np.save(out_path, labels)
print(f"Saved {labels.shape[0]} labels to: {out_path}")

In [None]:
import torch
print("GPU available:", torch.cuda.is_available())
print("Device:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "No CUDA device")

### Train

In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
from tqdm import tqdm
from sklearn.metrics import accuracy_score, recall_score, f1_score, confusion_matrix, ConfusionMatrixDisplay
from torch.optim import AdamW
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader, Subset
from multiprocessing import cpu_count
from torch.cuda.amp import autocast, GradScaler
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import random

def set_seed(seed=42):
    random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
set_seed(42)

def find_best_threshold(y_true, y_probs):
    ths = np.linspace(0.0, 1.0, 101)
    accs = [accuracy_score(y_true, (y_probs >= t).astype(int)) for t in ths]
    i = int(np.argmax(accs))
    return float(ths[i]), float(accs[i])

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

train_dir = r'D:\mgr\PedestrianIntentionEstimation\JAAD_dataset\training_data'
train_pt = PreprocessedDataset(train_dir, transform=None)

# Labels
LABELS_FILE = r'D:\mgr\PedestrianIntentionEstimation\train_labels_all.npy'
if os.path.exists(LABELS_FILE):
    labels = np.load(LABELS_FILE)
    print(f"[labels] Loaded from {LABELS_FILE} -> {labels.shape[0]} samples")
else:
    print("[labels] Building from dataset...")
    labels = []
    for i in tqdm(range(len(train_pt)), desc="Reading labels"):
        _, _, lab, _, _, _, _ = train_pt[i]
        labels.append(int(lab))
    labels = np.array(labels, dtype=np.int64)
    np.save(LABELS_FILE, labels)
    print(f"[labels] Saved to {LABELS_FILE}")

# Model
model = PedestrianCrossingPredictor().to(device)
model.load_state_dict(torch.load(
    r'D:\mgr\PedestrianIntentionEstimation\model\modelrn.pth',
    map_location=device
))

# Balancing classes
rng = np.random.default_rng(42)

crossing_idxs = np.where(labels == 1)[0]
non_crossing_idxs = np.where(labels == 0)[0]

n_keep = min(len(crossing_idxs), len(non_crossing_idxs))
crossing_keep = rng.choice(crossing_idxs, size=n_keep, replace=False)
non_crossing_keep = rng.choice(non_crossing_idxs, size=n_keep, replace=False)

final_idxs = np.concatenate([crossing_keep, non_crossing_keep])
rng.shuffle(final_idxs)

print(f"Crossing samples kept:      {len(crossing_keep)}")
print(f"Not-crossing samples kept:  {len(non_crossing_keep)}")
print(f"Total available (balanced): {len(final_idxs)}")


train_idx, val_idx = train_test_split(
    final_idxs, test_size=0.10, random_state=42, shuffle=True, stratify=labels[final_idxs]
)
train_subset = Subset(train_pt, train_idx)
val_subset   = Subset(train_pt, val_idx)

def collate_fn(batch):
    frames, keypoints, labels_b, traffic, vehicle, appearance, attrs = zip(*batch)
    frames = torch.stack(frames)
    keypoints = torch.stack(keypoints)
    labels_b = torch.tensor(labels_b)
    traffic = torch.stack(traffic)
    vehicle = torch.stack(vehicle)
    appearance = torch.stack(appearance)
    attrs = torch.stack(attrs)
    return frames, keypoints, labels_b, traffic, vehicle, appearance, attrs

n_w = min(16, cpu_count())
comb_train_loader = DataLoader(
    train_subset, batch_size=32, shuffle=True,
    num_workers=0, pin_memory=True, collate_fn=collate_fn
)
val_loader = DataLoader(
    val_subset, batch_size=64, shuffle=False,
    num_workers=0, pin_memory=True, collate_fn=collate_fn
)


criterion = nn.BCEWithLogitsLoss()

optimizer = AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)
scheduler = StepLR(optimizer, step_size=5, gamma=0.5)

# evaluation and train
def evaluate_on_loader(model, loader, criterion, device, threshold=None):
    model.eval()
    running_loss = 0.0
    all_probs, all_targets = [], []
    with torch.no_grad():
        for frames, keypoints, labels_batch, traffic, vehicle, appearance, attrs in loader:
            frames = frames.to(device); keypoints = keypoints.to(device)
            labels_batch = labels_batch.to(device); traffic = traffic.to(device)
            vehicle = vehicle.to(device); appearance = appearance.to(device); attrs = attrs.to(device)
            with autocast():
                out = model(frames, keypoints, traffic, vehicle, appearance, attrs)
                loss = criterion(out, labels_batch.float().view(-1,1))
            running_loss += loss.item() * frames.size(0)
            all_probs.extend(torch.sigmoid(out).cpu().numpy().flatten())
            all_targets.extend(labels_batch.cpu().numpy().flatten())
    avg_loss = running_loss / len(loader.dataset)
    all_probs = np.array(all_probs); all_targets = np.array(all_targets).astype(int)
    thr = find_best_threshold(all_targets, all_probs)[0] if threshold is None else threshold
    preds = (all_probs >= thr).astype(int)
    acc = accuracy_score(all_targets, preds)
    rec = recall_score(all_targets, preds)
    f1  = f1_score(all_targets, preds)
    return avg_loss, acc, rec, f1, thr, all_targets, preds

def train(model, train_loader, val_loader, optimizer, criterion, scheduler=None,
          num_epochs=10, device='cuda', verbose=True,
          patience=10, min_best_epoch=5, min_delta=1e-4,
          model_save_path=r'D:\mgr\PedestrianIntentionEstimation\model\trained_modelrnall30.pth'):

    model.to(device)
    scaler = GradScaler()
    history = {
        'loss': [], 'accuracy': [], 'recall': [], 'f1': [], 'threshold': [],
        'val_loss': [], 'val_acc': [], 'val_recall': [], 'val_f1': [], 'val_threshold': []
    }

    best_val_loss = float('inf')
    best_state = None
    epochs_no_improve = 0

    for epoch in range(num_epochs):
        model.train()
        running_loss, probs_all, targets_all = 0.0, [], []

        for bidx, (frames, keypoints, labels_b, traffic, vehicle, appearance, attrs) in \
                enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", unit="batch")):

            if epoch == 0 and bidx == 0:
                print("Labels in first batch:", labels_b.cpu().numpy())
                print("Number of crossing (1):", (labels_b == 1).sum().item())
                print("Number of not crossing (0):", (labels_b == 0).sum().item())

            frames = frames.to(device); keypoints = keypoints.to(device)
            labels_b = labels_b.to(device); traffic = traffic.to(device)
            vehicle = vehicle.to(device); appearance = appearance.to(device); attrs = attrs.to(device)

            optimizer.zero_grad(set_to_none=True)
            with autocast():
                out = model(frames, keypoints, traffic, vehicle, appearance, attrs)
                loss = criterion(out, labels_b.float().view(-1, 1))

            scaler.scale(loss).backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            scaler.step(optimizer); scaler.update()

            running_loss += loss.item() * frames.size(0)
            probs_all.extend(torch.sigmoid(out).detach().cpu().numpy().flatten())
            targets_all.extend(labels_b.detach().cpu().numpy().flatten())

        # train metrics
        train_preds = (np.array(probs_all) >= 0.5).astype(int)
        train_loss = running_loss / len(train_loader.dataset)
        train_acc  = accuracy_score(targets_all, train_preds)
        train_rec  = recall_score(targets_all, train_preds)
        train_f1   = f1_score(targets_all, train_preds)

        # validation
        val_loss, val_acc, val_rec, val_f1, val_thr, _, _ = evaluate_on_loader(
            model, val_loader, criterion, device, threshold=None
        )

        history['loss'].append(train_loss); history['accuracy'].append(train_acc)
        history['recall'].append(train_rec); history['f1'].append(train_f1); history['threshold'].append(0.5)
        history['val_loss'].append(val_loss); history['val_acc'].append(val_acc)
        history['val_recall'].append(val_rec); history['val_f1'].append(val_f1); history['val_threshold'].append(val_thr)

        if verbose:
            print(f"Epoch {epoch+1}/{num_epochs} | "
                  f"Train: L {train_loss:.4f} A {train_acc:.4f} R {train_rec:.4f} F1 {train_f1:.4f} | "
                  f"Val:   L {val_loss:.4f} A {val_acc:.4f} R {val_rec:.4f} F1 {val_f1:.4f} Thr* {val_thr:.2f}")

        if scheduler:
            scheduler.step()

        # Early stopping
        if (epoch + 1) > min_best_epoch:
            if val_loss < best_val_loss - min_delta:
                best_val_loss = val_loss
                best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
                epochs_no_improve = 0
            else:
                epochs_no_improve += 1
                if epochs_no_improve >= patience:
                    print(f"Early stopping: no val_loss improvement ≥ {min_delta} for {patience} epochs after epoch {min_best_epoch}.")
                    break

    if best_state is not None:
        model.load_state_dict(best_state)
    torch.save(model.state_dict(), model_save_path)
    return model, history


# Train
optimizer = AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)
scheduler = StepLR(optimizer, step_size=5, gamma=0.5)
trained_m, history = train(
    model, comb_train_loader, val_loader, optimizer, criterion, scheduler,
    num_epochs=30, device=device, verbose=True, patience=5, min_best_epoch=5
)

plot_metrics({'loss': history['loss'], 'accuracy': history['accuracy'],
              'recall': history['recall'], 'f1': history['f1']})

# Plots
epochs = range(1, len(history['loss']) + 1)

plt.figure(figsize=(12,5))
plt.subplot(1,2,1)
plt.plot(epochs, history['loss'], marker='o', label='Train Loss')
plt.plot(epochs, history['val_loss'], marker='o', label='Val Loss')
plt.title('Loss (train vs val)')
plt.xlabel('Epoch'); plt.ylabel('Loss'); plt.grid(True); plt.legend()

plt.subplot(1,2,2)
plt.plot(epochs, history['accuracy'], marker='o', label='Train Acc')
plt.plot(epochs, history['val_acc'], marker='o', label='Val Acc')
plt.title('Accuracy (train vs val)')
plt.xlabel('Epoch'); plt.ylabel('Accuracy'); plt.grid(True); plt.legend()

plt.tight_layout()
plt.show()


### Test

In [None]:
import os
import random
import numpy as np
import torch
import torch.nn as nn
from tqdm import tqdm
import matplotlib.pyplot as plt
from multiprocessing import cpu_count
from torch.utils.data import DataLoader, Subset
from torch.cuda.amp import autocast
from sklearn.metrics import accuracy_score, recall_score, f1_score, confusion_matrix, ConfusionMatrixDisplay

TEST_DIR = r"D:\mgr\PedestrianIntentionEstimation\JAAD_dataset\test_data"
LABELS_FILE = r"D:\mgr\PedestrianIntentionEstimation\test_labelsrn.npy"
MODEL_CKPT = r"D:\mgr\PedestrianIntentionEstimation\model\trained_modelrnall30.pth"

SAMPLES_PER_CLASS = 5000 
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class PreprocessedDataset(torch.utils.data.Dataset):
    def __init__(self, data_dir, transform=None):
        self.data_dir = data_dir
        self.data_files = sorted([os.path.join(data_dir, f) for f in os.listdir(data_dir) if f.endswith('.pt')])
        self.transform = transform

    def __len__(self):
        return len(self.data_files)

    def __getitem__(self, idx):
        data = torch.load(self.data_files[idx], map_location="cpu")
        frame = data['frame']               # Tensor [3,224,224]
        keypoints = data['keypoints']       # Tensor [33,3]
        label = data['label']               # int 0/1
        traffic_info = data['traffic_info'] # Tensor [4]
        vehicle_info = data['vehicle_info'] # Tensor [1]
        appearance_info = data['appearance_info'] # Tensor [3]
        attributes_info = data['attributes_info'] # Tensor [3]
        if self.transform:
            frame = self.transform(frame)
        return frame, keypoints, label, traffic_info, vehicle_info, appearance_info, attributes_info

def collate_fn(batch):
    frames, keypoints, labels_b, traffic, vehicle, appearance, attrs = zip(*batch)
    frames = torch.stack(frames)
    keypoints = torch.stack(keypoints)
    labels_b = torch.tensor(labels_b)
    traffic = torch.stack(traffic)
    vehicle = torch.stack(vehicle)
    appearance = torch.stack(appearance)
    attrs = torch.stack(attrs)
    return frames, keypoints, labels_b, traffic, vehicle, appearance, attrs

test_pt = PreprocessedDataset(TEST_DIR, transform=None)

# Labels
if os.path.exists(LABELS_FILE):
    labels_all = np.load(LABELS_FILE)
    print(f"[labels] Loaded from {LABELS_FILE} -> {labels_all.shape[0]} samples")
else:
    print("[labels] Building from dataset...")
    labels_all = []
    for i in tqdm(range(len(test_pt)), desc="Extracting labels"):
        _, _, lab, _, _, _, _ = test_pt[i]
        labels_all.append(int(lab))
    labels_all = np.array(labels_all, dtype=np.int8)
    np.save(LABELS_FILE, labels_all)
    print(f"[labels] Saved to {LABELS_FILE}")

# Test set
total_samples = labels_all.shape[0]
n_crossing = int((labels_all == 1).sum())
n_not_crossing = total_samples - n_crossing
print("\n STATYSTYKI ZBIORU TESTOWEGO ")
print(f"Łącznie próbek:      {total_samples}")
print(f"Crossing (1):        {n_crossing}  ({n_crossing / max(1,total_samples):.2%})")
print(f"Not crossing (0):    {n_not_crossing}  ({n_not_crossing / max(1,total_samples):.2%})")

# Model
class PedestrianCrossingPredictor(nn.Module):
    def __init__(self):
        super().__init__()
        import torchvision.models as models
        vgg19 = models.vgg19(pretrained=True)
        self.vgg19_features = vgg19.features
        self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
        self.vgg19_classifier = nn.Sequential(*list(vgg19.classifier.children())[:-1])  # 4096-D

        for p in self.vgg19_features[:36].parameters():
            p.requires_grad = False

        self.fc_img = nn.Sequential(
            nn.Linear(4096, 512), nn.ReLU(), nn.BatchNorm1d(512), nn.Dropout(0.5),
            nn.Linear(512, 256), nn.ReLU()
        )
        self.goal_module = nn.Sequential(
            nn.Linear(256, 128), nn.ReLU(), nn.BatchNorm1d(128), nn.Dropout(0.5),
            nn.Linear(128, 256), nn.ReLU()
        )
        # 99 keypoints + (4+1+3+3)=11 context = 110
        self.com_fc1 = nn.Linear(256 + 256 + 110, 256)
        self.com_bn1 = nn.BatchNorm1d(256)
        self.com_d1 = nn.Dropout(0.5)
        self.com_fc2 = nn.Linear(256, 128)
        self.com_bn2 = nn.BatchNorm1d(128)
        self.com_d2 = nn.Dropout(0.5)
        self.com_fc3 = nn.Linear(128, 1)

    def forward(self, x, keypoints, traffic_info, vehicle_info, appearance_info, attributes_info):
        # x: [B,3,224,224]
        c_out = self.vgg19_features(x)
        c_out = self.avgpool(c_out)
        c_out = torch.flatten(c_out, 1)
        c_out = self.vgg19_classifier(c_out)  # [B,4096]
        c_out = self.fc_img(c_out)            # [B,256]

        B = x.size(0)
        keypoints = keypoints.view(B, -1)     # [B,99]
        additional_info = torch.cat([keypoints, traffic_info, vehicle_info, appearance_info, attributes_info], dim=1)  # [B,110]

        goal_out = self.goal_module(c_out)
        combined = torch.cat((c_out, goal_out, additional_info), dim=1)
        combined = torch.relu(self.com_fc1(combined))
        combined = self.com_bn1(combined)
        combined = self.com_d1(combined)
        combined = torch.relu(self.com_fc2(combined))
        combined = self.com_bn2(combined)
        combined = self.com_d2(combined)
        logits = self.com_fc3(combined)       # [B,1]
        return logits

model = PedestrianCrossingPredictor().to(device)
model.load_state_dict(torch.load(MODEL_CKPT, map_location=device))
criterion = nn.BCEWithLogitsLoss()

# Test sampling
crossing_idx = np.where(labels_all == 1)[0].tolist()
non_crossing_idx = np.where(labels_all == 0)[0].tolist()
crossing_sampled = random.sample(crossing_idx, min(SAMPLES_PER_CLASS, len(crossing_idx)))
non_crossing_sampled = random.sample(non_crossing_idx, min(SAMPLES_PER_CLASS, len(non_crossing_idx)))
balanced_indices = crossing_sampled + non_crossing_sampled
random.shuffle(balanced_indices)

balanced_test = Subset(test_pt, balanced_indices)

def collate_fn(batch):
    frames, keypoints, labels_b, traffic, vehicle, appearance, attrs = zip(*batch)
    frames = torch.stack(frames)
    keypoints = torch.stack(keypoints)
    labels_b = torch.tensor(labels_b)
    traffic = torch.stack(traffic)
    vehicle = torch.stack(vehicle)
    appearance = torch.stack(appearance)
    attrs = torch.stack(attrs)
    return frames, keypoints, labels_b, traffic, vehicle, appearance, attrs

n_w = min(16, cpu_count())
balanced_loader = DataLoader(
    balanced_test, batch_size=32, shuffle=False, num_workers=0, pin_memory=True, collate_fn=collate_fn
)

def find_best_threshold(y_true, y_probs):
    ths = np.arange(0.0, 1.01, 0.01)
    best_t, best_acc = 0.5, 0.0
    for t in ths:
        preds = (y_probs >= t).astype(int)
        acc = accuracy_score(y_true, preds)
        if acc > best_acc:
            best_acc, best_t = acc, t
    return best_t, best_acc

def test(model, criterion, test_loader, ablation=None):
    model.eval()
    test_running_loss = 0.0
    all_probs, test_targets = [], []

    with torch.no_grad():
        for frames, keypoints, labels, traffic_info, vehicle_info, appearance_info, attributes_info in tqdm(
            test_loader, desc="Testing", unit="batch"
        ):
            frames = frames.to(device); keypoints = keypoints.to(device)
            traffic_info = traffic_info.to(device); vehicle_info = vehicle_info.to(device)
            appearance_info = appearance_info.to(device); attributes_info = attributes_info.to(device)
            labels = labels.to(device)

            if ablation == 'traffic':
                traffic_info = torch.zeros_like(traffic_info)
            elif ablation == 'vehicle':
                vehicle_info = torch.zeros_like(vehicle_info)
            elif ablation == 'appearance':
                appearance_info = torch.zeros_like(appearance_info)
            elif ablation == 'attributes':
                attributes_info = torch.zeros_like(attributes_info)

            with autocast():
                outputs = model(frames, keypoints, traffic_info, vehicle_info, appearance_info, attributes_info)
                loss = criterion(outputs, labels.unsqueeze(1).float())

            test_running_loss += loss.item() * frames.size(0)
            probs = torch.sigmoid(outputs).cpu().numpy().reshape(-1)
            all_probs.extend(probs)
            test_targets.extend(labels.cpu().numpy().reshape(-1))

    avg_test_loss = test_running_loss / len(test_loader.dataset)
    all_probs = np.array(all_probs)
    test_targets = np.array(test_targets).astype(int)

    best_thresh, best_acc = find_best_threshold(test_targets, all_probs)
    preds = (all_probs >= best_thresh).astype(int)
    test_recall = recall_score(test_targets, preds)
    test_f1 = f1_score(test_targets, preds)
    return avg_test_loss, best_acc, test_recall, test_f1, best_thresh, test_targets, preds

# Test
avg_loss, acc, rec, f1, thr, targets, preds = test(model, criterion, balanced_loader)

print(f"\nWyniki testu (po {SAMPLES_PER_CLASS} na klasę) ")
print(f"Test Loss: {avg_loss:.4f}")
print(f"Best Threshold: {thr:.2f}")
print(f"Accuracy: {acc:.4f}")
print(f"Recall: {rec:.4f}")
print(f"F1 Score: {f1:.4f}")

# Confusion Matrix
cm = confusion_matrix(targets, preds)
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=["Not Crossing", "Crossing"])
fig, ax = plt.subplots(figsize=(6, 6))
disp.plot(ax=ax, cmap='Blues', values_format='d')
plt.title(f"Confusion Matrix (Thr={thr:.2f}, {SAMPLES_PER_CLASS} per class)")
plt.tight_layout()
plt.show()
