<a href="https://colab.research.google.com/github/adamceek/gpx-hr-analyzer/blob/main/comparison.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# @title
# script: comparison.py

import os
import glob
import xml.etree.ElementTree as ET
from datetime import datetime, timedelta
import matplotlib.pyplot as plt
import numpy as np
import matplotlib.dates as mdates
from pathlib import Path

# =========================
# üîß USER SETTINGS
# =========================

# True = mount Google Drive and save to /MyDrive/001-Strava/HR_IMG
# False = save only to the local notebook directory
#         (or not at all if DOWNLOAD_RESULTS=False)
USE_GOOGLE_DRIVE = False

# True = automatically download PNG + TSV when finished (Colab only)
DOWNLOAD_RESULTS = False

# "Left" / "Right" / None (None = try to auto-detect from watch filename)
HAND_OVERRIDE = None

# Comparison settings
TOLERANCE = 5
OFFSET_RANGE = range(-15, 15)      # tested time offsets in seconds
START_TIME_STR = "16:30"           # "HH:MM" or None (filter window start)
END_TIME_STR = None                # "HH:MM" or None (filter window end)

# Device labels (change if you do not use Garmin/Huawei)
label1 = "Garmin HRM Dual"         # reference DEVICE (belt / chest strap)
label2 = "Huawei Watch Fit 3"      # comparison DEVICE (watch)

# Timezone offset of GPX timestamps to your local time (hours)
TIMEZONE_OFFSET_HOURS = 1

# =========================
# üåç ENV & PATHS
# =========================

IN_COLAB = False
try:
    import google.colab  # type: ignore
    from google.colab import files, drive  # type: ignore
    IN_COLAB = True
except ImportError:
    IN_COLAB = False
    files = None
    drive = None

if IN_COLAB and USE_GOOGLE_DRIVE:
    drive.mount('/content/drive', force_remount=True)
    BASE_DIR = "/content/drive/MyDrive/001-Strava"
else:
    BASE_DIR = os.getcwd()

IMG_DIR = os.path.join(BASE_DIR, "HR_IMG")
TSV_PATH = os.path.join(IMG_DIR, "comparison.tsv")
tolerance_str = f"¬±{TOLERANCE} bpm"

START_TIME = datetime.strptime(START_TIME_STR, "%H:%M").time() if START_TIME_STR else None
END_TIME = datetime.strptime(END_TIME_STR, "%H:%M").time() if END_TIME_STR else None

# =========================
# üìÇ FILE INPUT
# =========================

if IN_COLAB:
    # Clean up old GPX files from the working directory (optional but handy)
    for f in glob.glob("*.gpx"):
        try:
            os.remove(f)
        except FileNotFoundError:
            pass

    print("‚¨ÜÔ∏è Upload reference GPX (belt / chest strap)")
    uploaded_ref = files.upload()
    if not uploaded_ref:
        raise RuntimeError("No reference GPX uploaded.")
    file1 = next(iter(uploaded_ref.keys()))

    print("‚¨ÜÔ∏è Upload comparison GPX (watch)")
    uploaded_cmp = files.upload()
    if not uploaded_cmp:
        raise RuntimeError("No comparison GPX uploaded.")
    file2 = next(iter(uploaded_cmp.keys()))
else:
    file1 = input("Path to reference GPX (belt): ").strip()
    file2 = input("Path to comparison GPX (watch): ").strip()

print("Using files:")
print("  reference:", file1)
print("  compare:  ", file2)

# =========================
# üñêÔ∏è HAND LABEL
# =========================

def get_hand_label(filename: str, override: str | None = None) -> str:
    if override in ("Left", "Right"):
        return override
    name = filename.lower()
    # Simple auto-detection from filename
    if "right" in name or "prav" in name:
        return "Right"
    if "left" in name or "lav" in name:
        return "Left"
    return "Unknown"

hand_label = get_hand_label(file2, HAND_OVERRIDE)

# =========================
# üóìÔ∏è ACTIVITY DATE, TYPE & START TIME
# =========================

def get_activity_date(filepath: str) -> str:
    ns = {'default': 'http://www.topografix.com/GPX/1/1'}
    tree = ET.parse(filepath)
    root = tree.getroot()
    for trkpt in root.findall('.//default:trkpt', ns):
        time_elem = trkpt.find('default:time', ns)
        if time_elem is not None:
            dt = datetime.fromisoformat(time_elem.text.replace("Z", "+00:00"))
            return dt.strftime("%d_%m_%Y")
    return "unknown_date"

def get_activity_type(filepath: str) -> str:
    tree = ET.parse(filepath)
    root = tree.getroot()
    for elem in root.iter():
        if elem.tag.lower().endswith("type") and elem.text:
            return elem.text.strip().capitalize()
    return "Unknown"

def get_start_time(filepath: str) -> str:
    ns = {'default': 'http://www.topografix.com/GPX/1/1'}
    tree = ET.parse(filepath)
    root = tree.getroot()
    for trkpt in root.findall('.//default:trkpt', ns):
        time_elem = trkpt.find('default:time', ns)
        if time_elem is not None:
            dt = datetime.fromisoformat(time_elem.text.replace("Z", "+00:00"))
            dt = dt + timedelta(hours=TIMEZONE_OFFSET_HOURS)
            return dt.strftime("%H:%M")
    return "unknown"

comparison_label = get_activity_date(file1)
activity_type = get_activity_type(file1)
start_time_str = get_start_time(file1)

# =========================
# ‚ù§Ô∏è HR DATA EXTRACTION
# =========================

def extract_hr_data(filepath: str, time_offset_sec: int = 0) -> dict:
    ns = {
        'default': 'http://www.topografix.com/GPX/1/1',
        'gpxtpx': 'http://www.garmin.com/xmlschemas/TrackPointExtension/v1'
    }
    tree = ET.parse(filepath)
    root = tree.getroot()
    data: dict[datetime, int] = {}
    for trkpt in root.findall('.//default:trkpt', ns):
        time_elem = trkpt.find('default:time', ns)
        hr_elem = trkpt.find('.//gpxtpx:hr', ns)
        if time_elem is not None and hr_elem is not None:
            t = datetime.fromisoformat(time_elem.text.replace("Z", "+00:00"))
            t = t + timedelta(hours=TIMEZONE_OFFSET_HOURS, seconds=time_offset_sec)
            t = t.replace(microsecond=0)
            if ((START_TIME is None or t.time() >= START_TIME) and
                (END_TIME is None or t.time() <= END_TIME)):
                data[t] = int(hr_elem.text)
    return data

# =========================
# üï≥Ô∏è GAP DETECTION
# =========================

def find_gaps(reference_times, compare_times, max_gap: int = 10):
    gaps = []
    ref_set = set(reference_times)
    cmp_set = set(compare_times)
    all_times = sorted(set(reference_times + compare_times))
    for i in range(1, len(all_times)):
        delta = (all_times[i] - all_times[i - 1]).total_seconds()
        if delta > max_gap:
            ts1 = all_times[i - 1]
            ts2 = all_times[i]
            in_ref = ts1 in ref_set or ts2 in ref_set
            in_cmp = ts1 in cmp_set or ts2 in cmp_set
            if in_ref and not in_cmp:
                gaps.append((ts1, ts2, 'compare'))   # missing watch data
            elif in_cmp and not in_ref:
                gaps.append((ts1, ts2, 'reference')) # missing belt data
            elif not in_ref and not in_cmp:
                gaps.append((ts1, ts2, 'both'))
    return gaps

# =========================
# üßÆ LOAD & ANALYZE DATA
# =========================

data1 = extract_hr_data(file1)
total_points1 = len(data1)

results = []
for offset in OFFSET_RANGE:
    data2 = extract_hr_data(file2, time_offset_sec=offset)
    total_points2 = len(data2)
    common_times = sorted(set(data1.keys()) & set(data2.keys()))
    if not common_times:
        continue

    hr1 = [data1[t] for t in common_times]
    hr2 = [data2[t] for t in common_times]

    match_count = sum(1 for g, h in zip(hr1, hr2) if abs(g - h) <= TOLERANCE)
    accuracy = (match_count / len(common_times)) * 100

    if len(common_times) >= 2:
        corr = np.corrcoef(hr1, hr2)[0, 1]
    else:
        corr = float("nan")

    results.append({
        'offset': offset,
        'accuracy': accuracy,
        'corr': corr,
        'common_times': common_times,
        'hr1': hr1,
        'hr2': hr2,
        'total_points2': total_points2,
        'match_count': match_count,
        'data2': data2
    })

if not results:
    print("\n‚ùå No common timestamps found within the specified time window and offsets.")
else:
    # Pick the offset with the best accuracy
    best_result = max(results, key=lambda x: x['accuracy'])
    offset_str = f"{best_result['offset']}s"
    accuracy_str = f"{best_result['accuracy']:.2f}%"
    corr_str = f"{best_result['corr']:.4f}"
    match_count = best_result['match_count']

    # =========================
    # üìà VISUALIZATION
    # =========================
    fig, ax = plt.subplots(figsize=(10, 5))

    gaps = find_gaps(sorted(data1.keys()), list(best_result['data2'].keys()))
    colors = {'compare': 'gray', 'reference': 'orange', 'both': 'black'}
    labels_drawn = set()
    for start, end, label in gaps:
        if label not in labels_drawn:
            ax.axvspan(start, end, color=colors[label], alpha=0.2, zorder=0, label=label)
            labels_drawn.add(label)
        else:
            ax.axvspan(start, end, color=colors[label], alpha=0.2, zorder=0)

    # Clean legend labels: Belt vs Watch
    ax.plot(
        best_result['common_times'],
        best_result['hr1'],
        label="Belt",
        color='#51D2D6',
        linewidth=2.5,
        zorder=1
    )
    ax.plot(
        best_result['common_times'],
        best_result['hr2'],
        label=f"Watch (offset {best_result['offset']}s)",
        color='#F97D75',
        linewidth=2.5,
        alpha=0.6,
        zorder=2
    )

    ax.xaxis.set_major_formatter(mdates.DateFormatter('%H:%M'))
    fig.autofmt_xdate()

    avg1 = np.mean(best_result['hr1'])
    avg2 = np.mean(best_result['hr2'])
    max1 = np.max(best_result['hr1'])
    max2 = np.max(best_result['hr2'])

    info_text = (
        f"Activity: {activity_type}\n"
        f"Ref device: {label1}\n"
        f"Cmp device: {label2}\n"
        f"Hand (cmp): {hand_label}\n"
        f"Ref points: {total_points1}\n"
        f"Cmp points: {best_result['total_points2']}\n"
        f"Compared: {len(best_result['common_times'])}\n"
        f"Match ({tolerance_str}): {match_count}\n"
        f"Accuracy: {accuracy_str}\n"
        f"Correlation: {corr_str}\n"
        f"Best offset: {offset_str}\n"
        f"{label1} avg/max: {avg1:.1f}/{max1}\n"
        f"{label2} avg/max: {avg2:.1f}/{max2}"
    )

    # Clean title: only date + time
    ax.set_title(
        f"Heart Rate Comparison ‚Äì {comparison_label} ‚Äì {start_time_str}",
        fontsize=16,
        weight='bold'
    )
    ax.set_xlabel("Time", fontsize=12, weight='bold')
    ax.set_ylabel("Heart Rate (bpm)", fontsize=12, weight='bold')
    ax.legend(loc="upper right", fontsize=9)

    ax.text(
        1.01, 0.5, info_text,
        transform=ax.transAxes,
        fontsize=11,
        verticalalignment='center',
        bbox=dict(boxstyle="round", facecolor='white', alpha=0.7)
    )

    ax.grid(False)
    ax.tick_params(axis='both', which='both', length=0)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    plt.tight_layout()

    plt.show()

    # =========================
    # üíæ SAVE (optional)
    # =========================
    SHOULD_SAVE_FILES = USE_GOOGLE_DRIVE or DOWNLOAD_RESULTS

    if SHOULD_SAVE_FILES:
        os.makedirs(IMG_DIR, exist_ok=True)

        # Filename uses date + start time
        safe_time = start_time_str.replace(":", "-")
        base_filename = f"comparison_{comparison_label}_{safe_time}.png"
        img_filename = base_filename
        counter = 1
        while os.path.exists(os.path.join(IMG_DIR, img_filename)):
            img_filename = f"comparison_{comparison_label}_{safe_time}({counter}).png"
            counter += 1

        img_filepath = os.path.join(IMG_DIR, img_filename)
        plt.savefig(img_filepath, bbox_inches='tight', dpi=300)
        print(f"üñºÔ∏è Chart saved as image: {img_filepath}")

        # TSV log
        if not os.path.exists(TSV_PATH):
            with open(TSV_PATH, "w", encoding="utf-8") as f:
                f.write("Date\tActivity\tTolerance\tRefPoints\tMatchedPoints\tAccuracy\tCorrelation\tHand\tOffset\n")

        output_line = (
            f"{comparison_label}\t{activity_type}\t{tolerance_str}\t"
            f"{total_points1}\t{match_count}\t{accuracy_str}\t{corr_str}\t"
            f"{hand_label}\t{offset_str}\n"
        )

        with open(TSV_PATH, "a", encoding="utf-8") as f:
            f.write(output_line)

        print(f"‚úÖ Result appended to: {TSV_PATH}")

        if IN_COLAB and DOWNLOAD_RESULTS:
            files.download(img_filepath)
            files.download(TSV_PATH)
