In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image
from dataset import LOLDataset  # Import the dataset class

# Set device (MPS for M1 Pro GPU)
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")

Using device: mps


In [2]:
class LOLDataset(Dataset):
    def __init__(self, low_light_dir, high_light_dir, transform=None):
        self.low_light_dir = low_light_dir
        self.high_light_dir = high_light_dir
        self.transform = transform
        self.low_light_images = sorted(os.listdir(low_light_dir))
        self.high_light_images = sorted(os.listdir(high_light_dir))

    def __len__(self):
        return len(self.low_light_images)

    def __getitem__(self, idx):
        low_light_img_path = os.path.join(self.low_light_dir, self.low_light_images[idx])
        high_light_img_path = os.path.join(self.high_light_dir, self.high_light_images[idx])

        low_light_img = Image.open(low_light_img_path).convert('RGB')
        high_light_img = Image.open(high_light_img_path).convert

In [3]:
class BrighteningModule(nn.Module):
    def __init__(self):
        super(BrighteningModule, self).__init__()
        self.conv = nn.Conv2d(3, 3, kernel_size=3, padding=1)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        # Learn a pixel-wise scaling factor for brightness adjustment
        scale = self.relu(self.conv(x))
        return x * (1 + scale)  # Adjust brightness


class DenoisingModule(nn.Module):
    def __init__(self):
        super(DenoisingModule, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(64, 3, kernel_size=3, padding=1)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        residual = x
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.conv3(x)
        return x + residual  # Skip connection for denoising


class RecursiveEnhancementModel(nn.Module):
    def __init__(self, num_iterations=3):
        super(RecursiveEnhancementModel, self).__init__()
        self.num_iterations = num_iterations
        self.brightening_module = BrighteningModule()
        self.denoising_module = DenoisingModule()

    def forward(self, x):
        for _ in range(self.num_iterations):
            # Brighten the image
            x = self.brightening_module(x)
            # Denoise the image
            x = self.denoising_module(x)
        return x

In [4]:
def train_model(model, dataloader, criterion, optimizer, num_epochs=10, device="cpu"):
    model.to(device)
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        for low_light_imgs, high_light_imgs in dataloader:
            low_light_imgs = low_light_imgs.to(device)
            high_light_imgs = high_light_imgs.to(device)

            optimizer.zero_grad()
            outputs = model(low_light_imgs)
            loss = criterion(outputs, high_light_imgs)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(dataloader):.4f}")

In [5]:
if __name__ == "__main__":
    # Define paths
    LOW_LIGHT_DIR = "./lol_dataset/our485/low"
    HIGH_LIGHT_DIR = "./lol_dataset/our485/high"

    # Define transformations
    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
    ])

    # Load dataset
    dataset = LOLDataset(LOW_LIGHT_DIR, HIGH_LIGHT_DIR, transform=transform)
    dataloader = DataLoader(dataset, batch_size=16, shuffle=True, num_workers=4)

    # Initialize model, loss function, and optimizer
    model = RecursiveEnhancementModel(num_iterations=3)
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    # Train the model
    train_model(model, dataloader, criterion, optimizer, num_epochs=20, device=device)

Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/Users/achintyajha/miniconda3/envs/ml/lib/python3.10/multiprocessing/spawn.py", line 116, in spawn_main
    exitcode = _main(fd, parent_sentinel)
  File "/Users/achintyajha/miniconda3/envs/ml/lib/python3.10/multiprocessing/spawn.py", line 126, in _main
    self = reduction.pickle.load(from_parent)
AttributeError: Can't get attribute 'LOLDataset' on <module '__main__' (built-in)>
Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/Users/achintyajha/miniconda3/envs/ml/lib/python3.10/multiprocessing/spawn.py", line 116, in spawn_main
    exitcode = _main(fd, parent_sentinel)
  File "/Users/achintyajha/miniconda3/envs/ml/lib/python3.10/multiprocessing/spawn.py", line 126, in _main
    self = reduction.pickle.load(from_parent)
AttributeError: Can't get attribute 'LOLDataset' on <module '__main__' (built-in)>
Traceback (most recent call last):
  File "<string>", line 1, in <module

RuntimeError: DataLoader worker (pid(s) 2213) exited unexpectedly