In [10]:
import torch
import matplotlib.pyplot as plt
from torchvision.transforms import v2
from your_dataset_file import FLYDataset  # Pfad anpassen, z. B. `from fly_dataset import FLYDataset`

# ----------- Config -------------
DATA_PATH = "/scratch/cv-course2025/group2/data"
IDX = 42
CAM = 0
H, W = 480, 980
# --------------------------------

# ----------- Transform definieren -------------
transform = v2.Compose([
    v2.RandomRotation(degrees=25),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=(0.5,), std=(0.5,))
])

# ----------- Originales Bild + Keypoints -------------
dataset = FLYDataset(DATA_PATH, mode="training", cam=CAM)
img_orig, (x_orig, y_orig) = dataset.__getvisual__(IDX)

# ----------- Bild mit Transformation vorbereiten -------------
# Hole rohes Bild und Keypoints aus __getitem__
img_tensor, kps = dataset[IDX]  # [1, H, W], [N, 2] normiert
img_tensor = img_tensor.clone()
kps = kps.clone()

# Apply transform (nur aufs Bild, Keypoints manuell)
torch.manual_seed(0)  # für reproduzierbare Rotation
angle = 25  # manuell, da RandomRotation fix ist
img_trans = v2.RandomRotation(degrees=(angle, angle))(img_tensor)

# Keypoints normiert → Pixel
kps_px = kps.clone()
kps_px[:, 1] *= H  # y
kps_px[:, 0] *= W  # x

# Manuelle Rotation um Bildzentrum
import math
import numpy as np
angle_rad = math.radians(angle)
center = torch.tensor([W / 2, H / 2])
rel_kps = kps_px - center
rotation_matrix = torch.tensor([
    [math.cos(angle_rad), -math.sin(angle_rad)],
    [math.sin(angle_rad),  math.cos(angle_rad)]
])
kps_rotated = rel_kps @ rotation_matrix.T + center
x_rot, y_rot = kps_rotated[:, 0], kps_rotated[:, 1]

# Bild zurück für Anzeige
img_vis = img_trans.squeeze().numpy()
img_vis = (img_vis * 0.5 + 0.5) * 255
img_vis = np.clip(img_vis, 0, 255).astype(np.uint8)

# ----------- Visualisierung -------------
fig, axes = plt.subplots(1, 2, figsize=(12, 6))

# Original
axes[0].imshow(img_orig, cmap="gray")
axes[0].scatter(x_orig, y_orig, c="lime", s=20)
axes[0].set_title("Original")
axes[0].axis("off")

# Transformiert
axes[1].imshow(img_vis, cmap="gray")
axes[1].scatter(x_rot, y_rot, c="red", s=20)
axes[1].set_title("Mit Rotation")
axes[1].axis("off")

plt.tight_layout()
plt.show()


ModuleNotFoundError: No module named 'your_dataset_file'