In [1]:
import sys
import os
import matplotlib.pyplot as plt

# Add the src directory to the path. TEMPORARY FIX
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "../..")))

from models.unet.unet import UNet
from src.models.data_management.cnn_formes import CNNFormes
from src.data_processing.dataset_loader import CoastData

  check_for_updates()


In [None]:
# Load the data to split it and save it to a dict
path = os.path.abspath(os.path.join(os.getcwd(), "../../data/processed/"))
data = CoastData(data_path=path, name="arenaldentem")

data_split = data.split_data()

In [None]:
unet = UNet(num_classes=3, experiment_name="test_experiments", use_mlflow=False)

In [None]:
# Load the data to the model
data = unet.load_data(data_split, CNNFormes)

In [None]:
artifact_path = os.path.abspath(os.path.join(path, "../../artifacts/"))

description = ""

# Train the model
unet.train(epochs=2, artifact_path=artifact_path, run_description=description)

## UNet validation

In [13]:
from src.models.metrics import Metrics
from src.models.data_management.data_loader import DataLoaderManager

import cv2

import torch

In [14]:
artifact_path = os.path.abspath(os.path.join(os.getcwd(), "../../artifacts/"))

num_classes = 3
metrics = {
    "train": Metrics(phase="train", num_classes=num_classes, average=None, use_margin=False),
    "validation": Metrics(phase="val", num_classes=num_classes, average=None, use_margin=False),
    "test": Metrics(phase="test", num_classes=num_classes, average=None, use_margin=False) # 'macro'
}

# Load the data to split it and save it to a dict
model = UNet(num_classes=3, experiment_name="test_experiments", use_mlflow=False)

model.load_model(os.path.abspath(os.path.join(artifact_path, "unet/models/best_model.pth")))
# 256x256 -> 2025-03-02-16-06-29
# 352x352 -> 2025-03-14-08-00-53_ducknet

data_path = os.path.abspath(os.path.join(os.getcwd(), "../../data/processed/"))

coast_data = CoastData(data_path)
split = coast_data.split_data()

data = DataLoaderManager.load_data(split)

CoastData: global - 1717 images
Coast: agrelo, Total size: 244
Coast: arenaldentem, Total size: 40
Coast: cadiz, Total size: 946
Coast: cies, Total size: 430
Coast: samarador, Total size: 57


In [15]:
for split in data:
    print(f"Split: {split}")
    for img_path, mask_path in zip(data[split]["images"], data[split]["masks"]):
        pred = model.predict_patch(img_path, combination="max", patch_size = 256, stride = 128)
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
        mask = torch.from_numpy(mask).float() 
        metrics[split].update_metrics(pred, mask)

    metrics[split].compute()
    print(metrics[split].get_last_epoch_info())

Split: train


  num_correct = mask.new_zeros(num_classes).scatter_(0, target, mask, reduce="add")


train metrics: 
	train_accuracy: tensor([1.0000, 0.2753, 0.3915])
	train_f1_score: tensor([0.5453, 0.4282, 0.5420])
	train_precision: tensor([0.3748, 0.9640, 0.8805])
	train_recall: tensor([1.0000, 0.2753, 0.3915])
	train_confusion_matrix: 
		1.0000 0.0000 0.0000
		0.6725 0.2753 0.0523
		0.5980 0.0105 0.3915

Split: validation
val metrics: 
	val_accuracy: tensor([1.0000, 0.2691, 0.3602])
	val_f1_score: tensor([0.5345, 0.4211, 0.5099])
	val_precision: tensor([0.3647, 0.9673, 0.8728])
	val_recall: tensor([1.0000, 0.2691, 0.3602])
	val_confusion_matrix: 
		1.0000 0.0000 0.0000
		0.6796 0.2691 0.0513
		0.6305 0.0093 0.3602

Split: test
test metrics: 
	test_accuracy: tensor([1.0000, 0.2528, 0.3663])
	test_f1_score: tensor([0.5301, 0.3997, 0.5136])
	test_precision: tensor([0.3606, 0.9531, 0.8591])
	test_recall: tensor([1.0000, 0.2528, 0.3663])
	test_confusion_matrix: 
		1.0000 0.0000 0.0000
		0.6889 0.2528 0.0583
		0.6209 0.0128 0.3663

