In [None]:
!pip install torchmetrics
!pip install ptflops thop

In [None]:
from google.colab import drive

drive.mount('/content/gdrive', force_remount=True)

In [None]:
cd /content/gdrive/MyDrive/Conj_Vessel_Extraction/

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from tqdm import tqdm
import numpy as np
import random

from src.data.get_loaders import get_loaders
from src.models import DilTransAttUNet
from src.utils.Metrics import Metrics
from src.utils.model_analysis import analyse_model
from src.utils.losses import DiceBCELoss
from src.utils.config import DTYPE, get_device
from src.utils.config import ACCURACY, AUPRC, AUROC_, DICE_SCORE, F1_SCORE, JACCARD_INDEX, PRECISION, RECALL, SPECIFICITY
from src.inference.patch_inference import run_patchwise_test

from src.solver import Solver

In [None]:
device = get_device()
device, torch.cuda.device_count()

In [None]:
train_loader, val_loader = get_loaders(
    "/content/gdrive/MyDrive/Conj_Vessel_Extraction/data/train",  # your train/val source folder
    train_ratio=0.8,
    batch_size=4,
    size=(512, 512),
    num_workers=2,
)

print(f'Train samples: {len(train_loader.dataset)}')
print(f'Val samples: {len(val_loader.dataset)}')

In [None]:
import warnings
warnings.filterwarnings('ignore')

In [None]:
best_model = None
best_dice_score = 0.0
best_train_loss_batch_history = None
best_train_loss_history = None
best_val_loss_history = None
best_train_metrics = None
best_val_metrics = None

seed = 42
epochs = 2
channels = (3, 32, 64, 128, 256, 512)
is_residual = True
bias = True
learning_rate = 1e-3
weight_decay = 1e-5
lr_reduce_factor = 0.15
patience = 2
model_name = 'DilTransAttUNet'

seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

channels = [3, 32, 64, 128, 256, 512, 1024]
residuals = True
head = 4

model = DilTransAttUNet(channels, head, residuals, bias).to(device)

criterion = DiceBCELoss(device).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200, eta_min=1e-6)

solver = Solver(
    model=model,
    epochs=1,
    device=device,
    train_loader=train_loader,
    val_loader=val_loader,
    criterion=criterion,
    optimizer=optimizer,
    scheduler=scheduler,
    model_name="DilTransAttUNet",
    run_name="resTrue_head4_ch512",
    save_dir="../saved_models",
    save_each_epoch=False,
)

result = solver.fit()

In [None]:
stats = analyse_model(model, device, input_size=(1, 3, 512, 512))
for k,v in stats.items():
    if isinstance(v, float):
        print(f"{k}: {v:.3f}")
    else:
        print(f"{k}: {v}")

In [None]:
device = get_device()

metrics = Metrics(device=device, threshold=0.5)

TEST_DIR = "/content/gdrive/MyDrive/Conj_Vessel_Extraction/data/test"

results = run_patchwise_test(
    model = model,
    data_dir = TEST_DIR,  # single root dir
    device = device,
    metrics = metrics,
    # save_dir = "/.../preds",     # optional
    patch_size = 512,
    stride = 482,
    threshold = 0.5,
)

print("\n===== PATCHWISE TEST RESULTS =====")
for k, v in results.items():
    print(f"{k:12}: {v:.4f}")