In [1]:
import os
import torch

# Check the avtive DNN
######################
print(torch.cuda.device_count())

GPU_number = 1 # if you wanna use A5000, 2
torch.cuda.set_device(GPU_number)
print(torch.cuda.get_device_name())

if not torch.cuda.is_available():
    raise Exception('NO GPU!')

3
NVIDIA GeForce RTX 3090


In [2]:
import torch
from torch.utils.data import Dataset
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.nn import Module
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import DataLoader
import scipy.io as scio
import h5py
import numpy as np
import time
import datetime
import math
import sys

# Create path for Dataset
#########################

class imgdataset(Dataset):
    def __init__(self, path):
        super(Imgdataset, self).__init__()
        self.data = []
        if os.path.exists(path):
            dir_list = os.listdir(path)
            dir_list = sorted(dir_list)
            self.data = [{'ground_truth': path + '/' + dir_list[i]} for i in range(len(dir_list))]

        else:
            raise FileNotFoundError('path doesnt exist!')

    def __getitem__(self, index):

        ground_truth= self.data[index]["ground_truth"]
        
        with h5py.File(ground_truth, 'r') as f:
            gt = torch.from_numpy(np.array(f["subframe_ideal128"])) #32,256,256=c,w,h
            depth = torch.from_numpy(np.array(f["depth_1024"]))
            
        gt = gt.permute(0,1,3,2)#c,h,w
        depth = depth.permute(1,0)#c,h,w
        return gt, depth
        
    def __len__(self):

        return len(self.data)
        
class Imgdataset_multipath(Dataset):

    def __init__(self, path):
        super(Imgdataset_multipath, self).__init__()
        self.data = []
        if os.path.exists(path):
            dir_list = os.listdir(path)
            dir_list = sorted(dir_list)
            self.data = [{'ground_truth': path + '/' + dir_list[i]} for i in range(len(dir_list))]

        else:
            raise FileNotFoundError('path doesnt exist!')

    def __getitem__(self, index):

        ground_truth= self.data[index]["ground_truth"]
        with h5py.File(ground_truth, 'r') as f:
            gt = torch.from_numpy(np.array(f["subframe128"])) #32,256,256=c,w,h
            depth = torch.from_numpy(np.array(f["depth_1024"]))
            
        gt = gt.permute(0,1,3,2)#c,h,w
        depth = depth.permute(1,0)#c,h,w
        return gt, depth

    def __len__(self):
        return len(self.data)

# dnn architecture
##################

class double_conv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(double_conv, self).__init__()
        self.d_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, padding=1),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x = self.d_conv(x)
        return x

class Unet(nn.Module):

    def __init__(self,in_ch, out_ch):
        super(Unet, self).__init__()
                
        self.dconv_down1 = double_conv(in_ch, 32)
        self.dconv_down2 = double_conv(32, 64)
        self.dconv_down3 = double_conv(64, 128)       

        self.maxpool = nn.MaxPool2d(2)
        self.upsample2 = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2),
            nn.ReLU(inplace=True)
        )
        
        self.upsample1 = nn.Sequential(
            nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2),
            nn.ReLU(inplace=True)
        )
        
        self.dconv_up2 = double_conv(64 + 64, 64)
        self.dconv_up1 = double_conv(32 + 32, 32)
        
        self.conv_last = nn.Conv2d(32, out_ch, 1)
        self.afn_last = nn.Tanh()
        
        
    def forward(self, x):
        
        inputs = x
        conv1 = self.dconv_down1(x)
        x = self.maxpool(conv1)

        conv2 = self.dconv_down2(x)
        x = self.maxpool(conv2)
        
        conv3 = self.dconv_down3(x)
        x = self.upsample2(conv3)
        
        x = torch.cat([x, conv2], dim=1)
        x = self.dconv_up2(x)
        x = self.upsample1(x)        
        x = torch.cat([x, conv1], dim=1)       

        x = self.dconv_up1(x)  
        x = self.conv_last(x)
        x = self.afn_last(x)
        
        out = x + inputs
        
        return out

class Unet128(nn.Module):

    def __init__(self,in_ch, out_ch):
        super(Unet128, self).__init__()
                
        self.dconv_down1 = double_conv(in_ch, 128)
        self.dconv_down2 = double_conv(128, 128)
        self.dconv_down3 = double_conv(128, 256)       

        self.maxpool = nn.MaxPool2d(2)
        self.upsample2 = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2),
            nn.ReLU(inplace=True)
        )
        self.upsample1 = nn.Sequential(
            nn.ConvTranspose2d(128, 128, kernel_size=2, stride=2),
            nn.ReLU(inplace=True)
        )
        self.dconv_up2 = double_conv(128 + 128, 128)
        self.dconv_up1 = double_conv(128 + 128, 128)
        
        self.conv_last = nn.Conv2d(128, out_ch, 1)
        self.afn_last = nn.Tanh()
        
        
    def forward(self, x):
        inputs = x
        conv1 = self.dconv_down1(x)       
        x = self.maxpool(conv1)

        conv2 = self.dconv_down2(x)
        x = self.maxpool(conv2)
       
        conv3 = self.dconv_down3(x)     
        x = self.upsample2(conv3)      
        x = torch.cat([x, conv2], dim=1)
        
        x = self.dconv_up2(x)
        x = self.upsample1(x)        
        x = torch.cat([x, conv1], dim=1)       

        x = self.dconv_up1(x)  
        x = self.conv_last(x)
        x = self.afn_last(x)
        
        out = x + inputs
        
        return out

class Unet_depth(nn.Module):

    def __init__(self,in_ch, out_ch):
        super(Unet_depth, self).__init__()
                
        self.dconv_down1 = double_conv(in_ch, 32)
        self.dconv_down2 = double_conv(32, 64)
        self.dconv_down3 = double_conv(64, 128)       

        self.maxpool = nn.MaxPool2d(2)
        self.upsample2 = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2),
            nn.ReLU(inplace=True)
        )
        self.upsample1 = nn.Sequential(
            nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2),
            nn.ReLU(inplace=True)
        )
        self.dconv_up2 = double_conv(64 + 64, 64)
        self.dconv_up1 = double_conv(32 + 32, 32)
        
        self.conv_last1 = nn.Conv2d(32, 32, 1)
        self.afn_last = nn.Tanh()
        self.conv_last = nn.Conv2d(32, out_ch, 1)
        
    def forward(self, x):
        inputs = x
        conv1 = self.dconv_down1(x)
        x = self.maxpool(conv1)

        conv2 = self.dconv_down2(x)
        x = self.maxpool(conv2)
        
        conv3 = self.dconv_down3(x)

        
        x = self.upsample2(conv3)        
        x = torch.cat([x, conv2], dim=1)
        
        x = self.dconv_up2(x)
        x = self.upsample1(x)        
        x = torch.cat([x, conv1], dim=1)       

        x = self.dconv_up1(x)  
        
        x = self.conv_last1(x)
        x = self.afn_last(x)
        out = x + inputs
        out = self.conv_last(out)
        return out

class Unet_depth128(nn.Module):

    def __init__(self,in_ch, out_ch):
        super(Unet_depth128, self).__init__()
                
        self.dconv_down1 = double_conv(in_ch, 128)
        self.dconv_down2 = double_conv(128, 128)
        self.dconv_down3 = double_conv(128, 256)       

        self.maxpool = nn.MaxPool2d(2)
        self.upsample2 = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2),
            nn.ReLU(inplace=True)
        )
        self.upsample1 = nn.Sequential(
            nn.ConvTranspose2d(128, 128, kernel_size=2, stride=2),
            nn.ReLU(inplace=True)
        )
        self.dconv_up2 = double_conv(128 + 128, 128)
        self.dconv_up1 = double_conv(128 + 128, 128)
        
        self.conv_last1 = nn.Conv2d(128, 128, 1)
        self.afn_last = nn.Tanh()
        self.conv_last = nn.Conv2d(128, out_ch, 1)    
        
    def forward(self, x):
        inputs = x
        conv1 = self.dconv_down1(x)
        x = self.maxpool(conv1)

        conv2 = self.dconv_down2(x)
        x = self.maxpool(conv2)
        
        conv3 = self.dconv_down3(x)
        
        x = self.upsample2(conv3)        
        x = torch.cat([x, conv2], dim=1)
        
        x = self.dconv_up2(x)
        x = self.upsample1(x)        
        x = torch.cat([x, conv1], dim=1)       

        x = self.dconv_up1(x)  
        
        x = self.conv_last1(x)
        x = self.afn_last(x)
        out = x + inputs
        out = self.conv_last(out)
        return out
    

# main Network ( 32) 
##############
class ADMM(nn.Module):

    def __init__(self):
        super(ADMM, self).__init__()
                
        self.unet1 = Unet(32, 32)
        self.unet2 = Unet(32, 32)
        self.unet3 = Unet(32, 32)
        self.unet4 = Unet(32, 32)
        self.unet5 = Unet(32, 32)
        self.unet6 = Unet(32, 32)
        self.unet7 = Unet(32, 32)
        self.unet8 = Unet(32, 32)
        self.unet9 = Unet(32, 32)
        self.gamma1 = torch.nn.Parameter(torch.Tensor([0]))
        self.gamma2 = torch.nn.Parameter(torch.Tensor([0]))
        self.gamma3 = torch.nn.Parameter(torch.Tensor([0]))
        self.gamma4 = torch.nn.Parameter(torch.Tensor([0]))
        self.gamma5 = torch.nn.Parameter(torch.Tensor([0]))
        self.gamma6 = torch.nn.Parameter(torch.Tensor([0]))
        self.gamma7 = torch.nn.Parameter(torch.Tensor([0]))
        self.gamma8 = torch.nn.Parameter(torch.Tensor([0]))
        self.gamma9 = torch.nn.Parameter(torch.Tensor([0]))

    def forward(self, y, Phi, Phi_r, Phi_s):
        x_list = []
        d_list = []
        theta = At(y,Phi_r)
        b = torch.zeros_like(theta)
        ### 1-3
        yb = A(theta+b,Phi_r)
        x = theta+b + At(torch.div(y-yb,Phi_s+self.gamma1),Phi_r)
        x1 = x-b
        theta = self.unet1(x1)
        b = b- (x-theta)
        x_list.append(theta)
        yb = A(theta+b,Phi_r)
        x = theta+b + At(torch.div(y-yb,Phi_s+self.gamma2),Phi_r)
        x1 = x-b
        theta = self.unet2(x1)
        b = b- (x-theta)
        x_list.append(theta)
        yb = A(theta+b,Phi_r)
        x = theta+b + At(torch.div(y-yb,Phi_s+self.gamma3),Phi_r)
        x1 = x-b
        theta = self.unet3(x1)
        b = b- (x-theta)
        x_list.append(theta)
        ### 4-6
        yb = A(theta+b,Phi_r)
        x = theta+b + At(torch.div(y-yb,Phi_s+self.gamma4),Phi_r)
        x1 = x-b
        theta = self.unet4(x1)
        b = b- (x-theta)
        x_list.append(theta)
        yb = A(theta+b,Phi_r)
        x = theta+b + At(torch.div(y-yb,Phi_s+self.gamma5),Phi_r)
        x1 = x-b
        theta = self.unet5(x1)
        b = b- (x-theta)
        x_list.append(theta)
        yb = A(theta+b,Phi_r)
        x = theta+b + At(torch.div(y-yb,Phi_s+self.gamma6),Phi_r)
        x1 = x-b
        theta = self.unet6(x1)
        b = b- (x-theta)
        x_list.append(theta)
        ### 7-9
        yb = A(theta+b,Phi_r)
        x = theta+b + At(torch.div(y-yb,Phi_s+self.gamma7),Phi_r)
        x1 = x-b
        theta = self.unet7(x1)
        output_depth = torch.argmax(theta,dim = 1)
        d_list.append(output_depth)
        
        b = b- (x-theta)
        x_list.append(theta)
        yb = A(theta+b,Phi_r)
        x = theta+b + At(torch.div(y-yb,Phi_s+self.gamma8),Phi_r)
        x1 = x-b
        theta = self.unet8(x1)
        output_depth = torch.argmax(theta,dim = 1)
        d_list.append(output_depth)
        
        b = b- (x-theta)
        x_list.append(theta)
        yb = A(theta+b,Phi_r)
        x = theta+b + At(torch.div(y-yb,Phi_s+self.gamma9),Phi_r)
        x1 = x-b
        theta = self.unet9(x1)
        output_depth = torch.argmax(theta,dim = 1)
        d_list.append(output_depth)
        b = b- (x-theta)
        x_list.append(theta)
        x_list.append(d_list[-3])
        x_list.append(d_list[-2])
        x_list.append(d_list[-1])
        
        output_list = x_list[-6:]
        
        return output_list


class ADMM_depthnet(nn.Module):

    def __init__(self):
        super(ADMM_depthnet, self).__init__()
        
        self.depthnet =  Unet_depth(32,1)        
        self.unet1 = Unet(32, 32)
        self.unet2 = Unet(32, 32)
        self.unet3 = Unet(32, 32)
        self.unet4 = Unet(32, 32)
        self.unet5 = Unet(32, 32)
        self.unet6 = Unet(32, 32)
        self.unet7 = Unet(32, 32)
        self.unet8 = Unet(32, 32)
        self.unet9 = Unet(32, 32)
        self.gamma1 = torch.nn.Parameter(torch.Tensor([0]))
        self.gamma2 = torch.nn.Parameter(torch.Tensor([0]))
        self.gamma3 = torch.nn.Parameter(torch.Tensor([0]))
        self.gamma4 = torch.nn.Parameter(torch.Tensor([0]))
        self.gamma5 = torch.nn.Parameter(torch.Tensor([0]))
        self.gamma6 = torch.nn.Parameter(torch.Tensor([0]))
        self.gamma7 = torch.nn.Parameter(torch.Tensor([0]))
        self.gamma8 = torch.nn.Parameter(torch.Tensor([0]))
        self.gamma9 = torch.nn.Parameter(torch.Tensor([0]))

    def forward(self, y, Phi, Phi_r, Phi_s):
        x_list = []
        d_list = []
        theta = At(y,Phi_r)
        b = torch.zeros_like(theta)
        ### 1-3
        yb = A(theta+b,Phi_r)
        x = theta+b + At(torch.div(y-yb,Phi_s+self.gamma1),Phi_r)
        x1 = x-b
        theta = self.unet1(x1)
        b = b- (x-theta)
        x_list.append(theta)
        yb = A(theta+b,Phi_r)
        x = theta+b + At(torch.div(y-yb,Phi_s+self.gamma2),Phi_r)
        x1 = x-b
        theta = self.unet2(x1)
        b = b- (x-theta)
        x_list.append(theta)
        yb = A(theta+b,Phi_r)
        x = theta+b + At(torch.div(y-yb,Phi_s+self.gamma3),Phi_r)
        x1 = x-b
        theta = self.unet3(x1)
        b = b- (x-theta)
        x_list.append(theta)
        ### 4-6
        yb = A(theta+b,Phi_r)
        x = theta+b + At(torch.div(y-yb,Phi_s+self.gamma4),Phi_r)
        x1 = x-b
        theta = self.unet4(x1)
        b = b- (x-theta)
        x_list.append(theta)
        yb = A(theta+b,Phi_r)
        x = theta+b + At(torch.div(y-yb,Phi_s+self.gamma5),Phi_r)
        x1 = x-b
        theta = self.unet5(x1)
        b = b- (x-theta)
        x_list.append(theta)
        yb = A(theta+b,Phi_r)
        x = theta+b + At(torch.div(y-yb,Phi_s+self.gamma6),Phi_r)
        x1 = x-b
        theta = self.unet6(x1)
        b = b- (x-theta)
        x_list.append(theta)
        ### 7-9
        yb = A(theta+b,Phi_r)
        x = theta+b + At(torch.div(y-yb,Phi_s+self.gamma7),Phi_r)
        x1 = x-b
        theta = self.unet7(x1)
        output_depth = self.depthnet(theta)
        d_list.append(output_depth)
        
        b = b- (x-theta)
        x_list.append(theta)
        yb = A(theta+b,Phi_r)
        x = theta+b + At(torch.div(y-yb,Phi_s+self.gamma8),Phi_r)
        x1 = x-b
        theta = self.unet8(x1)
        output_depth = self.depthnet(theta)
        d_list.append(output_depth)
        
        b = b- (x-theta)
        x_list.append(theta)
        yb = A(theta+b,Phi_r)
        x = theta+b + At(torch.div(y-yb,Phi_s+self.gamma9),Phi_r)
        x1 = x-b
        theta = self.unet9(x1)
        output_depth = self.depthnet(theta)
        d_list.append(output_depth)
        b = b- (x-theta)
        x_list.append(theta)
        x_list.append(d_list[-3])
        x_list.append(d_list[-2])
        x_list.append(d_list[-1])
        
        output_list = x_list[-6:]     
        return output_list

# main Network (128) 
##############

class ADMM_128(nn.Module):
    def __init__(self):
        super(ADMM_128, self).__init__()
               
        self.unet1 = Unet(128, 128)
        self.unet2 = Unet128(128, 128)
        self.unet3 = Unet128(128, 128)
        self.unet4 = Unet128(128, 128)
        self.unet5 = Unet128(128, 128)
        self.unet6 = Unet128(128, 128)
        self.unet7 = Unet128(128, 128)
        self.unet8 = Unet128(128, 128)
        self.unet9 = Unet128(128, 128)
        self.gamma1 = torch.nn.Parameter(torch.Tensor([0]))
        self.gamma2 = torch.nn.Parameter(torch.Tensor([0]))
        self.gamma3 = torch.nn.Parameter(torch.Tensor([0]))
        self.gamma4 = torch.nn.Parameter(torch.Tensor([0]))
        self.gamma5 = torch.nn.Parameter(torch.Tensor([0]))
        self.gamma6 = torch.nn.Parameter(torch.Tensor([0]))
        self.gamma7 = torch.nn.Parameter(torch.Tensor([0]))
        self.gamma8 = torch.nn.Parameter(torch.Tensor([0]))
        self.gamma9 = torch.nn.Parameter(torch.Tensor([0]))

    def forward(self, y, Phi, Phi_r, Phi_s):
        x_list = []
        d_list = []
        theta = At(y,Phi_r)
        b = torch.zeros_like(theta)
        ### 1-3
        yb = A(theta+b,Phi_r)
        x = theta+b + At(torch.div(y-yb,Phi_s+self.gamma1),Phi_r)
        x1 = x-b
        theta = self.unet1(x1)
        b = b- (x-theta)
        x_list.append(theta)
        yb = A(theta+b,Phi_r)
        x = theta+b + At(torch.div(y-yb,Phi_s+self.gamma2),Phi_r)
        x1 = x-b
        theta = self.unet2(x1)
        b = b- (x-theta)
        x_list.append(theta)
        yb = A(theta+b,Phi_r)
        x = theta+b + At(torch.div(y-yb,Phi_s+self.gamma3),Phi_r)
        x1 = x-b
        theta = self.unet3(x1)
        b = b- (x-theta)
        x_list.append(theta)
        ### 4-6
        yb = A(theta+b,Phi_r)
        x = theta+b + At(torch.div(y-yb,Phi_s+self.gamma4),Phi_r)
        x1 = x-b
        theta = self.unet4(x1)
        b = b- (x-theta)
        x_list.append(theta)
        yb = A(theta+b,Phi_r)
        x = theta+b + At(torch.div(y-yb,Phi_s+self.gamma5),Phi_r)
        x1 = x-b
        theta = self.unet5(x1)
        b = b- (x-theta)
        x_list.append(theta)
        yb = A(theta+b,Phi_r)
        x = theta+b + At(torch.div(y-yb,Phi_s+self.gamma6),Phi_r)
        x1 = x-b
        theta = self.unet6(x1)
        b = b- (x-theta)
        x_list.append(theta)
        ### 7-9
        yb = A(theta+b,Phi_r)
        x = theta+b + At(torch.div(y-yb,Phi_s+self.gamma7),Phi_r)
        x1 = x-b
        theta = self.unet7(x1)
        output_depth = torch.argmax(theta,dim = 1)
        d_list.append(output_depth)
        
        b = b- (x-theta)
        x_list.append(theta)
        yb = A(theta+b,Phi_r)
        x = theta+b + At(torch.div(y-yb,Phi_s+self.gamma8),Phi_r)
        x1 = x-b
        theta = self.unet8(x1)
        output_depth = torch.argmax(theta,dim = 1)
        d_list.append(output_depth)
        
        b = b- (x-theta)
        x_list.append(theta)
        yb = A(theta+b,Phi_r)
        x = theta+b + At(torch.div(y-yb,Phi_s+self.gamma9),Phi_r)
        x1 = x-b
        theta = self.unet9(x1)
        output_depth = torch.argmax(theta,dim = 1)
        d_list.append(output_depth)
        b = b- (x-theta)
        x_list.append(theta)
        x_list.append(d_list[-3])
        x_list.append(d_list[-2])
        x_list.append(d_list[-1])
        
        output_list = x_list[-6:]
        
        return output_list

        
class ADMM_depthnet128(nn.Module):
    def __init__(self):
        super(ADMM_depthnet128, self).__init__()
        
        self.depthnet =  Unet_depth128(128,1)        
        self.unet1 = Unet128(128, 128)
        self.unet2 = Unet128(128, 128)
        self.unet3 = Unet128(128, 128)
        self.unet4 = Unet128(128, 128)
        self.unet5 = Unet128(128, 128)
        self.unet6 = Unet128(128, 128)
        self.unet7 = Unet128(128, 128)
        self.unet8 = Unet128(128, 128)
        self.unet9 = Unet128(128, 128)
        self.gamma1 = torch.nn.Parameter(torch.Tensor([0]))
        self.gamma2 = torch.nn.Parameter(torch.Tensor([0]))
        self.gamma3 = torch.nn.Parameter(torch.Tensor([0]))
        self.gamma4 = torch.nn.Parameter(torch.Tensor([0]))
        self.gamma5 = torch.nn.Parameter(torch.Tensor([0]))
        self.gamma6 = torch.nn.Parameter(torch.Tensor([0]))
        self.gamma7 = torch.nn.Parameter(torch.Tensor([0]))
        self.gamma8 = torch.nn.Parameter(torch.Tensor([0]))
        self.gamma9 = torch.nn.Parameter(torch.Tensor([0]))

    def forward(self, y, Phi, Phi_r, Phi_s):
        x_list = []
        d_list = []
        theta = At(y,Phi_r)
        b = torch.zeros_like(theta)
        ### 1-3
        yb = A(theta+b,Phi_r)
        x = theta+b + At(torch.div(y-yb,Phi_s+self.gamma1),Phi_r)
        x1 = x-b
        theta = self.unet1(x1)
        b = b- (x-theta)
        x_list.append(theta)
        yb = A(theta+b,Phi_r)
        x = theta+b + At(torch.div(y-yb,Phi_s+self.gamma2),Phi_r)
        x1 = x-b
        theta = self.unet2(x1)
        b = b- (x-theta)
        x_list.append(theta)
        yb = A(theta+b,Phi_r)
        x = theta+b + At(torch.div(y-yb,Phi_s+self.gamma3),Phi_r)
        x1 = x-b
        theta = self.unet3(x1)
        b = b- (x-theta)
        x_list.append(theta)
        ### 4-6
        yb = A(theta+b,Phi_r)
        x = theta+b + At(torch.div(y-yb,Phi_s+self.gamma4),Phi_r)
        x1 = x-b
        theta = self.unet4(x1)
        b = b- (x-theta)
        x_list.append(theta)
        yb = A(theta+b,Phi_r)
        x = theta+b + At(torch.div(y-yb,Phi_s+self.gamma5),Phi_r)
        x1 = x-b
        theta = self.unet5(x1)
        b = b- (x-theta)
        x_list.append(theta)
        yb = A(theta+b,Phi_r)
        x = theta+b + At(torch.div(y-yb,Phi_s+self.gamma6),Phi_r)
        x1 = x-b
        theta = self.unet6(x1)
        b = b- (x-theta)
        x_list.append(theta)
        ### 7-9
        yb = A(theta+b,Phi_r)
        x = theta+b + At(torch.div(y-yb,Phi_s+self.gamma7),Phi_r)
        x1 = x-b
        theta = self.unet7(x1)
        output_depth = self.depthnet(theta)
        d_list.append(output_depth)
        
        b = b- (x-theta)
        x_list.append(theta)
        yb = A(theta+b,Phi_r)
        x = theta+b + At(torch.div(y-yb,Phi_s+self.gamma8),Phi_r)
        x1 = x-b
        theta = self.unet8(x1)
        output_depth = self.depthnet(theta)
        d_list.append(output_depth)
        
        b = b- (x-theta)
        x_list.append(theta)
        yb = A(theta+b,Phi_r)
        x = theta+b + At(torch.div(y-yb,Phi_s+self.gamma9),Phi_r)
        x1 = x-b
        theta = self.unet9(x1)
        output_depth = self.depthnet(theta)
        d_list.append(output_depth)
        b = b- (x-theta)
        x_list.append(theta)
        x_list.append(d_list[-3])
        x_list.append(d_list[-2])
        x_list.append(d_list[-1])
        
        output_list = x_list[-6:]
        
        return output_list


In [3]:
# matrix mulfunction
####################

def time2file_name(time):
    year = time[2:4]
    month = time[5:7]
    day = time[8:10]
    hour = time[11:13]
    minute = time[14:16]
    time_filename = year + '_' + month + '_' + day + '_' + hour + '_' + minute + '_'
    return time_filename

def A(xt,Phi_r):
    xt =  torch.unsqueeze(xt, 2).repeat(1,1,Phi_r.shape[2],1,1)
    yt = xt*Phi_r
    yt = torch.sum(yt,1)
    return yt

def At(y,Phi_r):
    temp = torch.unsqueeze(y, 1).repeat(1,Phi_r.shape[1],1,1,1).cuda()
    x = temp*Phi_r
    x = torch.sum(x,2)
    return x/4

def A0(x,Phi):
    temp = x*Phi
    y = torch.sum(temp,1)
    return torch.from_numpy(y)

def At0(y,Phi):
    temp = torch.unsqueeze(y, 1).repeat(1,Phi.shape[1],1,1)
    x = temp*Phi
    return x

# Define train and test
#######################

criterion  = nn.MSELoss()
criterion.cuda()
l1_loss = nn.L1Loss()
l1_loss.cuda()

def test(test_path, result_path, psnr_epoch,mask, mask_r, mask_s, noise_add= True, separate='separate', multi_path=False):
    
    test_list = os.listdir(test_path)
    test_list = sorted(test_list)
    psnr_sample = torch.zeros(len(test_list))
    pred = []
    outdepth = []
    
    for i in range(len(test_list)):
        with h5py.File(test_path + '/' + test_list[i],'r') as f:
            if multi_path:
                gt = torch.from_numpy(np.array(f["subframe128"]))
            else:
                gt = torch.from_numpy(np.array(f["subframe_ideal128"]))
            depth = torch.from_numpy(np.array(f["depth_1024"]))

        gt = torch.unsqueeze(gt.permute(0,1,3,2),0).cuda()#c,h,w
        gt = gt.float()
        
        if separate=='separate':
            coef = torch.tensor([1.0991, 1.3081, 0.6170, 1.0767])
            coef = torch.unsqueeze(coef,0)
            coef = torch.unsqueeze(coef,2)
            coef = torch.unsqueeze(coef,3)
            coef = torch.unsqueeze(coef,4)
            sg = gt.shape
            coefg = coef.repeat(sg[0],1,sg[2],sg[3],sg[4]).cuda()
            gt = gt*coefg
            gt = torch.sum(gt,1) 
            
        elif separate=='noseparate':
            a=1.3326
            gt = gt*a.cuda()
            gt = torch.sum(gt,1) 
            
        elif separate=='normal':
            gt = torch.sum(gt,1) 
        else:
            raise Exception('separate error')

        depth = torch.unsqueeze(depth.permute(1,0),0).cuda()#c,h,w
        depth = depth.float()
            
        if noise_add:
            gt_max = torch.max(gt)
            e_elec= 20000/gt_max
            noise_poisson =  torch.poisson(gt*e_elec)
            noise_gaussian = torch.normal(0,40,size = gt.shape).cuda()
            gt = (noise_poisson+noise_gaussian)/e_elec
            y = gt.detach()
        else:
            y=gt.detach()
            
        y = torch.unsqueeze(y,2)
        y=y.repeat([1,1,4,1,1])
        Phi = mask.repeat([1,1,1])
        Phi_s = mask_s.repeat([1, 1, 128, 128])
        Phi_r = mask_r.repeat([1, 1,1, 128, 128])
        
        y = y*Phi_r
        y = torch.sum(y, dim=1)       

        with torch.no_grad():
            out_pic_list = network(y, Phi, Phi_r, Phi_s)
            out_pic = out_pic_list[-4]
            out_depth = out_pic_list[-1]
            
            psnr_1 = 0
            for ii in range(frame_num):
                out_pic_p = out_pic[0, ii, :, :]
                gt_t = gt[0, ii, :, :]
                rmse = torch.sqrt(criterion(out_pic_p, gt_t))
                rmse = rmse.data
                psnr_1 += 10 * torch.log10(1 / criterion(out_pic_p, gt_t))
            psnr_1 = psnr_1 / (gt.shape[0] * frame_num)
            psnr_sample[i] = psnr_1   
            
        pred.append(out_pic.cpu().numpy())
        outdepth.append(out_depth.cpu().numpy())
        
    psnr_epoch.append(psnr_sample)
    return pred,outdepth, psnr_epoch

def train(epoch, learning_rate,mask, mask_r, mask_s,noise_add= True, separate='separate',loss_alpha=0.5,loss_beta_i=0.5,loss_beta_d=0.5,multi_path=False):

    epoch_loss = 0
    begin = time.time()
    if epoch<=500:
        optimizer = optim.Adam([{"params": network.depthnet.parameters(),"lr": depth_learning_rate*1},
                                  {'params': network.unet1.parameters()},
                                  {'params': network.unet2.parameters()},
                                  {'params': network.unet3.parameters()},
                                  {'params': network.unet4.parameters()},
                                  {'params': network.unet5.parameters()},
                                  {'params': network.unet6.parameters()},
                                  {'params': network.unet7.parameters()},
                                  {'params': network.unet8.parameters()},
                                  {'params': network.unet9.parameters()},
                                 ], lr=learning_rate*1)
    else:
        optimizer = optim.Adam([{"params": network.depthnet.parameters(),"lr": depth_learning_rate*1},
                                  {'params': network.unet1.parameters()},
                                  {'params': network.unet2.parameters()},
                                  {'params': network.unet3.parameters()},
                                  {'params': network.unet4.parameters()},
                                  {'params': network.unet5.parameters()},
                                  {'params': network.unet6.parameters()},
                                  {'params': network.unet7.parameters()},
                                  {'params': network.unet8.parameters()},
                                  {'params': network.unet9.parameters()},
                                 ], lr=learning_rate)

    if __name__ == '__main__':

        
        for iteration, batch in enumerate(train_data_loader):
            gt, depth = Variable(batch[0]), Variable(batch[1])
            gt = gt.cuda() # [batch,32,256,256]
            gt = gt.float()
            
            if separate=='separate':
                coef= torch.rand(4)+0.5
                coef = torch.unsqueeze(coef,0)
                coef = torch.unsqueeze(coef,2)
                coef = torch.unsqueeze(coef,3)
                coef = torch.unsqueeze(coef,4)
                sg = gt.shape
                coefg = coef.repeat(sg[0],1,sg[2],sg[3],sg[4]).cuda()
                gt = torch.sum(gt,1) 
            elif separate=='noseparate':
                a = torch.rand(1)+0.5
                gt = gt*a.cuda()
                gt = torch.sum(gt,1) 
            elif separate=='normal':
                gt = torch.sum(gt,1) 
            else:
                raise Exception('separate error')
                      
            depth = torch.unsqueeze(depth,1).cuda()
            depth = depth.float()
            batch_size1 = gt.shape[0]

            if noise_add:
                gt_max = torch.max(gt)
                e_elec= 20000/gt_max
                noise_poisson =  torch.poisson(gt*e_elec)
                noise_gaussian = torch.normal(0,40,size = gt.shape)
                gt = (noise_poisson+noise_gaussian)/e_elec
                y0 = gt.detach()
            else:
                y0=gt.detach()
                
            y0 = torch.unsqueeze(y0,2)
            y0=y0.repeat([1,1,4,1,1])
            
            Phi_r = mask_r.repeat([batch_size1, 1,1, 128, 128])
            Phi = mask.repeat([batch_size1,1,1])           
            Phi_s = torch.sum(Phi_r,dim=1) 
            
            y = y0*Phi_r
            y = torch.sum(y, dim=1)
            optimizer.zero_grad()
            model_out = network(y, Phi, Phi_r, Phi_s)

            Loss1 = torch.sqrt(criterion(model_out[-4], gt)) + 0.5*torch.sqrt(criterion(model_out[-5], gt)) + 0.5*torch.sqrt(criterion(model_out[-6], gt))
            Loss2 = torch.sqrt(criterion(model_out[-1],depth))+0.5*torch.sqrt(criterion(model_out[-2],depth))+0.5*torch.sqrt(criterion(model_out[-3],depth))
            Loss3 = l1_loss(model_out[-4], gt) + 0.5*l1_loss(model_out[-5], gt) + 0.5*l1_loss(model_out[-6], gt)
            Loss4 = l1_loss(model_out[-1],depth)+0.5*l1_loss(model_out[-2],depth)+0.5*l1_loss(model_out[-3],depth)
            Loss = Loss1*(1-loss_alpha)+Loss3*(1-loss_alpha)+Loss2*loss_alpha
            
            epoch_loss += Loss.data

            Loss.backward()
            optimizer.step()
            
    end = time.time()
    print("===> Epoch {} Complete: Avg. Loss: {:.7f}".format(epoch, epoch_loss / len(train_data_loader)), "  time: {:.2f}".format(end - begin))
    return (epoch_loss / len(train_data_loader),mask,mask_r,mask_s)

def checkpoint(epoch, model_path,mask,mask_r):
    model_out_path = './' + model_path + '/S{}'.format(stage_num) + "_model_epoch_{}.pth".format(epoch)
    model_out_path2 = './' + model_path + '/S{}'.format(stage_num) + "_model_epoch_{}_state_dict.pth".format(epoch)
    torch.save(network, model_out_path)
    torch.save(network.state_dict(), model_out_path2)
    print("Checkpoint saved to {}".format(model_out_path))

def main(learning_rate,depth_learning_rate,model_name,mask, mask_r, mask_s,noise_add= True, separate='separate',loss_alpha=0.5,loss_beta_i=0.5,loss_beta_d=0.5,multi_path=False):
    date_time = str(datetime.datetime.now())
    date_time = time2file_name(date_time)
    
    result_name = date_time +model_name
    result_path = 'recon' + '/' + result_name
    model_path = 'model' + '/' + date_time +model_name
    
    if not os.path.exists(result_path):
        os.makedirs(result_path)

    if not os.path.exists(model_path):
        os.makedirs(model_path)
        
    file_para = open(model_path+r'/para.txt','w' )
    file_para.write('model_name'+':'+model_name+'\n')
    file_para.write('learning_rate'+':'+str(learning_rate)+'\n')
    file_para.write('depth_learning_rate'+':'+str(depth_learning_rate)+'\n')
    file_para.write('noise_add'+':'+str(noise_add)+'\n')
    file_para.write('separate'+':'+separate+'\n')
    file_para.write('loss_alpha'+':'+str(loss_alpha)+'\n')
    file_para.write('loss_beta_i'+':'+str(loss_beta_i)+'\n')
    file_para.write('loss_beta_d'+':'+str(loss_beta_d)+'\n')
    file_para.write('multi_path'+':'+str(multi_path)+'\n')
    file_para.close()
    
    print('model_name'+':'+model_name)
    print('learning_rate'+':{}'.format(learning_rate))
    print('depth_learning_rate'+':{}'.format(depth_learning_rate))
    print('noise_add'+':{}'.format(noise_add))
    print('separate'+':'+separate)
    print('loss_alpha'+':{}'.format(loss_alpha))
    print('loss_beta_i'+':{}'.format(loss_beta_i))
    print('loss_beta_d'+':{}'.format(loss_beta_d))
    print('multi_path'+':{}'.format(multi_path))
    
    psnr_epoch = []
    psnr_epoch2 = []
    loss_log = []
    psnr_max = 0
    checkpoint(0, model_path,mask,mask_r)

    for epoch in range(last_epoch + 1, last_epoch + max_iter + 1):
        train_out = train(epoch, learning_rate,mask, mask_r, mask_s,noise_add,separate,loss_alpha,loss_beta_i,loss_beta_d,multi_path)
        epoch_loss = train_out[0].detach()
        mask = train_out[1].cuda().detach()
        mask_r = train_out[2].cuda().detach()
        mask_s = train_out[3].cuda().detach()
        loss_log.append(epoch_loss)
        loss_log_n = torch.tensor(loss_log).to('cpu').detach().numpy().copy()
        np.save(r'./' + model_path+r'/loss_log',loss_log_n)
        np.savetxt(r'./' + model_path+r'/loss_log.txt',loss_log_n)
        
        if (epoch % 50 == 0) and (epoch < 500):
            learning_rate = learning_rate * 0.9
            depth_learning_rate = depth_learning_rate*0.9

        if epoch%100 == 0 or epoch==1:
            pred,outdepth, psnr_epoch = test(test_path1,  result_path, psnr_epoch,mask, mask_r, mask_s,noise_add,'separate',multi_path)
            print(psnr_epoch)
            psnr_mean = torch.mean(psnr_epoch[-1])
            print("Test result: {:.4f}".format(psnr_mean))
            checkpoint(epoch, model_path,mask,mask_r)
            name = result_path + '/S{}'.format(stage_num) + '_pred_' + '{}_{:.4f}'.format(epoch, psnr_mean) + '.mat'
            scio.savemat(name, {'pred': pred})
            name2 = result_path + '/S{}'.format(stage_num) + '_outdepth_' + '{}_{:.4f}'.format(epoch, psnr_mean) + '.mat'
            scio.savemat(name2, {'outdepth': outdepth})
            
        if math.isnan(epoch_loss):   
            break            

def generate_masks2(mask_path,shutter_bit,init_mask = 'real_opt_4x'):
    if init_mask == 'real_opt_4x':
        with h5py.File(mask_path + '/MAU19_optimized_shutter_606MHz_2Bit_4samp.mat', 'r') as f:
            mask =np.array(f["a"]) 

    over_rate = int(frame_num/shutter_bit)
    mask_r0 = np.repeat(mask,over_rate,axis = 0)    
    mask_r = np.zeros((frame_num, 4, 2, 2))
    p=0
    for i in range(2):
        for j in range(2):
            for k in range(4):
                mask_r[:,k,j,i] = mask_r0[:,p]
                p = p+1

    mask_s = np.sum(mask_r, axis=0)
    index = np.where(mask_s == 0)
    mask_s[index] = 1
    
    mask = mask.reshape([shutter_bit,4,4])
    mask = torch.from_numpy(mask)
    mask = mask.float()
    mask = mask.cuda()
    
    mask_s = torch.from_numpy(mask_s)
    mask_s = mask_s.float()
    mask_s = mask_s.cuda()
    
    mask_r = torch.from_numpy(mask_r)
    mask_r = mask_r.float()
    mask_r = mask_r.cuda()
    
    return mask, mask_r, mask_s


In [4]:
###################
     # TRAIN 
###################

In [5]:
# Define hyperparam
###################

# to use checkpoint
last_epoch=0
model_save_filename = ''

# traing param
max_iter = 1100
noise_add=False # out of services
batch_size =4
stage_num = 9
mode = 'train'  # train or test
learning_rate = 0.002
depth_learning_rate = 0.002
loss_alpha = 0.5 #the ratio of the enhancement reproduction and depth estimation.
loss_beta_i = 201
loss_beta_d = 200

# to change shutter patturns and datasets
data_path = r"./dataset/train128"  
test_path1 = r"./dataset/test128"  # simulation data for comparison
tap_num = 4
mask_path = r"./matlab"
init_mask = 'real_opt_4x'
separate='separate' #'separate' , 'noseparate' ,'normal'
frame_num = 128 #the lenght of transient images.
shutter_bit = frame_num
mask, mask_r, mask_s = generate_masks2(mask_path,shutter_bit,init_mask)

multi_path = True


In [None]:
# RUN training
##############

# network = ADMM_128().cuda()
network = ADMM_depthnet128().cuda()

if last_epoch != 0:
    network = torch.load(
        './model/' + model_save_filename + "/S{}_model_epoch_{}.pth".format(stage_num, last_epoch))
    
if multi_path:
    dataset = Imgdataset_multipath(data_path)
else:
    dataset = Imgdataset(data_path)

train_data_loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True)
mask_check_data_loader = DataLoader(dataset=dataset, batch_size=1, shuffle=True)

model_name = 'Recons_depth_4x(pretrain)'

if __name__ == '__main__':
    main(learning_rate,depth_learning_rate,model_name,mask, mask_r, mask_s,noise_add,separate,loss_alpha,loss_beta_i,loss_beta_d,multi_path)
    

model_name:Recons_depth_4x(pretrain)
learning_rate:0.002
depth_learning_rate:0.002
noise_add:False
separate:separate
loss_alpha:0.5
loss_beta_i:201
loss_beta_d:200
multi_path:True
Checkpoint saved to ./model/25_02_09_16_45_Recons_depth_4x(pretrain)/S9_model_epoch_0.pth
===> Epoch 1 Complete: Avg. Loss: 0.8492954   time: 146.58
[tensor([19.1600, 19.4980, 10.8066, 17.1689, 17.4198, 14.3987, 10.3160, 13.0090,
        13.6839, 18.7218, 18.2228, 16.7289, 19.3871, 18.6625, 18.2957])]
Test result: 16.3653
Checkpoint saved to ./model/25_02_09_16_45_Recons_depth_4x(pretrain)/S9_model_epoch_1.pth
===> Epoch 2 Complete: Avg. Loss: 0.6977388   time: 143.01
===> Epoch 3 Complete: Avg. Loss: 0.5834948   time: 142.97
===> Epoch 4 Complete: Avg. Loss: 0.5928745   time: 143.12
===> Epoch 5 Complete: Avg. Loss: 0.5447168   time: 143.21
===> Epoch 6 Complete: Avg. Loss: 0.6296815   time: 142.99
===> Epoch 7 Complete: Avg. Loss: 0.5236457   time: 142.94
===> Epoch 8 Complete: Avg. Loss: 0.4972136   time: 

In [18]:
# Reproduction for real 
#######################
import glob
import os
import torch

# Check the avtive DNN
######################
print(torch.cuda.device_count())

GPU_number =2 # if you wanna use A5000, 2
torch.cuda.set_device(GPU_number)
print(torch.cuda.get_device_name())

if not torch.cuda.is_available():
    raise Exception('NO GPU!')

if not torch.cuda.is_available():
    raise Exception('NO GPU!')

# Choose Network
###################
a = ['25_02_09_16_45_Recons_depth_4x(pretrain)']
# network = ADMM_128().cuda()
network = ADMM_depthnet128().cuda()

# real scene test path
######################
calibration = 40000
test_path = r"./real"
mask_path = r"./matlab"
init_mask = 'real_opt_4x'
frame_num = 128 #the lenght of transient images.
shutter_bit = frame_num
mask, mask_r, mask_s = generate_masks2(mask_path,shutter_bit,init_mask)
Phi_r = mask_r.repeat([1, 1,1, 94, 106])
Phi = mask.repeat([1,1,1])
Phi_s = torch.sum(Phi_r,dim=1)
    
for model_name in a:
    print(model_name)
    result_path = 'recon/' + model_name
    if not os.path.exists(result_path):
        os.makedirs(result_path)
    psnr_epoch = []
    test_list = [f for f in os.listdir(test_path) if f.endswith('.mat')]
    print(test_list)
    psnr_sample = torch.zeros(len(test_list))
    pred = []
    outdepth = []
    for i_test in range(0,len(test_list)):
        print(i_test)
        print(test_path + '/' + test_list[i_test])
        gt = scio.loadmat(test_path + '/' + test_list[i_test])
        gt = gt['Gresult']
        calibration = 1.5*np.max(gt)
        gt = torch.from_numpy(np.array(gt))
        gt = torch.unsqueeze(gt.permute(2,0,1),0).float().cuda()/calibration#c,h,w
        y = gt
        torch.cuda.synchronize()
        start = time.time()
        with torch.no_grad():
            network.load_state_dict(torch.load(r'./model/' + model_name + '/S9_model_epoch_600_state_dict.pth')) # The best epoch iterations is 600 for real scene i think. 
            out_pic_list = network(y, Phi, Phi_r, Phi_s)
            torch.cuda.synchronize()
            elapsed_time = time.time() - start
            print(elapsed_time, 'sec.')
            
            out_pic = out_pic_list[-4] #K > 3
            out_depth = out_pic_list[-1] 
            pred.append(out_pic.cpu().numpy())
            outdepth.append(out_depth.cpu().numpy())
            name = result_path + '/real_{}'.format(int(calibration))+'reproduced_'+ test_list[i_test] 
            scio.savemat(name, {'pred': pred})
            name2 = result_path + '/real_{}'.format(int(calibration))+'depth_' + test_list[i_test]
            scio.savemat(name2, {'outdepth': outdepth})
            pred = []
            outdepth = []
    print('b')
        

3
NVIDIA RTX A5000
25_02_09_16_45_Recons_depth_4x(pretrain)
['1215_corner_wo.mat']
0
./real/1215_corner_wo.mat


  network.load_state_dict(torch.load(r'./model/' + model_name + '/S9_model_epoch_600_state_dict.pth'))


0.7010691165924072 sec.
b
