In [None]:
!pip install datasets
from datasets import load_dataset



In [None]:
import numpy as np
import cv2
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from PIL import Image
from torchvision import transforms
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import h5py
import matplotlib.pyplot as plt
import numpy as np

In [None]:
import yaml

In [None]:
parent_path = "drive/MyDrive/M202A/"

Hyperparameter Setup and other banal things

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


Importing Data into Disk

In [None]:
from datasets import load_from_disk
ds = load_from_disk(parent_path+"NYU/")


Loading Testing Dataset

In [None]:
path_to_dataset = "drive//MyDrive//M202A//nyu_depth_v2_labeled.mat"

def extract_files(config):
  path = config["path"]
  f = h5py.File(path_to_dataset)
  length = len(f["images"])
  indices = np.random.randint(0,length,size=length//3)
  batch_size = length//3
  color = np.zeros((batch_size,480,640,3))
  rawdepth = np.zeros((batch_size,480,640))
  truedepth = np.zeros_like(rawdepth)
  for i in tqdm(range(0,len(indices))):
    j = indices[i]
    # read 0-th image. original format is [3 x 640 x 480], uint8
    img = f['images'][j].astype(np.uint8)
    img = img.transpose(2,1,0)
    color[i,...] = img
    rawD = f['rawDepths'][j]
    rawdepth[i,...] = rawD.T.astype(np.float32)
    depth_from_disk = f["depths"][j].astype(np.float32)
    truedepth[i,...] = depth_from_disk.T
  config["color"] = color
  config["depth"] = truedepth
  config["rawDepth"] = rawdepth
config={"path":path_to_dataset}
extract_files(config)

100%|██████████| 483/483 [00:31<00:00, 15.30it/s]


In [None]:
rgb_config = {"R_u":-0.00776206,"R_w":0.01519309,"Rvar":6.88208756e-06,
              "G_u":0.01147944,"G_w":0.50978071,"Gvar":9.19419688e-06,
              "B_u":0.00371937,"B_w":0.47502621,"Bvar":6.89479643e-06}

# Depth Noise based on the Channels, readings taken @0.5m

In [None]:
print(len(ds['train']))
ds.set_format("torch", device=device)
ds.shuffle()

47584


DatasetDict({
    train: Dataset({
        features: ['image', 'depth_map'],
        num_rows: 47584
    })
    validation: Dataset({
        features: ['image', 'depth_map'],
        num_rows: 654
    })
})

Loading Pretrained Unet for Depth

In [None]:
with open(parent_path+'config.yaml', 'r') as file:
    model_configs = yaml.safe_load(file)
model_configs['model']['channels']=1
model_configs['model']['base filters']=16

In [None]:
from model import *

Creating New Network by mixing the two

In [None]:
class Resnet_UNet(nn.Module):
  """
  Residual-Dense U-net for image denoising.
  """
  def __init__(self,**kwargs):
      super().__init__()
      channels = kwargs['channels']
      filters_0 = kwargs['base filters']
      filters_1 = 2 * filters_0
      filters_2 = 4 * filters_0
      filters_3 = 8 * filters_0

      # Encoder:
      # Level 0:
      self.drop = nn.Dropout(p=0.5)
      self.input_block = InputBlock(4, filters_0)
      self.block_0_0 = DenoisingBlock(filters_0, filters_0 // 2, filters_0)
      self.block_0_1 = DenoisingBlock(filters_0, filters_0 // 2, filters_0)
      self.down_0 = DownsampleBlock(filters_0, filters_1)

      # Level 1:
      self.block_1_0 = DenoisingBlock(filters_1, filters_1 // 2, filters_1)
      self.block_1_1 = DenoisingBlock(filters_1, filters_1 // 2, filters_1)
      self.down_1 = DownsampleBlock(filters_1, filters_2)

      # Level 2:
      self.block_2_0 = DenoisingBlock(filters_2, filters_2 // 2, filters_2)
      self.block_2_1 = DenoisingBlock(filters_2, filters_2 // 2, filters_2)
      self.down_2 = DownsampleBlock(filters_2, filters_3)

      # Level 3 (Bottleneck)
      self.block_3_0 = DenoisingBlock(filters_3, filters_3 // 2, filters_3)
      self.block_3_1 = DenoisingBlock(filters_3, filters_3 // 2, filters_3)

      # Decoder
      # Level 2:
      self.up_2 = UpsampleBlock(filters_3, filters_2, filters_2)
      self.block_2_2 = DenoisingBlock(filters_2, filters_2 // 2, filters_2)
      self.block_2_3 = DenoisingBlock(filters_2, filters_2 // 2, filters_2)

      # Level 1:
      self.up_1 = UpsampleBlock(filters_2, filters_1, filters_1)
      self.block_1_2 = DenoisingBlock(filters_1, filters_1 // 2, filters_1)
      self.block_1_3 = DenoisingBlock(filters_1, filters_1 // 2, filters_1)

      # Level 0:
      self.up_0 = UpsampleBlock(filters_1, filters_0, filters_0)
      self.block_0_2 = DenoisingBlock(filters_0, filters_0 // 2, filters_0)
      self.block_0_3 = DenoisingBlock(filters_0, filters_0 // 2, filters_0)

      self.output_block = OutputBlock(filters_0, channels)


  def forward(self, inputs):
      inputs = self.drop(inputs)
      out_0 = self.input_block(inputs)    # Level 0
      out_0 = self.block_0_0(out_0)
      out_0 = self.block_0_1(out_0)

      out_1 = self.down_0(out_0)          # Level 1
      out_1 = self.block_1_0(out_1)
      out_1 = self.block_1_1(out_1)

      out_2 = self.down_1(out_1)          # Level 2
      out_2 = self.block_2_0(out_2)
      out_2 = self.block_2_1(out_2)

      out_3 = self.down_2(out_2)          # Level 3 (Bottleneck)

      out_3 = self.block_3_0(out_3)
      out_3 = self.block_3_1(out_3)



      out_4 = self.up_2([out_3, out_2])   # Level 2
      out_4 = self.block_2_2(out_4)
      out_4 = self.block_2_3(out_4)

      out_5 = self.up_1([out_4, out_1])   # Level 1
      out_5 = self.block_1_2(out_5)
      out_5 = self.block_1_3(out_5)

      out_6 = self.up_0([out_5, out_0])   # Level 0
      out_6 = self.block_0_2(out_6)
      out_6 = self.block_0_3(out_6)

      return self.output_block(out_6)


Loading Optimizers, Schedulers and Training the network

In [None]:
train_data = ds['train']
val_data = ds['validation']

In [None]:
from tqdm import tqdm
from datetime import datetime
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')

```
Sanity Testing the Segmentaion Loss before plugging it into the loss model
```

In [None]:
def add_noise_inloop(depth,cfg=rgb_config):
  noise  =(torch.randn(depth.shape,device=device))
  noise[:,1:,:,:] = 0
  depth = depth+noise
  return depth, noise[:,0,:,:].unsqueeze(1)

In [None]:
def test_loop():
  error = []
  color=config["color"]
  truedepth = config["depth"]
  rawdepth= config["rawDepth"]
  length = len(color)
  with torch.no_grad():
    for i in (range(length)):
      depth_img = torch.tensor(rawdepth[i],dtype=torch.float32).to(device).reshape(1,1,480,640)
      color_img  = torch.tensor(color[i],dtype=torch.float32).to(device).permute(2,0,1).unsqueeze(0)/255
      stacked = torch.hstack((depth_img,color_img))
      Z,n= add_noise_inloop(stacked,rgb_config)
      noise_estimate = ResUnet(Z)
      denoised_depth = (Z[:,0,:,:].unsqueeze(1) - n_estimate)
      mask =depth_img> 0
      masked_denoised_depth =denoised_depth*mask
      gt = truedepth[i]
      gt_tensor = torch.tensor(gt,dtype=torch.float).to(device)
      loss = (masked_denoised_depth-gt_tensor*mask).to("cpu").detach().numpy()
      error.append(loss)
    mae = [np.mean(abs(i)) for i in error]
    rmse = [np.sqrt(np.mean(i**2)) for i in error]
    print("Average mean squared error is ",np.mean(mae)*1000,"mm")
    print("Average root mean squared error is ",np.mean(rmse)*1000,"mm")

In [None]:
mse_loss = nn.MSELoss()
n_epochs = 5
lr = 1e-4
best_vloss = 1_000_000.
ResUnet = Resnet_UNet(**model_configs['model'])
ResUnet.to(device)
bs = 12
freq = bs * 200;
optimizer= torch.optim.AdamW(filter(lambda p:p.requires_grad,ResUnet.parameters()),lr=1e-3,weight_decay = 1e-4)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)
for epoch in (range(n_epochs)):
  train_data.shuffle()
  val_data.shuffle()

  running_loss = 0

  ## Training Loop
  for i in tqdm(range(0,len(train_data),bs)):
    with torch.autocast(device_type='cuda', dtype=torch.float16):
      optimizer.zero_grad()
      color = train_data[i:i+bs]['image'].permute(0,3,1,2)/255 ## Channels first and normalizing to 0 and 1
      depth = train_data[i:i+bs]['depth_map'].unsqueeze(0).permute(1,0,2,3)
      stacked = torch.hstack((depth,color))
      Z,n = add_noise_inloop(stacked,rgb_config)
      n_estimate = ResUnet(Z)
      denoised_depth_estimate = (Z[:,0,:,:].unsqueeze(1) - n_estimate)
      loss = mse_loss(depth,denoised_depth_estimate)
      loss.backward()
      optimizer.step()
      running_loss+=loss.item()
      if(i%freq==0 and i!=0):
        avg_loss = running_loss/freq
        running_loss = 0
        running_vloss = 0;
        with torch.no_grad():
          for count in range(len(val_data)):
            optimizer.zero_grad()
            color = train_data[i]['image'].permute(2,0,1).unsqueeze(0)/255 ## Channels first and normalizing to 0 and 1
            depth = train_data[i]['depth_map'].reshape(1,1,480,640)
            stacked = torch.hstack((depth,color))
            Z,n = add_noise_inloop(stacked,rgb_config)
            n_estimate = ResUnet(Z)
            denoised_depth_estimate = (Z[:,0,:,:].unsqueeze(1) - n_estimate)
            loss = mse_loss(depth,denoised_depth_estimate)
            running_vloss+=loss.item()
          avg_vloss = running_vloss/len(val_data)
          print("Loss train {} valid {}".format(avg_loss,avg_vloss))
          if avg_loss < best_vloss:
            best_vloss = avg_loss
            torch.save(ResUnet.state_dict(), parent_path+"MSE_Model.pth")
        test_loop()
        scheduler.step()


  5%|▌         | 200/3966 [05:47<1:48:55,  1.74s/it]

Loss train 1.861854945673258e-05 valid 0.0002709411569537783


  5%|▌         | 201/3966 [06:50<20:59:26, 20.07s/it]

Average mean squared error is  12.667677365243435 mm
Average root mean squared error is  27.72698737680912 mm


 10%|█         | 400/3966 [12:32<1:40:56,  1.70s/it]

Loss train 1.4152530187251008e-05 valid 0.0001329170278174196


 10%|█         | 401/3966 [13:34<19:31:37, 19.72s/it]

Average mean squared error is  10.001235641539097 mm
Average root mean squared error is  24.729710072278976 mm


 15%|█▌        | 600/3966 [22:32<1:35:19,  1.70s/it]

Loss train 1.7262352788141773e-05 valid 0.00030914716428253995


 15%|█▌        | 601/3966 [23:34<18:36:32, 19.91s/it]

Average mean squared error is  13.445792719721794 mm
Average root mean squared error is  28.47258374094963 mm


 20%|██        | 800/3966 [33:10<1:39:03,  1.88s/it]

Loss train 1.843243475984006e-05 valid 0.00016927394248041059


 20%|██        | 801/3966 [34:12<17:37:55, 20.06s/it]

Average mean squared error is  10.744567960500717 mm
Average root mean squared error is  25.568634271621704 mm


 25%|██▌       | 1000/3966 [41:59<1:34:20,  1.91s/it]

Loss train 1.761124383241016e-05 valid 0.0001494062683792124


 25%|██▌       | 1001/3966 [43:02<16:38:12, 20.20s/it]

Average mean squared error is  10.387249290943146 mm
Average root mean squared error is  25.111721828579903 mm


 30%|███       | 1200/3966 [52:08<1:28:02,  1.91s/it]

Loss train 1.79364878567867e-05 valid 0.00022211395312561326


 30%|███       | 1201/3966 [53:10<15:25:19, 20.08s/it]

Average mean squared error is  11.784960515797138 mm
Average root mean squared error is  26.61459520459175 mm


 35%|███▌      | 1400/3966 [1:02:24<1:27:13,  2.04s/it]

Loss train 2.4371994762380685e-05 valid 0.00024411513545142418


 35%|███▌      | 1401/3966 [1:03:26<14:17:54, 20.07s/it]

Average mean squared error is  12.364715337753296 mm
Average root mean squared error is  27.147715911269188 mm


 40%|████      | 1600/3966 [1:12:39<1:33:57,  2.38s/it]

Loss train 2.5952625401259864e-05 valid 0.00034865953468942633


 40%|████      | 1601/3966 [1:13:45<14:04:25, 21.42s/it]

Average mean squared error is  13.870383612811565 mm
Average root mean squared error is  29.01744842529297 mm


 45%|████▌     | 1800/3966 [1:21:18<1:38:32,  2.73s/it]

Loss train 2.371732815542297e-05 valid 0.00028508394637603645


 45%|████▌     | 1801/3966 [1:22:22<12:32:55, 20.87s/it]

Average mean squared error is  13.05477973073721 mm
Average root mean squared error is  27.803409844636917 mm


 50%|█████     | 2000/3966 [1:30:45<1:45:04,  3.21s/it]

Loss train 3.216072207578691e-05 valid 0.0004099826713392386


 50%|█████     | 2001/3966 [1:31:51<12:01:51, 22.04s/it]

Average mean squared error is  15.197339467704296 mm
Average root mean squared error is  30.075281858444214 mm


 55%|█████▌    | 2200/3966 [1:41:18<4:23:21,  8.95s/it]

Loss train 2.4907964798330797e-05 valid 0.00023272005063342785


 55%|█████▌    | 2201/3966 [1:42:22<12:28:05, 25.43s/it]

Average mean squared error is  12.339246459305286 mm
Average root mean squared error is  26.896359398961067 mm


 61%|██████    | 2400/3966 [1:51:58<49:31,  1.90s/it]

Loss train 2.4607820517606645e-05 valid 0.0003255288153925956


 61%|██████    | 2401/3966 [1:53:09<9:57:01, 22.89s/it]

Average mean squared error is  14.000093564391136 mm
Average root mean squared error is  28.681961819529533 mm


 66%|██████▌   | 2600/3966 [2:01:05<40:01,  1.76s/it]

Loss train 2.4585669052612502e-05 valid 0.00030168601027291144


 66%|██████▌   | 2601/3966 [2:02:07<7:31:55, 19.86s/it]

Average mean squared error is  13.547930866479874 mm
Average root mean squared error is  28.170626610517502 mm


 71%|███████   | 2800/3966 [2:09:42<34:16,  1.76s/it]

Loss train 2.559706793059983e-05 valid 0.00023082226819382974


 71%|███████   | 2801/3966 [2:10:44<6:29:05, 20.04s/it]

Average mean squared error is  12.213987298309803 mm
Average root mean squared error is  26.77823230624199 mm


 76%|███████▌  | 3000/3966 [2:19:27<27:32,  1.71s/it]

Loss train 2.2303322172471477e-05 valid 0.00042778663932657675


 76%|███████▌  | 3001/3966 [2:20:30<5:22:15, 20.04s/it]

Average mean squared error is  15.667036175727844 mm
Average root mean squared error is  30.6987427175045 mm


 81%|████████  | 3200/3966 [2:29:53<23:04,  1.81s/it]

Loss train 3.081572437319361e-05 valid 0.0003000813773003712


 81%|████████  | 3201/3966 [2:30:54<4:10:49, 19.67s/it]

Average mean squared error is  13.646910898387432 mm
Average root mean squared error is  28.198838233947754 mm


 86%|████████▌ | 3400/3966 [2:39:33<18:00,  1.91s/it]

Loss train 2.959886066795055e-05 valid 0.00024015123163604829


 86%|████████▌ | 3401/3966 [2:40:35<3:09:16, 20.10s/it]

Average mean squared error is  12.381047941744328 mm
Average root mean squared error is  26.90998837351799 mm


 91%|█████████ | 3600/3966 [2:48:11<12:42,  2.08s/it]

Loss train 2.8089054212614427e-05 valid 0.00031378712491395765


 91%|█████████ | 3601/3966 [2:49:11<1:58:35, 19.49s/it]

Average mean squared error is  13.708987273275852 mm
Average root mean squared error is  28.245197609066963 mm


 96%|█████████▌| 3800/3966 [2:57:57<05:10,  1.87s/it]

Loss train 2.8634942282224074e-05 valid 0.00031415471483687


 96%|█████████▌| 3801/3966 [2:58:59<54:47, 19.92s/it]

Average mean squared error is  13.696121983230114 mm
Average root mean squared error is  28.358345851302147 mm


100%|██████████| 3966/3966 [3:06:00<00:00,  2.81s/it]
  5%|▌         | 200/3966 [05:44<1:48:47,  1.73s/it]

Loss train 2.1852483247736623e-05 valid 0.00039063011870515854


  5%|▌         | 201/3966 [06:45<20:32:12, 19.64s/it]

Average mean squared error is  15.036677941679955 mm
Average root mean squared error is  29.8590287566185 mm


  8%|▊         | 319/3966 [10:07<1:47:30,  1.77s/it]

Eval

In [None]:
error = []
color=config["color"]
truedepth = config["depth"]
rawdepth= config["rawDepth"]
length = len(color)
denoised_depth_images = np.zeros_like(color)
with torch.no_grad():
  for i in (range(length)):
    depth_img = torch.tensor(rawdepth[i],dtype=torch.float32).to(device).reshape(1,1,480,640)
    color_img  = torch.tensor(color[i],dtype=torch.float32).to(device).permute(2,0,1).unsqueeze(0)/255
    stacked = torch.hstack((depth_img,color_img))
    Z,n= add_noise_inloop(stacked,rgb_config)
    noise_estimate = ResUnet(Z)
    denoised_depth = (Z[:,0,:,:].unsqueeze(1) - n_estimate)
    mask =depth_img> 0
    masked_denoised_depth =denoised_depth*mask
    gt = truedepth[i]
    gt_tensor = torch.tensor(gt,dtype=torch.float).to(device)
    loss = (masked_denoised_depth-gt_tensor*mask).to("cpu").detach().numpy()
    denoised_depth_images[i,...] = denoised_depth.detach().to("cpu").numpy()
    error.append(loss)
mae = [np.mean(abs(i)) for i in error]
rmse = [np.sqrt(np.mean(i**2)) for i in error]
print("Average mean squared error is ",np.mean(mae)*1000,"mm")
print("Average root mean squared error is ",np.mean(rmse)*1000,"mm")
random_indice = np.random.randint(low=0,high = len(mae),size=5)
denoised_images = denoised_depth_images[random_indice]
mean_squared_error = (truedepth[random_indice] - denoised_depth_images[random_indice])**2
fig,axs = plt.subplots(5,3)
for i in range(0,5):
      axs[i,0].imshow(truedepth[random_indice[i]],cmap="jet")
      axs[i,1].imshow(denoised_depth_images[random_indice[i]],cmap="jet")
      axs[i,2].imshow(mean_squared_error[i],cmap='jet') ## There will be a band of error on the borders, because of how the sensor was setup
plt.show()

In [None]:
random_indice = np.random.randint(low=0,high = len(mae),size=5)
denoised_images = denoised_depth_images[random_indice]
mean_squared_error = (truedepth[random_indice] - denoised_depth_images[random_indice])**2
fig,axs = plt.subplots(5,3)
for i in range(0,5):
      axs[i,0].imshow(truedepth[random_indice[i]],cmap="jet")
      axs[i,1].imshow(denoised_depth_images[random_indice[i]],cmap="jet")
      axs[i,2].imshow(mean_squared_error[i],cmap='jet') ## There will be a band of error on the borders, because of how the sensor was setup
plt.show()