In [None]:
import torch
from pathlib import Path
from torch.utils.data import RandomSampler, DataLoader
from torch.optim import Adam
from torchgeo.datasets.utils import stack_samples
import numpy as np

from minerva.loss import AuxCELoss
from minerva.datasets import DFC2020
from minerva.models import MinervaPSP
from minerva.transforms import Normalise, MinervaCompose
from minerva.utils.utils import get_cuda_device

In [None]:
patch_size = (4, 256, 256)
feature_dim = 512
n_classes = 15
batch_size = 16
encoder_name = "resnet18"
encoder_depth = 5
pre_train_path = Path("../cache/SimConv-MkVIII.pt")

In [None]:
device = get_cuda_device(0)

In [None]:
root = Path(input("Path to the root directory containing all the data"))

train_root = root / "DFC2020"

In [None]:
train_dataset = DFC2020(str(train_root), split="test", use_s2hr=True, labels=True, transforms=MinervaCompose({"image": Normalise(4095)}))
val_dataset = DFC2020(str(train_root), split="val", use_s2hr=True, labels=True, transforms=MinervaCompose({"image": Normalise(4095)}))

In [None]:
train_sampler = RandomSampler(train_dataset, num_samples=1096)
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, collate_fn=stack_samples, batch_size=batch_size, num_workers=2)

In [None]:
criterion = AuxCELoss()

model = MinervaPSP(
    criterion,
    patch_size,
    n_classes=1,
    encoder_name=encoder_name,
    encoder_depth=encoder_depth,
    psp_out_channels=feature_dim,
    #backbone_weight_path=backbone_weight_path
)

model.load_state_dict(torch.load(pre_train_path))

In [None]:
model.model.make_segmentation_head(n_classes)

In [None]:
model.model.make_classification_head({"classes": n_classes})

In [None]:
opt = Adam(model.parameters(), lr=1e-3)
model.set_optimiser(opt)
model.to(device)

In [None]:
n_epochs = 50

for epoch in range(n_epochs):
    losses = []
    accs = []
    avg_loss = 0.0
    avg_std = 0.0
    for i, batch in enumerate(train_dataloader):
        images = batch["image"].to(device).float()
        masks = batch["mask"].to(device).float()

        # Uses MinervaModel.step.
        loss, pred = model.step(images, masks, train=True)
        losses.append(loss.item())
        
        accs.append((torch.argmax(pred, 1) == masks).sum().item()) 

        
    print(
        f"Train {epoch}| Loss: {np.mean(losses)}| Accuracy: {np.mean(accs) * 100.0}%"
    )