In [1]:
# ===========================
# Import Libraries
# ===========================

# Deep Learning & Torch Utilities
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import Dataset, DataLoader, Subset

# Segmentation Models
import segmentation_models_pytorch as smp

# Image Processing & Computer Vision
import cv2
from PIL import Image
import numpy as np
import albumentations as A
from albumentations.pytorch import ToTensorV2

# Scientific Computing & Morphological Operations
from scipy.ndimage import binary_closing, generate_binary_structure

# Utilities
import os
from pathlib import Path
from tqdm import tqdm

  check_for_updates()


In [None]:
class CustomDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.images = os.listdir(image_dir)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        img_path = os.path.join(self.image_dir, self.images[index])
        mask_path = os.path.join(self.mask_dir, self.images[index])
        image = np.array(Image.open(img_path).convert("RGB"))
        mask = np.array(Image.open(mask_path).convert("L"), dtype=np.float32)
        for i in range(mask.shape[0]):  # Parcourir les lignes
            for j in range(mask.shape[1]):  # Parcourir les colonnes
                if 127 < mask[i, j] :  # Si la valeur du pixel est 255
                    mask[i, j] = 1  # Remplace par 1
                else:
                    mask[i, j] = 0  # Remplace par 0


        if self.transform is not None:
            augmentations = self.transform(image=image, mask=mask)
            image = augmentations["image"]
            mask = augmentations["mask"]

        return image, mask

In [None]:
def get_loaders(
    train_dir,
    train_maskdir,
    val_dir,
    val_maskdir,
    batch_size,
    train_transform,
    val_transform,
    num_workers=4,
    pin_memory=True,
):
    train_ds = CustomDataset(
        image_dir=train_dir,
        mask_dir=train_maskdir,
        transform=train_transform,
    )

    train_loader = DataLoader(
        train_ds,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=pin_memory,
        shuffle=True,
    )

    val_ds = CustomDataset(
        image_dir=val_dir,
        mask_dir=val_maskdir,
        transform=val_transform,
    )

    val_loader = DataLoader(
        val_ds,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=pin_memory,
        shuffle=False,
    )

    return train_loader, val_loader

In [13]:
def check_accuracy(loader, model, threshold=0.5, device="cuda"):
    model.eval()
    num_correct = 0
    num_samples = 0
    dice_scores = []
    Jaccard_scores = []
    Recall_scores = []
    Precision_scores = []
    ConfmIndex_scores = []

    with torch.no_grad():
        for x, y in loader:
            x = x.float().to(device)
            y = y.float().unsqueeze(1).to(device)

            # Prédictions
            preds = model(x)

            # Calcul of metrics
            preds_binarized = (torch.sigmoid(preds) > threshold).float()
            FP = (preds_binarized * (1 - y)).sum()  
            FN = (y * (1 - preds_binarized)).sum()
            union = preds_binarized.sum() + y.sum() - (preds_binarized * y).sum() 
            intersection = (preds_binarized * y).sum()
            somme = preds_binarized.sum() + y.sum()
            dice = (2.0 * intersection) / (somme + 1e-6)
            Jaccard = intersection / (union + 1e-6)
            Recall = intersection / (intersection + FN + 1e-6)
            Precision = intersection / (intersection + FP + 1e-6)
            ConfmIndex = 1 - (FP + FN) / (intersection + 1e-6)
            Jaccard_scores.append(Jaccard.item())
            dice_scores.append(dice.item())
            Recall_scores.append(Recall.item())
            Precision_scores.append(Precision.item())
            ConfmIndex_scores.append(ConfmIndex.item())

            # For the accuracy calculation
            num_correct += (preds_binarized == y).sum()
            num_samples += preds.numel()

    # Moyenne des scores de Dice
    avg_ConfmIndex = sum(ConfmIndex_scores) / len(ConfmIndex_scores)
    avg_precision = sum(Precision_scores) / len(Precision_scores)
    avg_recall = sum(Recall_scores) / len(Recall_scores)
    avg_jaccard = sum(Jaccard_scores) / len(Jaccard_scores)
    avg_dice = sum(dice_scores) / len(dice_scores)
    acc = num_correct / num_samples

    print(f"Accuracy: {acc * 100:.2f}%, Average Dice: {avg_dice:.4f}, Average Jaccard: {avg_jaccard:.4f}, Average Recall: {avg_recall:.4f}, Average Precision: {avg_precision:.4f}, Average ConfmIndex: {avg_ConfmIndex:.4f}")
    model.train()


def save_predictions_as_imgs(
    loader, model, folder="saved_images/", device="cuda"
):
    model.eval()
    for idx, (x, y) in enumerate(loader):
        x = x.float().to(device=device)
        with torch.no_grad():
            preds = torch.sigmoid(model(x))
            preds = (preds > 0.5).float()
        torchvision.utils.save_image(
            preds, f"{folder}/pred_{idx}.png"
        )
        torchvision.utils.save_image(y.unsqueeze(1), f"{folder}{idx}.png")
        break

    model.train()


def save_predictions_as_imgs_Post(
    loader, model, size_closing, folder="saved_images/", device="cuda"
):
    model.eval()
    for idx, (x, y) in enumerate(loader):
        x = x.float().to(device=device)
        with torch.no_grad():
            preds = torch.sigmoid(model(x))
            preds = (preds > 0.5).float().cpu().numpy()  # Move to CPU for processing
            
        # Define a 4x4 structuring element
        structuring_element = np.ones((size_closing, size_closing), dtype=bool)
        
        # Apply morphological closing
        preds_closed = torch.tensor([
            binary_closing(pred[0], structure=structuring_element)  # Apply closing to each prediction
            for pred in preds
        ]).unsqueeze(1)  # Add channel dimension back
        
        # Save predictions and ground truth
        torchvision.utils.save_image(
            torch.tensor(preds_closed).float(), f"{folder}/pred_{idx}_closed.png"
        )
        torchvision.utils.save_image(y.unsqueeze(1), f"{folder}{idx}.png")
        break

    model.train()


  from scipy.ndimage.morphology import generate_binary_structure


In [4]:
device = torch.device("cuda:3")
weights_path = "unet_weights_v_2.pth"

# Charger le modèle (assurez-vous que sa structure est identique à celle utilisée pour sauvegarder les poids)
model = smp.Unet(
    encoder_name="resnet34",  # Assurez-vous que cela correspond à l'entraînement
    encoder_weights=None,    # Pas de poids pré-entraînés pour l'encodeur
    in_channels=3,           # Entrée avec 3 canaux (par ex. RGB)
    classes=1                # Une classe pour la segmentation
)

# Charger les poids dans le modèle
model.load_state_dict(torch.load(weights_path, map_location=device))

# Envoyer le modèle sur GPU (si disponible)
model = model.to(device)



  model.load_state_dict(torch.load(weights_path, map_location=device))


In [5]:
BATCH_SIZE = 10000
NUM_WORKERS = 2
PIN_MEMORY = True
LOAD_MODEL = False
TRAIN_IMG_DIR = "./EBHI-SEG-Segmentation-train/image/"
TRAIN_MASK_DIR = "./EBHI-SEG-Segmentation-train/label/"
VAL_IMG_DIR = "./EBHI-SEG-Segmentation-test-by-class/Low-grade IN/image/"
VAL_MASK_DIR = "./EBHI-SEG-Segmentation-test-by-class/Low-grade IN/label/"


train_transform = A.Compose(
        [   A.Resize(256, 256),
            ToTensorV2(),
        ],
    )

val_transforms = A.Compose(
        [   A.Resize(256, 256),
            ToTensorV2(),
        ],
    )

_ , val_loader = get_loaders(
        TRAIN_IMG_DIR,
        TRAIN_MASK_DIR,
        VAL_IMG_DIR,
        VAL_MASK_DIR,
        1,
        train_transform,
        val_transforms,
        NUM_WORKERS,
        PIN_MEMORY,
    )

In [7]:
check_accuracy(val_loader, model, device=device)

Accuracy: 96.30%, Average Dice: 0.9676, Average Jaccard: 0.9373, Average Recall: 0.9630, Average Precision: 0.9723, Average ConfmIndex: 0.9331


In [16]:
save_predictions_as_imgs(
    val_loader, model, folder="saved_images/Results of segmentation v_2/Low-grade IN/", device=device
)

In [15]:
save_predictions_as_imgs_Post(
    val_loader, model, 6, folder="saved_images/", device=device)

  preds_closed = torch.tensor([
  torch.tensor(preds_closed).float(), f"{folder}/pred_{idx}_closed.png"
