# üõë Stop Module Training on Google Colab

This notebook trains a **Standalone Stop Detector** using supervised learning to detect DSA curve endpoints.

## Features
- **Vessel Realism**: Trains with tapering (wide‚Üínarrow) and fading (bright‚Üídim) to match real vessel characteristics
- **Balanced Dataset**: Generates equal positive (endpoint) and negative (midpoint) samples
- **GPU Accelerated**: Uses Colab's free GPU for fast training
- **Validation Metrics**: Tracks per-class accuracy (Stop vs Go)

## What You'll Get
- A trained stop detector model (`stop_detector_v1.pth`)
- Training/validation accuracy metrics
- Visualizations of sample crops


In [None]:
# Install and setup
!git clone https://github.com/mahsaabadian/DSA-RL-Tracker.git
%cd DSA-RL-Tracker
!pip install -r Experiment1/requirements.txt


## Check GPU Availability


In [None]:
# Check CUDA / GPU availability
import torch
print("PyTorch version:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("CUDA device count:", torch.cuda.device_count())
    print("CUDA device name:", torch.cuda.get_device_name(0))
    print("GPU Memory:", f"{torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    print("‚ö†Ô∏è  GPU not available ‚Äî training will fall back to CPU (slower)")


## Configure Training Parameters

Adjust these parameters as needed:
- `epochs`: Number of training epochs (default: 15)
- `batch_size`: Batch size (default: 64, increase if GPU memory allows)
- `samples`: Samples per class (default: 5000 = 10k total samples)
- `learning_rate`: Learning rate (default: 1e-4)
- `vessel_realism`: Enable vessel-realistic features (tapering & fading) - **Recommended: True**


In [None]:
# Training Configuration
EPOCHS = 15              # Number of training epochs
BATCH_SIZE = 64          # Batch size (increase to 128 if GPU memory allows)
SAMPLES_PER_CLASS = 5000 # Samples per class (5000 endpoints + 5000 midpoints = 10k total)
LEARNING_RATE = 1e-4     # Learning rate
VESSEL_REALISM = True    # Enable vessel-realistic features (tapering & fading)
OUTPUT_NAME = "stop_detector_v1"  # Output filename (will save as .pth)

print("Training Configuration:")
print(f"  Epochs: {EPOCHS}")
print(f"  Batch Size: {BATCH_SIZE}")
print(f"  Samples per Class: {SAMPLES_PER_CLASS} (Total: {SAMPLES_PER_CLASS * 2})")
print(f"  Learning Rate: {LEARNING_RATE}")
print(f"  Vessel Realism: {VESSEL_REALISM} {'(tapering & fading enabled)' if VESSEL_REALISM else '(uniform curves)'}")
print(f"  Output: StopModule/weights/{OUTPUT_NAME}.pth")


## Visualize Sample Data (Optional)

Before training, let's visualize what the dataset looks like:


In [None]:
# Quick visualization of sample crops
import sys
import os
import numpy as np
import matplotlib.pyplot as plt

# Add project root to path
project_root = "/content/DSA-RL-Tracker"
if project_root not in sys.path:
    sys.path.append(project_root)

from Experiment5_reward_config.src.train import CurveMakerFlexible, crop32, load_curve_config

# Create a sample curve with vessel realism
cfg = load_curve_config()[0]
maker = CurveMakerFlexible(h=128, w=128, config=cfg)

# Generate sample with vessel realism
w_range = (2, 5)
img, mask, pts_all = maker.sample_curve(
    width_range=w_range,
    curvature_factor=1.0,
    noise_prob=0.3,
    invert_prob=0.5,
    width_variation="wide_to_narrow",
    start_width=6,
    end_width=1,
    intensity_variation="bright_to_dim",
    start_intensity=0.9,
    end_intensity=0.3
)

pts = pts_all[0]

# Show endpoint crop (positive sample)
end_pt = pts[-1]
path_mask = np.zeros_like(img)
for p in pts:
    path_mask[int(p[0]), int(p[1])] = 1.0

crop_img = crop32(img, int(end_pt[0]), int(end_pt[1]))
crop_path = crop32(path_mask, int(end_pt[0]), int(end_pt[1]))

# Show midpoint crop (negative sample)
mid_idx = len(pts) // 2
mid_pt = pts[mid_idx]
path_mask_mid = np.zeros_like(img)
for p in pts[:mid_idx+1]:
    path_mask_mid[int(p[0]), int(p[1])] = 1.0

crop_img_mid = crop32(img, int(mid_pt[0]), int(mid_pt[1]))
crop_path_mid = crop32(path_mask_mid, int(mid_pt[0]), int(mid_pt[1]))

# Visualize
fig, axes = plt.subplots(2, 3, figsize=(12, 8))

# Full curve
axes[0, 0].imshow(img, cmap='gray')
axes[0, 0].plot(pts[:, 1], pts[:, 0], 'r-', linewidth=2, alpha=0.5)
axes[0, 0].plot(end_pt[1], end_pt[0], 'go', markersize=10, label='Endpoint')
axes[0, 0].plot(mid_pt[1], mid_pt[0], 'bo', markersize=10, label='Midpoint')
axes[0, 0].set_title('Full Curve (Vessel with Tapering & Fading)')
axes[0, 0].legend()
axes[0, 0].axis('off')

# Endpoint crop (positive)
axes[0, 1].imshow(crop_img, cmap='gray')
axes[0, 1].set_title('Endpoint Crop (STOP - Positive)')
axes[0, 1].axis('off')

axes[0, 2].imshow(crop_path, cmap='jet')
axes[0, 2].set_title('Path Mask at Endpoint')
axes[0, 2].axis('off')

# Midpoint crop (negative)
axes[1, 0].imshow(img, cmap='gray')
axes[1, 0].plot(pts[:, 1], pts[:, 0], 'r-', linewidth=2, alpha=0.5)
axes[1, 0].plot(mid_pt[1], mid_pt[0], 'bo', markersize=10)
axes[1, 0].set_title('Full Curve (Midpoint Highlighted)')
axes[1, 0].axis('off')

axes[1, 1].imshow(crop_img_mid, cmap='gray')
axes[1, 1].set_title('Midpoint Crop (GO - Negative)')
axes[1, 1].axis('off')

axes[1, 2].imshow(crop_path_mid, cmap='jet')
axes[1, 2].set_title('Path Mask at Midpoint')
axes[1, 2].axis('off')

plt.tight_layout()
plt.show()

print("‚úÖ Sample visualization complete!")
print("   Top row: Endpoint (STOP) - narrow and faded")
print("   Bottom row: Midpoint (GO) - wider and brighter")


## Train Stop Detector

This will:
1. Generate balanced dataset (endpoints vs midpoints)
2. Split into train/validation (80/20)
3. Train CNN classifier
4. Save best model based on validation accuracy


In [None]:
# Run training
%cd /content/DSA-RL-Tracker/StopModule

# Build command
cmd = f"python -u src/train_standalone.py --epochs {EPOCHS} --batch_size {BATCH_SIZE} --samples {SAMPLES_PER_CLASS} --lr {LEARNING_RATE} --output weights/{OUTPUT_NAME}.pth"

if not VESSEL_REALISM:
    cmd += " --no_vessel_realism"

print("üöÄ Starting training...")
print(f"Command: {cmd}\n")
!{cmd}


## Training Results Summary


In [None]:
# Check if weights were saved
import os
weight_path = f"/content/DSA-RL-Tracker/StopModule/weights/{OUTPUT_NAME}.pth"

if os.path.exists(weight_path):
    file_size = os.path.getsize(weight_path) / 1024  # KB
    print(f"‚úÖ Model saved successfully!")
    print(f"   Path: {weight_path}")
    print(f"   Size: {file_size:.1f} KB")
else:
    print(f"‚ö†Ô∏è  Weight file not found at: {weight_path}")
    print("   Checking for alternative locations...")
    weights_dir = "/content/DSA-RL-Tracker/StopModule/weights"
    if os.path.exists(weights_dir):
        files = os.listdir(weights_dir)
        if files:
            print(f"   Found files: {files}")
        else:
            print("   No weight files found")


## Test the Trained Model (Optional)

Load and test the model on some sample curves:


In [None]:
# Quick test of the trained model
import torch
import sys
import os
import numpy as np
import matplotlib.pyplot as plt

# Add project root
project_root = "/content/DSA-RL-Tracker"
if project_root not in sys.path:
    sys.path.append(project_root)

from StopModule.src.models import StandaloneStopDetector
from Experiment5_reward_config.src.train import CurveMakerFlexible, crop32, load_curve_config

# Load model
device = "cuda" if torch.cuda.is_available() else "cpu"
model = StandaloneStopDetector().to(device)
weight_path = f"/content/DSA-RL-Tracker/StopModule/weights/{OUTPUT_NAME}.pth"

if os.path.exists(weight_path):
    model.load_state_dict(torch.load(weight_path, map_location=device))
    model.eval()
    print(f"‚úÖ Model loaded from {weight_path}")
    
    # Generate test samples
    cfg = load_curve_config()[0]
    maker = CurveMakerFlexible(h=128, w=128, config=cfg)
    
    # Test on endpoint (should predict STOP)
    img, mask, pts_all = maker.sample_curve(
        width_range=(2, 5),
        width_variation="wide_to_narrow",
        start_width=6,
        end_width=1,
        intensity_variation="bright_to_dim",
        start_intensity=0.9,
        end_intensity=0.3,
        noise_prob=0.3
    )
    pts = pts_all[0]
    end_pt = pts[-1]
    path_mask = np.zeros_like(img)
    for p in pts:
        path_mask[int(p[0]), int(p[1])] = 1.0
    
    crop_img = crop32(img, int(end_pt[0]), int(end_pt[1]))
    crop_path = crop32(path_mask, int(end_pt[0]), int(end_pt[1]))
    crop_input = torch.tensor(np.stack([crop_img, crop_path], axis=0), dtype=torch.float32).unsqueeze(0).to(device)
    
    with torch.no_grad():
        logit = model(crop_input)
        prob = torch.sigmoid(logit).item()
    
    print(f"\nüìä Test Results:")
    print(f"   Endpoint prediction: {prob:.3f} ({'STOP' if prob > 0.5 else 'GO'})")
    print(f"   Expected: STOP (probability should be > 0.5)")
    
    # Visualize
    fig, axes = plt.subplots(1, 3, figsize=(12, 4))
    axes[0].imshow(img, cmap='gray')
    axes[0].plot(pts[:, 1], pts[:, 0], 'r-', linewidth=2, alpha=0.5)
    axes[0].plot(end_pt[1], end_pt[0], 'go', markersize=10)
    axes[0].set_title(f'Full Curve\nPrediction: {"STOP" if prob > 0.5 else "GO"} ({prob:.3f})')
    axes[0].axis('off')
    
    axes[1].imshow(crop_img, cmap='gray')
    axes[1].set_title('Image Crop')
    axes[1].axis('off')
    
    axes[2].imshow(crop_path, cmap='jet')
    axes[2].set_title('Path Mask')
    axes[2].axis('off')
    
    plt.tight_layout()
    plt.show()
else:
    print(f"‚ö†Ô∏è  Model file not found: {weight_path}")
    print("   Please run training first")


## Download Trained Weights

Download the trained model weights to your local machine:


In [None]:
# Download weights
from google.colab import files

weight_path = f"/content/DSA-RL-Tracker/StopModule/weights/{OUTPUT_NAME}.pth"

if os.path.exists(weight_path):
    print(f"üì• Downloading {OUTPUT_NAME}.pth...")
    files.download(weight_path)
    print("‚úÖ Download complete!")
else:
    print(f"‚ö†Ô∏è  File not found: {weight_path}")
    print("   Please check the training output above for the correct path")


## Next Steps

After training, you can:

1. **Use the model in FineTune module**: Load these weights to replace RL-trained stop heads
2. **Integrate into RL training**: Use as a pretrained stop detector
3. **Test on real DSA images**: Use the model for inference on actual vessel images

The model file (`stop_detector_v1.pth`) contains the trained weights and can be loaded with:
```python
from StopModule.src.models import StandaloneStopDetector
model = StandaloneStopDetector()
model.load_state_dict(torch.load('stop_detector_v1.pth'))
```
