In [1]:
import numpy as np
import torch
import math 
import matplotlib.pyplot as plt
import pickle
import copy
from scipy import ndimage
from matplotlib.colors import LinearSegmentedColormap
from torch.utils.data import DataLoader, TensorDataset
from sklearn.decomposition import PCA
from ripser import ripser
from persim import plot_diagrams

seed_ = 123
torch.manual_seed(seed_)

<torch._C.Generator at 0x10e3c4170>

In [2]:
class LoadMazeData(): # This class contains the functions to load the data and pre-process it
    def __init__(self):
        self.train_percentage = 0.9
        pass

    def load_data(self, filename):
        with open(filename, 'rb') as file:
            data = pickle.load(file)
        return data

    def direction_to_angle(self, direction):
        dx, dy = direction  # Assuming direction is a tuple (dx, dy)
        angle_radians = math.atan2(dy, dx)
        return angle_radians
        
    def xy_to_polar (self, x,y): 
        radius = math.sqrt(x**2 + y**2)
        angle = math.atan2(y, x)  
        return radius, angle 

    def _load_data(self, circular = True, pruned = False):
        if circular:
            filename = "data/3d_maze_dataset_spinning.pkl"
        if not circular:
            filename = "data/3d_maze_dataset_walking_around.pkl"
        if pruned:
            filename = "data/3d_maze_dataset_pruned.pkl"
        data = self.load_data(filename)
        
        images = np.array(data['images'])
        positions = np.array(data['positions'])
        directions = np.array(data['directions'])


        images = self.pre_process_images(images)


        xy_in_polar = [ self.xy_to_polar((positions[i][0]-4.5) /4.5 ,(positions[i][1]-4.5) /4.5 ) for i in range(len(positions))]
        pos_dir = np.array([[xy_in_polar[i][0],xy_in_polar[i][1], self.direction_to_angle((directions[i][0],directions[i][1]))] for i in range(len(positions))])
        
        num_samples = images.shape[0]
        train_size = int(num_samples * self.train_percentage)
        
        train_images = images[:train_size]
        train_pos_dir = pos_dir[:train_size]

        return train_images, train_pos_dir
    
    def pre_process_images(self, images):
        images = images / 255.0
        return images
    
    def post_process_images(self, images):
        images = images * 255.0
        return images

    def get_maze_layout(self):
        filename = "data/3d_maze_dataset_spinning.pkl"
        data = self.load_data(filename)
        maze_layout = np.array(data['maze_layout'])
        return maze_layout


In [3]:
data_handler = LoadMazeData()
X = data_handler._load_data(circular = False)
P = data_handler._load_data(pruned = True)
S = data_handler._load_data(circular = True)

In [4]:
# Shape of the data
print("Shape of the data")
print("All points: ", X[0].shape)
print("Pruned points: ", P[0].shape)
print("spinning data: ", S[0].shape)

Shape of the data
All points:  (1400, 64, 64, 3)
Pruned points:  (332, 64, 64, 3)
spinning data:  (77, 64, 64, 3)


In [5]:
X[1][0] # [0],[1] are polar coordiantes(r,angle) of the location in the maze embedded into the unit circle, [2] is the direction of the agent

array([ 0.91624569, -0.24497866, -2.35880685])

# Mapping to $T_1$

In [6]:
def T_1( v , phi,theta,  remove_a_slice = False): # remove_a_slice is used for visualization purposes
    # thetha is random 


    R = 3 # Distance from the center of the hole to the center of the tube
    r = 0.5 # Radius of the tube


    # Parametric equations for the filled torus
    x = (R + v * r * np.cos(phi)) * np.cos(theta)
    y = (R + v * r * np.cos(phi)) * np.sin(theta)
    z = v * r * np.sin(phi)
    return x,y,z



In [7]:
# 3d plot using ploty of T_1(X[1])
import plotly.express as px
import plotly.graph_objects as go
import numpy as np

X_T_1 = [T_1(X[1][i][0], X[1][i][1], X[1][i][2]) for i in range(len(X[1]))]
P_T_1 = [T_1(P[1][i][0], P[1][i][1], P[1][i][2]) for i in range(len(P[1]))]
S_T_1 = [T_1(S[1][i][0], S[1][i][1], S[1][i][2]) for i in range(len(S[1]))]

In [8]:
#number of images
print(len(X_T_1))
print(len(P_T_1))
print(len(S_T_1))

1400
332
77


In [9]:
def plot_3d(data):
    # Plotting using plotly
    fig = go.Figure(data=[go.Scatter3d(x=[x[0] for x in data], y=[x[1] for x in data], z=[x[2] for x in data], mode='markers', marker=dict(size=2))])


    #making all axis from -3 to 3
    fig.update_layout(scene = dict(
                        xaxis = dict(nticks=4, range=[-3,3],),
                        yaxis = dict(nticks=4, range=[-3,3],),
                        zaxis = dict(nticks=4, range=[-3,3],),),
                        margin=dict(r=20, l=10, b=10, t=10))
    fig.show()


plot_3d(X_T_1)
plot_3d(P_T_1)
plot_3d(S_T_1)




In [10]:
def calculate_deltas(datapoints, delta_min, delta_max, homology_to_check, max_dim = 2, number_of_delta = 10):
        # Define delta values for the range you are interested in
        delta_check_range = np.linspace(delta_min, delta_max, number_of_delta)

        # Ripser computes persistent homology on the distance matrix from the point cloud
        ripser_ = ripser(datapoints, maxdim=max_dim) 
        deltas = []

        for delta in delta_check_range:


                ripser_['dgms'][0][:,1][np.isinf(ripser_['dgms'][0][:,1])] = 1000
                ripser_['dgms'][1][:,1][np.isinf(ripser_['dgms'][1][:,1])] = 1000
                if max_dim == 2:
                        ripser_['dgms'][2][:,1][np.isinf(ripser_['dgms'][2][:,1])] = 1000



                number_of_simplices_dim_0 = np.sum(ripser_['dgms'][0][:,1] > delta)
                number_of_simplices_dim_1 = np.sum(ripser_['dgms'][1][:,1] > delta)
                if max_dim == 2:
                        number_of_simplices_dim_2 = np.sum(ripser_['dgms'][2][:,1] > delta)
                

                print("Number of simplices of dimension 0: ", number_of_simplices_dim_0)
                print("Number of simplices of dimension 1: ", number_of_simplices_dim_1)
                if max_dim == 2:
                        print("Number of simplices of dimension 2: ", number_of_simplices_dim_2)


                if number_of_simplices_dim_0 == homology_to_check[0] and number_of_simplices_dim_1 == homology_to_check[1]   :
                        if max_dim == 2:
                                if number_of_simplices_dim_2 == homology_to_check[2]:
                                        print("found delta: ", delta)
                                        deltas.append(delta)
                                else:
                                     continue
                        else:       
                                print("found delta: ", delta)
                                deltas.append(delta)
                                
        if len(deltas) == 0:
            return None, None
        delta_1 = min(deltas)
        delta_2 = max(deltas)


        return delta_1, delta_2

X_T_1 = np.array(X_T_1)
P_T_1 = np.array(P_T_1)
S_T_1 = np.array(S_T_1)



In [11]:
"""``dgms``: list (size maxdim) of ndarray (n_pairs, 2)
                For each dimension less than ``maxdim`` a list of persistence diagrams.
                Each persistent diagram is a pair (birth time, death time)."""

'``dgms``: list (size maxdim) of ndarray (n_pairs, 2)\n                For each dimension less than ``maxdim`` a list of persistence diagrams.\n                Each persistent diagram is a pair (birth time, death time).'

In [12]:
calculate_homology = False # Setting to True will calculate the deltas for the data this takes a few hours to run

In [13]:

if calculate_homology:
      delta_min = 0.1
      delta_max = 10
      homology_to_check = [1,1,0]

      #circluar
      delta_1, delta_2 = calculate_deltas(S_T_1, delta_min, delta_max, homology_to_check,max_dim=1, number_of_delta= 100)
      print("Circular points H = [1,1,0], delta 1   ", delta_1)
      print("Circular points H= [1,1,0] = , delta 2: ", delta_2
            )


In [14]:
if calculate_homology:
    # for homology [1,1,0]:
    homology_to_check = [1,1,0]
    delta_1, delta_2 = calculate_deltas(X_T_1, delta_min, delta_max, homology_to_check,max_dim=1, number_of_delta= 100)
    print("All points H = [1,1,0], delta 1   ", delta_1)
    print("All points H= [1,1,0] = , delta 2: ", delta_2)

In [21]:
if calculate_homology:
    # for homology [1,1,0]:
    delta_min = 0.1
    delta_max = 10    

    homology_to_check = [1,1,0]
    delta_1, delta_2 = calculate_deltas(P_T_1, delta_min, delta_max, homology_to_check,max_dim=1, number_of_delta= 100)
    print("Pruned points H = [1,1,0], delta 1   ", delta_1)
    print("Pruned points H= [1,1,0] = , delta 2: ", delta_2)

Number of simplices of dimension 0:  211
Number of simplices of dimension 1:  44
Number of simplices of dimension 0:  109
Number of simplices of dimension 1:  42
Number of simplices of dimension 0:  61
Number of simplices of dimension 1:  38
Number of simplices of dimension 0:  25
Number of simplices of dimension 1:  34
Number of simplices of dimension 0:  10
Number of simplices of dimension 1:  21
Number of simplices of dimension 0:  5
Number of simplices of dimension 1:  15
Number of simplices of dimension 0:  1
Number of simplices of dimension 1:  8
Number of simplices of dimension 0:  1
Number of simplices of dimension 1:  4
Number of simplices of dimension 0:  1
Number of simplices of dimension 1:  1
found delta:  0.9
Number of simplices of dimension 0:  1
Number of simplices of dimension 1:  1
found delta:  1.0
Number of simplices of dimension 0:  1
Number of simplices of dimension 1:  1
found delta:  1.1
Number of simplices of dimension 0:  1
Number of simplices of dimension 1: 

In [16]:
if calculate_homology:
    # For homology [1,1,5]:
    delta_min = 0.1
    delta_max = 2
    homology_to_check = [1,1,5]
    delta_1, delta_2 = calculate_deltas(X_T_1,delta_min, delta_max, homology_to_check,max_dim=2, number_of_delta= 100)
    print("All points H = [1,1,5], delta 1   ", delta_1)
    print("All points H= [1,1,5] = , delta 2: ", delta_2)



In [18]:

if calculate_homology:
    # For homology [1,1,5]:
    delta_min = 0.5
    delta_max = 4.5
    homology_to_check = [1,1,5]
    delta_1, delta_2 = calculate_deltas(P_T_1, delta_min, delta_max, homology_to_check, max_dim=2, number_of_delta= 10)
    print("Pruned points H = [1,1,5], delta 1   ", delta_1)
    print("Pruned points H= [1,1,5] = , delta 2: ", delta_2)







Number of simplices of dimension 0:  10
Number of simplices of dimension 1:  21
Number of simplices of dimension 2:  5
Number of simplices of dimension 0:  1
Number of simplices of dimension 1:  1
Number of simplices of dimension 2:  5
found delta:  0.9444444444444444
Number of simplices of dimension 0:  1
Number of simplices of dimension 1:  1
Number of simplices of dimension 2:  5
found delta:  1.3888888888888888
Number of simplices of dimension 0:  1
Number of simplices of dimension 1:  1
Number of simplices of dimension 2:  5
found delta:  1.8333333333333333
Number of simplices of dimension 0:  1
Number of simplices of dimension 1:  1
Number of simplices of dimension 2:  5
found delta:  2.2777777777777777
Number of simplices of dimension 0:  1
Number of simplices of dimension 1:  1
Number of simplices of dimension 2:  5
found delta:  2.7222222222222223
Number of simplices of dimension 0:  1
Number of simplices of dimension 1:  1
Number of simplices of dimension 2:  5
found delta:  