In [None]:
import os
import sys 
os.chdir(os.path.join(os.getenv('HOME'), 'ASL'))
sys.path.insert(0, os.getcwd())
sys.path.append(os.path.join(os.getcwd() + '/src'))

from pseudo_label import OpticalFlowLoader

#TODO: Jonas Frey
# list of things that need to change if cropping scannet at first: 
# -> crop all labels
# -> crop in flow generation
# -> crop when loading gt/label

In [None]:
import imageio
arr = imageio.imread("/home/jonfrey/Datasets/scannet/scans/scene0000_00/color/0.jpg")
arr.shape
from PIL import Image
b = 11
arr2 = arr[b:968-b,b:1296-b]
arr2.shape
Image.fromarray(arr2)

In [None]:
from visu import Visualizer
visualizer = Visualizer(os.getenv('HOME')+'/tmp', logger=None, epoch=0, store=False, num_classes=41)

In [None]:
from torchvision.utils import make_grid
from skimage.segmentation import slic
from skimage.segmentation import mark_boundaries

import copy
import torch
import cv2
import numpy as np
import imageio
import os

import matplotlib.pyplot as plt
from torch.nn.functional import one_hot
from math import ceil
import torchvision

"""
TODO: Jonas Frey
Fully regenerate the flow with the correct center crop to get rid of the boarder induces by the rectification step.
"""

def make_grid_sqaure( img, padding=5, pad_value= 0):
    nrow = ceil(len(img)**0.5)
    img = torch.from_numpy( np.stack( img, axis=0)).permute(0,3,1,2)
    img = torchvision.utils.make_grid( img , nrow = nrow, padding = padding, pad_value = pad_value).permute(1,2,0).numpy()
    return np.uint8(img)
    
class PseudoLabelGenerator():
    def __init__(self, confidence='equal', 
            flow_mode='sequential', visualizer= None):
        """  
        confidence:
          'equal': perfect optical flow -> all project labels are equally good
          'linear': linear rate -> per frame
          'exponential': exponential rate -> per frame 
        flow_mode:
          'sequential': #-> 0->1, 1->2, 2->3
        """
        assert flow_mode == 'sequential', f"Not defined flow mode: {flow_mode}"
        self.visualizer = visualizer
    
    def calculate_label(self, flows, labels, images = None):
        
        print("VISU Input")
        self.visu( flows, labels, images)
        
        seg_forwarded = self._forward_index( flows, labels )
        
        print("VISU Output")
        self.visu( flows, seg_forwarded, images)
        
        out = {}
        for k in seg_forwarded[0].keys():
            all_labels = [labels[i]['pretrain25k'] for i in range(len(labels))]
            if len( all_labels[0].shape ) == 2:
                # COUNT BASED
                ar = np.stack( all_labels , axis= 0)

                uni = np.unique( ar)
                uni = uni[uni!= 0]
                l = np.zeros( (41,* all_labels[0].shape) )
                for j,u in enumerate(uni):
                    l[ u ] = np.sum( ar == u, axis=0)
                out[k] = l.argmax( axis = 0)
            else:
                # SUM BASED
                ar = np.stack( all_labels , axis= 0).sum(axis=0)
                out[k] = ar.argmax( axis = 2)
            
            if self.visualizer != None:
                self.visualizer.plot_detectron( images[0], out[k], tag=f"Output {k}", jupyter=True)
            
            
        return out
    
    def _get_confidence_values( self, seq_length ):
        if self._confidence == 'equal':
            return [float( 1/seq_length)] * seq_length 

        if self._confidence == 'linear':
            ret = []
            lin_rate = 0.1
            s = 0
            for i in range(seq_length):
                res = 1 - lin_rate* (seq_length-i)
                if res < 0: 
                    res = 0
                s += res
                
                ret.append(res)
            return [r/s for r in ret]

        elif self._confidence == 'exponential':
            ret = []
            exp_rate = 0.8
            s = 0
            for i in range(seq_length):
                res = exp_rate**(seq_length-i)
                if res < 0: 
                    res = 0
                s += res
                ret.append(res)
            return [r/s for r in ret]

    def _labels_add_channels(self, labels):
        for j, frame in enumerate( labels):
            for k in frame.keys():
                print(frame[k].shape)
                if len(frame[k].shape) == 2:
                    labels[j][k] = frame[k][:,:,None]
        return labels

    def _forward_index(self, flows, labels ):
        """
        # OUTPUT OF OpticalFlowLoader # labels = H,W or H,W,C
        """ 
        labels = self._labels_add_channels(labels) # -> H,W,C C (1 or 40)
        
        print("labels[0]['pretrain25k'].shape",labels[0]['pretrain25k'].shape)
        H,W,_ = flows[0][0].shape
        
        labels_forwarded = []
        
        for i in range(0,len(flows)):
            # START AT OLDEST FLOW
            flow = flows[i][0]
            valid = flows[i][1]
            
            labels_forwarded.append( labels[i] ) # each entry is a dict with mutiple labels
            
            h_, w_ = np.mgrid[0:H, 0:W].astype(np.float32)
            h_ -= flow[:,:,1]
            w_ -= flow[:,:,0]
            
            for labels_frame in labels_forwarded:
                for k in labels_frame.keys():
                    if labels_frame[k].shape[2] != 1:
                        # soft label with prob
                        forw = cv2.remap( labels_frame[k] , w_, h_, 
                                                      interpolation=cv2.INTER_LINEAR, 
                                                      borderMode=cv2.BORDER_CONSTANT, 
                                                      borderValue=0)
                        labels_frame[k] = forw
                    else:
                        # hard label
                        forw = cv2.remap( labels_frame[k], w_, h_, 
                                                      interpolation=cv2.INTER_NEAREST, 
                                                      borderMode=cv2.BORDER_CONSTANT, 
                                                      borderValue=0)[:,:,None]
                        if i == len(flows)-1:
                            labels_frame[k] = forw[:,:,0]
                        else:
                            labels_frame[k] = forw
                    
        return labels_forwarded
    
    def visu(self, flows, labels, images):
        if self.visualizer is not None:
            
            for label_name in labels[0].keys():
                detectron_plots = []
                for _f,label_dict, img in zip(flows, labels, images):
                    flow,valid = _f 
                    plo = visualizer.plot_detectron( img, label_dict[label_name] , jupyter=False) 
                    detectron_plots.append(plo)
                grid = make_grid_sqaure( detectron_plots , padding=5, pad_value= 0)
                print("label_name ",label_name)
                res = visualizer.plot_image(grid, tag=label_name, jupyter=True) 

ofl = OpticalFlowLoader( lookahead=4)
ofl.register_predictions("pretrain25k", "/home/jonfrey/Datasets/labels_generated/labels_pretrain25k/scans")
ofl.register_predictions("gt", "/home/jonfrey/Datasets/scannet/scans")
plg = PseudoLabelGenerator(visualizer = visualizer)

for j,val in enumerate( ofl):
  flows, labels, images = val
  if j == 0: break
        
plg.calculate_label(flows, labels, images)


In [None]:
images[0].shape

In [None]:

#             plo = visualizer.plot_detectron( img, label_dict[label_name] , jupyter=False) 

uni

In [None]:
images[0].shape

In [None]:
len(flows), len(images), len(labels)

In [None]:
# ORIGINAL

from torchvision.utils import make_grid
from skimage.segmentation import slic
from skimage.segmentation import mark_boundaries

import copy
import torch
import cv2
import numpy as np
import imageio
import os

import matplotlib.pyplot as plt
from torch.nn.functional import one_hot

class PseudoLabelGenerator():
    def __init__(self, confidence='equal', 
            flow_mode='sequential'):
        """  
        confidence:
          'equal': perfect optical flow -> all project labels are equally good
          'linear': linear rate -> per frame
          'exponential': exponential rate -> per frame 
        flow_mode:
          'sequential': #-> 0->1, 1->2, 2->3
        """
        assert flow_mode == 'sequential', f"Not defined flow mode: {flow_mode}"
    
    def __len__(self):
      return self._pll.length

    def get_gt_label(self, index):
        seg, depth, flow, paths = self._pll[index] 
        return seg[0][1]
    
    def get_img(self, index):
        return self._pll.getImage(index)
    
    def get_depth(self, index):
        seg, depth, flow, paths = self._pll[index]
        return depth[0]
    
    def calculate_label(self, index=None, seg=[],flow=[], image= None ):
        if not index is None:
            seg_forwarded= self._forward_index(index, self._pre_fusion_function) #return H,W,C
        else:
            seg_forwarded = self._forward_index(index, seg, flow, self._pre_fusion_function)
        
        if self._visu_active:
            self._visu_seg_forwarded(seg_forwarded)

        # -1 39 -> 0 -> 40 We assume that the network is also able to predict the invalid class
        # In reality this is not the case but this way we can use the ground truth labels for testing
        
        if seg_forwarded[0].shape[2] == 1:
            for i in range(len( seg_forwarded) ):
                seg_forwarded[i] += 1
                seg_forwarded[i][seg_forwarded[i]>40] = 40
        confidence_values_list = self._get_confidence_values(seq_length= len(seg_forwarded))
        if seg_forwarded[0].shape[2] == 1:
            import pdb; pdb.set_trace()
            one_hot_acc = np.zeros( (*seg_forwarded[0].shape, self._nc+1), dtype=np.float32) # H,W,C
            
            for conf, seg in zip(confidence_values_list, seg_forwarded):    
                one_hot_acc += (np.eye(self._nc+1)[seg.astype(np.int32)]).astype(np.float32) * conf
            one_hot_acc = one_hot_acc[:,:,0,:]
            invalid_labels = np.sum( one_hot_acc[:,:,1:],axis=2 ) == 0
            
        
        label = np.argmax( one_hot_acc[:,:,1:], axis=2 )
        label[ invalid_labels ] = -1 
        
#         if self._visu_active:
#             print("Aggregated Label")
#             self._visu.plot_segmentation(seg= label+1, jupyter=True)

#         if self._refine_superpixel:
#             if image is None:
#                 img = self._pll.getImage(index).astype(np.float32)/256
#             else:
#                 img = image
                
#             label_super, img, segments = self._superpixel_label(img, label)
#             if self._visu_active:
#                 print("Label Superpixel")
#                 self._visu.plot_segmentation(seg= label_super + 1, jupyter=True)  
#                 self._visu.plot_image(img=img, jupyter=True)  
#                 self._visu_superpixels(img, segments)
#             label = label_super
#         print( "Time rest", time.time()-st)
#         if self._get_depth_superpixel:
#             self._superpixel_depth(depth_forwarded[-1], label)
        return label, (seg_forwarded[-1]-1).astype(np.int32)
    
    def _get_confidence_values( self, seq_length ):
        if self._confidence == 'equal':
            return [float( 1/seq_length)] * seq_length 

        if self._confidence == 'linear':
            ret = []
            lin_rate = 0.1
            s = 0
            for i in range(seq_length):
                res = 1 - lin_rate* (seq_length-i)
                if res < 0: 
                    res = 0
                s += res

                ret.append(res)
            return [r/s for r in ret]

        elif self._confidence == 'exponential':
            ret = []
            exp_rate = 0.8
            s = 0
            for i in range(seq_length):
                res = exp_rate**(seq_length-i)
                if res < 0: 
                    res = 0
                s += res
                ret.append(res)
            return [r/s for r in ret]


    def _forward_index(self, index=None, seg=[],flow=[] ,pre_fusion_function=None ):
        """
        seg[0] , C,H,W
        
        pre_fusion_function should be used to integrate the depth measurments 
        to the semseg before forward projection !

        seg_forwarded[0] -> oldest_frame
        seg_forwarded[len(seg_forwarded)] -> latest_frame not forwarded

        """
        
        if not index is None:
            if pre_fusion_function is None:
                seg, _, flow, _ = self._pll[index]
            else:
                seg, _, flow, _ = pre_fusion_function( self._pll[index] )
        
        if len( seg[0].shape ) == 3 and seg[0].shape[0] != 1:
            soft = True
        else:
            soft = False
            
        assert self._flow_mode == 'sequential'
        seg_forwarded = []
        
        for j in range(len( seg )):
            seg[j] = np.moveaxis( seg[j], [0,1,2], [2,0,1] ) #C,H,W -> H,W,C
        
        for i in range(0,len(seg)-1):
            i = len(seg)-1-i
            seg_forwarded.append( seg[i].astype(np.float32) )

            
            # start at oldest frame
            if i != 0:
                f = flow[i][0]
            else:
                f = np.zeros(flow[i][0].shape, dtype=np.float32)
            
            h_, w_ = np.mgrid[0:self._H, 0:self._W].astype(np.float32)
            h_ -= f[:,:,1]
            w_ -= f[:,:,0]

            j = 0
            for s in seg_forwarded : #  seg_forwarded, depth_forwarded
                if soft:
                    s = cv2.remap( s, w_, h_, interpolation=cv2.INTER_LINEAR, borderMode=cv2.BORDER_CONSTANT, borderValue=0)
                else:
                    s = cv2.remap( s[None], w_, h_, interpolation=cv2.INTER_NEAREST, borderMode=cv2.BORDER_CONSTANT, borderValue=-1)[None]
                seg_forwarded[j] = s
                
                j += 1
    
        if not soft:
            for i in range(len( seg_forwarded) ):
                seg_forwarded[i] = np.moveaxis(  seg_forwarded[i] , [0,1,2], [2,0,1] )
        seg_forwarded.append( seg[0] )
        return seg_forwarded # H,W,C


    def _visu_seg_forwarded(self, seg):
        s = int( len(seg) ** 0.5 )
        ba = torch.zeros( (int(s*s),3, *seg[0].shape), dtype= torch.float32 )
        for i in range( int(s*s) ) :
            ba[i,:] = torch.from_numpy( seg[-(i+1)] )[None,:,:].repeat(3,1,1)
        grid_ba = make_grid( ba ,nrow = s ,padding = 2,
          scale_each = False, pad_value = -1)[0]
        self._visu.plot_segmentation(seg= grid_ba +1 , jupyter=True)

    def _superpixel_label(self, img, label, segments=250):
        assert segments < 256 #I think slic fails if segments > 256 given that a 8bit uint is returend!

        segments = slic(img, n_segments = segments, sigma = 5, start_label=0)
        # show the output of SLIC
        out_label = copy.copy(label)
        for i in range(0,segments.max()):
            m1 = segments == i
            m = m1 * ( label != -1 )
            unique_val, unique_counts = np.unique( label [m], return_counts=True)
            # fill a segment preferable not with invalid !
            if unique_counts.shape[0] == 0:
                val = -1
            else:
                ma = unique_counts == unique_counts.max()
                while ma.sum() != 1:
                    ma[np.random.randint(0,ma.shape[0])] = False
                val = unique_val[ma]
            out_label[m1] = val 

        return out_label, img, segments
    
    def _visu_superpixels(self, img, segments):
        import matplotlib.pyplot as plt
        from skimage.segmentation import mark_boundaries
        fig = plt.figure("Superpixels -- segments" )
        ax = fig.add_subplot(1, 1, 1)
        ax.imshow(mark_boundaries(img, segments))
        plt.axis("off")
        plt.show()

# plg_ws_2 = PseudoLabelGenerator(base_path='/home/jonfrey/results/scannet_pseudo_label/scannet', 
#                            visu=visu,
#                            window_size=window_size_2,
#                           cfg_loader = {"ignore_depth": True},
#                           visu_active=plot,
#                           refine_superpixel=False)
# pseudo_2, _ = plg_ws_2.calculate_label(
#                         index=None, 
#                         seg= seg[:window_size_2], 
#                         flow= flow_list[:window_size_2])