In [3]:
import torch, random, math, numpy as np
from tqdm import tqdm
from nets.attention_model import AttentionModel
from problems.tsp.problem_tsp import TSP

model = AttentionModel(
    embedding_dim=128,
    hidden_dim=128,
    n_encode_layers=3,
    n_heads=8,
    tanh_clipping=10.0,
    normalization="batch",
    problem=TSP()
)

ckpt = torch.load("pretrained/tsp_20/epoch-99.pt", map_location="cpu")
model.load_state_dict(ckpt["model"])
model.eval()

model.set_decode_type("greedy")

def euclidean_distance(a, b):
    return math.hypot(a[0]-b[0], a[1]-b[1])

def calc_distance(path, coords):
    d = 0.0
    for i in range(len(path)):
        d += euclidean_distance(coords[path[i]], coords[path[(i+1) % len(path)]])
    return d

def greedy_tour(coords):
    tour = [0]
    unv = set(range(1, len(coords)))
    cur = 0
    while unv:
        nxt = min(unv, key=lambda i: euclidean_distance(coords[cur], coords[i]))
        tour.append(nxt)
        unv.remove(nxt)
        cur = nxt
    return tour

N = 100
num_nodes = 20
ai_ds, gr_ds = [], []

for _ in tqdm(range(N)):
    coords = [[random.random(), random.random()] for _ in range(num_nodes)]
    coords_tensor = torch.tensor(coords, dtype=torch.float).unsqueeze(0)

    with torch.no_grad():
        out = model(coords_tensor, return_pi=True)

    if isinstance(out, torch.Tensor):
        pi = out
    else:
        pi = next(x for x in out if isinstance(x, torch.Tensor) and x.dim() == 2)

    tour_tensor = pi[0]
    tour_list = tour_tensor.long().tolist()
    coords_np  = coords_tensor.squeeze(0).tolist()

    ai_len  = calc_distance(tour_list, coords_np)
    gr_path = greedy_tour(coords_np)
    gr_len  = calc_distance(gr_path, coords_np)

    ai_ds.append(ai_len)
    gr_ds.append(gr_len)

mean_ai     = np.mean(ai_ds)
mean_gr     = np.mean(gr_ds)
diff        = mean_gr - mean_ai

print(f"🧠 AI 평균 경로 길이     : {mean_ai:.4f}")
print(f"🧠 Greedy 평균 경로 길이 : {mean_gr:.4f}")
print(f"📊 차이 (Greedy - AI)    : {diff:.4f}  →  {'AI 우세' if diff>0 else 'Greedy 우세'}")

100%|████████████████████████████████████████| 100/100 [00:00<00:00, 374.76it/s]

🧠 AI 평균 경로 길이     : 3.8717
🧠 Greedy 평균 경로 길이 : 4.5576
📊 차이 (Greedy - AI)    : 0.6860  →  AI 우세





In [2]:
pip install ortools

Collecting ortools
  Downloading ortools-9.12.4544-cp312-cp312-macosx_11_0_arm64.whl.metadata (3.1 kB)
Collecting absl-py>=2.0.0 (from ortools)
  Downloading absl_py-2.2.2-py3-none-any.whl.metadata (2.6 kB)
Collecting protobuf<5.30,>=5.29.3 (from ortools)
  Downloading protobuf-5.29.4-cp38-abi3-macosx_10_9_universal2.whl.metadata (592 bytes)
Collecting immutabledict>=3.0.0 (from ortools)
  Downloading immutabledict-4.2.1-py3-none-any.whl.metadata (3.5 kB)
Downloading ortools-9.12.4544-cp312-cp312-macosx_11_0_arm64.whl (18.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m18.3/18.3 MB[0m [31m57.5 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hDownloading absl_py-2.2.2-py3-none-any.whl (135 kB)
Downloading immutabledict-4.2.1-py3-none-any.whl (4.7 kB)
Downloading protobuf-5.29.4-cp38-abi3-macosx_10_9_universal2.whl (417 kB)
Installing collected packages: protobuf, immutabledict, absl-py, ortools
  Attempting uninstall: protobuf
    Found existing installation: 