In [None]:
import torch
import torch.nn as nn
import torchvision.models as models
import os
from PIL import Image
from torchvision import transforms
import albumentations as A
import cv2
import numpy as np
import matplotlib.pyplot as plt

In [None]:


# Define a more comprehensive augmentation pipeline
augmentations = A.Compose([
    A.HorizontalFlip(p=0.9),  # Horizontal flip with 90% probability
    A.VerticalFlip(p=0.9),    # Vertical flip with 50% probability
    A.RandomRotate90(p=0.9),  # Random 90-degree rotations with 50% probability
    A.Rotate(limit=45, p=0.9),  # Random rotation between -45 and 45 degrees
    A.RandomBrightnessContrast(p=0.5),  # Random brightness and contrast adjustment
    A.Blur(blur_limit=7, p=0.3),  # Random blur with a blur limit
    A.Resize(512, 512),  # Resize to ensure uniform size
])


# List of directories containing paired input/target subdirectories
dataset_dirs = [r"D:\MANIPAL\Research\Reflection_Removal\Code\Dataset\DSLR\unaligned_test50", 
                r"D:\MANIPAL\Research\Reflection_Removal\Code\Dataset\DSLR\unaligned_train250",
                r"D:\MANIPAL\Research\Reflection_Removal\Code\Dataset\Smartphone\unaligned150"               
                ]  # Add paths to your directories here

# Global output directories
output_input_dir = r"D:/MANIPAL/Research/Reflection_Removal/Code/new_aug/aug_blended"
output_target_dir = r"D:/MANIPAL/Research/Reflection_Removal/Code/new_aug/aug_transmission_layer"

# Ensure output directories exist
os.makedirs(output_input_dir, exist_ok=True)
os.makedirs(output_target_dir, exist_ok=True)

# Function to augment and save images
def augment_and_save(input_path, target_path, output_input_dir, output_target_dir, new_name):
    input_image = cv2.imread(input_path)
    target_image = cv2.imread(target_path)

    # Apply augmentations
    augmented = augmentations(image=input_image, mask=target_image)
    augmented_input = augmented['image']
    augmented_target = augmented['mask']

    # Save augmented images
    cv2.imwrite(os.path.join(output_input_dir, new_name), augmented_input)
    cv2.imwrite(os.path.join(output_target_dir, new_name), augmented_target)

# Process each dataset directory
image_counter = 0  # Counter to ensure unique filenames across all directories

for dataset_dir in dataset_dirs:
    input_dir = os.path.join(dataset_dir, "blended")
    target_dir = os.path.join(dataset_dir, "transmission_layer")

    # Process each pair of input/target images
    input_files = sorted(os.listdir(input_dir))
    target_files = sorted(os.listdir(target_dir))

    for input_file, target_file in zip(input_files, target_files):
        # Ensure input and target have the same name
        input_path = os.path.join(input_dir, input_file)
        target_path = os.path.join(target_dir, target_file)

        # Perform 3 augmentations per image
        for i in range(5):
            new_name = f"aug_image_{image_counter:04d}.jpg"  # Unique name with a zero-padded counter
            augment_and_save(input_path, target_path, output_input_dir, output_target_dir, new_name)
            image_counter += 1

print("Augmentation completed. All images are stored in 'aug_blended' and 'aug_transmission_layer'.")


Augmentation completed. All images are stored in 'aug_blended' and 'aug_transmission_layer'.


In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.skip_connection = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.skip_connection = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels),
            )

    def forward(self, x):
        identity = self.skip_connection(x)
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out += identity
        out = self.relu(out)
        return out

# ResNet-style Encoder and Decoder
class ReflectionRemovalNet(nn.Module):
    def __init__(self):
        super(ReflectionRemovalNet, self).__init__()

        self.encoder = nn.Sequential(
            ResidualBlock(3, 64, stride=1),
            ResidualBlock(64, 128, stride=2),
            ResidualBlock(128, 256, stride=2)
        )
        
        # Bottleneck
        self.bottleneck = nn.Sequential(
            ResidualBlock(256, 256, stride=1)
        )
        
        self.decoder = nn.Sequential(
            ResidualBlock(256, 128, stride=1),
            nn.ConvTranspose2d(128, 128, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(inplace=True),
            ResidualBlock(128, 64, stride=1),
            nn.ConvTranspose2d(64, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 3, kernel_size=3, stride=1, padding=1),
            nn.Sigmoid()  
        )
    
    def forward(self, x):
        enc = self.encoder(x)
        bottleneck = self.bottleneck(enc)
        dec = self.decoder(bottleneck)
        return dec

class PerceptualLoss(nn.Module):
    def __init__(self, path=r"anti_reflection\models\mobilenet_v2-b0353104.pth"):
        super(PerceptualLoss, self).__init__()
        
        mobilenet = models.mobilenet_v2(pretrained=True)
        self.features = mobilenet.features
        self.layers = nn.Sequential(*list(self.features.children())[:15]).eval()
    
        for param in self.layers.parameters():
            param.requires_grad = False

    def forward(self, input, target):
        input_features = self.layers(input)
        target_features = self.layers(target)
        return nn.functional.l1_loss(input_features, target_features)


# gives pixel level accuracy
class ReconstructionLoss(nn.Module):
    def __init__(self, loss_type='l1'):
        super(ReconstructionLoss, self).__init__()
        if loss_type == 'l1':
            self.loss_fn = nn.L1Loss()
        elif loss_type == 'l2':
            self.loss_fn = nn.MSELoss()

    def forward(self, predicted, target):
        return self.loss_fn(predicted, target)


Traceback (most recent call last):
  File "c:\Users\chris\.vscode\extensions\ms-python.python-2024.20.0-win32-x64\python_files\python_server.py", line 130, in exec_user_input
    retval = callable_(user_input, user_globals)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<string>", line 1, in <module>
NameError: name 'nn' is not defined



In [None]:
class ReflectionDataset(torch.utils.data.Dataset):
    def __init__(self, reflected_dir, clear_dir, transform=None):
        self.reflected_files = sorted(os.listdir(reflected_dir))
        self.clear_files = sorted(os.listdir(clear_dir))
        self.reflected_dir = reflected_dir
        self.clear_dir = clear_dir
        self.transform = transform

    def __len__(self):
        return len(self.reflected_files)

    def __getitem__(self, idx):
        reflected_image = Image.open(os.path.join(self.reflected_dir, self.reflected_files[idx])).convert('RGB')
        clear_image = Image.open(os.path.join(self.clear_dir, self.clear_files[idx])).convert('RGB')

        if self.transform:
            reflected_image = self.transform(reflected_image)
            clear_image = self.transform(clear_image)

        return {'input': reflected_image, 'target': clear_image}

# Data Preprocessing
def get_dataloader(reflected_dir, clear_dir, batch_size=16):
    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor()
    ])

    dataset = ReflectionDataset(reflected_dir=reflected_dir, clear_dir=clear_dir, transform=transform)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
    return dataloader


Traceback (most recent call last):
  File "c:\Users\chris\.vscode\extensions\ms-python.python-2024.20.0-win32-x64\python_files\python_server.py", line 130, in exec_user_input
    retval = callable_(user_input, user_globals)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<string>", line 1, in <module>
NameError: name 'torch' is not defined



In [None]:
def train_model(
    model, dataloader, device, num_epochs=25, learning_rate=1e-4, output_dir="trained_models/"
):
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    # defining all the functions
    perceptual_loss_fn = PerceptualLoss().to(device)
    mse_loss_fn = nn.MSELoss()  # MSE Loss
    reconstruction_loss_fn = nn.L1Loss()  

    model = model.to(device)

    for epoch in range(num_epochs):
        if epoch > 300:
            for param_group in optimizer.param_groups:
                param_group['lr'] = learning_rate
            
        model.train()
        running_loss = 0.0
        
        for batch in dataloader:
            inputs, targets = batch['input'].to(device), batch['target'].to(device)

            # Forward pass
            outputs = model(inputs)
            mse_loss = mse_loss_fn(outputs, targets)  
            perceptual_loss = perceptual_loss_fn(outputs, targets) 
            reconstruction_loss = reconstruction_loss_fn(outputs, targets)  
            
            loss = mse_loss + 0.5 * perceptual_loss + 0.5 * reconstruction_loss #sum all hte losses
            
            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(dataloader):.4f}")

        if (epoch + 1) % 50 == 0:  # checkpointing
            if not os.path.exists(output_dir):
                os.makedirs(output_dir)
            torch.save(model.state_dict(), os.path.join(output_dir, f"model_epoch_{epoch+1}.pth"))
    
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    torch.save(model.state_dict(), os.path.join(output_dir, "new_model_final.pth"))
    print("Training complete. Model saved.")


In [None]:
if __name__ == "__main__":
    reflected_dir = r"anti_reflection\dataset\aug_blended" # train images that have reflections
    clear_dir = r"anti_reflection\dataset\aug_transmission_layer" # train images that have no reflections
    output_dir = r"anti_reflection\models" # dir to store the model

    os.makedirs(output_dir, exist_ok=True)

    batch_size = 16
    num_epochs = 100
    learning_rate = 1e-3

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    dataloader = get_dataloader(reflected_dir, clear_dir, batch_size)
    model = ReflectionRemovalNet()

    train_model(model, dataloader, device, num_epochs, learning_rate, output_dir)


Epoch [1/100], Loss: 0.2718
Epoch [2/100], Loss: 0.2682
Epoch [3/100], Loss: 0.2568
Epoch [4/100], Loss: 0.2403
Epoch [5/100], Loss: 0.2270
Epoch [6/100], Loss: 0.2138
Epoch [7/100], Loss: 0.2030
Epoch [8/100], Loss: 0.1939
Epoch [9/100], Loss: 0.1858
Epoch [10/100], Loss: 0.1784
Epoch [11/100], Loss: 0.1708
Epoch [12/100], Loss: 0.1632
Epoch [13/100], Loss: 0.1557


KeyboardInterrupt: 

In [None]:
def load_model(model_path, device):
    model = ReflectionRemovalNet()  
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.to(device)
    model.eval()
    return model

def process_image(image_path, transform, device):

    image = Image.open(image_path).convert('RGB')
    image = transform(image).unsqueeze(0) # to correct the dimension
    return image.to(device)

def save_output_image(output_tensor, save_path):
    output_image = output_tensor.squeeze(0).cpu().detach()
    output_image = transforms.ToPILImage()(output_image)
    output_image.save(save_path)

def test_model(model_path, input_dir, output_dir, device):
    os.makedirs(output_dir, exist_ok=True)
    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor()
    ])

    model = load_model(model_path, device) 

    input_images = sorted(os.listdir(input_dir))
    for img_name in input_images:
        input_path = os.path.join(input_dir, img_name)
        output_path = os.path.join(output_dir, img_name)

        # check if image files
        if not img_name.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff')):
            print(f"Skipping non-image file: {img_name}")
            continue

        input_tensor = process_image(input_path, transform, device)

        # generating output
        with torch.no_grad():
            output_tensor = model(input_tensor)

        save_output_image(output_tensor, output_path)
        print(f"Processed {img_name} and saved to {output_path}")

if __name__ == "__main__":
    model_path = r"anti_reflection\models\new_model_final.pth"  # path to the trained model
    input_dir = r"anti_reflection\dataset\aug_blended"  # Folder containing test reflected images
    output_dir = r"anti_reflection\dataset\predicted"  # Folder to save output images

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    test_model(model_path, input_dir, output_dir, device)


  model.load_state_dict(torch.load(model_path, map_location=device))


Processed 0.png and saved to D:\MANIPAL\Research\Reflection_Removal\Code\predicted_imgs\0.png
Processed aug_image_0009.jpg and saved to D:\MANIPAL\Research\Reflection_Removal\Code\predicted_imgs\aug_image_0009.jpg
Processed aug_image_0017.jpg and saved to D:\MANIPAL\Research\Reflection_Removal\Code\predicted_imgs\aug_image_0017.jpg
Processed cupboard.jpg and saved to D:\MANIPAL\Research\Reflection_Removal\Code\predicted_imgs\cupboard.jpg
Processed dada.jpg and saved to D:\MANIPAL\Research\Reflection_Removal\Code\predicted_imgs\dada.jpg


In [None]:
def calculate_psnr(output_image, ground_truth_image):
    mse = np.mean((ground_truth_image - output_image) ** 2)
    if mse == 0:  
        return float('inf')
    max_pixel = 1.0  
    psnr_value = 20 * np.log10(max_pixel / np.sqrt(mse))
    return psnr_value

def calculate_ssim(output_image, ground_truth_image):
    
    C1 = 0.01 ** 2
    C2 = 0.03 ** 2

    output_gray = cv2.cvtColor((output_image * 255).astype(np.uint8), cv2.COLOR_RGB2GRAY).astype(np.float32) / 255.0
    ground_truth_gray = cv2.cvtColor((ground_truth_image * 255).astype(np.uint8), cv2.COLOR_RGB2GRAY).astype(np.float32) / 255.0

    mu_x = cv2.GaussianBlur(output_gray, (11, 11), 1.5)
    mu_y = cv2.GaussianBlur(ground_truth_gray, (11, 11), 1.5)
    sigma_x = cv2.GaussianBlur(output_gray ** 2, (11, 11), 1.5) - mu_x ** 2
    sigma_y = cv2.GaussianBlur(ground_truth_gray ** 2, (11, 11), 1.5) - mu_y ** 2
    sigma_xy = cv2.GaussianBlur(output_gray * ground_truth_gray, (11, 11), 1.5) - mu_x * mu_y

    numerator = (2 * mu_x * mu_y + C1) * (2 * sigma_xy + C2)
    denominator = (mu_x ** 2 + mu_y ** 2 + C1) * (sigma_x + sigma_y + C2)
    ssim_map = numerator / (denominator + 1e-6)
    return ssim_map.mean()

def calculate_precision_recall_accuracy(output_image, ground_truth_image):
    
    output_gray = cv2.cvtColor((output_image * 255).astype(np.uint8), cv2.COLOR_RGB2GRAY)
    ground_truth_gray = cv2.cvtColor((ground_truth_image * 255).astype(np.uint8), cv2.COLOR_RGB2GRAY)

    _, output_binary = cv2.threshold(output_gray, 128, 1, cv2.THRESH_BINARY)
    _, ground_truth_binary = cv2.threshold(ground_truth_gray, 128, 1, cv2.THRESH_BINARY)

    # calculate TP, FP, FN, TN
    tp = np.sum(output_binary * ground_truth_binary)  
    fp = np.sum(output_binary * (1 - ground_truth_binary)) 
    fn = np.sum((1 - output_binary) * ground_truth_binary)  
    tn = np.sum((1 - output_binary) * (1 - ground_truth_binary))  

    precision = tp / (tp + fp + 1e-6)  # avoid division by zero
    recall = tp / (tp + fn + 1e-6)
    accuracy = (tp + tn) / (tp + fp + fn + tn + 1e-6)

    return {"Precision": precision, "Recall": recall, "Accuracy": accuracy}




In [None]:

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def test_gan_model_on_directory(generator_path, test_dir, ground_truth_dir):
   
    generator = ReflectionRemovalNet().to(device)
    generator.load_state_dict(torch.load(generator_path, map_location=device))
    generator.eval()

    transform = transforms.Compose([
        transforms.ToTensor(),
    ])

    psnr_values = []
    ssim_values = []
    precision_values = []
    recall_values = []
    accuracy_values = []

    test_images = sorted(os.listdir(test_dir))
    ground_truth_images = sorted(os.listdir(ground_truth_dir))

    for test_img_name, gt_img_name in zip(test_images, ground_truth_images):
        test_img_path = os.path.join(test_dir, test_img_name)
        gt_img_path = os.path.join(ground_truth_dir, gt_img_name)

        input_image = Image.open(test_img_path).convert("RGB")
        input_tensor = transform(input_image).unsqueeze(0).to(device)

        with torch.no_grad():
            output_tensor = generator(input_tensor)

        output_tensor = output_tensor.squeeze(0).cpu().numpy()
        output_image = np.transpose(output_tensor, (1, 2, 0))
        output_image = np.clip(output_image, 0, 1)

        ground_truth_image = Image.open(gt_img_path).convert("RGB")
        ground_truth_image = np.asarray(ground_truth_image).astype(np.float32) / 255.0

        output_image = cv2.resize(output_image, (ground_truth_image.shape[1], ground_truth_image.shape[0]), interpolation=cv2.INTER_LINEAR)

        psnr_value = calculate_psnr(output_image, ground_truth_image)
        ssim_value = calculate_ssim(output_image, ground_truth_image)
        metrics = calculate_precision_recall_accuracy(output_image, ground_truth_image)

        psnr_values.append(psnr_value)
        ssim_values.append(ssim_value)
        precision_values.append(metrics["Precision"])
        recall_values.append(metrics["Recall"])
        accuracy_values.append(metrics["Accuracy"])

    avg_psnr = np.mean(psnr_values)
    avg_ssim = np.mean(ssim_values)
    avg_precision = np.mean(precision_values)
    avg_recall = np.mean(recall_values)
    avg_accuracy = np.mean(accuracy_values)

    print("Average Metrics for Directory:")
    print(f"PSNR: {avg_psnr:.4f}")
    print(f"SSIM: {avg_ssim:.4f}")
    print(f"Precision: {avg_precision:.4f}")
    print(f"Recall: {avg_recall:.4f}")
    print(f"Accuracy: {avg_accuracy:.4f}")

    return {
        "PSNR": avg_psnr,
        "SSIM": avg_ssim,
        "Precision": avg_precision,
        "Recall": avg_recall,
        "Accuracy": avg_accuracy,
    }

test_image_dir = r"anti_reflection\dataset\aug_blended"
ground_truth_dir = r"anti_reflection\dataset\aug_transmission_layer"
model_path = r"anti_reflection\models\new_model_final.pth"

metrics = test_gan_model_on_directory(model_path, test_image_dir, ground_truth_dir)


  generator.load_state_dict(torch.load(generator_path, map_location=device))


Average Metrics for Directory:
PSNR: 16.0225
SSIM: 0.4029
Precision: 0.7755
Recall: 0.7496
Accuracy: 0.8101
