In [None]:
!pip install segmentation-models-pytorch monai pytorch-lightning

In [None]:
import sys
import os
import glob 
import pytorch_lightning as pl
import torch
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
import segmentation_models_pytorch as smp
from sklearn.model_selection import train_test_split
from google.colab import drive

In [None]:
drive.mount("/content/drive")

In [None]:
CHECKPOINT_PATH = ""
PROJECT_PATH = ""
SAVE_PATH = ""
LOG_PATH = ""

TRAIN_INDEX_PATH = ""
VAL_INDEX_PATH = ""

In [None]:
if PROJECT_PATH not in sys.path:
    sys.path.append(PROJECT_PATH)

In [None]:
from module import SegmentationModule
from datamodule import DataModule
from callbacks import get_callbacks
from logger import get_logger
from utils import compute_class_weight_and_get_mask, set_seed, ComboLoss

In [None]:
set_seed(42)

In [None]:
ds_path = "/content/drive/MyDrive/"

In [None]:
all_files = glob.glob(os.path.join(ds_path, "**", "*.h5"), recursive=True)

In [None]:
train_path, val_path = train_test_split(all_files, test_size=0.2, random_state=42)

In [None]:
data_module = DataModule(
    train_data = train_path,
    val_data = val_path,
    batch_size = 64,
    num_workers = 2,
    index_train_path = TRAIN_INDEX_PATH,
    index_val_path = VAL_INDEX_PATH
)

In [None]:
backbone = smp.Unet(
    encoder_name="efficientnet-b2",
    encoder_weights="imagenet",
    in_channels=3,
    classes=4
)

In [None]:
for p in backbone.encoder.parameters():
    p.requires_grad = False

In [None]:
remap = {0:0, 50:1, 100:2, 150:3}

class_weight = compute_class_weight_and_get_mask(file_path=train_path, remap=remap, save_path=TRAIN_INDEX_PATH)

criterion = ComboLoss(num_classes=4, ce_weight=class_weight)

In [None]:
model = SegmentationModule(backbone=backbone, num_classes=4)

model.set_criterion(criterion)

optimizer = optim.AdamW(filter(lambda p: p.requires_grad, backbone.parameters()), lr=1e-3, weight_decay=1e-5)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.2, patience=2)

current_lr = optimizer.param_groups[0]['lr']
print(f"Active lr: {current_lr}")

model.set_optimizer(optimizer, scheduler)

In [None]:
callbacks = get_callbacks(
    dirpath = CHECKPOINT_PATH,
    filename =  "best-warmup-{epoch:02d}-{val_loss:.2f}"
)

logger = get_logger(
    log_dir = LOG_PATH
    name = "segmentation_multiclass_warmup"
)

In [None]:
trainer = pl.Trainer(
    max_epochs = 100
    accelerator = "gpu",
    precission = "16-mixed",
    callbacks = callbacks
    logger = logger,
    log_every_n_step = 10,
    devices = 1 
)

In [None]:
trainer.fit(model, data_module)