# ScarNet Tutorial
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/NedaTavakoli/ScarNet/blob/main/examples/ScarNet_Tutorial.ipynb)

This notebook demonstrates how to use ScarNet for cardiac scar segmentation.


## Setup
First, let's install the required packages and clone the repository.

In [None]:
# Install dependencies
!pip install torch torchvision h5py matplotlib tqdm scikit-learn

# Clone the repository
!git clone https://github.com/NedaTavakoli/ScarNet.git
!cd ScarNet && pip install -r requirements.txt

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import h5py
from tqdm.notebook import tqdm

from scarnet.models.scarnet import ScarNet
from scarnet.data.dataset import CardiacDataset
from scarnet.utils.visualization import Visualizer
from scarnet.config import Config

## Configuration

In [None]:
# Initialize configuration
config = Config()

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

## Load Data
Let's prepare a small dataset for demonstration.

In [None]:
# For demonstration, we'll use a small subset of data
data_path = Path('ScarNet/data')

# Create dataset
dataset = CardiacDataset(
    x_files=sorted(data_path.glob('**/Mag_image/*.h5'))[:5],  # Use 5 samples
    y_files=[Path(str(x).replace('Mag_image', '4layer_mask')) 
             for x in sorted(data_path.glob('**/Mag_image/*.h5'))[:5]],
    imsize=128,
    augment=True,
    config=config
)

# Create dataloader
loader = torch.utils.data.DataLoader(
    dataset,
    batch_size=1,
    shuffle=False
)

## Initialize Model

In [None]:
# Initialize model
model = ScarNet(
    pretrained_path='ScarNet/weights/medsam_vit_b.pth',
    num_classes=4
).to(device)

# Load pretrained weights if available
checkpoint_path = 'ScarNet/weights/scarnet_best.pth'
if Path(checkpoint_path).exists():
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    print('Loaded pretrained weights')

## Training Example

In [None]:
def train_model(model, train_loader, val_loader=None, epochs=10):
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    criterion = torch.nn.CrossEntropyLoss()
    best_dice = 0.0
    
    for epoch in range(epochs):
        # Training
        model.train()
        train_loss = 0
        
        for batch in tqdm(train_loader, desc=f'Epoch {epoch+1}/{epochs}'):
            image = batch['image'].to(device)
            mask = batch['mask'].to(device)
            
            # Forward pass
            output = model(image)
            loss = criterion(output, mask.squeeze(1))
            
            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
        
        avg_train_loss = train_loss / len(train_loader)
        print(f'Epoch {epoch+1}, Training Loss: {avg_train_loss:.4f}')

# Uncomment to train
# train_model(model, loader)