In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
from torchvision.utils import save_image
from torchsummary import summary
import os
from PIL import Image
import numpy as np
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import albumentations as A
from albumentations.pytorch import ToTensorV2

In [6]:
class Disc_Block(nn.Module):
    def __init__(self, in_channels, out_channels, stride):
        super(Disc_Block, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(
                in_channels, out_channels, 4, stride, 1, bias=False, padding_mode="reflect"
            ),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2),
        )

    def forward(self, x):
        return self.conv(x)

In [7]:
class Discriminator(nn.Module):
    def __init__(self, in_channels=3, features=[64, 128, 256, 512]):
        super().__init__()
        self.initial = nn.Sequential(
            nn.Conv2d(
                in_channels * 2,
                features[0],
                kernel_size=4,
                stride=2,
                padding=1,
                padding_mode="reflect",
            ),
            nn.LeakyReLU(0.2),
        )

        layers = []
        in_channels = features[0]
        for feature in features[1:]:
            layers.append(
                Disc_Block(in_channels, feature, stride=1 if feature == features[-1] else 2),
            )
            in_channels = feature

        layers.append(
            nn.Conv2d(
                in_channels, 1, kernel_size=4, stride=1, padding=1, padding_mode="reflect"
            ),
        )

        self.model = nn.Sequential(*layers)

    def forward(self, x, y):
        x = torch.cat([x, y], dim=1)
        x = self.initial(x)
        x = self.model(x)
        return x

In [51]:
# def test():
#     x = torch.randn((1, 3, 256, 256))
#     y = torch.randn((1, 3, 256, 256))
#     model = Discriminator(in_channels=3)
#     preds = model(x, y)
#     print(model)
#     print(preds.shape)
# test()

Discriminator(
  (initial): Sequential(
    (0): Conv2d(6, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), padding_mode=reflect)
    (1): LeakyReLU(negative_slope=0.2)
  )
  (model): Sequential(
    (0): CNNBlock(
      (conv): Sequential(
        (0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False, padding_mode=reflect)
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): LeakyReLU(negative_slope=0.2)
      )
    )
    (1): CNNBlock(
      (conv): Sequential(
        (0): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False, padding_mode=reflect)
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): LeakyReLU(negative_slope=0.2)
      )
    )
    (2): CNNBlock(
      (conv): Sequential(
        (0): Conv2d(256, 512, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1), bias=False, padding_mode=reflect)
        (1): Batc

In [53]:


model = Discriminator(in_channels=3)
summary(model, [(3, 256, 256), (3, 256, 256)])  # Provide input shapes for both `x` and `y`


----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 128, 128]           6,208
         LeakyReLU-2         [-1, 64, 128, 128]               0
            Conv2d-3          [-1, 128, 64, 64]         131,072
       BatchNorm2d-4          [-1, 128, 64, 64]             256
         LeakyReLU-5          [-1, 128, 64, 64]               0
          CNNBlock-6          [-1, 128, 64, 64]               0
            Conv2d-7          [-1, 256, 32, 32]         524,288
       BatchNorm2d-8          [-1, 256, 32, 32]             512
         LeakyReLU-9          [-1, 256, 32, 32]               0
         CNNBlock-10          [-1, 256, 32, 32]               0
           Conv2d-11          [-1, 512, 31, 31]       2,097,152
      BatchNorm2d-12          [-1, 512, 31, 31]           1,024
        LeakyReLU-13          [-1, 512, 31, 31]               0
         CNNBlock-14          [-1, 512,

In [8]:
class Gen_Block(nn.Module):
    def __init__(self, in_channels, out_channels, down=True, act="relu", use_dropout=False):
        super(Gen_Block, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 4, 2, 1, bias=False, padding_mode="reflect")
            if down
            else nn.ConvTranspose2d(in_channels, out_channels, 4, 2, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU() if act == "relu" else nn.LeakyReLU(0.2),
        )

        self.use_dropout = use_dropout
        self.dropout = nn.Dropout(0.5)
        self.down = down

    def forward(self, x):
        x = self.conv(x)
        return self.dropout(x) if self.use_dropout else x



In [9]:
class Generator(nn.Module):
    def __init__(self, in_channels=3, features=64):
        super().__init__()
        self.initial_down = nn.Sequential(
            nn.Conv2d(in_channels, features, 4, 2, 1, padding_mode="reflect"),
            nn.LeakyReLU(0.2),
        )
        self.down1 = Gen_Block(features, features * 2, down=True, act="leaky", use_dropout=False)
        self.down2 = Gen_Block(
            features * 2, features * 4, down=True, act="leaky", use_dropout=False
        )
        self.down3 = Gen_Block(
            features * 4, features * 8, down=True, act="leaky", use_dropout=False
        )
        self.down4 = Gen_Block(
            features * 8, features * 8, down=True, act="leaky", use_dropout=False
        )
        self.down5 = Gen_Block(
            features * 8, features * 8, down=True, act="leaky", use_dropout=False
        )
        self.down6 = Gen_Block(
            features * 8, features * 8, down=True, act="leaky", use_dropout=False
        )
        self.bottleneck = nn.Sequential(
            nn.Conv2d(features * 8, features * 8, 4, 2, 1), nn.ReLU()
        )

        self.up1 = Gen_Block(features * 8, features * 8, down=False, act="relu", use_dropout=True)
        self.up2 = Gen_Block(
            features * 8 * 2, features * 8, down=False, act="relu", use_dropout=True
        )
        self.up3 = Gen_Block(
            features * 8 * 2, features * 8, down=False, act="relu", use_dropout=True
        )
        self.up4 = Gen_Block(
            features * 8 * 2, features * 8, down=False, act="relu", use_dropout=False
        )
        self.up5 = Gen_Block(
            features * 8 * 2, features * 4, down=False, act="relu", use_dropout=False
        )
        self.up6 = Gen_Block(
            features * 4 * 2, features * 2, down=False, act="relu", use_dropout=False
        )
        self.up7 = Gen_Block(features * 2 * 2, features, down=False, act="relu", use_dropout=False)
        self.final_up = nn.Sequential(
            nn.ConvTranspose2d(features * 2, in_channels, kernel_size=4, stride=2, padding=1),
            nn.Tanh(),
        )

    def forward(self, x):
        d1 = self.initial_down(x)
        d2 = self.down1(d1)
        d3 = self.down2(d2)
        d4 = self.down3(d3)
        d5 = self.down4(d4)
        d6 = self.down5(d5)
        d7 = self.down6(d6)
        bottleneck = self.bottleneck(d7)
        up1 = self.up1(bottleneck)
        up2 = self.up2(torch.cat([up1, d7], 1))
        up3 = self.up3(torch.cat([up2, d6], 1))
        up4 = self.up4(torch.cat([up3, d5], 1))
        up5 = self.up5(torch.cat([up4, d4], 1))
        up6 = self.up6(torch.cat([up5, d3], 1))
        up7 = self.up7(torch.cat([up6, d2], 1))
        return self.final_up(torch.cat([up7, d1], 1))

In [56]:
# def test():
#     x = torch.randn((1, 3, 256, 256))
#     model = Generator(in_channels=3, features=64)
#     preds = model(x)
#     print(preds.shape)
# test()

torch.Size([1, 3, 256, 256])


In [10]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
TRAIN_DIR = "/root/.cache/kagglehub/datasets/tanvirnwu/loli-street-low-light-image-enhancement-of-street/versions/1/LoLI-Street Dataset/Train"
VAL_DIR = "/root/.cache/kagglehub/datasets/tanvirnwu/loli-street-low-light-image-enhancement-of-street/versions/1/LoLI-Street Dataset/Val"
LEARNING_RATE = 2e-4
BATCH_SIZE = 64
NUM_WORKERS = 2
IMAGE_SIZE = 256
CHANNELS_IMG = 3
L1_LAMBDA = 100
LAMBDA_GP = 10
NUM_EPOCHS = 50
LOAD_MODEL = False
SAVE_MODEL = False
CHECKPOINT_DISC = "disc.pth.tar"
CHECKPOINT_GEN = "gen.pth.tar"

In [11]:
both_transform = A.Compose(
    [A.Resize(width=256, height=256),], additional_targets={"image0": "image"},
)

transform_only_input = A.Compose(
    [
        A.HorizontalFlip(p=0.5),
        A.ColorJitter(p=0.2),
        A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255.0,),
        ToTensorV2(),
    ]
)

transform_only_mask = A.Compose(
    [
        A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255.0,),
        ToTensorV2(),
    ]
)

In [12]:
class MapDataset(Dataset):
  def __init__(self, root_dir):
    self.high_dir = os.path.join(root_dir, "high")
    self.low_dir = os.path.join(root_dir, "low")
    self.high_images = sorted(os.listdir(self.high_dir))
    self.low_images = sorted(os.listdir(self.low_dir))

  def __len__(self):
    return len(os.listdir(self.high_dir))

  def __getitem__(self, index):
    high_path = os.path.join(self.high_dir, self.high_images[index])
    low_path = os.path.join(self.low_dir, self.low_images[index])

    high_image = np.array(Image.open(high_path))
    low_image = np.array(Image.open(low_path))

    augmentations = both_transform(image=high_image, image0=low_image)
    high_image = augmentations["image"]
    low_image = augmentations["image0"]

    high_image = transform_only_input(image=high_image)["image"]
    low_image = transform_only_mask(image=low_image)["image"]

    return high_image, low_image

In [14]:
import kagglehub
path = kagglehub.dataset_download("tanvirnwu/loli-street-low-light-image-enhancement-of-street")

  from .autonotebook import tqdm as notebook_tqdm


Downloading from https://www.kaggle.com/api/v1/datasets/download/tanvirnwu/loli-street-low-light-image-enhancement-of-street?dataset_version_number=1...


100%|██████████| 2.63G/2.63G [16:16<00:00, 2.89MB/s] 

Extracting files...





In [60]:
print("Dataset Path:", path)
print("Contents of the dataset root folder:", os.listdir(path))

Dataset Path: /root/.cache/kagglehub/datasets/tanvirnwu/loli-street-low-light-image-enhancement-of-street/versions/1
Contents of the dataset root folder: ['LoLI-Street Dataset']


In [61]:
dataset_path = "/root/.cache/kagglehub/datasets/tanvirnwu/loli-street-low-light-image-enhancement-of-street/versions/1"
loli_dataset_folder = os.path.join(dataset_path, "LoLI-Street Dataset")
print("Contents of LoLI-Street Dataset:", os.listdir(loli_dataset_folder))

Contents of LoLI-Street Dataset: ['Val', 'Test', 'YOLO Annotations', 'Train']


In [62]:
train_path = "/root/.cache/kagglehub/datasets/tanvirnwu/loli-street-low-light-image-enhancement-of-street/versions/1/LoLI-Street Dataset/Train"

In [63]:
dataset = MapDataset(train_path)
loader = DataLoader(dataset, batch_size=5)
# Convert tensors back to PIL images
to_pil = transforms.ToPILImage()

import matplotlib.pyplot as plt

# Visualize the entire batch of images
for x, y in loader:
  # Denormalize (convert back to range [0, 1] for visualization)
  x = x * 0.5 + 0.5
  y = y * 0.5 + 0.5

  # Number of images in the batch
  batch_size = x.shape[0]

  # Create a subplot with enough space for all images in the batch
  fig, axes = plt.subplots(batch_size, 2, figsize=(10, batch_size * 5))

  # Loop through each image in the batch
  for i in range(batch_size):
    # Convert PyTorch tensors to numpy for Matplotlib
    x_np = x[i].permute(1, 2, 0).cpu().numpy()
    y_np = y[i].permute(1, 2, 0).cpu().numpy()

    # Plot input and target images
    axes[i, 0].imshow(x_np)
    axes[i, 0].set_title(f"Input Image {i+1}")
    axes[i, 0].axis("off")

    axes[i, 1].imshow(y_np)
    axes[i, 1].set_title(f"Target Image {i+1}")
    axes[i, 1].axis("off")

  plt.tight_layout()
  plt.show()
  break  # Remove this if you want to see multiple batches



Output hidden; open in https://colab.research.google.com to view.

In [None]:
from torchvision.utils import save_image

def save_some_examples(gen, val_loader, epoch, folder):
    x, y = next(iter(val_loader))
    x, y = x.to(DEVICE), y.to(DEVICE)
    gen.eval()
    with torch.no_grad():
        y_fake = gen(x)
        y_fake = y_fake * 0.5 + 0.5  # remove normalization#
        save_image(y_fake, folder + f"/y_gen_{epoch}.png")
        save_image(x * 0.5 + 0.5, folder + f"/input_{epoch}.png")
        if epoch == 1:
            save_image(y * 0.5 + 0.5, folder + f"/label_{epoch}.png")
    gen.train()


def save_checkpoint(model, optimizer, filename="my_checkpoint.pth.tar"):
    print("=> Saving checkpoint")
    checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer": optimizer.state_dict(),
    }
    torch.save(checkpoint, filename)


def load_checkpoint(checkpoint_file, model, optimizer, lr):
    print("=> Loading checkpoint")
    checkpoint = torch.load(checkpoint_file, map_location=DEVICE)
    model.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer"])

    # If we don't do this then it will just have learning rate of old checkpoint
    # and it will lead to many hours of debugging \:
    for param_group in optimizer.param_groups:
        param_group["lr"] = lr

In [64]:
torch.backends.cudnn.benchmark = True


def train_fn(disc, gen, loader, opt_disc, opt_gen, l1_loss, bce, g_scaler, d_scaler):
  loop = tqdm(loader, leave=True)

  for idx, (x, y) in enumerate(loop):
    x = x.to(DEVICE)
    y = y.to(DEVICE)

    # Train Discriminator
    with torch.cuda.amp.autocast():
        y_fake = gen(x)
        D_real = disc(x, y)
        D_real_loss = bce(D_real, torch.ones_like(D_real))
        D_fake = disc(x, y_fake.detach())
        D_fake_loss = bce(D_fake, torch.zeros_like(D_fake))
        D_loss = (D_real_loss + D_fake_loss) / 2

    disc.zero_grad()
    d_scaler.scale(D_loss).backward()
    d_scaler.step(opt_disc)
    d_scaler.update()

    # Train generator
    with torch.cuda.amp.autocast():
        D_fake = disc(x, y_fake)
        G_fake_loss = bce(D_fake, torch.ones_like(D_fake))
        L1 = l1_loss(y_fake, y) * L1_LAMBDA
        G_loss = G_fake_loss + L1

    opt_gen.zero_grad()
    g_scaler.scale(G_loss).backward()
    g_scaler.step(opt_gen)
    g_scaler.update()

    if idx % 10 == 0:
        loop.set_postfix(
            D_real=torch.sigmoid(D_real).mean().item(),
            D_fake=torch.sigmoid(D_fake).mean().item(),
        )

In [65]:
def main():
  disc = Discriminator(in_channels=3).to(DEVICE)
  gen = Generator(in_channels=3, features=64).to(DEVICE)
  opt_disc = optim.Adam(disc.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999),)
  opt_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
  BCE = nn.BCEWithLogitsLoss()
  L1_LOSS = nn.L1Loss()

  # if config.LOAD_MODEL:
  #     load_checkpoint(
  #         config.CHECKPOINT_GEN, gen, opt_gen, config.LEARNING_RATE,
  #     )
  #     load_checkpoint(
  #         config.CHECKPOINT_DISC, disc, opt_disc, config.LEARNING_RATE,
  #     )

  train_dataset = MapDataset(root_dir=TRAIN_DIR)
  train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
  )
  g_scaler = torch.GradScaler()
  d_scaler = torch.GradScaler()
  val_dataset = MapDataset(root_dir=VAL_DIR)
  val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)

  for epoch in range(NUM_EPOCHS):
      train_fn(
          disc, gen, train_loader, opt_disc, opt_gen, L1_LOSS, BCE, g_scaler, d_scaler,
      )

      # if config.SAVE_MODEL and epoch % 5 == 0:
      #       save_checkpoint(gen, opt_gen, filename=config.CHECKPOINT_GEN)
      #       save_checkpoint(disc, opt_disc, filename=config.CHECKPOINT_DISC)

      save_some_examples(gen, val_loader, epoch, folder="generated_images")

In [None]:
main()

  g_scaler = torch.cuda.amp.GradScaler()
  d_scaler = torch.cuda.amp.GradScaler()
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  3%|▎         | 47/1875 [29:39<18:44:58, 36.92s/it, D_fake=0.24, D_real=0.779]