In [1]:
!pip install datasets
from datasets import load_dataset



In [2]:
import numpy as np
import cv2
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from PIL import Image
from torchvision import transforms
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import h5py
import matplotlib.pyplot as plt
import numpy as np

In [3]:
import yaml

In [4]:
parent_path = "drive/MyDrive/M202A/"

Hyperparameter Setup and other banal things

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


Importing Data into Disk

In [6]:
from datasets import load_from_disk
ds = load_from_disk(parent_path+"NYU/")


Loading Testing Dataset

In [7]:
path_to_dataset = "drive//MyDrive//M202A//nyu_depth_v2_labeled.mat"

def extract_files(config):
  path = config["path"]
  f = h5py.File(path_to_dataset)
  length = len(f["images"])
  indices = np.random.randint(0,length,size=length//3)
  batch_size = length//3
  color = np.zeros((batch_size,480,640,3))
  rawdepth = np.zeros((batch_size,480,640))
  truedepth = np.zeros_like(rawdepth)
  for i in tqdm(range(0,len(indices))):
    j = indices[i]
    # read 0-th image. original format is [3 x 640 x 480], uint8
    img = f['images'][j].astype(np.uint8)
    img = img.transpose(2,1,0)
    color[i,...] = img
    rawD = f['rawDepths'][j]
    rawdepth[i,...] = rawD.T.astype(np.float32)
    depth_from_disk = f["depths"][j].astype(np.float32)
    truedepth[i,...] = depth_from_disk.T
  config["color"] = color
  config["depth"] = truedepth
  config["rawDepth"] = rawdepth
config={"path":path_to_dataset}
extract_files(config)

100%|██████████| 483/483 [00:28<00:00, 17.15it/s]


In [8]:
rgb_config = {"R_u":-0.00776206,"R_w":0.01519309,"Rvar":6.88208756e-06,
              "G_u":0.01147944,"G_w":0.50978071,"Gvar":9.19419688e-06,
              "B_u":0.00371937,"B_w":0.47502621,"Bvar":6.89479643e-06}

# Depth Noise based on the Channels, readings taken @0.5m

In [9]:
print(len(ds['train']))
ds.set_format("torch", device=device)
ds.shuffle()

47584


DatasetDict({
    train: Dataset({
        features: ['image', 'depth_map'],
        num_rows: 47584
    })
    validation: Dataset({
        features: ['image', 'depth_map'],
        num_rows: 654
    })
})

Loading Pretrained Unet for Depth

In [10]:
with open(parent_path+'config.yaml', 'r') as file:
    model_configs = yaml.safe_load(file)
model_configs['model']['channels']=1
model_configs['model']['base filters']=16

In [11]:
from model import *

Creating New Network by mixing the two

In [12]:
class Resnet_UNet(nn.Module):
  """
  Residual-Dense U-net for image denoising.
  """
  def __init__(self,**kwargs):
      super().__init__()
      channels = kwargs['channels']
      filters_0 = kwargs['base filters']
      filters_1 = 2 * filters_0
      filters_2 = 4 * filters_0
      filters_3 = 8 * filters_0

      # Encoder:
      # Level 0:
      self.drop = nn.Dropout(p=0.5)
      self.input_block = InputBlock(4, filters_0)
      self.block_0_0 = DenoisingBlock(filters_0, filters_0 // 2, filters_0)
      self.block_0_1 = DenoisingBlock(filters_0, filters_0 // 2, filters_0)
      self.down_0 = DownsampleBlock(filters_0, filters_1)

      # Level 1:
      self.block_1_0 = DenoisingBlock(filters_1, filters_1 // 2, filters_1)
      self.block_1_1 = DenoisingBlock(filters_1, filters_1 // 2, filters_1)
      self.down_1 = DownsampleBlock(filters_1, filters_2)

      # Level 2:
      self.block_2_0 = DenoisingBlock(filters_2, filters_2 // 2, filters_2)
      self.block_2_1 = DenoisingBlock(filters_2, filters_2 // 2, filters_2)
      self.down_2 = DownsampleBlock(filters_2, filters_3)

      # Level 3 (Bottleneck)
      self.block_3_0 = DenoisingBlock(filters_3, filters_3 // 2, filters_3)
      self.block_3_1 = DenoisingBlock(filters_3, filters_3 // 2, filters_3)

      # Decoder
      # Level 2:
      self.up_2 = UpsampleBlock(filters_3, filters_2, filters_2)
      self.block_2_2 = DenoisingBlock(filters_2, filters_2 // 2, filters_2)
      self.block_2_3 = DenoisingBlock(filters_2, filters_2 // 2, filters_2)

      # Level 1:
      self.up_1 = UpsampleBlock(filters_2, filters_1, filters_1)
      self.block_1_2 = DenoisingBlock(filters_1, filters_1 // 2, filters_1)
      self.block_1_3 = DenoisingBlock(filters_1, filters_1 // 2, filters_1)

      # Level 0:
      self.up_0 = UpsampleBlock(filters_1, filters_0, filters_0)
      self.block_0_2 = DenoisingBlock(filters_0, filters_0 // 2, filters_0)
      self.block_0_3 = DenoisingBlock(filters_0, filters_0 // 2, filters_0)

      self.output_block = OutputBlock(filters_0, channels)


  def forward(self, inputs):
      inputs = self.drop(inputs)
      out_0 = self.input_block(inputs)    # Level 0
      out_0 = self.block_0_0(out_0)
      out_0 = self.block_0_1(out_0)

      out_1 = self.down_0(out_0)          # Level 1
      out_1 = self.block_1_0(out_1)
      out_1 = self.block_1_1(out_1)

      out_2 = self.down_1(out_1)          # Level 2
      out_2 = self.block_2_0(out_2)
      out_2 = self.block_2_1(out_2)

      out_3 = self.down_2(out_2)          # Level 3 (Bottleneck)

      out_3 = self.block_3_0(out_3)
      out_3 = self.block_3_1(out_3)



      out_4 = self.up_2([out_3, out_2])   # Level 2
      out_4 = self.block_2_2(out_4)
      out_4 = self.block_2_3(out_4)

      out_5 = self.up_1([out_4, out_1])   # Level 1
      out_5 = self.block_1_2(out_5)
      out_5 = self.block_1_3(out_5)

      out_6 = self.up_0([out_5, out_0])   # Level 0
      out_6 = self.block_0_2(out_6)
      out_6 = self.block_0_3(out_6)

      return self.output_block(out_6), out_3


Loading Optimizers, Schedulers and Training the network

In [13]:
train_data = ds['train']
val_data = ds['validation']

In [14]:
from tqdm import tqdm
from datetime import datetime
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')

```
Sanity Testing the Segmentaion Loss before plugging it into the loss model
```

In [15]:
def add_noise_inloop(depth,cfg=rgb_config):
  R_noise  = cfg["Rvar"]*(torch.randn(depth.shape,device=device))+cfg["R_u"]
  G_noise  =cfg["Gvar"]*(torch.randn(depth.shape,device=device))+cfg["G_u"]
  B_noise  = cfg["Bvar"]*(torch.randn(depth.shape,device=device))+cfg["B_u"]
  noise = cfg["R_w"]*R_noise+cfg["G_w"]*G_noise+cfg["B_w"]*B_noise
  noise[:,1:,:,:] = 0
  depth = depth+noise
  return depth, noise[:,0,:,:].unsqueeze(1)

In [16]:
def test_loop():
  error = []
  color=config["color"]
  truedepth = config["depth"]
  rawdepth= config["rawDepth"]
  length = len(color)
  with torch.no_grad():
    for i in (range(length)):
      depth_img = torch.tensor(rawdepth[i],dtype=torch.float32).to(device).reshape(1,1,480,640)
      color_img  = torch.tensor(color[i],dtype=torch.float32).to(device).permute(2,0,1).unsqueeze(0)/255
      stacked = torch.hstack((depth_img,color_img))
      Z,n= add_noise_inloop(stacked,rgb_config)
      noise_estimate = ResUnet(Z)
      denoised_depth = (Z[:,0,:,:].unsqueeze(1) - n_estimate)
      mask =depth_img> 0
      masked_denoised_depth =denoised_depth*mask
      gt = truedepth[i]
      gt_tensor = torch.tensor(gt,dtype=torch.float).to(device)
      loss = (masked_denoised_depth-gt_tensor*mask).to("cpu").detach().numpy()
      error.append(loss)
    mae = [np.mean(abs(i)) for i in error]
    rmse = [np.sqrt(np.mean(i**2)) for i in error]
    print("Average mean squared error is ",np.mean(mae)*1000,"mm")
    print("Average root mean squared error is ",np.mean(rmse)*1000,"mm")

In [None]:
mse_loss = nn.MSELoss()
n_epochs = 5
lr = 1e-4
best_vloss = 1_000_000.
ResUnet = Resnet_UNet(**model_configs['model'])
ResUnet.to(device)
bs = 4
freq = bs * 200;
optimizer= torch.optim.AdamW(filter(lambda p:p.requires_grad,ResUnet.parameters()),lr=1e-3,weight_decay = 1e-4)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)
for epoch in (range(n_epochs)):
  train_data.shuffle()
  val_data.shuffle()

  running_loss = 0

  ## Training Loop
  for i in tqdm(range(0,len(train_data),bs)):
    with torch.autocast(device_type='cuda', dtype=torch.float16):
      optimizer.zero_grad()
      color = train_data[i:i+bs]['image'].permute(0,3,1,2)/255 ## Channels first and normalizing to 0 and 1
      depth = train_data[i:i+bs]['depth_map'].unsqueeze(0).permute(1,0,2,3)
      stacked = torch.hstack((depth,color))
      Z,n = add_noise_inloop(stacked,rgb_config)
      n_estimate,denoised_embedding = ResUnet(Z)

      _,true_encoding = ResUnet(stacked)
      denoised_depth_estimate = (Z[:,0,:,:].unsqueeze(1) - n_estimate)
      loss = mse_loss(depth,denoised_depth_estimate)
      loss.backward(retain_graph=True)
      loss = mse_loss(true_encoding,denoised_embedding)
      loss.backward()
      optimizer.step()
      running_loss+=loss.item()
      if(i%freq==0 and i!=0):
        avg_loss = running_loss/freq
        running_loss = 0
        running_vloss = 0;
        with torch.no_grad():
          for count in range(len(val_data)):
            optimizer.zero_grad()
            color = train_data[i]['image'].permute(2,0,1).unsqueeze(0)/255 ## Channels first and normalizing to 0 and 1
            depth = train_data[i]['depth_map'].reshape(1,1,480,640)
            stacked = torch.hstack((depth,color))
            Z,n = add_noise_inloop(stacked,rgb_config)
            n_estimate,_ = ResUnet(Z)
            denoised_depth_estimate = (Z[:,0,:,:].unsqueeze(1) - n_estimate)
            loss = mse_loss(depth,denoised_depth_estimate)
            running_vloss+=loss.item()
          avg_vloss = running_vloss/len(val_data)
          print("Loss train {} valid {}".format(avg_loss,avg_vloss))
          if avg_loss < best_vloss:
            best_vloss = avg_loss
            model_path = 'model_{}_{}'.format(timestamp, i)
            torch.save(ResUnet.state_dict(), parent_path+"Autoencoder_MSE.pth")
        test_loop()
        scheduler.step()


  2%|▏         | 200/11896 [03:19<3:43:20,  1.15s/it]

Loss train 7.664919309036122e-06 valid 0.00026457398518064094


  2%|▏         | 201/11896 [04:22<64:14:24, 19.77s/it]

Average mean squared error is  13.480965979397297 mm
Average root mean squared error is  31.604032963514328 mm


  3%|▎         | 400/11896 [07:33<3:03:25,  1.04it/s]

Loss train 2.8091087460779818e-06 valid 0.00027125381063762484


  3%|▎         | 401/11896 [08:38<63:51:44, 20.00s/it]

Average mean squared error is  13.775179162621498 mm
Average root mean squared error is  31.81399405002594 mm


  5%|▌         | 600/11896 [11:48<2:56:35,  1.07it/s]

Loss train 2.9238418812838065e-06 valid 0.00043242143024548046


  5%|▌         | 601/11896 [12:53<63:19:17, 20.18s/it]

Average mean squared error is  16.498837620019913 mm
Average root mean squared error is  34.72406789660454 mm


  7%|▋         | 800/11896 [16:03<2:54:35,  1.06it/s]

Loss train 3.772790783500568e-06 valid 0.00023698442539827477


  7%|▋         | 801/11896 [17:06<60:10:28, 19.52s/it]

Average mean squared error is  12.967951595783234 mm
Average root mean squared error is  31.190648674964905 mm


  8%|▊         | 1000/11896 [20:17<2:52:57,  1.05it/s]

Loss train 2.4887033322329444e-06 valid 0.00022806708713148262


  8%|▊         | 1001/11896 [21:20<60:01:07, 19.83s/it]

Average mean squared error is  12.884109281003475 mm
Average root mean squared error is  31.03976882994175 mm


 10%|█         | 1200/11896 [24:32<2:46:21,  1.07it/s]

Loss train 2.576449904836409e-06 valid 0.00016075204100637108


 10%|█         | 1201/11896 [25:33<57:04:39, 19.21s/it]

Average mean squared error is  11.264032684266567 mm
Average root mean squared error is  29.60657700896263 mm


 12%|█▏        | 1400/11896 [28:44<2:45:50,  1.05it/s]

Loss train 4.520176840969725e-06 valid 0.0003166335292850459


 12%|█▏        | 1401/11896 [29:48<57:20:07, 19.67s/it]

Average mean squared error is  14.540267176926136 mm
Average root mean squared error is  32.67379850149155 mm


 13%|█▎        | 1600/11896 [33:06<4:04:48,  1.43s/it]

Loss train 4.428333668329288e-06 valid 0.00028264007730650166


 13%|█▎        | 1601/11896 [34:10<57:22:56, 20.07s/it]

Average mean squared error is  13.825428672134876 mm
Average root mean squared error is  31.988907605409622 mm


 15%|█▌        | 1800/11896 [37:34<2:37:55,  1.07it/s]

Loss train 4.7433197244117765e-06 valid 0.00037616019946948797


 15%|█▌        | 1801/11896 [38:36<53:49:48, 19.20s/it]

Average mean squared error is  15.645939856767654 mm
Average root mean squared error is  33.78966450691223 mm


 17%|█▋        | 2000/11896 [42:41<2:42:08,  1.02it/s]

Loss train 4.568801292066382e-06 valid 0.00028768729776128363


 17%|█▋        | 2001/11896 [43:44<54:20:20, 19.77s/it]

Average mean squared error is  14.119083061814308 mm
Average root mean squared error is  32.308779656887054 mm


 18%|█▊        | 2200/11896 [47:04<2:36:09,  1.03it/s]

Loss train 5.153584348676077e-06 valid 0.00027505204661653894


 19%|█▊        | 2201/11896 [48:08<53:08:21, 19.73s/it]

Average mean squared error is  13.70152086019516 mm
Average root mean squared error is  31.8513922393322 mm


 20%|██        | 2400/11896 [51:46<2:31:27,  1.04it/s]

Loss train 4.612522077422909e-06 valid 0.00020960979996265455


 20%|██        | 2401/11896 [52:48<51:16:01, 19.44s/it]

Average mean squared error is  12.439042329788208 mm
Average root mean squared error is  30.66541999578476 mm


 22%|██▏       | 2600/11896 [56:19<2:27:11,  1.05it/s]

Loss train 3.4538594343302976e-06 valid 0.00025002701195179047


 22%|██▏       | 2601/11896 [57:23<51:11:32, 19.83s/it]

Average mean squared error is  13.234775513410568 mm
Average root mean squared error is  31.380321830511093 mm


 24%|██▎       | 2800/11896 [1:00:50<2:26:25,  1.04it/s]

Loss train 3.833217834880997e-06 valid 0.00032378704410291275


 24%|██▎       | 2801/11896 [1:01:53<49:05:38, 19.43s/it]

Average mean squared error is  14.530714601278305 mm
Average root mean squared error is  32.750461250543594 mm


 25%|██▌       | 3000/11896 [1:05:21<2:22:03,  1.04it/s]

Loss train 4.656686670045929e-06 valid 0.00017357320884411404


 25%|██▌       | 3001/11896 [1:06:24<48:18:23, 19.55s/it]

Average mean squared error is  11.587334796786308 mm
Average root mean squared error is  29.915764927864075 mm


 27%|██▋       | 3200/11896 [1:09:57<2:20:59,  1.03it/s]

Loss train 3.8120650140172073e-06 valid 0.0002930659321222215


 27%|██▋       | 3201/11896 [1:11:00<47:05:31, 19.50s/it]

Average mean squared error is  14.137591235339642 mm
Average root mean squared error is  32.262276858091354 mm


 29%|██▊       | 3400/11896 [1:14:30<2:10:39,  1.08it/s]

Loss train 3.5896845452043635e-06 valid 0.00018893452729696597


 29%|██▊       | 3401/11896 [1:15:33<45:38:02, 19.34s/it]

Average mean squared error is  11.915131472051144 mm
Average root mean squared error is  30.210932716727257 mm


 30%|███       | 3600/11896 [1:19:06<2:33:15,  1.11s/it]

Loss train 3.871531046115706e-06 valid 0.00023565560413132185


 30%|███       | 3601/11896 [1:20:08<44:53:14, 19.48s/it]

Average mean squared error is  12.862281873822212 mm
Average root mean squared error is  31.111907213926315 mm


 32%|███▏      | 3800/11896 [1:23:52<2:47:58,  1.24s/it]

Loss train 6.942611923932418e-06 valid 0.00032828567907020495


 32%|███▏      | 3801/11896 [1:24:55<44:12:40, 19.66s/it]

Average mean squared error is  14.439752325415611 mm
Average root mean squared error is  32.71646425127983 mm


 34%|███▎      | 4000/11896 [1:28:18<2:07:42,  1.03it/s]

Loss train 5.581360904898247e-06 valid 0.0002086786379041193


 34%|███▎      | 4001/11896 [1:29:20<42:52:08, 19.55s/it]

Average mean squared error is  12.31416966766119 mm
Average root mean squared error is  30.589023604989052 mm


 35%|███▌      | 4200/11896 [1:32:53<2:50:55,  1.33s/it]

Loss train 4.875171767935171e-06 valid 0.00024059009855304807


 35%|███▌      | 4201/11896 [1:33:55<41:58:48, 19.64s/it]

Average mean squared error is  13.08816485106945 mm
Average root mean squared error is  31.299002468585968 mm


 37%|███▋      | 4400/11896 [1:37:24<5:55:38,  2.85s/it]

Loss train 5.38593116687025e-06 valid 0.00024514627653253847


 37%|███▋      | 4401/11896 [1:38:27<43:03:31, 20.68s/it]

Average mean squared error is  13.031548820436 mm
Average root mean squared error is  31.30243346095085 mm


 39%|███▊      | 4600/11896 [1:41:48<1:51:40,  1.09it/s]

Loss train 4.928988448114069e-06 valid 0.00023857798556486053


 39%|███▊      | 4601/11896 [1:42:48<37:54:32, 18.71s/it]

Average mean squared error is  12.798326089978218 mm
Average root mean squared error is  31.086590141057968 mm


 40%|████      | 4800/11896 [1:46:22<1:52:29,  1.05it/s]

Loss train 5.390002751823886e-06 valid 0.00029604633377732523


 40%|████      | 4801/11896 [1:47:24<37:59:05, 19.27s/it]

Average mean squared error is  13.818172737956047 mm
Average root mean squared error is  32.29440003633499 mm


 42%|████▏     | 5000/11896 [1:50:48<1:46:31,  1.08it/s]

Loss train 3.661805357637604e-06 valid 0.00018098721631720023


 42%|████▏     | 5001/11896 [1:51:50<37:01:21, 19.33s/it]

Average mean squared error is  11.718083173036575 mm
Average root mean squared error is  30.032724142074585 mm


 44%|████▎     | 5200/11896 [1:55:33<1:49:06,  1.02it/s]

Loss train 4.42804703652655e-06 valid 0.0002950562174600729


 44%|████▎     | 5201/11896 [1:56:35<35:56:32, 19.33s/it]

Average mean squared error is  14.167862944304943 mm
Average root mean squared error is  32.3311910033226 mm


 45%|████▌     | 5400/11896 [2:00:18<1:42:28,  1.06it/s]

Loss train 3.993387603031806e-06 valid 0.00023641989677460916


 45%|████▌     | 5401/11896 [2:01:20<34:58:51, 19.39s/it]

Average mean squared error is  12.719853781163692 mm
Average root mean squared error is  31.090272590517998 mm


 47%|████▋     | 5600/11896 [2:04:37<1:36:59,  1.08it/s]

Loss train 5.6300820392607424e-06 valid 0.00027193492362119075


 47%|████▋     | 5601/11896 [2:05:38<33:35:18, 19.21s/it]

Average mean squared error is  13.823184184730053 mm
Average root mean squared error is  32.028622925281525 mm


 49%|████▉     | 5800/11896 [2:09:17<1:40:09,  1.01it/s]

Loss train 6.7821076231666665e-06 valid 0.0003014754633286068


 49%|████▉     | 5801/11896 [2:10:19<32:38:25, 19.28s/it]

Average mean squared error is  14.280425384640694 mm
Average root mean squared error is  32.583195716142654 mm


 50%|█████     | 6000/11896 [2:13:50<1:38:51,  1.01s/it]

Loss train 6.491944279787276e-06 valid 0.00031953392273017324


 50%|█████     | 6001/11896 [2:14:53<31:52:00, 19.46s/it]

Average mean squared error is  14.576386660337448 mm
Average root mean squared error is  32.70316123962402 mm


 52%|█████▏    | 6200/11896 [2:18:47<1:28:34,  1.07it/s]

Loss train 5.330253144393282e-06 valid 0.000273971241384044


 52%|█████▏    | 6201/11896 [2:19:50<30:46:03, 19.45s/it]

Average mean squared error is  13.707445934414864 mm
Average root mean squared error is  31.89137950539589 mm


 54%|█████▍    | 6400/11896 [2:23:33<2:12:50,  1.45s/it]

Loss train 3.5558498115051407e-06 valid 0.00018999609509778468


 54%|█████▍    | 6401/11896 [2:24:35<29:55:35, 19.61s/it]

Average mean squared error is  12.019512243568897 mm
Average root mean squared error is  30.288400128483772 mm


 55%|█████▌    | 6600/11896 [2:28:19<2:24:41,  1.64s/it]

Loss train 2.4435994569671493e-06 valid 0.00017891860952325286


 55%|█████▌    | 6601/11896 [2:29:21<29:06:07, 19.79s/it]

Average mean squared error is  11.751226149499416 mm
Average root mean squared error is  30.057262629270554 mm


 57%|█████▋    | 6800/11896 [2:33:01<1:22:10,  1.03it/s]

Loss train 1.531256263689329e-06 valid 0.000264288983384321


 57%|█████▋    | 6801/11896 [2:34:04<27:48:04, 19.64s/it]

Average mean squared error is  13.349775224924088 mm
Average root mean squared error is  31.74426406621933 mm


 59%|█████▉    | 7000/11896 [2:37:57<1:36:08,  1.18s/it]

Loss train 4.797695493721221e-06 valid 0.0002567840029666377


 59%|█████▉    | 7001/11896 [2:39:00<26:53:26, 19.78s/it]

Average mean squared error is  13.185049407184124 mm
Average root mean squared error is  31.500566750764847 mm


 61%|██████    | 7200/11896 [2:42:39<1:33:05,  1.19s/it]

Loss train 3.485629956685443e-06 valid 0.0002440360030041871


 61%|██████    | 7201/11896 [2:43:43<26:05:45, 20.01s/it]

Average mean squared error is  13.192114420235157 mm
Average root mean squared error is  31.39495849609375 mm


 62%|██████▏   | 7400/11896 [2:47:16<1:10:14,  1.07it/s]

Loss train 3.4782458493509696e-06 valid 0.00021490497902300584


 62%|██████▏   | 7401/11896 [2:48:20<24:47:43, 19.86s/it]

Average mean squared error is  12.5998854637146 mm
Average root mean squared error is  30.780963599681854 mm


 64%|██████▍   | 7600/11896 [2:52:07<1:09:10,  1.04it/s]

Loss train 3.1442027113826045e-06 valid 0.00021514686701010446


 64%|██████▍   | 7601/11896 [2:53:09<23:10:21, 19.42s/it]

Average mean squared error is  12.41916511207819 mm
Average root mean squared error is  30.703073367476463 mm


 66%|██████▌   | 7800/11896 [2:56:51<1:04:35,  1.06it/s]

Loss train 2.8792888821271843e-06 valid 0.00022062077638689595


 66%|██████▌   | 7801/11896 [2:57:53<22:01:09, 19.36s/it]

Average mean squared error is  12.662212364375591 mm
Average root mean squared error is  30.888337641954422 mm


 67%|██████▋   | 8000/11896 [3:01:45<1:01:48,  1.05it/s]

Loss train 4.4510257282581735e-06 valid 0.00024079408356261265


 67%|██████▋   | 8001/11896 [3:02:48<21:08:04, 19.53s/it]

Average mean squared error is  13.096807524561882 mm
Average root mean squared error is  31.302105635404587 mm


 69%|██████▉   | 8200/11896 [3:06:34<1:01:03,  1.01it/s]

Loss train 3.857449151496439e-06 valid 0.0002670480478722289


 69%|██████▉   | 8201/11896 [3:07:36<19:50:15, 19.33s/it]

Average mean squared error is  13.446730561554432 mm
Average root mean squared error is  31.723804771900177 mm


 69%|██████▉   | 8224/11896 [3:08:01<1:04:31,  1.05s/it]

Eval

In [None]:

error = []
color=config["color"]
truedepth = config["depth"]
rawdepth= config["rawDepth"]
length = len(color)
denoised_depth_images = np.zeros_like(color)
with torch.no_grad():
  for i in (range(length)):
    depth_img = torch.tensor(rawdepth[i],dtype=torch.float32).to(device).reshape(1,1,480,640)
    color_img  = torch.tensor(color[i],dtype=torch.float32).to(device).permute(2,0,1).unsqueeze(0)/255
    stacked = torch.hstack((depth_img,color_img))
    Z,n= add_noise_inloop(stacked,rgb_config)
    noise_estimate = ResUnet(Z)
    denoised_depth = (Z[:,0,:,:].unsqueeze(1) - n_estimate)
    mask =depth_img> 0
    masked_denoised_depth =denoised_depth*mask
    gt = truedepth[i]
    gt_tensor = torch.tensor(gt,dtype=torch.float).to(device)
    loss = (masked_denoised_depth-gt_tensor*mask).to("cpu").detach().numpy()
    denoised_depth_images[i,...] = denoised_depth.detach().to("cpu").numpy()
    error.append(loss)
mae = [np.mean(abs(i)) for i in error]
rmse = [np.sqrt(np.mean(i**2)) for i in error]
print("Average mean squared error is ",np.mean(mae)*1000,"mm")
print("Average root mean squared error is ",np.mean(rmse)*1000,"mm")
random_indice = np.random.randint(low=0,high = len(mae),size=5)
denoised_images = denoised_depth_images[random_indice]
mean_squared_error = (truedepth[random_indice] - denoised_depth_images[random_indice])**2
fig,axs = plt.subplots(5,3)
for i in range(0,5):
      axs[i,0].imshow(truedepth[random_indice[i]],cmap="jet")
      axs[i,1].imshow(denoised_depth_images[random_indice[i]],cmap="jet")
      axs[i,2].imshow(mean_squared_error[i],cmap='jet') ## There will be a band of error on the borders, because of how the sensor was setup
plt.show()