In [1]:
import math
import os
import sys
import time
from tqdm import tqdm
from skimage.metrics import structural_similarity as sim

import config as config

from models.N2N_Unet import N2N_Unet_DAS, N2N_Orig_Unet, Cut2Self, U_Net_origi, U_Net, TestNet
from metric import Metric

import numpy as np

import torch
import matplotlib.pyplot as plt
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms

In [5]:
def masking(img, percent):
    num_pixels = img.shape[-3] * img.shape[-2] * img.shape[-1]
    maskamount = int(np.round(num_pixels * percent))
    mask = torch.zeros_like(img)
    for i in range(img.shape[0]):
        masked_indices = torch.randperm(num_pixels)[:maskamount]
        mask_tmp = torch.zeros(img.shape[-3] , img.shape[-2] , img.shape[-1])
        # Pixel in Maske auf 1 setzen
        mask_tmp.view(-1)[masked_indices] = 1
        mask[i] = mask_tmp
    return mask

def kaput(img, percent):
    num_pixels = img.shape[-2] * img.shape[-1]
    maskamount = int(np.round(num_pixels * percent))
    mask = torch.zeros_like(img)
    for i in range(img.shape[0]):
        masked_indices = torch.randperm(num_pixels)[:maskamount]
        mask_tmp = torch.zeros(img.shape[-3] , img.shape[-2] , img.shape[-1])
        # Pixel in Maske auf 1 setzen
        mask_tmp.view(-1)[masked_indices] = 1
        mask[i] = mask_tmp
    return mask

t = torch.randn(32, 3, 128, 128)
mask = masking(t, 0.005)
print(torch.sum(mask))
print(torch.sum(mask[0]))
mask = kaput(t, 0.05)
print(torch.sum(mask))
print(torch.sum(mask[0]))

tensor(7872.)
tensor(246.)
tensor(26208.)
tensor(819.)


In [None]:
lr = 0.0004
changeLR_steps = 5000
changeLR_rate = -0.5
modi = 0
use_scheduler = True

In [None]:
def get_lr_lambda(initial_lr, step_size, lr_decrement):
    def lr_lambda(step):
        return max(0.0, initial_lr - (step // step_size) * lr_decrement / initial_lr)
    return lr_lambda

def add_norm_noise(img, snr_db):
    noise = torch.randn_like(img)
    snr_linear = snr_db
    Es = torch.sum(img**2)
    En = torch.sum(noise**2)
    alpha = torch.sqrt(Es/(snr_linear*En))
    noise = img + noise * alpha
    return noise, alpha.item()

def masking(img, percent):
    num_pixels = img.shape[-3] * img.shape[-2] * img.shape[-1]
    maskamount = int(np.round(num_pixels * percent))
    mask = torch.zeros_like(img)
    for i in range(img.shape[0]):
        masked_indices = torch.randperm(num_pixels)[:maskamount]
        mask_tmp = torch.zeros(img.shape[-3] , img.shape[-2] , img.shape[-1])
        # Pixel in Maske auf 1 setzen
        mask_tmp.view(-1)[masked_indices] = 1
        mask[i] = mask_tmp
    return mask

def calculate_loss(noise_imges, model, device):
    lambda_inv = 2
    mask = masking(noise_imges, 0.005).to(device)
    marked_points = torch.sum(mask)
    input = noise_imges * (1-mask) + (torch.normal(0, 0.2, size=noise_imges.shape).to(device) * mask)
    denoised = model(noise_imges)
    denoised_masked = model(input)
    loss_rec = torch.mean((denoised-noise_imges)**2) # mse(denoised, noise_images)
    loss_inv = torch.sum(mask*(denoised-denoised_masked)**2)# mse(denoised, denoised_mask)
    loss = loss_rec + lambda_inv * (loss_inv/marked_points).sqrt()
    if math.isnan(loss):
        print(f"{loss_rec} + {loss_inv} * {(loss_inv/marked_points).sqrt()}, marked_points: {marked_points}")
    return loss, denoised, mask, loss_rec, loss_inv


In [None]:
if torch.cuda.device_count() == 1:
    device = "cuda" if torch.cuda.is_available() else "cpu"
else:
    device = "cuda:3"

celeba_dir = config.celeba_dir

transform_noise = transforms.Compose([
    #transforms.RandomResizedCrop((128,128)),
    transforms.CenterCrop((128,128)),
    #transforms.Resize((512,512)), #for self2self
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.float()),
    transforms.Lambda(lambda x:  x * 2 -1),
    ])
print("lade Datensätze ...")
dataset_all = datasets.CelebA(root=celeba_dir, split='train', download=False, transform=transform_noise)

dataset_validate_all = datasets.CelebA(root=celeba_dir, split='valid', download=False, transform=transform_noise)

dataset_test_all = datasets.CelebA(root=celeba_dir, split='test', download=False, transform=transform_noise)

print(f"Using {device} device")


dataset = torch.utils.data.Subset(dataset_all, list(range(6400)))
dataset_validate = torch.utils.data.Subset(dataset_validate_all, list(range(640)))
dataset_test = torch.utils.data.Subset(dataset_test_all, list(range(640)))
dataLoader = DataLoader(dataset, batch_size=32, shuffle=True)
dataLoader_validate = DataLoader(dataset_validate, batch_size=32, shuffle=False)
dataLoader_test = DataLoader(dataset_test, batch_size=32, shuffle=False)
if torch.cuda.device_count() == 1:
    model = N2N_Orig_Unet(3,3).to(device) #default
else:
    model = U_Net(in_chanel=3, first_out_chanel=96, batchNorm=True).to(device)

#configAtr = getattr(config, methode) #config.methode wobei methode ein string ist
optimizer = torch.optim.Adam(model.parameters(), lr=lr)


if use_scheduler:
    lr_lambda = get_lr_lambda(lr, changeLR_steps, changeLR_rate)
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
else:
    scheduler = None

print("fertig")


In [None]:
def train(dataLoader, optimizer, scheduler, device, model, mode):
    global modi
    global use_scheduler
    loss_log = []
    psnr_log = []
    original_psnr_log = []
    sim_log = []
    loss_rec_log = []
    loss_inv_log = []
    for batch_idx, (original, label) in enumerate((dataLoader)):#tqdm
        original = original.to(device)
        noise_images, sigma = add_norm_noise(original, 2)
        noise_images = noise_images.to(device)

        if mode=="test" or mode =="validate":
            model.eval()
            with torch.no_grad():
                #n2same:
                denoised = model(noise_images)
                #n2info:
                """
                if mode=="test":
                    denoised = model(noise_images)
                else:
                    #loss, denoised, loss_rec, loss_inv, marked_pixel = n2same(noise_images, device, model, lambda_inv)
                    loss, denoised, loss_rec, loss_inv, marked_pixel = n2info(noise_images, model, device, sigma_info)
                    all_marked += marked_pixel
                    lex += loss_rec
                    lin += loss_inv
                    n_partition = (denoised-noise_images).view(denoised.shape[0], -1) # (b, c*w*h)
                    n_partition = torch.sort(n_partition, dim=1).values #descending=False
                    n = torch.cat((n, n_partition), dim=0)
                    if batch_idx == len(dataLoader)-1:
                        e_l = 0
                        for i in range(config.methodes['n2info']['predictions']): #kmc
                            #to big for torch.multinomial if all pictures from validation should be used
                            #samples = torch.tensor(torch.multinomial(n.view(-1), n.shape[1], replacement=True))#.view(1, n.shape[1])
                            #samples = torch.sort(samples).values
                            samples = np.sort(np.random.choice((n.cpu()).reshape(-1),[1, n.shape[1]])) #(1,49152)
                            e_l += torch.mean((n-torch.from_numpy(samples).to(device))**2)
                        lex = lex / (len(dataLoader) * denoised.shape[0])
                        lin = lin / all_marked
                        e_l = e_l / config.methodes['n2info']['predictions']
                        #estimated_sigma = (lin)**0.5 + (lin + lex-e_l)**0.5 #inplementation from original github of noise2info
                        m = len(dataLoader) * denoised.shape[0] *3*128*128 #TODO: is m right?
                        estimated_sigma = lex + (lex**2 + m *(lin-e_l))**0.5/m #from paper
                        print('new sigma_loss is ', estimated_sigma)
                        if 0 < estimated_sigma < sigma_info:
                            sigma_info = float(estimated_sigma)
                            print('sigma_loss updated to ', estimated_sigma)
                """
                        
        else:
            model.train()
            #original, noise_images are only important if n2void
            loss, denoised, mask, loss_rec, loss_inv = calculate_loss(noise_images, model, device)
            loss_rec_log.append(loss_rec.item())
            loss_inv_log.append(loss_inv.item())
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if use_scheduler:
                scheduler.step()

        #log Data
        original_psnr_batch = Metric.calculate_psnr(original, denoised)
        denoised = (denoised-denoised.min())  / (denoised.max() - denoised.min())
        noise_images = (noise_images-noise_images.min())  / (noise_images.max() - noise_images.min())
        original = (original-original.min())  / (original.max() - original.min())
        psnr_batch = Metric.calculate_psnr(original, denoised)
        similarity_batch, diff_picture = Metric.calculate_similarity(original, denoised)
        if "train" in mode:
            loss_log.append(loss.item())
        psnr_log.append(psnr_batch.item())
        original_psnr_log.append(original_psnr_batch.item())
        sim_log.append(similarity_batch)    
    
    return loss_log, psnr_log, original_psnr_log, sim_log, loss_rec_log, loss_inv_log


loss_rec_log = []
loss_inv_log = []
for epoch in tqdm(range(50)):
    
    loss, psnr, original_psnr_log, similarity, loss_rec, loss_inv = train(dataLoader, optimizer, scheduler, device, model, mode="train")
    loss_rec_log.append(loss_rec)
    loss_inv_log.append(loss_inv)
    
    if math.isnan(loss[-1]):
        break

    loss_val, psnr_val, original_psnr_log_val, similarity_val, _, _ = train(dataLoader_validate, optimizer, scheduler, device, model, mode="validate")

    print(f"Epoch {epoch}:\n"
          f"Train: loss={Metric.avg_list(loss):.5f}, psnr={Metric.avg_list(psnr):.5f}, psnr_orig={Metric.avg_list(original_psnr_log):.5f}, similarity={Metric.avg_list(similarity):.5f}\n"
          f"Validate: loss={0.00000}, psnr={Metric.avg_list(psnr_val):.5f}, psnr_orig={Metric.avg_list(original_psnr_log_val):.5f}, similarity={Metric.avg_list(similarity_val):.5f}")
    
    
loss_test, psnr_test, original_psnr_log_test, similarity_test, _, _ = train(dataLoader_test, optimizer, scheduler, device, model, mode="test")

print(f"Test: loss={Metric.avg_list(loss_test):.5f}, psnr={Metric.avg_list(psnr_test):.5f}, similarity={Metric.avg_list(similarity_test):.5f}")

In [None]:
plt.figure(figsize=(10, 5))

# Plot loss_rec_log
plt.plot(range(len(loss_rec_log)), loss_rec_log, label='Reconstruction Loss')

# Plot loss_inv_log
plt.plot(range(len(loss_inv_log)), loss_inv_log, label='Inversion Loss')

plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Reconstruction and Inversion Loss over Epochs')
plt.legend()
plt.show()

In [None]:
from torchvision.utils import make_grid
original, label = next(iter(dataLoader_test))
original = original.to(device)
noise_images, sigma = add_norm_noise(original, 2)
noise_images = noise_images.to(device)
denoised = model(noise_images)
comparison = torch.cat((original[:4], denoised[:4], noise_images[:4]), dim=0)
grid = make_grid(comparison, nrow=4, normalize=False).cpu()

plt.imshow(grid.permute(1, 2, 0))
plt.show()
