In [1]:
import sys
sys.path.append("../GCNet")
from GCNet_model import GCNet
import torch 
from torch.utils.data import random_split
import numpy as np
torch.manual_seed(42)
np.random.seed(42)


import torch.nn as nn
from torch.nn import MSELoss, L1Loss
from torch.optim import Adam

from torchvision import models, transforms
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import os

from PIL import Image
import os
from tqdm import tqdm


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [3]:
class ImageDataset(Dataset):
    def __init__(self, images_dir, transform=None):
        self.images = sorted(os.listdir(images_dir))
        self.images_dir = images_dir
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = Image.open(os.path.join(self.images_dir, self.images[idx]))
        #image = image.convert('RGB')


        # Split the image 
        width, height = image.size        
        truth_image = image.crop((0, 0, width // 3, height))
        glare_image = image.crop((width // 3, 0, (width//3)*2, height))

        if self.transform:
            truth_image = self.transform(truth_image)
            glare_image = self.transform(glare_image)

            truth_image = truth_image.expand(3, -1, -1)
            glare_image = glare_image.expand(3, -1, -1)

        return glare_image, truth_image

In [4]:
transform = transforms.Compose([
    # change to grayscale
    transforms.Grayscale(),
    transforms.Resize((512, 512)),
    transforms.ToTensor()
])

In [5]:
model = GCNet(in_channels=3, out_channels=3).to(device)
model.load_state_dict(torch.load("../GCNet/GCNet_weight.pth", map_location=device),strict=False)
# Freeze all layers
for param in model.parameters():
    param.requires_grad = False

# Unfreeze the last few layers
for param in model.final4.parameters():
    param.requires_grad = True


In [6]:
model

GCNet(
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (up): Interpolate()
  (conv0_0): GCVGGBlock(
    (model): Sequential(
      (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.01, inplace=True)
      (3): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): LeakyReLU(negative_slope=0.01, inplace=True)
    )
  )
  (conv1_0): GCVGGBlock(
    (model): Sequential(
      (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.01, inplace=True)
      (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): BatchNorm2d(64, e

In [7]:
train_path = '../SD1/train'
dataset = ImageDataset(train_path, transform=transform)


# Assuming dataset is already defined
total_size = len(dataset)
train_size = int(total_size * 0.8)  # 80% for training
val_size = total_size - train_size  # Remaining 20% for validation

# Perform the split
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=4, shuffle=False)

In [8]:
mse_criterion = MSELoss().to(device)
l1_criterion = L1Loss().to(device)
optimizer = Adam(model.parameters(), lr=1e-5)
 
# Training loop
num_epochs = 10
 
# Variable to keep track of the best validation loss
best_val_loss = float('inf')
if not os.path.exists('../checkpoint'):
    os.makedirs(['../checkpoint'])
best_model_path = "../checkpoint/best_model.pth"

 
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    running_l1_loss = 0.0
 
    # Training phase with progress bar
    train_pbar = tqdm(train_dataloader, desc=f"Epoch [{epoch+1}/{num_epochs}] Training")
    for glare_images, truth_images in train_pbar:
        glare_images, truth_images = glare_images.to(device), truth_images.to(device)
 
        # Zero the parameter gradients
        optimizer.zero_grad()
 
        # Forward pass
        outputs = model(glare_images)
 
        # Compute the loss
        loss = mse_criterion(outputs, truth_images)
        l1_loss = l1_criterion(outputs, truth_images)


 
        # Backward pass and optimize
        loss.backward()
        optimizer.step()
 
        running_loss += loss.item() * glare_images.size(0)
        running_l1_loss += l1_loss.item() * glare_images.size(0)
 
        # Update progress bar
        train_pbar.set_postfix(loss=loss.item())
 
    epoch_loss = running_loss / len(train_dataloader.dataset)
    l1_epoch_loss = running_l1_loss / len(train_dataloader.dataset)
    print(f'Epoch [{epoch+1}/{num_epochs}], MSE Loss: {epoch_loss:.4f}, L1 Loss: {l1_epoch_loss}')
 
    # Validation phase with progress bar
    model.eval()
    val_loss = 0.0
 
    val_pbar = tqdm(val_dataloader, desc=f"Epoch [{epoch+1}/{num_epochs}] Validation")
    with torch.no_grad():
        for glare_images, truth_images in val_pbar:
            glare_images, truth_images = glare_images.to(device), truth_images.to(device)
 
            # Forward pass
            outputs = model(glare_images)
 
            # Compute L1 loss
            loss = l1_criterion(outputs, truth_images)
 
            val_loss += loss.item() * glare_images.size(0)
 
            # Update progress bar
            val_pbar.set_postfix(loss=loss.item())
 
    val_loss /= len(val_dataloader.dataset)
    print(f'L1 Validation Loss: {val_loss:.4f}')
 
    # Save the model if the validation loss is the best we've seen so far
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), best_model_path)
        print(f'Saved Best Model with Validation Loss: {val_loss:.4f}')
 
print('Training complete.')
print(f'Best validation loss: {best_val_loss:.4f}')

Epoch [1/10] Training: 100%|██████████| 4800/4800 [31:44<00:00,  2.52it/s, loss=0.0333]


Epoch [1/10], MSE Loss: 0.2936, L1 Loss: 0.33946515616960826


Epoch [1/10] Validation: 100%|██████████| 600/600 [07:12<00:00,  1.39it/s, loss=0.154] 


L1 Validation Loss: 0.1254
Saved Best Model with Validation Loss: 0.1254


Epoch [2/10] Training: 100%|██████████| 4800/4800 [32:03<00:00,  2.50it/s, loss=0.0195] 


Epoch [2/10], MSE Loss: 0.0314, L1 Loss: 0.13087072954745962


Epoch [2/10] Validation: 100%|██████████| 600/600 [07:12<00:00,  1.39it/s, loss=0.122] 


L1 Validation Loss: 0.1021
Saved Best Model with Validation Loss: 0.1021


Epoch [3/10] Training: 100%|██████████| 4800/4800 [31:56<00:00,  2.50it/s, loss=0.0105] 


Epoch [3/10], MSE Loss: 0.0210, L1 Loss: 0.10954819753067568


Epoch [3/10] Validation: 100%|██████████| 600/600 [07:12<00:00,  1.39it/s, loss=0.12]  


L1 Validation Loss: 0.0964
Saved Best Model with Validation Loss: 0.0964


Epoch [4/10] Training: 100%|██████████| 4800/4800 [31:57<00:00,  2.50it/s, loss=0.0137] 


Epoch [4/10], MSE Loss: 0.0179, L1 Loss: 0.10216635415252918


Epoch [4/10] Validation: 100%|██████████| 600/600 [07:11<00:00,  1.39it/s, loss=0.114] 


L1 Validation Loss: 0.0910
Saved Best Model with Validation Loss: 0.0910


Epoch [5/10] Training: 100%|██████████| 4800/4800 [32:05<00:00,  2.49it/s, loss=0.00961]


Epoch [5/10], MSE Loss: 0.0168, L1 Loss: 0.09934377415648972


Epoch [5/10] Validation: 100%|██████████| 600/600 [07:20<00:00,  1.36it/s, loss=0.115] 


L1 Validation Loss: 0.0943


Epoch [6/10] Training:  47%|████▋     | 2254/4800 [15:07<17:29,  2.43it/s, loss=0.0156] 