In [None]:
import torch
from pathlib import Path
from torch.utils.data import RandomSampler, DataLoader
from torch.optim import Adam
from torchgeo.datasets.utils import stack_samples
import numpy as np

from minerva.loss import AuxCELoss
from minerva.datasets import DFC2020
from minerva.models import MinervaPSP
from minerva.transforms import Normalise, MinervaCompose
from minerva.utils.utils import get_cuda_device

In [None]:
patch_size = (4, 256, 256)
feature_dim = 512
n_classes = 10
batch_size = 8
encoder_name = "resnet50"
encoder_depth = 5

In [None]:
pre_train_path = Path(input("Path to the pre-trained backbone weights"))

In [None]:
device = get_cuda_device(0)

In [None]:
root = Path(input("Path to the root directory containing all the data"))

train_root = root / "DFC2020"

In [None]:
train_dataset = DFC2020(str(train_root), split="test", use_s2hr=True, labels=True, transforms=MinervaCompose({"image": Normalise(4095)}))
#val_dataset = DFC2020(str(train_root), split="val", use_s2hr=True, labels=True, transforms=MinervaCompose({"image": Normalise(4095)}))

In [None]:
train_sampler = RandomSampler(train_dataset, num_samples=256)
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, collate_fn=stack_samples, batch_size=batch_size, num_workers=2)

In [None]:
pretrain_model = MinervaPSP(
    AuxCELoss(),
    patch_size,
    n_classes=1,
    encoder_name=encoder_name,
    encoder_depth=encoder_depth,
    psp_out_channels=feature_dim,
    #freeze_backbone=True
)

pretrain_model.load_state_dict(torch.load(pre_train_path))

In [None]:
pretrain_model.model.make_segmentation_head(n_classes, upsampling=32, activation=torch.nn.PReLU)

In [None]:
pretrain_model.model.make_classification_head({"classes": n_classes, "activation": torch.nn.PReLU})

In [None]:
pretrain_opt = Adam(pretrain_model.parameters(), lr=1e-3)
pretrain_model.set_optimiser(pretrain_opt)
pretrain_model.to(device)

In [None]:
baseline_model = MinervaPSP(
    AuxCELoss(),
    patch_size,
    n_classes=n_classes,
    encoder_name=encoder_name,
    encoder_depth=encoder_depth,
    psp_out_channels=feature_dim,
    upsampling=32,
    aux_params={"classes": n_classes, "activation": torch.nn.PReLU},
    classification_on=True,
    activation=torch.nn.PReLU,
)

baseline_opt = Adam(baseline_model.parameters(), lr=1e-3)
baseline_model.set_optimiser(baseline_opt)
baseline_model.to(device)

In [None]:
n_epochs = 50

from kornia.color import bgr_to_rgb
from matplotlib.colors import ListedColormap
from minerva.utils.visutils import get_mlp_cmap
import matplotlib.pyplot as plt

cmap_style = ListedColormap(train_dataset.colours.values(), N=n_classes)
cmap = get_mlp_cmap(cmap_style, n_classes)

for epoch in range(n_epochs):
    pretrain_losses = []
    pretrain_accs = []
    baseline_losses = []
    baseline_accs = []
    for i, batch in enumerate(train_dataloader):
        images = batch["image"].to(device).float()
        masks = batch["mask"].to(device).long()

        # Uses MinervaModel.step.
        for name, model, losses, accs in (("pretrain", pretrain_model, pretrain_losses, pretrain_accs), ("baseline", baseline_model, baseline_losses, baseline_accs)):
            loss, pred = model.step(images, masks, train=True)
            losses.append(loss.item())
        
            accs.append((torch.argmax(pred[0], 1) == masks).sum().item())
            
            print(name)
            fig, axs = plt.subplots(3, pred[0].shape[0], figsize=(10, 4))
            for i in range(pred[0].shape[0]):
                axs[0, i].imshow(images[i].cpu().numpy()[:3].transpose(1, 2, 0))
                axs[1, i].imshow(
                    masks[i].cpu().numpy(), cmap=cmap, vmin=0, vmax=n_classes
                )
                axs[2, i].imshow(
                    pred[0][i].detach().argmax(dim=0).cpu().numpy(),
                    cmap=cmap,
                    vmin=0,
                    vmax=n_classes,
                )
            plt.setp(plt.gcf().get_axes(), xticks=[], yticks=[])
            plt.show()

    for name, losses, accs in (("pretrain", pretrain_losses, pretrain_accs), ("baseline", baseline_losses, baseline_accs)):
        print(
            f"Train {epoch} ({name} model)| Loss: {np.mean(losses)}| Accuracy: {np.mean(accs) * 100.0 / (batch_size * patch_size[1] * patch_size[2])}%"
        )