In [None]:
#@title Installing necessary libraries, that are not already installed on Colab
# This might take a couple of minutes, and it may restart the runtime!
# If the runtime requests reloading, canceling is usually the better.
# (This might change in the future.)

# Detectron2 environment for instance segmentation
!pip install "git+https://github.com/facebookresearch/detectron2.git"

# SMP environment for tracking
!pip install segmentation-models-pytorch

# Filterpy for Kalman Filter definition (can be ignored, if Kalman Filter reference is not used)
!pip install filterpy

In [2]:
#@title Importing required functions and libraries

# Imports from Symmetry-Tracker repo

import sys
sys.path.append("/path/to/symmetry_tracker")

from symmetry_tracker.general_functionalities.video_transformation import TransformVideoFromTIFF

from symmetry_tracker.segmentation.segmentator import SingleVideoSegmentation
from symmetry_tracker.segmentation.segmentation_io import DisplaySegmentation, WriteSegmentation

from symmetry_tracker.tracking.symmetry_tracker import SingleVideoSymmetryTracking
from symmetry_tracker.tracking.tracking_io import DisplayTracks, WriteTracks, SaveTracksVideo, SaveTracks, LoadTracks
from symmetry_tracker.tracking.post_processing import InterpolateMissingObjects, RemoveShortPaths, HeuristicalEquivalence

# Other necessary imports

import torch
import os
import shutil

In [3]:
sample_categories = ["1"]

In [None]:
#@title Downloading the Models and Sample Data

!rm -r downloads
!mkdir downloads

for sample_category in sample_categories:
  sample_record_name = f"ArrowSynth{sample_category}"
  train_sample_ratio = "1.0"
  sample_id = "081"

  !mkdir downloads/$sample_record_name

  segmentator_models_dir = f"/path/to/{sample_record_name}/trained_nets/InstanceSegmentation/"
  segmentator_model_name = f"model_final.pth"
  segmentator_config_name = "config.yaml"

  tracking_models_dir = f"/path/to/{sample_record_name}/trained_nets/LocalTracking/"
  tracking_model_name = f"{sample_record_name}_[DLV3p,resnet50]_FBtr4_Ep50_Adv2_SR{train_sample_ratio}_NonPretrained_final.pth"
  #tracking_model_name = f"{sample_record_name}_[DLV3p,resnet50]_FBtr4_Ep50_Adv2_SR{train_sample_ratio}_final.pth"

  testdata_dir = f"/path/to/{sample_record_name}/data/"
  testdata = f"{sample_record_name}_test_samples.zip"

  !wget --no-clobber $segmentator_models_dir$segmentator_model_name -P downloads/$sample_record_name/
  !wget --no-clobber $segmentator_models_dir$segmentator_config_name -P downloads/$sample_record_name/
  !wget --no-clobber $tracking_models_dir$tracking_model_name -P downloads/$sample_record_name/

  !mkdir downloads/$sample_record_name/data
  !wget --no-clobber $testdata_dir$testdata -P downloads/$sample_record_name/data
  !unzip downloads/$sample_record_name/data/$testdata -d downloads/$sample_record_name/data/

In [14]:
sample_category = "1"

In [None]:
#@title Pipeline parameter setup

sample_record_name = f"ArrowSynth{sample_category}"
sample_id = "081"
train_sample_ratio = "1.0"

segmentator_model_name = f"model_final.pth"
segmentator_config_name = "config.yaml"
tracking_model_name = f"{sample_record_name}_[DLV3p,resnet50]_FBtr4_Ep50_Adv2_SR{train_sample_ratio}_NonPretrained_final.pth"
#tracking_model_name = f"{sample_record_name}_[DLV3p,resnet50]_FBtr4_Ep50_Adv2_SR{train_sample_ratio}_final.pth"

# Input paths
SegmentationModelPath = f"./downloads/{sample_record_name}/"+segmentator_model_name
SegmentationModelConfigPath = f"./downloads/{sample_record_name}/"+segmentator_config_name
TrackingModelPath = f"./downloads/{sample_record_name}/"+tracking_model_name
InputVideoPath = f"./downloads/{sample_record_name}/data/{sample_record_name}_test_samples/{sample_id}/imgs/"

# Output paths
!mkdir outputs
!mkdir outputs/$sample_record_name
!mkdir outputs/$sample_record_name/segmentations
!mkdir outputs/$sample_record_name/trackings
!mkdir outputs/$sample_record_name/videos
SegmentationSavePath = f"./outputs/{sample_record_name}/segmentations/"+sample_record_name+"_"+sample_id+"_Segmentation.txt"
TrackingSavePath = f"./outputs/{sample_record_name}/trackings/"+sample_record_name+"_"+sample_id+"_Tracks.json"
TrackingWritePath = f"./outputs/{sample_record_name}/trackings/"+sample_record_name+"_"+sample_id+"_Tracks.txt"
TrackingVideoPath = f"./outputs/{sample_record_name}/videos/"+sample_record_name+"_"+sample_id+"_Tracks.mp4"

# TimeKernelSize is the size of the kernel in both directions without the central image
# (So if TimeKernelSize is 2, then the full kernel size is 5)
# This parameter is unique to the given tracking model
TimeKernelSize = 4

# Matching colab environment (for now GPU vs CPU)
Device = ("cuda:0" if torch.cuda.is_available() else "cpu")
print("Colab environment: "+Device)

In [None]:
#@title Instance Segmentation

# Performing segmentation
Outmasks = SingleVideoSegmentation(InputVideoPath,
                                   SegmentationModelPath,
                                   SegmentationModelConfigPath,
                                   Device,
                                   Color = "GRAYSCALE",
                                   ScoreThreshold = 0.4)

# Displaying segmentation results
#DisplaySegmentation(InputVideoPath, Outmasks)

# Saving segmentation
WriteSegmentation(Outmasks, SegmentationSavePath)

In [None]:
#@title Dataloading before Local Tracking

import cv2
import pandas as pd
import numpy as np
import pandas as pd
from symmetry_tracker.tracking.tracker_utilities import LoadAnnotationDF

AnnotPath = SegmentationSavePath

VideoFrames = sorted(os.listdir(InputVideoPath))
Img0 = cv2.imread(os.path.join(InputVideoPath,VideoFrames[0]))
VideoShape = [len(os.listdir(InputVideoPath)), np.shape(Img0)[0], np.shape(Img0)[1]]

AnnotDF = LoadAnnotationDF(AnnotPath, VideoShape, MinObjectPixelNumber=20, MaxOverlapRatio=0.5)

AnnotDF.to_pickle("AnnotDF_Temp.pkl")

In [22]:
#@title Redefining Local Tracking for Saliency Maps

import time
import cv2
import os
import numpy as np
import torch
import torch.nn.functional as F
import gc
from scipy.optimize import linear_sum_assignment

from symmetry_tracker.general_functionalities.misc_utilities import EncodeMultiRLE, DecodeMultiRLE, OuterBoundingBox, BoxOverlap, dfs
from symmetry_tracker.tracking.tracker_metrics import TracksIOU
from symmetry_tracker.tracking.tracker_utilities import LoadAnnotationDF, LoadPretrainedModel

try:
  from IPython.display import display
  from symmetry_tracker.general_functionalities.misc_utilities import progress
except:
  pass

from symmetry_tracker.tracking.symmetry_tracker import KernelTrackBbox

import matplotlib.pyplot as plt
import seaborn as sns

saliency_temporal_distributions = []
saliency_radial_distributions = []

def compute_saliency_maps_all_ch(model, input_tensor, marker):
    model.eval()
    input_tensor.requires_grad_()  # Ensure input_tensor requires gradients

    with torch.enable_grad():
        output_sm = model(input_tensor)
        saliency_maps = torch.zeros_like(input_tensor)  # Assuming input_tensor is (batch_size, channels, H, W)
        for i in range(output_sm.size(1)):  # Iterate over output channels
            model.zero_grad()
            output_sm[:, i, marker[0], marker[1]].backward(retain_graph=True)
            saliency_map = input_tensor.grad.abs()
            saliency_maps += saliency_map

        saliency_maps /= output_sm.size(1)  # Average across all output channels

    return saliency_maps

def compute_saliency_maps_center_ch(model, input_tensor, marker):
    model.eval()
    input_tensor.requires_grad_()  # Ensure input_tensor requires gradients

    with torch.enable_grad():
        output_sm = model(input_tensor)
        saliency_maps = torch.zeros_like(input_tensor)  # Assuming input_tensor is (batch_size, channels, H, W)
        central_ch = output_sm.size(1)//2+1
        model.zero_grad()
        output_sm[:, central_ch, marker[0], marker[1]].backward(retain_graph=True)
        saliency_map = input_tensor.grad.abs()
        saliency_maps += saliency_map

    return saliency_maps

def calculate_radial_profile(image, marker):
    y0, x0 = marker
    y, x = np.indices(image.shape)
    r = np.sqrt((x - x0)**2 + (y - y0)**2)

    r = r.astype(int)
    radial_profile = np.bincount(r.ravel(), image.ravel()) / np.bincount(r.ravel())

    return radial_profile

def KernelTrackBbox_Saliency(LocalVideo, VideoShape, Model, Device, SegmentationConfidence, ObjectBbox):
  with torch.no_grad():
    inputs = LocalVideo
    BboxImg = np.zeros([VideoShape[1],VideoShape[2]])
    [x0, y0, x1, y1] = ObjectBbox
    BboxImg[x0:x1,y0:y1]=255
    inputs = np.append(inputs, [BboxImg], axis=0)
    inputs = np.array(inputs, dtype=float)/255
    inputs = torch.Tensor(np.array(inputs))

    pad_h = (16 - inputs.shape[1] % 16) % 16
    pad_w = (16 - inputs.shape[2] % 16) % 16
    inputs = F.pad(inputs, (0, pad_w, 0, pad_h), mode='constant', value=0)

    inputs=torch.unsqueeze(inputs, dim=0)

    inputs=inputs.to(torch.device(Device))
    output = np.array(torch.sigmoid(Model(inputs).cpu()))
    output = output>SegmentationConfidence
    output = output*1.0
    output = np.nan_to_num(output, nan=0.0, posinf=1.0, neginf=0.0)

    marker = [(x0+x1)//2, (y0+y1)//2]

    #saliency_maps = np.array(compute_saliency_maps_all_ch(Model, inputs, marker).cpu())
    saliency_maps = np.array(compute_saliency_maps_center_ch(Model, inputs, marker).cpu())

    torch.cuda.empty_cache()

    n_channels = np.shape(saliency_maps)[1]
    max_saliency = np.max(saliency_maps)

    saliency_temp_means = []
    saliency_radial_means = []
    for ch in range(n_channels):
      saliency_temp_means.append(np.mean(saliency_maps[0, ch, :, :]))
      saliency_radial_means.append(calculate_radial_profile(saliency_maps[0, ch, :, :], marker))


    """
    fig1, ax1 = plt.subplots(1, n_channels, squeeze=False, figsize=[3*n_channels,3])  # Create multiple subplots
    fig2, ax2 = plt.subplots(1, n_channels, squeeze=False, figsize=[3*n_channels,3])

    for ch in range(n_channels):
        saliency_map_ch = saliency_maps[0, ch, :, :]
        ax1[0][ch].imshow(saliency_map_ch, cmap='magma', vmin=0, vmax=max_saliency/5)
        ax1[0][ch].axis('off')

        input_ch = np.array(inputs[0, ch, :, :].cpu())
        ax2[0][ch].imshow(input_ch, cmap='gray')
        ax2[0][ch].axis('off')
    plt.show()
    """

    """

    sns.barplot(saliency_temp_means)
    plt.show()

    for ch in range(n_channels):
      plt.plot(saliency_radial_means[ch])
    plt.show()
    """

    """
    plt.imshow(output[0, 0, :, :])
    plt.show()
    """

    #### GLOBAL VARIABLE ACCESS SECTION ####

    saliency_temporal_distributions.append(saliency_temp_means)
    saliency_radial_distributions.append(saliency_radial_means)

    #### --- ####

  return np.array(output[0], dtype = bool)

def LocalTracking_Saliency(VideoPath, VideoShape, AnnotDF, Model, Device, TimeKernelSize, Color = "GRAYSCALE", Marker = "BBOX", SegmentationConfidence = 0.2):

  if not Color in ["GRAYSCALE", "RGB"]:
    raise Exception(f"{Color} is an invalid keyword for Color")
  if not Marker in ["CENTROID", "BBOX"]:
    raise Exception(f"{Marker} is not an appropriate keyword for Marker")

  VideoFrames = sorted(os.listdir(VideoPath))
  NumFrames = len(VideoFrames)

  print("Local Tracking")
  try:
    ProgressBar = display(progress(0, NumFrames), display_id=True)
  except:
    pass

  for Frame in range(NumFrames):
    ObjectIDs = AnnotDF.query("Frame == @Frame")["ObjectID"]
    for ObjectID in ObjectIDs:

      # Input image Composition

      if Color == "GRAYSCALE":
        CentralImg = cv2.imread(os.path.join(VideoPath,VideoFrames[Frame]), cv2.IMREAD_GRAYSCALE)
        LocalVideo = np.repeat(CentralImg[np.newaxis, ...], 2*TimeKernelSize+1, axis=0)
        for dt in range(-TimeKernelSize, TimeKernelSize+1):
          if Frame+dt >= 0 and Frame+dt < NumFrames and dt != 0:
            LocalVideo[dt+TimeKernelSize] = cv2.imread(os.path.join(VideoPath,VideoFrames[Frame+dt]), cv2.IMREAD_GRAYSCALE)

      elif Color == "RGB":
        CentralImg = cv2.cvtColor(cv2.imread(os.path.join(VideoPath, VideoFrames[Frame])), cv2.COLOR_BGR2RGB)
        CentralImg = np.transpose(CentralImg, (2,0,1))
        NumReps = 2*TimeKernelSize+1
        LocalVideo = np.zeros((3*NumReps,
                       np.shape(CentralImg)[1],
                       np.shape(CentralImg)[2]),
                      dtype=CentralImg.dtype)
        for Rep in range(NumReps):
          LocalVideo[3*Rep:3*Rep+3] = CentralImg
        for dt in range(-TimeKernelSize, TimeKernelSize+1):
          if Frame+dt >= 0 and Frame+dt < NumFrames and dt != 0:
            LocalImg = cv2.cvtColor(cv2.imread(os.path.join(VideoPath, VideoFrames[Frame+dt])), cv2.COLOR_BGR2RGB)
            LocalVideo[3*(dt+TimeKernelSize):3*(dt+TimeKernelSize)+3] = np.transpose(LocalImg, (2,0,1))

      else:
        raise Exception(f"{Color} is an invalid keyword for Color")

      # Local Tracking

      LocalTrack = None

      """
      if Marker == "CENTROID":
        ObjectCenter = AnnotDF.query("ObjectID == @ObjectID")["Centroid"].iloc[0]
        LocalTrack = KernelTrackCentroid_Saliency(LocalVideo, VideoShape, Model, Device, SegmentationConfidence, ObjectCenter)
      """
      if Marker == "CENTROID":
        raise ValueError("CENTROID marker does not have saliency maps defined for now")

      elif Marker == "BBOX":
        ObjectBbox = AnnotDF.query("ObjectID == @ObjectID")["SegBbox"].iloc[0]
        #LocalTrack = KernelTrackBbox_Saliency(LocalVideo, VideoShape, Model, Device, SegmentationConfidence, ObjectBbox)

        if Frame > TimeKernelSize and Frame < NumFrames-TimeKernelSize:
          print(f"Frame {Frame}")
          print(f"Tracking Object {ObjectID}")
          LocalTrack = KernelTrackBbox_Saliency(LocalVideo, VideoShape, Model, Device, SegmentationConfidence, ObjectBbox)
        else:
          LocalTrack = KernelTrackBbox(LocalVideo, VideoShape, Model, Device, SegmentationConfidence, ObjectBbox)


      AnnotDF.loc[AnnotDF.query("ObjectID == @ObjectID").index, "LocalTrackRLE"] = [EncodeMultiRLE(LocalTrack)]

      # 3D Boundary Box calculation

      bbox = OuterBoundingBox(LocalTrack)
      AnnotDF.loc[AnnotDF.query("ObjectID == @ObjectID").index, "TrackBbox"] = [bbox]


    #### EARLY STOP ####
    if Frame - TimeKernelSize >= 10:
      return AnnotDF

    try:
      ProgressBar.update(progress(Frame, NumFrames))
    except:
      pass

  try:
    ProgressBar.update(progress(1, 1))
  except:
    pass

  return AnnotDF

In [None]:
#@title Local Tracking with Saliency Maps

AnnotDF = pd.read_pickle("AnnotDF_Temp.pkl")

Model = LoadPretrainedModel(TrackingModelPath, Device)
AnnotDF = LocalTracking_Saliency(InputVideoPath,
                        VideoShape,
                        AnnotDF,
                        Model,
                        Device,
                        TimeKernelSize = TimeKernelSize,
                        Color = "GRAYSCALE",
                        Marker = "BBOX",
                        SegmentationConfidence = 0.2)
del Model
gc.collect()

In [19]:
# Save distributions to pickle

import pickle

with open(f'saliency_central_temporal_distributions_AS{sample_category}_F10.pkl', 'wb') as f:
    pickle.dump(saliency_temporal_distributions, f)
    f.close()

with open(f'saliency_central_radial_distributions_AS{sample_category}_F10.pkl', 'wb') as f:
    pickle.dump(saliency_radial_distributions, f)
    f.close()