In [10]:
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
%reload_ext autoreload
%autoreload 2

import glob
import pickle
import os
from collections import namedtuple
from multiprocessing import Pool, Queue
from time import sleep

import torch
from matplotlib import pyplot as plt
from geomloss import SamplesLoss
from torch import optim
from vectran.renderers.cairo import render as cairo_render
from vectran.util.evaluation_utils import vector_image_from_patches

from vecopt.aligner import (
    StatefulBatchAligner,
    init_ot_aligner,
    make_default_loss_fn,
    make_default_optimize_fn,
)
from vecopt.aligner_utils import (
    LossComposition, store_render_difference, 
    perceptual_bce, strip_confidence_grads, 
    compose, coords_only_grads
)
from vecopt.crossing_model import CrossingRefinerFull
from vecopt.inference import IntermediateOutputAligner

In [52]:
IntermediateSample = namedtuple('IntermediateSample', ['worker_idx', 'sample', 'filename'])
n_workers = 4

data = []

data_folder = '/home/apankov/vecopt_datasets/results/abc/our_curves/intermediate_output'
worker_idx = -1
for filename in glob.glob(os.path.join(data_folder, '*')):
    path = os.path.join(data_folder, filename)
    with open(path, 'rb') as handle:
        sample = pickle.load(handle)
        
    worker_idx = (worker_idx + 1) % n_workers
    
    data.append(IntermediateSample(worker_idx, sample, filename))
    
def make_aligner(device, n_steps):
    crossing_model = CrossingRefinerFull().to(device)
    crossing_model.load_state_dict(torch.load('../vecopt/weights/best_crossings_mult.pt'))
    _ = crossing_model.train(False)

    loss = LossComposition()
    ot_loss = SamplesLoss("sinkhorn", p=2, blur=.01, scaling=.5, reach=5.)
    loss.add(make_default_loss_fn(
        bce_schedule=(lambda state: 0.0),
        ot_loss=ot_loss
    ))
    loss.add(perceptual_bce(crossing_model, 3))
    loss.add(perceptual_bce(crossing_model, 4))

    grad_transformer = compose(strip_confidence_grads, coords_only_grads(n_steps - 150))

    aligner = StatefulBatchAligner(device=device)
    init_ot_aligner(aligner, loss_fn=loss, device=device,
                    optimize_fn=make_default_optimize_fn(
                        aligner, 
                        lr=0.25, 
                        transform_grads=grad_transformer,
                        base_optimizer=optim.Adam,
                    ))

    aligner.add_callback(store_render_difference)
    
    return aligner
    

n_steps = 500
    
    
class Model:
    def __init__(self, data, results, gpu_idx, total_gpus):
        self.results = results
        self.gpu_idx = gpu_idx
        self.total_gpus = total_gpus
        self.data = data
        
        crossing_model = CrossingRefinerFull().to(f'cuda:{gpu_idx}')
        crossing_model.load_state_dict(torch.load('../vecopt/weights/best_crossings_mult.pt'))
        _ = crossing_model.train(False)
        aligner = make_aligner(f'cuda:{gpu_idx}', n_steps)
        self.worker = (IntermediateOutputAligner(aligner, n_steps=n_steps, crossing_model=crossing_model))
    
    def __call__(self):
        for idx in range(self.gpu_idx, len(self.data), self.total_gpus):
            sample = self.data[idx]
            self.results.append(self.worker(sample.sample))

In [53]:
results = []

models = [
    Model(data, results, 0, 4),
    Model(data, results, 1, 4),
    Model(data, results, 2, 4),
    Model(data, results, 3, 4)
]

In [54]:
from joblib import Parallel, delayed
from time import sleep

In [None]:
Parallel(n_jobs=4, verbose=10)(delayed(model)() for model in models)

[Parallel(n_jobs=4)]: Using backend LokyBackend with 4 concurrent workers.
