In [1]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from torchvision.transforms import functional as F
from PIL import Image
import matplotlib.pyplot as plt
from tqdm import tqdm
from skimage.metrics import peak_signal_noise_ratio as psnr

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [2]:



# Dataset pour le débruitage
class DenoisingDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.image_paths = [
            os.path.join(root_dir, img) for img in os.listdir(root_dir)
            if img.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff'))
        ]
        if not self.image_paths:
            raise ValueError(f"Aucune image valide trouvée dans le répertoire: {root_dir}")
        self.transform = transform

    def add_gaussian_noise(self, image, mean=0, std=0.3):
        noise = torch.randn_like(image) * std + mean
        return torch.clamp(image + noise, 0, 1)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)

        noisy_image = self.add_gaussian_noise(image)
        return noisy_image, image


# Transformations pour le dataset
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor()
])

# Chargement du dataset BSD500
train_dataset = DenoisingDataset(root_dir='archive/images/train', transform=transform)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)




In [3]:
# Définition du générateur (CGNet)
class CGNet(nn.Module):
    def __init__(self):
        super(CGNet, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 3, kernel_size=3, stride=1, padding=1),
        )

    def forward(self, x):
        return self.model(x)


# Définition du discriminateur
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 1, kernel_size=4, stride=1, padding=0),
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.model(x)
        x = torch.mean(x, dim=(2, 3))  # Global average pooling
        return self.sigmoid(x)




In [None]:
# Initialisation des modèles
generator = CGNet().to(device)
discriminator = Discriminator().to(device)

# Critères et optimiseurs
pixelwise_loss = nn.MSELoss()
adversarial_loss = nn.BCELoss()
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002)
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002)

# Entraînement
num_epochs = 1
for epoch in range(num_epochs):
    generator.train()
    discriminator.train()
    for noisy_imgs, clean_imgs in tqdm(train_loader):
        noisy_imgs = noisy_imgs.to(device)
        clean_imgs = clean_imgs.to(device)

        # Labels pour le discriminateur
        batch_size = noisy_imgs.size(0)
        valid = torch.ones((batch_size, 1), device=device, requires_grad=False)
        fake = torch.zeros((batch_size, 1), device=device, requires_grad=False)

        # ----- Mise à jour du générateur -----
        optimizer_G.zero_grad()
        generated_imgs = generator(noisy_imgs)
        g_loss_pixelwise = pixelwise_loss(generated_imgs, clean_imgs)
        g_loss_adversarial = adversarial_loss(discriminator(generated_imgs), valid)
        g_loss = 0.001 * g_loss_adversarial + g_loss_pixelwise
        g_loss.backward()
        optimizer_G.step()

        # ----- Mise à jour du discriminateur -----
        optimizer_D.zero_grad()
        real_loss = adversarial_loss(discriminator(clean_imgs), valid)
        fake_loss = adversarial_loss(discriminator(generated_imgs.detach()), fake)
        d_loss = (real_loss + fake_loss) / 2
        d_loss.backward()
        optimizer_D.step()

    print(f"Epoch [{epoch+1}/{num_epochs}] | Generator Loss: {g_loss.item():.4f} | Discriminator Loss: {d_loss.item():.4f}")

# Sauvegarde des modèles
torch.save(generator.state_dict(), 'generator.pth')
torch.save(discriminator.state_dict(), 'discriminator.pth')




In [None]:
# Fonction pour appliquer plusieurs cycles de fermeture et ouverture sur chaque composante RGB
def apply_multiple_morphology_operations(generated_image, iterations=20):
    generated_np = generated_image.permute(1, 2, 0).cpu().numpy() * 255
    transformed_np = np.copy(generated_np)

    kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3))

    for _ in range(iterations):
        for i in range(3):
            channel = transformed_np[:, :, i].astype(np.uint8)
            channel_closed = cv2.morphologyEx(channel, cv2.MORPH_CLOSE, kernel)
            channel_opened = cv2.morphologyEx(channel_closed, cv2.MORPH_OPEN, kernel)
            transformed_np[:, :, i] = channel_opened

    return torch.tensor(transformed_np / 255.0).permute(2, 0, 1)


# Fonction pour fusionner avec les hautes fréquences
def fuse_with_high_frequencies(base_image, refined_image):
    base_np = base_image.permute(1, 2, 0).cpu().numpy() * 255
    refined_np = refined_image.permute(1, 2, 0).cpu().numpy() * 255

    refined_blur = cv2.GaussianBlur(refined_np, (5, 5), 0)
    high_freq = refined_np - refined_blur
    fused_np = np.clip(base_np + high_freq, 0, 255)

    return torch.tensor(fused_np / 255.0).permute(2, 0, 1)


# Fonction pour détecter les contours avec Canny
def detect_canny_edges(image):
    image_np = image.permute(1, 2, 0).cpu().numpy() * 255
    image_gray = cv2.cvtColor(image_np.astype(np.uint8), cv2.COLOR_RGB2GRAY)
    edges = cv2.Canny(image_gray, threshold1=100, threshold2=200)
    return torch.tensor(edges / 255.0).unsqueeze(0)


# Fonction d'affichage
def show_sample_with_metrics(generator, dataset, index, iterations=3):
    generator.eval()

    noisy_img, clean_img = dataset[index]
    noisy_img = noisy_img.unsqueeze(0).to(device)
    clean_img = clean_img.unsqueeze(0).to(device)

    with torch.no_grad():
        generated_img = generator(noisy_img)

    refined_img = apply_multiple_morphology_operations(generated_img.squeeze(), iterations=iterations)
    fused_img = fuse_with_high_frequencies(generated_img.squeeze(), refined_img)

    noisy_edges = detect_canny_edges(noisy_img.squeeze())
    clean_edges = detect_canny_edges(clean_img.squeeze())
    generated_edges = detect_canny_edges(generated_img.squeeze())
    refined_edges = detect_canny_edges(refined_img)
    fused_edges = detect_canny_edges(fused_img)

    # Calcul des PSNR
    psnr_noisy = psnr(clean_img.squeeze().permute(1, 2, 0).cpu().numpy(), noisy_img.squeeze().permute(1, 2, 0).cpu().numpy())
    psnr_generated = psnr(clean_img.squeeze().permute(1, 2, 0).cpu().numpy(), generated_img.squeeze().permute(1, 2, 0).cpu().numpy())
    psnr_refined = psnr(clean_img.squeeze().permute(1, 2, 0).cpu().numpy(), refined_img.permute(1, 2, 0).cpu().numpy())
    psnr_fused = psnr(clean_img.squeeze().permute(1, 2, 0).cpu().numpy(), fused_img.permute(1, 2, 0).cpu().numpy())

    print(f"PSNR Noisy: {psnr_noisy:.2f}, PSNR Generated: {psnr_generated:.2f}, PSNR Refined: {psnr_refined:.2f}, PSNR Fused: {psnr_fused:.2f}")

    plt.figure(figsize=(20, 16))

    images = [
        (noisy_img, noisy_edges, "Noisy Image"),
        (clean_img, clean_edges, "Clean Image"),
        (generated_img, generated_edges, "Generated Image"),
        (refined_img, refined_edges, "Refined Image"),
        (fused_img, fused_edges, "Fused Image"),
    ]

    for i, (image, edge, title) in enumerate(images):
        plt.subplot(2, 5, i + 1)
        plt.title(title)
        plt.imshow(image.squeeze().permute(1, 2, 0).cpu().numpy())
        plt.axis('off')

        plt.subplot(2, 5, i + 6)
        plt.title(f"{title} Edges")
        plt.imshow(edge.squeeze().cpu().numpy(), cmap='gray')
        plt.axis('off')

    plt.show()


# Exemple d'affichage
for idx in range(5):  # Affiche les 5 premiers échantillons
    show_sample_with_metrics(generator, train_dataset, idx, iterations=3)
