In [1]:
from typing import Optional, Tuple, List, Dict, Any, Union
import os
from pathlib import Path
import numpy as np
from glob import glob
import glob
import random
import pandas as pd
import matplotlib.pyplot as plt
import pickle


import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.optim import Adam
from torch.optim.lr_scheduler import OneCycleLR
from torchmetrics.classification import BinaryJaccardIndex, BinaryF1Score, Dice, BinaryPrecision, BinaryRecall
# from torchvision import transforms
import segmentation_models_pytorch as smp

In [2]:
class GeoImageDataset(Dataset):
    def __init__(self, img_dir, mask_dir, transform=None):
        self.img_dir = img_dir
        self.mask_dir = mask_dir
        self.img_files = os.listdir(self.img_dir)
        self.mask_files = os.listdir(self.mask_dir)
        self.transform = transform

    def __len__(self):
        return len(self.img_files)

    def __getitem__(self, idx):
        # Load image
        img_path = os.path.join(self.img_dir, self.img_files[idx])
        # mask and img_file have so far the same name
        mask_path = os.path.join(self.mask_dir, self.img_files[idx])
        img = torch.load(img_path)
        # converts bool mask into integer (0/1)
        mask = torch.load(mask_path).long()
        # Apply transform (if any)
        if self.transform:
            img = self.transform(img)
            mask = self.transform(mask)

        return img, mask

In [10]:
class Trainer:
    def __init__(self, loss_fn, max_lr, epochs, transform):
        # self.model = model
        self.loss_fn = loss_fn
        self.max_lr = max_lr
        self.epochs = epochs
        self.transform = transform

        self.train_loss_steps = []
        self.validation_loss_steps = []
        self.validation_dice_steps = []
        self.validation_f1_epochs = []
        self.validation_precision_epochs = []
        self.validation_recall_epochs = []

        self.train_loss_epochs = []
        self.val_loss_epochs = []
        self.val_dice_epochs = []

        self.device = (
            "cuda"
            if torch.cuda.is_available()
            else "mps"
            if torch.backends.mps.is_available()
            else "cpu"
        )

    def train_and_save(self, model, train_dataloader, val_dataloader, sampling, validated, backbone, batch_size):
        optimizer = Adam(model.parameters(), lr=self.max_lr)

        scheduler = OneCycleLR(
            optimizer,
            max_lr=self.max_lr,
            steps_per_epoch=len(train_dataloader),
            epochs=self.epochs,
        )
        filename = f"u_net_{sampling}_{validated}_{backbone}_{self.epochs}_{batch_size}_{self.max_lr}_{self.transform}"
        for t in range(self.epochs):
            print(f"Epoch {t+1}\n-------------------------------")
            loss, train_loss = self.train(model, train_dataloader, optimizer)
            self.train_loss_steps.extend(train_loss)
            self.train_loss_epochs.append(loss)
            loss, val_loss, dice, dice_list, f1_score, precision, recall = self.test(model, val_dataloader)
            self.validation_loss_steps.extend(val_loss)
            self.val_loss_epochs.append(loss)
            self.val_dice_epochs.append(dice)
            self.validation_dice_steps.extend(dice_list)
            self.validation_f1_epochs.append(f1_score)
            self.validation_precision_epochs.append(precision)
            self.validation_recall_epochs.append(recall)
            if dice >= np.max(self.val_dice_epochs):
                torch.save(
                    model.state_dict(),
                    f"models/{filename}_best_model.pt",
                )
                print("Model saved!")
            scheduler.step()

        
        data_dict = {
            "train_loss_epochs": self.train_loss_epochs,
            "val_loss_epochs": self.val_loss_epochs,
            "val_dice_epochs": self.val_dice_epochs,
            "validation_f1_epochs": self.validation_f1_epochs,
            "validation_precision_epochs": self.validation_precision_epochs,
            "validation_recall_epochs": self.validation_recall_epochs,
        }
        self.save_as_csv(data_dict, f"trainings_results/{filename}")
        steps_dict = {
            "train_loss_steps": self.train_loss_steps,
            "validation_loss_steps": self.validation_loss_steps,
            "validation_dice_steps": self.validation_dice_steps,
        }
        data_dict.update(steps_dict)
        self.save_as_pickle(data_dict, f"trainings_results/{filename}")
        torch.save(model.state_dict(), f"models/{filename}.pt")

    def train(
        self,
        model: nn.Module,
        dataloader: Any,
        optimizer: Any,
    ) -> Union[torch.Tensor, List[float]]:
        size = len(dataloader.dataset)
        loss_vals= []
        model.train()
        for batch, (X, y) in enumerate(dataloader):
            X, y = X.to(self.device), y.to(self.device)

            pred = model(X)
            
            loss = self.loss_fn(pred, y.to(torch.float32)) # pred.squeeze(1)
            loss_vals.append(loss.item())
            # Backpropagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if batch % 20 == 0:
                loss, current = loss.item(), (batch + 1) * len(X)
                
                print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")
        avg_loss = np.mean(loss_vals)
        return avg_loss, loss_vals

    def test(
            self, 
            model: nn.Module, 
            dataloader: Any
        ) -> Union[torch.Tensor, torch.Tensor, List[float]]:
        loss_vals=  []
        Dice_idx_vals = []
        # size = len(dataloader.dataset)
        num_batches = len(dataloader)
        model.eval()
        test_loss = 0
        test_dice = 0
        test_f1 = 0
        test_precision = 0
        test_recall = 0
        metric = Dice(zero_division=1).to(self.device)
        f1_score = BinaryF1Score(multidim_average='global').to(self.device)
        precision = BinaryPrecision(multidim_average='global').to(self.device)
        recall = BinaryRecall(multidim_average='global').to(self.device)

        with torch.no_grad():
            for X, y in dataloader:
                X, y = X.to(device), y.to(device)
                pred = model(X)
                loss = self.loss_fn(pred, y.to(torch.float32)).item() # pred.squeeze(1)
                test_loss += loss
                dice = metric(pred, y.to(torch.int8))
                loss_vals.append(loss)
                test_dice += dice.item()
                Dice_idx_vals.append(dice.item())
                test_f1 += f1_score(pred, y.to(torch.int8)).item()
                test_precision += precision(pred, y.to(torch.int8)).item()
                test_recall += recall(pred, y.to(torch.int8)).item()

        test_loss /= num_batches
        test_dice /= num_batches
        test_f1 /= num_batches
        test_precision /= num_batches
        test_recall /= num_batches
        # correct /= size

        # Dice_idx = 100 * metric(pred, y)
        print(
            f"Test Error: \n"
            f"Dice-Coefficient: {test_dice:>0.2f}, Avg loss: {test_loss:>5f} \n"
        )

        return test_loss, loss_vals, test_dice, Dice_idx_vals, test_f1, test_precision, test_recall

    def save_as_csv(self, data_dict: dict, filename: str) -> None:
        df = pd.DataFrame.from_dict(data_dict)
        df.to_csv(f"{filename}.csv", index=False)

    def save_as_pickle(self, data_dict: dict, filename: str) -> None:
        with open(f"{filename}.pkl", 'wb') as f:
            pickle.dump(data_dict, f)

In [4]:
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

Using cuda device


In [5]:
root = Path(r"C:\Users\Fabian\Documents\Github_Masterthesis\Solarpark-detection\data_local\data_splitted_undersampling_refactored_not_color_cleaned")

train_images_path = Path(root / "train/images")
train_masks_path = Path(root / "train/masks")
val_images_path = Path(root / "val/images")
val_masks_path = Path(root / "val/masks")
test_images_path = Path(root / "test/images")
test_masks_path = Path(root / "test/masks")

In [6]:
transform = None
train_dataset = GeoImageDataset(train_images_path, train_masks_path, transform=transform)
val_dataset = GeoImageDataset(val_images_path, val_masks_path, transform=transform)
test_dataset = GeoImageDataset(test_images_path, test_masks_path, transform=transform)

In [7]:
batch_size = 32
shuffle = True

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=shuffle)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=shuffle)

In [8]:
print(f"Size train set: {train_dataset.__len__()}")
print(f"Size val set: {val_dataset.__len__()}")
print(f"Size test set: {test_dataset.__len__()}")

Size train set: 13472
Size val set: 3790
Size test set: 2213


In [11]:
loss_fn = nn.BCELoss()
sampling = "undersampling"
validated = "cleaned_verified"
epochs = 15
transform = "None"
max_lr = 0.01

for backbone in ["timm-resnest14d"]:
    model = smp.Unet(
        encoder_name=backbone,
        encoder_weights='imagenet',
        in_channels=4,
        classes=1,
        activation='sigmoid',
    ).to(device)
    trainer = Trainer(loss_fn, max_lr=max_lr, epochs=epochs, transform=transform)
    trainer.train_and_save(model, train_dataloader, val_dataloader, sampling, validated, backbone, batch_size)
    print(f"Model {backbone} trained and saved")
print("Done!")

Epoch 1
-------------------------------
loss: 0.883835  [   32/13472]
loss: 0.400794  [  672/13472]
loss: 0.236435  [ 1312/13472]
loss: 0.152067  [ 1952/13472]
loss: 0.108940  [ 2592/13472]
loss: 0.086918  [ 3232/13472]
loss: 0.071116  [ 3872/13472]
loss: 0.066269  [ 4512/13472]
loss: 0.062043  [ 5152/13472]
loss: 0.055427  [ 5792/13472]
loss: 0.042748  [ 6432/13472]
loss: 0.035058  [ 7072/13472]
loss: 0.054732  [ 7712/13472]
loss: 0.026789  [ 8352/13472]
loss: 0.029717  [ 8992/13472]
loss: 0.032248  [ 9632/13472]
loss: 0.039837  [10272/13472]
loss: 0.022128  [10912/13472]
loss: 0.022999  [11552/13472]
loss: 0.018790  [12192/13472]
loss: 0.018710  [12832/13472]
loss: 0.015033  [13472/13472]
Test Error: 
Dice-Coefficient: 0.00, Avg loss: 0.020892 

Model saved!
Epoch 2
-------------------------------
loss: 0.016574  [   32/13472]
loss: 0.016561  [  672/13472]
loss: 0.016449  [ 1312/13472]
loss: 0.014803  [ 1952/13472]
loss: 0.015724  [ 2592/13472]
loss: 0.010109  [ 3232/13472]
loss: 0.0