In [1]:
!pip install gdown -q

[0m

In [2]:
!gdown https://drive.google.com/uc?id=1-h9Uno7UB8MPWmYpcWQzMy0ukWkEZ62F
!gdown https://drive.google.com/uc?id=1O5TF1LJMTr03FSUWO5_4Abh-Irl3U0DA

Downloading...
From (uriginal): https://drive.google.com/uc?id=1-h9Uno7UB8MPWmYpcWQzMy0ukWkEZ62F
From (redirected): https://drive.google.com/uc?id=1-h9Uno7UB8MPWmYpcWQzMy0ukWkEZ62F&confirm=t&uuid=9594ae9e-0fda-4812-890f-1f480110f7ae
To: /kaggle/working/transmission_layer.tar.gz
100%|████████████████████████████████████████| 688M/688M [00:45<00:00, 15.1MB/s]
Downloading...
From (uriginal): https://drive.google.com/uc?id=1O5TF1LJMTr03FSUWO5_4Abh-Irl3U0DA
From (redirected): https://drive.google.com/uc?id=1O5TF1LJMTr03FSUWO5_4Abh-Irl3U0DA&confirm=t&uuid=dcee1697-2553-4d4e-b216-030bdff2434f
To: /kaggle/working/reflection_layer.tar.gz
100%|████████████████████████████████████████| 690M/690M [00:52<00:00, 13.2MB/s]


In [3]:
%%capture
!tar -xvzf /kaggle/working/transmission_layer.tar.gz
!tar -xvzf /kaggle/working/reflection_layer.tar.gz

In [4]:
import numpy as np
import scipy.stats as st
k_sz=np.linspace(1,5,80) # for synthetic images


In [5]:


# functions for synthesizing images with reflection (details in the paper)
def gkern(kernlen=100, nsig=1):
    """Returns a 2D Gaussian kernel array."""
    interval = (2*nsig+1.)/(kernlen)
    x = np.linspace(-nsig-interval/2., nsig+interval/2., kernlen+1)
    kern1d = np.diff(st.norm.cdf(x))
    kernel_raw = np.sqrt(np.outer(kern1d, kern1d))
    kernel = kernel_raw/kernel_raw.sum()
    kernel = kernel/kernel.max()
    return kernel


g_mask=gkern(560,3)
g_mask=np.dstack((g_mask,g_mask,g_mask))

def syn_data(t,r,sigma):
    t=np.power(t,2.2)
    r=np.power(r,2.2)
    
    sz=int(2*np.ceil(2*sigma)+1)
    r_blur=cv2.GaussianBlur(r,(sz,sz),sigma,sigma,0)
    blend=r_blur+t
    
    att=1.08+np.random.random()/10.0
    
    for i in range(3):
        maski=blend[:,:,i]>1
        mean_i=max(1.,np.sum(blend[:,:,i]*maski)/(maski.sum()+1e-6))
        r_blur[:,:,i]=r_blur[:,:,i]-(mean_i-1)*att
    r_blur[r_blur>=1]=1
    r_blur[r_blur<=0]=0

    h,w=r_blur.shape[0:2]
    neww=np.random.randint(0, 560-w-10)
    newh=np.random.randint(0, 560-h-10)
    alpha1=g_mask[newh:newh+h,neww:neww+w,:]
    alpha2 = 1-np.random.random()/5.0;
    r_blur_mask=np.multiply(r_blur,alpha1)
    blend=r_blur_mask+t*alpha2
    
    t=np.power(t,1/2.2)
    r_blur_mask=np.power(r_blur_mask,1/2.2)
    blend=np.power(blend,1/2.2)
    blend[blend>=1]=1
    blend[blend<=0]=0

    return t,r_blur_mask,blend


### Image Preprocessing. Set new_dim here.

In [6]:
import cv2
new_dim=256

img = cv2.imread("reflection_layer/4059.jpg")
width, height,_ = img.shape
min_dim = min(width, height)
img = img[:min_dim, :min_dim]
img = cv2.resize(img, (new_dim, new_dim), cv2.INTER_CUBIC)

r = np.float32(img)/255.0
             


In [7]:

img = cv2.imread("transmission_layer/4059.jpg")
width, height,_ = img.shape
min_dim = min(width, height)
img = img[:min_dim, :min_dim]
img = cv2.resize(img, (new_dim, new_dim), cv2.INTER_CUBIC)

t = np.float32(img)/255.0
             

In [8]:
sigma=k_sz[np.random.randint(0, len(k_sz))]
_t, blur_mask, blend = syn_data(t,r,sigma)
blend.shape, _t.shape, blur_mask.shape

((256, 256, 3), (256, 256, 3), (256, 256, 3))

In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import PIL
import os
import torch
import pandas as pd
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils

# Ignore warnings
import warnings
warnings.filterwarnings("ignore")
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

plt.ion()   # interactive mode
# torch.set_default_dtype(torch.float32)

<contextlib.ExitStack at 0x7da5fc9d6620>

In [10]:
device

device(type='cuda')

### Cropping and Resizing Reflection Images

In [11]:
!mkdir /kaggle/working/r_images

In [12]:
import os
from tqdm import tqdm

path = "/kaggle/working/reflection_layer"
save_to = "/kaggle/working/r_images"
for file in tqdm(os.listdir(path)):
    img = cv2.imread(path+"/"+file) 
    width, height,_ = img.shape
    min_dim = min(width, height)
    img = img[:min_dim, :min_dim]
    img = cv2.resize(img, (new_dim, new_dim))
    cv2.imwrite(save_to+"/"+file, img)
  

100%|██████████| 13700/13700 [01:23<00:00, 164.12it/s]


### Cropping and Resizing Target/Transmission Images

In [13]:
!mkdir /kaggle/working/t_images

In [14]:
import os
from tqdm import tqdm

path = "/kaggle/working/transmission_layer"
save_to = "/kaggle/working/t_images"
for file in tqdm(os.listdir(path)):
    img = cv2.imread(path+"/"+file) 
    width, height,_ = img.shape
    min_dim = min(width, height)
    img = img[:min_dim, :min_dim]
    img = cv2.resize(img, (new_dim, new_dim))
    cv2.imwrite(save_to+"/"+file, img)
  

100%|██████████| 13749/13749 [01:24<00:00, 162.51it/s]


In [15]:
import os

lt = [int(x.split(".")[0]) for x in os.listdir("/kaggle/working/t_images")]
lt.sort()
lt[-1]

15000

### Filling in missing images by copying a random image

In [16]:
import os
import shutil

# Specify the folder path where the images are located
folder_path = "/kaggle/working/t_images"

# Specify the file extension of the images
file_extension = ".jpg"

# Specify the range of image numbers
start_number = 1
end_number = 15000

# Iterate over the range of image numbers
for i in range(start_number, end_number + 1):
    # Generate the filename for the current image
    filename = str(i) + file_extension
    
    # Check if the image file exists
    if not os.path.isfile(os.path.join(folder_path, filename)):
        # Search for an existing image to copy and rename
        for j in range(i + 1, end_number + 1):
            existing_filename = str(j) + file_extension
            if os.path.isfile(os.path.join(folder_path, existing_filename)):
                # Copy and rename the existing image to fill the missing image
                shutil.copy2(os.path.join(folder_path, existing_filename),
                             os.path.join(folder_path, filename))
#                 print(f"Filled missing image: {filename} with {existing_filename}")
                break


In [17]:
import os
import shutil

# Specify the folder path where the images are located
folder_path = "/kaggle/working/r_images"

# Specify the file extension of the images
file_extension = ".jpg"

# Specify the range of image numbers
start_number = 1
end_number = 15000

# Iterate over the range of image numbers
for i in range(start_number, end_number + 1):
    # Generate the filename for the current image
    filename = str(i) + file_extension
    
    # Check if the image file exists
    if not os.path.isfile(os.path.join(folder_path, filename)):
        # Search for an existing image to copy and rename
        for j in range(i + 1, end_number + 1):
            existing_filename = str(j) + file_extension
            if os.path.isfile(os.path.join(folder_path, existing_filename)):
                # Copy and rename the existing image to fill the missing image
                shutil.copy2(os.path.join(folder_path, existing_filename),
                             os.path.join(folder_path, filename))
#                 print(f"Filled missing image: {filename} with {existing_filename}")
                break


In [18]:
# !cp /kaggle/working/transmission_layer2/35.jpg /kaggle/working/transmission_layer2/34.jpg
# !cp /kaggle/working/blended2/35.jpg /kaggle/working/blended2/34.jpg

### Statically creating the dataset i.e. target_image + reflection = input_image

Takes time. Consider saving it as a kaggle dataset

In [19]:
!mkdir /kaggle/working/i_images

In [20]:
for t_image in tqdm(os.listdir("/kaggle/working/t_images")):
    t = cv2.imread("/kaggle/working/t_images/"+t_image)
    t = t.astype(np.float32)/255.0
    r = cv2.imread("/kaggle/working/r_images/"+t_image)
    r = r.astype(np.float32)/255.0

    sigma=k_sz[np.random.randint(0, len(k_sz))]
    _,_,input_image = syn_data(t, r, sigma)

    i = (input_image*255).astype(np.uint8)
    cv2.imwrite("/kaggle/working/i_images/"+t_image, i)
  



    


100%|██████████| 15000/15000 [05:44<00:00, 43.50it/s]


## UNET

In [21]:
import torch
import torch.nn as nn

# U-Net architecture
class UNet(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UNet, self).__init__()

        # Contracting path
        self.conv1 = self.double_conv(in_channels, 64)
        self.pool1 = nn.MaxPool2d(2)
        self.conv2 = self.double_conv(64, 128)
        self.pool2 = nn.MaxPool2d(2)
        self.conv3 = self.double_conv(128, 256)
        self.pool3 = nn.MaxPool2d(2)
        self.conv4 = self.double_conv(256, 512)
        self.pool4 = nn.MaxPool2d(2)
        self.conv5 = self.double_conv(512, 1024)

        # Expanding path
        self.upconv1 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
        self.conv6 = self.double_conv(1024, 512)
        self.upconv2 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.conv7 = self.double_conv(512, 256)
        self.upconv3 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.conv8 = self.double_conv(256, 128)
        self.upconv4 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.conv9 = self.double_conv(128, 64)

        # Output layer
        self.outconv = nn.Conv2d(64, out_channels, kernel_size=1)

    def double_conv(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        # Contracting path
        c1 = self.conv1(x)
        p1 = self.pool1(c1)
        c2 = self.conv2(p1)
        p2 = self.pool2(c2)
        c3 = self.conv3(p2)
        p3 = self.pool3(c3)
        c4 = self.conv4(p3)
        p4 = self.pool4(c4)
        c5 = self.conv5(p4)

        # Expanding path
        up1 = self.upconv1(c5)
        merge1 = torch.cat([up1, c4], dim=1)
        c6 = self.conv6(merge1)
        up2 = self.upconv2(c6)
        merge2 = torch.cat([up2, c3], dim=1)
        c7 = self.conv7(merge2)
        up3 = self.upconv3(c7)
        merge3 = torch.cat([up3, c2], dim=1)
        c8 = self.conv8(merge3)
        up4 = self.upconv4(c8)
        merge4 = torch.cat([up4, c1], dim=1)
        c9 = self.conv9(merge4)

        # Output layer
        output = self.outconv(c9)
        return output


### Dataset for Static Input Images.

Has a cache mechanism (default disabled) which may or maynot improve speed. Cache doesn't work with >5000 dataset

In [22]:

import random

class ReflectDataset(Dataset):
  
    def __init__(self, input_dir, reflect_dir, target_dir, dataset_len, offset, do_cache=False):
        self.target_dir = target_dir
        self.reflect_dir = reflect_dir
        self.input_dir = input_dir
        self.dataset_len = dataset_len
        self.offset = offset
        
        self.do_cache = do_cache
        if self.do_cache:
            self.cache = [None]*dataset_len

    def __len__(self):
        return self.dataset_len

    def __getitem__(self, idx):
        
        if self.do_cache and self.cache[idx]:
            return self.cache[idx]
    
        
        img_name = self.reflect_dir+"/"+str( random.randint(1, 15000) )+".jpg"
    
        img_name = self.target_dir+"/"+str(self.offset+idx+1)+".jpg"
        target_image = cv2.imread(img_name)
        target_image = target_image.astype(np.float32)/255.0

        img_name = self.input_dir+"/"+str(self.offset+idx+1)+".jpg"
        input_image = cv2.imread(img_name)
        input_image = input_image.astype(np.float32)/255.0
        input_image = torch.from_numpy(input_image).to(device)   
        input_image = input_image.permute(2, 0, 1)

        
        target_image = torch.from_numpy(target_image).to(device)        
        target_image = target_image.permute(2,0,1)

       
        if self.do_cache:      
            self.cache[idx]=(input_image,target_image)

        return input_image,target_image

### Class that generates on the fly. Quite slow.

In [23]:

import random

class ReflectDataset_OTF(Dataset):
  
    def __init__(self, input_dir, reflect_dir, target_dir, dataset_len, offset):
        self.target_dir = target_dir
        self.reflect_dir = reflect_dir
        self.input_dir = input_dir
        self.dataset_len = dataset_len
        self.offset = offset

    def __len__(self):
        return self.dataset_len

    def __getitem__(self, idx):
     
        img_name = self.reflect_dir+"/"+str( random.randint(1, 15000) )+".jpg"

        reflection = cv2.imread(img_name) 
        reflection = reflection.astype(np.float32)/255.0

        
        img_name = self.target_dir+"/"+str(self.offset+idx+1)+".jpg"
        target_image = cv2.imread(img_name)
        target_image = target_image.astype(np.float32)/255.0

        sigma=k_sz[np.random.randint(0, len(k_sz))]
        _,_,input_image = syn_data(target_image, reflection, sigma)
        input_image = torch.from_numpy(input_image.astype(np.float32)).to(device)

      
        input_image = input_image.permute(2, 0, 1)

        
        target_image = torch.from_numpy(target_image).to(device)        
        target_image = target_image.permute(2,0,1)

    
        return input_image,target_image

In [24]:
n = len(os.listdir("/kaggle/working/t_images"))
print(n)

15000


## Reduce dataset here.

In [25]:
n=1000

## Change from Static to Dynamic Here

In [26]:

# split = 0.25
# offset = int(split*n)
# train_ds = ReflectDataset_OTF("/kaggle/working/i_images","/kaggle/working/r_images","/kaggle/working/t_images",offset,0)
# test_ds = ReflectDataset_OTF("/kaggle/working/i_images","/kaggle/working/r_images","/kaggle/working/t_images",n-offset,offset)

In [27]:

split = 0.25
offset = int(split*n)
train_ds = ReflectDataset("/kaggle/working/i_images","/kaggle/working/r_images","/kaggle/working/t_images",offset,0)
test_ds = ReflectDataset("/kaggle/working/i_images","/kaggle/working/r_images","/kaggle/working/t_images",n-offset,offset)

In [28]:
batch_size = 8
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=batch_size,shuffle=True)

test_loader = torch.utils.data.DataLoader(test_ds, batch_size=2*batch_size,shuffle=False)


In [29]:
for i, batch in enumerate(train_loader):
    print(batch[0].shape)
    print(batch[1].shape)
    break

torch.Size([8, 3, 256, 256])
torch.Size([8, 3, 256, 256])


In [30]:
torch.cuda.empty_cache()

# Example usage
in_channels = 3  # Number of input channels
out_channels = 3 # Number of output channels
model = UNet(in_channels, out_channels).to(device)

for i, batch in enumerate(train_loader):
    input_images = batch[0]
    target_images = batch[1]
    
    predictions = model(input_images)
    
    print(predictions.shape)
    break


torch.Size([8, 3, 256, 256])


In [31]:
train_num_batches = len(train_ds)/(batch_size)
test_num_batches = len(test_ds)/(2*batch_size)

In [32]:

torch.cuda.empty_cache()
# model = Generator(in_channels,out_features).to(device)
# model.load_state_dict(torch.load('/kaggle/working/best_weights.pth', map_location = device))

# model.load_state_dict(torch.load('best_weights.pth'))
num_epochs = 50

# Example usage
in_channels = 3  # Number of input channels
out_channels = 3 # Number of output channels
model = UNet(in_channels, out_channels).to(device)

criterion = nn.MSELoss()
learning_rate = 0.0003
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=learning_rate, steps_per_epoch=len(train_ds), epochs=num_epochs)


In [None]:

min_test_loss = 1000000
for epoch in range(num_epochs):
    epoch_train_loss = 0
    epoch_test_loss = 0
    epoch_rand_loss = 0
    
    model.train()
    for i, batch in enumerate(train_loader):
      
        images = batch[0]
        targets = batch[1]

        # Forward pass
        predictions = model(images)
        train_loss = criterion(predictions, targets)
        epoch_train_loss += train_loss

        # Backward and optimize
        optimizer.zero_grad()
        train_loss.backward()
        optimizer.step()
       
    model.eval()
  
    with torch.no_grad():

        for i, test_batch in enumerate(test_loader):
            images = test_batch[0]
            targets = test_batch[1]
            test_predictions = model(images)
            test_loss = criterion(test_predictions, targets)
            epoch_test_loss += test_loss 
        

    epoch_test_loss  = epoch_test_loss/test_num_batches
    epoch_train_loss = epoch_train_loss/train_num_batches
   
    if min_test_loss > epoch_test_loss:
        min_test_loss = epoch_test_loss
        print("saving on epoch",epoch)
        torch.save(model.state_dict(), './model1.pth')
    
#     print(f'rand loss: {epoch_rand_loss:.4f}')
    print (f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {epoch_train_loss:.4f} Test Loss: {epoch_test_loss:.4f}, Ratio: {epoch_test_loss/epoch_train_loss:.4f}')
    

print('Finished Training')
PATH = './model2.pth'
torch.save(model.state_dict(), PATH)


saving on epoch 0
Epoch [1/50], Train Loss: 0.0896 Test Loss: 0.0337, Ratio: 0.3766
saving on epoch 1
Epoch [2/50], Train Loss: 0.0317 Test Loss: 0.0268, Ratio: 0.8467
saving on epoch 2
Epoch [3/50], Train Loss: 0.0264 Test Loss: 0.0211, Ratio: 0.8009
Epoch [4/50], Train Loss: 0.0256 Test Loss: 0.0235, Ratio: 0.9189
saving on epoch 4
Epoch [5/50], Train Loss: 0.0234 Test Loss: 0.0209, Ratio: 0.8925
saving on epoch 5
Epoch [6/50], Train Loss: 0.0231 Test Loss: 0.0203, Ratio: 0.8777
saving on epoch 6
Epoch [7/50], Train Loss: 0.0219 Test Loss: 0.0194, Ratio: 0.8848
Epoch [8/50], Train Loss: 0.0215 Test Loss: 0.0200, Ratio: 0.9293
Epoch [9/50], Train Loss: 0.0212 Test Loss: 0.0196, Ratio: 0.9252
saving on epoch 9
Epoch [10/50], Train Loss: 0.0209 Test Loss: 0.0186, Ratio: 0.8890
Epoch [11/50], Train Loss: 0.0215 Test Loss: 0.0186, Ratio: 0.8674
Epoch [12/50], Train Loss: 0.0211 Test Loss: 0.0195, Ratio: 0.9225
Epoch [13/50], Train Loss: 0.0207 Test Loss: 0.0198, Ratio: 0.9570
Epoch [14/50

In [None]:

model.load_state_dict(torch.load('/kaggle/working/model1.pth', map_location = device))
model.eval()

In [None]:
import matplotlib.pyplot as plt

import torchvision.transforms.functional as F


def imshow(imgs):
    if not isinstance(imgs, list):
        imgs = [imgs]
    fig, axs = plt.subplots(ncols=len(imgs), squeeze=False)
    for i, img in enumerate(imgs):
        img = img.detach()
        img = F.to_pil_image(img)
        axs[0, i].imshow(np.asarray(img))
        axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
        
# imshow(torchvision.utils.make_grid(list( train_ds[100] )))

In [None]:

with torch.no_grad():

    for i, test_batch in enumerate(test_loader):
        images = test_batch[0]
        targets = test_batch[1]
        test_predictions = model(images)


        all_images = []
        for image,image2,image3 in zip(images,targets,test_predictions):
            gen_img=image
#             image = (image.permute(1,2,0).cpu().numpy()*255).astype(np.uint8)
            all_images.append(image)

       
#             image2 = (image2.permute(1,2,0).cpu().numpy()*255).astype(np.uint8)
            all_images.append(image2)

          
#             image3 = (image3.permute(1,2,0).cpu().numpy()*255).astype(np.uint8)
            all_images.append(image3)
            
        idx = 1
        imshow(torchvision.utils.make_grid([images[idx],targets[idx],test_predictions[idx]]))
        break