In [30]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data import Dataset, DataLoader
from einops import rearrange
from PIL import Image
import os
from tqdm import tqdm

In [31]:
os.environ["PYTHONHASHSEED"] = str(42)
np.random.seed(42)
torch.manual_seed(42)
torch.cuda.manual_seed(42)
torch.cuda.manual_seed_all(42)

In [32]:
noisy_path = r"C:\Users\91909\Desktop\ML\DATA\NTIRE\noisy_data\DIV2K\DIV2K_train_HR\DIV2K_train_HR"
clean_path = r"C:\Users\91909\Desktop\ML\DATA\NTIRE\clean_data\DIV2K\DIV2K_train_HR\DIV2K_train_HR"

In [33]:
IMAGE_SIZE = 1536
PATCH_SIZE = 256
BATCH_SIZE = 1
EPOCHS = 10
LEARNING_RATE = 0.0001
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
EARLY_STOPPING_PATIENCE = 2

In [34]:
# class DenoiseDataset(Dataset):
#     def __init__(self, clean_dir, noisy_dir, transform):
#         self.clean_files = sorted(os.listdir(clean_dir))
#         self.noisy_files = sorted(os.listdir(noisy_dir))
#         self.clean_dir = clean_dir
#         self.noisy_dir = noisy_dir
#         self.transform = transform

#     def __len__(self):
#         return len(self.clean_files)

#     def __getitem__(self, idx):
#         clean_img = cv2.imread(os.path.join(self.clean_dir, self.clean_files[idx]))
#         noisy_img = cv2.imread(os.path.join(self.noisy_dir, self.noisy_files[idx]))

#         clean_img = cv2.cvtColor(clean_img, cv2.COLOR_BGR2RGB)
#         noisy_img = cv2.cvtColor(noisy_img, cv2.COLOR_BGR2RGB)

#         clean_img = cv2.resize(clean_img, (IMAGE_SIZE, IMAGE_SIZE))
#         noisy_img = cv2.resize(noisy_img, (IMAGE_SIZE, IMAGE_SIZE))

#         clean_img = self.transform(clean_img)
#         noisy_img = self.transform(noisy_img)

#         return noisy_img, clean_img

In [35]:
class ImageDataset(Dataset):
    def __init__(self, clean_dir, noisy_dir, transform=None):
        self.clean_files = sorted(os.listdir(clean_dir))
        self.noisy_files = sorted(os.listdir(noisy_dir))
        self.clean_dir = clean_dir
        self.noisy_dir = noisy_dir
        self.transform = transform
    
    def __len__(self):
        return len(self.clean_files)
    
    def __getitem__(self, idx):
        clean_path = os.path.join(self.clean_dir, self.clean_files[idx])
        noisy_path = os.path.join(self.noisy_dir, self.noisy_files[idx])
        
        clean_img = Image.open(clean_path).convert("RGB").resize((IMAGE_SIZE, IMAGE_SIZE))
        noisy_img = Image.open(noisy_path).convert("RGB").resize((IMAGE_SIZE, IMAGE_SIZE))
        
        if self.transform:
            clean_img = self.transform(clean_img)
            noisy_img = self.transform(noisy_img)
        
        return noisy_img, clean_img

In [36]:
transform = transforms.Compose([
    transforms.ToTensor()
])

In [37]:
# train_dataset = DenoiseDataset(clean_path, noisy_path, transform)
# train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)

In [38]:
dataset = ImageDataset(clean_path, noisy_path, transform)
train_loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)

In [39]:
class SwinIR(nn.Module):
    def __init__(self, in_channels=3):
        super(SwinIR, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(64, in_channels, kernel_size=3, padding=1)
        self.relu = nn.ReLU()
    
    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.conv3(x)
        return x

In [40]:
model = SwinIR().to(DEVICE)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
loss_history = []

In [41]:
best_loss = float('inf')
patience = 0

In [42]:
for epoch in range(EPOCHS):
    model.train()
    epoch_loss = 0
    for noisy_imgs, clean_imgs in tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}"):
        noisy_imgs, clean_imgs = noisy_imgs.to(DEVICE), clean_imgs.to(DEVICE)
        
        optimizer.zero_grad()
        output = model(noisy_imgs)
        loss = criterion(output, clean_imgs)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
    
    avg_loss = epoch_loss / len(train_loader)
    loss_history.append(avg_loss)
    print(f"Epoch {epoch+1}/{EPOCHS}, Loss: {avg_loss:.6f}")

Epoch 1/10:   0%|          | 0/800 [00:00<?, ?it/s]


RuntimeError: CUDA out of memory. Tried to allocate 576.00 MiB (GPU 0; 4.00 GiB total capacity; 83.31 MiB already allocated; 1.88 GiB free; 634.00 MiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [None]:
plt.figure()
plt.plot(range(1, EPOCHS+1), loss_history, marker='o', linestyle='-')
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Training Loss Curve")
plt.grid()
plt.show()

In [None]:
def stitch_patches(patches, img_size, patch_size):
    h = w = img_size // patch_size
    return rearrange(patches, "(b h w) c p1 p2 -> b c (h p1) (w p2)", h=h, w=w)

In [None]:
def denoise_batch(noisy_batch):
    noisy_batch = rearrange(noisy_batch, 'b c (h p1) (w p2) -> (b h w) c p1 p2', p1=PATCH_SIZE, p2=PATCH_SIZE)
    with torch.no_grad():
        denoised_patches = model(noisy_batch)
    return stitch_patches(denoised_patches, IMAGE_SIZE, PATCH_SIZE)

In [None]:
clean_test_path = r"C:\Users\91909\Desktop\ML\DATA\NTIRE\clean_data\DIV2K\DIV2K_valid_HR\DIV2K_valid_HR"
noisy_test_path = r"C:\Users\91909\Desktop\ML\DATA\NTIRE\noisy_data\DIV2K\DIV2K_valid_HR\DIV2K_valid_HR"

In [None]:
model.eval()
noisy_img, clean_img = dataset[0]
noisy_img, clean_img = noisy_img.unsqueeze(0).to(DEVICE), clean_img.unsqueeze(0).to(DEVICE)
with torch.no_grad():
    denoised_img = model(noisy_img)

In [None]:
# test_dataset = DenoiseDataset(clean_test_path, noisy_test_path, transform)
# test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [None]:
def show_images(noisy, denoised, clean):
    noisy, denoised, clean = noisy.squeeze().cpu().numpy(), denoised.squeeze().cpu().numpy(), clean.squeeze().cpu().numpy()
    noisy, denoised, clean = np.transpose(noisy, (1,2,0)), np.transpose(denoised, (1,2,0)), np.transpose(clean, (1,2,0))
    
    fig, axs = plt.subplots(1, 3, figsize=(12,4))
    axs[0].imshow(noisy)
    axs[0].set_title("Noisy Image")
    axs[1].imshow(denoised)
    axs[1].set_title("Denoised Image")
    axs[2].imshow(clean)
    axs[2].set_title("Clean Image")
    for ax in axs:
        ax.axis("off")
    plt.show()

In [None]:
show_images(noisy_img, denoised_img, clean_img)

In [None]:
model.eval()

SwinIR(
  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv3): Conv2d(64, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu): ReLU()
)

In [None]:
for noisy_batch, _ in test_loader:
    noisy_batch = noisy_batch.to(DEVICE)
    denoised_batch = denoise_batch(noisy_batch)

    for i, denoised_image in enumerate(denoised_batch):
        denoised_image = denoised_image.permute(1, 2, 0).cpu().numpy()
        denoised_image = (denoised_image * 255).astype(np.uint8)
        cv2.imwrite(f"denoised_output_{i}.png", cv2.cvtColor(denoised_image, cv2.COLOR_RGB2BGR))

    break

In [None]:
for epoch in range(EPOCHS):
    model.train()
    epoch_loss = 0

    for noisy_imgs, clean_imgs in tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}"):
        noisy_imgs, clean_imgs = noisy_imgs.to(DEVICE), clean_imgs.to(DEVICE)

        # Convert images into 256x256 patches
        noisy_patches = rearrange(noisy_imgs, 'b c (h p1) (w p2) -> (b h w) c p1 p2', p1=PATCH_SIZE, p2=PATCH_SIZE)
        clean_patches = rearrange(clean_imgs, 'b c (h p1) (w p2) -> (b h w) c p1 p2', p1=PATCH_SIZE, p2=PATCH_SIZE)

        optimizer.zero_grad()
        output = model(noisy_patches)
        loss = criterion(output, clean_patches)
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

    avg_loss = epoch_loss / len(train_loader)
    print(f"Epoch [{epoch+1}/{EPOCHS}], Loss: {avg_loss:.6f}")

    # Early Stopping
    if avg_loss < best_loss:
        best_loss = avg_loss
        patience = 0
        torch.save(model.state_dict(), "best_model.pth")
    else:
        patience += 1
        if patience >= EARLY_STOPPING_PATIENCE:
            print("Early stopping triggered.")
            break

Epoch 1/20: 100%|██████████| 800/800 [14:37<00:00,  1.10s/it]


Epoch [1/20], Loss: 0.009627


Epoch 2/20: 100%|██████████| 800/800 [14:25<00:00,  1.08s/it]


Epoch [2/20], Loss: 0.002457


Epoch 3/20: 100%|██████████| 800/800 [14:46<00:00,  1.11s/it]


Epoch [3/20], Loss: 0.002189


Epoch 4/20: 100%|██████████| 800/800 [16:18<00:00,  1.22s/it] 


Epoch [4/20], Loss: 0.002098


Epoch 5/20: 100%|██████████| 800/800 [14:47<00:00,  1.11s/it]


Epoch [5/20], Loss: 0.002032


Epoch 6/20: 100%|██████████| 800/800 [15:00<00:00,  1.13s/it]


Epoch [6/20], Loss: 0.002037


Epoch 7/20: 100%|██████████| 800/800 [14:52<00:00,  1.12s/it]


Epoch [7/20], Loss: 0.001959


Epoch 8/20: 100%|██████████| 800/800 [15:07<00:00,  1.13s/it]


Epoch [8/20], Loss: 0.001948


Epoch 9/20: 100%|██████████| 800/800 [16:18<00:00,  1.22s/it]  


Epoch [9/20], Loss: 0.001915


Epoch 10/20: 100%|██████████| 800/800 [14:57<00:00,  1.12s/it]


Epoch [10/20], Loss: 0.001905


Epoch 11/20: 100%|██████████| 800/800 [14:57<00:00,  1.12s/it]


Epoch [11/20], Loss: 0.001881


Epoch 12/20: 100%|██████████| 800/800 [14:51<00:00,  1.11s/it]


Epoch [12/20], Loss: 0.001875


Epoch 13/20: 100%|██████████| 800/800 [14:52<00:00,  1.12s/it]


Epoch [13/20], Loss: 0.001861


Epoch 14/20: 100%|██████████| 800/800 [14:53<00:00,  1.12s/it]


Epoch [14/20], Loss: 0.001851


Epoch 15/20: 100%|██████████| 800/800 [14:47<00:00,  1.11s/it]


Epoch [15/20], Loss: 0.001841


Epoch 16/20: 100%|██████████| 800/800 [14:44<00:00,  1.11s/it]


Epoch [16/20], Loss: 0.001825


Epoch 17/20: 100%|██████████| 800/800 [14:42<00:00,  1.10s/it]


Epoch [17/20], Loss: 0.001824


Epoch 18/20: 100%|██████████| 800/800 [14:42<00:00,  1.10s/it]


Epoch [18/20], Loss: 0.001810


Epoch 19/20: 100%|██████████| 800/800 [15:06<00:00,  1.13s/it]


Epoch [19/20], Loss: 0.001802


Epoch 20/20: 100%|██████████| 800/800 [16:00<00:00,  1.20s/it] 

Epoch [20/20], Loss: 0.001798





In [None]:
def stitch_patches(patches, img_size, patch_size):
    h = w = img_size // patch_size
    return rearrange(patches, "(b h w) c p1 p2 -> b c (h p1) (w p2)", h=h, w=w)