# Imports

In [1]:
import os
import cv2
from pathlib import Path

os.chdir(Path().resolve().parent)

from src.data_processing.loader import SAMDataset, create_dataloader
from src.modeling.models import SAMModel

# Prepare the data

In [None]:
train_dataset = SAMDataset(
    image_dir="data/train_images",
    mask_dir="data/train_labels",
    spacing_metadata_dir="data/metadata/spacing_mm.txt",
    processor="facebook/sam-vit-base",
)

print(f"Number of records: {len(dataset)}")
print(f"Example of one record:")
for k, v in train_dataset[0].items():
    print(f"{k}: {v.shape}")

In [None]:
test_dataset = SAMDataset(
    image_dir="data/train_images",
    bbox_file_dir="data/metadata/test1bbox.txt",
    spacing_metadata_dir="data/metadata/spacing_mm.txt",
    processor="facebook/sam-vit-base",
)

print(f"Number of records: {len(dataset)}")
print(f"Example of one record:")
for k, v in test_dataset[0].items():
    print(f"{k}: {v.shape}")

# Use dataloader

In [None]:
train_dataloader = create_dataloader(
    train_dataset,
    batch_size=36,
    train_ratio=0.8,
    shuffle=True,
    num_workers=2,
)

In [None]:
batch = next(iter(train_dataloader))

print(f"Example of one batch:")
for k, v in batch.items():
    print(f"{k}: {v.shape}")

# Train Model

In [2]:
model = SAMModel(
    model_name="facebook/sam-vit-base",  # Default model name
    device="cpu",  # Use GPU (cuda) if available
    learning_rate=1e-5,  # Default learning rate
    weight_decay=0,  # Default weight decay
)

# model.k_fold_cross_validation(
#     dataloader=train_dataloader,  # SAM DataLoader object
#     k_folds=5,  # Default: 5 folds
#     num_epochs=10,  # Default: 10 epochs per fold
# )