In [4]:
import torch
import torch.nn.functional as F
import torch.nn as nn
import torchvision.models as models
import torch_geometric.nn as pyg_nn
import torch_geometric.utils as pyg_utils
import torchvision.transforms.functional as TF
import cv2
import matplotlib.pyplot as plt
from tensorboardX import SummaryWriter
import numpy as np


In [5]:
class MaxPool(nn.Module):
    def __init__(self, pool_size):
        super(MaxPool, self).__init__()
        self.pool = nn.MaxPool2d(kernel_size=pool_size, stride=pool_size)

    def forward(self, x):
        # print("x shape ", x.shape)
        return self.pool(x)


class Noise(nn.Module):
    def __init__(self, R_scale):
        super(Noise, self).__init__()
        self.mean = 0
        self.stdev = 1  # as defined in the paper

    def forward(self, d_coarse):
        noise = torch.rand_like(d_coarse)*self.stdev + self.mean
        d_noised = d_coarse + noise
        return d_noised


class IntervalThreshold(nn.Module):
    def __init__(self, m, n):
        super(IntervalThreshold, self).__init__()
        self.m = m
        self.n = n

    def forward(self, d_pool):
        threshold = (torch.max(d_pool) - torch.min(d_pool))/min(self.m, self.n)
        return threshold


class ReconGraph(nn.Module):
    def __init__(self, m, n):
        super(ReconGraph, self).__init__()
        self.m = m
        self.n = n

    def forward(self, d_noised, threshold):
        neighbours = set()
        labels = {}

        count = 0
        # print(self.m, self.n)
        for i in range(self.m):
            for j in range(self.n):

                labels[(j, i)] = count  # Labeling each pixel in (x, y) form
                count += 1
                for dy in range(-1, 2):
                    for dx in range(-1, 2):

                        if dx != 0 and dy != 0 and i+dy >= 0 and i+dy < self.m and j+dx >= 0 and j+dx < self.n:
                            if abs(d_noised[0][i+dy][j+dx] - d_noised[0][i][j]) <= threshold:
                                # (x, y) format
                                neighbours.add(((j, i), (j+dx, i+dy)))
        adjacency_matrix = torch.zeros(
            (self.m*self.n, self.m*self.n), dtype=bool)
        # print(adjacency_matrix.shape)

        for val in neighbours:
            N1, N2 = val  # in (x, y) form
            N1_x, N1_y = N1
            N2_x, N2_y = N2

            l1 = labels[(N1_x, N1_y)]
            l2 = labels[(N2_x, N2_y)]

            # Symmetric connections
            adjacency_matrix[l1, l2] = 1
            adjacency_matrix[l2, l1] = 1

        return adjacency_matrix


class GraphDropout(nn.Module):
    def __init__(self, p=0.5) -> None:
        super(GraphDropout, self).__init__()
        self.p = p

    def forward(self, adjacency_matrix):
        if self.train:
            mask = torch.empty_like(adjacency_matrix).bernoulli_(1 - self.p)
            output = adjacency_matrix * mask

        else:
            output = adjacency_matrix

        return output


In [6]:
class ExtractGraph(nn.Module):

    def __init__(self) -> None:
        super(ExtractGraph, self).__init__()

        self.maxpool = MaxPool(pool_size=2)
        self.noise = Noise(R_scale=0.4)  # From paper results
        self.dropout = GraphDropout(p=0.5)

    def forward(self, d_coarse, R_scale):

        # print('d_coarse: ', d_coarse.shape, 'type: ', d_coarse.dtype)
        d_pool = self.maxpool.forward(d_coarse)
        m = d_pool.shape[1]
        n = d_pool.shape[2]
        self.interval_threshold = IntervalThreshold(m, n)
        self.recon_graph = ReconGraph(m, n)

        # print("pooled shape ", d_pool.shape)
        d_noise = self.noise.forward(d_pool)
        threshold = self.interval_threshold.forward(d_pool)
        adjacency_matrix = self.recon_graph.forward(d_noise, threshold)
        adjacency_matrix = self.dropout.forward(adjacency_matrix)

        return adjacency_matrix


In [7]:
class Encoder(nn.Module):

    def __init__(self) -> None:
        super(Encoder, self).__init__()
        encoder = models.resnet.resnet50(
            weights=models.ResNet50_Weights.DEFAULT)
        encoder = nn.Sequential(*list(encoder.children()))[:3]
        self.resnet_encoder = encoder

    def forward(self, x):
        self.resnet_encoder.eval()
        return self.resnet_encoder(x)


In [8]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device

device(type='cpu')

In [9]:
from torchvision.transforms import transforms
from PIL import Image
from torch.utils.data import DataLoader

preprocessing_transform_2 = transforms.Compose([
    transforms.ToTensor()
])

In [10]:
from gcn_depth_dataloader import GCNDepthDataLoader

# for train set
nyu_dataset_train = GCNDepthDataLoader(mode='train',
                                       image_folder='../dataset/dataset/nyu_depth_v2/official_splits/train/rgb',
                                       depth_folder='../dataset/dataset/nyu_depth_v2/official_splits/train/depth',
                                       transform=preprocessing_transform_2)

nyu_dataset_test = GCNDepthDataLoader(mode='test',
                                      image_folder='../dataset/dataset/nyu_depth_v2/official_splits/test/rgb',
                                      depth_folder='../dataset/dataset/nyu_depth_v2/official_splits/test/depth',
                                      transform=preprocessing_transform_2)


In [11]:
midas_model_type = "MiDaS_small"
midas = torch.hub.load("intel-isl/MiDaS", midas_model_type)
midas.to(device)
midas.eval()

midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms")

if midas_model_type == "DPT_Large" or midas_model_type == "DPT_Hybrid":
    midas_transform = midas_transforms.dpt_transform
else:
    midas_transform = midas_transforms.small_transform

Using cache found in /Users/adityadandwate/.cache/torch/hub/intel-isl_MiDaS_master


Loading weights:  None


Using cache found in /Users/adityadandwate/.cache/torch/hub/rwightman_gen-efficientnet-pytorch_master
Using cache found in /Users/adityadandwate/.cache/torch/hub/intel-isl_MiDaS_master


In [12]:
from torch.utils.data import Dataset, DataLoader


class GraphDataLoader(Dataset):
    def __init__(self, dataset, transform=None) -> None:
        super().__init__()
        self.dataset = dataset
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):
        rgb, true_depth = self.dataset.__getitem__(index)
        features, adj_matrix, true_adj_matrix = self.graph_extract(
            rgb, true_depth)
        print('working')
        return features, adj_matrix, true_adj_matrix, true_depth

    def graph_extract(self, rgb, true_depth):
        extractor = ExtractGraph()
        encoder = Encoder()

        rgb = rgb.permute(1, 2, 0)*255
        rgb = rgb.numpy()
        # print('rgb shape ', rgb.shape, ' true depth shape ', true_depth.shape)
        rgb = midas_transform(rgb).to(device)  # (1, C, H, W)

        # print('rgb shape after midas transform ', rgb.shape)
        with torch.no_grad():
            depth_map = midas(rgb)  # (C, H, W) where C = 1
            down_rgb = encoder.forward(rgb)

        # print('midas depth shape: ',depth_map.shape, ' encoded shape: ', down_rgb.shape)
        target_size = down_rgb.shape[2:]
        num_downsampled_channels = down_rgb.shape[1]
        # Maxpool will downsample by half further
        target_size = [x*2 for x in target_size]
        # print('targetsize: ', target_size)
        resize_transform = transforms.Resize(target_size)
        depth_map = depth_map
        # Downsample midas output to (192, 256) using bilinear interpolation
        depth_map = resize_transform(depth_map)

        true_depth_map = resize_transform(true_depth)
        true_depth_map = true_depth_map.to(device=device).to(torch.float32)

        # print("true depth shape ", true_depth_map.shape, " depth map shape ", depth_map.shape)
        adjacency_matrix = extractor.forward(depth_map, 0.4)
        true_adjacency_matrix = extractor.forward(true_depth_map, 0.4)
        # print('pred mat shape: ', adjacency_matrix.shape)
        # print('true mat shape: ', true_adjacency_matrix.shape)

        # shape will be (64, 120*160)
        node_features = torch.reshape(down_rgb, (num_downsampled_channels, -1))
        node_features = node_features.t() # transpose to shape (N, D)

        return node_features, adjacency_matrix, true_adjacency_matrix


In [13]:
nyu_dataset_processes_train = GraphDataLoader(
    nyu_dataset_train, transform=preprocessing_transform_2)
nyu_dataset_processes_test = GraphDataLoader(
    nyu_dataset_test, transform=preprocessing_transform_2)

processed_train_dataloader = DataLoader(nyu_dataset_processes_train, batch_size=32, shuffle=True)
processed_test_dataloader = DataLoader(nyu_dataset_processes_test, batch_size=32, shuffle=True)


In [21]:
def convert_to_edgeindex(matrix):

    # print(features.shape)

    batch_size, num_nodes, _ = matrix.size()
    print('num of nodes ', num_nodes)
    edge_indices = []
    num_nodes = [num_nodes for i in range(batch_size)]


    for i in range(batch_size):
        adj = matrix[i]
        adj_coo = adj.to_sparse().coalesce()
        edge_index = adj_coo.indices()
        edge_indices.append(edge_index)
        
    # graph_data = [pyg_Data(edge_index=e, x=x)
    #               for e, x in zip(edge_indices, features)]

    return edge_indices


In [22]:
from torch_geometric.data import InMemoryDataset as pyg_Dataset, download_url
from torch_geometric.data import Data as pyg_Data
from torch_geometric.data import Batch as pyg_Batch
from torch_geometric.loader import DataLoader as pyg_Loader
from pathlib import Path 
from tqdm import tqdm

class GraphDataset(pyg_Dataset):

    def __init__(self, root = None, transform = None, pre_transform = None, pre_filter = None, log: bool = True):
        super().__init__(root, transform, pre_transform, pre_filter, log)
        self.data_list = []
        # self.intermed_edge_indices = []
        self.root = root
        self.transform = transform
        self.pre_transform = pre_transform
        self.pre_filter = pre_filter
        self.log = log
        print(self.processed_dir)
        # self.true_edge_indices = []

    @property
    def processed_file_names(self):
        """ List all pytorch files available in the root directory.
        A list of files in the processed_dir which needs to be found in order to skip the processing.
        Those files are created by the process functions and are graph that can be used for training and evaluation.
        """
        return [str(p.with_suffix(".pt").name) for p in Path(self.root).glob("*.vtk")]

    def process(self):
        
        print(self.processed_dir)
        idx = 0
        
        for img_batch in tqdm(processed_train_dataloader):
            feature, intermed_adj_matrix, true_adj_matrix, true_depth = img_batch # batch dim is 32

            # Feature shape should be (N, D) check if it in fact of that shape
            true_edge_indices, intermed_edge_indices =self.convert_to_edgeindex(intermed_adj_matrix, true_adj_matrix)

            for i in range(len(true_edge_indices)):
                true_edge_index = true_edge_indices[i]
                intermed_edge_index = intermed_edge_indices[i]

                file_save_path = Path(self.processed_dir) / f"{idx}.pt"

                x = feature[i]
                true_d = true_depth[i]

                data = pyg_Data(
                    x=x,
                    edge_index=intermed_edge_index,
                    true_depth = true_d,
                    true_edge_index = true_edge_index
                )

                torch.save(data, file_save_path)
                idx += 1

    def len(self):
        return len(self.processed_file_names)

    def get(self, idx):
        data = torch.load(Path(self.processed_dir) / f'{idx}.pt')
        return data

    def convert_to_edgeindex(self, intermed_matrix, true_matrix):

        # print(features.shape)
        batch_size, num_nodes, _ = intermed_matrix.size()
        true_edge_indices = []
        intermed_edge_indices = []
        num_nodes = [num_nodes for i in range(batch_size)]

        for i in range(batch_size):
            intermed_adj = intermed_matrix[i]
            intermed_adj_coo = intermed_adj.to_sparse().coalesce()
            intermed_edge_index = intermed_adj_coo.indices()

            true_adj = true_matrix[i]
            true_adj_coo = true_adj.to_sparse().coalesce()
            true_edge_index = true_adj_coo.indices()

            true_edge_indices.append(true_edge_index)
            intermed_edge_indices.append(intermed_edge_index)

        return true_edge_indices, intermed_edge_indices
    


In [23]:
dataset = GraphDataset(root = '../dataset/graphs/')

Processing...


../dataset/graphs/processed


  0%|          | 0/25 [00:00<?, ?it/s]

working
working
working
working
working
working
working
working
working


  0%|          | 0/25 [00:56<?, ?it/s]


KeyboardInterrupt: 

In [None]:
# Ground truth depth is of shape (1, 480, 640)

In [20]:
data_inst = torch.load('../dataset/graphs/processed/0.pt')

Data(x=[12288, 64], edge_index=[2, 15899], true_depth=[1, 480, 640], true_edge_index=[2, 18372])