In [None]:
import os
import numpy as np
import torch

In [None]:
from torch.utils.data import DataLoader
from generate_data import generate_mec_data
from utils import load_model
from problems import MEC

In [None]:
# 加载 model
model, _ = load_model('outputs/mec_6/demo_6_20241205T021257/')

In [None]:
# 生成 batch
torch.manual_seed(1234)
dataset = MEC.make_dataset(size=6, num_samples=2, dependency=model.dependency)
# Need a dataloader to batch instances
dataloader = DataLoader(dataset, batch_size=1)
# Make var works for dicts
batch = next(iter(dataloader))

In [None]:
# Run the model
model.eval()
model.set_decode_type('greedy')

with torch.no_grad():
    cost, log_p, pi = model(batch, return_pi=True)
tours = pi

print("model solution"," with dp of ", [i+1 for i in model.dependency])
for i in range(pi.size(0)):
   print( cost[i].item(), " -|- ", pi[i])


In [None]:
# 定义 plot
%matplotlib inline
from matplotlib import pyplot as plt

from matplotlib.collections import PatchCollection
from matplotlib.patches import Rectangle
from matplotlib.lines import Line2D

# Code inspired by Google OR Tools plot:
# https://github.com/google/or-tools/blob/fb12c5ded7423d524fc6c95656a9bdc290a81d4d/examples/python/cvrptw_plot.py

def plot_mec(data, tour, ax):
    
    depot = data['UAV_start_pos'].cpu().numpy().tolist()
    locs = data['task_position'].cpu().numpy().tolist()
    tour = tour.cpu().numpy().tolist()
    
    loc = depot + locs
    tour = [0] + tour 
    # 提取按照 tour 顺序排列的坐标
    x_vals = [loc[i][0] for i in tour]
    y_vals = [loc[i][1] for i in tour]
    
    # 绘制路径线
    ax.plot(x_vals, y_vals, marker='o', linestyle='-', color='b', label='Path')

    # 可选: 在每对相邻的点之间绘制箭头
    for i in range(len(tour) - 1):
        x_start, y_start = loc[tour[i]]
        x_end, y_end = loc[tour[i+1]]
        
        # 在两个点之间画箭头，箭头的大小和样式可以调整
        ax.annotate('', xy=(x_end, y_end), xytext=(x_start, y_start),
                    arrowprops=dict(facecolor='black', edgecolor='black', arrowstyle='->', lw=2, mutation_scale=20))

    # 可选: 标注每个点
    for i, (x, y) in enumerate(zip(x_vals, y_vals)):
        ax.text(x, y, f'{tour[i]}', fontsize=30, ha='right', color='red')
    
    # 可选: 绘制第一个和最后一个点的线段，表示起点和终点
    ax.plot([x_vals[0], x_vals[-1]], [y_vals[0], y_vals[-1]], '-', label='Start-End')
    
    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.set_title('MEC Path with Arrows')
    ax.legend()
    

In [None]:
from demopts import *
# 求 loss
def getloss(p):
    return
# 求最优解
def chechformodel(data):
    locs = data['task_position'].cpu().numpy().tolist()
    depot = data['UAV_start_pos'].cpu().numpy().tolist()
    

    # load data

    dependency = data["dependency"]
    task_position = data["task_position"]
    time_window = data["time_window"]
    UAV_start_pos = data["UAV_start_pos"]
    task_data = data["task_data"]
    IoT_resource = data["IoT_resource"]
    UAV_resource = data["UAV_resource"]
    CPU_circles = data["CPU_circles"]
    task_num = task_position.shape[0]
    upload_time = task_data / upload_speed
    UAV_execute_time = CPU_circles / UAV_resource
    UAV_execute_energy = UAV_execute_time * UAV_p
    UAV_transmit_energy = upload_time * UAV_p
    IoT_execute_time = CPU_circles / IoT_resource
    IoT_execute_energy = switched_capacitance * pow(IoT_resource, v - 1) * CPU_circles

    print("the fact best loss")


In [None]:
# Plot the results
for i, (data, tour) in enumerate(zip(dataset, tours)):
   print("the model tour : ", tour)
   fig, ax = plt.subplots(figsize=(5, 5))
   plot_mec(data, tour, ax)
   chechformodel(data)

plt.show()
