# Imports :

In [24]:
import os
import math
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader, TensorDataset
from torchvision import transforms, datasets

# Config :

In [25]:
# Configuration
img_size = (128, 128)
num_noise_steps = 450
beta_min = 1e-4
beta_max = 0.02
variance_schedule = np.linspace(beta_min, beta_max, num_noise_steps)

# Diffusion Network

In [26]:
class SinusoidalPositionEncoding(nn.Module):
    def __init__(self, dim, max_len=10000):
        super(SinusoidalPositionEncoding, self).__init__()
        self.dim = dim
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, dim, 2) * -(math.log(10000.0) / dim))
        pe = torch.zeros(max_len, dim)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)
    def forward(self, t):
        return self.pe[t]

class ResidualBlock(nn.Module):
    def __init__(self, num_channels, filter_size, num_groups, name):
        super(ResidualBlock, self).__init__()
        self.norm1 = nn.GroupNorm(num_groups, num_channels)
        self.swish = nn.SiLU()  # Utilisation de SiLU comme approximation de Swish
        self.conv1 = nn.Conv2d(num_channels, num_channels, kernel_size=filter_size, padding=1)
        self.norm2 = nn.GroupNorm(num_groups, num_channels)
        self.conv2 = nn.Conv2d(num_channels, num_channels, kernel_size=filter_size, padding=1)
        self.fc = nn.Linear(256, num_channels)  # Ajuste la taille de l'`embedding` à `num_channels`

    def forward(self, x, embedding):
        # Normalisation et activation
        out = self.norm1(x)
        out = self.swish(out)
        out = self.conv1(out)
        if len(embedding.shape) == 1:  # Si embedding est de la forme [256]
            embedding = embedding.unsqueeze(0)  # Redimensionner en [1, 256] pour un seul batch
        if len(embedding.shape) == 2:  # Si l'`embedding` est (batch_size, 256)
            embedding = self.fc(embedding)  # (batch_size, num_channels)
        embedding = embedding.unsqueeze(-1).unsqueeze(-1)
        embedding = embedding.expand(-1, -1, out.shape[2], out.shape[3])  # Diffusion (batch_size, num_channels, H, W)
        out = out + embedding
        out = self.norm2(out)
        out = self.swish(out)
        out = self.conv2(out)
        return x + out 

class AttentionBlock(nn.Module):
    def __init__(self, num_heads, num_key_channels, num_groups, name):
        super(AttentionBlock, self).__init__()
        self.norm = nn.GroupNorm(num_groups, num_key_channels)
        self.self_attention = nn.MultiheadAttention(num_key_channels, num_heads)

    def forward(self, x):
        B, C, H, W = x.shape
        x_flat = x.view(B, C, H * W).permute(2, 0, 1)  # Reshape en [HW, B, C]
        out, _ = self.self_attention(x_flat, x_flat, x_flat)
        out = out.permute(1, 2, 0).view(B, C, H, W)  # Retour en [B, C, H, W]
        return x + out  # Skip connection



class DiffusionUNet(nn.Module):
    def __init__(self, num_image_channels=1, initial_num_channels=64, num_groups=32, num_heads=1):
        super(DiffusionUNet, self).__init__()
        self.initial_num_channels = initial_num_channels
        self.conv_in = nn.Conv2d(num_image_channels, initial_num_channels, kernel_size=3, padding=1)
        self.res_block1 = ResidualBlock(initial_num_channels, 3, num_groups, "1")
        self.res_block2 = ResidualBlock(initial_num_channels, 3, num_groups, "2")
        self.downsample2 = nn.Conv2d(initial_num_channels, 2 * initial_num_channels, kernel_size=3, padding=1, stride=2)
        self.res_block3 = ResidualBlock(2 * initial_num_channels, 3, num_groups, "3")
        self.attn_block3 = AttentionBlock(num_heads, 2 * initial_num_channels, num_groups, "3")
        self.downsample4 = nn.Conv2d(2 * initial_num_channels, 4 * initial_num_channels, kernel_size=3, padding=1, stride=2)
        self.res_block5 = ResidualBlock(4 * initial_num_channels, 3, num_groups, "5")
        self.res_block7 = ResidualBlock(4 * initial_num_channels, 3, num_groups, "7")
        self.attn_block7 = AttentionBlock(num_heads, 4 * initial_num_channels, num_groups, "7")
        self.upsample4 = nn.ConvTranspose2d(4 * initial_num_channels, 2 * initial_num_channels, kernel_size=2, stride=2)
        self.res_block9 = ResidualBlock(2 * initial_num_channels, 3, num_groups, "9")
        self.upsample2 = nn.ConvTranspose2d(2 * initial_num_channels, initial_num_channels, kernel_size=2, stride=2)
        self.res_block11 = ResidualBlock(initial_num_channels, 3, num_groups, "11")
        self.conv_out = nn.Conv2d(initial_num_channels, num_image_channels, kernel_size=3, padding=1)
        self.position_encoding = SinusoidalPositionEncoding(4 * initial_num_channels)
        self.fc_embed = nn.Sequential(
            nn.Linear(4 * initial_num_channels, 4 * initial_num_channels),
            nn.SiLU(),
            nn.Linear(4 * initial_num_channels, 4 * initial_num_channels)
        )

    def forward(self, x, t):
        # Encode l'étape de bruit
        t_emb = self.position_encoding(t)
        t_emb = self.fc_embed(t_emb)
        x1 = self.conv_in(x)
        x2 = self.res_block1(x1, t_emb)
        x3 = self.res_block2(x2, t_emb)
        x4 = self.downsample2(x3)
        x5 = self.res_block3(x4, t_emb)
        x6 = self.attn_block3(x5)
        x7 = self.downsample4(x6)
        x8 = self.res_block7(x7, t_emb)
        x9 = self.attn_block7(x8)
        x10 = self.upsample4(x9)
        x11 = self.res_block9(x10 + x6, t_emb)
        x12 = self.upsample2(x11)
        x13 = self.res_block11(x12 + x3, t_emb)
        output = self.conv_out(x13 + x1)
        return output


# Bruit

In [27]:
def apply_noise_to_image(img, noise, noise_step, variance_schedule):# Fonction pour appliquer du bruit à une image
    alpha_bar = np.cumprod(1 - variance_schedule)
    alpha_bar_t = alpha_bar[noise_step]
    noisy_img = torch.sqrt(torch.tensor(alpha_bar_t)) * img + torch.sqrt(1 - torch.tensor(alpha_bar_t)) * noise
    return noisy_img

# NetCos et NetSin :

In [28]:
netCos = DiffusionUNet()  # instanciation du modèle
state_dict = torch.load('Models/128Cos.pth')#Chargement du modèle
netCos.load_state_dict(state_dict)# Charger les poids dans le modèle
netCos = netCos.to(torch.device('cuda' if torch.cuda.is_available() else 'cpu'))# Déplacer le modèle vers le bon appareil (GPU ou CPU)

In [29]:
netSin = DiffusionUNet()  # instanciation du modèle
state_dict = torch.load('Models/128Sin.pth')#Chargement du modèle
netSin.load_state_dict(state_dict)# Charger les poids dans le modèle
netSin = netSin.to(torch.device('cuda' if torch.cuda.is_available() else 'cpu'))# Déplacer le modèle vers le bon appareil (GPU ou CPU)

# Diffusion

In [38]:
imageclean = 'Datas/NoiseSIN/sin_ImageBruite1.tif.png'
imagenoise = 'Datas/NoiseCOS/cos_ImageBruite1.tif.png'

# Define the transformations
transform = transforms.Compose([
    transforms.Resize((512, 512)),
    transforms.ToTensor()
])

def load_images_from_directory(img_path, transform):
    if os.path.isfile(img_path):
        image = Image.open(img_path).convert('L')  # Load as grayscale
        image = transform(image)
        return image

# Load the images
Cos_noise_img = load_images_from_directory(imagenoise, transform)
Sin_noise_img = load_images_from_directory(imageclean, transform)
imgs = [clean_img, noise_img]


# Fonction génération d'image 

In [39]:
def generate_image(net, variance_schedule, img):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    with torch.no_grad():
        if len(img.shape) == 3:  # img of shape [num_channels, height, width]
            img = img.unsqueeze(0)  # Becomes [1, num_channels, height, width]

        # Move images to the correct device (GPU or CPU)
        images = img.to(device)

        # Calculate alpha_bar and posterior variance
        alpha_bar = np.cumprod(1 - variance_schedule)
        alpha_bar_prev = np.hstack([1, alpha_bar[:-1]])
        posterior_variance = variance_schedule * (1 - alpha_bar_prev) / (1 - alpha_bar)

        for noise_step in reversed(range(len(variance_schedule))):  # Iterate over 250 steps
            z = torch.zeros_like(images)  # Set z to zero, avoiding the addition of noise

            predicted_noise = net(images, noise_step)
            sqrt_one_minus_beta = torch.sqrt(torch.tensor(1 - variance_schedule[noise_step], device=device))
            pred_noise = (torch.tensor(variance_schedule[noise_step], device=device) * predicted_noise / 
                          torch.sqrt(torch.tensor(1 - alpha_bar[noise_step], device=device)))

            # Update images without adding noise
            images = (1 / sqrt_one_minus_beta) * (images - pred_noise) + \
                     (torch.sqrt(torch.tensor(posterior_variance[noise_step], device=device)) * z)

            # Ensure images are 4D for Conv2D: [batch_size, channels, height, width]
            if images.dim() == 5:
                images = images.squeeze(1)  # Remove the second dimension (if it's size 1)
            if images.dim() != 4:
                raise ValueError(f"Expected 4D input for Conv2D, but got {images.dim()}D tensor.")
    return images.cpu().numpy().squeeze()

# Génération de patch puis diffusion

In [40]:
def diffPatch(net,noise_img):
    # Dimensions de l'image
    image_size = 512
    patch_size = 128
    recouvrement = 64
    
    # Calculer le pas basé sur le recouvrement
    pas = int((recouvrement / 100) * patch_size)
    
    # Créer les matrices pour l'image finale et les poids
    imageFinal = np.zeros((image_size, image_size))
    poidsImage = np.zeros((image_size, image_size))
    
    # Calculer les indices des patches
    patch_indices = [(i, j) for j in range(0, image_size - pas, pas) for i in range(0, image_size - pas, pas)]
    
    # Préparer une liste pour les images générées
    images_générées = []
    
    # Générer toutes les images pour chaque patch
    for (i, j) in patch_indices:
        patch_input = noise_img[:1, i:i + patch_size, j:j + patch_size]  # Extraire le patch
        image_générée = generate_image(net, variance_schedule, patch_input)  # Générer l'image pour le patch
        images_générées.append((i, j, image_générée))  # Stocker l'image générée avec ses indices
    
    # Mettre à jour l'image finale et les poids
    for (i, j, image_générée) in images_générées:
        imageFinal[i:i + patch_size, j:j + patch_size] += image_générée.squeeze()  # Mettre à jour l'image finale
        poidsImage[i:i + patch_size, j:j + patch_size] += 1  # Incrémenter le poids
    
    # Après cela, tu peux calculer la moyenne si nécessaire
    poidsImage[poidsImage == 0] = 1  # Remplacer les zéros par 1 pour éviter la division par zéro
    imageFinal /= poidsImage  # Normalisation
    return imageFinal

# Diffusion sur Cos et Sin pour une image

In [41]:
diff_Sin = diffPatch(netSin,Sin_noise_img)

RuntimeError: The size of tensor a (108) must match the size of tensor b (107) at non-singleton dimension 2

In [None]:
diff_Cos = diffPatch(netCos,Cos_noise_img)

# Reconstitution de la Phase

In [None]:
phase_image = np.angle(diff_Cos + (diff_Sin * 1j))
plt.figure(figsize=(5, 5))  # Ajuster la taille ici (en pouces)
plt.imshow(phase_image, cmap='gray')
plt.title('Image de Phase Reconstituée')
plt.axis('off')
plt.show()