In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
import torchvision.transforms.functional as ttf
import numpy as np
from PIL import Image
from pathlib import Path
from sklearn.model_selection import train_test_split
import json
import time
from tqdm import tqdm
from collections import deque
import os

batch_size = 8
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class ResnetBackBone(nn.Module):
    def __init__(self):
        super(ResnetBackBone, self).__init__()
        resnet = models.resnet18(pretrained=True)
        self.input_pool = nn.Sequential(*list(resnet.children())[:4])
        self.down_block_1, self.down_block_2, self.down_block_3, self.down_block_4 = \
            [block for block in resnet.children() if isinstance(block, nn.Sequential)]

    def forward(self, x):
        x = self.input_pool(x)
        x = self.down_block_1(x)
        feat8 = self.down_block_2(x)
        feat16 = self.down_block_3(feat8)
        feat32 = self.down_block_4(feat16)
        return feat8, feat16, feat32


class CovBnRelu(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1):
        super(CovBnRelu, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias=False)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x


class AttentionRefinementModule(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(AttentionRefinementModule, self).__init__()
        self.conv_1 = CovBnRelu(in_channels, out_channels)
        self.conv_attention = nn.Conv2d(out_channels, out_channels, kernel_size=1, padding=0, bias=False)
        self.bn_attention = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        feat = self.conv_1(x)
        atten = torch.mean(feat, dim=(2, 3), keepdim=True)
        atten = self.conv_attention(atten)
        atten = self.bn_attention(atten)
        atten = torch.sigmoid(atten)
        out = torch.mul(feat, atten)
        return out


class ContextPath(nn.Module):
    def __init__(self):
        super(ContextPath, self).__init__()
        self.backbone = ResnetBackBone()
        self.conv_avg = CovBnRelu(512, 128, kernel_size=1, stride=1, padding=0)
        self.arm32 = AttentionRefinementModule(512, 128)
        self.arm16 = AttentionRefinementModule(256, 128)
        self.up32 = nn.Upsample(scale_factor=2.)
        self.up16 = nn.Upsample(scale_factor=2.)
        self.conv_head32 = CovBnRelu(128, 128)
        self.conv_head16 = CovBnRelu(128, 128)

    def forward(self, x):
        feat8, feat16, feat32 = self.backbone(x)
        avg = torch.mean(feat32, dim=(2, 3), keepdim=True)
        avg = self.conv_avg(avg)  # [1, 128, 1, 1]

        feat32_arm = self.arm32(feat32)  # [1, 128, 15, 20]
        feat32_sum = feat32_arm + avg
        feat32_up = self.up32(feat32_sum)  # [1, 128, 30, 40]
        feat32_up = self.conv_head32(feat32_up)  # [1, 128, 30, 40]

        feat16_arm = self.arm16(feat16)  # [1, 128, 30, 40]
        feat16_sum = feat16_arm + feat32_up
        feat16_up = self.up16(feat16_sum)  # [1, 128, 60, 80]
        feat16_up = self.conv_head16(feat16_up)  # [1, 128, 60, 80]
        return feat16_up, feat32_up


class SpatialPath(nn.Module):
    def __init__(self):
        super(SpatialPath, self).__init__()
        self.conv_1 = CovBnRelu(3, 64, kernel_size=7, stride=2, padding=3)
        self.conv_2 = CovBnRelu(64, 64, kernel_size=3, stride=2, padding=1)
        self.conv_3 = CovBnRelu(64, 64, kernel_size=3, stride=2, padding=1)
        self.conv_out = CovBnRelu(64, 128, kernel_size=1, stride=1, padding=0)

    def forward(self, x):
        feat = self.conv_1(x)
        feat = self.conv_2(feat)
        feat = self.conv_3(feat)
        feat = self.conv_out(feat)
        return feat  # [4, 128, 60, 80]


class FeatureFusionModule(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(FeatureFusionModule, self).__init__()
        self.conv_1 = CovBnRelu(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
        self.conv_atten = nn.Conv2d(out_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False)
        self.bn_atten = nn.BatchNorm2d(out_channels)

    def forward(self, fsp, fcp):
        feat = torch.cat([fsp, fcp], dim=1)
        feat = self.conv_1(feat)
        atten = torch.mean(feat, dim=(2, 3), keepdim=True)
        atten = self.conv_atten(atten)
        atten = self.bn_atten(atten)
        atten = torch.sigmoid(atten)
        feat_atten = torch.mul(feat, atten)
        feat_out = feat_atten + feat
        return feat_out


class BiSeNetOutput(nn.Module):
    def __init__(self, in_channels, mid_channels, out_channels, up_factor):
        super(BiSeNetOutput, self).__init__()
        self.conv = CovBnRelu(in_channels, mid_channels)
        self.conv_out = nn.Conv2d(mid_channels, out_channels, kernel_size=1)
        self.up = nn.Upsample(scale_factor=up_factor, mode="bilinear", align_corners=False)

    def forward(self, x):
        x = self.conv(x)
        x = self.conv_out(x)
        x = self.up(x)
        return x


class BiSeNetV1(nn.Module):
    def __init__(self):
        super(BiSeNetV1, self).__init__()
        self.cp = ContextPath()
        self.sp = SpatialPath()
        self.ffm = FeatureFusionModule(256, 256)
        self.conv_out_8 = BiSeNetOutput(256, 256, 1, up_factor=8)

    def forward(self, x):
        feat_cp_8, feat_cp_16 = self.cp(x)
        feat_sp_8 = self.sp(x)
        feat_fuse = self.ffm(feat_sp_8, feat_cp_8)  # [4, 256, 60, 80]
        feat_out = self.conv_out_8(feat_fuse)
        return torch.exp(feat_out)


class NyuDataset(Dataset):
    def __init__(self, mode="train"):
        assert mode == "train" or mode == "test"
        self.mode = mode
        files = [str(x) for x in Path("./data").glob("*.json")]
        train_files, test_files = train_test_split(files, test_size=0.05, random_state=0)
        self.files = train_files if self.mode == "train" else test_files
        self.dataset = []
        for file_path in self.files:
            with open(file_path, "r", encoding="utf-8") as f:
                self.dataset.extend(json.load(f)[::5])
        self.dataset_length = len(self.dataset)
        print(f"dataset {self.mode} with sample number {self.dataset_length}")

    @staticmethod
    def read_pgm(pgm_file_path):
        with open(pgm_file_path, 'rb') as pgm_file:
            p5, width, height, depth = pgm_file.readline().split()
            assert p5 == b'P5'
            assert depth == b'65535'
            width, height = int(width), int(height)
            data = np.fromfile(pgm_file, dtype='<u2', count=width * height)
            data = data.reshape([height, width]).astype(np.uint32)
            return Image.fromarray(data, mode="I")

    @staticmethod
    def read_ppm(ppm_file_path):
        with open(ppm_file_path, 'rb') as ppm_file:
            p6, width, height, depth = ppm_file.readline().split()
            assert p6 == b'P6'
            assert depth == b'255'
            width, height = int(width), int(height)
            data = np.fromfile(ppm_file, dtype=np.uint8, count=width * height * 3)
            data = data.reshape([height, width, 3])
            return Image.fromarray(data, mode="RGB")

    def transform_fn(self, rgb, depth):
        if self.mode == "train":
            if np.random.uniform() < 0.5:
                rgb = ttf.hflip(rgb)
                depth = ttf.hflip(depth)

            degree = np.random.uniform(-5.0, 5.0)
            rgb = ttf.rotate(rgb, degree)
            depth = ttf.rotate(depth, degree)

            brightness = np.random.uniform(0.9, 1.1)
            contrast = np.random.uniform(0.9, 1.1)
            saturation = np.random.uniform(0.9, 1.1)
            rgb = ttf.adjust_brightness(rgb, brightness)
            rgb = ttf.adjust_contrast(rgb, contrast)
            rgb = ttf.adjust_saturation(rgb, saturation)

        rgb = ttf.to_tensor(rgb)
        depth = ttf.to_tensor(depth)
        rgb = torch.as_tensor(rgb, dtype=torch.float)
        depth = torch.as_tensor(depth, dtype=torch.float)
        depth /= 1000

        rgb = ttf.normalize(rgb, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        return rgb, depth

    def __getitem__(self, idx):
        sample = self.dataset[idx]
        try:
            rgb = self.read_ppm(sample["ppm"])
            depth = self.read_pgm(sample["pgm"])
            rgb, depth = self.transform_fn(rgb, depth)
        except Exception as e:
            rgb = torch.zeros([3, 480, 640])
            depth = torch.zeros([1, 480, 640])
        return rgb, depth

    def __len__(self):
        return self.dataset_length


def calc_error(depth_pred, depth_target):
    error = {}
    mask = torch.logical_and(torch.gt(depth_target, 1e-3), torch.lt(depth_target, 2))
    depth_pred = depth_pred[mask]
    depth_target = depth_target[mask]
    n_valid_element = depth_target.shape[0] + 1e-4

    diff_mat = torch.abs(depth_pred - depth_target)
    rel_mat = torch.div(diff_mat, depth_target)
    error["MSE"] = torch.sum(torch.pow(diff_mat, 2)) / n_valid_element
    error["RMSE"] = torch.sqrt(error["MSE"])
    error['MAE'] = torch.sum(diff_mat) / n_valid_element
    error['ABS_REL'] = torch.sum(rel_mat) / n_valid_element
    y_over_z = torch.div(depth_target, depth_pred)
    z_over_y = torch.div(depth_pred, depth_target)
    max_ratio = torch.max(y_over_z, z_over_y)
    error['DELTA1.02'] = torch.sum(max_ratio < 1.02) / float(n_valid_element)
    error['DELTA1.05'] = torch.sum(max_ratio < 1.05) / float(n_valid_element)

    error = {K: V.item() for K, V in error.items()}

    error['loss'] = torch.mean(torch.pow(torch.log(depth_pred) - torch.log(depth_target), 2))

    return error


def train(model, dataloader, epoch, optimizer):
    time.sleep(0.2)
    model.train()
    loss_count = deque([], maxlen=100)
    pbar = tqdm(dataloader)
    pbar.set_description("train epoch {}".format(epoch))
    for rgb, depth_target in pbar:
        optimizer.zero_grad()
        rgb, depth_target = rgb.to(device), depth_target.to(device)

        depth_pred = model(rgb)
        # print("depth_pred,",depth_pred.shape)
        # print("depth_target,",depth_target.shape)
        error = calc_error(depth_pred, depth_target)
        loss = error["loss"]

        loss.backward()
        optimizer.step()

        error['loss'] = error['loss'].item()
        loss_count.append(error)
        loss_arr = [x["loss"] for x in loss_count]
        mse_arr = [x["MSE"] for x in loss_count]
        rmse_arr = [x["RMSE"] for x in loss_count]
        MAE_arr = [x["MAE"] for x in loss_count]
        ABS_REL_arr = [x["ABS_REL"] for x in loss_count]
        DELTA2_arr = [x["DELTA1.02"] for x in loss_count]
        DELTA5_arr = [x["DELTA1.05"] for x in loss_count]

        log_str = f"loss={np.mean(loss_arr):0.8f},MSE={np.mean(mse_arr):0.8f},RMSE={np.mean(rmse_arr):0.8f},MAE={np.mean(MAE_arr):0.8f},ABS_REL={np.mean(ABS_REL_arr):0.8f}, DELTA1.02={np.mean(DELTA2_arr):0.8f}, DELTA1.05={np.mean(DELTA5_arr):0.8f}"
        pbar.set_postfix_str(log_str)


def test(model, dataloader, epoch):
    time.sleep(0.2)
    model.eval()
    loss_count = []
    pbar = tqdm(dataloader)
    pbar.set_description("test epoch {}".format(epoch))
    for rgb, depth_target in pbar:
        rgb, depth_target = rgb.to(device), depth_target.to(device)
        with torch.no_grad():
            depth_pred = model(rgb)
            error = calc_error(depth_pred, depth_target)

        error['loss'] = error['loss'].item()
        loss_count.append(error)
        loss_arr = [x["loss"] for x in loss_count]
        mse_arr = [x["MSE"] for x in loss_count]
        rmse_arr = [x["RMSE"] for x in loss_count]
        MAE_arr = [x["MAE"] for x in loss_count]
        ABS_REL_arr = [x["ABS_REL"] for x in loss_count]
        DELTA2_arr = [x["DELTA1.02"] for x in loss_count]
        DELTA5_arr = [x["DELTA1.05"] for x in loss_count]
        log_str = f"loss={np.mean(loss_arr):0.8f},MSE={np.mean(mse_arr):0.8f},RMSE={np.mean(rmse_arr):0.8f},MAE={np.mean(MAE_arr):0.8f},ABS_REL={np.mean(ABS_REL_arr):0.8f}, DELTA1.02={np.mean(DELTA2_arr):0.8f}, DELTA1.05={np.mean(DELTA5_arr):0.8f}"
        pbar.set_postfix_str(log_str)


if __name__ == '__main__':
    dataset_train = NyuDataset(mode="train")
    dataset_test = NyuDataset(mode="test")

    dataloader_train = DataLoader(dataset_train, batch_size=batch_size, shuffle=True, num_workers=16, pin_memory=True)
    dataloader_test = DataLoader(dataset=dataset_test, batch_size=batch_size, shuffle=False, num_workers=16, pin_memory=True)

    model = BiSeNetV1()
    # model.load_state_dict(torch.load("./model_1/model_v1.pth", map_location="cpu"))
    model.to(device)

    optimizer = torch.optim.Adam(model.parameters())

    for epoch in range(100):
        train(model, dataloader_train, epoch, optimizer)
        test(model, dataloader_test, epoch)

        model.eval()
        torch.save(model.state_dict(), f"./model_1/model_{epoch}.pth")


dataset train with sample number 90655
dataset test with sample number 4310


train epoch 0: 100%|██████████| 11332/11332 [54:58<00:00,  3.44it/s, loss=0.00357681,MSE=0.00243489,RMSE=0.04812039,MAE=0.03388047,ABS_REL=0.03879489, DELTA1.02=0.41192950, DELTA1.05=0.75313748]  
test epoch 0: 100%|██████████| 539/539 [02:15<00:00,  3.97it/s, loss=0.00346345,MSE=0.00243341,RMSE=0.04566022,MAE=0.03480058,ABS_REL=0.03938123, DELTA1.02=0.39687968, DELTA1.05=0.73574014]
train epoch 1: 100%|██████████| 11332/11332 [54:00<00:00,  3.50it/s, loss=0.00265416,MSE=0.00180431,RMSE=0.04151762,MAE=0.02932507,ABS_REL=0.03359532, DELTA1.02=0.46572401, DELTA1.05=0.79867962] 
test epoch 1: 100%|██████████| 539/539 [02:15<00:00,  3.99it/s, loss=0.00299351,MSE=0.00205106,RMSE=0.04073735,MAE=0.03024330,ABS_REL=0.03457515, DELTA1.02=0.48265217, DELTA1.05=0.79656483]
train epoch 2: 100%|██████████| 11332/11332 [53:59<00:00,  3.50it/s, loss=0.00172227,MSE=0.00116465,RMSE=0.03345234,MAE=0.02289120,ABS_REL=0.02605824, DELTA1.02=0.56913731, DELTA1.05=0.86648042] 
test epoch 2: 100%|██████████| 

KeyboardInterrupt: 

In [2]:
model.eval()
torch.save(model.state_dict(), f"./model_1/model_v1.pth")

In [None]:
!nvidia-smi