In [None]:
import numpy as np
import pandas as pd

import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F

import math
import random
from copy import deepcopy
import json
import time
from tqdm import tqdm

from models.matching import OptimalMatching
from models.backbone import R2U_Net, NonMaxSuppression, DetectionBranch
from utils.utils import scores_to_permutations, permutations_to_polygons

import matplotlib.pyplot as plt

import geopandas as gpd
from shapely.geometry import Polygon
from rasterio import features

from skimage import io
from skimage.transform import resize
import cv2
from PIL import Image


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
device

# Data

### RoofTop

In [None]:
from pycocotools.coco import COCO
from pathlib import Path
import os
from pycocotools import mask as cocomask


def clean(indices):
    cleaned_indices = {}
    k = 0
    for i, (x,y) in enumerate(indices[:-1]):
        if (x,y) not in cleaned_indices:
            cleaned_indices[(x,y)] = k
            k +=1
    
    for k, v in cleaned_indices.items():
        indices[v] = list(k)

    indices = indices[:len(cleaned_indices)]
    indices = np.append(indices, indices[0])
    return indices.reshape(-1,2)





def collate_fn(batch):
    return tuple(zip(*batch))


    


class RoofTopImage (Dataset):
    def __init__(self,
                 s3_image_bucket,
                 prefix, 
                 annotation_path, 
                 max_corner_points=256, 
                 img_size=320, 
                 batchsize = 0,
                 prediction = False, 
                 prediction_file_path=None,
                 load_type = 'train',
                 cache_data = 'True',
                 cache_path='../data'
                ):
        self.s3_image_bucket = s3_image_bucket
        self.prefix = prefix
        self.load_type = load_type
        self.cache_data = cache_data
        self.cache_path = cache_path

        self.annotation_path = annotation_path
        self.prediction = prediction
        self.img_annotations = COCO(self.annotation_path)
        self.segs = []
        self.batchsize = batchsize
        
        if self.prediction:
            print('Loading prediction data ...', prediction_file_path)
            prediction_file = json.loads(open(prediction_file_path).read())
            self.img_annotations = self.img_annotations.loadRes(prediction_file)
        
        self.cat_id = self.img_annotations.getCatIds()
        self.img_ids = self.img_annotations.getImgIds(catIds= self.cat_id)
        self.img_ids = random.sample(range(0, len(self.img_ids)), 10000)
        

                         
        self.window_size = img_size
        self.corner_points = max_corner_points

        if self.cache_data:
            Path(os.path.join(self.cache_path, self.load_type)).mkdir(parents=True, exist_ok=True)
            
        
    def __len__(self):
        return len(self.img_ids)
    
    def __getitem__(self, indx):
        idx = self.img_ids[indx]
        img = self.img_annotations.loadImgs(idx)[0]
        
        img_path = os.path.join(os.path.join(self.cache_path,self.load_type), img['file_name'])
        
    

        # image = io.imread(img_path)
        # orig_img_size = image.shape[0]
        # image = resize(image, (self.window_size, self.window_size, 3), anti_aliasing=True, preserve_range=True)
        # image = torch.from_numpy(image).to(torch.float32)
        # image = image.permute(2,0,1) / 255.0
        
        
        ann_ids = self.img_annotations.getAnnIds(imgIds=idx)
        coco_annotations = self.img_annotations.loadAnns(ann_ids)
        

        num_objs = len(coco_annotations)
        random.shuffle(coco_annotations)

        corner_mask = np.zeros((self.window_size, self.window_size))
        mask = np.zeros((self.window_size, self.window_size))

        gt_index = {}
        k = 0
        gt_indices = []
        gt_permutation_matrix = np.zeros((self.corner_points, self.corner_points))
 
        for i in range(num_objs):
            skip = False
            corner_points =  np.flip(np.array(coco_annotations[i]['segmentation'][0]).reshape(-1,2),1)
            #indices = corner_points.round().astype('int64')
            point_pair = corner_points/ (300/self.window_size)
            indices = point_pair.round().astype('int64').clip(0, self.window_size - 1)
            
            indices = clean(indices)
            if indices.shape[0]-1  < 3:
                continue

            if indices[0][0] != indices[-1][0] or indices[0][1] != indices[-1][1]:
                continue

            for x,y in indices[:-1]:
                if (x,y) not in gt_index:
                    gt_index[(x,y)] = k
                    k+=1
                else:
                    #seen before
                    skip = True
                    break
            if skip:
                continue
                
            
            corner_mask[indices[:-1][:,0], indices[:-1][:,1]] = 1
            gt_indices.append(indices.tolist())
            
            rle = cocomask.frPyObjects(coco_annotations[i]['segmentation'],self.window_size,self.window_size)
            m = cocomask.decode(rle).astype('float32')
            m = m.reshape((self.window_size, self.window_size))
            mask = mask + m.reshape((self.window_size, self.window_size))
        
        vertices = [v for polygon in gt_indices for v in polygon[:-1]]
        num_vertices = 0
       
        for polygon in gt_indices: # take care of cases where gt_indices in empty
            n = len(polygon)
            num_vertices += n-1
            # iterate through each vertex and its corresponding adjacent vertex
            for i in range(n-1):
                v1 = polygon[i]
                v2 = polygon[(i + 1) % (n-1)]
                gt_permutation_matrix[gt_index[tuple(v1)]][gt_index[tuple(v2)]] = 1 
                
        #vertices = np.array(vertices)
        gt_permutation_matrix[range(num_vertices, self.corner_points), range(num_vertices, self.corner_points)] =1
        #image_idx = torch.tensor([idx])
       
        return torch.from_numpy(gt_permutation_matrix), np.array(vertices), gt_index, torch.from_numpy(corner_mask), torch.from_numpy(mask)

In [None]:
IMAGE_BUCKET="sundial-geometric-roof-inference"

TRAIN_IMAGE_PREFIX= 'data/val/images'
# VAL_IMAGE_PREFIX= 'data/val/images' 


TRAIN_ANNOTATION_PATH = 'data/train/annotation.json'
# VAL_ANNOTATION_PATH = 'data/val/annotation.json'

PREDICTION_PATH = 'raw_val_predictions.json'

BATCH_SIZE = 6
MAX_CORNER_POINTS = 256
NUM_EPOCHS = 10
INIT_LR = 0.001
LAMBDA = 1000
IMG_SIZE = 320
BASE_OUTPUT = "output"






train_data = RoofTopImage(IMAGE_BUCKET,
                          TRAIN_IMAGE_PREFIX,
                          TRAIN_ANNOTATION_PATH,
                          MAX_CORNER_POINTS, 
                          IMG_SIZE,
                          BATCH_SIZE,
                          load_type = 'train'
                         )

trainloader = DataLoader(train_data,
                        batch_size=BATCH_SIZE,
                        shuffle=True,
                        num_workers=0,
                        pin_memory=False,
                        #prefetch_factor = BATCH_SIZE,
                        drop_last=True,
                        collate_fn=collate_fn
                    )

In [None]:
def sort_sync_nsm_points(nms_graph, vertices,gt_index): 
    B, N, D= nms_graph.shape #1,256,2
    sorted_nsm_points = np.zeros((B,N,D), dtype=int)
    nms_graph = nms_graph.detach().cpu().numpy()
    
    for b in range(B):
        sorted_nsm = np.zeros((N,D), dtype=int)

        #vertices = np.array([v for polygon in gt_indices[b] for v in polygon[:-1]])        
        n = vertices[b].shape[0]
        m = nms_graph[b].shape[0]
    
        distances = np.linalg.norm(vertices[b][:, None] - nms_graph[b], axis=2)
        distances = distances.reshape(n*m, 1)
        
        distances = np.hstack((distances,np.repeat(vertices[b], m, axis=0),np.tile(nms_graph[b], (n, 1))))
        sorted_distance = distances[np.argsort(distances[:,0])]
        
        # Sort distances by the first column (distance)
        sorted_distances = distances[np.argsort(distances[:,0])]

        cndd_used = set()
        gt_used = set()
        cndd_mapped = {tuple(cndd):0 for cndd in nms_graph[b]}
       
        for d_p in sorted_distances:
            gt_p = tuple((d_p[1], d_p[2]))
            cndd_p = tuple((d_p[3], d_p[4]))
            if gt_p not in gt_used and cndd_p not in cndd_used:
                #print('we have a match ..', gt_p ,'->', cndd_p, ' with distance of ', d_p[0])
                sorted_nsm[gt_index[b][gt_p]]= list(cndd_p)
                gt_used.add(gt_p)
                cndd_used.add(cndd_p)
                cndd_mapped[cndd_p] =1
                                
        restart_index = n
        for k, v in cndd_mapped.items():
            if v ==0:
                sorted_nsm[restart_index] = list(k)   
                restart_index +=1
        sorted_nsm_points[b] = sorted_nsm 
    return torch.from_numpy(sorted_nsm_points).to(device)

def prepare_gt_vertices(vertices):
    B = len(vertices)
    v_gt = torch.empty((BATCH_SIZE, MAX_CORNER_POINTS, 2), dtype=torch.float64)
    for b in range(B):
        gt_size = vertices[b].shape[0]
        extra = torch.full((MAX_CORNER_POINTS - gt_size, 2), 0, dtype = torch.float64)
        extra_gt = torch.cat((torch.from_numpy(vertices[b]), extra), dim=0).to(device)
        v_gt[b] = extra_gt
    return v_gt.to(device)

In [None]:
permutation_matrix, vertices, gt_index, corner_mask, mask = next(iter(trainloader))

In [None]:
plt.imshow(permutation_matrix[5])

In [None]:
plt.imshow(mask[0])

In [None]:
plt.imshow(corner_mask[0])

In [None]:
corner_mask[0].shape

In [None]:
vertices[0]

In [None]:
gt_index

In [None]:
vertices[0].shape

In [None]:
sorted_nsm_points = sort_sync_nsm_points(nms_graph, vertices, gt_index)

In [None]:
prepare_gt_vertices(vertices).shape

In [None]:
prepare_gt_vertices(vertices)[0]

In [None]:
torch.stack(permutation_matrix).shape

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(20, 5))
axes[0].imshow(corner_mask[0])
axes[0].set_title("Corner Mask")

axes[1].imshow(mask[0])
axes[1].set_title("Mask")

axes[2].imshow(permutation_matrix[0])
axes[2].set_title("Permutation Matrix")

### CrowdAI

In [None]:
class CrowdAI(Dataset):

    def __init__(self, images_directory, annotations_path):

        self.IMAGES_DIRECTORY = images_directory
        self.ANNOTATIONS_PATH = annotations_path
        
        self.window_size = 320
        self.max_points = 256

        # load annotation json
        with open(self.ANNOTATIONS_PATH) as f:
            self.annotations = json.load(f)
        
        self.images = pd.DataFrame(self.annotations['images'])[:4]
        self.labels = pd.DataFrame(self.annotations['annotations'])








    def _create_permutation_matrix(self, segmentations, N=256):

        permutation_matrix = torch.zeros((N, N), dtype=torch.uint8)

        n = 0
        for i, polygon in enumerate(segmentations):
            for v, point in enumerate(polygon):
                if v != len(polygon) - 1:
                    permutation_matrix[n, n+1] = 1
                else:
                    permutation_matrix[n, n-v] = 1
                n += 1
        for i in range(n, N):
            permutation_matrix[i, i] = 1

        return permutation_matrix
        


    def _create_segmentation_mask(self, polygons, image_size):
        mask = np.zeros((image_size, image_size), dtype=np.uint8)
        for polygon in polygons:
            cv2.fillPoly(mask, [polygon], 1)
        return torch.tensor(mask, dtype=torch.uint8)



    def _create_vertex_mask(self, polygons, image_shape=(320, 320)):
        mask = torch.zeros(image_shape, dtype=torch.uint8)

        for poly in polygons:
            for p in poly:
                mask[p[1], p[0]] = 1

        return mask








    def __len__(self):
        return len(self.images)







    def __getitem__(self, idx):

        image = io.imread(self.IMAGES_DIRECTORY + self.images['file_name'][idx])
        image = resize(image, (self.window_size, self.window_size), anti_aliasing=True)
        image = torch.from_numpy(image)
        width, height = self.images['width'][idx], self.images['height'][idx]
        ratio = self.window_size / max(width, height)



        # Get the image ID
        image_id = self.images['id'][idx]
        # Get all annotations for this image
        image_annotations = self.labels[self.labels['image_id'] == image_id]
        # get all polygons for the image
        segmentations = image_annotations['segmentation'].values
        segmentations = [e[0] for e in segmentations]
        for i, poly in enumerate(segmentations):
            # rescale the polygon
            poly = [int(e * ratio) for e in poly]
            # out of bounds check
            for j, e in enumerate(poly):
                if j % 2 == 0:
                    poly[j] = min(max(0, e), self.window_size - 1)
                else:
                    poly[j] = min(max(0, e), self.window_size - 1)
            segmentations[i] = poly



        # print(segmentations)
        segmentations = [np.array(poly, dtype=int).reshape(-1, 2) for poly in segmentations] # convert a list of polygons to a list of numpy arrays of points
        # print(segmentations)





        # create permutation matrix
        permutation_matrix = self._create_permutation_matrix(segmentations, N=self.max_points)
        # create vertex mask
        vertex_mask = self._create_vertex_mask(segmentations, image_shape=(self.window_size, self.window_size))
        # create segmentation mask
        seg_mask = self._create_segmentation_mask(segmentations, image_size=self.window_size)
        # print(torch.topk(torch.tensor(vertex_mask.flatten()), self.max_points))


        segmentations = [torch.from_numpy(poly) for poly in segmentations]





        return image, vertex_mask, seg_mask, permutation_matrix, segmentations

In [None]:
def collate_fn(batch):
    return tuple(zip(*batch))

In [None]:
# Initiate the dataloader
batch_size = 2
dataset = CrowdAI(images_directory='data/val/images/', annotations_path='data/val/annotation.json')
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4, collate_fn=collate_fn)

In [None]:
image, gt_vertex_mask, gt_seg_mask, gt_permutation_matrix, gt_polygons = next(iter(dataloader))
image, gt_vertex_mask, gt_seg_mask, gt_permutation_matrix = torch.stack(image), torch.stack(gt_vertex_mask), torch.stack(gt_seg_mask), torch.stack(gt_permutation_matrix)

In [None]:
print('Image: ', image.shape)
print('Vertex: ', gt_vertex_mask.shape)
print('Segmentation: ', gt_seg_mask.shape)
print('Permutation: ', gt_permutation_matrix.shape)

In [None]:
len(gt_polygons)

In [None]:
# plot all the masks and the images in the batch
B = len(image)
fig, axes = plt.subplots(B, 4, figsize=(21, 7*B))
for i in range(B):
    axes[i, 0].imshow(image[i])
    axes[i, 1].imshow(gt_seg_mask[i])
    axes[i, 2].imshow(gt_vertex_mask[i])
    axes[i, 3].imshow(gt_permutation_matrix[i])

    
plt.tight_layout()
plt.show()

In [None]:
# plot all the images, vertex masks, and scatter plots in the batch
B = len(image)


fig, axes = plt.subplots(B, 3, figsize=(21, 7*B))

for i in range(B):
    # Original image
    axes[i, 0].imshow(image[i])
    axes[i, 0].set_title(f'Original Image {i+1}')
    axes[i, 0].axis('off')

    # Vertex mask as image
    axes[i, 1].imshow(gt_vertex_mask[i], cmap='viridis')
    axes[i, 1].set_title(f'Vertex Mask {i+1}')
    axes[i, 1].axis('off')

    # Scatter plot of vertex mask on original image
    axes[i, 2].imshow(image[i])
    y, x = np.nonzero(gt_vertex_mask[i].numpy())
    scatter = axes[i, 2].scatter(x, y, c=gt_vertex_mask[i].numpy()[y, x], s=20, alpha=0.5, cmap='Oranges_r')
    axes[i, 2].set_title(f'Overlay {i+1}')
    axes[i, 2].axis('off')

plt.tight_layout()
plt.show()

# Network

## Vertex Detection

In [None]:
class DetectionBranch(nn.Module):
    def __init__(self):
        super(DetectionBranch,self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(64, 64, kernel_size=1,stride=1,padding=0,bias=True),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 1, kernel_size=1,stride=1,padding=0,bias=True)
        )

    def forward(self,x):
        x = self.conv(x)
        return x

### Blocks

In [None]:
class up_conv(nn.Module):
    def __init__(self,ch_in,ch_out):
        super(up_conv,self).__init__()
        self.up = nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.Conv2d(ch_in,ch_out,kernel_size=3,stride=1,padding=1,bias=True),
		    nn.BatchNorm2d(ch_out),
			nn.ReLU(inplace=True)
        )

    def forward(self,x):
        x = self.up(x)
        return x


class Recurrent_block(nn.Module):
    def __init__(self,ch_out,t=2):
        super(Recurrent_block,self).__init__()
        self.t = t
        self.ch_out = ch_out
        self.conv = nn.Sequential(
            nn.Conv2d(ch_out,ch_out,kernel_size=3,stride=1,padding=1,bias=True),
		    nn.BatchNorm2d(ch_out),
			nn.ReLU(inplace=True)
        )

    def forward(self,x):
        for i in range(self.t):

            if i==0:
                x1 = self.conv(x)
            
            x1 = self.conv(x+x1)
        return x1

        
class RRCNN_block(nn.Module):
    def __init__(self,ch_in,ch_out,t=2):
        super(RRCNN_block,self).__init__()
        self.RCNN = nn.Sequential(
            Recurrent_block(ch_out,t=t),
            Recurrent_block(ch_out,t=t)
        )
        self.Conv_1x1 = nn.Conv2d(ch_in,ch_out,kernel_size=1,stride=1,padding=0)

    def forward(self,x):
        x = self.Conv_1x1(x)
        x1 = self.RCNN(x)
        return x+x1

### Backbone

In [None]:
class R2U_Net(nn.Module):
    def __init__(self,img_ch=3,t=1):
        super(R2U_Net,self).__init__()
        
        self.Maxpool = nn.MaxPool2d(kernel_size=2,stride=2)
        self.Upsample = nn.Upsample(scale_factor=2)

        self.RRCNN1 = RRCNN_block(ch_in=img_ch,ch_out=64,t=t)
        self.RRCNN2 = RRCNN_block(ch_in=64,ch_out=128,t=t)
        self.RRCNN3 = RRCNN_block(ch_in=128,ch_out=256,t=t)
        self.RRCNN4 = RRCNN_block(ch_in=256,ch_out=512,t=t)
        self.RRCNN5 = RRCNN_block(ch_in=512,ch_out=1024,t=t)
        
        self.Up5 = up_conv(ch_in=1024,ch_out=512)
        self.Up_RRCNN5 = RRCNN_block(ch_in=1024, ch_out=512,t=t)
        
        self.Up4 = up_conv(ch_in=512,ch_out=256)
        self.Up_RRCNN4 = RRCNN_block(ch_in=512, ch_out=256,t=t)
        
        self.Up3 = up_conv(ch_in=256,ch_out=128)
        self.Up_RRCNN3 = RRCNN_block(ch_in=256, ch_out=128,t=t)
        
        self.Up2 = up_conv(ch_in=128,ch_out=64)
        self.Up_RRCNN2 = RRCNN_block(ch_in=128, ch_out=64,t=t)


    def forward(self,x):
        # encoding path
        x1 = self.RRCNN1(x)

        x2 = self.Maxpool(x1)
        x2 = self.RRCNN2(x2)
        
        x3 = self.Maxpool(x2)
        x3 = self.RRCNN3(x3)

        x4 = self.Maxpool(x3)
        x4 = self.RRCNN4(x4)

        x5 = self.Maxpool(x4)
        x5 = self.RRCNN5(x5)

        # decoding + concat path
        d5 = self.Up5(x5)
        d5 = torch.cat((x4,d5),dim=1)
        d5 = self.Up_RRCNN5(d5)
        
        d4 = self.Up4(d5)
        d4 = torch.cat((x3,d4),dim=1)
        d4 = self.Up_RRCNN4(d4)

        d3 = self.Up3(d4)
        d3 = torch.cat((x2,d3),dim=1)
        d3 = self.Up_RRCNN3(d3)

        d2 = self.Up2(d3)
        d2 = torch.cat((x1,d2),dim=1)
        d2 = self.Up_RRCNN2(d2)

        return d2

### NMS

In [None]:
# This block is not diffrentiable

class NonMaxSuppression(nn.Module):
    def __init__(self, n_peaks=256):
        super(NonMaxSuppression,self).__init__()
        self.k = 3 # kernel
        self.p = 1 # padding
        self.s = 1 # stride
        self.center_idx = self.k**2//2
        self.sigmoid = nn.Sigmoid()
        self.unfold = nn.Unfold(kernel_size=self.k, padding=self.p, stride=self.s)
        self.n_peaks = n_peaks

    def sample_peaks(self, x):
        B, _, H, W = x.shape
        for b in range(B):
            x_b = x[b,0]
            idx = torch.topk(x_b.flatten(), self.n_peaks).indices
            idx_i = torch.div(idx, W, rounding_mode='floor')
            idx_j = idx % W
            idx = torch.cat((idx_i.unsqueeze(1), idx_j.unsqueeze(1)), dim=1)
            idx = idx.unsqueeze(0)

            if b == 0:
                graph = idx
            else:
                graph = torch.cat((graph, idx), dim=0)

        return graph 

    def forward(self, feat):
        B, C, H, W = feat.shape

        x = self.sigmoid(feat)

        # Prepare filter
        f = self.unfold(x).view(B, self.k**2, H, W)
        f = torch.argmax(f, dim=1).unsqueeze(1)
        f = (f == self.center_idx).float()

        # Apply filter
        x = x * f

        # Sample top peaks
        graph = self.sample_peaks(x)
        return x, graph

## GNN

In [None]:
def MultiLayerPerceptron(channels: list, batch_norm=True):
    n_layers = len(channels)

    layers = []
    for i in range(1, n_layers):
        layers.append(nn.Conv1d(channels[i - 1], channels[i], kernel_size=1, bias=True))

        if i < (n_layers - 1):
            if batch_norm:
                layers.append(nn.BatchNorm1d(channels[i]))
            layers.append(nn.ReLU())

    return nn.Sequential(*layers)


class Attention(nn.Module):

    def __init__(self, n_heads: int, d_model: int):
        super().__init__()
        assert d_model % n_heads == 0
        self.dim = d_model // n_heads
        self.n_heads = n_heads
        self.merge = nn.Conv1d(d_model, d_model, kernel_size=1)
        self.proj = nn.ModuleList([deepcopy(self.merge) for _ in range(3)])

    def forward(self, query, key, value):
        b = query.size(0)
        query, key, value = [l(x).view(b, self.dim, self.n_heads, -1)
                             for l, x in zip(self.proj, (query, key, value))]

        b, d, h, n = query.shape
        scores = torch.einsum('bdhn,bdhm->bhnm', query, key) / d**.5
        attn = torch.einsum('bhnm,bdhm->bdhn', torch.nn.functional.softmax(scores, dim=-1), value)

        return self.merge(attn.contiguous().view(b, self.dim*self.n_heads, -1))


class AttentionalPropagation(nn.Module):

    def __init__(self, feature_dim: int, n_heads: int):
        super().__init__()
        self.attn = Attention(n_heads, feature_dim)
        self.mlp = MultiLayerPerceptron([feature_dim*2, feature_dim*2, feature_dim])
        nn.init.constant_(self.mlp[-1].bias, 0.0)

    def forward(self, x):
        message = self.attn(x, x, x)
        return self.mlp(torch.cat([x, message], dim=1))


class AttentionalGNN(nn.Module):

    def __init__(self, feature_dim: int, num_layers: int):
        super().__init__()
        self.conv_init = nn.Sequential(
            nn.Conv1d(feature_dim + 2, feature_dim, kernel_size=1,stride=1,padding=0,bias=True),
            nn.BatchNorm1d(feature_dim),
            nn.ReLU(inplace=True),
            nn.Conv1d(feature_dim, feature_dim, kernel_size=1,stride=1,padding=0,bias=True),
            nn.BatchNorm1d(feature_dim),
            nn.ReLU(inplace=True),
            nn.Conv1d(feature_dim, feature_dim, kernel_size=1,stride=1,padding=0,bias=True)
        )

        self.layers = nn.ModuleList([
            AttentionalPropagation(feature_dim, 4)
            for _ in range(num_layers)])

        self.conv_desc = nn.Sequential(
            nn.Conv1d(feature_dim, feature_dim, kernel_size=1,stride=1,padding=0,bias=True),
            nn.BatchNorm1d(feature_dim),
            nn.ReLU(inplace=True),
            nn.Conv1d(feature_dim, feature_dim, kernel_size=1,stride=1,padding=0,bias=True),
            nn.BatchNorm1d(feature_dim),
            nn.ReLU(inplace=True)
        )

        self.conv_offset = nn.Sequential(
            nn.Conv1d(feature_dim, feature_dim, kernel_size=1,stride=1,padding=0,bias=True),
            nn.BatchNorm1d(feature_dim),
            nn.ReLU(inplace=True),
            nn.Conv1d(feature_dim, feature_dim, kernel_size=1,stride=1,padding=0,bias=True),
            nn.BatchNorm1d(feature_dim),
            nn.ReLU(inplace=True),
            nn.Conv1d(feature_dim, 2, kernel_size=1,stride=1,padding=0,bias=True),
            nn.Hardtanh()
        )

    def forward(self, feat, graph):
        graph = graph.permute(0,2,1)
        feat = torch.cat((feat, graph), dim=1)
        feat = self.conv_init(feat)

        for layer in self.layers:
            feat = feat + layer(feat)

        desc = self.conv_desc(feat)
        offset = self.conv_offset(feat).permute(0,2,1)
        return desc, offset

### Optimal Connection Network

In [None]:
class ScoreNet(nn.Module):

    def __init__(self, in_ch):
        super().__init__()
        self.relu = nn.ReLU(inplace=True)
        self.conv1 = nn.Conv2d(in_ch, 256, kernel_size=1, stride=1, padding=0, bias=True)
        self.bn1 = nn.BatchNorm2d(256)
        self.conv2 = nn.Conv2d(256, 128, kernel_size=1, stride=1, padding=0, bias=True)
        self.bn2 = nn.BatchNorm2d(128)
        self.conv3 = nn.Conv2d(128, 64, kernel_size=1, stride=1, padding=0, bias=True)
        self.bn3 = nn.BatchNorm2d(64)
        self.conv4 = nn.Conv2d(64, 1, kernel_size=1, stride=1, padding=0, bias=True)

    def forward(self, x):
        n_points = x.shape[-1]

        x = x.unsqueeze(-1)
        x = x.repeat(1,1,1,n_points)
        t = torch.transpose(x, 2, 3)
        x = torch.cat((x, t), dim=1)

        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)

        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)

        x = self.conv3(x)
        x = self.bn3(x)
        x = self.relu(x)

        x = self.conv4(x)
        return x[:,0]

## Matching Network

In [None]:
class OptimalMatching(nn.Module):

    def __init__(self):
        super(OptimalMatching, self).__init__()
        
        # Default configuration settings
        self.descriptor_dim = 64
        self.sinkhorn_iterations = 100
        self.attention_layers = 4
        self.correction_radius = 0.05

        # Modules
        self.scorenet1 = ScoreNet(self.descriptor_dim * 2)
        self.scorenet2 = ScoreNet(self.descriptor_dim * 2)
        self.gnn = AttentionalGNN(self.descriptor_dim, self.attention_layers)
            

    def normalize_coordinates(self, graph, ws, input):
        if input == 'global':
            graph = (graph * 2 / ws - 1)
        elif input == 'normalized':
            graph = ((graph + 1) * ws / 2)
            graph = torch.round(graph).long()
            graph[graph < 0] = 0
            graph[graph >= ws] = ws - 1
        return graph




    def log_optimal_transport_batch(self, Z, iters):
        """
        Computes the optimal transport between all pairs of rows and columns of a batch of cost matrices Z in log space,
        using the Sinkhorn algorithm.

        Args:
            Z: a tensor of shape (batch_size, m, n) representing a batch of cost matrices, where m is the number of rows
            and n is the number of columns.
            iters: the number of Sinkhorn iterations to perform.

        Returns:
            A tensor of the same shape as Z, containing the optimal transport plan between all pairs of rows and columns
            in the batch.
        """
        batch_size, m, n = Z.shape
        log_mu = -torch.tensor(m).to(Z).log().expand(batch_size, m)
        log_nu = -torch.tensor(n).to(Z).log().expand(batch_size, n)
        u, v = torch.zeros_like(log_mu), torch.zeros_like(log_nu)

        for _ in range(iters):
            v = log_nu - torch.logsumexp(Z + u.unsqueeze(-1), dim=-2)
            u = log_mu - torch.logsumexp(Z + v.unsqueeze(-2), dim=-1)

        return Z + u.unsqueeze(-1) + v.unsqueeze(-2)




    def predict(self, image, descriptors, graph):
        B, _, H, W = image.shape
        B, N, _ = graph.shape

        for b in range(B):
            b_desc = descriptors[b]
            b_graph = graph[b]

            # Extract descriptors
            b_desc = b_desc[:, b_graph[:,0], b_graph[:,1]]

            # Concatenate descriptors in batches
            if b == 0:                    
                sel_desc = b_desc.unsqueeze(0)
            else:
                sel_desc = torch.cat((sel_desc, b_desc.unsqueeze(0)), dim=0)

        # Multi-layer Transformer network. (Attentional Graph Neural Network)
        norm_graph = self.normalize_coordinates(graph, W, input="global") #out: normalized coordinate system [-1, 1]
        sel_desc, offset = self.gnn(sel_desc, norm_graph)

        # Correct points coordinates
        norm_graph = norm_graph + offset * self.correction_radius
        graph = self.normalize_coordinates(norm_graph, W, input="normalized") # out: global coordinate system [0, W]

        # Compute scores (Optimal connection Network)
        scores_1 = self.scorenet1(sel_desc) # Clockwise Scores
        scores_2 = self.scorenet2(sel_desc) # Counter-Clockwise Scores
        scores = scores_1 + torch.transpose(scores_2, 1, 2) # Permutation Matrix

        sinkhorn_scores = self.log_optimal_transport_batch(scores, self.sinkhorn_iterations)
        
        permutation_mat = scores_to_permutations(scores) # linear sum assignment
        poly = permutations_to_polygons(permutation_mat, graph, out='torch') 


        return poly, permutation_mat, scores, sinkhorn_scores, graph

# Training

### Loss Function

In [None]:
def cross_entropy_loss(sinkhorn_results, gt_permutation):
    loss_match = -torch.mean(torch.masked_select(sinkhorn_results, gt_permutation == 1))
    return loss_match 

In [None]:
def iou_loss_function(pred, gt):
    B, H, W = gt.shape
    total_iou = 0
    
    for batch in range(B):
        batch_pred = torch.zeros((H, W), device=gt.device)
        for poly in pred[batch]:
            # Convert polygon to mask
            mask = torch.zeros((H, W), device=gt.device)
            poly_tensor = poly.long()  # Ensure integer coordinates
            mask[poly_tensor[:, 1], poly_tensor[:, 0]] = 1
            mask = F.max_pool2d(mask.unsqueeze(0).float(), kernel_size=3, stride=1, padding=1).squeeze(0)
            batch_pred = torch.max(batch_pred, mask)
        
        plt.imshow(mask)
        plt.show()
        plt.imshow(gt[batch])
        plt.show()
        
        intersection = torch.sum(torch.min(batch_pred, gt[batch]))
        union = torch.sum(torch.max(batch_pred, gt[batch]))
        batch_iou = intersection / (union + 1e-6)  # Add small epsilon to avoid division by zero
        total_iou += batch_iou
    
    avg_iou = total_iou / B
    return 1 - avg_iou

In [None]:
def iou_loss_function(pred, gt):
    B, H, W = gt.shape
    iou = 0
    for batch in range(B):
        K = len(pred[batch]) # Number of polygons
        batch_tensor = np.zeros((K, H, W), dtype=np.uint8)
        for i, poly in enumerate(pred[batch]):
            cv2.fillPoly(batch_tensor[i], [poly.detach().cpu().numpy()], 1)

        batch_pred_mask = torch.sum(torch.tensor(batch_tensor), dim=0).permute(1,0)

        # plt.imshow(batch_pred_mask)
        # plt.show()
        # plt.imshow(gt[batch])
        # plt.show()

        intersection = torch.min(batch_pred_mask, gt[batch])
        union = torch.max(batch_pred_mask, gt[batch])
        batch_iou = torch.sum(intersection) / torch.sum(union)
        iou += batch_iou

    return torch.tensor(1 - iou, requires_grad=True)

### Models

In [None]:
# Backbone
model = R2U_Net()
model = model.to(device)
model = model.train()
model.load_state_dict(torch.load('trained_weights/polyworld_backbone', map_location=device))

# Vertex Detection
head_ver = DetectionBranch()
head_ver = head_ver.to(device)
head_ver = head_ver.train()
head_ver.load_state_dict(torch.load('trained_weights/polyworld_seg_head', map_location=device))

# NMS
suppression = NonMaxSuppression()
suppression = suppression.to(device)

# Matching
matching = OptimalMatching()
matching = matching.to(device)
matching = matching.train()
matching.load_state_dict(torch.load('trained_weights/polyworld_matching', map_location=device))

# Freeze 
# for param in model.parameters():
#     param.requires_grad = False

# for param in head_ver.parameters():
#     param.requires_grad = False

# for param in matching.parameters():
#     param.requires_grad = False

### Utils

In [None]:
def graph_to_vertex_mask(points, image):
    B, _, H, W = image.shape

    mask = torch.zeros((B, H, W), dtype=torch.uint8)

    # Loop style
    # for batch in range(B):
    #     mask[batch, points[batch, :, 0], points[batch, :, 1]] = 1

    # Vectorized Style
    batch_indices = np.arange(B)[:, None]
    mask[batch_indices, points[:, :, 0], points[:, :, 1]] = 1
    
    return mask





def polygon_to_vertex_mask(polygons: list):
    B = len(polygons)
    mask = torch.zeros((B, 320, 320), dtype=torch.uint8)

    for batch in range(B):
        batch_polygons = [np.array(poly, dtype=int) for poly in polygons[batch]]
        for poly in batch_polygons:
            for point in poly:
                mask[batch, point[1], point[0]] = 1


    return mask





def tensor_to_numpy(input: list):
    ''' convert a list of tensors to a list of numpy arrays '''
    numpy = []
    for batch in range(len(input)):
        batch_polygons = [tensor.cpu().numpy() for tensor in input[batch]]
        numpy.append(batch_polygons)
    return numpy





def point_to_polygon(points: list):
    
    B = len(points)
    polygons = []

    for batch in range(B):
        batch_polygons = [np.array(poly, dtype=int).reshape(-1, 2) for poly in points[batch]]
        polygons.append(batch_polygons)

    return polygons









def polygon_to_seg_mask():
    pass




def point_to_permutation():
    pass

### Training Loop

In [None]:
# Optimizer
optimizer = torch.optim.Adam([
    # {'params': model.parameters()},
    # {'params': head_ver.parameters()},
    {'params': matching.parameters()}
], lr=1e-4)

In [None]:
w = 100
# detection_loss_function = nn.BCELoss(weight=torch.tensor([w])).to(device) # For Vertices # Cuda Error using this loss!
# detection_loss_function = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([w])).to(device) # For Vertices

# matching_loss_function = nn.CrossEntropyLoss().to(device)  # For Permutation Matrix
# matching_loss_function = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([w])).to(device)
# matching_loss_function = nn.NLLLoss().to(device)
# matching_loss_function = nn.BCELoss().to(device)

# angle_loss_function = AngleLoss().to(device)

In [None]:
dataloader = DataLoader(dataset, batch_size=2, shuffle=False, num_workers=4, collate_fn=collate_fn)

In [None]:
# Training Loop
EPOCHS = 300
for epoch in range(EPOCHS):
    for image, gt_vertex_mask, gt_seg_mask, gt_permutation_matrix, gt_polygons in tqdm(dataloader):

        # Stack Tensors
        image, gt_vertex_mask, gt_seg_mask, gt_permutation_matrix = torch.stack(image), torch.stack(gt_vertex_mask), torch.stack(gt_seg_mask), torch.stack(gt_permutation_matrix)

        # Move Inputs to Device
        image = image.float().permute(0, 3, 1, 2).to(device)
        gt_vertex_mask = gt_vertex_mask.unsqueeze(1).to(device)
        # gt_seg_mask = gt_seg_mask.unsqueeze(1).cuda()
        gt_permutation_matrix = gt_permutation_matrix.float().to(device)
        # gt_polygons = gt_polygons.to(device)

        optimizer.zero_grad()

        # Forward pass
        features = model(image)
        vertex_logits = head_ver(features) # occupancy grid
        _ , graph = suppression(vertex_logits) # (B, 256, 2) --> Vertex Positions
        poly, permutation_matrix, scores, sinkhorn_scores, _ = matching.predict(image, features, graph) # Graph --> Vertex Points | features --> descriptors
        # permutation_matrix = permutation_matrix.to(device)
        # permutation_matrix.requires_grad = True


        # del features
        # del vertex_logits
        # del graph
        del image
        # del gt_vertex_mask
        # del gt_seg_mask
        # del gt_permutation_matrix
        # del gt_polygons

        # print('Graph: ', graph.shape)
        # print('Permutation Matrix: ', permutation_matrix.shape)
        # print('Poly: ', len(poly))
        # print('Scores: ', scores.shape)
        # print('Vertex Logits: ', vertex_logits.shape)
        # print('Features: ', features.shape)
        # print('GT Vertex Mask: ', gt_vertex_mask.shape)
        # print('GT Seg Mask: ', gt_seg_mask.shape)
        # print('GT Permutation Matrix: ', gt_permutation_matrix.shape)
        # print('GT Polygons: ', len(gt_polygons))


        # Compute loss
        # detection_loss = detection_loss_function(vertex_logits, gt_vertex_mask.float())
        # segmentation_loss = iou_loss_function(poly, gt_seg_mask)
        # matching_loss = matching_loss_function(permutation_matrix.float(), gt_permutation_matrix)
        matching_loss = cross_entropy_loss(sinkhorn_scores, gt_permutation_matrix)



        loss = matching_loss

        # Backward pass
        loss.backward()
        optimizer.step()


        # del image, gt_vertex_mask, gt_seg_mask, gt_permutation_matrix, features, vertex_logits, graph, poly


    # plt.imshow(vertex_logits[0].detach().cpu().numpy().squeeze())
    # plt.show()
    plt.imshow(permutation_matrix[0].detach().cpu().numpy())
    plt.show()
    plt.imshow(gt_permutation_matrix[0].detach().cpu().numpy())
    plt.show()
    # plt.imshow(scores[0].detach().cpu().numpy())    
    # plt.show()

    
    # print('Detection Loss: ', detection_loss)
    # print('Segmentation Loss: ', segmentation_loss)
    print('Matching Loss: ', matching_loss)
    print(f"Epoch {epoch} - Loss: {loss.item()}")

### Save

In [None]:
# save the complete model and optimizer state
torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,
            }, 'trained_weights/polyworld_backbone_overfit_1sample.pth')

### Inference

In [None]:
# Inputs
image, gt_vertex_mask, gt_seg_mask, gt_permutation_matrix, gt_polygons = next(iter(dataloader))
image, gt_vertex_mask, gt_seg_mask, gt_permutation_matrix = torch.stack(image), torch.stack(gt_vertex_mask), torch.stack(gt_seg_mask), torch.stack(gt_permutation_matrix)
image = image.float().cuda().permute(0, 3, 1, 2)

In [None]:
print('Image: ', image.shape)
print('Vertex: ', gt_vertex_mask.shape)
print('Segmentation: ', gt_seg_mask.shape)
print('Permutation: ', gt_permutation_matrix.shape)

In [None]:
# Outputs
features = model(image)
vertex_logits = head_ver(features)
_ , graph = suppression(vertex_logits)
poly, permutation_matrix, scores, sinkhorn_scores, graph_refined = matching.predict(image, features, graph) 

In [None]:
print('RGB: ', image.shape)
print('Features Map: ', features.shape)
print('Vertex Logits: ', vertex_logits.shape)
print('Graph: ', graph.shape)
print('Permutation Matrix: ', permutation_matrix.shape)
print('Sinkhorn Scores: ', sinkhorn_scores.shape)
print('Scores: ', scores.shape)
print('Graph Refined: ', graph_refined.shape)

In [None]:
plt.imshow(image[0].permute(1,2,0).cpu().numpy())
plt.imshow(vertex_logits[0].cpu().detach().numpy().squeeze(), alpha=.3)

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(20, 10))

axes[0].imshow(torch.sigmoid(scores[0]).cpu().detach().numpy())
axes[0].set_title("Scores")

axes[1].imshow(torch.sigmoid(sinkhorn_scores[0]).cpu().detach().numpy())
axes[1].set_title("Sinkhorn Scores")



# plt.imshow(torch.sigmoid(scores[0]).cpu().detach().numpy())
# plt.imshow(torch.sigmoid(sinkhorn_scores[0]).cpu().detach().numpy())

In [None]:
plt.imshow(torch.sigmoid(sinkhorn_scores[0]).cpu().detach().numpy())

In [None]:
i = 0


fig, axes = plt.subplots(1, 3, figsize=(20, 10))

axes[0].imshow(sinkhorn_scores[i].detach().cpu().numpy())
axes[0].set_title("Sinkhorn Scores")

axes[1].imshow(permutation_matrix[i].detach().cpu().numpy())
axes[1].set_title("Permutation Matrix")

axes[2].imshow(scores[i].detach().cpu().numpy())
axes[2].set_title("Scores")

plt.tight_layout()
plt.show()

In [None]:
mask1 = graph_to_vertex_mask(graph, image)

mask2 = polygon_to_vertex_mask(tensor_to_numpy(poly)).permute(0, 2, 1)

In [None]:
plt.imshow(mask2[0])

In [None]:
mask1.shape, mask2.shape

In [None]:
# plot mask and seg mask
i = 0
fig, axes = plt.subplots(1, 4, figsize=(21, 13))
axes[0].imshow(gt_vertex_mask[i].cpu().numpy().squeeze())
axes[1].imshow(mask1[i])
axes[2].imshow(mask2[i])
axes[3].imshow(gt_seg_mask[i].cpu().numpy().squeeze())


axes[0].set_title('GT Vertex Mask')
axes[1].set_title('Prediction after NMS')
axes[2].set_title('Prediction from Permutations Matrix + Graph')
axes[3].set_title('GT Segmentation Mask')

plt.tight_layout()
plt.show()

In [None]:
i = 0

fig, axes = plt.subplots(1, 2, figsize=(21, 13))

axes[0].imshow(gt_permutation_matrix[i].cpu().numpy())
axes[1].imshow(permutation_matrix[i].cpu().numpy())

axes[0].set_title('GT Permutation Matrix')
axes[1].set_title('Prediction Permutation Matrix')

In [None]:
# apply sigmoid to the logits
output = torch.sigmoid(vertex_logits)
output.shape

In [None]:
i = 0

fig, axes = plt.subplots(1, 3, figsize=(21, 13))
axes[0].imshow(output[i].detach().cpu().numpy().squeeze())
axes[1].imshow(vertex_logits[i].detach().cpu().numpy().squeeze())
axes[2].imshow(image[i].cpu().numpy().squeeze().transpose(1, 2, 0))

axes[0].set_title('Applied Sigmoid on Logits')
axes[1].set_title('Logits')
axes[2].set_title('RGB Image')

plt.tight_layout()
plt.show()

# Prediction

In [None]:
def bounding_box_from_points(points):
    points = np.array(points).flatten()
    even_locations = np.arange(points.shape[0]/2) * 2
    odd_locations = even_locations + 1
    X = np.take(points, even_locations.tolist())
    Y = np.take(points, odd_locations.tolist())
    bbox = [X.min(), Y.min(), X.max()-X.min(), Y.max()-Y.min()]
    bbox = [int(b) for b in bbox]
    return bbox


def single_annotation(image_id, poly):
    _result = {}
    _result["image_id"] = int(image_id)
    _result["category_id"] = 100 
    _result["score"] = 1
    _result["segmentation"] = poly
    _result["bbox"] = bounding_box_from_points(_result["segmentation"])
    return _result

In [None]:
def prediction(batch_size, images_directory, annotations_path):

    # Vertex Detection
    model = R2U_Net()
    model = model.cuda()
    model = model.train()

    head_ver = DetectionBranch()
    head_ver = head_ver.cuda()
    head_ver = head_ver.train()


    # NMS
    suppression = NonMaxSuppression()
    suppression = suppression.cuda()

    # Generate the connections between virtices
    matching = OptimalMatching()
    matching = matching.cuda()
    matching = matching.train()

    # NOTE: The modules are set to .train() mode during inference to make sure that the BatchNorm layers 
    # rely on batch statistics rather than the mean and variance estimated during training. 
    # Experimentally, using batch stats makes the network perform better during inference.

    print("Loading pretrained model")
    model.load_state_dict(torch.load("./trained_weights/polyworld_backbone"))
    head_ver.load_state_dict(torch.load("./trained_weights/polyworld_seg_head"))
    matching.load_state_dict(torch.load("./trained_weights/polyworld_matching"))

    # Initiate the dataloader
    CrowdAI_dataset = CrowdAI(images_directory=images_directory, annotations_path=annotations_path)
    dataloader = DataLoader(CrowdAI_dataset, batch_size=batch_size, shuffle=False, num_workers=batch_size)

    train_iterator = tqdm(dataloader)

    speed = []
    predictions = []
    for i_batch, sample_batched in enumerate(train_iterator):

        rgb = sample_batched['image'].cuda().float()
        idx = sample_batched['image_idx']

        t0 = time.time()

        features = model(rgb)
        occupancy_grid = head_ver(features)

        _, graph_pressed = suppression(occupancy_grid)

        poly = matching.predict(rgb, features, graph_pressed) 

        speed.append(time.time() - t0)


        for i, pp in enumerate(poly):
            for p in pp:
                predictions.append(single_annotation(idx[i], [p]))

        del features
        del occupancy_grid
        del graph_pressed
        del poly
        del rgb
        del idx

    print("Average model speed: ", np.mean(speed) / batch_size, " [s / image]")

    # fp = open("predictions.json", "w")
    # fp.write(json.dumps(predictions))
    # fp.close()

In [None]:
prediction(batch_size=1, images_directory="data/val/images/", annotations_path="data/val/annotation.json")

# Playground

In [None]:
with open('data/val/annotation.json') as f:
    annotations = json.load(f)

images = pd.DataFrame(annotations['images'])
labels = pd.DataFrame(annotations['annotations'])

In [None]:
segmentations = labels[labels['image_id'] == 9]

In [None]:
segmentations = segmentations['segmentation'].values
segmentations = [e[0] for e in segmentations]
segmentations = [np.array(poly, dtype=int).reshape(-1, 2) for poly in segmentations]
# segmentations = [np.array(poly * ratio, dtype=int).reshape(-1, 2) for poly in segmentations]
segmentations

In [None]:

N = 256
permutation_matrix = np.zeros((N, N))

n = 0
for i, polygon in enumerate(segmentations):
    for v, point in enumerate(polygon):
        if v != len(polygon) - 1:
            permutation_matrix[n, n+1] = 1
        else:
            permutation_matrix[n, n-v] = 1
        n += 1
    
    print(f'Polygon {i} Finished')

In [None]:
plt.figure(figsize=(15, 15))
plt.imshow(permutation_matrix)

In [None]:
len(segmentations)

In [None]:
def create_adjacency_matrix(polygons, N=256):
    # Initialize the permutation matrix
    permutation_matrix = np.zeros((N, N), dtype=int)
    
    for polygon in polygons:
        # Extract vertices from the polygon list
        vertices = [(polygon[i], polygon[i+1]) for i in range(0, len(polygon), 2)]
        
        # Get the number of vertices
        num_vertices = len(vertices)
        
        for i in range(num_vertices):
            # Current vertex
            current_vertex = vertices[i]
            # Next vertex (with wrap-around)
            next_vertex = vertices[(i + 1) % num_vertices]
            
            # Find indices of current and next vertices
            current_index = i
            next_index = (i + 1) % num_vertices
            
            # Update the permutation matrix
            permutation_matrix[current_index, next_index] = 1
            # permutation_matrix[next_index, current_index] = 1
    
    return permutation_matrix

In [None]:
adj_matrix = create_adjacency_matrix(segmentations, N=256)

In [None]:
adj_matrix.shape

In [None]:
adj_matrix

In [None]:
plt.imshow(adj_matrix[:20, :20])

In [None]:
def log_optimal_transport_batch(Z, iters):
    """
    Computes the optimal transport between all pairs of rows and columns of a batch of cost matrices Z in log space,
    using the Sinkhorn algorithm.

    Args:
        Z: a tensor of shape (batch_size, m, n) representing a batch of cost matrices, where m is the number of rows
        and n is the number of columns.
        iters: the number of Sinkhorn iterations to perform.

    Returns:
        A tensor of the same shape as Z, containing the optimal transport plan between all pairs of rows and columns
        in the batch.
    """
    batch_size, m, n = Z.shape
    log_mu = -torch.tensor(m).to(Z).log().expand(batch_size, m)
    log_nu = -torch.tensor(n).to(Z).log().expand(batch_size, n)
    u, v = torch.zeros_like(log_mu), torch.zeros_like(log_nu)

    for _ in range(iters):
        v = log_nu - torch.logsumexp(Z + u.unsqueeze(-1), dim=-2)
        u = log_mu - torch.logsumexp(Z + v.unsqueeze(-2), dim=-1)

    return Z , u.unsqueeze(-1) , v.unsqueeze(-2)

In [None]:
# create a (8, 10, 10) torch tensor with value 1 at random indices
Z = torch.zeros((8, 10, 10))
for i in range(8):
    idx = torch.randint(0, 10, (10, 2))
    Z[i, idx[:, 0], idx[:, 1]] = 1

In [None]:
plt.imshow(Z[0].detach().cpu().numpy())

In [None]:
Z, u, v = log_optimal_transport_batch(Z, 100)

In [None]:
v.shape, u.shape

In [None]:
plt.imshow(Z[0].detach().cpu().numpy())

In [None]:
plt.imshow(u[0].detach().cpu().numpy())

In [None]:
plt.imshow(v[0].detach().cpu().numpy())

In [None]:
x = u[0] + v[0]

In [None]:
x.shape

In [None]:
plt.imshow(x)