# Librerias

In [25]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset
from torchvision import transforms
from torch.utils.data import DataLoader

from sklearn.metrics import precision_score, recall_score, f1_score

import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import os
import glob
import math
import random
from tqdm import tqdm
import yaml
import cv2
import time

import onnxruntime as ort
from onnxruntime.quantization import quantize_dynamic, QuantType

# Modelo 

In [13]:
class STELLE_Seg(nn.Module):
    def __init__(self):
        super(STELLE_Seg,self).__init__()

        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels=1,out_channels=16,kernel_size=3,padding=1),
            nn.BatchNorm2d(16),
            nn.LeakyReLU(),
            nn.MaxPool2d(2), #H/2

            nn.Conv2d(16, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(),
            nn.MaxPool2d(2),  # H/4

            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(),
            nn.MaxPool2d(2),  # H/8
        )
        self.middle = nn.Sequential(
            nn.Conv2d(in_channels=64,out_channels=128,kernel_size=3,padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),

            nn.Conv2d(in_channels=128,out_channels=128,kernel_size=3,padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(),

            nn.Conv2d(in_channels=128,out_channels=64,kernel_size=3,padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(),
        )
        self.decoder = nn.Sequential(

            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.LeakyReLU(),
            nn.Upsample(scale_factor=2, mode='nearest'),  # H/4

            nn.Conv2d(64, 32, kernel_size=3, padding=1),
            nn.LeakyReLU(),
            nn.Upsample(scale_factor=2, mode='nearest'),  # H/2

            nn.Conv2d(32, 16, kernel_size=3, padding=1),
            nn.LeakyReLU(),
            nn.Upsample(scale_factor=2, mode='nearest'),  # H
        )

        self.segmentator = nn.Conv2d(16, 14, kernel_size=1)

    def forward(self,x):
        x = self.encoder(x)
        x = self.middle(x)
        x = self.decoder(x)
        x = self.segmentator(x)
        return x # [B, num_classes, H, W]

# Dataset

In [14]:
class STELLE_Data(Dataset):
    def __init__(self, yaml_path, split='train',img_size=(144, 160)):
        with open(yaml_path, 'r') as f:
            data = yaml.safe_load(f)

        assert split in ['train', 'val', 'test'], "Split debe ser 'train', 'val' o 'test'"

        self.image_dir = data[split]
        self.label_dir = self.image_dir.replace('/images', '/labels')
        self.img_size = img_size
        self.num_classes = data['nc']
        self.class_names = data['names']

        self.image_files = [f for f in os.listdir(self.image_dir) if f.endswith(('.jpg', '.png'))]

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_file = self.image_files[idx]
        image_path = os.path.join(self.image_dir, img_file)
        label_path = os.path.join(self.label_dir, os.path.splitext(img_file)[0] + ".txt")

        image = Image.open(image_path).convert("L")
        image = image.resize((self.img_size[1], self.img_size[0]))
        image_tensor = transforms.ToTensor()(image)

        h, w = self.img_size
        mask = np.zeros(self.img_size, dtype=np.uint8)

        if os.path.exists(label_path):
            with open(label_path, 'r') as f:
                for line in f:
                    parts = line.strip().split()
                    # Se espera: al menos 1 token para clase + 6 tokens para 3 puntos (mínimo)
                    if len(parts) < 7:
                        continue
    
                    # El primer token es el ID de clase
                    class_id = int(parts[0])
                    # Extraer las coordenadas normalizadas (x, y)
                    coords = list(map(float, parts[1:]))
                    if len(coords) % 2 != 0:
                        # Si no son pares, la línea no es válida
                        continue
    
                    # Convertir la lista de coordenadas a un array y dar forma (n_puntos, 2)
                    pts = np.array(coords, dtype=np.float32).reshape(-1, 2)
                    # Convertir coordenadas normalizadas a píxeles
                    pts[:, 0] *= w  # x: ancho
                    pts[:, 1] *= h  # y: alto
                    pts = pts.astype(np.int32)
                    # Asegurarse que los puntos queden dentro de la imagen
                    pts[:, 0] = np.clip(pts[:, 0], 0, w - 1)
                    pts[:, 1] = np.clip(pts[:, 1], 0, h - 1)
    
                    # Dibujar el polígono en la máscara usando el valor del ID de clase
                    cv2.fillPoly(mask, [pts], color=class_id+1)
    
        mask_tensor = torch.from_numpy(mask).long()
        return image_tensor, mask_tensor

In [15]:
dataset =STELLE_Data("dataset/data.yaml",split='test')
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

In [41]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = STELLE_Seg()
model.to(device)
checkpoint = torch.load("weights/STELLE_Seg.pth", map_location=device)
model.load_state_dict(checkpoint)
model.eval()

STELLE_Seg(
  (encoder): Sequential(
    (0): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): LeakyReLU(negative_slope=0.01)
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (4): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (5): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): LeakyReLU(negative_slope=0.01)
    (7): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (8): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): LeakyReLU(negative_slope=0.01)
    (11): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (middle): Sequential(
    (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1

# Evaluacion

In [38]:
def pixel_accuracy_multiclass(pred, target):
    # pred: [B, C, H, W] logits o probs
    # target: [B, H, W] con labels 0..C-1 (clase correcta por pixel)
    pred_labels = pred.argmax(dim=1)
    correct = (pred_labels == target).float()
    return correct.sum() / correct.numel()

def iou_multiclass(pred, target, num_classes=14):
    # pred: [B, C, H, W]
    # target: [B, H, W]
    pred_labels = pred.argmax(dim=1)
    ious = []
    for cls in range(num_classes):
        pred_mask = (pred_labels == cls)
        target_mask = (target == cls)
        intersection = (pred_mask & target_mask).float().sum()
        union = (pred_mask | target_mask).float().sum()
        if union == 0:
            ious.append(torch.tensor(1.0))  # clase no presente -> IoU perfecta
        else:
            ious.append(intersection / union)
    return ious # Return list for per-class breaking down later

def precision_multiclass(pred, target, num_classes=14):
    pred_labels = pred.argmax(dim=1)
    precisions = []
    for cls in range(num_classes):
        pred_mask = (pred_labels == cls)
        target_mask = (target == cls)
        tp = (pred_mask & target_mask).sum().float()
        fp = (pred_mask & (~target_mask)).sum().float()
        precisions.append(tp / (tp + fp + 1e-8))
    return precisions # Return list

def recall_multiclass(pred, target, num_classes=14):
    pred_labels = pred.argmax(dim=1)
    recalls = []
    for cls in range(num_classes):
        pred_mask = (pred_labels == cls)
        target_mask = (target == cls)
        tp = (pred_mask & target_mask).sum().float()
        fn = ((~pred_mask) & target_mask).sum().float()
        recalls.append(tp / (tp + fn + 1e-8))
    return recalls # Return list

def f1_score_multiclass(precisions, recalls): # Modified to take lists of prec/rec
    f1s = []
    for p, r in zip(precisions, recalls):
        f1s.append(2 * (p * r) / (p + r + 1e-8))
    return f1s # Return list

In [43]:
def evaluate_model_multiclass_per_class(model, dataset, device, num_classes=14):
    model.eval()

    overall_pixel_accuracies = []

    per_class_ious = [[] for _ in range(num_classes)]
    per_class_precisions = [[] for _ in range(num_classes)]
    per_class_recalls = [[] for _ in range(num_classes)]
    per_class_f1s = [[] for _ in range(num_classes)]

    # New: To store total pixel count for each class across the dataset
    total_class_pixels = torch.zeros(num_classes, dtype=torch.long)

    for i in tqdm(range(len(dataset))):
        img, map_gt, *_ = dataset[i]
        img = img.unsqueeze(0).to(device)

        if map_gt.ndim == 3 and map_gt.shape[0] == num_classes:
            target = map_gt.argmax(dim=0).to(device)
        else:
            target = map_gt.to(device)

        target_cpu = target.cpu() # Move target to CPU once for all calculations

        with torch.no_grad():
            output = model(img)

        output = output.cpu()

        # Accumulate pixel counts for each class from the ground truth
        # Note: map_gt_flat should be 1D with class labels
        unique_labels, counts = torch.unique(target_cpu, return_counts=True)
        for label, count in zip(unique_labels, counts):
            if 0 <= label < num_classes: # Ensure label is within expected range
                total_class_pixels[label] += count.item()


        pixel_acc = pixel_accuracy_multiclass(output, target_cpu)
        overall_pixel_accuracies.append(pixel_acc.item())

        ious_img = iou_multiclass(output, target_cpu, num_classes)
        prec_img = precision_multiclass(output, target_cpu, num_classes)
        rec_img = recall_multiclass(output, target_cpu, num_classes)
        f1_img = f1_score_multiclass(prec_img, rec_img)

        for cls_idx in range(num_classes):
            per_class_ious[cls_idx].append(ious_img[cls_idx].item())
            per_class_precisions[cls_idx].append(prec_img[cls_idx].item())
            per_class_recalls[cls_idx].append(rec_img[cls_idx].item())
            per_class_f1s[cls_idx].append(f1_img[cls_idx].item())

    mean_pixel_acc = np.mean(overall_pixel_accuracies)

    mean_per_class_ious = [np.mean(cls_scores) if cls_scores else 0.0 for cls_scores in per_class_ious] # Handle empty lists for IoU
    mean_per_class_precisions = [np.mean(cls_scores) if cls_scores else 0.0 for cls_scores in per_class_precisions]
    mean_per_class_recalls = [np.mean(cls_scores) if cls_scores else 0.0 for cls_scores in per_class_recalls]
    mean_per_class_f1s = [np.mean(cls_scores) if cls_scores else 0.0 for cls_scores in per_class_f1s]

    # Calculate overall mean (macro-average) across all classes (unweighted)
    overall_mean_iou_unweighted = np.mean(mean_per_class_ious)
    overall_mean_precision_unweighted = np.mean(mean_per_class_precisions)
    overall_mean_recall_unweighted = np.mean(mean_per_class_recalls)
    overall_mean_f1_unweighted = np.mean(mean_per_class_f1s)

    # Calculate WEIGHTED Macro-averages for Precision, Recall, F1
    # Use total_class_pixels as weights. Normalize weights.
    total_pixels_sum = total_class_pixels.sum().item()
    if total_pixels_sum == 0:
        # Avoid division by zero if dataset is empty or all classes are truly empty
        weighted_precision = 0.0
        weighted_recall = 0.0
        weighted_f1 = 0.0
    else:
        # Filter out classes that never appeared (total_class_pixels[cls_idx] == 0)
        # to avoid multiplying by zero weight for non-existent classes.
        # Ensure that scores from non-existent classes (which might be 0.0 if cls_scores is empty)
        # don't get multiplied by a zero weight if they actually contributed a zero score due to 0 TP/FP/FN.
        # It's better to exclude classes that truly have 0 ground truth pixels.

        weighted_precisions_list = []
        weighted_recalls_list = []
        weighted_f1s_list = []
        valid_weights = []

        for cls_idx in range(num_classes):
            if total_class_pixels[cls_idx] > 0: # Only include classes that actually have pixels in the ground truth
                weighted_precisions_list.append(mean_per_class_precisions[cls_idx] * total_class_pixels[cls_idx].item())
                weighted_recalls_list.append(mean_per_class_recalls[cls_idx] * total_class_pixels[cls_idx].item())
                weighted_f1s_list.append(mean_per_class_f1s[cls_idx] * total_class_pixels[cls_idx].item())
                valid_weights.append(total_class_pixels[cls_idx].item())

        if sum(valid_weights) > 0:
            weighted_precision = sum(weighted_precisions_list) / sum(valid_weights)
            weighted_recall = sum(weighted_recalls_list) / sum(valid_weights)
            weighted_f1 = sum(weighted_f1s_list) / sum(valid_weights)
        else:
            weighted_precision = 0.0
            weighted_recall = 0.0
            weighted_f1 = 0.0


    return {
        'overall_pixel_acc': mean_pixel_acc,
        'overall_iou_unweighted': overall_mean_iou_unweighted, # Renamed for clarity
        'overall_precision_unweighted': overall_mean_precision_unweighted,
        'overall_recall_unweighted': overall_mean_recall_unweighted,
        'overall_f1_unweighted': overall_mean_f1_unweighted,
        'overall_precision_weighted': weighted_precision, # New: Weighted Macro Precision
        'overall_recall_weighted': weighted_recall,       # New: Weighted Macro Recall
        'overall_f1_weighted': weighted_f1,               # New: Weighted Macro F1
        'per_class_iou': mean_per_class_ious,
        'per_class_precision': mean_per_class_precisions,
        'per_class_recall': mean_per_class_recalls,
        'per_class_f1': mean_per_class_f1s,
        'total_class_pixels_in_testset': total_class_pixels.tolist() # New: Return pixel counts
    }


In [22]:
def measure_inference_time(model, dataset, device, warmup=5):
    model.eval()
    times = []

    for i in range(len(dataset)):
        img, *_ = dataset[i]
        img = img.unsqueeze(0).to(device)  # (1, 1, H, W)

        if i < warmup:
            with torch.no_grad():
                model(img)
            continue

        start = time.time()
        with torch.no_grad():
            model(img)
        end = time.time()
        times.append(end - start)

    avg_time = np.mean(times)
    return avg_time, 1.0 / avg_time

In [44]:
results = evaluate_model_multiclass_per_class(model, dataset, device, num_classes=14)

print(f"Overall Pixel Accuracy: {results['overall_pixel_acc']:.4f}")
print(f"Overall Mean IoU (Unweighted): {results['overall_iou_unweighted']:.4f}") # Renamed
print(f"Overall Mean Precision (Unweighted): {results['overall_precision_unweighted']:.4f}") # Renamed
print(f"Overall Mean Recall (Unweighted): {results['overall_recall_unweighted']:.4f}") # Renamed
print(f"Overall Mean F1 Score (Unweighted): {results['overall_f1_unweighted']:.4f}") # Renamed

print(f"\nOverall Precision (Weighted by Pixels): {results['overall_precision_weighted']:.4f}") # NEW
print(f"Overall Recall (Weighted by Pixels): {results['overall_recall_weighted']:.4f}")       # NEW
print(f"Overall F1 Score (Weighted by Pixels): {results['overall_f1_weighted']:.4f}")         # NEW

print("\n--- Per-Class Metrics ---")
for cls_idx in range(14):
    print(f"Class {cls_idx}:")
    print(f"  IoU: {results['per_class_iou'][cls_idx]:.4f}")
    print(f"  Precision: {results['per_class_precision'][cls_idx]:.4f}")
    print(f"  Recall: {results['per_class_recall'][cls_idx]:.4f}")
    print(f"  F1 Score: {results['per_class_f1'][cls_idx]:.4f}")

print("\n--- Total Pixels per Class in Testset ---") # NEW
for cls_idx, count in enumerate(results['total_class_pixels_in_testset']):
    print(f"Class {cls_idx}: {count} pixels")


100%|████████████████████████████████████████████████████████████████████████████████| 132/132 [00:02<00:00, 53.12it/s]

Overall Pixel Accuracy: 0.9564
Overall Mean IoU (Unweighted): 0.8970
Overall Mean Precision (Unweighted): 0.3409
Overall Mean Recall (Unweighted): 0.3366
Overall Mean F1 Score (Unweighted): 0.3347

Overall Precision (Weighted by Pixels): 0.7393
Overall Recall (Weighted by Pixels): 0.7445
Overall F1 Score (Weighted by Pixels): 0.7401

--- Per-Class Metrics ---
Class 0:
  IoU: 0.8807
  Precision: 0.8689
  Recall: 0.8653
  F1 Score: 0.8663
Class 1:
  IoU: 0.9366
  Precision: 0.3457
  Recall: 0.3480
  F1 Score: 0.3467
Class 2:
  IoU: 0.8280
  Precision: 0.3001
  Recall: 0.2753
  F1 Score: 0.2785
Class 3:
  IoU: 0.8025
  Precision: 0.8478
  Recall: 0.8883
  F1 Score: 0.8571
Class 4:
  IoU: 0.9820
  Precision: 0.0452
  Recall: 0.0428
  F1 Score: 0.0439
Class 5:
  IoU: 0.8760
  Precision: 0.2701
  Recall: 0.2158
  F1 Score: 0.2332
Class 6:
  IoU: 0.8994
  Precision: 0.8593
  Recall: 0.8787
  F1 Score: 0.8658
Class 7:
  IoU: 0.7496
  Precision: 0.4087
  Recall: 0.4231
  F1 Score: 0.4118
Class 




In [24]:
avg_time, fps = measure_inference_time(model, dataset, device=device)

print("\n=== Tiempo de inferencia ===")
print(f"Tiempo medio por imagen: {avg_time:.4f} segundos")
print(f"Imágenes por segundo (FPS): {fps:.2f}")


=== Tiempo de inferencia ===
Tiempo medio por imagen: 0.0010 segundos
Imágenes por segundo (FPS): 969.01
