# Claude ansatz torch dataset erstellen

In [1]:
import h5py
import numpy as np
import torch
from torch.utils.data import Dataset
from torchvision import transforms
import torchvision.transforms.functional as TF


class H5PatchDataset(Dataset):
    """
    PyTorch Dataset für H5-Bilder mit überlappenden Patches.
    
    Args:
        bild_path: Pfad zur H5-Datei mit dem Bild
        maske_path: Pfad zur H5-Datei mit der Maske
        patch_size: Größe der quadratischen Patches (z.B. 256)
        overlap: Überlappung in Pixeln (z.B. 32)
        transform: Optionale Transformationen (werden auf beide angewendet)
    """
    
    def __init__(self, bild_path, maske_path, patch_size=256, overlap=32, transform=None):
        self.bild_path = bild_path
        self.maske_path = maske_path
        self.patch_size = patch_size
        self.overlap = overlap
        self.transform = transform
        
        # Lade die Daten
        with h5py.File(bild_path, 'r') as f:
            self.bild = np.array(f["Image"])[:, :]
        
        with h5py.File(maske_path, 'r') as f:
            self.maske = np.array(f["exported_data"])[:, :]
        
        # Berechne Patch-Positionen
        self.patch_positions = self._calculate_patch_positions()
        
    def _calculate_patch_positions(self):
        """Berechnet alle Patch-Positionen mit Überlappung."""
        h, w = self.bild.shape
        stride = self.patch_size - self.overlap
        
        positions = []
        for y in range(0, h - self.patch_size + 1, stride):
            for x in range(0, w - self.patch_size + 1, stride):
                positions.append((y, x))
        
        # Füge Rand-Patches hinzu, falls nötig
        if (h - self.patch_size) % stride != 0:
            for x in range(0, w - self.patch_size + 1, stride):
                positions.append((h - self.patch_size, x))
        
        if (w - self.patch_size) % stride != 0:
            for y in range(0, h - self.patch_size + 1, stride):
                positions.append((y, w - self.patch_size))
        
        # Ecke unten rechts
        if (h - self.patch_size) % stride != 0 and (w - self.patch_size) % stride != 0:
            positions.append((h - self.patch_size, w - self.patch_size))
        
        return positions
    
    def __len__(self):
        return len(self.patch_positions)
    
    def __getitem__(self, idx):
        y, x = self.patch_positions[idx]
        
        # Extrahiere Patches
        bild_patch = self.bild[y:y+self.patch_size, x:x+self.patch_size]
        maske_patch = self.maske[y:y+self.patch_size, x:x+self.patch_size, :]
        
        # Konvertiere zu Tensoren
        bild_tensor = torch.from_numpy(bild_patch).float().unsqueeze(0)  # (1, H, W)
        maske_tensor = torch.from_numpy(maske_patch).float().permute(2, 0, 1)  # (3, H, W)
        
        # Wende Transformationen an (beide werden gleich transformiert)
        if self.transform:
            bild_tensor, maske_tensor = self.transform(bild_tensor, maske_tensor)
        
        return {
            'image': bild_tensor,
            'mask': maske_tensor,
            'position': (y, x),  # Für Rekonstruktion
            'index': idx
        }
    
    def get_original_shape(self):
        """Gibt die ursprüngliche Bildgröße zurück."""
        return self.bild.shape


class SyncedTransform:
    """
    Transformationen, die synchron auf Bild und Maske angewendet werden.
    """
    
    def __init__(self, rotation_range=15, flip_prob=0.5, brightness=0.2, contrast=0.2):
        self.rotation_range = rotation_range
        self.flip_prob = flip_prob
        self.brightness = brightness
        self.contrast = contrast
    
    def __call__(self, image, mask):
        # Zufällige Rotation
        if self.rotation_range > 0:
            angle = torch.rand(1).item() * 2 * self.rotation_range - self.rotation_range
            image = TF.rotate(image, angle, interpolation=TF.InterpolationMode.BILINEAR)
            mask = TF.rotate(mask, angle, interpolation=TF.InterpolationMode.NEAREST)
        
        # Zufälliges horizontales Flip
        if torch.rand(1).item() < self.flip_prob:
            image = TF.hflip(image)
            mask = TF.hflip(mask)
        
        # Zufälliges vertikales Flip
        if torch.rand(1).item() < self.flip_prob:
            image = TF.vflip(image)
            mask = TF.vflip(mask)
        
        # Farbanpassungen nur auf Bild (nicht auf Maske)
        if self.brightness > 0:
            factor = 1 + (torch.rand(1).item() * 2 - 1) * self.brightness
            image = TF.adjust_brightness(image, factor)
        
        if self.contrast > 0:
            factor = 1 + (torch.rand(1).item() * 2 - 1) * self.contrast
            image = TF.adjust_contrast(image, factor)
        
        return image, mask


def reconstruct_from_patches(predictions, positions, original_shape, patch_size, overlap):
    """
    Rekonstruiert das vollständige Bild aus überlappenden Patches.
    
    Args:
        predictions: Liste von Patch-Predictions [(C, H, W), ...]
        positions: Liste von (y, x) Positionen
        original_shape: (H, W) des Originalbilds
        patch_size: Größe der Patches
        overlap: Überlappung in Pixeln
    
    Returns:
        Rekonstruiertes Bild (C, H, W)
    """
    h, w = original_shape
    c = predictions[0].shape[0]
    
    # Erstelle Output-Array und Gewichtungs-Array
    output = np.zeros((c, h, w), dtype=np.float32)
    weights = np.zeros((h, w), dtype=np.float32)
    
    # Erstelle Gewichtungsmaske für weiche Übergänge
    weight_mask = create_weight_mask(patch_size, overlap)
    
    for pred, (y, x) in zip(predictions, positions):
        pred_np = pred.cpu().numpy() if torch.is_tensor(pred) else pred
        output[:, y:y+patch_size, x:x+patch_size] += pred_np * weight_mask
        weights[y:y+patch_size, x:x+patch_size] += weight_mask
    
    # Normalisiere durch Gewichte
    output = output / (weights + 1e-8)
    
    return output


def create_weight_mask(patch_size, overlap):
    """Erstellt eine Gewichtungsmaske für weiche Übergänge."""
    mask = np.ones((patch_size, patch_size), dtype=np.float32)
    
    if overlap > 0:
        fade = np.linspace(0, 1, overlap)
        # Oben
        mask[:overlap, :] *= fade[:, np.newaxis]
        # Unten
        mask[-overlap:, :] *= fade[::-1, np.newaxis]
        # Links
        mask[:, :overlap] *= fade[np.newaxis, :]
        # Rechts
        mask[:, -overlap:] *= fade[::-1][np.newaxis, :]
    
    return mask

In [None]:
from pathlib import Path

cwd = Path.cwd()
path = Path('PE-2024-01126-M_00_s0021_PM_Complete_Transmittance_Stitched_Flat_v004.h5')
fullpath = Path.cwd().parent/f'data/{path}'
Maskpath = Path('PE-2024-01126-M_00_s0021_PM_Complete_Transmittance_Stitched_Flat_v004-Image_Probabilities.h5')
Maskfullpath = Path.cwd().parent/f'data/{Maskpath}'
# Dataset erstellen
transform = SyncedTransform(rotation_range=15, flip_prob=0.5)
dataset = H5PatchDataset(
	bild_path=fullpath,
	maske_path=Maskfullpath,
	patch_size=256,
	overlap=32,
	transform=transform
)

print(f"Dataset Größe: {len(dataset)} Patches")
print(f"Original Bildgröße: {dataset.get_original_shape()}")

# DataLoader erstellen
from torch.utils.data import DataLoader
dataloader = DataLoader(dataset, batch_size=8, shuffle=True, num_workers=4)

# Beispiel: Ein Batch laden
batch = next(iter(dataloader))
print(f"Batch Image Shape: {batch['image'].shape}")
print(f"Batch Mask Shape: {batch['mask'].shape}")

# Beispiel: Rekonstruktion (im Inference-Modus ohne Transform)
test_dataset = H5PatchDataset(
	bild_path="path/to/image.h5",
	maske_path="path/to/mask.h5",
	patch_size=256,
	overlap=32,
	transform=None  # Keine Augmentation für Inferenz
)

# Sammle alle Predictions (hier als Beispiel die Masken selbst)
predictions = []
positions = []
for sample in test_dataset:
	predictions.append(sample['mask'])
	positions.append(sample['position'])

# Rekonstruiere
reconstructed = reconstruct_from_patches(
	predictions, 
	positions, 
	test_dataset.get_original_shape(),
	256, 
	32
)
print(f"Rekonstruiertes Bild Shape: {reconstructed.shape}")

Dataset Größe: 1044 Patches
Original Bildgröße: (7956, 6488)


## CNN aus meiner deep learning übung
Ist nicht sehr sinnvoll für unsere anwendung, da hier nur 10 outputs exestieren.... muss also angepasst werden


In [None]:
import torch.nn.functional as F
from torch.autograd import Variable
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image



class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=5, stride=1, padding=0, dilation=1)
        self.conv2 = nn.Conv2d(in_channels=16, out_channels=16, kernel_size=5, stride=1, padding=0, dilation=1)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv3 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=0, dilation=1)
        self.conv4 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=0, dilation=1)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(32*4*4, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = self.pool1(x)
        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x))
        x = self.pool2(x)
        x = x.view(-1, 32*4*4)
        x = F.relu(self.fc1(x))
        x = F.log_softmax(self.fc2(x), dim=1)
        return x

    # Define the train_net function.
    def train_net(self, criterion, optimizer, trainloader, epochs, _net="CNN"):
        log_interval = 10
        for epoch in range(epochs):
            for batch_idx, (data, target) in enumerate(trainloader):
                data, target = Variable(data), Variable(target)  # data, target = data.to(device), target.to(device)
                if _net == "MLP":
                    data = data.view(-1, 3 * 32 * 32)
                optimizer.zero_grad()
                net_out = self(data)
                loss = criterion(net_out, target)
                loss.backward()
                optimizer.step()
                if batch_idx % log_interval == 0:
                    print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch, batch_idx * len(data),
                                                                                   len(trainloader.dataset),
                                                                                   100. * batch_idx / len(trainloader),
                                                                                   loss.data.item()))

    # Define the test_net function.
    def test_net(self, criterion, testloader, _net="CNN"):
        test_loss = 0
        correct = 0
        for i_batch, (data, target) in enumerate(testloader):
            data, target = Variable(data), Variable(target)  # data, target = data.to(device), target.to(device)
            if _net == "MLP":
                data = data.view(-1, 3 * 32 * 32)
            net_out = self(data)
            test_loss += criterion(net_out, target).data.item()  # sum up batch loss
            pred = net_out.data.max(1)[1]  # get the index of the max log-probability
            batch_labels = pred.eq(target.data)
            correct += batch_labels.sum()
        test_loss /= len(testloader.dataset)
        acc = 100. * float(correct) / len(testloader.dataset)
        print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(test_loss, correct,
                                                                                     len(testloader.dataset), acc))

In [None]:
# Create the model.
model = Net()
model_type = "CNN"

# Create the optimizer and the (loss-)criterion.
optimizer = torch.optim.SGD(model.parameters(), lr=0.005, momentum=0.9)
criterion = nn.NLLLoss()

# Train and save.
model.train()
model.train_net(criterion, optimizer, trainloader, 2, _net=model_type)
torch.save(model.state_dict(), f"data/net_{model_type}.pt")