In [1]:
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
%reload_ext autoreload
%autoreload 2

import glob
import pickle
import os
import time
from collections import deque, namedtuple
from concurrent.futures import ThreadPoolExecutor, wait

import torch
from matplotlib import pyplot as plt
from geomloss import SamplesLoss
from torch import optim
from vectran.renderers.cairo import render as cairo_render
from vectran.util.evaluation_utils import vector_image_from_patches

from vecopt.aligner import (
    StatefulBatchAligner,
    init_ot_aligner,
    make_default_loss_fn,
    make_default_optimize_fn,
)
from vecopt.aligner_utils import (
    LossComposition, store_render_difference, 
    perceptual_bce, strip_confidence_grads, 
    compose, coords_only_grads
)
from vecopt.crossing_model import CrossingRefinerFull
from vecopt.inference import IntermediateOutputAligner

In [3]:
IntermediateSample = namedtuple('IntermediateSample', ['worker_idx', 'sample', 'filename'])
n_workers = 4

data = deque()

data_folder = '/home/apankov/vecopt_datasets/results/abc/our_curves/intermediate_output'
worker_idx = -1
for filename in glob.glob(os.path.join(data_folder, '*')):
    path = os.path.join(data_folder, filename)
    with open(path, 'rb') as handle:
        sample = pickle.load(handle)
        
    worker_idx = (worker_idx + 1) % n_workers
    
    data.append(IntermediateSample(worker_idx, sample, filename))
    
def make_aligner(device, n_steps):
    crossing_model = CrossingRefinerFull().to(device)
    crossing_model.load_state_dict(torch.load('../vecopt/weights/best_crossings_mult.pt'))
    _ = crossing_model.train(False)

    loss = LossComposition()
    ot_loss = SamplesLoss("sinkhorn", p=2, blur=.01, scaling=.5, reach=5.)
    loss.add(make_default_loss_fn(
        bce_schedule=(lambda state: 0.0),
        ot_loss=ot_loss
    ))
    loss.add(perceptual_bce(crossing_model, 3))
    loss.add(perceptual_bce(crossing_model, 4))

    grad_transformer = compose(strip_confidence_grads, coords_only_grads(n_steps - 150))

    aligner = StatefulBatchAligner(device=device)
    init_ot_aligner(aligner, loss_fn=loss, device=device,
                    optimize_fn=make_default_optimize_fn(
                        aligner, 
                        lr=0.25, 
                        transform_grads=grad_transformer,
                        base_optimizer=optim.Adam,
                    ))

    aligner.add_callback(store_render_difference)
    
    return aligner
    

n_steps = 10
    
    
class Worker:
    def __init__(self, data, gpu_idx):
        self.gpu_idx = gpu_idx
        self.data = data
        
        crossing_model = CrossingRefinerFull().to(f'cuda:{gpu_idx}')
        crossing_model.load_state_dict(torch.load('../vecopt/weights/best_crossings_mult.pt'))
        _ = crossing_model.train(False)
        aligner = make_aligner(f'cuda:{gpu_idx}', n_steps)
        self.worker = (IntermediateOutputAligner(aligner, n_steps=n_steps, crossing_model=crossing_model))
    
    def __call__(self):
        results = []
        
        while self.data:
            try:
                sample = self.data.popleft()
                print(sample.filename)
                results.append((self.worker(sample.sample), sample.filename))
            except IndexError:
                pass
            
        return results

In [4]:
workers = [
    Worker(data, 0),
    Worker(data, 1),
    Worker(data, 2),
    Worker(data, 3)
]

In [5]:
executor = ThreadPoolExecutor(max_workers=1)

start = time.time()

jobs = []
for worker in workers:
    jobs.append(executor.submit(worker))
    
wait(jobs)
    
end = time.time()

KeyboardInterrupt: 

In [None]:
print(end - start)

In [6]:
config = {
    'crossing_model_weights': '../vecopt/weights/best_crossings_mult.pt',
    'ot_loss': {
        'p': 2,
        'blur': 0.01,
        'scaling': 0.5,
        'reach': 5.0,
    },
    'perceptual_bce': [3, 4],
    'n_steps': 500,
    'coord_only_grads': 350,
    'batch_size': 64,
    'infer_crossings': True,
}

In [7]:
import json

In [8]:
json.dumps(config)

'{"crossing_model_weights": "../vecopt/weights/best_crossings_mult.pt", "ot_loss": {"p": 2, "blur": 0.01, "scaling": 0.5, "reach": 5.0}, "perceptual_bce": [3, 4], "n_steps": 500, "coord_only_grads": 350, "batch_size": 64, "infer_crossings": true}'