In [None]:
from google.colab import drive

# Force remount to refresh credentials
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [None]:
!ls "/content/drive/MyDrive/Datasets/"

defocused_blurred  DIV2K_train_HR  DIV2K_valid_HR  sharp
DIV2K_blurred	   DIV2K_train_LR  motion_blurred


In [None]:
import numpy as np
import os

KERNEL_DIR = "/content/drive/MyDrive/Datasets/DIV2K_blurred/kernels"
bad_files = []
for fname in os.listdir(KERNEL_DIR):
    if fname.endswith(".npy"):
        arr = np.load(os.path.join(KERNEL_DIR, fname))
        if arr.shape != (15, 15):
            bad_files.append((fname, arr.shape))

print(f"Talált hibás PSF fájlok: {len(bad_files)}")
print(bad_files[:5])  # csak az első 5 hibás fájl kiírása

Talált hibás PSF fájlok: 0
[]


In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import numpy as np
from tqdm import tqdm

# --- Paraméterek ---
DATA_DIR = "/content/drive/MyDrive/Datasets/DIV2K_blurred"
KERNEL_SIZE = 15
BATCH_SIZE = 16
EPOCHS = 20
LR = 1e-4

# --- Dataset ---
class PSFDataset(Dataset):
    def __init__(self, blur_dir, kernel_dir, transform=None):
        self.blur_dir = blur_dir
        self.kernel_dir = kernel_dir
        self.filenames = sorted([f for f in os.listdir(blur_dir) if f.endswith('.png')])
        self.transform = transform

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        fname = self.filenames[idx]
        blur_path = os.path.join(self.blur_dir, fname)
        kernel_path = os.path.join(self.kernel_dir, fname.replace('.png', '.npy'))

        image = Image.open(blur_path).convert("RGB")
        if self.transform:
            image = self.transform(image)

        kernel = np.load(kernel_path)
        kernel = torch.from_numpy(kernel).float().unsqueeze(0)  # [1, 15, 15]

        return image, kernel

# --- Transzformáció ---
transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor()
])

dataset = PSFDataset(
    blur_dir=os.path.join(DATA_DIR, "blurred"),
    kernel_dir=os.path.join(DATA_DIR, "kernels"),
    transform=transform
)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)

# --- Hálózat ---
class PSFPredictor(nn.Module):
    def __init__(self, kernel_size=15):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 32, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, 3, padding=1),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d((1, 1))
        )
        self.fc = nn.Linear(64, kernel_size * kernel_size)
        self.kernel_size = kernel_size

    def forward(self, x):
        x = self.encoder(x)  # [B, 64, 1, 1]
        x = x.view(x.size(0), -1)  # [B, 64]
        x = self.fc(x)             # [B, 225]
        x = x.view(-1, 1, self.kernel_size, self.kernel_size)
        x = F.relu(x)              # non-negatív
        x = x / (x.sum(dim=[2, 3], keepdim=True) + 1e-8)  # normalizálás
        return x

# --- Tanítás ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = PSFPredictor(kernel_size=KERNEL_SIZE).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
criterion = nn.MSELoss()

for epoch in range(EPOCHS):
    model.train()
    total_loss = 0
    for imgs, kernels in tqdm(dataloader, desc=f"Epoch {epoch+1}/{EPOCHS}"):
        imgs, kernels = imgs.to(device), kernels.to(device)
        pred = model(imgs)
        loss = criterion(pred, kernels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"\u2705 Epoch {epoch+1}: Loss = {total_loss / len(dataloader):.6f}")

# --- Modell mentése ---
torch.save(model.state_dict(), "/content/drive/MyDrive/models/psf_predictor.pth")
print("\n📅 Modell mentve: psf_predictor.pth")

Epoch 1/20: 100%|██████████| 50/50 [00:19<00:00,  2.57it/s]


✅ Epoch 1: Loss = 0.000632


Epoch 2/20: 100%|██████████| 50/50 [00:16<00:00,  3.03it/s]


✅ Epoch 2: Loss = 0.000584


Epoch 3/20: 100%|██████████| 50/50 [00:16<00:00,  3.09it/s]


✅ Epoch 3: Loss = 0.000551


Epoch 4/20: 100%|██████████| 50/50 [00:17<00:00,  2.91it/s]


✅ Epoch 4: Loss = 0.000533


Epoch 5/20: 100%|██████████| 50/50 [00:16<00:00,  3.02it/s]


✅ Epoch 5: Loss = 0.000527


Epoch 6/20: 100%|██████████| 50/50 [00:16<00:00,  3.05it/s]


✅ Epoch 6: Loss = 0.000526


Epoch 7/20: 100%|██████████| 50/50 [00:16<00:00,  3.00it/s]


✅ Epoch 7: Loss = 0.000525


Epoch 8/20: 100%|██████████| 50/50 [00:16<00:00,  3.00it/s]


✅ Epoch 8: Loss = 0.000525


Epoch 9/20: 100%|██████████| 50/50 [00:16<00:00,  2.97it/s]


✅ Epoch 9: Loss = 0.000525


Epoch 10/20: 100%|██████████| 50/50 [00:16<00:00,  3.03it/s]


✅ Epoch 10: Loss = 0.000525


Epoch 11/20: 100%|██████████| 50/50 [00:16<00:00,  3.04it/s]


✅ Epoch 11: Loss = 0.000525


Epoch 12/20: 100%|██████████| 50/50 [00:16<00:00,  3.04it/s]


✅ Epoch 12: Loss = 0.000525


Epoch 13/20: 100%|██████████| 50/50 [00:16<00:00,  2.97it/s]


✅ Epoch 13: Loss = 0.000525


Epoch 14/20: 100%|██████████| 50/50 [00:16<00:00,  3.00it/s]


✅ Epoch 14: Loss = 0.000525


Epoch 15/20: 100%|██████████| 50/50 [00:16<00:00,  3.02it/s]


✅ Epoch 15: Loss = 0.000525


Epoch 16/20: 100%|██████████| 50/50 [00:16<00:00,  3.01it/s]


✅ Epoch 16: Loss = 0.000524


Epoch 17/20: 100%|██████████| 50/50 [00:16<00:00,  3.01it/s]


✅ Epoch 17: Loss = 0.000524


Epoch 18/20: 100%|██████████| 50/50 [00:17<00:00,  2.92it/s]


✅ Epoch 18: Loss = 0.000524


Epoch 19/20: 100%|██████████| 50/50 [00:16<00:00,  2.98it/s]


✅ Epoch 19: Loss = 0.000524


Epoch 20/20: 100%|██████████| 50/50 [00:16<00:00,  2.97it/s]

✅ Epoch 20: Loss = 0.000524

📅 Modell mentve: psf_predictor.pth



