In [2]:
import numpy as np
from math import atan2, degrees, radians, cos, sin
from datetime import datetime, timedelta
import json
import os
import torch
import torch.nn as nn
import torch.optim as optim
import random
from collections import deque, namedtuple
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

# CUDA 디바이스 설정
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# 경험 저장을 위한 named tuple 정의
Experience = namedtuple('Experience', ('state', 'action', 'reward', 'next_state', 'done'))

# Dueling DQN 네트워크 정의
class DuelingDQN(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(DuelingDQN, self).__init__()
        self.fc1 = nn.Linear(state_dim, 128)
        self.fc2 = nn.Linear(128, 64)
        
        # 상태 가치 스트림
        self.value_stream = nn.Linear(64, 1)
        # 액션 이점 스트림
        self.advantage_stream = nn.Linear(64, action_dim)
        
    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        
        value = self.value_stream(x)
        advantage = self.advantage_stream(x)
        
        # Q 값 계산: V(s) + (A(s,a) - mean(A(s)))
        q_values = value + (advantage - advantage.mean(dim=1, keepdim=True))
        return q_values

# SumTree 클래스 정의 (PER을 위한 우선순위 관리)
class SumTree:
    def __init__(self, capacity):
        self.capacity = capacity
        self.tree = np.zeros(2 * capacity - 1)  # 우선순위를 저장하는 트리
        self.data = np.zeros(capacity, dtype=object)  # 경험 데이터를 저장
        self.data_pointer = 0  # 다음 저장 위치

    def add(self, priority, data):
        tree_idx = self.data_pointer + self.capacity - 1
        self.data[self.data_pointer] = data
        self.update(tree_idx, priority)
        self.data_pointer += 1
        if self.data_pointer >= self.capacity:
            self.data_pointer = 0  # 버퍼가 가득 차면 처음부터 덮어씀

    def update(self, tree_idx, priority):
        change = priority - self.tree[tree_idx]
        self.tree[tree_idx] = priority
        # 부모 노드 업데이트
        while tree_idx != 0:
            tree_idx = (tree_idx - 1) // 2
            self.tree[tree_idx] += change

    def get_leaf(self, s):
        parent_idx = 0
        while True:
            left_child_idx = 2 * parent_idx + 1
            right_child_idx = left_child_idx + 1
            if left_child_idx >= len(self.tree):
                leaf_idx = parent_idx
                break
            if s <= self.tree[left_child_idx]:
                parent_idx = left_child_idx
            else:
                s -= self.tree[left_child_idx]
                parent_idx = right_child_idx
        data_idx = leaf_idx - self.capacity + 1
        return leaf_idx, self.tree[leaf_idx], self.data[data_idx]

    def total_priority(self):
        return self.tree[0]  # 루트 노드의 합계

# 항해 환경 클래스 정의 (변경 없음)
class NavigationEnv:
    def __init__(self):
        self.grid = np.load('land_sea_grid_cartopy_downsized.npy')
        self.n_rows, self.n_cols = self.grid.shape
        self.lat_min, self.lat_max = 30, 38
        self.lon_min, self.lon_max = 120, 127
        self.start_pos = self.latlon_to_grid(37.46036, 126.52360)
        self.end_pos = self.latlon_to_grid(30.62828, 122.06400)
        self.step_time_minutes = 12
        self.max_steps = 300
        self.cumulative_time = 0
        self.step_count = 0
        self.tidal_data_dir = r"C:\baramproject\tidal_database"
        self.wind_data_dir = r"C:\baramproject\wind_database_2"
        self.action_space = np.array([0, 45, 90, 135, 180, -135, -90, -45])
        self.grid_directions = [(-1, 0), (-1, 1), (0, 1), (1, 1), (1, 0), (1, -1), (0, -1), (-1, -1)]
        self.k_c = 0.1
        self.k_w = 0.005
        self.path = []
        self.reset()

    def latlon_to_grid(self, lat, lon):
        row = int((self.lat_max - lat) / (self.lat_max - self.lat_min) * self.n_rows)
        col = int((lon - self.lon_min) / (self.lon_max - self.lon_min) * self.n_cols)
        return row, col

    def reset(self, start_time=None):
        start_date = datetime(2018, 1, 1, 0, 0)
        end_date = datetime(2018, 12, 29, 0, 0)
        if start_time is None:
            time_delta = (end_date - start_date).total_seconds()
            random_seconds = np.random.randint(0, int(time_delta / 60 / 30) + 1) * 30 * 60
            start_time = start_date + timedelta(seconds=random_seconds)
        
        self.current_pos = self.start_pos
        self.visit_count = {}
        self.prev_action = None
        self.current_time = start_time
        self.cumulative_time = 0
        self.load_tidal_data()
        self.map_tidal_to_grid()
        self.load_wind_data()
        self.map_wind_to_grid()
        self.prev_distance = self.get_distance_to_end()
        self.step_count = 0
        self.path = [self.current_pos]
        return self._get_state()

    def get_relative_position_and_angle(self):
        rel_pos = np.array(self.end_pos) - np.array(self.current_pos)
        distance = np.linalg.norm(rel_pos)
        end_angle = degrees(atan2(rel_pos[1], rel_pos[0])) % 360
        return rel_pos, distance, end_angle

    def get_distance_to_end(self):
        rel_pos = np.array(self.end_pos) - np.array(self.current_pos)
        return np.linalg.norm(rel_pos)

    def angle_to_grid_direction(self, abs_action_angle):
        grid_angles = np.array([0, 45, 90, 135, 180, 225, 270, 315])
        angle_diff = np.abs(grid_angles - abs_action_angle)
        closest_idx = np.argmin(angle_diff)
        return self.grid_directions[closest_idx]

    def load_data(self, data_dir, filename_prefix, time_str):
        data_file = os.path.join(data_dir, f"{filename_prefix}{time_str}.json")
        if not os.path.exists(data_file):
            print(f"Warning: Data file {data_file} not found. Episode will be terminated.")
            return None
        with open(data_file, 'r') as f:
            data = json.load(f)
        return data["result"]["data"]

    def map_data_to_grid(self, data, dir_key, speed_key):
        grid_dir = np.zeros((self.n_rows, self.n_cols))
        grid_speed = np.zeros((self.n_rows, self.n_cols))
        grid_valid = np.zeros((self.n_rows, self.n_cols), dtype=bool)
        if data is None:
            return grid_dir, grid_speed, grid_valid
        positions = [(float(item["pre_lat"]), float(item["pre_lon"])) for item in data]
        directions = [float(item[dir_key]) for item in data]
        speeds = [float(item[speed_key]) for item in data]
        for pos, dir, speed in zip(positions, directions, speeds):
            lat, lon = pos
            row, col = self.latlon_to_grid(lat, lon)
            if 0 <= row < self.n_rows and 0 <= col < self.n_cols:
                grid_dir[row, col] = dir
                grid_speed[row, col] = speed
                grid_valid[row, col] = True
        return grid_dir, grid_speed, grid_valid

    def load_tidal_data(self):
        time_str = self.current_time.strftime("%Y%m%d_%H%M")
        tidal_data = self.load_data(self.tidal_data_dir, "tidal_", time_str)
        self.tidal_data = tidal_data if tidal_data is not None else None

    def map_tidal_to_grid(self):
        if self.tidal_data is not None:
            self.tidal_grid_dir, self.tidal_grid_speed, self.tidal_grid_valid = self.map_data_to_grid(
                self.tidal_data, "current_dir", "current_speed"
            )
        else:
            self.tidal_grid_dir = np.zeros((self.n_rows, self.n_cols))
            self.tidal_grid_speed = np.zeros((self.n_rows, self.n_cols))
            self.tidal_grid_valid = np.zeros((self.n_rows, self.n_cols), dtype=bool)

    def load_wind_data(self):
        time_str = self.current_time.strftime("%Y%m%d_%H%M")
        wind_data = self.load_data(self.wind_data_dir, "wind_", time_str)
        self.wind_data = wind_data if wind_data is not None else None

    def map_wind_to_grid(self):
        if self.wind_data is not None:
            self.wind_grid_dir, self.wind_grid_speed, self.wind_grid_valid = self.map_data_to_grid(
                self.wind_data, "wind_dir", "wind_speed"
            )
        else:
            self.wind_grid_dir = np.zeros((self.n_rows, self.n_cols))
            self.wind_grid_speed = np.zeros((self.n_rows, self.n_cols))
            self.wind_grid_valid = np.zeros((self.n_rows, self.n_cols), dtype=bool)

    def calculate_fuel_consumption(self, abs_action_angle, position):
        row, col = position
        tidal_dir, tidal_speed = 0, 0
        if 0 <= row < self.n_rows and 0 <= col < self.n_cols and self.tidal_grid_valid[row, col]:
            tidal_dir = self.tidal_grid_dir[row, col]
            tidal_speed = self.tidal_grid_speed[row, col]
        wind_dir, wind_speed = 0, 0
        if 0 <= row < self.n_rows and 0 <= col < self.n_cols and self.wind_grid_valid[row, col]:
            wind_dir = self.wind_grid_dir[row, col]
            wind_speed = self.wind_grid_speed[row, col]
        tidal_dir_rad = (90 - tidal_dir) * np.pi / 180
        wind_dir_rad = (90 - wind_dir) * np.pi / 180
        action_angle_rad = (90 - abs_action_angle) * np.pi / 180
        theta_c = action_angle_rad - tidal_dir_rad
        theta_w = action_angle_rad - wind_dir_rad
        f_0 = 1
        tidal_effect = -self.k_c * tidal_speed * cos(theta_c)
        wind_effect = self.k_w * wind_speed * cos(theta_w)
        total_fuel = f_0 + wind_effect + tidal_effect
        return total_fuel

    def step(self, action):
        self.step_count += 1
        rel_pos, distance, end_angle = self.get_relative_position_and_angle()
        rel_action_angle = self.action_space[action]
        abs_action_angle = (end_angle + rel_action_angle) % 360
        turn_penalty = 0
        if hasattr(self, 'previous_direction') and self.previous_direction is not None:
            angle_diff = min((abs_action_angle - self.previous_direction) % 360, 
                             (self.previous_direction - abs_action_angle) % 360)
            turn_penalty = angle_diff * 0.1
        move_dir = self.angle_to_grid_direction(abs_action_angle)
        new_pos = (self.current_pos[0] + move_dir[0], self.current_pos[1] + move_dir[1])
        current_fuel = self.calculate_fuel_consumption(abs_action_angle, self.current_pos)
        next_fuel = self.calculate_fuel_consumption(abs_action_angle, new_pos)
        fuel_reduction = current_fuel - next_fuel
        if (0 <= new_pos[0] < self.n_rows and 0 <= new_pos[1] < self.n_cols and 
            self.grid[new_pos[0], new_pos[1]] == 0):
            self.current_pos = new_pos
            self.path.append(self.current_pos)
        self.previous_direction = abs_action_angle
        self.prev_action = action
        self.cumulative_time += self.step_time_minutes
        if self.cumulative_time >= 30:
            next_time = self.current_time + timedelta(minutes=30)
            end_date = datetime(2018, 12, 31, 23, 30)
            if next_time <= end_date:
                self.current_time = next_time
                self.load_tidal_data()
                self.map_tidal_to_grid()
                self.load_wind_data()
                self.map_wind_to_grid()
            else:
                print("Warning: Time exceeds 2018 range. Keeping previous data.")
            self.cumulative_time -= 30
        state = self._get_state()
        current_distance = self.get_distance_to_end()
        distance_reward = (self.prev_distance - current_distance) * 2.0
        
        self.prev_distance = current_distance
        goal_reward = 100 if tuple(self.current_pos) == self.end_pos else 0
        reward = -current_fuel + fuel_reduction * 1.0 + distance_reward - turn_penalty + goal_reward
        done = tuple(self.current_pos) == self.end_pos or self.step_count >= self.max_steps
        return state, reward, done, {}

    def _get_state(self):
        row, col = self.current_pos
        rel_pos, distance, end_angle = self.get_relative_position_and_angle()
        tidal_dir, tidal_speed = 0, 0
        if hasattr(self, 'tidal_grid_valid') and self.tidal_grid_valid[row, col]:
            tidal_dir = self.tidal_grid_dir[row, col]
            tidal_speed = self.tidal_grid_speed[row, col]
        wind_dir, wind_speed = 0, 0
        if hasattr(self, 'wind_grid_valid') and self.wind_grid_valid[row, col]:
            wind_dir = self.wind_grid_dir[row, col]
            wind_speed = self.wind_grid_speed[row, col]
        return np.array([rel_pos[0], rel_pos[1], distance, tidal_dir, tidal_speed, wind_dir, wind_speed])

# DQN 에이전트 클래스 정의 (PER 통합)
class DQNAgent:
    def __init__(self, state_dim, action_dim):
        self.state_dim = state_dim
        self.action_dim = action_dim
        
        # 기존 하이퍼파라미터
        self.lr = 0.0001
        self.gamma = 0.99
        self.batch_size = 64
        self.buffer_size = 100000
        self.target_update = 1000
        self.epsilon_start = 1.0
        self.epsilon_end = 0.01
        self.epsilon_decay = 10000
        
        # PER 관련 하이퍼파라미터
        self.alpha = 0.6  # 우선순위 가중치 조절
        self.beta_start = 0.4  # 초기 샘플링 편향 보정
        self.beta_end = 1.0  # 최종 샘플링 편향 보정
        
        # 신경망 및 옵티마이저
        self.policy_net = DuelingDQN(state_dim, action_dim).to(device)
        self.target_net = DuelingDQN(state_dim, action_dim).to(device)
        self.target_net.load_state_dict(self.policy_net.state_dict())
        self.target_net.eval()
        self.optimizer = optim.Adam(self.policy_net.parameters(), lr=self.lr)
        
        # PER을 위한 SumTree 메모리
        self.memory = SumTree(self.buffer_size)
        self.step_count = 0
        self.max_priority = 1.0  # 초기 최대 우선순위

    def select_action(self, state, epsilon):
        self.step_count += 1
        if random.random() < epsilon:
            return random.randrange(self.action_dim)
        state = torch.FloatTensor(state).unsqueeze(0).to(device)
        with torch.no_grad():
            q_values = self.policy_net(state)
        return q_values.argmax().item()

    def store_experience(self, state, action, reward, next_state, done):
        # 새로운 경험을 최대 우선순위로 저장
        experience = Experience(state, action, reward, next_state, done)
        self.memory.add(self.max_priority, experience)

    def sample_batch(self, beta):
        # 우선순위에 비례하여 경험 샘플링
        batch = []
        idxs = []
        priorities = []
        segment = self.memory.total_priority() / self.batch_size
        
        for i in range(self.batch_size):
            a = segment * i
            b = segment * (i + 1)
            s = random.uniform(a, b)
            idx, p, data = self.memory.get_leaf(s)
            batch.append(data)
            idxs.append(idx)
            priorities.append(p)
        
        # 샘플링 확률과 가중치 계산
        sampling_probabilities = np.array(priorities) / self.memory.total_priority()
        is_weight = np.power(self.buffer_size * sampling_probabilities, -beta)
        is_weight /= is_weight.max()  # 정규화
        return batch, idxs, is_weight

    def compute_loss(self, batch, idxs, is_weight, beta):
        # 배치에서 데이터 추출
        states, actions, rewards, next_states, dones = zip(*batch)
        states = np.array(states)
        next_states = np.array(next_states)
        actions = np.array(actions)
        rewards = np.array(rewards)
        dones = np.array(dones)
        
        states = torch.FloatTensor(states).to(device)
        next_states = torch.FloatTensor(next_states).to(device)
        actions = torch.LongTensor(actions).to(device)
        rewards = torch.FloatTensor(rewards).to(device)
        dones = torch.FloatTensor(dones).to(device)
        
        # Q 값 계산
        q_values = self.policy_net(states).gather(1, actions.unsqueeze(1)).squeeze(1)
        with torch.no_grad():
            next_q_values = self.target_net(next_states).max(dim=1)[0]
            targets = rewards + self.gamma * next_q_values * (1 - dones)
        
        # TD 오차 계산 및 우선순위 업데이트
        td_errors = torch.abs(targets - q_values)
        loss = (torch.FloatTensor(is_weight).to(device) * (q_values - targets.detach()) ** 2).mean()
        
        for idx, td_error in zip(idxs, td_errors.detach().cpu().numpy()):
            priority = td_error.item() ** self.alpha
            self.memory.update(idx, priority)
            self.max_priority = max(self.max_priority, priority)
        
        return loss

    def update(self):
        if self.memory.data_pointer < self.batch_size:
            return
        
        # Beta 값을 학습 진행에 따라 증가
        beta = self.beta_start + (self.beta_end - self.beta_start) * min(1.0, self.step_count / 50000)
        batch, idxs, is_weight = self.sample_batch(beta)
        loss = self.compute_loss(batch, idxs, is_weight, beta)
        
        # 모델 업데이트
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        
        if self.step_count % self.target_update == 0:
            self.target_net.load_state_dict(self.policy_net.state_dict())

# 학습 루프 정의 (변경 없음)
def train_dqn(env, agent, max_episodes=30000):
    rewards = []
    path_lengths = []
    epsilon = agent.epsilon_start
    
    image_dir = r"C:\baramproject\trained_model\sibal22\episode_debug_image"
    data_dir = r"C:\baramproject\trained_model\sibal22\episode_debug_data"
    os.makedirs(image_dir, exist_ok=True)
    os.makedirs(data_dir, exist_ok=True)
    
    for episode in tqdm(range(max_episodes), desc="Training Episodes"):
        state = env.reset()
        total_reward = 0
        path_length = 0
        done = False
        debug_data = []
        
        while not done:
            epsilon = max(agent.epsilon_end, epsilon - (agent.epsilon_start - agent.epsilon_end) / agent.epsilon_decay)
            action = agent.select_action(state, epsilon)
            next_state, reward, done, _ = env.step(action)
            
            q_values = agent.policy_net(torch.FloatTensor(state).unsqueeze(0).to(device)).detach().cpu().numpy().flatten()
            debug_data.append({
                "step": path_length,
                "state": state.tolist(),
                "action": action,
                "reward": reward,
                "next_state": next_state.tolist(),
                "q_values": q_values.tolist(),
                "epsilon": epsilon
            })
            
            agent.store_experience(state, action, reward, next_state, done)
            agent.update()
            
            state = next_state
            total_reward += reward
            path_length += 1
        
        rewards.append(total_reward)
        path_lengths.append(path_length)
        
        if episode % 100 == 0:
            print(f"Episode {episode}, Total Reward: {total_reward}, Path Length: {path_length}")
            
            plt.figure(figsize=(10, 8))
            plt.imshow(env.grid, cmap='gray')
            path_array = np.array(env.path)
            plt.plot(path_array[:, 1], path_array[:, 0], 'r-', label='Path')
            plt.plot(env.start_pos[1], env.start_pos[0], 'go', label='Start')
            plt.plot(env.end_pos[1], env.end_pos[0], 'bo', label='End')
            plt.legend()
            plt.title(f"Episode {episode} Path")
            plt.savefig(os.path.join(image_dir, f"episode_{episode}.png"))
            plt.close()
            
            with open(os.path.join(data_dir, f"episode_{episode}.json"), 'w') as f:
                json.dump(debug_data, f, indent=4)
        
        if episode % 10000 == 0 and episode > 0:
            plt.plot(rewards)
            plt.title("Total Rewards Over Episodes")
            plt.xlabel("Episode")
            plt.ylabel("Reward")
            plt.savefig(os.path.join(image_dir, f"rewards_episode_{episode}.png"))
            plt.close()
    
    torch.save(agent.policy_net.state_dict(), r"C:\baramproject\trained_model\sibal22\navigation_model.pth")
    return rewards, path_lengths

# 메인 실행
if __name__ == "__main__":
    env = NavigationEnv()
    state_dim = 7
    action_dim = len(env.action_space)
    agent = DQNAgent(state_dim, action_dim)
    
    rewards, path_lengths = train_dqn(env, agent)
    
    plt.plot(rewards)
    plt.title("Total Rewards Over Episodes")
    plt.xlabel("Episode")
    plt.ylabel("Reward")
    plt.show()
    
    plt.plot(path_lengths)
    plt.title("Path Lengths Over Episodes")
    plt.xlabel("Episode")
    plt.ylabel("Path Length")
    plt.show()

Using device: cuda


Training Episodes:   0%|          | 0/30000 [00:00<?, ?it/s]

Episode 0, Total Reward: -3066.5816721047627, Path Length: 300
Episode 100, Total Reward: -521.3756744589826, Path Length: 300
Episode 200, Total Reward: 16.15992522806465, Path Length: 300
Episode 300, Total Reward: 132.76008370803322, Path Length: 300
Episode 400, Total Reward: 42.73644971508658, Path Length: 300
Episode 500, Total Reward: -22.95471225333901, Path Length: 300
Episode 600, Total Reward: 25.44360584774398, Path Length: 300
Episode 700, Total Reward: 39.06387280801141, Path Length: 300
Episode 800, Total Reward: 92.96721129149414, Path Length: 300
Episode 900, Total Reward: 108.64027623787109, Path Length: 300
Episode 1000, Total Reward: 45.603643538426866, Path Length: 300
Episode 1100, Total Reward: 39.387202245476146, Path Length: 300
Episode 1200, Total Reward: 61.32925566435068, Path Length: 300
Episode 1300, Total Reward: 26.757159385695957, Path Length: 300
Episode 1400, Total Reward: 20.92972983588207, Path Length: 300
Episode 1500, Total Reward: 67.853721078085

KeyboardInterrupt: 