In [1]:
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
%reload_ext autoreload
%autoreload 2

import glob
import pickle
import os
from collections import namedtuple
from concurrent.futures import ThreadPoolExecutor
from time import sleep

import torch
from matplotlib import pyplot as plt
from geomloss import SamplesLoss
from torch import optim
from vectran.renderers.cairo import render as cairo_render
from vectran.util.evaluation_utils import vector_image_from_patches

from vecopt.aligner import (
    StatefulBatchAligner,
    init_ot_aligner,
    make_default_loss_fn,
    make_default_optimize_fn,
)
from vecopt.aligner_utils import (
    LossComposition, store_render_difference, 
    perceptual_bce, strip_confidence_grads, 
    compose, coords_only_grads
)
from vecopt.crossing_model import CrossingRefinerFull
from vecopt.inference import IntermediateOutputAligner

In [2]:
IntermediateSample = namedtuple('IntermediateSample', ['worker_idx', 'sample', 'filename'])
n_workers = 4

data = []

data_folder = '/home/apankov/vecopt_datasets/results/abc/our_curves/intermediate_output'
worker_idx = -1
for filename in glob.glob(os.path.join(data_folder, '*')):
    path = os.path.join(data_folder, filename)
    with open(path, 'rb') as handle:
        sample = pickle.load(handle)
        
    worker_idx = (worker_idx + 1) % n_workers
    
    data.append(IntermediateSample(worker_idx, sample, filename))
    
def make_aligner(device, n_steps):
    crossing_model = CrossingRefinerFull().to(device)
    crossing_model.load_state_dict(torch.load('../vecopt/weights/best_crossings_mult.pt'))
    _ = crossing_model.train(False)

    loss = LossComposition()
    ot_loss = SamplesLoss("sinkhorn", p=2, blur=.01, scaling=.5, reach=5.)
    loss.add(make_default_loss_fn(
        bce_schedule=(lambda state: 0.0),
        ot_loss=ot_loss
    ))
    loss.add(perceptual_bce(crossing_model, 3))
    loss.add(perceptual_bce(crossing_model, 4))

    grad_transformer = compose(strip_confidence_grads, coords_only_grads(n_steps - 150))

    aligner = StatefulBatchAligner(device=device)
    init_ot_aligner(aligner, loss_fn=loss, device=device,
                    optimize_fn=make_default_optimize_fn(
                        aligner, 
                        lr=0.25, 
                        transform_grads=grad_transformer,
                        base_optimizer=optim.Adam,
                    ))

    aligner.add_callback(store_render_difference)
    
    return aligner
    

n_steps = 500
    
    
class Model:
    def __init__(self, data, gpu_idx, total_gpus):
        self.gpu_idx = gpu_idx
        self.total_gpus = total_gpus
        self.data = data
        
        crossing_model = CrossingRefinerFull().to(f'cuda:{gpu_idx}')
        crossing_model.load_state_dict(torch.load('../vecopt/weights/best_crossings_mult.pt'))
        _ = crossing_model.train(False)
        aligner = make_aligner(f'cuda:{gpu_idx}', n_steps)
        self.worker = (IntermediateOutputAligner(aligner, n_steps=n_steps, crossing_model=crossing_model))
    
    def __call__(self):
        results = []
        
        for idx in range(self.gpu_idx, len(self.data), self.total_gpus):
            sample = self.data[idx]
            results.append((self.worker(sample.sample), sample.filename))
            
        return results

In [3]:
models = [
    Model(data, 0, 4),
    Model(data, 1, 4),
    Model(data, 2, 4),
    Model(data, 3, 4)
]

In [4]:
executor = ThreadPoolExecutor(max_workers=4)

def call(x):
    return x()

results = executor.map(call, models)

In [6]:
results = list(results)

In [None]:
with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
    future_to_url = {executor.submit(load_url, url, 60): url for url in URLS}
    for future in concurrent.futures.as_completed(future_to_url):
        url = future_to_url[future]
        try:
            data = future.result()
        except Exception as exc:
            print('%r generated an exception: %s' % (url, exc))
        else:
            print('%r page is %d bytes' % (url, len(data)))

In [8]:
results

[[(tensor([[[ 4.0422e-03, -5.5553e-03,  2.8521e-04,  ..., -3.0744e-03,
             -1.1091e-02,  6.5028e-01],
            [-2.0370e-04, -4.8677e-03, -3.3178e-03,  ...,  3.1711e-03,
             -4.9337e-03,  1.3062e-01],
            [-2.2724e-04, -4.8564e-03, -3.3208e-03,  ...,  3.1445e-03,
             -4.8829e-03,  1.2462e-01],
            ...,
            [-3.0106e-04, -4.8138e-03, -3.3286e-03,  ...,  3.0611e-03,
             -4.7277e-03,  1.0704e-01],
            [-2.9942e-04, -4.8153e-03, -3.3282e-03,  ...,  3.0622e-03,
             -4.7311e-03,  1.0727e-01],
            [-2.9910e-04, -4.8157e-03, -3.3282e-03,  ...,  3.0624e-03,
             -4.7321e-03,  1.0734e-01]],
   
           [[ 4.0422e-03, -5.5553e-03,  2.8521e-04,  ..., -3.0744e-03,
             -1.1091e-02,  6.5028e-01],
            [-2.0370e-04, -4.8677e-03, -3.3178e-03,  ...,  3.1711e-03,
             -4.9337e-03,  1.3062e-01],
            [-2.2724e-04, -4.8564e-03, -3.3208e-03,  ...,  3.1445e-03,
             -4.882