In [2]:
import os
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import random
import torch.nn as nn
import torch.optim as optim
from torchvision.models import convnext_base
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm import tqdm
import copy
import gc

In [3]:
class LoadDataset(Dataset):
    def __init__(self, low_dir, high_dir, max=None, image_size = (360,360), augment=False):
        self.image_paths = []
        self.labels = []
        self.augment = augment
        self.max = max

        low_images = [os.path.join(low_dir,f) for f in os.listdir(low_dir)if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
        low_images = random.sample(low_images, max)
        
        high_images = [os.path.join(high_dir,f) for f in os.listdir(high_dir)if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
        high_images = random.sample(high_images, max)

        self.image_paths += low_images
        self.labels += [1] * len(low_images)

        self.image_paths += high_images
        self.labels += [0] * len(high_images)

        combined = list(zip(self.image_paths, self.labels))
        random.shuffle(combined)
        self.image_paths, self.labels = zip(*combined)

        self.base_transforms = transforms.Compose([transforms.Resize(image_size), transforms.ToTensor()])
        self.augmentation = transforms.Compose([transforms.RandomHorizontalFlip(), 
                                              transforms.ColorJitter(brightness = 0.4, contrast = 0.4, saturation = 0.4, hue = 0.2),
                                              transforms.GaussianBlur(kernel_size = 3, sigma = (0.1,2.0)),
                                              transforms.RandomAffine(degrees = 5, translate = (0.25,0.25))])
    def __len__(self):
        return len(self.image_paths)
        
    def __getitem__(self, idx):
        img = Image.open(self.image_paths[idx]).convert("RGB")

        if self.augment:
            img = self.augmentation(img)

        img = self.base_transforms(img)
        label = torch.tensor(self.labels[idx], dtype=torch.float32)
        return img, label

In [4]:
base_path = "/kaggle/input/loli-street-low-light-image-enhancement-of-street"
base_path = base_path + "/LoLI-Street Dataset"

train_low_dir = os.path.join(base_path, 'Train', 'low')
train_high_dir = os.path.join(base_path, 'Train', 'high')

test_low_dir = os.path.join(base_path, 'Val', 'low')
test_high_dir = os.path.join(base_path, 'Val', 'high')

train_dataset = LoadDataset(low_dir=train_low_dir,high_dir=train_high_dir,max = 10000,image_size=(360,360),augment=True)
test_dataset = LoadDataset(low_dir=test_low_dir,high_dir=test_high_dir,max = 200,image_size=(360,360),augment=False)

train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False)

In [None]:
from torchvision.models import ConvNeXt_Base_Weights

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = convnext_base(weights=ConvNeXt_Base_Weights.DEFAULT)
model.classifier[2] = nn.Linear(model.classifier[2].in_features, 1)
model = model.to(device)

criterion = nn.BCEWithLogitsLoss()
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)
scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=2)

num_epochs = 7
best_val_acc = 0.0
best_val_loss = float('inf')
best_model_wts = copy.deepcopy(model.state_dict())

for epoch in range(num_epochs):
    gc.collect()
    torch.cuda.empty_cache()

    model.train()
    train_loss, train_correct, total = 0.0, 0, 0

    for inputs, labels in tqdm(train_loader, desc=f"Epoch {epoch+1} Training"):
        inputs, labels = inputs.to(device), labels.to(device).unsqueeze(1).float()

        outputs = model(inputs)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item() * inputs.size(0)
        preds = (torch.sigmoid(outputs) > 0.5).float()
        train_correct += (preds == labels).sum().item()
        total += labels.size(0)

    train_acc = train_correct / total
    avg_train_loss = train_loss / total

    model.eval()
    val_loss, val_correct, total = 0.0, 0, 0

    with torch.no_grad():
        for inputs, labels in tqdm(test_loader, desc=f"Epoch {epoch+1} Validation"):
            inputs, labels = inputs.to(device), labels.to(device).unsqueeze(1).float()

            outputs = model(inputs)
            loss = criterion(outputs, labels)

            val_loss += loss.item() * inputs.size(0)
            preds = (torch.sigmoid(outputs) > 0.5).float()
            val_correct += (preds == labels).sum().item()
            total += labels.size(0)

    val_acc = val_correct / total
    avg_val_loss = val_loss / total

    print(f"Epoch {epoch+1}: Train Loss: {avg_train_loss:.4f}, Train Acc: {train_acc:.4f}, Val Loss: {avg_val_loss:.4f}, Val Acc: {val_acc:.4f}")

    scheduler.step(val_acc)
    torch.save(model.state_dict(), f"convnext_epoch_{epoch+1}.pt")

    if (val_acc > best_val_acc) or (val_acc == best_val_acc and avg_val_loss < best_val_loss):
        best_val_acc = val_acc
        best_val_loss = avg_val_loss
        best_model_wts = copy.deepcopy(model.state_dict())
        torch.save(best_model_wts, "Classifier.pt")

model.load_state_dict(best_model_wts)
print(f"Training complete. Best Val Acc: {best_val_acc:.4f}, Best Val Loss: {best_val_loss:.4f}")

Downloading: "https://download.pytorch.org/models/convnext_base-6075fbad.pth" to /root/.cache/torch/hub/checkpoints/convnext_base-6075fbad.pth
100%|██████████| 338M/338M [00:01<00:00, 211MB/s] 
Epoch 1 Training:  20%|█▉        | 977/5000 [13:57<56:36,  1.18it/s]  