# Imports

In [1]:
# # If Google Colab
# !git clone https://github.com/Samthesimpsons/CS701-Group-09-Project.git
# !pip3 install -r /content/CS701-Group-09-Project/requirements.txt
# !rm -rf /content/CS701-Group-09-Project/data
# !rm -rf /content/sample_data
# !unzip /content/CS701-Group-09-Project/data.zip -d /content/CS701-Group-09-Project/

In [2]:
import os

# os.chdir("/content/CS701-Group-09-Project")
os.chdir("C:\\Users\\samue\\OneDrive\\Desktop\\CS701-Group-09-Project")

In [3]:
import cv2
import shutil
import numpy as np

from google.colab import files
from src.visualization import (
    process_training_ct_scan_metadata,
    process_test_ct_scan_metadata,
    visualize_segmentation_from_numpy_arrays,
    generate_sweetviz_report,
)
from src.loader import SAMSegmentationDataset, create_dataloader
from src.trainer import SAMTrainer
from src.inference import run_SAM_inference_and_save_masks
from src.utils import get_latest_model_path

In [None]:
import torch

if torch.cuda.is_available():
    print("CUDA is available.")
    print(f"Device name: {torch.cuda.get_device_name(0)}")
    print(f"CUDA version: {torch.version.cuda}")
else:
    print("CUDA is not available.")

# Exploratory Data Analysis

In [5]:
# train_data = process_training_ct_scan_metadata(
#     train_images_directory="data/train_images/",
#     train_labels_directory="data/train_labels/",
#     spacing_file_path="data/metadata/spacing_mm.txt",
# )

# train_data.head(5)

In [6]:
# test_data = process_test_ct_scan_metadata(
#     test_images_directory="data/test_images/",
#     spacing_file_path="data/metadata/spacing_mm.txt",
# )

# test_data.head(5)

In [None]:
image = cv2.imread("data/train_images/33/15.png", cv2.IMREAD_GRAYSCALE)
mask = cv2.imread("data/train_labels/33/15.png", cv2.IMREAD_GRAYSCALE)

visualize_segmentation_from_numpy_arrays(image, mask)

In [8]:
# generate_sweetviz_report(
#     train_data, report_filename="results/EDA/train_data_EDA_report.html"
# )

# generate_sweetviz_report(
#     test_data, report_filename="results/EDA/test_data_EDA_report.html"
# )

# Modeling

In [None]:
pretrained_model_name = "wanglab/medsam-vit-base"

train_dataset = SAMSegmentationDataset(
    image_dir="data/train_images",
    mask_dir="data/train_labels",
    spacing_metadata_dir="data/metadata/spacing_mm.txt",
    processor=pretrained_model_name,
)

print(f"Number of records: {len(train_dataset)}")
print(f"Example of one record:")
for k, v in train_dataset[33].items():
    try:
        print(f"{k}: {v.shape}")
    except:
        print(f"{k}: {v}")

test_dataset = SAMSegmentationDataset(
    image_dir="data/test_images",
    bbox_file_dir="data/metadata/test1_bbox.txt",
    spacing_metadata_dir="data/metadata/spacing_mm.txt",
    processor=pretrained_model_name,
)

print("\n====================\n")
print(f"Number of records: {len(test_dataset)}")
print(f"Example of one record:")
for k, v in test_dataset[33].items():
    try:
        print(f"{k}: {v.shape}")
    except:
        print(f"{k}: {v}")

train_dataloader = create_dataloader(
    train_dataset,
    batch_size=3,
    shuffle=True,
    num_workers=2,
)

batch = next(iter(train_dataloader))

print("\n====================\n")
print(f"Example of one batch:")
for k, v in batch.items():
    print(f"{k}: {v.shape}")

In [None]:
trainer = SAMTrainer(
    model_name=pretrained_model_name,
    device="cpu",
    learning_rate=1e-5,
)

In [12]:
trainer.model.k_fold_cross_validation(
    dataloader=train_dataloader,
    k_folds=10,
    num_epochs=200,
)

In [None]:
latest_model_path = get_latest_model_path('models/')

root_model_path, base_model_path = latest_model_path.split("/")

shutil.make_archive(latest_model_path, "zip", root_dir=root_model_path, base_dir=base_model_path)

In [None]:
results = run_SAM_inference_and_save_masks(
    model=trainer.model,
    test_dataset=test_dataset,
    batch_size=1,
    device="cpu",
)

In [None]:
image = cv2.imread("data/test_images/51/25.png", cv2.IMREAD_GRAYSCALE)
mask = cv2.imread("data/test_labels/51/25.png", cv2.IMREAD_GRAYSCALE)

visualize_segmentation_from_numpy_arrays(
    image,
    mask,
    [
        [299, 129, 444, 369],
        [308, 141, 362, 214],
        [110, 204, 332, 313],
        [219, 182, 283, 238],
        [263, 214, 309, 249],
        [183, 205, 230, 247],
        [192, 308, 248, 360],
        [116, 146, 217, 240],
    ],
    from_inference=True,
)

In [None]:
shutil.make_archive("data/test_labels", "zip", root_dir="data", base_dir="test_labels")