In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
from glob import glob

# ================== PARAMETRI ==================
DATA_DIR = "processed"
PROB_THRESHOLD = 0.7
MIN_WINDOWS = 1
COOLDOWN_SEC = 30
BATCH_SIZE = 64

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ================== LOAD FILE LIST ==================
X_files = sorted(glob(os.path.join(DATA_DIR, "X_*.npy")))
T_files = sorted(glob(os.path.join(DATA_DIR, "T_*.npy")))
EVENTS_files = sorted(glob(os.path.join(DATA_DIR, "EVENTS_*.npy")))

# ================== MODEL ==================
X_sample = np.load(X_files[0])
WINDOW, FEATURES = X_sample.shape[1], X_sample.shape[2]

class FallCNN_AllFeatures(nn.Module):
    def __init__(self, window, features):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=(5, features)),
            nn.ReLU(),
            nn.MaxPool2d((2,1)),
            nn.Conv2d(16, 32, kernel_size=(5,1), padding=(2,0)),
            nn.ReLU(),
            nn.MaxPool2d((2,1)),
            nn.Flatten(),
            nn.Linear(32 * (window // 4) * 1, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )
    def forward(self, x): return self.net(x)

model = FallCNN_AllFeatures(WINDOW, FEATURES).to(device)
model.load_state_dict(torch.load("fall_cnn_final.pt", map_location=device))
model.eval()

# ================== EVALUATE ==================
TP = FP = FN = 0
total_gt = total_pred = 0

print("\n=== FALL DETECTION RESULTS ===\n")

with torch.no_grad():
    for X_file, T_file, EVENTS_file in zip(X_files, T_files, EVENTS_files):
        day = os.path.basename(X_file).split("_")[1].split(".")[0]

        X_day = np.load(X_file)
        T_day = np.load(T_file, allow_pickle=True)
        EVENTS_day = np.load(EVENTS_file, allow_pickle=True)

        probs = []
        for i in range(0, len(X_day), BATCH_SIZE):
            batch = X_day[i:i+BATCH_SIZE]
            Xb = torch.tensor(batch[:, None, :, :], dtype=torch.float32, device=device)
            batch_probs = torch.sigmoid(model(Xb)).cpu().numpy().flatten()
            probs.extend(batch_probs)
        probs = np.array(probs)

        predicted_events = []
        active_count = 0
        last_event_time = None

        for p, t in zip(probs, T_day):
            if p >= PROB_THRESHOLD:
                active_count += 1
            else:
                if active_count >= MIN_WINDOWS:
                    if last_event_time is None or (t - last_event_time).total_seconds() >= COOLDOWN_SEC:
                        predicted_events.append(t)
                        last_event_time = t
                active_count = 0

        gt_events = EVENTS_day
        matched = set()
        for pe in predicted_events:
            for ge in gt_events:
                if abs((pe - ge).total_seconds()) <= 2 and ge not in matched:
                    TP += 1
                    matched.add(ge)
                    break

        FP += len(predicted_events) - len(matched)
        FN += len(gt_events) - len(matched)

        total_gt += len(gt_events)
        total_pred += len(predicted_events)

        print(f"{day}: GT falls={len(gt_events)}, predicted={len(predicted_events)}")

precision = TP / (TP + FP) if TP + FP > 0 else 0
recall = TP / (TP + FN) if TP + FN > 0 else 0

print("\n=== SUMMARY ===")
print(f"GT falls: {total_gt}")
print(f"Predicted falls: {total_pred}")
print(f"TP: {TP}, FP: {FP}, FN: {FN}")
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")
