# ビームサーチ＆エッジの教師データの有効性検証

In [None]:
exact_edges_target = create_data_files(config, data_mode="val")

In [None]:
def beam_test(net, config, master_bar, mode='val'):
    # Set evaluation mode
    net.eval()
    
    # Assign parameters
    num_data = getattr(config, f'num_{mode}_data')
    #batch_size = config.batch_size
    #num_data = 1
    batch_size = 1
    num_commodities = config.num_commodities
    num_nodes = config.num_nodes
    beam_size = config.beam_size
    batches_per_epoch = config.batches_per_epoch
    #batches_per_epoch = 1
    accumulation_steps = config.accumulation_steps
    
    # Load UELB data
    dataset = DatasetReader(num_data, batch_size, mode)
    
        # Convert dataset to iterable
    dataset = iter(dataset)
    
    # Initially set loss class weights as None
    edge_cw = None

    # Initialize running data
    running_loss = 0.0
    running_mean_maximum_load_factor = 0.0
    running_gt_load_factor = 0.0
    running_nb_data = 0
    running_nb_batch = 0
    
    with torch.no_grad():
        start_test = time.time()
        for batch_num in progress_bar(range(batches_per_epoch), parent=master_bar):
            print("batch_num: ", batch_num)
            # Generate a batch of TSPs
            try:
                batch = next(dataset)
            except StopIteration:
                break

            # Convert batch to torch Variables
            x_edges_capacity = torch.FloatTensor(batch.edges_capacity).to(torch.float).contiguous().requires_grad_(False)
            y_edges = torch.LongTensor(batch.edges_target).to(torch.long).contiguous().requires_grad_(False)
            batch_commodities = torch.LongTensor(batch.commodities).to(torch.long).contiguous().requires_grad_(False)     
            
            kakai_max_values_per_batch = compute_load_factor(exact_edges_target, x_edges_capacity, batch_commodities)
            max_values_per_batch = compute_load_factor(y_edges, x_edges_capacity, batch_commodities)
            print("kakai_max_values_per_batch: ", kakai_max_values_per_batch)
            print("max_values_per_batch: ", max_values_per_batch)

            # Compute class weights (if uncomputed)
            if type(edge_cw) != torch.Tensor:
                edge_labels = y_edges.cpu().numpy().flatten()
                edge_cw = compute_class_weight("balanced", classes=np.unique(edge_labels), y=edge_labels)
                

            beam_search = BeamsearchUELB(
                y_edges, beam_size, batch_size, x_edges_capacity, batch_commodities, dtypeFloat, dtypeLong, mode_strict=True) 
            #bs_nodes, pred_paths = beam_search.search()
            pred_paths = beam_search.search()
            torch.set_printoptions(linewidth=200)
            
            #if batch_num == 0:
                #print("bs_nodes.shape: ", bs_nodes.shape)
                #print("bs_nodes:\n", bs_nodes)
                #print("pred_paths:\n", pred_paths)
            
            # Compute error metrics and mean load factor
            # err_edges, err_tour, err_tsp, tour_err_idx, tsp_err_idx = edge_error(y_preds, y_edges, x_edges)
            mean_maximum_load_factor = mean_feasible_load_factor(batch_size, num_commodities, num_nodes, pred_paths, x_edges_capacity, batch_commodities)
            print("mean_maximum_load_factor: ", mean_maximum_load_factor)
            gt_load_factor = np.mean(batch.load_factor)
            print("gt_load_factor: ", gt_load_factor)
            #if mean_maximum_load_factor < gt_load_factor:
               #print("x_edges_capacity:\n", x_edges_capacity)
                #print("batch_commodities:\n", batch_commodities)
                #print("pred_paths:\n", pred_paths)
            running_mean_maximum_load_factor += batch_size* mean_maximum_load_factor
            running_gt_load_factor += batch_size* gt_load_factor
            
        #print("running_mean_maximum_load_factor: ", running_mean_maximum_load_factor)
        #print("running_gt_load_factor: ", running_gt_load_factor)
        Accuracy = running_gt_load_factor /running_mean_maximum_load_factor
        print("Accuracy: ", Accuracy)
        
            
    # Compute statistics for full epoch


    return mean_maximum_load_factor, gt_load_factor