In [None]:
import numpy as np
import random
import matplotlib.pyplot as plt
from collections import deque
from grid import generate_synthetic_traffic  

class TabuSearchOptimizer:
    def __init__(self, num_bs=50, num_pixels=2500, tabu_size=10, candidate_size=20):
        self.num_bs = num_bs
        self.num_pixels = num_pixels
        self.side_len = int(np.sqrt(num_pixels))
        
        # K = 크기 조정 상수 (Cost 스케일링)
        self.K = 10**8  
        # g= 가중치 변수
        # g 가 클 수록 커버리지(R)을 높이는걸 목표로 함 -> 작을수록 N_BTS(기지국 수)를 줄이는걸 목표로 함
        self.g = 4.5 
        
        # Tabu Search 파라미터 V=10
        self.tabu_size = tabu_size
        self.tabu_list = deque(maxlen=self.tabu_size)
        self.candidate_size = candidate_size  
        self.grid_resolution = 0.06 
        
        # 트래픽 맵 생성
        traffic_2d = self._generate_traffic_map_2d()
        self.traffic_map = traffic_2d 
        self.traffic_weight_map = traffic_2d.flatten()
        
        # 기지국 좌표 미리 생성
        self.candidate_locations = []
        for _ in range(self.num_bs):
            bx = random.randint(0, self.side_len - 1)
            by = random.randint(0, self.side_len - 1)
            self.candidate_locations.append([bx, by])
        self.candidate_locations = np.array(self.candidate_locations)
        
        self.coverage_matrix = self._generate_log_normal_coverage()
        self.current_signal_counts = np.zeros(self.num_pixels, dtype=int)

    def _generate_traffic_map_2d(self):
        traffic_2d = generate_synthetic_traffic(
            width=self.side_len,    
            height=self.side_len,   
            pattern="multi_hotspot", 
            rng=np.random.default_rng(42)
        )
        traffic_2d = traffic_2d * 3.0 
        return traffic_2d

    def _generate_log_normal_coverage(self):
        coverage = np.zeros((self.num_bs, self.num_pixels), dtype=int)
        # A : 거리가 아주 가까울 때 기본적으로 깎이는 신호의 세기
        # -> 사용하는 주파수와 안테나의 높이 등에 따라 결정 -> 주파수가 높을수록 A 커짐
        A = 50 
        
        # B : 거리가 멀어질 때 신호가 얼마나 급격하게 약해지는지를 나타내는 기울기
        B = 40 
        P_threshold = 100
        
        for i, (bx, by) in enumerate(self.candidate_locations):
            for px in range(self.num_pixels):
                row = px // self.side_len
                col = px % self.side_len
                
                dist_grid = np.sqrt((row - bx)**2 + (col - by)**2)
                dist_km = dist_grid * self.grid_resolution
                if dist_km == 0: dist_km = 0.001 
                
                pixel_weight = self.traffic_weight_map[px]
                
                if pixel_weight >= 3.0: max_radius = 0.3
                elif pixel_weight >= 1.5: max_radius = 0.6
                else: max_radius = 1.2
                
                if dist_km > max_radius:
                    coverage[i][px] = 0
                    continue

                N = np.random.normal(0, np.sqrt(10))
                
                # 계산된 경로 손실(P_loss)이 허용치(P_threshold, 100)보다 낮으면 통신이 가능한 것으로 판단하여 1을 기록
                P_loss = A + B * np.log10(dist_km) + N
                
                if P_loss < P_threshold:
                    coverage[i][px] = 1
                else:
                    coverage[i][px] = 0
        return coverage

    # 초기 상태의 신호 카운트 계산
    def initialize_signal_counts(self, state):
        active_indices = np.where(state == 1)[0]
        if len(active_indices) > 0:
            self.current_signal_counts = np.sum(self.coverage_matrix[active_indices], axis=0)
        else:
            self.current_signal_counts = np.zeros(self.num_pixels, dtype=int)

    # 카운트 배열을 이용한 빠른 Cost 계산         
    def calculate_cost_fast(self, signal_counts, n_bts):
        if n_bts == 0: return float('inf')
        covered_mask = signal_counts > 0
        
        weighted_covered_sum = np.sum(covered_mask * self.traffic_weight_map)
        total_weighted_sum = np.sum(self.traffic_weight_map)
        
        R = (weighted_covered_sum / total_weighted_sum) * 100
        if R == 0: return float('inf')
        f = self.K * (n_bts / (R ** self.g))
        return f

    def calculate_cost(self, state):
        n_bts = np.sum(state)
        if n_bts == 0: return float('inf')
        active_indices = np.where(state == 1)[0]
        if len(active_indices) > 0:
            covered_mask = np.any(self.coverage_matrix[active_indices], axis=0)
            weighted_covered_sum = np.sum(covered_mask * self.traffic_weight_map)
            total_weighted_sum = np.sum(self.traffic_weight_map)
            R = (weighted_covered_sum / total_weighted_sum) * 100
        else:
            R = 0
        if R == 0: return float('inf')
        return self.K * (n_bts / (R ** self.g))

    # 비복원 추출 + Delta 함께 반환
    def get_batch_candidates_optimized(self, current_state):
        ones_indices = np.where(current_state == 1)[0]
        zeros_indices = np.where(current_state == 0)[0]
        
        all_possible_moves = []
        
        # (1) Flip
        for idx in range(self.num_bs):
            all_possible_moves.append(('flip', idx))
            
        # (2) Swap
        if len(ones_indices) > 0 and len(zeros_indices) > 0:
            for i in ones_indices:
                for j in zeros_indices:
                    all_possible_moves.append(('swap', i, j))
        
        # 비복원 추출
        sample_k = min(len(all_possible_moves), self.candidate_size)
        sampled_moves_instructions = random.sample(all_possible_moves, sample_k)
        
        # 후보 생성 (Delta 계산 포함)
        candidates = []
        for move in sampled_moves_instructions:
            cand_state = current_state.copy()
            move_indices = set()
            delta_counts = np.zeros(self.num_pixels, dtype=int)
            
            if move[0] == 'flip':
                idx = move[1]
                cand_state[idx] = 1 - cand_state[idx]
                move_indices.add(idx)
                if cand_state[idx] == 1: delta_counts += self.coverage_matrix[idx]
                else: delta_counts -= self.coverage_matrix[idx]
                
            elif move[0] == 'swap':
                idx_off = move[1]
                idx_on = move[2]
                cand_state[idx_off] = 0
                cand_state[idx_on] = 1
                move_indices.add(idx_off)
                move_indices.add(idx_on)
                
                 # 변화량 계산
                delta_counts -= self.coverage_matrix[idx_off]
                delta_counts += self.coverage_matrix[idx_on]
            
            candidates.append((cand_state, frozenset(move_indices), delta_counts))
        return candidates

    def print_detailed_metrics(self, state, title="Analysis"):
        n_bts = np.sum(state)
        active_indices = np.where(state == 1)[0]
        if len(active_indices) == 0:
            print(f"[{title}] No Base Stations.")
            return

        covered_mask = np.any(self.coverage_matrix[active_indices], axis=0)
        raw_pixel_count = np.sum(covered_mask)
        total_pixels = self.num_pixels
        raw_coverage = (raw_pixel_count / total_pixels) * 100
        
        weighted_score = np.sum(covered_mask * self.traffic_weight_map)
        total_score = np.sum(self.traffic_weight_map)
        weighted_R = (weighted_score / total_score) * 100
        
        cost = self.K * (n_bts / (weighted_R ** self.g)) if weighted_R > 0 else float('inf')

        print(f"[{title}]")
        print(f" - 기지국 수: {n_bts}개")
        print(f" - 단순 면적 커버리지 (시각적): {raw_coverage:.2f}%")
        print(f" - 트래픽 가중 커버리지 (성능): {weighted_R:.2f}%")
        print(f" - Cost: {cost:.4f}")
        print("----------------------------------------------------")

    def run(self, max_iter=2000): 
        current_state = np.zeros(self.num_bs, dtype=int)
        initial_indices = random.sample(range(self.num_bs), 25)
        current_state[initial_indices] = 1
        
        initial_state = current_state.copy()
        
        self.initialize_signal_counts(current_state)
        current_cost = self.calculate_cost_fast(self.current_signal_counts, np.sum(current_state))
        
        best_cost = current_cost
        best_state = current_state.copy()
        
        current_cost_history = [current_cost] 
        best_cost_history = [best_cost]       
        
        self.tabu_list.clear()

        print(f"Start Tabu Search... Initial Cost: {current_cost:.4f}")
        self.print_detailed_metrics(current_state, title="Initial State Analysis")

        for i in range(1, max_iter + 1):
            candidates_batch = self.get_batch_candidates_optimized(current_state)
            
            best_candidate_state = None
            best_candidate_cost = float('inf')
            best_move_indices = None
            best_delta = None
            
            for cand_state, move_indices, delta_counts in candidates_batch:
                temp_counts = self.current_signal_counts + delta_counts
                cand_cost = self.calculate_cost_fast(temp_counts, np.sum(cand_state))
                
                is_tabu = False
                for idx in move_indices:
                    if idx in self.tabu_list:
                        is_tabu = True
                        break
                if is_tabu and cand_cost < best_cost:
                    is_tabu = False
                if not is_tabu:
                    if cand_cost < best_candidate_cost:
                        best_candidate_cost = cand_cost
                        best_candidate_state = cand_state
                        best_move_indices = move_indices
                        best_delta = delta_counts
            
            if best_candidate_state is not None:
                current_state = best_candidate_state
                current_cost = best_candidate_cost
                self.current_signal_counts += best_delta
                
                for idx in best_move_indices:
                    self.tabu_list.append(idx) 
                if current_cost < best_cost:
                    best_cost = current_cost
                    best_state = current_state.copy()
            
            current_cost_history.append(current_cost)
            best_cost_history.append(best_cost)

            if i % 500 == 0:
                print(f"Iter {i}: Cost {current_cost:.4f} (Best: {best_cost:.4f})")
                self.print_detailed_metrics(current_state, title=f"Iteration {i} Analysis")

        return initial_state, best_state, best_cost, current_cost_history, best_cost_history

# 시각화
def plot_final_map(optimizer, state, title="Final Optimization Result"):
    map_2d = optimizer.traffic_map 

    plt.figure(figsize=(8, 6))
    plt.imshow(map_2d, origin="lower", cmap='viridis') 
    plt.colorbar(label="Traffic Density")

    active_indices = np.where(state == 1)[0]
    active_coords = optimizer.candidate_locations[active_indices]

    if len(active_coords) > 0:
        xs = active_coords[:, 0]
        ys = active_coords[:, 1]
        plt.scatter(xs, ys, marker='x', color='red', s=100, linewidths=2, label='Base Station')

    plt.title(title)
    plt.legend(loc='upper right')
    plt.tight_layout()
    plt.show()

optimizer = TabuSearchOptimizer(tabu_size=10, candidate_size=20)
initial_state, final_state, final_cost, cur_hist, best_hist = optimizer.run(max_iter=2000)

print(f"\n=== Final Optimization Result ===")
optimizer.print_detailed_metrics(final_state, title="Final Best State Metric")

plt.figure(figsize=(10, 6))

plt.plot(cur_hist, color='blue', alpha=0.3, linewidth=1, label='Current Cost (Exploration)')

plt.plot(best_hist, color='red', linewidth=2, label='Best Cost (Convergence)')

plt.title("Tabu Search Cost History")
plt.xlabel("Iteration")
plt.ylabel("Cost")
plt.legend()  
plt.grid(True, alpha=0.3)
plt.show()

plot_final_map(optimizer, final_state, "Final Layout with Base Stations")