In [1]:
import torch
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import os
import pickle
import time
from nets.model import Model
from utils import get_device, get_costs, get_costs_from_D, get_mrcst_costs
cost_functions = {'default': get_costs, 'mrcst': get_mrcst_costs, 'from_D': get_costs_from_D}

In [2]:
device, device_count = get_device()
device, device_count

torch.cuda.is_available(): True
torch.cuda.device_count(): 1
device_idxes: [0]
devices: [<torch.cuda.device object at 0x00000158877AF550>]
device_names: ['NVIDIA GeForce RTX 2070 with Max-Q Design'] 

torch.cuda.current_device(): 0
torch.cuda.device(current_device): <torch.cuda.device object at 0x00000158877AF550>
torch.cuda.get_device_name(current_device): NVIDIA GeForce RTX 2070 with Max-Q Design 

Using device: cuda 

Memory Usage:
Allocated: 0.0 GB
Cached:    0.0 GB


(device(type='cuda'), 1)

In [4]:
def load_model(load_graph_size, degree_constrain, cost_function_key, is_stp, test_graph_size=None):
    """
    @param grapg_size: int, e.g. 20, 50, 100
    @param degree_constrain: None / int > 1
    @param cost_function_key: str, e.g. default, mrcst
    @param is_stp: bool
    @param return: model
    """
    graph_dim = 2
    d_model = 256
    nhead = 8
    num_encoder_layers = 3
    num_decoder_layers = 2
    dim_feedforward = 128 if load_graph_size == 100 else 512
    lr = 1e-04
    cost_function = cost_functions[cost_function_key]
    num_batches = 2500
    batch_size = 512
    target_epoch = None
    paths = []
    i = 0
    for rn in os.listdir('pretrained'):
        target_run_name = 'gs{}-dm{}-nh{}-nel{}-ndl{}-df{}-lr{:.0e}-dc{}-stp{}-cf{}-nb{}-bs{}'.format(
            load_graph_size, d_model, nhead, num_encoder_layers, num_decoder_layers, dim_feedforward, lr, degree_constrain, is_stp, cost_function_key, num_batches, batch_size
        )
        if rn[16:] == target_run_name:
            num_epochs = [int(_.split('.')[0]) for _ in os.listdir('pretrained/{}'.format(rn))]
            if not num_epochs:
                continue
            path = 'pretrained/{}/{}.pt'.format(
                rn, max(num_epochs) if not target_epoch else target_epoch
            )
            print('{} find satified checkpoint in:'.format(i), path)
            paths.append(path)
            i += 1
    if paths:
        idx = 0 if len(paths) == 1 else eval(input())
        target_path = paths[idx]
        print('load checkpoint in:', target_path)
    else:
        print('do not find satified checkpoint')
    if not test_graph_size:
        test_graph_size = load_graph_size
    model = Model(test_graph_size, graph_dim, d_model, nhead, num_encoder_layers, num_decoder_layers, dim_feedforward, degree_constrain, is_stp, device=device).eval()
    model.load_state_dict(torch.load(target_path)['model_state_dict'])
    return model

load_model(50, 2, 'default', False)
load_model(50, None, 'mrcst', False)
load_model(50, None, 'default', True)
load_model(20, 2, 'default', False)
load_model(20, None, 'mrcst', False)
load_model(20, None, 'default', True)
print()

0 find satified checkpoint in: pretrained/20221119-000812-gs50-dm256-nh8-nel3-ndl2-df512-lr1e-04-dc2-stpFalse-cfdefault-nb2500-bs512/99.pt
load checkpoint in: pretrained/20221119-000812-gs50-dm256-nh8-nel3-ndl2-df512-lr1e-04-dc2-stpFalse-cfdefault-nb2500-bs512/99.pt
0 find satified checkpoint in: pretrained/20221119-000735-gs50-dm256-nh8-nel3-ndl2-df512-lr1e-04-dcNone-stpFalse-cfmrcst-nb2500-bs512/99.pt
load checkpoint in: pretrained/20221119-000735-gs50-dm256-nh8-nel3-ndl2-df512-lr1e-04-dcNone-stpFalse-cfmrcst-nb2500-bs512/99.pt
0 find satified checkpoint in: pretrained/20221119-000526-gs50-dm256-nh8-nel3-ndl2-df512-lr1e-04-dcNone-stpTrue-cfdefault-nb2500-bs512/99.pt
load checkpoint in: pretrained/20221119-000526-gs50-dm256-nh8-nel3-ndl2-df512-lr1e-04-dcNone-stpTrue-cfdefault-nb2500-bs512/99.pt
0 find satified checkpoint in: pretrained/20221119-001123-gs20-dm256-nh8-nel3-ndl2-df512-lr1e-04-dc2-stpFalse-cfdefault-nb2500-bs512/99.pt
load checkpoint in: pretrained/20221119-001123-gs20-dm

In [5]:
def test(x, raw, model, cost_function_key, terminal_size, batch_size, sample_size, greedy=True, return_costs=False, return_edges=False):
    """
    @param raw: raw_x of shape (batch_size, graph_size, 2) or D of shape (batch_size, graph_size, graph_size)
    """
    graph_size = x.shape[1]
    assert batch_size >= sample_size
    assert x.shape[:2] == raw.shape[:2]
    assert not cost_function_key == 'from D' or raw.shape[-1] == graph_size
    dataset = TensorDataset(x, raw)
    dataloader = DataLoader(dataset, batch_size=batch_size // sample_size)
    cost = mean = std = 0
    start_time = time.time()
    costs_min_list = []
    costs_mean_list = []
    costs_std_list = []
    edges_list = []
    with torch.no_grad():
        for data, raw_data in dataloader:  # (batch_size, graph_size, 2)
            edges, log_prob_sum, lengths, weights_by_step = model(
                data, terminal_size=terminal_size, batch_size=batch_size, greedy=greedy, return_weights=True, max_sample_size=sample_size
            )  # (batch_size, 2, graph_size - 1)
            assert not len(edges) % len(data), 'length error, len(edges): {}, len(data): {}'.format(len(edges), len(data))
            assert sample_size == len(edges) // len(data)
            costs = cost_functions[cost_function_key](
                raw_data.unsqueeze(dim=1).expand(-1, sample_size, -1, -1).clone().view(-1, graph_size, raw.shape[-1]),
                edges
            )  # (batch_size * sample_size)
            temp = costs.view(-1, sample_size).cpu()
            batch_cost_min = temp.min(dim=1)[0]
            batch_cost_mean = temp.mean(dim=1)
            batch_cost_std =temp.std(dim=1)
            if return_edges:
                batch_idx = torch.argmin(costs.view(-1, sample_size), dim=1)
                batch_edge = edges.view(-1, sample_size, 2, graph_size - 1)[torch.arange(len(costs) // sample_size), batch_idx]
            assert  len(data) == len(batch_cost_min) == len(batch_cost_mean) == len(batch_cost_std)
            cost += (batch_cost_min).sum()
            mean += (batch_cost_mean).sum()
            std += (batch_cost_std).sum()
            if return_costs:
                for c in batch_cost_min:
                    costs_min_list.append(c.item())
                for c in batch_cost_mean:
                    costs_mean_list.append(c.item())
                for c in batch_cost_std:
                    costs_std_list.append(c.item())
            if return_edges:
                for e in batch_edge:
                    edges_list.append(np.array(e.cpu()))
        torch.cuda.empty_cache()
    return (
        cost / len(dataloader.dataset), 
        mean / len(dataloader.dataset), 
        std / len(dataloader.dataset), 
        (time.time() - start_time) / len(dataloader.dataset), 
        {
            'costs_min_list': costs_min_list,
            'costs_mean_list': costs_mean_list,
            'costs_std_list': costs_std_list,
            'edge_list': edges_list
        }
    )
load_graph_size, test_graph_size = 50, 20
num_test, batch_size, sample_size = 3, 8, 4  # batch_size of dataloader is 8 / 4 = 2, which divide data into batches of [2, 1]

with open('data/random/{}_test_seed1234.pkl'.format(test_graph_size), 'rb') as f:
    x = torch.tensor(np.load(f, allow_pickle=True), dtype=torch.float32).to(device)[:num_test]  # (batch_size, graph_size, dim)
with open('data/random/{}_test_seed1234.pkl'.format(test_graph_size), 'rb') as f:
    raw_x = torch.tensor(np.load(f, allow_pickle=True), dtype=torch.float32).to(device)[:num_test]

In [6]:
# dcmst (d=2)
degree_constrain, cost_function_key, terminal_size = 2, 'default', test_graph_size
model = load_model(load_graph_size, degree_constrain, cost_function_key, False, test_graph_size=test_graph_size)
test(x, raw_x, model, cost_function_key, terminal_size, batch_size, sample_size, greedy=True, return_costs=True, return_edges=True)

0 find satified checkpoint in: pretrained/20221119-000812-gs50-dm256-nh8-nel3-ndl2-df512-lr1e-04-dc2-stpFalse-cfdefault-nb2500-bs512/99.pt
load checkpoint in: pretrained/20221119-000812-gs50-dm256-nh8-nel3-ndl2-df512-lr1e-04-dc2-stpFalse-cfdefault-nb2500-bs512/99.pt


(tensor(3.2387),
 tensor(3.2995),
 tensor(0.0556),
 0.6177262465159098,
 {'costs_min_list': [3.1079583168029785,
   3.2559947967529297,
   3.352095127105713],
  'costs_mean_list': [3.1366958618164062,
   3.3844077587127686,
   3.3772921562194824],
  'costs_std_list': [0.03780844062566757,
   0.09447567909955978,
   0.034585416316986084],
  'edge_list': [array([[12,  4,  9, 17, 12, 15,  2, 16,  6,  1, 13, 19, 18,  3,  0, 14,
           10,  7,  5],
          [ 4,  9, 17,  6, 15,  2, 16,  8,  1, 13, 19, 18,  3,  0, 14, 10,
            7,  5, 11]], dtype=int64),
   array([[17, 13,  7,  8,  9, 15,  2, 18,  4, 10,  5,  1,  3, 19,  0, 14,
           12, 16, 11],
          [13,  7,  8,  9, 15,  2, 18,  4, 10,  5,  1,  3, 19,  0, 14, 12,
           16, 11,  6]], dtype=int64),
   array([[ 5, 13,  9, 14, 16,  6,  8,  3, 10, 19,  1,  5,  0, 11,  4,  2,
           17,  7, 18],
          [13,  9, 14, 16,  6,  8,  3, 10, 19,  1, 11,  0,  7,  4,  2, 17,
           12, 18, 15]], dtype=int64)]})

In [7]:
# mrcst
degree_constrain, cost_function_key, terminal_size = None, 'mrcst', test_graph_size
model = load_model(load_graph_size, degree_constrain, cost_function_key, False, test_graph_size=test_graph_size)
test(x, raw_x, model, cost_function_key, terminal_size, batch_size, sample_size, greedy=True, return_costs=True, return_edges=True)

0 find satified checkpoint in: pretrained/20221119-000735-gs50-dm256-nh8-nel3-ndl2-df512-lr1e-04-dcNone-stpFalse-cfmrcst-nb2500-bs512/99.pt
load checkpoint in: pretrained/20221119-000735-gs50-dm256-nh8-nel3-ndl2-df512-lr1e-04-dcNone-stpFalse-cfmrcst-nb2500-bs512/99.pt


(tensor(131.4133),
 tensor(131.4133),
 tensor(0.),
 0.04753939310709635,
 {'costs_min_list': [120.3918228149414,
   137.89297485351562,
   135.95506286621094],
  'costs_mean_list': [120.3918228149414,
   137.89297485351562,
   135.95506286621094],
  'costs_std_list': [0.0, 0.0, 0.0],
  'edge_list': [array([[ 6,  6,  1,  1, 10, 10,  7, 13,  1,  1, 10,  6, 17, 17, 14, 17,
            7,  2,  7],
          [17,  1, 13, 10,  7, 14,  5,  3, 19, 18,  0,  9, 12, 15, 11,  4,
            2, 16,  8]], dtype=int64),
   array([[ 6, 11,  4,  4,  4,  5, 18,  4, 10,  4,  4, 18, 11,  4, 18,  4,
           18,  7, 15],
          [11,  4, 18, 10,  5,  1,  2,  3, 17, 19,  0, 15, 16, 14, 13, 12,
            7,  8,  9]], dtype=int64),
   array([[10,  7,  7,  0,  0,  4,  8,  0,  0,  0, 10,  4,  7, 15,  5, 11,
           10,  4, 17],
          [ 7,  0,  4,  8, 16, 11,  6,  5, 14,  9,  3, 17, 15, 18, 13,  1,
           19,  2, 12]], dtype=int64)]})

In [8]:
# stp
degree_constrain, cost_function_key, terminal_size = None, 'default', 2
model = load_model(load_graph_size, degree_constrain, cost_function_key, True, test_graph_size=test_graph_size)
test(x, raw_x, model, cost_function_key, terminal_size, batch_size, sample_size, greedy=True, return_costs=True, return_edges=True)

0 find satified checkpoint in: pretrained/20221119-000526-gs50-dm256-nh8-nel3-ndl2-df512-lr1e-04-dcNone-stpTrue-cfdefault-nb2500-bs512/99.pt
load checkpoint in: pretrained/20221119-000526-gs50-dm256-nh8-nel3-ndl2-df512-lr1e-04-dcNone-stpTrue-cfdefault-nb2500-bs512/99.pt


(tensor(0.4226),
 tensor(0.4226),
 tensor(0.),
 0.0059884389241536455,
 {'costs_min_list': [0.2954133450984955,
   0.4907519221305847,
   0.4814966320991516],
  'costs_mean_list': [0.2954133450984955,
   0.4907519221305847,
   0.4814966320991516],
  'costs_std_list': [0.0, 0.0, 0.0],
  'edge_list': [array([[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
          [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
         dtype=int64),
   array([[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
          [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
         dtype=int64),
   array([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
          [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
         dtype=int64)]})