### Imports & Preparation

In [1]:
import sys
sys.path.insert(1, '../')  # to load from any submodule in the repo

from models2D import detector2D, predictor2D
from utils import dpcr_utils, dpcr_generator

import torch
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import scipy as sp
import torch
import matplotlib
import matplotlib.pyplot as plt
import time

device = torch.device("cpu")

if torch.device("cuda"):
    device = torch.device("cuda")
    print("Using", torch.cuda.device_count(), "CUDA devices")

seed_file = open('../utils/seed.txt', "r")
seed = int(seed_file.read())
seed_file.close()

print ("Using Seed:", seed)
    
torch.manual_seed(seed)
np.random.seed(seed)

Using 1 CUDA devices
Using Seed: 34897567


### Model Loading

In [3]:
detector = detector2D.Model(k = 10, emb_dims = 1024, dropout = 0.5).to(device)
predictor = predictor2D.Model(k = 10, emb_dims = 1024, dropout = 0.5).to(device)

detector = torch.nn.DataParallel(detector)
predictor = torch.nn.DataParallel(predictor)

detector_checkpoint = torch.load('../models2D/detector2D_model_e30.t7')
detector.load_state_dict(detector_checkpoint['model_state_dict'])
print ("> Loaded detector model (%d epochs)" % detector_checkpoint['epoch'])

predictor_checkpoint = torch.load('../models2D/predictor2D_model_e30.t7')
predictor.load_state_dict(predictor_checkpoint['model_state_dict'][-1])
print ("> Loaded predictor model (%d epochs)" % predictor_checkpoint['epoch'])

_ = detector.eval()
_ = predictor.eval()

> Loaded detector model (30 epochs)
> Loaded predictor model (30 epochs)


In [4]:
start = time.time()

testArr = dpcr_generator.getTrainingArray(100, resolution = 50, max_iter = 5, gamma = 1.2)

print ("Total Time: ", time.time() - start)

Total Time:  0.04598546028137207


In [5]:
def getClusterPoints(pts, threshold):
    
    neighbor_index = sp.spatial.distance.cdist(pts, pts) < threshold
    
    groups = []
    
    for i in range(neighbor_index.shape[0]):
        
        group = None
        
        # check if point exists in another group
        for k in range(len(groups)):
            if i in groups[k]:
                group = k
                break
                
        if group == None:
            
            groups.append(np.nonzero(neighbor_index[i, :].reshape(-1))[0].tolist())
            
        else:
            
            groups[group] += np.nonzero(neighbor_index[i, :].reshape(-1))[0].tolist()
                
    
    cluster_pts = np.zeros((len(groups), pts.shape[1]))
    for i in range(len(groups)):
        cluster_pts[i] = np.mean(pts[groups[i]], axis = 0)

    return cluster_pts

In [1]:
entry = testArr[np.random.randint(0, len(testArr))]

origin_points = np.copy(entry[:,0:2])

points = entry[:,0:2]
nearest_hidden = entry[:,2:4]
edge_mask = entry[:,4].astype(bool)

extra_points = []

iteration_count = 0

for i in range(10):
    
    pts = torch.from_numpy(np.swapaxes(np.expand_dims(points, axis = 0), 1,2)).float().to(device)
    
    edge_points_prediction = detector(pts).squeeze(0).transpose(0,1)
    new_points_prediction = predictor(pts).squeeze(0).transpose(0,1)
    
    new_points_prediction_np = new_points_prediction.detach().cpu().numpy()
    
    edge_points_prediction_np = edge_points_prediction.detach().cpu().numpy()
    edge_points_prediction_np = np.argmax(edge_points_prediction_np, axis = 1).astype(bool)
    
    del pts
    del edge_points_prediction
    del new_points_prediction
    
    iteration_count += 1
    
    if not np.any(edge_points_prediction_np):
        
        break
        
    else:

        new_points = (points + new_points_prediction_np)[edge_points_prediction_np]
 
        dists = sp.spatial.distance.cdist(new_points, points)
    
        new_points = new_points[np.amin(dists >= 0.05, axis = 1).astype(bool)]
        
        new_points = getClusterPoints(new_points, 0.05)
        
        if new_points.shape[0] == 0:
            
            break
        
        else:
        
            extra_points.append(new_points)
            points = np.concatenate((points, new_points), axis = 0)
        

print ("Ran %d iterations!" % (iteration_count))

plt.figure(figsize=(16,9))

fig, ax = plt.subplots(1,1, figsize=(16,9))

plt.axis('equal')

ax.set_facecolor([0.5,0.5,0.5])
ax.grid(True)
ax.axis('equal')

ax.scatter(
    origin_points[:,0],
    origin_points[:,1],
    color = 'black',
    label = 'input',
    s = 100
)

for i in range(len(extra_points)):
    
    fac = (i+1) / len(extra_points)
    
    ax.scatter(
        extra_points[i][:,0],
        extra_points[i][:,1],
        alpha = 0.9,
        label = "iteration %d" % (i+1),
        color = matplotlib.cm.get_cmap('viridis')(fac),
        s = 100
    )

ax.legend()

#plt.show()

#plt.savefig("example_.png", dpi = 200)

NameError: name 'testArr' is not defined