In [None]:
# model architecture to use
ARCHITECTURE = "HYSTO_SEG"
assert ARCHITECTURE in ["HYSTO_SEG", "VITAE_V2", "VITAE_V2_OCR"]

In [None]:
from google.colab import drive
# mount google drive
drive.mount("/content/drive")

In [None]:
# get data
# %%capture # uncomment to mute unzip
!unzip /content/drive/MyDrive/CRAG.zip
#!unzip /content/drive/MyDrive/valid_ori_crag.zip -d /content/valid_ori_crag/
if ARCHITECTURE == "VITAE_V2":
    !unzip /content/drive/MyDrive/vitae_v2_sem_seg.zip
elif ARCHITECTURE == "VITAE_V2_OCR":
    !unzip /content/drive/MyDrive/vitae_v2_sem_seg_ocr.zip
elif ARCHITECTURE == "HYSTO_SEG":
    !unzip /content/drive/MyDrive/histo_seg_pt.zip

In [None]:
import torch
# cuda and torch version
versions = torch.__version__
parts = versions.split("+")
version = parts[0].strip()
cuda_version = parts[1].strip()
print(version, cuda_version)

In [None]:
# check libraries

missing_libraries = []
try:
    import albumentations as A
except ImportError:
    missing_libraries.append("albumentations")
try:
    from thop import profile
except ImportError:
    missing_libraries.append("thop")
if ARCHITECTURE == "VITAE_V2" or ARCHITECTURE == "VITAE_V2_OCR":
    try:
        from functools import partial
    except ImportError:
        missing_libraries.append("functools")
    try:
        from timm.models.layers import trunc_normal_

        # from vitae_v2_sem_seg.vitaev2 import ViTAEv2, CustomSegmentationHead
        from vitae_v2_sem_seg_ocr.vitaev2 import ViTAEv2_OCR
    except ImportError:
        missing_libraries.append("timm")
    try:
        from einops import rearrange
    except ImportError:
        missing_libraries.append("einops")
    try:
        from mmcv.cnn import ConvModule
    except ImportError:
        missing_libraries.append("mmcv-full")
elif ARCHITECTURE == "HISTO_SEG":
    pass

# install missing libraries if any
if missing_libraries:
    print("The following libraries are missing: " + ", ".join(missing_libraries))
    print("Installing missing libraries...")
    if "timm" in missing_libraries:
        !pip install timm
        import timm
        from timm.models.layers import trunc_normal_
    if "functools" in missing_libraries:
        !pip install functools
        from functools import partial
    if "albumentations" in missing_libraries:
        !pip install albumentations
        import albumentations as A
    if "einops" in missing_libraries:
        !pip install einops
        from einops import rearrange
    if "mmcv-full" in missing_libraries:
        !pip install openmim -U
        #!pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/{cuda_vesion}/{version}/index.html
        !mim install mmcv
        from mmcv.cnn import ConvModule
    if "thop" in missing_libraries:
        !pip install thop
        from thop import profile
try:
    if ARCHITECTURE == "VITAE_V2":
        from vitae_v2_sem_seg.vitaev2 import ViTAEv2, CustomSegmentationHead
    elif ARCHITECTURE == "VITAE_V2_OCR":
        from vitae_v2_sem_seg_ocr.vitaev2 import ViTAEv2_OCR as seg_model
    elif ARCHITECTURE == "HYSTO_SEG":
        from histo_seg_pt.histo_seg_pytorch import HistoSeg_ as seg_model
except ImportError:
    raise Exception("Could not import model!")

In [None]:
import cv2
import numpy as np
import os
import matplotlib.pyplot as plt
import random
import torch
import matplotlib.pyplot as plt
from tqdm import tqdm

paths = {
    "train_x": "/content/CRAG/train/Images",
    "train_y": "/content/CRAG/train/Annotation",
    "test_x": "/content/CRAG/test/Images",
    "test_y": "/content/CRAG/test/Annotation",
}

""" test_original = {
    'test_x': '/content/valid_ori_crag/Images_ori',
    'test_y': '/content/valid_ori_crag/Annotation_ori'
} """

# data generator class following pytorch's Dataset class
class DataGen(torch.utils.data.Dataset):
    def __init__(
        self, paths, train=True, info=True, target_size=(256, 256), task="bin"
    ):
        assert task in ["bin", "inst"], "Task must be one of ['bin', 'inst']"
        self.task = task
        self.paths = paths
        self.train = train
        self.info = info
        self.x = (
            self._get_image_paths(self.paths["train_x"])
            if self.train
            else self._get_image_paths(self.paths["test_x"])
        )
        self.y = (
            self._get_image_paths(self.paths["train_y"])
            if self.train
            else self._get_image_paths(self.paths["test_y"])
        )
        self.len = len(self.x)
        self.target_size = target_size

    def _get_image_paths(self, directory):
        image_paths = []
        for filename in os.listdir(directory):
            if filename.endswith(".png") or filename.endswith(".jpg"):
                image_paths.append(os.path.join(directory, filename))
        print(f"Found {len(image_paths)} images in {directory}") if self.info else None
        return image_paths

    def _encode_mask(self, mask, task="bin"):
        assert task in ["bin", "inst"], "task must be one of ['bin', 'inst']"
        if task == "bin":
            mask[mask > 0] = 255
            encoded_mask = mask
        elif task == "inst":
            unique_labels = np.unique(mask)
            encoded_mask = np.zeros_like(mask)
            for idx, label in enumerate(unique_labels):
                if label == 0:  # Skip the background
                    continue
                encoded_mask[mask == label] = idx
        return encoded_mask.astype(np.int64)

    def __len__(self):
        return self.len

    def __getitem__(self, idx):
        image = cv2.imread(self.x[idx], cv2.IMREAD_COLOR)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = cv2.resize(image, self.target_size)
        mask = (
            cv2.imread(self.y[idx], cv2.IMREAD_GRAYSCALE)
            if self.task == "bin"
            else cv2.imread(self.y[idx], cv2.IMREAD_COLOR)
        )
        mask = cv2.resize(mask, self.target_size)
        encoded_mask = self._encode_mask(mask, task=self.task)
        image = torch.tensor(image, dtype=torch.float32)
        mask = torch.tensor(
            encoded_mask, dtype=torch.int64
        )  # Create the mask as an integer tensor
        image = image.permute(
            2, 0, 1
        )  # Permute dimensions to (channels, height, width)
        mask = mask.unsqueeze(0)  # Add a channel dimension for the mask
        image = image / 255.0
        mask = mask / (np.max(encoded_mask) + 1)  # Normalize the mask labels
        return image, mask


train_dataset = DataGen(paths, train=True)
test_dataset = DataGen(paths, train=False)

# display test images and masks
def display_image_mask(image, mask):
    fig, axes = plt.subplots(1, 2, figsize=(10, 5))

    axes[0].imshow(image)
    axes[0].set_title("Image")
    axes[0].axis("off")

    axes[1].imshow(mask, cmap="jet")
    axes[1].set_title("Mask")
    axes[1].axis("off")

    plt.tight_layout()
    plt.show()

# check data integrity
for i in range(1):
    image, mask = train_dataset[random.randint(0, len(train_dataset) - 1)]
    print(f"Image shape: {image.shape}, Min: {image.min()}, Max: {image.max()}")
    print(
        f"Mask shape: {mask.shape}, Min: {mask.min()}, Max: {mask.max()}, Classes: {len(np.unique(mask))}"
    )
    display_image_mask(image.permute(1, 2, 0).numpy(), mask.squeeze().numpy())

for i in range(1):
    image, mask = test_dataset[random.randint(0, len(test_dataset) - 1)]
    print(f"Image shape: {image.shape}, Min: {image.min()}, Max: {image.max()}")
    print(
        f"Mask shape: {mask.shape}, Min: {mask.min()}, Max: {mask.max()}, Classes: {len(np.unique(mask))}"
    )
    display_image_mask(image.permute(1, 2, 0).numpy(), mask.squeeze().numpy())


batch_size = 8
gen_params_train = {
    "batch_size": batch_size,
    "shuffle": True,
    "num_workers": 0,
}

gen_params_test = {
    "batch_size": batch_size,
    "shuffle": False,
    "num_workers": 0,
}

# create data generators
train_generator = torch.utils.data.DataLoader(train_dataset, **gen_params_train)
test_generator = torch.utils.data.DataLoader(test_dataset, **gen_params_test)

In [None]:
import torch
from torch.optim.optimizer import Optimizer
import torch.nn as nn
import torch.nn.functional as F


class Lion(Optimizer):
    def __init__(
        self,
        params,
        lr: float = 1e-4,
        betas: tuple = (0.9, 0.99),
        weight_decay: float = 0.0,
    ):
        assert lr > 0.0
        assert all([0.0 <= beta <= 1.0 for beta in betas])
        defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay)
        super().__init__(params, defaults)

    @torch.no_grad()
    def step(self, closure=None):
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            for p in filter(lambda p: p.grad is not None, group["params"]):
                # parameter
                grad, lr, wd, beta1, beta2, state = (
                    p.grad,
                    group["lr"],
                    group["weight_decay"],
                    *group["betas"],
                    self.state[p],
                )
                if len(state) == 0:
                    state["exp_avg"] = torch.zeros_like(p)
                exp_avg = state["exp_avg"]

                # Lion optimizer
                p.data.mul_(1 - lr * wd)
                update = exp_avg.clone().lerp_(grad, 1 - beta1)
                p.add_(torch.sign(update), alpha=-lr)
                exp_avg.lerp_(grad, 1 - beta2)
        return loss

# metrics
# value methods are used to get the value of the metric after each epoch
# for the training loop
class DiceCoef:
    def __init__(self) -> None:
        self._value = None

    def __call__(self, y_true, y_pred):
        y_true_f = y_true.flatten()
        y_pred_f = y_pred.flatten()
        intersection = torch.sum(y_true_f * y_pred_f)
        smooth = 0.0001
        self._value = (2.0 * intersection + smooth) / (
            torch.sum(y_true_f) + torch.sum(y_pred_f) + smooth
        )
        return self._value

    def value(self):
        return self._value


class Recall:
    def __init__(self) -> None:
        self._value = None

    def __call__(self, y_true, y_pred):
        y_true_f = y_true.flatten()
        y_pred_f = y_pred.flatten()
        intersection = torch.sum(y_true_f * y_pred_f)
        smooth = 0.0001
        self._value = (intersection + smooth) / (torch.sum(y_true_f) + smooth)
        return self._value

    def value(self):
        return self._value


class Precision:
    def __init__(self) -> None:
        self._value = None

    def __call__(self, y_true, y_pred):
        y_true_f = y_true.flatten()
        y_pred_f = y_pred.flatten()
        intersection = torch.sum(y_true_f * y_pred_f)
        smooth = 0.0001
        self._value = (intersection + smooth) / (torch.sum(y_pred_f) + smooth)
        return self._value

    def value(self):
        return self._value


class PrecisionTorchMetrics:
    def __init__(self) -> None:
        self._value = None

    def __call__(self, y_pred, y_true):
        y_pred_f = y_pred.flatten()
        y_true_f = y_true.flatten()
        self._value = torchmetrics.Precision(task="binary")(y_pred_f, y_true_f)
        return self._value

    def value(self):
        return self._value


class F1Score:
    def __init__(self) -> None:
        self._value = None

    def __call__(self, y_true, y_pred):
        y_true_f = y_true.flatten()
        y_pred_f = y_pred.flatten()
        intersection = torch.sum(y_true_f * y_pred_f)
        smooth = 0.0001
        precision = (intersection + smooth) / (torch.sum(y_pred_f) + smooth)
        recall = (intersection + smooth) / (torch.sum(y_true_f) + smooth)
        self._value = 2 * (precision * recall) / (precision + recall)
        return self._value

    def value(self):
        return self._value


class DiceScore:
    def __init__(self) -> None:
        self._value = None

    def __call__(self, y_true, y_pred, ignore_index=None):
        smooth = 1.0
        y_true_f = y_true.clone().view(-1)
        y_pred_f = y_pred.clone().view(-1)
        if ignore_index is not None:
            mask = y_true_f == ignore_index
            y_true_f[mask] = 0
            y_pred_f[mask] = 0
        intersection = torch.sum(y_true_f * y_pred_f)
        self._value = (2.0 * intersection + smooth) / (
            torch.sum(y_true_f) + torch.sum(y_pred_f) + smooth
        )
        return self._value

    def value(self):
        return self._value


class DiceLoss(nn.Module):
    def __init__(self):
        super(DiceLoss, self).__init__()

    def forward(self, inputs, targets, smooth=1):
        # comment out if your model contains a sigmoid or equivalent activation layer
        # inputs = F.sigmoid(inputs)

        inputs = inputs.view(-1)
        targets = targets.view(-1)

        intersection = (inputs * targets).sum()
        dice = (2.0 * intersection + smooth) / (inputs.sum() + targets.sum() + smooth)

        return 1 - dice

# combination loss out of dice loss and binary cross entropy loss
# loss is weighted by counter_weight
class CombinedLoss(nn.Module):
    def __init__(self, counter_weight) -> None:
        super().__init__()
        self.counter_weight = counter_weight
        self.dice_loss = DiceLoss()
        self.bce_loss = nn.BCELoss()

    def forward(self, outputs, targets):
        dice_loss = self.dice_loss(outputs, targets)
        bce_loss = self.bce_loss(outputs, targets)
        return dice_loss + self.counter_weight * bce_loss

# print model information
def test_model_(model):
    dummy_shape = (
        gen_params_train["batch_size"],
        model.in_channels,
        model.img_size,
        model.img_size,
    )
    dummy = torch.randn(dummy_shape)
    out = model(dummy)
    print(f"Output Shape: {out.shape}")
    flops, params = profile(model, inputs=(dummy,))
    print(f"Output Shape: {out.shape}")
    print(f"Min: {torch.min(out)}, Max: {torch.max(out)}")
    print(f"Flops: {flops}, Params: {params} \n")

In [None]:
import time
import datetime
import csv

# write metrics to csv file
def write_to_csv(
    loss,
    metrics,
    epoch,
    header_written=False,
    filename: str = f"train_{datetime.date.today()}.csv",
):
    with open(filename, "a") as f:
        writer = csv.writer(f)
        if not header_written:
            writer.writerow(
                [
                    "epoch",
                    "loss",
                    "dice_coef",
                    "f1",
                    "recall",
                    "precision",
                    "dice_score",
                ]
            )
            header_written = True
        writer.writerow(
            [
                epoch + 1,
                loss,
                metrics["dice_coef"],
                metrics["f1"],
                metrics["recall"],
                metrics["precision"],
                metrics["dice_score"],
            ]
        )
    return header_written

# training loop
def train_step(
    model,
    train_generator,
    metrics,
    loss_fn,
    optimizer,
    device,
    save_path="model_epoch_",
    epochs: int = 10,
    save: bool = False,
    test_after_epoch: bool = False,
    test_generator=None,
    load_saved_model: bool = False,
    saved_model_path: str = None,
):
    if load_saved_model:
        model.load_state_dict(torch.load(saved_model_path))
    model.train()
    total_loss = 0
    total_samples = 0
    start = time.time()
    num_batches = len(train_generator)
    loss_checker = []
    loss_down = 0

    for epoch in range(epochs):
        batch = 0
        total_loss_epoch = 0
        total_samples_epoch = 0
        average_metrics = []
        print(f"Epoch {epoch+1}/{epochs}\n-------------------------------")
        for x, y in train_generator:
            x = x.to(device)
            y = y.to(device)
            batch += 1
            if len(loss_checker) > 1:
                if loss_down < 6:
                    if loss_checker[-1] > loss_checker[-2]:
                        print(
                            f"Loss increased from {loss_checker[-2]} to {loss_checker[-1]}\n",
                            end="",
                            flush=True,
                        )
                        for param_group in optimizer.param_groups:
                            param_group["lr"] = param_group["lr"] * 0.1
                            param_group["weight_decay"] = (
                                param_group["weight_decay"] * 0.1
                            )
                        loss_down += 1
                else:
                    raise Exception("Loss increased too many times")
            optimizer.zero_grad()

            outputs = model(x)
            loss = loss_fn(outputs, y)
            batch_size = x.size(0)

            loss.backward()
            optimizer.step()

            total_loss += loss.item() * batch_size
            total_samples += batch_size
            total_loss_epoch += loss.item() * batch_size
            total_samples_epoch += batch_size

            for metric in metrics.values():
                metric(outputs, y)

            avg_loss = total_loss / total_samples
            avg_loss_epoch = total_loss_epoch / total_samples_epoch
            metric_values = {
                metric_name: metric.value().item()
                for metric_name, metric in metrics.items()
            }
            metric_values = {
                metric_name: round(metric_value, 4)
                for metric_name, metric_value in metric_values.items()
            }
            average_metrics.append(metric_values)

            current_time = time.time()
            elapsed_time = current_time - start
            avg_time_per_batch = elapsed_time / batch
            remainig_batches = num_batches - batch
            remaining_time = remainig_batches * avg_time_per_batch
            print(
                f"\r Batch [{batch}/{num_batches}] Loss: {avg_loss:.4f} Metrics: {metric_values} Elapsed Time: {elapsed_time:.2f}s/{elapsed_time/60:.2f}m Remaining Time (epoch): {remaining_time:.2f}s/{remaining_time/60:.2f}m",
                flush=True,
                end=" ",
            )
            header_written = write_to_csv(
                avg_loss, metric_values, epoch, header_written if batch > 1 else False
            )

        avg_loss = total_loss / total_samples
        avg_loss_epoch = total_loss_epoch / total_samples_epoch
        print(avg_loss)
        avg_metrics = {
            metric_name: sum([metric[metric_name] for metric in average_metrics])
            / len(average_metrics)
            for metric_name in average_metrics[0].keys()
        }
        avg_metrics = {
            metric_name: round(metric_value, 4)
            for metric_name, metric_value in avg_metrics.items()
        }
        loss_checker.append(avg_loss)
        print(
            f"Epoch {epoch + 1} finished. \n Avg Loss: {avg_loss_epoch:.4f} Avg Metrics: {avg_metrics}"
        )
        if save:
            torch.save(model.state_dict(), f"{save_path}{epoch + 1}.pth")
            print(f"Model state dict saved at {save_path}{epoch + 1}.pth")
            torch.save(model, f"{save_path}{epoch + 1}_cm.pth")
            print(f"Complete model saved at {save_path}{epoch + 1}_cm.pth")

        # using test_step to test after each epoch
        if test_after_epoch:
            test_step(model, test_generator, metrics, loss_fn, device)


def test_step(
    model,
    test_generator,
    metrics,
    loss_fn,
    device,
    load_saved_model: bool = False,
    saved_model_path: str = None,
):
    if load_saved_model:
        model.load_state_dict(torch.load(saved_model_path))
    model.eval()
    total_loss = 0
    total_samples = 0
    avg_metrics = []

    with torch.no_grad():
        for x, y in test_generator:
            x = x.to(device)
            y = y.to(device)

            outputs = model(x)
            loss = loss_fn(outputs, y)
            batch_size = x.size(0)

            total_loss += loss.item() * batch_size
            total_samples += batch_size

            for metric in metrics.values():
                metric(outputs, y)

            metric_values = {
                metric_name: metric.value().item()
                for metric_name, metric in metrics.items()
            }
            metric_values = {
                metric_name: round(metric_value, 4)
                for metric_name, metric_value in metric_values.items()
            }
            avg_metrics.append(metric_values)

    avg_metrics = {
        metric_name: sum([metric[metric_name] for metric in avg_metrics])
        / len(avg_metrics)
        for metric_name in avg_metrics[0].keys()
    }
    avg_metrics = {
        metric_name: round(metric_value, 4)
        for metric_name, metric_value in avg_metrics.items()
    }
    avg_loss = total_loss / total_samples
    print(f"Test - Avg Loss (Batch): {avg_loss:.4f} Avg Metrics: {avg_metrics}")
    write_to_csv(
        avg_loss,
        metric_values,
        0,
        header_written=False,
        filename=f"test_{datetime.date.today()}.csv",
    )

# initialize metrics
dice_coef = DiceCoef()
f1 = F1Score()
recall = Recall()
precision = Precision()
dice_score = DiceScore()

# metric library
metrics = {
    "dice_coef": dice_coef,
    "f1": f1,
    "recall": recall,
    "precision": precision,
    "dice_score": dice_score,  # same as dice_coef because of binary segmentation
}


model = seg_model(backbone="xception")

# test_model_(seg_model)


optimizer = Lion(model.parameters(), lr=0.0001, weight_decay=0.00001)
# loss_fn = nn.BCEWithLogitsLoss()
loss_fn = CombinedLoss(counter_weight=0.5) # 0.5 is the counter weight, meaning same importance for both losses
num_epochs = 20
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)


train_step(
    model,
    train_generator,
    metrics,
    loss_fn,
    optimizer,
    device,
    save=True,
    save_path="model_epoch_",
    epochs=num_epochs,
    test_after_epoch=True,
    test_generator=test_generator,
)
# test_step(model, test_data, metrics, loss_fn, device)

In [None]:
# test already trained model
model_path = "/content/model_epoch_19.pth"
model = seg_model(dropout_ratio=0.1).to("cuda")
test_step(
    model,
    test_generator,
    metrics,
    loss_fn,
    device,
    load_saved_model=True,
    saved_model_path=model_path,
)

In [None]:
import shutil

# copy model to google drive
shutil.copy(
    "/content/model_epoch_19_cm.pth", "/content/drive/MyDrive/vitaeocr_eppch_19_cm.pth"
)
shutil.copy(
    "/content/model_epoch_19.pth",
    "/content/drive/MyDrive/vitaeocr_eppch_19_cm_dict.pth",
)