<a href="https://colab.research.google.com/github/Lukas-Pupelis/Kursinis/blob/main/Kursinis.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
import cv2 as cv
import torch
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import numpy as np
import tensorflow as tf
import pandas as pd
import torch.nn.functional as F
from tensorflow.keras import backend as K
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, jaccard_score
from scipy.stats import ttest_rel
from torch.utils.data import Dataset, DataLoader
import seaborn as sns
!pip install pretrainedmodels
!pip install efficientnet_pytorch



In [None]:
from google.colab import drive
drive_dir = drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
data_dir = '/content/drive/My Drive/mp'
weight_dir = '/content/drive/My Drive/clam_weights/pre-trained_weights'
model_dir = '/content/drive/My Drive/MP-Net/src/segmentation_models/segmentation_models_pytorch'
root_dir = "/content/drive/My Drive/MP-Net/src/segmentation_models"
images_dir = os.path.join(data_dir, 'images')
masks_dir = os.path.join(data_dir, 'masks')
unet_dir = os.path.join(model_dir, 'unet')
resized_images_dir = os.path.join(data_dir, 'resized_images')
resized_masks_dir = os.path.join(data_dir, 'resized_masks')
inverted_resized_masks_dir = os.path.join(data_dir, 'inverted_resized_masks')

# Unet

In [None]:
import sys
from importlib import import_module

def create_model():

  if root_dir not in sys.path:
      sys.path.append(root_dir)

  # 3. Check existence
  if not os.path.isdir(unet_dir):
      print(f"Error: The 'unet' directory was not found at '{unet_dir}'.")
      sys.exit(1)

  # from segmentation_models_pytorch.unet.model import Unet

  # Dynamically import the UNet class from unet.model
  try:
      unet_model = import_module('segmentation_models_pytorch.unet.model')
      UNet = getattr(unet_model, 'Unet')  # or 'UNet' if spelled that way
  except ModuleNotFoundError as e:
      print(f"Error importing the Unet model: {e}")
      sys.exit(1)
  except AttributeError:
      print("Error: The 'Unet' class was not found in 'model.py'.")
      sys.exit(1)

  model = UNet(encoder_name="resnet101",
      encoder_weights=None,
      in_channels=3,
      classes=1)

  model.load_state_dict(
      torch.load(os.path.join(weight_dir, 'unet4.pth'), map_location='cpu'),
      )

  model.eval()

  return model

# Resizing and saving to Drive

In [None]:
for dir in [images_dir, masks_dir]:
  for i, img_name in enumerate(os.listdir(dir)):
    image_path = os.path.join(dir, img_name)
    img = cv.imread(image_path)
    resized_img = cv.resize(img, (256, 256), interpolation=cv.INTER_AREA)

    cv.imwrite(os.path.join(resized_images_dir if "images" in dir else resized_masks_dir, img_name), resized_img)

# Loading

In [None]:
def preprocess_mask(mask_gray, invert_mask, mask_convention):
    """
    Processes the input grayscale mask and ensures the output is always:
    - Black foreground (0)
    - White background (255)

    Parameters:
    -----------
    mask_gray : np.ndarray
        The grayscale mask with pixel values in [0, 255].
    invert_mask : bool
        Whether to invert the mask (e.g., flip foreground and background).
    mask_convention : {"black_fg", "white_fg"}
        The input mask convention:
        - "black_fg": Black=foreground, White=background.
        - "white_fg": White=foreground, Black=background.

    Returns:
    --------
    mask_bin : np.ndarray
        Binary mask where:
        - Foreground is always black (0).
        - Background is always white (255).
    """
    # Step 1: Ensure input mask is binary (0 or 255)
    if not np.all(np.isin(mask_gray, [0, 255])):
        print(f"[INFO] Non-binary mask detected. Thresholding...")
        mask_gray = np.where(mask_gray > 128, 255, 0).astype(np.uint8)

    # Step 2: Handle input mask convention
    if mask_convention == "white_fg":
        # White=Foreground (255), Black=Background (0) => Convert to Black FG
        mask_gray = 255 - mask_gray  # Flip 0 <-> 255
    elif mask_convention != "black_fg":
        raise ValueError(f"Invalid mask_convention: {mask_convention}. Use 'black_fg' or 'white_fg'.")

    # Step 3: Handle inversion if required
    if invert_mask:
        mask_gray = 255 - mask_gray  # Flip 0 <-> 255 again

    # Step 4: Ensure final output is binary with:
    # Black foreground (0), White background (255)
    return mask_gray

In [None]:
class SegmentationDatasetRGB(Dataset):
    """
    A dataset that reads RGB images and RGB masks from disk, processes the masks,
    and ensures output ground-truth masks are always:
    - Black foreground (0).
    - White background (255).
    """
    def __init__(
        self,
        images_dir,
        masks_dir,
        invert_mask=False,
        mask_convention="black_fg",
        transform=None,
        debug_print=False
    ):
        super().__init__()
        self.images_dir = images_dir
        self.masks_dir = masks_dir
        self.invert_mask = invert_mask
        self.mask_convention = mask_convention.lower().strip()
        self.transform = transform
        self.debug_print = debug_print

        # Collect filenames
        self.image_names = sorted(os.listdir(images_dir))

    def __len__(self):
        return len(self.image_names)

    def __getitem__(self, idx):
        img_name = self.image_names[idx]
        image_path = os.path.join(self.images_dir, img_name)
        mask_path = os.path.join(self.masks_dir, img_name)

        # Load image
        img_bgr = cv.imread(image_path, cv.IMREAD_COLOR)
        if img_bgr is None:
            raise FileNotFoundError(f"Could not read image at {image_path}")
        img_rgb = cv.cvtColor(img_bgr, cv.COLOR_BGR2RGB)

        # Load mask
        mask_bgr = cv.imread(mask_path, cv.IMREAD_COLOR)
        if mask_bgr is None:
            raise FileNotFoundError(f"Could not read mask at {mask_path}")
        mask_gray = cv.cvtColor(mask_bgr, cv.COLOR_BGR2GRAY)

        # Preprocess mask
        mask_bin = preprocess_mask(mask_gray, self.invert_mask, self.mask_convention)

        # Apply augmentations if provided
        if self.transform:
            augmented = self.transform(image=img_rgb, mask=mask_bin)
            img_rgb = augmented["image"]
            mask_bin = augmented["mask"]

        # Convert to PyTorch tensors
        img_tensor = torch.tensor(img_rgb, dtype=torch.float32).permute(2, 0, 1)
        mask_tensor = torch.tensor(mask_bin, dtype=torch.float32).unsqueeze(0)

        if self.debug_print and idx < 1:
            # Display debug information and visualization
            print(f"[DEBUG] Loading: {img_name}")
            print("  image shape:", img_tensor.shape)
            print("  mask shape: ", mask_tensor.shape)
            print("  unique mask values:", torch.unique(mask_tensor))

            plt.figure(figsize=(10, 5))
            plt.subplot(1, 2, 1)
            plt.imshow(mask_gray, cmap="gray")
            plt.title("Original Grayscale Mask")
            plt.axis("off")

            plt.subplot(1, 2, 2)
            plt.imshow(mask_bin, cmap="gray", vmin=0, vmax=255)
            plt.title("Processed Binary Mask (0=Black FG, 255=White BG)")
            plt.axis("off")

            plt.tight_layout()
            plt.show()

        return img_tensor, mask_tensor

In [None]:
def get_dataloader(
    images_dir,
    masks_dir,
    invert_mask=False,
    mask_convention="black_fg",
    batch_size=32,
    debug_print=False
):
    """
    Creates a DataLoader for the SegmentationDatasetRGB.
    """
    dataset = SegmentationDatasetRGB(
        images_dir=images_dir,
        masks_dir=masks_dir,
        invert_mask=invert_mask,
        mask_convention=mask_convention,
        transform=None,     # or your augmentation
        debug_print=debug_print
    )
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    return dataloader

# Confusion matrix

In [None]:
def compute_confusion_matrix(y_true, y_pred):
    """
    Computes confusion matrix for binary segmentation.
    y_true and y_pred must be binary masks of the same shape.
    """
    # Flatten masks to 1D arrays
    y_true = y_true.flatten()
    y_pred = y_pred.flatten()

    # Compute confusion matrix elements
    TP = np.sum((y_true == 1) & (y_pred == 1))
    FP = np.sum((y_true == 0) & (y_pred == 1))
    FN = np.sum((y_true == 1) & (y_pred == 0))
    TN = np.sum((y_true == 0) & (y_pred == 0))

    return np.array([[TN, FP], [FN, TP]])

def plot_confusion_matrix(cm, class_names):
    """
    Plots a confusion matrix using Seaborn heatmap.
    """
    plt.figure(figsize=(6, 5))
    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=class_names, yticklabels=class_names)
    plt.xlabel("Predicted Label")
    plt.ylabel("True Label")
    plt.title("Aggregated Confusion Matrix")
    plt.show()

# Statistics

In [None]:
def calculate_metrics(y_true, y_pred, threshold=0.5):
    """
    Calculates accuracy, precision, recall, F1-score, and IoU for the foreground class only.
    Adds per-class metrics for both background and foreground.
    """
    # Squeeze dimensions if necessary
    if y_true.ndim == 3 and y_true.shape[0] == 1:
        y_true = y_true[0]
    if y_pred.ndim == 3 and y_pred.shape[0] == 1:
        y_pred = y_pred[0]

    # Ensure binary ground truth and predictions
    y_true_bin = (y_true == 0).astype(np.uint8)
    y_pred_bin = (y_pred <= threshold).astype(np.uint8)

    # Debugging information
    print("Unique values in ground truth (binary):", np.unique(y_true_bin))
    print("Unique values in predicted mask (binary):", np.unique(y_pred_bin))

    # Compute True Positives (TP), False Positives (FP), False Negatives (FN), True Negatives (TN)
    TP = np.logical_and(y_pred_bin == 1, y_true_bin == 1).sum()
    FP = np.logical_and(y_pred_bin == 1, y_true_bin == 0).sum()
    FN = np.logical_and(y_pred_bin == 0, y_true_bin == 1).sum()
    TN = np.logical_and(y_pred_bin == 0, y_true_bin == 0).sum()

    # Per-class precision and recall
    precision_bg = TN / (TN + FN) if (TN + FN) > 0 else 0
    recall_bg = TN / (TN + FP) if (TN + FP) > 0 else 0
    precision_fg = TP / (TP + FP) if (TP + FP) > 0 else 0
    recall_fg = TP / (TP + FN) if (TP + FN) > 0 else 0

    # F1-score (foreground only)
    f1_fg = 2 * (precision_fg * recall_fg) / (precision_fg + recall_fg) if (precision_fg + recall_fg) > 0 else 0

    # IoU (foreground only)
    iou_fg = TP / (TP + FP + FN) if (TP + FP + FN) > 0 else 0

    # IoU (background)
    iou_bg = TN / (TN + FP + FN) if (TN + FP + FN) > 0 else 0

    # Balanced accuracy
    balanced_accuracy = (recall_bg + recall_fg) / 2

    return precision_bg, recall_bg, precision_fg, recall_fg, f1_fg, iou_fg, iou_bg, balanced_accuracy

In [64]:
def test_model_and_calculate_statistics(
    model,
    dataloader,
    device="cpu",
    threshold=0.5
):
    """
    Loops over the entire dataset in 'dataloader',
    does inference, and calculates accuracy, precision, recall, f1, iou.
    Returns dict of mean metrics and aggregated confusion matrix.
    """
    model.eval()
    model.to(device)

    metrics_dict = {
        "precision_bg":  [],
        "recall_bg":     [],
        "precision_fg":  [],
        "recall_fg":     [],
        "f1":            [],
        "iou_fg":        [],
        "iou_bg":        [],
        "balanced_acc":  []
    }

    # Initialize aggregated confusion matrix
    total_cm = np.array([[0, 0], [0, 0]])

    with torch.no_grad():
        for images, masks in dataloader:
            images = images.to(device)

            # Forward pass
            outputs = model(images)  # shape (B,1,H,W)
            preds   = torch.sigmoid(outputs).cpu().numpy()  # => [0..1]
            masks   = masks.cpu().numpy()                   # shape (B,1,H,W)

            # Debug information
            print("Preds min:", preds.min(), "max:", preds.max())
            print("Unique values in ground truth masks:", np.unique(masks))

            # Visualize raw ground truth and predictions for debugging
            for i in range(len(masks)):
                plt.figure(figsize=(10, 5))

                plt.subplot(1, 3, 1)
                img_display = images[i].permute(1, 2, 0).cpu().numpy()
                # Normalize image display if needed
                if img_display.max() > 1:
                    img_display = img_display / 255.0
                plt.imshow(img_display, cmap="gray")
                plt.title("Originali nuotrauka")

                plt.subplot(1, 3, 2)
                plt.imshow(masks[i, 0], cmap="gray")
                plt.title("Etaloninė kaukė")

                plt.subplot(1, 3, 3)
                plt.imshow((preds[i, 0] > threshold).astype(np.uint8), cmap="gray")
                plt.title("Prognozuojama kaukė")

                plt.show()

            # Per-sample
            for i in range(len(masks)):
                precision_bg, recall_bg, precision_fg, recall_fg, f1_fg, iou_fg, iou_bg, balanced_acc = calculate_metrics(
                    y_true = masks[i],
                    y_pred = preds[i],
                    threshold = threshold
                )
                metrics_dict["precision_bg"].append(precision_bg)
                metrics_dict["recall_bg"].append(recall_bg)
                metrics_dict["precision_fg"].append(precision_fg)
                metrics_dict["recall_fg"].append(recall_fg)
                metrics_dict["f1"].append(f1_fg)
                metrics_dict["iou_fg"].append(iou_fg)
                metrics_dict["iou_bg"].append(iou_bg)
                metrics_dict["balanced_acc"].append(balanced_acc)

                # Compute confusion matrix for this sample
                cm = compute_confusion_matrix(
                    y_true = (masks[i] == 0).astype(np.uint8),
                    y_pred = (preds[i] <= threshold).astype(np.uint8)
                )
                total_cm += cm  # Aggregate confusion matrix

    # Mean metrics
    mean_metrics = {k: float(np.mean(v)) for k, v in metrics_dict.items()}

    return mean_metrics, total_cm

# Visualizing

In [None]:
def visualize_prediction(
    img_tensor,
    mask_tensor,
    pred_tensor,
    threshold=0.5
):
    """
    Shows side by side: Original Image, Ground-Truth Mask, Predicted Mask.
    - img_tensor:  (3,H,W) in float32
    - mask_tensor: (1,H,W) in {0,1}
    - pred_tensor: (1,H,W) in [0..1]
    """
    img_np  = img_tensor.detach().cpu().numpy()
    mask_np = mask_tensor.detach().cpu().numpy()
    pred_np = pred_tensor.detach().cpu().numpy()

    # Squeeze
    if img_np.ndim == 3 and img_np.shape[0] == 3:
        img_np = np.transpose(img_np, (1,2,0))  # => (H,W,3)
    if mask_np.ndim == 3 and mask_np.shape[0] == 1:
        mask_np = mask_np[0]
    if pred_np.ndim == 3 and pred_np.shape[0] == 1:
        pred_np = pred_np[0]

    # Binarize
    pred_bin = (pred_np > threshold).astype(np.uint8)

    # Scale image for display if in [0..255]
    if img_np.max() > 1:
        img_display = img_np / 255.0
    else:
        img_display = img_np

    fig, axes = plt.subplots(1, 3, figsize=(15,5))
    # Original
    axes[0].imshow(img_display)
    axes[0].set_title("Originali nuotrauka")
    axes[0].axis("off")

    # Ground Truth
    axes[1].imshow(mask_np, cmap="gray", vmin=0, vmax=1)
    axes[1].set_title("Teisinga kaukė")
    axes[1].axis("off")

    # Prediction
    axes[2].imshow(pred_bin, cmap="gray", vmin=0, vmax=1)
    axes[2].set_title("Prognozuojama kaukė")
    axes[2].axis("off")

    plt.tight_layout()
    plt.show()

def visualize_batch(
    model,
    dataloader,
    device="cpu",
    threshold=0.5
):
    """
    Displays the first sample of the first batch:
    (image, ground-truth mask, predicted mask).
    """
    model.eval()
    model.to(device)

    images, masks = next(iter(dataloader))
    images = images.to(device)

    with torch.no_grad():
        outputs = model(images)
        preds = torch.sigmoid(outputs).cpu().numpy()  # shape (B,1,H,W)

    idx = 0
    img_tensor  = images[idx].cpu()   # (3,H,W)
    mask_tensor = masks[idx].cpu()    # (1,H,W)
    pred_tensor = torch.from_numpy(preds[idx])  # (1,H,W)

    visualize_prediction(img_tensor, mask_tensor, pred_tensor, threshold=threshold)

# Executing and table creating

In [None]:
def main():
    # You must provide the paths to your images and masks.
    # Each must have matching filenames (e.g. "001.png" for both).
    images_dir = resized_images_dir
    masks_dir  = resized_masks_dir


    # If your model was trained on black=foreground(0),
    # but your masks are actually white=foreground(255),
    # you might do invert_mask=False + mask_convention="white_fg", etc.
    invert_mask = False
    mask_convention = "white_fg"

    # Create dataloader
    dataloader = get_dataloader(
        images_dir=images_dir,
        masks_dir=masks_dir,
        invert_mask=invert_mask,
        mask_convention=mask_convention,
        batch_size=4,
        debug_print=True  # prints debug info for the first item
    )

    # 2) Create or load your model
    model = create_model()
    # If you have pretrained weights, load them:
    #   model.load_state_dict(torch.load("model_weights.pth"))

    # 3) Visualize a sample batch
    #    (Helps you confirm if the mask is indeed correct after inversion/binarization)
    # Visualize first batch sample
    #visualize_batch(model, dataloader, device="cpu", threshold=0.5)

    # Evaluate across dataset
    metrics, aggregated_cm = test_model_and_calculate_statistics(
        model=model,
        dataloader=dataloader,
        device="cpu",
        threshold=0.1
    )


    # 5) Print the results
    print("===== METRICS =====")
    for k,v in metrics.items():
        print(f"{k}: {v:.4f}")

    # Plot aggregated confusion matrix
    plot_confusion_matrix(aggregated_cm, class_names=["Background", "Foreground"])


if __name__ == "__main__":
    main()

In [None]:
import pandas as pd
from tabulate import tabulate  # Optional, for a nice CLI table

def print_metrics_table(metrics_dict):
    """
    Print metrics in a tabular format using pandas DataFrame.
    """
    # Create a DataFrame from the metrics dictionary
    metrics_df = pd.DataFrame([metrics_dict])

    # Print as a table
    print("Metrics Table:")
    print(tabulate(metrics_df, headers="keys", tablefmt="grid"))


# Example Usage in Model Testing
mean_metrics = {
    #"Accuracy": 0.9984041124064632,
    "Precision": 0.0910,
    "Recall": 0.0205,
    #"F1-Score": 0.9991957977679148,
    "IoU": 0.0190
}

# Call the function to display metrics
print_metrics_table(mean_metrics)

Metrics Table:
+----+-------------+----------+-------+
|    |   Precision |   Recall |   IoU |
|  0 |       0.091 |   0.0205 | 0.019 |
+----+-------------+----------+-------+
