In [None]:
import numpy as np
import random
import matplotlib.pyplot as plt
from collections import deque
from grid import generate_synthetic_traffic  

class RandomWalkOptimizer:
    def __init__(self, num_bs=50, num_pixels=2500, p=0.03):
        self.num_bs = num_bs
        self.num_pixels = num_pixels
        self.side_len = int(np.sqrt(num_pixels)) 
        
        self.K = 10**8
        self.g = 4.5
        self.p = p
        
        # 격자 간격 60m 
        self.grid_resolution = 0.06 
        
        # 트래픽 맵 생성 (2D와 1D를 둘 다 저장)
        # self.traffic_map : 2D 배열 (맵 생성용, 50x50)
        # self.traffic_weight_map : 1D 배열 (계산용, 2500)
        traffic_2d = self._generate_traffic_map_2d()
        self.traffic_map = traffic_2d 
        self.traffic_weight_map = traffic_2d.flatten()
        
        # 기지국 좌표 저장 
        self.candidate_locations = []
        for _ in range(self.num_bs):
            bx = random.randint(0, self.side_len - 1)
            by = random.randint(0, self.side_len - 1)
            self.candidate_locations.append([bx, by])
        self.candidate_locations = np.array(self.candidate_locations) 
        
        # 커버리지 매트릭스 생성
        self.coverage_matrix = self._generate_log_normal_coverage()
        
        # 현재 각 픽셀별 신호 중첩 횟수를 저장하는 캐시 배열 
        self.current_signal_counts = np.zeros(self.num_pixels, dtype=int)

    def _generate_traffic_map_2d(self):
        # grid.py 
        traffic_2d = generate_synthetic_traffic(
            width=self.side_len,    
            height=self.side_len,   
            pattern="multi_hotspot", 
            rng=np.random.default_rng(42)
        )
        traffic_2d = traffic_2d * 3.0 
        return traffic_2d

    def _generate_log_normal_coverage(self):
        coverage = np.zeros((self.num_bs, self.num_pixels), dtype=int)
        A = 50; B = 40; P_threshold = 100
        
        # 저장된 좌표(candidate_locations) 사용
        for i, (bx, by) in enumerate(self.candidate_locations):
            for px in range(self.num_pixels):
                row = px // self.side_len
                col = px % self.side_len
                
                dist_grid = np.sqrt((row - bx)**2 + (col - by)**2)
                dist_km = dist_grid * self.grid_resolution
                if dist_km == 0: dist_km = 0.001 
                
                # 1D 맵 사용 (인덱싱 에러 방지)
                pixel_weight = self.traffic_weight_map[px]
                
                if pixel_weight >= 3.0: max_radius = 0.3
                elif pixel_weight >= 1.5: max_radius = 0.6
                else: max_radius = 1.2
                
                if dist_km > max_radius:
                    coverage[i][px] = 0
                    continue

                N = np.random.normal(0, np.sqrt(10))
                P_loss = A + B * np.log10(dist_km) + N
                
                if P_loss < P_threshold:
                    coverage[i][px] = 1
                else:
                    coverage[i][px] = 0
        return coverage

    def initialize_signal_counts(self, state):
        active_indices = np.where(state == 1)[0]
        if len(active_indices) > 0:
            self.current_signal_counts = np.sum(self.coverage_matrix[active_indices], axis=0)
        else:
            self.current_signal_counts = np.zeros(self.num_pixels, dtype=int)

    def calculate_cost_fast(self, signal_counts, n_bts):
        if n_bts == 0: return float('inf')

        covered_mask = signal_counts > 0
        
        # 1D 맵 사용
        weighted_covered_sum = np.sum(covered_mask * self.traffic_weight_map)
        total_weighted_sum = np.sum(self.traffic_weight_map)
        
        R = (weighted_covered_sum / total_weighted_sum) * 100
        
        if R == 0: return float('inf')
        f = self.K * (n_bts / (R ** self.g))
        return f

    def get_neighbor_with_delta(self, current_state):
        neighbor_state = current_state.copy()
        delta_counts = np.zeros(self.num_pixels, dtype=int) 
        
        if random.random() < 0.5:
            # (1) Flip 
            idx = random.randint(0, self.num_bs - 1)
            neighbor_state[idx] = 1 - neighbor_state[idx]
            if neighbor_state[idx] == 1:
                delta_counts += self.coverage_matrix[idx]
            else:
                delta_counts -= self.coverage_matrix[idx]
        else:
            # (2) Swap 
            ones_indices = np.where(neighbor_state == 1)[0]
            zeros_indices = np.where(neighbor_state == 0)[0]
            
            if len(ones_indices) > 0 and len(zeros_indices) > 0:
                idx_to_off = random.choice(ones_indices)
                idx_to_on = random.choice(zeros_indices)
                neighbor_state[idx_to_off] = 0
                neighbor_state[idx_to_on] = 1
                
                delta_counts -= self.coverage_matrix[idx_to_off]
                delta_counts += self.coverage_matrix[idx_to_on]
            else:
                idx = random.randint(0, self.num_bs - 1)
                neighbor_state[idx] = 1 - neighbor_state[idx]
                if neighbor_state[idx] == 1:
                    delta_counts += self.coverage_matrix[idx]
                else:
                    delta_counts -= self.coverage_matrix[idx]
                    
        return neighbor_state, delta_counts

    def print_detailed_metrics(self, state, title="Analysis"):
        n_bts = np.sum(state)
        active_indices = np.where(state == 1)[0]
        
        if len(active_indices) == 0:
            print(f"[{title}] No Base Stations.")
            return

        covered_mask = np.any(self.coverage_matrix[active_indices], axis=0)
        
        raw_pixel_count = np.sum(covered_mask)
        total_pixels = self.num_pixels
        raw_coverage = (raw_pixel_count / total_pixels) * 100
        
        weighted_score = np.sum(covered_mask * self.traffic_weight_map)
        total_score = np.sum(self.traffic_weight_map)
        weighted_R = (weighted_score / total_score) * 100
        
        cost = self.K * (n_bts / (weighted_R ** self.g)) if weighted_R > 0 else float('inf')

        print(f"[{title}]")
        print(f" - 기지국 수: {n_bts}개")
        print(f" - 단순 면적 커버리지 (시각적): {raw_coverage:.2f}%")
        print(f" - 트래픽 가중 커버리지 (성능): {weighted_R:.2f}% ")
        print(f" - Cost: {cost:.4f}")
        print("----------------------------------------------------")

    def run(self, max_iter=10000):
        current_state = np.zeros(self.num_bs, dtype=int)
        initial_indices = random.sample(range(self.num_bs), 25)
        current_state[initial_indices] = 1
        
        initial_state = current_state.copy() 
        
        self.initialize_signal_counts(current_state)
        
        current_cost = self.calculate_cost_fast(self.current_signal_counts, np.sum(current_state))
        
        best_cost = current_cost
        best_state = current_state.copy()
        
        current_cost_history = [current_cost]
        best_cost_history = [best_cost]

        print(f"Start Optimization... Initial Cost: {current_cost:.4f}")
        self.print_detailed_metrics(current_state, title="Initial State")
        
        count_bad = 0
        count_good = 0

        for i in range(1, max_iter + 1):
            neighbor_state, delta_counts = self.get_neighbor_with_delta(current_state)
            
            temp_counts = self.current_signal_counts + delta_counts
            neighbor_cost = self.calculate_cost_fast(temp_counts, np.sum(neighbor_state))

            # Random Walk 
            if neighbor_cost <= current_cost:
                current_state = neighbor_state
                current_cost = neighbor_cost
                count_good += 1
                self.current_signal_counts = temp_counts
            else:
                if random.random() < self.p:
                    current_state = neighbor_state
                    current_cost = neighbor_cost
                    count_bad += 1
                    self.current_signal_counts = temp_counts
                else:
                    pass

            if current_cost < best_cost:
                best_cost = current_cost
                best_state = current_state.copy()
            
            # 기록 업데이트
            current_cost_history.append(current_cost)
            best_cost_history.append(best_cost)
            
            if i % 1000 == 0:
                print(f"Iter {i}: Cost {current_cost:.4f} (Best: {best_cost:.4f})")
                
        return initial_state, best_state, best_cost, current_cost_history, best_cost_history, count_good, count_bad


# 시각화 
def plot_final_map(optimizer, state, title="Final Optimization Result"):
    # 저장해둔 2D 맵 사용
    map_2d = optimizer.traffic_map 

    plt.figure(figsize=(8, 6))
    plt.imshow(map_2d, origin="lower", cmap='viridis') 
    plt.colorbar(label="Traffic Density")

    active_indices = np.where(state == 1)[0]
    # 저장해둔 기지국 좌표 사용
    active_coords = optimizer.candidate_locations[active_indices]

    if len(active_coords) > 0:
        xs = active_coords[:, 0]
        ys = active_coords[:, 1]
        plt.scatter(xs, ys, marker='x', color='red', s=100, linewidths=2, label='Base Station')

    plt.title(title)
    plt.legend(loc='upper right')
    plt.tight_layout()
    plt.show()

# 논문 최적 p=0.03
optimizer = RandomWalkOptimizer(p=0.03) 
initial_state, final_state, final_cost, cur_hist, best_hist, count_good, count_bad = optimizer.run(max_iter=10000)

print(f"\n=== Final Optimization Result ===")
print(f"Final Cost: {final_cost:.4f}")
print(f"Initial Base Station : {np.sum(initial_state)}")
print(f"Final Base Stations: {np.sum(final_state)}")
print(f"Number of Acception of Bad Solution : {count_bad}")
print(f"Number of Acception of Good Solution : {count_good}")

optimizer.print_detailed_metrics(initial_state, title="Initial State Metric")
optimizer.print_detailed_metrics(final_state, title="Final Best State Metric")

plt.figure(figsize=(10, 6))
plt.plot(cur_hist, color='blue', alpha=0.3, linewidth=1, label='Current Cost (Exploration)')
plt.plot(best_hist, color='red', linewidth=2, label='Best Cost (Convergence)')
plt.title("Random Walk Optimization Cost History")
plt.xlabel("Iteration")
plt.ylabel("Cost")
plt.legend()
plt.show()

# 최종 배치 시각화
plot_final_map(optimizer, final_state, "Final Layout with Base Stations")