In [None]:
import os
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torch.nn.functional as F 

# PSNR 계산 함수
def calculate_psnr(img1, img2, max_val=1.0):
    mse = torch.mean((img1 - img2) ** 2) 
    if mse == 0:
        return float('inf')  
    psnr = 10 * torch.log10(max_val**2 / mse)
    return psnr

# SSIM 계산 함수
def calculate_ssim(img1, img2, max_val=1.0):
    C1 = (0.01 * max_val) ** 2
    C2 = (0.03 * max_val) ** 2

    mu1 = F.avg_pool2d(img1, kernel_size=3, stride=1, padding=1)
    mu2 = F.avg_pool2d(img2, kernel_size=3, stride=1, padding=1)

    mu1_sq = mu1.pow(2)
    mu2_sq = mu2.pow(2)
    mu1_mu2 = mu1 * mu2

    sigma1_sq = F.avg_pool2d(img1 * img1, kernel_size=3, stride=1, padding=1) - mu1_sq
    sigma2_sq = F.avg_pool2d(img2 * img2, kernel_size=3, stride=1, padding=1) - mu2_sq
    sigma12 = F.avg_pool2d(img1 * img2, kernel_size=3, stride=1, padding=1) - mu1_mu2

    ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
    return ssim_map.mean()

# 저해상도 이미지 생성 함수
def create_low_res_images(input_dir, output_dir, scale):
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    for filename in os.listdir(input_dir):
        if filename.endswith('.png') or filename.endswith('.jpg'):
            img = Image.open(os.path.join(input_dir, filename))
            lr_img = img.resize((img.width // scale, img.height // scale), Image.BICUBIC)
            lr_img.save(os.path.join(output_dir, filename))

# SRDataset 클래스 정의(학습 데이터DIV2K 사용)
class SRDataset(Dataset):
    def __init__(self, hr_dir, lr_dir, transform=None, target_size=(256, 256)):
        self.hr_dir = hr_dir
        self.lr_dir = lr_dir
        self.hr_filenames = sorted(os.listdir(hr_dir))
        self.lr_filenames = sorted(os.listdir(lr_dir))
        self.transform = transform
        self.target_size = target_size

    def __len__(self):
        return len(self.hr_filenames)

    def __getitem__(self, idx):
        hr_path = os.path.join(self.hr_dir, self.hr_filenames[idx])
        lr_path = os.path.join(self.lr_dir, self.lr_filenames[idx])

        hr_img = Image.open(hr_path).convert("RGB")
        lr_img = Image.open(lr_path).convert("RGB")
        hr_img = transforms.Resize(self.target_size)(hr_img)
        lr_img = transforms.Resize(self.target_size)(lr_img)

        if self.transform:
            hr_img = self.transform(hr_img)
            lr_img = self.transform(lr_img)

        return lr_img, hr_img

# SRCNN 모델 정의
class SRCNN(nn.Module):
    def __init__(self):
        super(SRCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=9, padding=4)
        self.conv2 = nn.Conv2d(64, 32, kernel_size=5, padding=2)
        self.conv3 = nn.Conv2d(32, 3, kernel_size=5, padding=2)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.conv3(x)
        return x

# 데이터 경로
train_hr_dir = r"C:\Users\kowm6\Desktop\SRdataset\DIV2K\DIV2K_train_HR"
train_lr_dir = r"C:\Users\kowm6\Desktop\SRdataset\DIV2K\DIV2K_train_LR"
valid_hr_dir = r"C:\Users\kowm6\Desktop\SRdataset\DIV2K\DIV2K_valid_HR"
valid_lr_dir = r"C:\Users\kowm6\Desktop\SRdataset\DIV2K\DIV2K_valid_LR"

# 데이터셋 및 DataLoader 설정
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = SRDataset(train_hr_dir, train_lr_dir, transform, target_size=(256, 256))
valid_dataset = SRDataset(valid_hr_dir, valid_lr_dir, transform, target_size=(256, 256))
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=16, shuffle=False)

# Train
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = SRCNN().to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# 학습 및 검증 루프
num_epochs = 2
best_valid_loss = float('inf')
save_path = r"C:\Users\kowm6\Desktop\sr\best_model.pth"

for epoch in range(num_epochs):
    # 학습 루프
    model.train()
    train_loss = 0
    train_psnr = 0
    for lr_imgs, hr_imgs in train_loader:
        lr_imgs, hr_imgs = lr_imgs.to(device), hr_imgs.to(device)
        preds = model(lr_imgs)
        loss = criterion(preds, hr_imgs)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss += loss.item()

        # PSNR 계산
        train_psnr += calculate_psnr(preds, hr_imgs)

    train_psnr /= len(train_loader)

    # 검증 루프
    model.eval()
    valid_loss = 0
    valid_psnr = 0
    valid_ssim = 0
    with torch.no_grad():
        for lr_imgs, hr_imgs in valid_loader:
            lr_imgs, hr_imgs = lr_imgs.to(device), hr_imgs.to(device)
            preds = model(lr_imgs)
            valid_loss += criterion(preds, hr_imgs).item()

            # PSNR 및 SSIM 계산
            valid_psnr += calculate_psnr(preds, hr_imgs)
            valid_ssim += calculate_ssim(preds, hr_imgs)

    valid_psnr /= len(valid_loader)
    valid_ssim /= len(valid_loader)
    valid_loss /= len(valid_loader)

    print(f"Epoch {epoch + 1}/{num_epochs}, Train Loss: {train_loss / len(train_loader):.4f}, "
          f"Validation Loss: {valid_loss:.4f}, Train PSNR: {train_psnr:.4f}, "
          f"Validation PSNR: {valid_psnr:.4f}, Validation SSIM: {valid_ssim:.4f}")

    # 최적의 가중치 저장
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), save_path)
        print("Best model saved!")

# 테스트 데이터 평가
test_hr_dir = r"C:/Users/kowm6/Desktop/testlr"
test_lr_dir = r"C:/Users/kowm6/Desktop/testhr"
test_dataset = SRDataset(test_hr_dir, test_lr_dir, transform, target_size=(256, 256))
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

model.load_state_dict(torch.load(save_path))
model.eval()

test_psnr = 0
test_ssim = 0
num_test_images = 0

with torch.no_grad():
    for lr_imgs, hr_imgs in test_loader:
        lr_imgs, hr_imgs = lr_imgs.to(device), hr_imgs.to(device)
        preds = model(lr_imgs)

        test_psnr += calculate_psnr(preds, hr_imgs)
        test_ssim += calculate_ssim(preds, hr_imgs)
        num_test_images += 1

test_psnr /= num_test_images
test_ssim /= num_test_images

print(f"Test PSNR: {test_psnr:.4f}, Test SSIM: {test_ssim:.4f}")


Epoch 1/2, Train Loss: 0.0574, Validation Loss: 0.0185, Train PSNR: 14.2296, Validation PSNR: 17.4362, Validation SSIM: 0.5050
Best model saved!
Epoch 2/2, Train Loss: 0.0113, Validation Loss: 0.0073, Train PSNR: 19.6674, Validation PSNR: 21.4736, Validation SSIM: 0.6553
Best model saved!
Test PSNR: 22.6087, Test SSIM: 0.7275


  model.load_state_dict(torch.load(save_path))
