# Part 1: Creation of Dataset : pair of fenced and de-fenced images

In [None]:
# Importing necesssary libraries to fence the images
import os
import cv2
import numpy as np
import random
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import shutil
from glob import glob

In [None]:
# Uncomment only to delete the output created pair_dataset
# shutil.rmtree("/kaggle/working/paired_dataset") 

In [None]:
# Define input and output directories
input_dir = '/kaggle/input/cocotest2014'
output_base_dir = '/kaggle/working/paired_dataset'
print("Using input directory:", input_dir)

In [None]:
# Step 1: Define subdirectories for train and test splits
train_input_dir  = os.path.join(output_base_dir, 'train', 'input')
train_target_dir = os.path.join(output_base_dir, 'train', 'target')
train_edge_dir   = os.path.join(output_base_dir, 'train', 'edge')

test_input_dir   = os.path.join(output_base_dir, 'test', 'input')
test_target_dir  = os.path.join(output_base_dir, 'test', 'target')
test_edge_dir    = os.path.join(output_base_dir, 'test', 'edge')

# Create directories if they don't exist
for directory in [train_input_dir, train_target_dir, train_edge_dir, 
                  test_input_dir, test_target_dir, test_edge_dir]:
    os.makedirs(directory, exist_ok=True)

print("Directories created under:", output_base_dir)

In [None]:
# Get all image paths
image_paths = glob(os.path.join(input_dir, '*.jpg'))

# Split data (80% train, 20% test)
split_idx = int(0.8 * len(image_paths))
train_paths = image_paths[:split_idx]
test_paths = image_paths[split_idx:]

In [None]:
# Step 2: Define Processing Function

def generate_fence(image, grid_size=(80, 40), color=(2,255, 255), thickness=4):
    fence_image = image.copy()
    dx, dy = grid_size
    h, w = fence_image.shape[:2]
    
    step_y = dy // 2
    num_rows = (h // step_y) + 2 
    
    for row in range(num_rows):
        y = row * step_y
        offset = dx // 2 if row % 2 == 1 else 0
        num_cols = (w // dx) + 2
        
        for col in range(num_cols):
            x = offset + col * dx
            top = (x, y - dy // 2)
            right = (x + dx // 2, y)
            bottom = (x, y + dy // 2)
            left = (x - dx // 2, y)
            
            pts = np.array([top, right, bottom, left], np.int32)
            pts = pts.reshape((-1, 1, 2))
            cv2.polylines(fence_image, [pts], isClosed=True, color=color, thickness=thickness)
    
    return fence_image


In [None]:
def process_and_save_image(img_path, target_size=(256, 256), out_dirs=None, filename_prefix=''):
    image = cv2.imread(img_path)
    if image is None:
        return
    image_resized = cv2.resize(image, target_size)
    image_fenced = generate_fence(image_resized)
    edge_map = cv2.Canny(image_fenced, threshold1=100, threshold2=200)
    
    target_filename = os.path.join(out_dirs['target'], f"{filename_prefix}_target.png")
    input_filename  = os.path.join(out_dirs['input'],  f"{filename_prefix}_input.png")
    edge_filename   = os.path.join(out_dirs['edge'],   f"{filename_prefix}_edge.png")
    
    cv2.imwrite(target_filename, image_resized)
    cv2.imwrite(input_filename,  image_fenced)
    cv2.imwrite(edge_filename,   edge_map)

In [None]:
# Step 3: Gather and Split the Dataset
all_image_paths = []
for root, dirs, files in os.walk(input_dir):
    for file in files:
        if file.lower().endswith(('.png', '.jpg', '.jpeg')):
            all_image_paths.append(os.path.join(root, file))
            
print(f"Found {len(all_image_paths)} images in the dataset.")

max_images = 50000  # Set to None to process all images
if max_images is not None:
    all_image_paths = all_image_paths[:max_images]

random.shuffle(all_image_paths)
split_idx = int(0.8 * len(all_image_paths))
train_paths = all_image_paths[:split_idx]
test_paths  = all_image_paths[split_idx:]

print(f"Processing {len(train_paths)} training images and {len(test_paths)} testing images.")

In [None]:
# Process training images.
print("Processing training images...")
for idx, img_path in enumerate(train_paths):
    filename_prefix = f"train_{idx}"
    out_dirs = {'input': train_input_dir, 'target': train_target_dir, 'edge': train_edge_dir}
    process_and_save_image(img_path, target_size=(256, 256), out_dirs=out_dirs, filename_prefix=filename_prefix)
    if idx % 100 == 0:
        print(f"Processed {idx} training images...")

# Process testing images.
print("Processing testing images...")
for idx, img_path in enumerate(test_paths):
    filename_prefix = f"test_{idx}"
    out_dirs = {'input': test_input_dir, 'target': test_target_dir, 'edge': test_edge_dir}
    process_and_save_image(img_path, target_size=(256, 256), out_dirs=out_dirs, filename_prefix=filename_prefix)
    if idx % 100 == 0:
        print(f"Processed {idx} testing images...")

print("Processing complete. Images saved under:", output_base_dir)

In [None]:
# Show a random image
random_image_path = random.choice(all_image_paths)
img = mpimg.imread(random_image_path)
plt.imshow(img)
plt.axis('off')  # Hide axes
plt.title(f"Random Image: {os.path.basename(random_image_path)}")
plt.show()

# Part2: De-fence the images

In [None]:
# Import necessary libraries
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import random
from torchmetrics.functional import peak_signal_noise_ratio as psnr
from torchmetrics.functional import structural_similarity_index_measure as ssim

In [None]:
# 2. Define the Generator (U-Net style)\
# We use a UnetSkipConnectionBlock from the pix2pix architecture.
class UnetSkipConnectionBlock(nn.Module):
    def __init__(self, outer_nc, inner_nc, input_nc=None, submodule=None,
                 outermost=False, innermost=False, use_dropout=False):
        super(UnetSkipConnectionBlock, self).__init__()
        self.outermost = outermost
        if input_nc is None:
            input_nc = outer_nc
        downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4, stride=2, padding=1, bias=False)
        downrelu = nn.LeakyReLU(0.2, True)
        downnorm = nn.BatchNorm2d(inner_nc)
        uprelu   = nn.ReLU(True)
        upnorm   = nn.BatchNorm2d(outer_nc)

        if outermost:
            upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1)
            down = [downconv]
            up   = [uprelu, upconv, nn.Tanh()]
            model = down + [submodule] + up
        elif innermost:
            upconv = nn.ConvTranspose2d(inner_nc, outer_nc, kernel_size=4, stride=2, padding=1, bias=False)
            down = [downrelu, downconv]
            up   = [uprelu, upconv, upnorm]
            model = down + up
        else:
            upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1, bias=False)
            down = [downrelu, downconv, downnorm]
            up   = [uprelu, upconv, upnorm]
            if use_dropout:
                model = down + [submodule] + up + [nn.Dropout(0.5)]
            else:
                model = down + [submodule] + up

        self.model = nn.Sequential(*model)

    def forward(self, x):
        if self.outermost:
            return self.model(x)
        else:
            # Concatenate skip connection
            return torch.cat([x, self.model(x)], 1)

class UNetGenerator(nn.Module):
    def __init__(self, input_nc=4, output_nc=3, num_downs=8, ngf=64):
        super(UNetGenerator, self).__init__()
        # Construct unet structure
        # innermost layer
        unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, innermost=True)
        # add intermediate layers with dropout for deeper layers
        for i in range(num_downs - 5):
            unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, submodule=unet_block, use_dropout=True)
        unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, submodule=unet_block)
        unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, submodule=unet_block)
        unet_block = UnetSkipConnectionBlock(ngf,     ngf * 2, submodule=unet_block)
        self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True)

    def forward(self, x):
        return self.model(x)


In [None]:
# defining the default transforms of the image
def default_transform(img, num_channels):
    # Resize to 256x256, convert to tensor and normalize to [-1, 1]
    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5]*num_channels, std=[0.5]*num_channels)
    ])
    return transform(img)

# Class to dive deep in de fencing the images
class DeFencingDataset(Dataset):
    def __init__(self, input_dir, target_dir, edge_dir):
        self.input_paths = sorted([os.path.join(input_dir, f) for f in os.listdir(input_dir) 
                                    if f.lower().endswith(('.png', '.jpg', '.jpeg'))])
        self.target_paths = sorted([os.path.join(target_dir, f) for f in os.listdir(target_dir)
                                     if f.lower().endswith(('.png', '.jpg', '.jpeg'))])
        self.edge_paths = sorted([os.path.join(edge_dir, f) for f in os.listdir(edge_dir)
                                   if f.lower().endswith(('.png', '.jpg', '.jpeg'))])
    def __len__(self):
        return len(self.input_paths)
    def __getitem__(self, idx):
        # Load images
        fenced_img = Image.open(self.input_paths[idx]).convert('RGB')  # 3 channels
        target_img = Image.open(self.target_paths[idx]).convert('RGB')   # 3 channels
        edge_img   = Image.open(self.edge_paths[idx]).convert('L')         # 1 channel
        
        # Apply transforms (normalization to [-1, 1])
        fenced_tensor = default_transform(fenced_img, 3)
        target_tensor = default_transform(target_img, 3)
        edge_tensor   = default_transform(edge_img, 1)
        
        # Concatenate the fenced image and edge map along the channel dimension -> 4 channels
        input_tensor = torch.cat([fenced_tensor, edge_tensor], dim=0)
        return input_tensor, target_tensor


In [None]:
# 3. Define the Discriminator (PatchGAN)

class PatchGANDiscriminator(nn.Module):
    def __init__(self, input_nc=7, ndf=64, n_layers=3):
        super(PatchGANDiscriminator, self).__init__()
        kw = 4
        padw = 1
        sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
                    nn.LeakyReLU(0.2, True)]
        nf_mult = 1
        for n in range(1, n_layers):
            nf_mult_prev = nf_mult
            nf_mult = min(2**n, 8)
            sequence += [
                nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=False),
                nn.BatchNorm2d(ndf * nf_mult),
                nn.LeakyReLU(0.2, True)
            ]
        nf_mult_prev = nf_mult
        nf_mult = min(2**n_layers, 8)
        sequence += [
            nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=False),
            nn.BatchNorm2d(ndf * nf_mult),
            nn.LeakyReLU(0.2, True)
        ]
        sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]
        self.model = nn.Sequential(*sequence)

    def forward(self, x):
        return self.model(x)

# The conditional discriminator concatenates the conditioned input (fenced+edge, 4 channels)
# with the target (or generated de-fenced image, 3 channels) → 7 channels.
class DeFencingDiscriminator(nn.Module):
    def __init__(self):
        super(DeFencingDiscriminator, self).__init__()
        self.model = PatchGANDiscriminator(input_nc=7)
    def forward(self, input_image, target_image):
        # Concatenate along channel dimension
        x = torch.cat([input_image, target_image], dim=1)
        return self.model(x)


In [None]:
# 4. Initialize Models, Losses and Optimizers

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

generator = UNetGenerator().to(device)
discriminator = DeFencingDiscriminator().to(device)

criterion_GAN = nn.BCEWithLogitsLoss()  # Adversarial loss
criterion_L1 = nn.L1Loss()              # L1 loss

optimizer_G = torch.optim.Adam(generator.parameters(), lr=2e-4, betas=(0.5, 0.999))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=2e-4, betas=(0.5, 0.999))

LAMBDA = 100  # weight for L1 loss
# Create datasets and dataloaders
train_dataset = DeFencingDataset(train_input_dir, train_target_dir, train_edge_dir)
test_dataset  = DeFencingDataset(test_input_dir, test_target_dir, test_edge_dir)
train_loader  = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=2)
test_loader   = DataLoader(test_dataset, batch_size=16, shuffle=False, num_workers=2)

In [None]:
# 5. Training Loop
num_epochs = 10

for epoch in range(1, num_epochs+1):
    generator.train()
    discriminator.train()
    for i, (input_img, target_img) in enumerate(train_loader):
        input_img = input_img.to(device)   # shape: [B, 4, 256, 256]
        target_img = target_img.to(device)   # shape: [B, 3, 256, 256]
        
        # ---------------------
        # Train Generator
        # ---------------------
        optimizer_G.zero_grad()
        fake_img = generator(input_img)      # Generated de-fenced image
        # Discriminator's output on fake pair
        pred_fake = discriminator(input_img, fake_img)
        valid = torch.ones_like(pred_fake, device=device)
        loss_G_GAN = criterion_GAN(pred_fake, valid)
        loss_G_L1  = criterion_L1(fake_img, target_img)
        loss_G = loss_G_GAN + LAMBDA * loss_G_L1
        loss_G.backward()
        optimizer_G.step()
        
        # ---------------------
        # Train Discriminator
        # ---------------------
        optimizer_D.zero_grad()
        # Real pair loss
        pred_real = discriminator(input_img, target_img)
        loss_D_real = criterion_GAN(pred_real, valid)
        # Fake pair loss (detach fake image)
        fake_detach = fake_img.detach()
        pred_fake = discriminator(input_img, fake_detach)
        fake = torch.zeros_like(pred_fake, device=device)
        loss_D_fake = criterion_GAN(pred_fake, fake)
        loss_D = (loss_D_real + loss_D_fake) * 0.5
        loss_D.backward()
        optimizer_D.step()

        
        if i % 50 == 0:
            print(f"Epoch [{epoch}/{num_epochs}] Batch [{i}/{len(train_loader)}] "
                  f"Loss_G: {loss_G.item():.4f} Loss_D: {loss_D.item():.4f}")
    if epoch % 10 == 0:
        generator.eval()
        with torch.no_grad():
            for test_input, test_target in test_loader:
                test_input = test_input.to(device)
                test_target = test_target.to(device)
                fake_test = generator(test_input)
                # Show the first image of the batch
                def imshow(img, title):
                    npimg = img.cpu().detach().numpy()
                    # rescale from [-1,1] to [0,1]
                    npimg = (npimg + 1) / 2
                    plt.imshow(np.transpose(npimg, (1, 2, 0)))
                    plt.title(title)
                    plt.axis('off')
                plt.figure(figsize=(15,5))
                plt.subplot(1,3,1)
                imshow(test_input[0, :3, :, :], "Fenced Image")
                plt.subplot(1,3,2)
                imshow(test_target[0], "Ground Truth")
                plt.subplot(1,3,3)
                imshow(fake_test[0], "De-fenced (Generated)")
                plt.show()
                break  # only display one batch
# Save Model Checkpoints
torch.save(generator.state_dict(), f"defencing_generator.pth")
torch.save(discriminator.state_dict(), f"defencing_discriminator.pth")

In [None]:
# 6. Evaluating the model
def evaluate_model(generator, test_loader, device):
    """
    Evaluates the generator on the test dataset and computes the average PSNR and SSIM.
    """
    generator.eval()
    total_psnr = 0.0
    total_ssim = 0.0
    count = 0

    with torch.no_grad():
        for input_img, target_img in test_loader:
            input_img = input_img.to(device)   # shape: [B, 4, 256, 256]
            target_img = target_img.to(device) # shape: [B, 3, 256, 256]
            
            fake_img = generator(input_img)
            
            fake_img = (fake_img + 1) / 2.0
            target_img = (target_img + 1) / 2.0
            
            # Compute metrics per image in the batch
            for i in range(fake_img.size(0)):
                current_psnr = psnr(fake_img[i], target_img[i], data_range=1.0)
                # ssim expects a 4D tensor: [B, C, H, W] so we unsqueeze at batch dimension.
                current_ssim = ssim(fake_img[i].unsqueeze(0), target_img[i].unsqueeze(0), data_range=1.0)
                total_psnr += current_psnr.item()
                total_ssim += current_ssim.item()
                count += 1

    avg_psnr = total_psnr / count
    avg_ssim = total_ssim / count
    print(f"Average PSNR: {avg_psnr:.2f} dB")
    print(f"Average SSIM: {avg_ssim:.4f}")
    return avg_psnr, avg_ssim


In [None]:
# 7. Visualize the model
def visualize_outputs(generator, test_loader, device, num_samples=3):
    generator.eval()
    with torch.no_grad():
        # Get one batch from the test loader
        for input_img, target_img in test_loader:
            input_img = input_img.to(device)
            target_img = target_img.to(device)
            fake_img = generator(input_img)
            # Rescale outputs to [0, 1]
            fake_img = (fake_img + 1) / 2.0
            target_img = (target_img + 1) / 2.0
            # For visualization, extract the first 3 channels (RGB) from the input.
            input_img_vis = (input_img[:, :3, :, :] + 1) / 2.0
            
            for i in range(min(num_samples, fake_img.size(0))):
                plt.figure(figsize=(12,4))
                plt.subplot(1,3,1)
                plt.imshow(np.transpose(input_img_vis[i].cpu().numpy(), (1,2,0)))
                plt.title("Fenced Input (RGB)")
                plt.axis('off')
                
                plt.subplot(1,3,2)
                plt.imshow(np.transpose(target_img[i].cpu().numpy(), (1,2,0)))
                plt.title("Ground Truth De-fenced")
                plt.axis('off')
                
                plt.subplot(1,3,3)
                plt.imshow(np.transpose(fake_img[i].cpu().numpy(), (1,2,0)))
                plt.title("Generated De-fenced")
                plt.axis('off')
                plt.show()
            break  # Only display one batch

In [None]:
# 8. Call the evaluation and visualization function
generator.load_state_dict(torch.load(f"defencing_generator.pth"))
generator.to(device)
avg_psnr, avg_ssim = evaluate_model(generator, test_loader, device)
visualize_outputs(generator, test_loader, device, num_samples=6)