import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"✅ GPU detected: {torch.cuda.get_device_name(0)}")
else:
    print("⚠️ GPU not detected. Go to Runtime -> Change runtime type to enable it.")

In [None]:
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"✅ GPU detected: {torch.cuda.get_device_name(0)}")
else:
    print("⚠️ GPU not detected. Go to Runtime -> Change runtime type to enable it.")

In [None]:
import os
import sys

if not os.path.exists('DSA-RL-Tracker'):
    !git clone https://github.com/MahsaAbadian/DSA-RL-Tracker.git

%cd DSA-RL-Tracker
sys.path.append(os.getcwd())

In [None]:
from StopModule.src.train_standalone import StopDataset, plot_samples

# Show a few samples of Endpoints (Label: STOP) vs Midpoints (Label: GO)
preview_ds = StopDataset(samples_per_class=10)
plot_samples(preview_ds, num_samples=4)

In [None]:
from StopModule.src.train_standalone import train_stop_detector

# This will generate 16,000 samples and train for 20 epochs
# Should reach ~99% accuracy
train_stop_detector(epochs=20, batch_size=64, samples=8000)cd ,,

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
from StopModule.src.models import StandaloneStopDetector
from Experiment4_separate_stop_v2.src.train import CurveMakerFlexible, crop32, load_curve_config

device = "cuda" if torch.cuda.is_available() else "cpu"
model = StandaloneStopDetector().to(device)
model.load_state_dict(torch.load("StopModule/weights/stop_detector_v1.pth", map_location=device))
model.eval()

cfg, _ = load_curve_config()
maker = CurveMakerFlexible(h=128, w=128, config=cfg)
img, mask, pts_all = maker.sample_curve(width_range=(2, 4))
pts = pts_all[0]

# Test at Start, Middle, and End of the path
indices = [0, len(pts)//2, len(pts)-1]
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

for i, idx in enumerate(indices):
    pt = pts[idx]
    path_mask = np.zeros_like(img)
    for p in pts[:idx+1]: path_mask[int(p[0]), int(p[1])] = 1.0
    
    c_img = crop32(img, int(pt[0]), int(pt[1]))
    c_path = crop32(path_mask, int(pt[0]), int(pt[1]))
    x = torch.tensor(np.stack([c_img, c_path]), dtype=torch.float32).unsqueeze(0).to(device)
    
    with torch.no_grad():
        logit = model(x)
        prob = torch.sigmoid(logit).item()
    
    axes[i].imshow(c_img, cmap='gray')
    axes[i].set_title(f"Position: {idx}/{len(pts)}\nTerminal Prob: {prob:.2%}")
    
plt.suptitle("Supervised Detector Verification (Exp 4 Gold Standard)", fontsize=16)
plt.show()

In [None]:
from google.colab import files
weight_path = "StopModule/weights/stop_detector_v1.pth"
if os.path.exists(weight_path):
    print("✅ Training complete! Downloading weights...")
    files.download(weight_path)
else:
    print("❌ Weight file not found. Check training logs.")

In [None]:
# Fetch the latest changes
!git fetch origin

# Rebase your local commit onto the remote branch
!git rebase origin/main

# If there are conflicts, resolve them, then:
!git add .
!git rebase --continue

# Force push (since rebase rewrites history)
!git push origin main --force-with-lease