# ðŸ›‘ Standalone Stop Detector Training (Supervised)

This notebook trains a dedicated vision-based Stop Detector for DSA curve tracking using supervised learning.

## ðŸ“– How to run this notebook:
1. **Enable GPU**: Runtime -> Change runtime type -> T4 GPU.
2. **Run All Cells**: Runtime -> Run all.
3. **Download**: The final cell downloads `stop_detector_v1.pth`.

In [None]:
import torch
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"âœ… GPU: {torch.cuda.get_device_name(0)}")

In [None]:
!pip install numpy scipy opencv-python tqdm matplotlib

In [None]:
import os, sys
if not os.path.exists('DSA-RL-Tracker'):
    !git clone https://github.com/MahsaAbadian/DSA-RL-Tracker.git
%cd DSA-RL-Tracker
sys.path.append(os.getcwd())

In [None]:
from StopModule.src.train_standalone import StopDataset, plot_samples
preview_ds = StopDataset(samples_per_class=10)
plot_samples(preview_ds, num_samples=4)

In [None]:
from StopModule.src.train_standalone import train_stop_detector
train_stop_detector(epochs=20, batch_size=64, samples=8000)

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from StopModule.src.models import StandaloneStopDetector
from Experiment4_separate_stop_v2.src.train import CurveMakerFlexible, crop32, load_curve_config

device = "cuda" if torch.cuda.is_available() else "cpu"
model = StandaloneStopDetector().to(device)
model.load_state_dict(torch.load("StopModule/weights/stop_detector_v1.pth", map_location=device))
model.eval()

cfg, _ = load_curve_config()
maker = CurveMakerFlexible(h=128, w=128, config=cfg)
img, mask, pts_all = maker.sample_curve(width_range=(2, 4))
pts = pts_all[0]

indices = [0, len(pts)//2, len(pts)-1]
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
for i, idx in enumerate(indices):
    pt = pts[idx]
    path_mask = np.zeros_like(img)
    for p in pts[:idx+1]: 
        py, px = int(np.clip(p[0], 0, 127)), int(np.clip(p[1], 0, 127))
        path_mask[py, px] = 1.0
    
    c_img = crop32(img, int(np.clip(pt[0], 0, 127)), int(np.clip(pt[1], 0, 127)))
    c_path = crop32(path_mask, int(np.clip(pt[0], 0, 127)), int(np.clip(pt[1], 0, 127)))
    x = torch.tensor(np.stack([c_img, c_path]), dtype=torch.float32).unsqueeze(0).to(device)
    with torch.no_grad():
        prob = torch.sigmoid(model(x)).item()
    axes[i].imshow(c_img, cmap='gray')
    axes[i].set_title(f"Pos: {idx}\nStop Prob: {prob:.2%}")
plt.show()

In [None]:
from google.colab import files
if os.path.exists("StopModule/weights/stop_detector_v1.pth"):
    files.download("StopModule/weights/stop_detector_v1.pth")