In [8]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

import torchvision
import torchvision.transforms as transforms

import matplotlib.pyplot as plt
import nibabel as nib
import numpy as np
import math
import sys

import random
from scipy.stats import rice

In [2]:
loss = nn.MSELoss()
input = torch.randn(3, 5, requires_grad=True)
target = torch.randn(3, 5)
output = loss(input, target)
print(output)
output.backward()

tensor(1.3785, grad_fn=<MseLossBackward0>)


In [33]:
def format_data_numpy(noised,denoised):
    data = []
    for i in range(denoised.shape[0]):
        pair = np.stack((noised[i],denoised[i]),axis=0)
        data.append(pair)
        
    data = np.array(data)
    return data

In [37]:
def format_data(noised, denoised):
    assert noised.shape == denoised.shape, "Input tensors must have the same shape."
    data = torch.stack((noised, denoised), dim=1)
    return data

In [53]:
def separate_data(data):
    assert data.shape[1] == 2, "Input tensor must have as second dim = 2 shape[1]"
    noised, denoised = torch.split(data, 1,dim=1)
    noised = noised.squeeze(1)
    denoised = denoised.squeeze(1)
    return noised, denoised 

General parameters of models :

In [32]:
lr = 0.0002

## Generator :
The generator takes denoised images and returnes noised images

In [3]:
#Generateur
#nb_feat : c'est le nombre de features que notre modele va gere 
class Generator(nn.Module):
    def __init__(self, nb_feat):
        super(Generator, self).__init__()
        self.nb_feat = nb_feat
        self.main = nn.Sequential(
            nn.Linear(nb_feat,100),
            nn.Tanh(),
            nn.Linear(100,nb_feat),
            nn.Tanh()
    )

    def forward(self, x):
        return self.main(x)

## Denoiser

## Discriminator 
Takes two images and has to say if this pair of images is a real one or if it's fake => it takes the stack of the 2 images

In [35]:
class Discriminator(nn.Module):
    def __init__(self,nb_feat):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Linear(nb_feat,100),
            nn.Tanh(),
            nn.Linear(100,nb_feat),
            nn.Tanh()
        )

    def forward(self, x):
        return self.main(x)

In [37]:
generator = Generator(50)
discriminator = Discriminator(50)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
generator = generator.to(device)
discriminator = discriminator.to(device)

In [None]:
class CustomCrossEntropyLoss(nn.Module):
    def __init__(self):
        super(CustomCrossEntropyLoss, self).__init__()

    def forward(self, logits, target):
        probabilities = torch.softmax(logits, dim=1)
        true_class_probabilities = probabilities[range(logits.size(0)), target]
        loss = -torch.log(true_class_probabilities + 1e-10)  # Avoid log(0)
        return loss.mean()

In [None]:
class CustomBCEWithLogitsLoss(nn.Module):
    def __init__(self, apha=0.5):
        super(CustomBCEWithLogitsLoss, self).__init__()
        self.pos_weight = pos_weight

    def forward(self, inputs, target):
        # Ensure input and target are of the same shape
        if inputs.shape != target.shape:
            raise ValueError("Input and target must have the same shape")

        # Compute the sigmoid of the input logits
        sigmoid_input = torch.sigmoid(inputs)

        # Compute the binary cross-entropy loss
        loss = - (target * torch.log(sigmoid_input + 1e-12) + 
                  (1 - target) * torch.log(1 - sigmoid_input + 1e-12))

        # Apply positive weight if provided
        if self.pos_weight is not None:
            loss = loss * (self.pos_weight * target + (1 - target))

        return loss.mean()  # Return the mean loss

The loss function follows the entropie :

In [None]:
class FLoss(nn.Module):
    def __init__(self,alpha = 0.5):
        super(FLoss, self).__init__()
        self.alpha = alpha
        
    def forward(self, logits, labels):
        # Example: Custom loss combining standard GAN loss with an additional term
        if logits.shape != labels.shape:
            raise ValueError("Input and labels must have the same shape")
            
        sigmoid_logits = torch.sigmoid(logits) 
        loss = ( labels * torch.log(sigmoid_logits + 1e-12) - 
                  alpha*(1 - labels) * torch.log(1 - sigmoid_logits + 1e-12))

        
        # You can add custom terms here if needed
        # For example, you could add a regularization term or a penalty
        # loss_g += some_custom_penalty
        
        return loss_d, loss_g


Now that we have our loss function we want to initialise the optimizer of our models, that we are going to add in the train function

In [38]:
#optimizer_g = torch.optim.Adam(generator.parameters(), lr=lr)
#optimizer_d = torch.optim.Adam(discriminator.parameters(), lr=lr)

Then we have the training loop : \
We know that our dataloader gives us the truth so all the labels are set to true/1

Note : my data contain the denoised image and the the noised one stack behind. Also we have 

In [None]:
def train(data_loader,num_epochs,g_model,d_model,r_model,nb_feat):#
    
    optimizer_g = torch.optim.Adam(g_model.parameters(), lr=lr)
    optimizer_d = torch.optim.Adam(d_model.parameters(), lr=lr)
    optimizer_r = torch.optim.Adam(r_model.parameters(), lr=lr)
    
    g_model = g_model.to(device)
    d_model = d_model.to(device)
    r_model = r_model.to(device)
    
    #while loss small engouh
    
    for e in range(num_epochs):
        
        for i, (real_data,_) in enumerate(data_loader):
            
            #place it on the machine 
            real_data = real_data.to(device)
            #def separate to noised and denoised
            noised,denoised = separate_data(real_data)
            
            #labels
            fake_labels = torch.zeros(real_data.size(0), 1, device=device)
            valid_labels = torch.ones(real_data.size(0)*3, 1, device=device)
            
            
            ########################
            #Discriminator Training# 
            ########################
            
            optimizer_d.zero_grad()
            
            #optimize the discriminator by giving the real images
            
            outputs = d_model(real_data)
            lossD_valid = fLoss(outputs, valid_labels)
            lossD_valid.backward()
            
            #generate fake data (start by giving a random noise)
            z = torch.randn(batch_size, nb_feat, device=device)#TODO : fix the size depending of the dataloaders
            fake_data_noised = g_model(z)
            fake_data_denoised = r_model(noised_data)
            
            #stack fake_data_ = fake_data_noised + fake_data_denoised (format_data(noised, denoised))
            #We have 4 different types of data : 
            #(x,y):real data with labels 1 -> right/true
            
            #the labels for all of these will be zero -> wrong/false
            #(x_hat,y_hat):fake data
            fake_fake = format_data(fake_data_noised,fake_data_denoised)
            #(x_hat,y) : fake real
            fake_real = format_data(fake_data_noised,denoised)
            #(x,y_hat) : real fake
            real_fake = format_data(noised,fake_data_denoised)
            
            seperated_data =[fake_fake,fake_real,real_fake]
            all_fake_data = torch.cat(seperated_data, dim=0)
            
            #optimize the discriminator by giving the generated images
            outputs = d_model(all_fake_data)
            lossD_f = fLoss(outputs, fake_labels)
            lossD_f.backward()
            optimizer_d.step()
            
            # Print losses
            if i % 100 == 0:
                print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(data_loader)}], '
                  f'Discriminator Loss: {lossD_valid.item() + lossD_f.item():.4f}')
            
        ####################
        # Update Generator # 
        ####################
            
        optimizer_g.zero_grad()
        
        #we fix the denoiser and the discriminator
        
        #fake_data_denoised = r_model(data)
        #fake_data = fake_data_noised + fake_data_denoised
        #output = d_model(fake_data)
        
        #loss_g = fLoss(outputs, fake_labels)
        #loss_g.backward()
        optimizer_g.step()
            
        ###################
        # Update Denoizer # 
        ###################
            
        optimizer_r.zero_grad()
            
        #fake_data_noised = g_model(data)
        #fake_data = fake_data_noised + fake_data_denoised
        #output = d_model(fake_data)
        
        #loss_r = fLoss(outputs, fake_labels)
        #loss_r.backward()
        
        optimizer_r.step()
                

Note : \
with the use of `.detach()` you ensure that the gradients from the discriminator do not flow back to the generator during the backward pass.

We get our transformed data :

In [9]:
def bruitage_racien(image,b = 0,loc=0,scale=1):
    noise = rice.rvs(b, loc=loc, scale=scale, size=image.shape)
    noisy_image = np.clip(image+noise, 0, 255)
    
    return noisy_image

In [35]:
#example file
filename = "./IMA_project/PIMA/test/0001/2_t2_tse_sag_384.nii"

data = np.array((nib.load(filename)).dataobj) #in this we have 15 slices
labels = torch.ones(data.shape[0])

noised = []
for im in data: 
    noised.append(bruitage_racien(im))
    
noised = np.array(noised)

In [38]:
all_im = format_data_numpy(data,noised)

denoised_torch = torch.tensor(data, dtype=torch.float32) 
noised_torch = torch.tensor(noised, dtype=torch.float32) 

data_torch = format_data(noised_torch,denoised_torch)

all_im.shape,data_torch.shape

((15, 2, 384, 384), torch.Size([15, 2, 384, 384]))

In [55]:
snoi,sden = separate_data(data_torch)
if sden.requires_grad:
    sden = sden.detach()
# Convert to NumPy array
t_im = sden.numpy()
#plt.imshow(t_im[0])

Note the labels will depend on the MODEL :

In [22]:
data_tensor = torch.from_numpy(data).float()/255.0
dataset = TensorDataset(data_tensor, labels)

batch_size = 4 
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

#veryfication loop
for batch_images, batch_labels in dataloader:
    print(f"Batch shape: {batch_images.shape}, Labels: {batch_labels}")

Batch shape: torch.Size([4, 384, 384]), Labels: tensor([0, 1, 0, 0])
Batch shape: torch.Size([4, 384, 384]), Labels: tensor([0, 0, 0, 0])
Batch shape: torch.Size([4, 384, 384]), Labels: tensor([1, 0, 1, 0])
Batch shape: torch.Size([3, 384, 384]), Labels: tensor([1, 1, 0])


In [39]:
for batch_images, batch_labels in dataloader:
    print(f"Batch shape: {batch_images.shape[1:]}, Labels: {batch_labels}")

Batch shape: torch.Size([384, 384]), Labels: tensor([0, 0, 0, 1])
Batch shape: torch.Size([384, 384]), Labels: tensor([0, 1, 0, 0])
Batch shape: torch.Size([384, 384]), Labels: tensor([0, 0, 1, 0])
Batch shape: torch.Size([384, 384]), Labels: tensor([1, 1, 0])


Have to think about the labaling :
So easy for the denoiser and for the generator because we only have 1 input.For the discriminator we consider that if at least one is not the real one then the label is fake so 0 otherwise it's 1(true).