In [1]:
import os
import cv2
import numpy as np
from torch.utils.data import Dataset, DataLoader
import torch

class LoLDataset(Dataset):
    def __init__(self, low_light_dir, high_light_dir, transform=None, target_size=(256, 256)):
        self.low_light_dir = low_light_dir
        self.high_light_dir = high_light_dir
        self.transform = transform
        self.target_size = target_size
        self.low_light_images = sorted(os.listdir(low_light_dir))
        self.high_light_images = sorted(os.listdir(high_light_dir))

    def __len__(self):
        return len(self.low_light_images)

    def __getitem__(self, idx):
        low_img_path = os.path.join(self.low_light_dir, self.low_light_images[idx])
        high_img_path = os.path.join(self.high_light_dir, self.high_light_images[idx])

        low_img = cv2.imread(low_img_path)
        high_img = cv2.imread(high_img_path)

        if low_img is None or high_img is None:
            raise ValueError(f"Failed to load image at {low_img_path} or {high_img_path}")

        low_img = cv2.cvtColor(low_img, cv2.COLOR_BGR2RGB)
        high_img = cv2.cvtColor(high_img, cv2.COLOR_BGR2RGB)

        # Resize images to target size
        low_img = cv2.resize(low_img, self.target_size, interpolation=cv2.INTER_AREA)
        high_img = cv2.resize(high_img, self.target_size, interpolation=cv2.INTER_AREA)

        if self.transform:
            low_img = self.transform(low_img)
            high_img = self.transform(high_img)

        return low_img, high_img

def get_data_loaders(low_light_dir, high_light_dir, batch_size=16, target_size=(256, 256)):
    transform = lambda x: torch.from_numpy(x.transpose(2, 0, 1)).float() / 255.0
    dataset = LoLDataset(low_light_dir, high_light_dir, transform, target_size)
    return DataLoader(dataset, batch_size=batch_size, shuffle=True)

def save_image(tensor, path):
    img = tensor.detach().cpu().numpy().transpose(1, 2, 0) * 255.0  # Thêm .detach()
    img = img.astype(np.uint8)
    img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
    cv2.imwrite(path, img)

In [2]:
import torch
import torch.nn as nn
from torchvision.models import vgg16

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        # Encoder
        self.enc1 = nn.Sequential(nn.Conv2d(3, 64, 4, stride=2, padding=1), nn.ReLU())
        self.enc2 = nn.Sequential(nn.Conv2d(64, 128, 4, stride=2, padding=1), nn.ReLU())
        self.enc3 = nn.Sequential(nn.Conv2d(128, 256, 4, stride=2, padding=1), nn.ReLU())
        self.enc4 = nn.Sequential(nn.Conv2d(256, 512, 4, stride=2, padding=1), nn.ReLU())

        # Decoder with skip connections
        self.dec4 = nn.Sequential(nn.ConvTranspose2d(512, 256, 4, stride=2, padding=1), nn.ReLU())
        self.dec3 = nn.Sequential(nn.ConvTranspose2d(512, 128, 4, stride=2, padding=1), nn.ReLU())
        self.dec2 = nn.Sequential(nn.ConvTranspose2d(256, 64, 4, stride=2, padding=1), nn.ReLU())
        self.dec1 = nn.Sequential(nn.ConvTranspose2d(128, 3, 4, stride=2, padding=1), nn.Tanh())

    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(e1)
        e3 = self.enc3(e2)
        e4 = self.enc4(e3)

        d4 = self.dec4(e4)
        d4 = torch.cat([d4, e3], dim=1)
        d3 = self.dec3(d4)
        d3 = torch.cat([d3, e2], dim=1)
        d2 = self.dec2(d3)
        d2 = torch.cat([d2, e1], dim=1)
        d1 = self.dec1(d2)
        return d1

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 64, 4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, 4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(128, 256, 4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(256, 1, 4, stride=1, padding=0),
            nn.Sigmoid(),
        )

    def forward(self, x):
        return self.model(x)

class PerceptualLoss(nn.Module):
    def __init__(self, device):
        super(PerceptualLoss, self).__init__()
        vgg = vgg16(pretrained=True).features[:16].eval().to(device)
        for param in vgg.parameters():
            param.requires_grad = False
        self.vgg = vgg
        self.loss = nn.MSELoss()

    def forward(self, fake, real):
        fake_features = self.vgg(fake)
        real_features = self.vgg(real)
        return self.loss(fake_features, real_features)

def initialize_weights(model):
    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
            nn.init.normal_(m.weight.data, 0.0, 0.02)
        elif isinstance(m, nn.BatchNorm2d):
            nn.init.normal_(m.weight.data, 1.0, 0.02)
            nn.init.constant_(m.bias.data, 0)

In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
import os

In [4]:
def train_gan(epochs=100, batch_size=16, lr=0.00005, beta1=0.5, start_epoch=0, checkpoint_path=None):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Initialize models
    generator = Generator().to(device)
    discriminator = Discriminator().to(device)
    perceptual_loss = PerceptualLoss(device)

    initialize_weights(generator)
    initialize_weights(discriminator)

    # Load checkpoint if provided
    if checkpoint_path and os.path.exists(checkpoint_path):
        checkpoint = torch.load(checkpoint_path, map_location=device)
        generator.load_state_dict(checkpoint)
        print(f"Loaded checkpoint from {checkpoint_path}")
    else:
        print("No checkpoint loaded, starting from scratch or with initialized weights.")

    # Loss functions
    adversarial_loss = nn.BCELoss()
    pixel_loss = nn.L1Loss()

    # Optimizers
    g_optimizer = optim.Adam(generator.parameters(), lr=lr, betas=(beta1, 0.999))
    d_optimizer = optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, 0.999))

    # Data loaders/content/drive/MyDrive/
    low_light_dir = "/content/drive/MyDrive/data/processed/our485/low"
    high_light_dir = "/content/drive/MyDrive/data/processed/our485/high"
    train_loader = get_data_loaders(low_light_dir, high_light_dir, batch_size)

    # Validation loader
    eval_loader = get_data_loaders("/content/drive/MyDrive/data/processed/eval15/low", "/content/drive/MyDrive/data/processed/eval15/high", batch_size)


    for epoch in range(start_epoch, start_epoch + epochs):
        generator.train()
        for i, (low_light, high_light) in enumerate(train_loader):
            low_light, high_light = low_light.to(device), high_light.to(device)

            # Calculate the output size of the Discriminator
            with torch.no_grad():
                sample_output = discriminator(low_light)
                output_size = sample_output.shape[2:]  # e.g., (29, 29)

            # Adjust labels to match the Discriminator's output size
            real_labels = torch.ones(low_light.size(0), 1, *output_size).to(device)
            fake_labels = torch.zeros(low_light.size(0), 1, *output_size).to(device)

            # Train Discriminator
            d_optimizer.zero_grad()

            real_output = discriminator(high_light)
            d_real_loss = adversarial_loss(real_output, real_labels)

            fake_images = generator(low_light)
            fake_output = discriminator(fake_images.detach())
            d_fake_loss = adversarial_loss(fake_output, fake_labels)

            d_loss = (d_real_loss + d_fake_loss) / 2
            d_loss.backward()
            d_optimizer.step()

            # Train Generator
            g_optimizer.zero_grad()
            fake_output = discriminator(fake_images)
            g_adv_loss = adversarial_loss(fake_output, real_labels)
            g_pixel_loss = pixel_loss(fake_images, high_light)
            g_perceptual_loss = perceptual_loss(fake_images, high_light)
            g_loss = g_adv_loss + 10 * g_pixel_loss + 5 * g_perceptual_loss
            g_loss.backward()
            g_optimizer.step()

            if i % 10 == 0:
                print(f"Epoch [{epoch}/{start_epoch + epochs - 1}] Batch [{i}/{len(train_loader)}] "
                    f"D Loss: {d_loss.item():.4f} G Loss: {g_loss.item():.4f}")

        # Evaluate on validation set
        if epoch % 10 == 0:
            generator.eval()
            with torch.no_grad():
                for low_light, high_light in eval_loader:
                    low_light, high_light = low_light.to(device), high_light.to(device)
                    fake_images = generator(low_light)
                    save_image(fake_images[0], f"/content/drive/MyDrive/UPDATE/eval_epoch_{epoch}.png")
                    break
            save_image(fake_images[0], f"/content/drive/MyDrive/UPDATE/epoch_{epoch}.png")
            torch.save(generator.state_dict(), f"/content/drive/MyDrive/UPDATE/generator_epoch_{epoch}.pth")

if __name__ == "__main__":
    checkpoint_path = "/content/drive/MyDrive/generator_epoch_280.pth"
    train_gan(epochs=200, start_epoch=281, checkpoint_path=checkpoint_path, lr=0.000005)

Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /root/.cache/torch/hub/checkpoints/vgg16-397923af.pth
100%|██████████| 528M/528M [00:04<00:00, 111MB/s]


Loaded checkpoint from /content/drive/MyDrive/generator_epoch_280.pth
Epoch [281/480] Batch [0/32] D Loss: 0.6932 G Loss: 4.2440
Epoch [281/480] Batch [10/32] D Loss: 0.6931 G Loss: 4.2888
Epoch [281/480] Batch [20/32] D Loss: 0.6928 G Loss: 4.4049
Epoch [281/480] Batch [30/32] D Loss: 0.6931 G Loss: 4.5884
Epoch [282/480] Batch [0/32] D Loss: 0.6930 G Loss: 3.8558
Epoch [282/480] Batch [10/32] D Loss: 0.6928 G Loss: 4.5542
Epoch [282/480] Batch [20/32] D Loss: 0.6927 G Loss: 4.4222
Epoch [282/480] Batch [30/32] D Loss: 0.6927 G Loss: 4.6977
Epoch [283/480] Batch [0/32] D Loss: 0.6924 G Loss: 4.3727
Epoch [283/480] Batch [10/32] D Loss: 0.6925 G Loss: 4.6222
Epoch [283/480] Batch [20/32] D Loss: 0.6926 G Loss: 4.4884
Epoch [283/480] Batch [30/32] D Loss: 0.6921 G Loss: 4.1030
Epoch [284/480] Batch [0/32] D Loss: 0.6920 G Loss: 4.2918
Epoch [284/480] Batch [10/32] D Loss: 0.6920 G Loss: 3.9667
Epoch [284/480] Batch [20/32] D Loss: 0.6922 G Loss: 4.6303
Epoch [284/480] Batch [30/32] D Lo

In [None]:
if __name__ == "__main__":
    checkpoint_path = "/content/drive/MyDrive/UPDATE/generator_epoch_480.pth"
    train_gan(epochs=120,batch_size=32, start_epoch=481, checkpoint_path=checkpoint_path, lr=0.000005)

Loaded checkpoint from /content/drive/MyDrive/UPDATE/generator_epoch_480.pth
Epoch [481/600] Batch [0/16] D Loss: 0.6933 G Loss: 4.3548
Epoch [481/600] Batch [10/16] D Loss: 0.6931 G Loss: 3.9231
Epoch [482/600] Batch [0/16] D Loss: 0.6931 G Loss: 4.4005
Epoch [482/600] Batch [10/16] D Loss: 0.6928 G Loss: 3.8293
Epoch [483/600] Batch [0/16] D Loss: 0.6929 G Loss: 4.6152
Epoch [483/600] Batch [10/16] D Loss: 0.6927 G Loss: 3.6724
Epoch [484/600] Batch [0/16] D Loss: 0.6927 G Loss: 3.9255
Epoch [484/600] Batch [10/16] D Loss: 0.6923 G Loss: 4.2722
Epoch [485/600] Batch [0/16] D Loss: 0.6926 G Loss: 3.7828
Epoch [485/600] Batch [10/16] D Loss: 0.6924 G Loss: 4.2178
Epoch [486/600] Batch [0/16] D Loss: 0.6922 G Loss: 4.1196
Epoch [486/600] Batch [10/16] D Loss: 0.6920 G Loss: 3.9088
Epoch [487/600] Batch [0/16] D Loss: 0.6917 G Loss: 4.1003
Epoch [487/600] Batch [10/16] D Loss: 0.6918 G Loss: 3.9935
Epoch [488/600] Batch [0/16] D Loss: 0.6921 G Loss: 3.9413
Epoch [488/600] Batch [10/16] D