In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from util_model_v2 import DepthCompletionModule
import torchvision.transforms.functional as ttf
import numpy as np
from PIL import Image
from pathlib import Path
from sklearn.model_selection import train_test_split
import json
import time
from tqdm import tqdm
from collections import deque
import os

batch_size = 8
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
min_depth = 0.5
max_depth = 7
min_dense_ratio = 0.15 / 100
max_dense_ratio = 0.5 / 100
test_dense_ratio = 0.2 / 100


class NyuDataset(Dataset):
    def __init__(self, mode="train"):
        assert mode == "train" or mode == "test"
        self.mode = mode
        files = [str(x) for x in Path("./data").glob("*.json")]
        train_files, test_files = train_test_split(files, test_size=0.05, random_state=0)
        self.files = train_files if self.mode == "train" else test_files
        self.dataset = []
        for file_path in self.files:
            with open(file_path, "r", encoding="utf-8") as f:
                self.dataset.extend(json.load(f)[::7])
        self.dataset_length = len(self.dataset)
        self.dataset_length -= self.dataset_length % batch_size
        print(f"dataset {self.mode} with sample number {self.dataset_length}")

    @staticmethod
    def read_pgm(pgm_file_path):
        with open(pgm_file_path, 'rb') as pgm_file:
            p5, width, height, depth = pgm_file.readline().split()
            assert p5 == b'P5'
            assert depth == b'65535'
            width, height = int(width), int(height)
            data = np.fromfile(pgm_file, dtype='<u2', count=width * height)
            data = data.reshape([height, width]).astype(np.uint32)
            return Image.fromarray(data, mode="I")

    @staticmethod
    def read_ppm(ppm_file_path):
        with open(ppm_file_path, 'rb') as ppm_file:
            p6, width, height, depth = ppm_file.readline().split()
            assert p6 == b'P6'
            assert depth == b'255'
            width, height = int(width), int(height)
            data = np.fromfile(ppm_file, dtype=np.uint8, count=width * height * 3)
            data = data.reshape([height, width, 3])
            return Image.fromarray(data, mode="RGB")

    def transform_fn(self, rgb, depth):
        if self.mode == "train":
            if np.random.uniform() < 0.5:
                rgb = ttf.hflip(rgb)
                depth = ttf.hflip(depth)

            # degree = np.random.uniform(-5.0, 5.0)
            # rgb = ttf.rotate(rgb, degree)
            # depth = ttf.rotate(depth, degree)

            brightness = np.random.uniform(0.9, 1.1)
            contrast = np.random.uniform(0.9, 1.1)
            saturation = np.random.uniform(0.9, 1.1)
            rgb = ttf.adjust_brightness(rgb, brightness)
            rgb = ttf.adjust_contrast(rgb, contrast)
            rgb = ttf.adjust_saturation(rgb, saturation)

        rgb = ttf.to_tensor(rgb)
        depth = ttf.to_tensor(depth)
        rgb = torch.as_tensor(rgb, dtype=torch.float)
        depth = torch.as_tensor(depth, dtype=torch.float)
        depth /= 256
        depth[depth < min_depth] = 0
        depth[depth > max_depth] = 0

        rgb = ttf.normalize(rgb, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        return rgb, depth

    def generate_sparse_depth(self, full_depth, num_sample):
        idx_candidate = torch.nonzero(full_depth > 1e-2)
        idx_sample = torch.randperm(len(idx_candidate))[:num_sample]
        idx_selected = idx_candidate[idx_sample]
        mask = torch.zeros_like(full_depth)
        mask[idx_selected[:, 0], idx_selected[:, 1], idx_selected[:, 2]] = 1.
        sparse_depth = mask * full_depth
        return sparse_depth

    def __getitem__(self, idx):
        sample = self.dataset[idx]
        try:
            rgb = self.read_ppm(sample["ppm"])
            depth = self.read_pgm(sample["pgm"])
            rgb, depth = self.transform_fn(rgb, depth)
            if self.mode == "train":
                num_sample = int(np.random.uniform(480 * 640 * min_dense_ratio, 480 * 640 * max_dense_ratio))
            else:
                num_sample = int(480 * 640 * test_dense_ratio)
            sparse_depth = self.generate_sparse_depth(depth, num_sample=num_sample)
        except Exception as e:
            rgb = torch.zeros([3, 480, 640])
            depth = torch.zeros([1, 480, 640])
            sparse_depth = torch.zeros([1, 480, 640])
        return rgb, sparse_depth, depth

    def __len__(self):
        return self.dataset_length


def calc_error(depth_pred, depth_target):
    error = {}
    mask = torch.logical_and(torch.gt(depth_target, min_depth), torch.lt(depth_target, max_depth))
    depth_pred = depth_pred[mask]
    depth_target = depth_target[mask]
    n_valid_element = depth_target.shape[0] + 1e-4

    diff_mat = torch.abs(depth_pred - depth_target)
    rel_mat = torch.div(diff_mat, depth_target)
    error["MSE"] = torch.sum(torch.pow(diff_mat, 2)) / n_valid_element
    error["RMSE"] = torch.sqrt(error["MSE"])
    error['MAE'] = torch.sum(diff_mat) / n_valid_element
    error['ABS_REL'] = torch.sum(rel_mat) / n_valid_element
    y_over_z = torch.div(depth_target, depth_pred)
    z_over_y = torch.div(depth_pred, depth_target)
    max_ratio = torch.max(y_over_z, z_over_y)
    error['DELTA1.02'] = torch.sum(max_ratio < 1.02) / float(n_valid_element)
    error['DELTA1.05'] = torch.sum(max_ratio < 1.05) / float(n_valid_element)

    error = {K: V.item() for K, V in error.items()}
    error['loss'] = F.mse_loss(torch.sqrt(depth_pred), torch.sqrt(depth_target))

    return error


def train(model, dataloader, epoch, optimizer):
    time.sleep(0.2)
    model.train()
    loss_count = deque([], maxlen=100)
    pbar = tqdm(dataloader)
    pbar.set_description("train epoch {}".format(epoch))
    for rgb, sparse_depth, depth_target in pbar:
        optimizer.zero_grad()
        rgb, sparse_depth, depth_target = rgb.to(device), sparse_depth.to(device), depth_target.to(device)
        depth_pred = model(rgb, sparse_depth)
        error = calc_error(depth_pred, depth_target)
        loss = error["loss"]
        loss.backward()
        optimizer.step()

        error['loss'] = error['loss'].item()
        loss_count.append(error)
        loss_arr = [x["loss"] for x in loss_count]
        mse_arr = [x["MSE"] for x in loss_count]
        rmse_arr = [x["RMSE"] for x in loss_count]
        MAE_arr = [x["MAE"] for x in loss_count]
        ABS_REL_arr = [x["ABS_REL"] for x in loss_count]
        DELTA2_arr = [x["DELTA1.02"] for x in loss_count]
        DELTA5_arr = [x["DELTA1.05"] for x in loss_count]

        log_str = f"loss={np.mean(loss_arr):0.8f},MSE={np.mean(mse_arr):0.8f},RMSE={np.mean(rmse_arr):0.8f},MAE={np.mean(MAE_arr):0.8f},ABS_REL={np.mean(ABS_REL_arr):0.8f},DELTA1.02={np.mean(DELTA2_arr):0.8f},DELTA1.05={np.mean(DELTA5_arr):0.8f}"
        pbar.set_postfix_str(log_str)


def test(model, dataloader, epoch):
    time.sleep(0.2)
    model.eval()
    loss_count = []
    pbar = tqdm(dataloader)
    pbar.set_description("test epoch {}".format(epoch))
    for rgb, sparse_depth, depth_target in pbar:
        rgb, sparse_depth, depth_target = rgb.to(device), sparse_depth.to(device), depth_target.to(device)
        with torch.no_grad():
            depth_pred = model(rgb, sparse_depth)
            error = calc_error(depth_pred, depth_target)

        error['loss'] = error['loss'].item()
        loss_count.append(error)
        loss_arr = [x["loss"] for x in loss_count]
        mse_arr = [x["MSE"] for x in loss_count]
        rmse_arr = [x["RMSE"] for x in loss_count]
        MAE_arr = [x["MAE"] for x in loss_count]
        ABS_REL_arr = [x["ABS_REL"] for x in loss_count]
        DELTA2_arr = [x["DELTA1.02"] for x in loss_count]
        DELTA5_arr = [x["DELTA1.05"] for x in loss_count]
        log_str = f"loss={np.mean(loss_arr):0.8f},MSE={np.mean(mse_arr):0.8f},RMSE={np.mean(rmse_arr):0.8f},MAE={np.mean(MAE_arr):0.8f},ABS_REL={np.mean(ABS_REL_arr):0.8f},DELTA1.02={np.mean(DELTA2_arr):0.8f},DELTA1.05={np.mean(DELTA5_arr):0.8f}"
        pbar.set_postfix_str(log_str)


if __name__ == '__main__':
    dataset_train = NyuDataset(mode="train")
    dataset_test = NyuDataset(mode="test")

    dataloader_train = DataLoader(dataset_train, batch_size=batch_size, shuffle=True, num_workers=16, pin_memory=True)
    dataloader_test = DataLoader(dataset=dataset_test, batch_size=batch_size, shuffle=True, num_workers=16, pin_memory=True)

    model = DepthCompletionModule()
    # model.load_state_dict(torch.load("./model_1/model_3.pth", map_location="cpu"))
    model.to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

    for epoch in range(100):
        train(model, dataloader_train, epoch, optimizer)
        test(model, dataloader_test, epoch)

        model.eval()
        torch.save(model.state_dict(), f"./model_5/model_2_{epoch}.pth")


dataset train with sample number 64824
dataset test with sample number 3080


train epoch 0: 100%|██████████| 8103/8103 [2:24:39<00:00,  1.07s/it, loss=0.00012583,MSE=0.00165993,RMSE=0.04012409,MAE=0.01413243,ABS_REL=0.00410362,DELTA1.02=0.95942058,DELTA1.05=0.98908504]  
test epoch 0: 100%|██████████| 385/385 [02:18<00:00,  2.78it/s, loss=0.00011121,MSE=0.00150803,RMSE=0.03836651,MAE=0.01345236,ABS_REL=0.00381847,DELTA1.02=0.96408651,DELTA1.05=0.98984460]
train epoch 1: 100%|██████████| 8103/8103 [2:24:34<00:00,  1.07s/it, loss=0.00010477,MSE=0.00138307,RMSE=0.03604351,MAE=0.01213178,ABS_REL=0.00350645,DELTA1.02=0.96862254,DELTA1.05=0.99178981]  
test epoch 1: 100%|██████████| 385/385 [02:20<00:00,  2.75it/s, loss=0.00007760,MSE=0.00104721,RMSE=0.03187641,MAE=0.01069824,ABS_REL=0.00305733,DELTA1.02=0.97369308,DELTA1.05=0.99314845]
train epoch 2: 100%|██████████| 8103/8103 [2:24:36<00:00,  1.07s/it, loss=0.00009082,MSE=0.00119573,RMSE=0.03370765,MAE=0.01123372,ABS_REL=0.00323854,DELTA1.02=0.97234880,DELTA1.05=0.99295590]  
test epoch 2: 100%|██████████| 385/385 

KeyboardInterrupt: 

In [6]:
import torch
import torch.nn as nn
from torchvision import models


class ResnetBackBone(nn.Module):
    def __init__(self):
        super(ResnetBackBone, self).__init__()
        resnet = models.resnet18(pretrained=True)
        self.conv_1 = resnet.layer1
        self.conv_2 = resnet.layer2
        self.conv_3 = resnet.layer3
        self.conv_4 = resnet.layer4

    def forward(self, x):
        feat2 = self.conv_1(x)
        feat2 = self.conv_2(feat2)
        feat4 = self.conv_3(feat2)
        feat8 = self.conv_4(feat4)
        return feat2, feat4, feat8


class CovBnRelu(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1):
        super(CovBnRelu, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias=False)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x


class AttentionRefinementModule(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(AttentionRefinementModule, self).__init__()
        self.conv_1 = CovBnRelu(in_channels, out_channels)
        self.conv_attention = nn.Conv2d(out_channels, out_channels, kernel_size=1, padding=0, bias=False)
        self.bn_attention = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        feat = self.conv_1(x)
        atten = torch.mean(feat, dim=(2, 3), keepdim=True)
        atten = self.conv_attention(atten)
        atten = self.bn_attention(atten)
        atten = torch.sigmoid(atten)
        out = torch.mul(feat, atten)
        return out


class DecovBnRelu(nn.Module):
    def __init__(self, in_channels_up, in_channels_encoder, out_channels):
        super(DecovBnRelu, self).__init__()
        self.conv_1 = nn.ConvTranspose2d(in_channels_up, in_channels_up // 2, kernel_size=4, stride=2, padding=1)
        self.conv_2 = AttentionRefinementModule(in_channels_up // 2 + in_channels_encoder, out_channels)

    def forward(self, x_up, x_encoder):
        x = torch.cat([self.conv_1(x_up), x_encoder], dim=1)
        x = self.conv_2(x)
        return x


class DepthEstimationModule(nn.Module):
    def __init__(self, kernel_size=5):
        super(DepthEstimationModule, self).__init__()
        self.conv_rgb = CovBnRelu(3, 64)
        self.afm_1 = AttentionRefinementModule(64, 64)
        self.backbone = ResnetBackBone()
        self.conv_8_16 = CovBnRelu(512, 512, stride=2)

        self.decoder_16_8 = DecovBnRelu(in_channels_up=512, in_channels_encoder=512, out_channels=256)
        self.decoder_8_4 = DecovBnRelu(in_channels_up=256, in_channels_encoder=256, out_channels=128)
        self.decoder_4_2 = DecovBnRelu(in_channels_up=128, in_channels_encoder=128, out_channels=64)
        self.decoder_2_1 = DecovBnRelu(in_channels_up=64, in_channels_encoder=64, out_channels=128)

        self.out_layer_depth = nn.Conv2d(128, 1, kernel_size=3, stride=1, padding=1)

    def forward(self, rgb):
        feat1 = self.conv_rgb(rgb)
        feat1 = self.afm_1(feat1)
        feat2, feat4, feat8 = self.backbone(feat1)
        feat16 = self.conv_8_16(feat8)
        feat8_up = self.decoder_16_8(feat16, feat8)
        feat4_up = self.decoder_8_4(feat8_up, feat4)
        feat2_up = self.decoder_4_2(feat4_up, feat2)
        feat_out = self.decoder_2_1(feat2_up, feat1)  # torch.Size([4, 128, 480, 640])

        coarse_depth = self.out_layer_depth(feat_out)
        coarse_depth = torch.exp(coarse_depth)  # torch.Size([4, 1, 480, 640])
        return coarse_depth


class DepthCompletionModule(nn.Module):
    def __init__(self, kernel_size=5):
        super(DepthCompletionModule, self).__init__()
        self.conv_rgb = CovBnRelu(3, 48)
        self.conv_depth = CovBnRelu(1, 16)
        self.afm_1 = AttentionRefinementModule(64, 64)
        self.backbone = ResnetBackBone()
        self.conv_8_16 = CovBnRelu(512, 512, stride=2)

        self.decoder_16_8 = DecovBnRelu(in_channels_up=512, in_channels_encoder=512, out_channels=256)
        self.decoder_8_4 = DecovBnRelu(in_channels_up=256, in_channels_encoder=256, out_channels=128)
        self.decoder_4_2 = DecovBnRelu(in_channels_up=128, in_channels_encoder=128, out_channels=64)
        self.decoder_2_1 = DecovBnRelu(in_channels_up=64, in_channels_encoder=64, out_channels=128)

        self.out_layer_depth = nn.Conv2d(128, 1, kernel_size=3, stride=1, padding=1)
        self.out_layer_affinity = nn.Conv2d(128, kernel_size * kernel_size, kernel_size=3, stride=1, padding=1)

        self.layer_unfold = nn.Unfold(kernel_size=kernel_size, dilation=1, padding=kernel_size // 2)

    def forward(self, rgb, sparse_depth):
        feat1 = torch.cat([self.conv_rgb(rgb), self.conv_depth(sparse_depth)], dim=1)
        feat1 = self.afm_1(feat1)
        feat2, feat4, feat8 = self.backbone(feat1)
        feat16 = self.conv_8_16(feat8)
        feat8_up = self.decoder_16_8(feat16, feat8)
        feat4_up = self.decoder_8_4(feat8_up, feat4)
        feat2_up = self.decoder_4_2(feat4_up, feat2)
        feat_out = self.decoder_2_1(feat2_up, feat1)  # torch.Size([4, 128, 480, 640])

        coarse_depth = self.out_layer_depth(feat_out)
        affinity = self.out_layer_affinity(feat_out)
        coarse_depth = torch.exp(coarse_depth)  # torch.Size([4, 1, 480, 640])
        affinity = torch.softmax(affinity, dim=1)  # torch.Size([4, 25, 480, 640])

        # stage 2
        mask = torch.gt(sparse_depth, 1e-3).float()
        refined_depth = coarse_depth
        for i in range(12):
            refined_depth = mask * sparse_depth + (1 - mask) * refined_depth
            depth_unfolded = self.layer_unfold(refined_depth).reshape(affinity.shape)
            depth_unfolded = depth_unfolded * affinity
            refined_depth = torch.sum(depth_unfolded, dim=1, keepdim=True)

        return refined_depth


if __name__ == '__main__':
    batch_size = 4
    model = DepthCompletionModule()
    rgb = torch.randn([batch_size, 3, 480, 640])
    sparse_depth = torch.randn([batch_size, 1, 480, 640])
    depth_predict = model(rgb, sparse_depth)
    print("depth_predict", depth_predict.shape)

    # torch.save(model.state_dict(), f"./model_v2.pth")


depth_predict torch.Size([4, 1, 480, 640])


In [7]:
!nvidia-smi

Wed Jan 19 08:46:16 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 465.27       Driver Version: 465.27       CUDA Version: 11.3     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA Tesla V1...  Off  | 00000000:3B:00.0 Off |                    0 |
| N/A   40C    P0    35W / 250W |  29235MiB / 32510MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces