# finestSAM - Train

<a target="_blank" href="https://colab.research.google.com/github/WholeNow/MSSegSAM/blob/main/notebooks/train.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

### Checkpoint Information
The checkpoint downloaded by default is the original **Meta SAM (ViT-B)** model.

> **To use a different checkpoint:**
> 1. Consult the [README](../finestSAM/sav/README.md) in the `finestSAM/sav` directory for download links.
> 2. Update the `finestSAM/config.py` file:
>    - Set `cfg.model.type` to the correct type (e.g., 'vit_h', 'vit_l', 'vit_b').
>    - Set `cfg.model.checkpoint` to the filename of your new checkpoint.

In [None]:
import os

# Environment detection
try:
    import google.colab
    IS_COLAB = True
except ImportError:
    IS_COLAB = False

if IS_COLAB:
    BASE_PATH = "/content"
    
    # Install dependencies
    %pip install -q Lightning segmentation_models_pytorch wget

    # Clone repository
    repo_name = "MSSegSAM"
    repo_url = "https://github.com/WholeNow/MSSegSAM.git"
    repo_path = os.path.join(BASE_PATH, repo_name)

    if not os.path.exists(repo_path):
        !git clone {repo_url}

    BASE_PATH = os.path.join(BASE_PATH, repo_name)
else:
    BASE_PATH = os.getcwd()

import wget

# Checkpoint verification
CHECKPOINT_DIR = os.path.join(BASE_PATH, "finestSAM", "sav")
CHECKPOINT_URL = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth"
CHECKPOINT_PATH = os.path.join(CHECKPOINT_DIR, "sam_vit_b_01ec64.pth")

os.makedirs(CHECKPOINT_DIR, exist_ok=True)

if not os.path.exists(CHECKPOINT_PATH):
    wget.download(CHECKPOINT_URL, out=CHECKPOINT_DIR)

print(f"Environment initialized in: {BASE_PATH}")

# Path definitions
DATA_PREP_DIR = os.path.join(BASE_PATH, "prep_data")
RAW_DATA_DIR = os.path.join(DATA_PREP_DIR, "Datasets_raw")
COCO_DIR = os.path.join(DATA_PREP_DIR, "Datasets_COCO")

print(f"Data Prep Directory: {DATA_PREP_DIR}")
print(f"Raw Data Directory:  {RAW_DATA_DIR}")
print(f"Output Directory:    {COCO_DIR}")

In [None]:
# # Configuration for Dataset Creation
# SCRIPT_PATH = os.path.join(DATA_PREP_DIR, "create_dataset.py")

# INPUT_PATH = os.path.join(RAW_DATA_DIR, "DatasetProcessed")
# DATASET_PATH = os.path.join(COCO_DIR, "dataset") 

# DATASET_IDS = "all"
# MODALITY = "T1"
# SLICE_RANGE = "all"
# REMOVE_EMPTY_SLICES = True

# # Define dataset path for training
# TRAIN_DATASET_PATH = DATASET_PATH

# if os.path.exists(TRAIN_DATASET_PATH):
#     print(f"Dataset already exists at {TRAIN_DATASET_PATH}, creation can be skipped if not needed.")
# else:
#     if os.path.exists(SCRIPT_PATH):
#         flag_empty = "--remove_empty" if REMOVE_EMPTY_SLICES else ""

#         !python "{SCRIPT_PATH}" \
#             --input_dir "{INPUT_PATH}" \
#             --output_dir "{TRAIN_DATASET_PATH}" \
#             --dataset_ids {DATASET_IDS} \
#             --modality {MODALITY} \
#             --slice_range {SLICE_RANGE} \
#             {flag_empty}
#     else:
#         print(f"Error: create_dataset.py not found in {DATA_PREP_DIR}")

In [None]:
DATASET_PATH = os.path.join(COCO_DIR, "dataset") 

In [None]:
# Switch to model directory
os.chdir(BASE_PATH)

print(f"Starting training...")
print(f"Target Dataset: {DATASET_PATH}")

# Run Train
%run finestSAM/__main__.py --mode train --dataset "{DATASET_PATH}"

# Return to base directory
os.chdir(BASE_PATH)