In [1]:
import os
from tqdm import tqdm
from skimage.metrics import structural_similarity as sim

from models.N2N_Unet import N2N_Unet_DAS, N2N_Orig_Unet, Cut2Self
from metric import Metric
from masks import Mask
from utils import *
from transformations import *

import numpy as np
import torch
import matplotlib.pyplot as plt
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms

#sys.path.insert(0, str(Path(__file__).resolve().parents[3])) #damit die Pfade auf dem Server richtig waren (copy past von PG)
from absl import app
from torch.utils.tensorboard import SummaryWriter

In [3]:
def generate_patches_from_list(data, num_patches_per_img=None, shape=(64, 64), augment=True, shuffle=False):
        """
        Extracts patches from 'list_data', which is a list of images, and returns them in a 'numpy-array'. The images
        can have different dimensionality.

        Parameters
        ----------
        data                : list(array(float))
                            List of images with dimensions 'SZYXC' or 'SYXC'
        num_patches_per_img : int, optional(default=None)
                            If 'None', as many patches as fit i nto the dimensions are extracted.
                            Else may generate overlapping patches
        shape               : tuple(int), optional(default=(256, 256))
                            Shape of the extracted patches.
        augment             : bool, optional(default=True)
                            Rotate the patches in XY-Plane and flip them along X-Axis. This only works if the patches are square in XY.
        shuffle             : bool, optional(default=False)
                            Shuffles extracted patches across all given images (data).

        Returns
        -------
        patches : array(float)
                Numpy-Array with the patches. The dimensions are 'SZYXC' or 'SYXC'
        """
        patches = []
        for img in data:
            patches.append( generate_patches(img.unsqueeze(0), num_patches=num_patches_per_img, shape=shape, augment=augment) )
        patches = torch.cat(patches, dim=0)

        if shuffle:
            indices = torch.randperm(len(patches))
            patches = patches[indices]
        return patches

def generate_patches(data, num_patches=None, shape=(64, 64), augment=True):
    """
    Extracts patches from 'data'. The patches can be augmented, which means they get rotated three times
    in XY-Plane and flipped along the X-Axis. Augmentation leads to an eight-fold increase in training data.

    Parameters
    ----------
    data        : list(array(float))
                List of images with dimensions 'SZYXC' or 'SYXC'
    num_patches : int, optional(default=None)
                Number of patches to extract per image. If 'None', as many patches as fit into the
                dimensions are extracted.
    shape       : tuple(int), optional(default=(256, 256))
                Shape of the extracted patches.
    augment     : bool, optional(default=True)
                Rotate the patches in XY-Plane and flip them along X-Axis. This only works if the patches are square in XY.

    Returns
    -------
    patches : array(float)
            Numpy-Array containing all patches (randomly shuffled along S-dimension).
            The dimensions are 'SZYXC' or 'SYXC'
    """
    patches = __extract_patches__(data, num_patches=num_patches, shape=shape)
    if augment and shape[0] == shape[1]:
        patches = __augment_patches__(patches)

    if num_patches is not None:
        indices = torch.randint(len(patches), (num_patches,))
        patches = patches[indices]

    print('Generated patches:', patches.shape)
    return patches
    
def __extract_patches__(data, num_patches=None, shape=(64, 64)):
    patches = []
    if num_patches is None:
        if data.shape[-2] >= shape[0] and data.shape[-1] >= shape[1]:
            for y in range(0, data.shape[-2] - shape[0] + 1, shape[0]):
                for x in range(0, data.shape[-1] - shape[1] + 1, shape[1]):
                    patches.append(data[..., y:y + shape[0], x:x + shape[1]])
    else:
        for i in range(num_patches):
            y, x = torch.randint(0, data.shape[-2] - shape[0] + 1, (2,))
            patches.append(data[..., y:y + shape[0], x:x + shape[1]])
    return torch.cat(patches, axis=0)

def __augment_patches__(patches):
    augmented = torch.cat((patches,
                            torch.rot90(patches, 1, (-2, -1)),
                            torch.rot90(patches, 2, (-2, -1)),
                            torch.rot90(patches, 3, (-2, -1))),
                            dim=0)
    augmented = torch.cat((augmented, torch.flip(augmented, [-2])))
    return augmented

def show_tensor_as_picture(img):
    img = img.cpu().detach()
    if len(img.shape)==4:
        img = img[0]
    if img.shape[0] == 3:
        img = img.permute(1,2,0)
    plt.imshow(img, interpolation='nearest')
    plt.show()

def n2void_mask(image_shape, num_masked_pixels=8):
    """
    uniform_pixel_selection_mask
    Erstellt eine Uniform Pixel Selection Maske.
    
    image_shape (tuple): Die Form des Bildes (batch, channels, height, width).
    num_masked_pixels (int): Die Anzahl der maskierten Pixel, die ausgewählt werden sollen.

    """
    if len(image_shape)==3:
        return select_random_pixels(image_shape, num_masked_pixels)
    else:
        mask_for_batch = []
        for i in range(image_shape[0]):
            mask_for_batch.append(select_random_pixels((image_shape[1],image_shape[2],image_shape[3]), num_masked_pixels))
        return torch.stack(mask_for_batch)
        

    
def select_random_pixels(image_shape, num_masked_pixels):
    num_pixels = image_shape[1] * image_shape[2]
    # Erzeuge zufällige Indizes für die ausgewählten maskierten Pixel
    masked_indices = torch.randperm(num_pixels)[:num_masked_pixels]
    mask = torch.zeros(image_shape[1], image_shape[2])
    # Piel in Maske auf 1 setzen
    mask.view(-1)[masked_indices] = 1
    # Mache für alle Chanels
    mask = mask.unsqueeze(0).expand(image_shape[0], -1, -1)
    return mask

In [None]:
transform = transforms.Compose([
    #transforms.Resize((128, 128)), #TODO:crop und dann resize
    transforms.RandomResizedCrop((128,128)),
    transforms.ToTensor(),                  # PIL-Bild in Tensor
    transforms.Lambda(lambda x: x.float()),  # in Float
    transforms.Lambda(lambda x: x / torch.max(x)) #skallieren auf [0,1]
])
device = "cuda" if torch.cuda.is_available() else "cpu"
device = "cpu"
methode="n2n_orig"
#store_path = log_files()
#run = wandb.init(entity="", project="my-project-name", anonymous="allow")
#writer = SummaryWriter(log_dir=os.path.join(store_path, "tensorboard"))

celeba_dir = 'dataset/celeba_dataset'
dataset = datasets.CelebA(root=celeba_dir, split='train', download=True, transform=transform)
dataset_validate = datasets.CelebA(root=celeba_dir, split='valid', download=True, transform=transform)
dataLoader = DataLoader(dataset, batch_size=64, shuffle=True)
dataLoader_validate = DataLoader(dataset_validate, batch_size=64, shuffle=True)

mask = Mask.cut2self_mask((128,128), 64).to(device)

model = N2N_Orig_Unet(3,3).to(device)
model = 
#model = Cut2Self(mask).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.003)


print(f"Using {device} device")

In [None]:
original, label = next(iter(dataLoader))
    
print(original.shape)
print(label.shape)


In [None]:
noise_image = (original + torch.randn_like(original) * (0.5)).to(device)
patches = generate_patches_from_list(noise_image, num_patches_per_img=6)
print(patches.shape)

In [None]:

show_tensor_as_picture(patches[0])
show_tensor_as_picture(patches[1])
show_tensor_as_picture(patches[2])
show_tensor_as_picture(patches[3])
show_tensor_as_picture(patches[4])
show_tensor_as_picture(patches[5])
print("-------------------------------------------------------------------")


mask = n2void_mask(patches.shape, num_masked_pixels=8).to(device)
print(patches.shape)
print(noise_image.shape)
print(mask.shape)
print((1-mask).shape)

mask_noise = patches * (1-mask)
#show_tensor_as_picture(mask_noise) #TODO: ist maske richtig? -> nicht schwärzen sondern, austauschen durch anderes Pixxeln in window-size



cords = torch.nonzero(mask)
for pixel_idx in range(8):
    x, y = cords[pixel_idx]
    new_x = max(0, min(patches.shape[2] - 1, x + torch.randint(-p//2, p//2 + 1, (1,)).item()))
    new_y = max(0, min(patches.shape[2] - 1, y + torch.randint(-p//2, p//2 + 1, (1,)).item()))
    patches[:, y, x] = patches[:, new_y, new_x]





denoised = model(mask_noise)
denoised_pixel = denoised * mask
target_pixel = patches * mask

show_tensor_as_picture(mask_noise)
show_tensor_as_picture(denoised)
show_tensor_as_picture(denoised_pixel)
show_tensor_as_picture(target_pixel)

In [None]:
mask = n2void_mask(patches.shape, num_masked_pixels=8).to(device)
print("maks.shape: ", mask.shape)
print("patches.shape: ", patches.shape)
print("--------")
p=5
cords = torch.nonzero(mask)
bearbeittete_Bilder = patches.clone()
num_masked_pixels=8
memory = []
for pixel_idx in range(cords.shape[0]):
    if pixel_idx > 16:
        break
    batch, chanel, x, y = cords[pixel_idx]
    batch, chanel, x, y = batch.item(), chanel.item(), x.item(), y.item()
    print("batch: ", batch)
    print("chanel: ", chanel)
    print("x,y: ", x,y)
    if chanel != 0:
        bearbeittete_Bilder[batch, chanel, x, y] = patches[batch, chanel, memory[pixel_idx%num_masked_pixels][0], memory[pixel_idx%num_masked_pixels][1]]
        if chanel==3 and (pixel_idx%num_masked_pixels)==(num_masked_pixels-1):
            memory = []
    else: 
        new_x = max(0, min(bearbeittete_Bilder.shape[2] - 1, x + torch.randint(-p//2, p//2 + 1, (1,)).item()))
        new_y = max(0, min(bearbeittete_Bilder.shape[2] - 1, y + torch.randint(-p//2, p//2 + 1, (1,)).item()))
        memory.append((new_x, new_y))
        bearbeittete_Bilder[batch, chanel, x, y] = patches[batch, chanel, new_x, new_y]

print("new x: ", new_x)
print("new y: ", new_y)
print("bearbeitete Bilder.shape: ", bearbeittete_Bilder.shape)
show_tensor_as_picture(patches[0]*(mask[0]).to(device))
show_tensor_as_picture(bearbeittete_Bilder[0]*(mask[0]).to(device))
show_tensor_as_picture(patches[2]*(mask[2]).to(device))
show_tensor_as_picture(bearbeittete_Bilder[2]*(mask[2]).to(device))


In [None]:
#show_tensor_as_picture(patches*(1-mask).to(device))
show_tensor_as_picture(bearbeittete_Bilder.to(device))

show_tensor_as_picture(patches[2]*(mask[2]).to(device))
show_tensor_as_picture(bearbeittete_Bilder[2]*(mask[2]).to(device))
print(torch.count_nonzero(mask[0]).item())

print(mask[0][0])
print(mask[0][1])
print(mask[0][2])


In [None]:
original, label = next(iter(dataLoader))
original = original.to(device)
noise1 = (torch.randn_like(original)).to(device)
noise2 = (torch.randn_like(original) * 0.5).to(device)

blue_chanel = torch.zeros_like(noise1).to(device)
blue_chanel[:,2,:,:] = 1

red_chanel = torch.zeros_like(noise1).to(device)
red_chanel[:,0,:,:] = 1

blue_m = noise1*red_chanel
red_m = noise2*blue_chanel

show_tensor_as_picture(original+blue_m)
show_tensor_as_picture(original+blue_m+red_m)
show_tensor_as_picture(blue_m)
show_tensor_as_picture(red_m)



Files already downloaded and verified


AttributeError: 'CelebA' object has no attribute 'mean'

AttributeError: 'CelebA' object has no attribute 'data'