In [1]:
!pip install torch torchvision numpy opencv-python tqdm lpips

Collecting lpips
  Downloading lpips-0.1.4-py3-none-any.whl.metadata (10 kB)
Downloading lpips-0.1.4-py3-none-any.whl (53 kB)
[2K   [38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m53.8/53.8 kB[0m [31m2.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: lpips
Successfully installed lpips-0.1.4


In [2]:
import os
import cv2
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import lpips  # Perceptual loss for detail preservation

In [3]:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Paths
LOW_LIGHT_DIR = "./lol_dataset/our485/low"
HIGH_LIGHT_DIR = "./lol_dataset/our485/high"
SAVE_MODEL_PATH = "low_light_enhancer.pth"

# Define image transformation
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

In [4]:
# Custom dataset
class LowLightDataset(Dataset):
    def __init__(self, low_dir, high_dir, transform=None):
        self.low_images = sorted(os.listdir(low_dir))
        self.high_images = sorted(os.listdir(high_dir))
        self.low_dir = low_dir
        self.high_dir = high_dir
        self.transform = transform

    def __len__(self):
        return len(self.low_images)

    def __getitem__(self, idx):
        low_path = os.path.join(self.low_dir, self.low_images[idx])
        high_path = os.path.join(self.high_dir, self.high_images[idx])

        # Read images
        low_img = cv2.imread(low_path, cv2.IMREAD_COLOR)
        high_img = cv2.imread(high_path, cv2.IMREAD_COLOR)

        # Convert BGR to RGB
        low_img = cv2.cvtColor(low_img, cv2.COLOR_BGR2RGB)
        high_img = cv2.cvtColor(high_img, cv2.COLOR_BGR2RGB)

        # Resize to 256x256 for training
        low_img = cv2.resize(low_img, (256, 256))
        high_img = cv2.resize(high_img, (256, 256))

        if self.transform:
            low_img = self.transform(low_img)
            high_img = self.transform(high_img)

        return low_img, high_img

# Define U-Net model
class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1),
            nn.ReLU()
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(64, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 3, kernel_size=3, stride=1, padding=1),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x


In [5]:
# Initialize model, loss, and optimizer
model = UNet().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-4)
l1_loss = nn.L1Loss()
perceptual_loss = lpips.LPIPS(net='vgg').to(device)

# Load dataset
dataset = LowLightDataset(LOW_LIGHT_DIR, HIGH_LIGHT_DIR, transform=transform)
dataloader = DataLoader(dataset, batch_size=8, shuffle=True)

# Training loop
num_epochs = 50
for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0
    for low_img, high_img in tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}"):
        low_img, high_img = low_img.to(device), high_img.to(device)

        optimizer.zero_grad()
        enhanced_img = model(low_img)

        loss_l1 = l1_loss(enhanced_img, high_img)
        loss_perceptual = perceptual_loss(enhanced_img, high_img).mean()
        loss = loss_l1 + 0.1 * loss_perceptual  # Weighted loss

        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss/len(dataloader):.4f}")

    # Save model every 10 epochs
    if (epoch + 1) % 10 == 0:
        torch.save(model.state_dict(), SAVE_MODEL_PATH)

print("Training complete. Model saved.")


Setting up [LPIPS] perceptual loss: trunk [vgg], v[0.1], spatial [off]


Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /Users/achintyajha/.cache/torch/hub/checkpoints/vgg16-397923af.pth
100%|████████████████████████████████████████| 528M/528M [00:24<00:00, 22.9MB/s]


Loading model from: /Users/achintyajha/miniconda3/envs/ml/lib/python3.10/site-packages/lpips/weights/v0.1/vgg.pth


Epoch 1/50:  66%|████████████████████▎          | 40/61 [02:23<01:15,  3.59s/it]


error: OpenCV(4.10.0) /Users/xperience/GHA-Actions-OpenCV/_work/opencv-python/opencv-python/opencv/modules/imgproc/src/color.cpp:196: error: (-215:Assertion failed) !_src.empty() in function 'cvtColor'


In [None]:
# Load model
model = UNet()
model.load_state_dict(torch.load("low_light_enhancer.pth", map_location="cpu"))
model.eval()

# Load image
img_path = "./lol_dataset/our485/low/sample.jpg"
img = cv2.imread(img_path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = cv2.resize(img, (256, 256))
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])
img = transform(img).unsqueeze(0)

In [None]:
# Enhance
with torch.no_grad():
    enhanced_img = model(img).squeeze().numpy()

# Convert back to uint8
enhanced_img = ((enhanced_img.transpose(1, 2, 0) + 1) * 127.5).astype(np.uint8)

# Save result
cv2.imwrite("enhanced.jpg", cv2.cvtColor(enhanced_img, cv2.COLOR_RGB2BGR))
print("Enhanced image saved.")