In [None]:
import os
import random
import json
import numpy as np
import vtk
from vtk.util.numpy_support import vtk_to_numpy

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
from PIL import Image
import sys
sys.path.append("../")

from lib.classes.models.demo import ReconstructionModel
from lib.classes.dataset import PollenDataset

def main():
    # Transformation: Resize, ToTensor, Normalisierung
    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5], std=[0.5])
    ])

    # Dataset erstellen
    dataset = PollenDataset(
        model_dir="../data-pipeline/data/models",
        transform=transform,
        return_3d=False  # Ground-Truth 3D gibt es hier nicht – in einer echten Anwendung müssten diese vorverarbeitet werden.
    )

    # Aufteilung in Train, Validation und Test (z. B. 70/15/15)
    total_size = len(dataset)
    train_size = int(0.7 * total_size)
    val_size = int(0.15 * total_size)
    test_size = total_size - train_size - val_size
    train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])
    print(f"Dataset-Größen: Train={len(train_dataset)}, Val={len(val_dataset)}, Test={len(test_dataset)}")

    # DataLoader
    batch_size = 4
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=0)

    # Modell, Loss-Funktion und Optimierer initialisieren
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = ReconstructionModel().to(device)
    criterion = nn.MSELoss()  # Platzhalter-Loss – in der echten Anwendung ein sinnvoller 3D-Rekonstruktions-Loss
    optimizer = optim.Adam(model.parameters(), lr=1e-3)

    num_epochs = 10
    for epoch in range(num_epochs):
        # Training
        model.train()
        running_loss = 0.0
        for batch in train_loader:
            left = batch['left_view'].to(device)   # (B, 1, H, W)
            right = batch['right_view'].to(device)
            outputs = model(left, right)
            # Dummy-Target: Zero-Volume
            target = torch.zeros_like(outputs).to(device)
            loss = criterion(outputs, target)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * left.size(0)
        epoch_loss = running_loss / train_size
        print(f"Epoch {epoch+1}/{num_epochs} Training Loss: {epoch_loss:.4f}")

        # Validation
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for batch in val_loader:
                left = batch['left_view'].to(device)
                right = batch['right_view'].to(device)
                outputs = model(left, right)
                target = torch.zeros_like(outputs).to(device)
                loss = criterion(outputs, target)
                val_loss += loss.item() * left.size(0)
        val_loss /= val_size
        print(f"Epoch {epoch+1}/{num_epochs} Validation Loss: {val_loss:.4f}")

    # Modell speichern
    torch.save(model.state_dict(), "reconstruction_model.pth")
    print("Training abgeschlossen. Modell gespeichert als reconstruction_model.pth.")

    # Optional: Testphase
    model.eval()
    test_loss = 0.0
    with torch.no_grad():
        for batch in test_loader:
            left = batch['left_view'].to(device)
            right = batch['right_view'].to(device)
            outputs = model(left, right)
            target = torch.zeros_like(outputs).to(device)
            loss = criterion(outputs, target)
            test_loss += loss.item() * left.size(0)
    test_loss /= test_size
    print(f"Test Loss: {test_loss:.4f}")

import torch.multiprocessing as mp

if __name__ == "__main__":
    mp.set_start_method("spawn", force=True)
    main()


Dataset-Größen: Train=168, Val=36, Test=37
