In [1]:
import torch 
from torch.utils.data import random_split
import numpy as np
torch.manual_seed(42)
np.random.seed(42)


import torch.nn as nn
import torch.optim as optim
from torchvision import models, transforms
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import os

from PIL import Image
import os
from tqdm import tqdm


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 1, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )

    def forward(self, x):
        return self.model(x)
    

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(256, 1, kernel_size=4, stride=2, padding=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.model(x)

In [4]:
vgg = models.vgg19(pretrained=True).features
vgg = nn.Sequential(*list(vgg.children())[:8])  # Use only first few layers
for param in vgg.parameters():
    param.requires_grad = False



In [5]:
class ImageDataset(Dataset):
    def __init__(self, images_dir, transform=None):
        self.images = sorted(os.listdir(images_dir))
        self.images_dir = images_dir
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = Image.open(os.path.join(self.images_dir, self.images[idx]))

        # Split the image 
        width, height = image.size        
        truth_image = image.crop((0, 0, width // 3, height))
        glare_image = image.crop((width // 3, 0, (width//3)*2, height))

        if self.transform:
            truth_image = self.transform(truth_image)
            glare_image = self.transform(glare_image)

        return truth_image, glare_image

In [6]:
transform = transforms.Compose([
    # change to grayscale
    transforms.Grayscale(),
    transforms.Resize((512, 512)),
    transforms.ToTensor()
])

In [7]:
train_path = '../dataset/SD1/train'
dataset = ImageDataset(train_path, transform=transform)


# Assuming dataset is already defined
total_size = len(dataset)
train_size = int(total_size * 0.8)  # 80% for training
val_size = total_size - train_size  # Remaining 20% for validation

# Perform the split
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=16, shuffle=False)

In [8]:
generator = Generator().to(device)
discriminator = Discriminator().to(device)

In [9]:
adversarial_loss = nn.BCELoss().to(device)
content_loss = nn.MSELoss().to(device)
pixel_loss = nn.MSELoss().to(device)
l1_loss = nn.L1Loss().to(device)

learning_rate = 1e-4
optimizer_G = optim.Adam(generator.parameters(), lr=learning_rate)
optimizer_D = optim.Adam(discriminator.parameters(), lr=learning_rate)

num_epochs = 10
best_val_loss = float('inf')
best_generator_path = "../checkpoint/best_generator.pth"
best_discriminator_path = "../checkpoint/best_discriminator.pth"
if not os.path.exists("../checkpoint"):
    os.makedirs("../checkpoint")

for epoch in range(num_epochs):
        generator.train()
        discriminator.train()
        
        train_pbar = tqdm(train_dataloader, desc=f"Epoch [{epoch+1}/{num_epochs}] Training")
        for truth_image, glare_image in train_pbar:

            truth_image = truth_image.to(device)
            glare_image = glare_image.to(device)
            batch_size = glare_image.size(0)

            # Adversarial ground truths
            valid = torch.ones((batch_size, 1, 32, 32), requires_grad=False).to(device)
            fake = torch.zeros((batch_size, 1, 32, 32), requires_grad=False).to(device)

            # -----------------
            #  Train Generator
            # -----------------

            optimizer_G.zero_grad()

            # Generate a batch of images
            generated_image = generator(glare_image)

            # VGG19 feature extraction
            vgg_gen_feat = vgg(generated_image.repeat(1, 3, 1, 1))  # Repeat channel dimension for VGG19
            vgg_no_glare_feat = vgg(truth_image.repeat(1, 3, 1, 1))

            # Loss measures generator's ability to fool the discriminator
            g_loss_adv = adversarial_loss(discriminator(generated_image), valid)
            g_loss_content = content_loss(vgg_gen_feat, vgg_no_glare_feat)
            g_loss_pixel = pixel_loss(generated_image, truth_image)
            g_loss = g_loss_adv + g_loss_content + g_loss_pixel

            g_loss.backward()
            optimizer_G.step()

            # ---------------------
            #  Train Discriminator
            # ---------------------

            optimizer_D.zero_grad()

            # Measure discriminator's ability to classify real from generated samples
            real_loss = adversarial_loss(discriminator(truth_image), valid)
            fake_loss = adversarial_loss(discriminator(generated_image.detach()), fake)
            d_loss = (real_loss + fake_loss) / 2

            d_loss.backward()
            optimizer_D.step()

        print(f'Epoch [{epoch+1}/{num_epochs}], D Loss: {d_loss.item()}, G Loss: {g_loss.item()}')

        # Validation
        generator.eval()
        val_pbar = tqdm(val_dataloader, desc=f"Epoch [{epoch+1}/{num_epochs}] Validation")
        with torch.no_grad():
            val_l1_loss = 0
            
            for truth_image, glare_image in val_pbar:
                truth_image = truth_image.to(device)
                glare_image = glare_image.to(device)
                generated_image = generator(glare_image)
                val_l1_loss += l1_loss(generated_image, truth_image).item()

            val_l1_loss /= len(val_dataloader)
            print(f'Epoch [{epoch+1}/{num_epochs}], D Loss: {d_loss.item()}, G Loss: {g_loss.item()}, Val L1 Loss: {val_l1_loss}')

        if val_l1_loss < best_val_loss:
            best_val_loss = val_l1_loss
            torch.save(generator.state_dict(), best_generator_path)
            torch.save(discriminator.state_dict(), best_discriminator_path)


Epoch [1/10] Training:   0%|          | 0/600 [00:00<?, ?it/s]Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/Users/gx/anaconda3/lib/python3.11/multiprocessing/spawn.py", line 122, in spawn_main
    exitcode = _main(fd, parent_sentinel)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/gx/anaconda3/lib/python3.11/multiprocessing/spawn.py", line 132, in _main
    self = reduction.pickle.load(from_parent)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: Can't get attribute 'ImageDataset' on <module '__main__' (built-in)>
Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/Users/gx/anaconda3/lib/python3.11/multiprocessing/spawn.py", line 122, in spawn_main
    exitcode = _main(fd, parent_sentinel)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/gx/anaconda3/lib/python3.11/multiprocessing/spawn.py", line 132, in _main
    self = reduction.pickle.load(from_parent)
           ^^^^^^^^^^^^^^^^^^^^^^

RuntimeError: DataLoader worker (pid(s) 61860) exited unexpectedly