# Testing Snake creation and evolution
Much (but not all) of this code will be duplicated to [main.py](../main.py)

In [1]:
import cv2
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm_notebook as tqdm

import tsnake.initialize as init
from tsnake.snake import TSnake, Element, Node
from tsnake.grid import Grid, Point
from tsnake.utils import dist, seg_intersect


## Load images, compute masked reigons
### Link to images I used for the plane are [here](https://drive.google.com/open?id=1TK6rc-USD4KKI0Bss_B6r4l13oAD2uqG)

In [2]:
msk_path = '../t-snake-mask-generation/examples/places2/case1_mask.png'
img_path = '../t-snake-mask-generation/examples/places2/case1_raw.png'

mask = init.load_mask(path=msk_path, convert=True)
image = init.load_grayscale_image(img_path)

regions = init._find_disjoint_masked_regions(mask)
# NOTE: Uncomment to visialize initial masked reigons
# init.visualize_masked_regions(mask, regions)

regions = init.compute_masked_regions(image, mask)
print('number of masked regions:', len(regions))


number of masked regions: 8


## Initialize t-snakes
_Note: Length 17 t-snake (i.e. the sort one) is visibly wrong, not sure how we can fix that_
*Legend*
* Green lines: Normal vectors
* White Nodes: Initial 5 nodes of the snake
* Yellow Nodes: Final 5 nodes of the snake
    * This tells us which way the snake was initialized, Counter Clockwise or Clockwise

In [3]:
tsnakes = []

### Parameters for T-snakes ### 
sigma = 20.0    # gaussian filter sigma
p = 1.0         # scale final image force with p
c = 2.0         # scale gradient magnitude of image (applied before p)
a = 1.0         # tension parameter
b = 1.0         # bending parameter
q = 1.0         # amplitude of the inflation force
gamma = 1.0     # friction coefficient
dt = 1.0        # time step
threshold = 10  # inflation force treshold

for region in regions:
    tsnake = region.initialize_tsnake(
        N=1000, p=p, c=c, sigma=sigma, a=a, b=b, q=q, gamma=gamma,
        dt=dt, threshold=threshold
    )
    tsnakes.append(tsnake)
    # region.visualize() # NOTE: To show tsnakes on images, uncomment

tsnakes.sort(key=lambda t: len(t.nodes))
t_snake_lengths = [len(t.nodes) for t in tsnakes]
print('Length of T-Snakes initialized on image:\n{}'.format(t_snake_lengths))


Length of T-Snakes initialized on image:
[32, 132, 174, 352, 464, 607, 1030, 1049]


## Create Grid

In [4]:
image = init.load_grayscale_image(img_path)
print('image shape: ', image.shape)
grid = Grid(image=image, scale=1)

# Update grid
# NOTE: Uncomment for force, expensive calculation
# force = grid.get_image_force(2,2,2)
grid.gen_simplex_grid()
print('Simplex grid shape: {}'.format(grid.grid.shape))


image shape:  (512, 680)
Simplex grid shape: (512, 680)


## Test Intersection Computation

In [5]:
print('shape of tsnakes:', np.shape(tsnakes))

# Compute snake intersections with grid
intersections = grid.get_snake_intersections(tsnakes)
print('intersections shape:', np.shape(intersections))

n_inter_for_each_t_snake = [len(inter) for inter in intersections]
print('num of intersections for each t-snake:', n_inter_for_each_t_snake)


shape of tsnakes: (8,)
intersections shape: (8,)
num of intersections for each t-snake: [32, 132, 174, 352, 464, 609, 1030, 1054]


## Test snake evolution

In [6]:
regions[1].show_snake(save_fig='images/img0.png')

# Test snake evolution
iterations = 1 # dummy value for testing purposes
M = 5

# Pick just one (for testing purposes) of the snakes on the grid
snakes = [tsnakes[1]]

# save the initial position of the snake
# for i in tqdm(range(snake.num_nodes)):
#     pos = snake.nodes[i].position

# run iterations=20 of:
# 1) the m-step function (which has M=5 deformation steps), and
# 2) the reparameterization (occuring every M=5 deformation steps)
for j in tqdm(range(iterations)):
    for snake in snakes:
        snake.m_step(M)
    new_snakes = grid.reparameterize(snakes)

    # save the updated positions of the nodes
    # for i in range(snake.num_nodes):
    #     pos = snake.nodes[i].position
    #     #print(pos)
    #     X[i, j+1] = pos[0,0]
    #     Y[i, j+1]= pos[0,1]

    regions[1].show_snake(save_fig='images/img{}.png'.format(j+1))


HBox(children=(IntProgress(value=0, max=1), HTML(value='')))




<Figure size 576x576 with 0 Axes>

<Figure size 576x576 with 0 Axes>

## Visualize Evolution

In [19]:
# plt.imshow(image, cmap=plt.cm.binary)
# colors = ['red', 'blue', 'orange', 'green','black']
# for i in tqdm(range(M)):
#     plt.clf()
#     plt.scatter(Y[:,i], X[:,i], c='red', s=1, alpha=0.5)
#     plt.savefig('images/img{}.png'.format(i))
