In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"

In [3]:
from pathlib import Path
from datetime import datetime
from tqdm.auto import tqdm

from objprint import objstr

from matplotlib import pyplot as plt

In [4]:
import torch
import monai
from monai.utils import ensure_tuple_rep
from accelerate import Accelerator
from timm.optim import optim_factory

In [5]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"

In [5]:
torch.cuda.set_device("cuda:1")

In [6]:
for i in range(torch.cuda.device_count()):
    print(f"{i}:", torch.cuda.get_device_properties(i).name)

0: Tesla P100-PCIE-16GB
1: Tesla P100-PCIE-16GB


In [7]:
from src import utils
from src.loader import get_dataloader
from src.optimizer import LinearWarmupCosineAnnealingLR
from src.SlimUNETR.SlimUNETR import SlimUNETR
from src.utils import Logger, same_seeds, load_config

In [8]:
config, data_flag, is_HepaticVessel = load_config()
config.trainer.batch_size = 8
data_flag

'tbad_dataset'

In [11]:
same_seeds(config.trainer.seed)
logging_dir = Path(os.getcwd()) / "logs" / str(datetime.now()).replace(":","_")
accelerator = Accelerator(
    cpu=False, log_with=["tensorboard"], project_dir=str(logging_dir)
)
Logger(logging_dir if accelerator.is_local_main_process else None)
accelerator.init_trackers('main')
accelerator.print(objstr(config))

accelerator.print("Load Model...")
model = SlimUNETR(**config.slim_unetr)
model.to(accelerator.device)
image_size = config.trainer.image_size

accelerator.print("Load Dataloader...")
train_loader, val_loader, unlab_loader = get_dataloader(config, data_flag)

In [12]:
inference = monai.inferers.SlidingWindowInferer(
    roi_size=ensure_tuple_rep(image_size, 3),
    overlap=0.5,
    sw_device=accelerator.device,
    device=accelerator.device,
)
metrics = {
    "dice_metric": monai.metrics.DiceMetric(
        include_background=True,
        reduction=monai.utils.MetricReduction.MEAN_BATCH,
        get_not_nans=False,
    ),
    # 'hd95_metric': monai.metrics.HausdorffDistanceMetric(percentile=95, include_background=True, reduction=monai.utils.MetricReduction.MEAN_BATCH, get_not_nans=False)
}
post_trans = monai.transforms.Compose(
    [
        # monai.transforms.Activations(sigmoid=True),
        monai.transforms.AsDiscrete(threshold=0.5),
        monai.transforms.AsDiscrete(threshold=0.5),
    ]
)

In [13]:
optimizer = optim_factory.create_optimizer_v2(
    model,
    opt=config.trainer.optimizer,
    weight_decay=config.trainer.weight_decay,
    lr=config.trainer.lr,
    betas=(0.9, 0.95),
)
scheduler = LinearWarmupCosineAnnealingLR(
    optimizer,
    warmup_epochs=config.trainer.warmup,
    max_epochs=config.trainer.num_epochs,
)
loss_functions = {
    "focal_loss": monai.losses.FocalLoss(to_onehot_y=False),
    "dice_loss": monai.losses.DiceLoss(
        smooth_nr=0, smooth_dr=1e-5, to_onehot_y=False, sigmoid=True
    ),
}


In [14]:
  
base_exp_path = f"{os.getcwd()}/model_store/{config.finetune.checkpoint}/seed{config.trainer.seed}/epoch{config.trainer.num_epochs}/"
base_exp_path += f"ims_{config.trainer.image_size}_rot_prob{config.trainer.rot_prob}_lrelu_split_new_class_GD_FL_g{config.trainer.gamma}"

model, starting_epoch, step, val_step = utils.resume_train_state(
    model, base_exp_path, train_loader, accelerator
)
print("Resuming training from epoch {}".format(starting_epoch))


In [17]:
import torch
from torch import nn


class ModelEmaV2(nn.Module):
    def __init__(self, model, decay=0.9999, device=None):
        super(ModelEmaV2, self).__init__()
        # make a copy of the model for accumulating moving average of weights
        self.module = model
        self.module.eval()
        self.decay = decay
        self.device = device  # perform ema on different device from model if set
        if self.device is not None:
            self.module.to(device=device)

    def _update(self, model, update_fn):
        with torch.no_grad():
            for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()):
                if self.device is not None:
                    model_v = model_v.to(device=self.device)
                ema_v.copy_(update_fn(ema_v, model_v))

    def update(self, model):
        self._update(model, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m)



class Trainer:

    def __init__(
        self,
        model,
        train_loader,
        unlab_loader,
        optimizer,
        accelerator,
        loss_functions,
        post_trans,
        unlab_weight,
    ):

        self.model = model
        self.train_loader = train_loader
        self.unlab_loader = unlab_loader

        self.optimizer = optimizer
        self.accelerator = accelerator
        self.loss_functions = loss_functions
        self.post_trans = post_trans
        self.unlab_weight = unlab_weight
        
        self.ema = ModelEmaV2(self.model, device=self.accelerator.device)
        

    def train_labeled_one_epoch(
        self,
        model: torch.nn.Module,
        num_epochs: int,
        epoch: int,
        step: int,
        use_transform: bool
    ):
        device = next(model.parameters()).device
        model.train()
        logits_list = []
        model.train()
        for i, image_batch in enumerate(self.train_loader):
            logits = model(image_batch["image"])
            total_loss, _ = calc_total_loss(logits, image_batch["label"], loss_functions)

            accelerator.log(values={"Train/Total Loss": float(total_loss)}, step=step)
            accelerator.print(
                f"Epoch [{epoch + 1}/{config.trainer.num_epochs}] Training [{i + 1}/{len(train_loader)}] Loss: {total_loss:1.5f}",
                flush=True,
            )
            step += 1

            accelerator.backward(total_loss)
            optimizer.step()
            optimizer.zero_grad()

            val_outputs = [post_trans(i) for i in logits]
            for metric_name in metrics:
                metrics[metric_name](y_pred=val_outputs, y=image_batch["label"])

            unlabeled_loss = self.calculate_loss(logits_list)

            self.accelerator.log(values={"Train/Unlabeled Loss": float(unlabeled_loss)}, step=step)
            self.accelerator.print(
                f"Epoch [{epoch + 1}/{num_epochs}] Training [{i + 1}/{len(unlab_loader)}] Unlab Loss: {unlabeled_loss:1.5f}",
                flush=True,
            )
            step += 1

            self.accelerator.backward(unlabeled_loss)
            self.optimizer.step()
            self.optimizer.zero_grad()

            return step



In [18]:
# from src.unlab.transforms import Transform
from monai.transforms.spatial import array
from typing import Callable, List
import numpy as np
import torch.nn.functional as F

In [19]:
step = 0
for epoch in tqdm(range(starting_epoch, config.trainer.num_epochs)):
    trainer.train_one_epoch(
        model, unlab_loader, config.trainer.num_epochs, epoch, step 
    )

  0%|          | 0/800 [00:00<?, ?it/s]

NameError: name 'trainer' is not defined

In [None]:
device = next(model.parameters()).device
model.train()
step = 0
for i, image_batch in enumerate(tqdm(unlab_loader)):
    img = trainer.center_crop_fc(image_batch[0]["image"].to(device))
    logits = model(img)
    val_outputs = [post_trans(logit) for logit in logits]
    break

label = trainer.center_crop_fc(image_batch[0]["label"].to(device))

In [None]:
pat_ind = 1
slice_num = label.shape[-1]

for i in list(range(slice_num))[::2]:
    fig, ax = plt.subplots(1, 3, figsize=(10, 30))
    im = img[pat_ind][0, ..., i].cpu().numpy().copy()
    lab = label[pat_ind][0, ..., i].cpu().numpy().copy()
    pred = val_outputs[pat_ind][0, ..., i].cpu().detach().numpy().copy()

    ax[0].imshow(im, cmap="gray")
    ax[1].imshow(lab, cmap="gray")
    ax[2].imshow(pred, cmap="gray")
    plt.show()