{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "18189d24-d525-4bd1-a837-6af29603aba8",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from math import atan2, degrees, radians, cos, sin\n",
    "from datetime import datetime, timedelta\n",
    "import json\n",
    "import os\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import random\n",
    "from collections import deque, namedtuple\n",
    "import matplotlib.pyplot as plt\n",
    "from tqdm.notebook import tqdm\n",
    "\n",
    "# CUDA 디바이스 설정\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "print(f\"Using device: {device}\")\n",
    "\n",
    "# 경험 저장을 위한 named tuple 정의\n",
    "Experience = namedtuple('Experience', ('state', 'action', 'reward', 'next_state', 'done'))\n",
    "\n",
    "# Dueling DQN 네트워크 정의\n",
    "class DuelingDQN(nn.Module):\n",
    "    def __init__(self, state_dim, action_dim):\n",
    "        super(DuelingDQN, self).__init__()\n",
    "        self.fc1 = nn.Linear(state_dim, 128)\n",
    "        self.fc2 = nn.Linear(128, 64)\n",
    "        \n",
    "        # 상태 가치 스트림\n",
    "        self.value_stream = nn.Linear(64, 1)\n",
    "        # 액션 이점 스트림\n",
    "        self.advantage_stream = nn.Linear(64, action_dim)\n",
    "        \n",
    "    def forward(self, x):\n",
    "        x = torch.relu(self.fc1(x))\n",
    "        x = torch.relu(self.fc2(x))\n",
    "        \n",
    "        value = self.value_stream(x)\n",
    "        advantage = self.advantage_stream(x)\n",
    "        \n",
    "        # Q 값 계산: V(s) + (A(s,a) - mean(A(s)))\n",
    "        q_values = value + (advantage - advantage.mean(dim=1, keepdim=True))\n",
    "        return q_values\n",
    "\n",
    "# 항해 환경 클래스 정의\n",
    "class NavigationEnv:\n",
    "    def __init__(self):\n",
    "        # 그리드 맵 로드\n",
    "        self.grid = np.load('land_sea_grid_cartopy_downsized.npy')\n",
    "        self.n_rows, self.n_cols = self.grid.shape\n",
    "        \n",
    "        # 경도/위도 범위\n",
    "        self.lat_min, self.lat_max = 30, 38\n",
    "        self.lon_min, self.lon_max = 120, 127\n",
    "        \n",
    "        # 시작점과 종료점 설정\n",
    "        self.start_pos = self.latlon_to_grid(37.46036, 126.52360)\n",
    "        self.end_pos = self.latlon_to_grid(30.62828, 122.06400)\n",
    "        \n",
    "        # 시간 관리\n",
    "        self.step_time_minutes = 12\n",
    "        self.max_steps = 300\n",
    "        self.cumulative_time = 0\n",
    "        self.step_count = 0\n",
    "        \n",
    "        # 조류 및 풍향/풍속 데이터 경로\n",
    "        self.tidal_data_dir = r\"C:\\baramproject\\tidal_database\"\n",
    "        self.wind_data_dir = r\"C:\\baramproject\\wind_database_2\"\n",
    "        \n",
    "        # 액션 공간 정의 (상대 각도: 종료점 방향 기준 8방향)\n",
    "        self.action_space = np.array([0, 45, 90, 135, 180, -135, -90, -45])  # 상대 각도 (도 단위)\n",
    "        \n",
    "        # 그리드 이동 방향 매핑 (상, 우상, 우, 우하, 하, 좌하, 좌, 좌상)\n",
    "        self.grid_directions = [\n",
    "            (-1, 0), (-1, 1), (0, 1), (1, 1), (1, 0), (1, -1), (0, -1), (-1, -1)\n",
    "        ]\n",
    "        \n",
    "        # 연비 효율 관련 계수\n",
    "        self.k_c = 0.1  # 조류 영향 계수\n",
    "        self.k_w = 0.005  # 풍속 영향 계수\n",
    "        \n",
    "        # 경로 저장용 리스트\n",
    "        self.path = []\n",
    "        \n",
    "        # 환경 초기화\n",
    "        self.reset()\n",
    "\n",
    "    def latlon_to_grid(self, lat, lon):\n",
    "        \"\"\"위도/경도를 그리드 좌표로 변환\"\"\"\n",
    "        row = int((self.lat_max - lat) / (self.lat_max - self.lat_min) * self.n_rows)\n",
    "        col = int((lon - self.lon_min) / (self.lon_max - self.lon_min) * self.n_cols)\n",
    "        return row, col\n",
    "\n",
    "    def reset(self, start_time=None):\n",
    "        start_date = datetime(2018, 1, 1, 0, 0)\n",
    "        end_date = datetime(2018, 12, 29, 0, 0)  # 12월 29일로 변경하여 여유를 둠\n",
    "        if start_time is None:\n",
    "            time_delta = (end_date - start_date).total_seconds()\n",
    "            random_seconds = np.random.randint(0, int(time_delta / 60 / 30) + 1) * 30 * 60\n",
    "            start_time = start_date + timedelta(seconds=random_seconds)\n",
    "        \n",
    "        self.current_pos = self.start_pos\n",
    "        self.visit_count = {}\n",
    "        self.prev_action = None\n",
    "        self.current_time = start_time\n",
    "        self.cumulative_time = 0\n",
    "        self.load_tidal_data()\n",
    "        self.map_tidal_to_grid()\n",
    "        self.load_wind_data()\n",
    "        self.map_wind_to_grid()\n",
    "        self.prev_distance = self.get_distance_to_end()\n",
    "        self.step_count = 0\n",
    "        self.path = [self.current_pos]\n",
    "        return self._get_state()\n",
    "\n",
    "    def get_relative_position_and_angle(self):\n",
    "        \"\"\"종료점을 기준으로 한 상대 좌표와 각도 계산\"\"\"\n",
    "        rel_pos = np.array(self.end_pos) - np.array(self.current_pos)\n",
    "        distance = np.linalg.norm(rel_pos)\n",
    "        end_angle = degrees(atan2(rel_pos[1], rel_pos[0])) % 360\n",
    "        return rel_pos, distance, end_angle\n",
    "\n",
    "    def get_distance_to_end(self):\n",
    "        \"\"\"종료점까지의 거리 계산\"\"\"\n",
    "        rel_pos = np.array(self.end_pos) - np.array(self.current_pos)\n",
    "        return np.linalg.norm(rel_pos)\n",
    "\n",
    "    def angle_to_grid_direction(self, abs_action_angle):\n",
    "        \"\"\"절대 각도를 그리드 이동 방향으로 매핑\"\"\"\n",
    "        grid_angles = np.array([0, 45, 90, 135, 180, 225, 270, 315])\n",
    "        angle_diff = np.abs(grid_angles - abs_action_angle)\n",
    "        closest_idx = np.argmin(angle_diff)\n",
    "        return self.grid_directions[closest_idx]\n",
    "\n",
    "    def load_data(self, data_dir, filename_prefix, time_str):\n",
    "        data_file = os.path.join(data_dir, f\"{filename_prefix}{time_str}.json\")\n",
    "        if not os.path.exists(data_file):\n",
    "            print(f\"Warning: Data file {data_file} not found. Episode will be terminated.\")\n",
    "            return None\n",
    "        \n",
    "        with open(data_file, 'r') as f:\n",
    "            data = json.load(f)\n",
    "        return data[\"result\"][\"data\"]\n",
    "\n",
    "    def map_data_to_grid(self, data, dir_key, speed_key):\n",
    "        \"\"\"공통 데이터 매핑 함수\"\"\"\n",
    "        grid_dir = np.zeros((self.n_rows, self.n_cols))\n",
    "        grid_speed = np.zeros((self.n_rows, self.n_cols))\n",
    "        grid_valid = np.zeros((self.n_rows, self.n_cols), dtype=bool)\n",
    "        \n",
    "        if data is None:\n",
    "            return grid_dir, grid_speed, grid_valid\n",
    "        \n",
    "        positions = [(float(item[\"pre_lat\"]), float(item[\"pre_lon\"])) for item in data]\n",
    "        directions = [float(item[dir_key]) for item in data]\n",
    "        speeds = [float(item[speed_key]) for item in data]\n",
    "        \n",
    "        for pos, dir, speed in zip(positions, directions, speeds):\n",
    "            lat, lon = pos\n",
    "            row, col = self.latlon_to_grid(lat, lon)\n",
    "            if 0 <= row < self.n_rows and 0 <= col < self.n_cols:\n",
    "                grid_dir[row, col] = dir\n",
    "                grid_speed[row, col] = speed\n",
    "                grid_valid[row, col] = True\n",
    "        \n",
    "        return grid_dir, grid_speed, grid_valid\n",
    "\n",
    "    def load_tidal_data(self):\n",
    "        \"\"\"조류 데이터 로드\"\"\"\n",
    "        time_str = self.current_time.strftime(\"%Y%m%d_%H%M\")\n",
    "        tidal_data = self.load_data(self.tidal_data_dir, \"tidal_\", time_str)\n",
    "        if tidal_data is not None:\n",
    "            self.tidal_data = tidal_data\n",
    "        else:\n",
    "            self.tidal_data = None  # 데이터가 없으면 None으로 설정\n",
    "\n",
    "    def map_tidal_to_grid(self):\n",
    "        \"\"\"조류 데이터를 그리드에 매핑\"\"\"\n",
    "        if self.tidal_data is not None:\n",
    "            self.tidal_grid_dir, self.tidal_grid_speed, self.tidal_grid_valid = self.map_data_to_grid(\n",
    "                self.tidal_data, \"current_dir\", \"current_speed\"\n",
    "            )\n",
    "        else:\n",
    "            self.tidal_grid_dir = np.zeros((self.n_rows, self.n_cols))\n",
    "            self.tidal_grid_speed = np.zeros((self.n_rows, self.n_cols))\n",
    "            self.tidal_grid_valid = np.zeros((self.n_rows, self.n_cols), dtype=bool)\n",
    "\n",
    "    def load_wind_data(self):\n",
    "        \"\"\"풍향/풍속 데이터 로드\"\"\"\n",
    "        time_str = self.current_time.strftime(\"%Y%m%d_%H%M\")\n",
    "        wind_data = self.load_data(self.wind_data_dir, \"wind_\", time_str)\n",
    "        if wind_data is not None:\n",
    "            self.wind_data = wind_data\n",
    "        else:\n",
    "            self.wind_data = None  # 데이터가 없으면 None으로 설정\n",
    "\n",
    "    def map_wind_to_grid(self):\n",
    "        \"\"\"풍향/풍속 데이터를 그리드에 매핑\"\"\"\n",
    "        if self.wind_data is not None:\n",
    "            self.wind_grid_dir, self.wind_grid_speed, self.wind_grid_valid = self.map_data_to_grid(\n",
    "                self.wind_data, \"wind_dir\", \"wind_speed\"\n",
    "            )\n",
    "        else:\n",
    "            self.wind_grid_dir = np.zeros((self.n_rows, self.n_cols))\n",
    "            self.wind_grid_speed = np.zeros((self.n_rows, self.n_cols))\n",
    "            self.wind_grid_valid = np.zeros((self.n_rows, self.n_cols), dtype=bool)\n",
    "\n",
    "    def calculate_fuel_consumption(self, abs_action_angle, position):\n",
    "        \"\"\"연료 소비 계산\"\"\"\n",
    "        row, col = position\n",
    "        \n",
    "        tidal_dir, tidal_speed = 0, 0\n",
    "        if 0 <= row < self.n_rows and 0 <= col < self.n_cols and self.tidal_grid_valid[row, col]:\n",
    "            tidal_dir = self.tidal_grid_dir[row, col]\n",
    "            tidal_speed = self.tidal_grid_speed[row, col]\n",
    "        \n",
    "        wind_dir, wind_speed = 0, 0\n",
    "        if 0 <= row < self.n_rows and 0 <= col < self.n_cols and self.wind_grid_valid[row, col]:\n",
    "            wind_dir = self.wind_grid_dir[row, col]\n",
    "            wind_speed = self.wind_grid_speed[row, col]\n",
    "        \n",
    "        tidal_dir_rad = (90 - tidal_dir) * np.pi / 180\n",
    "        wind_dir_rad = (90 - wind_dir) * np.pi / 180\n",
    "        action_angle_rad = (90 - abs_action_angle) * np.pi / 180\n",
    "        \n",
    "        theta_c = action_angle_rad - tidal_dir_rad\n",
    "        theta_w = action_angle_rad - wind_dir_rad\n",
    "        \n",
    "        f_0 = 1\n",
    "        tidal_effect = -self.k_c * tidal_speed * cos(theta_c)\n",
    "        wind_effect = self.k_w * wind_speed * cos(theta_w)\n",
    "        total_fuel = f_0 + wind_effect + tidal_effect\n",
    "        \n",
    "        return total_fuel\n",
    "\n",
    "    def step(self, action):\n",
    "        \"\"\"환경 스텝 실행\"\"\"\n",
    "        # 스텝 수 증가\n",
    "        self.step_count += 1\n",
    "    \n",
    "        # 상대 위치 및 각도 계산\n",
    "        rel_pos, distance, end_angle = self.get_relative_position_and_angle()\n",
    "        rel_action_angle = self.action_space[action]\n",
    "        abs_action_angle = (end_angle + rel_action_angle) % 360\n",
    "        \n",
    "        # 턴 페널티 계산 (이전 방향이 있을 경우)\n",
    "        turn_penalty = 0\n",
    "        if hasattr(self, 'previous_direction') and self.previous_direction is not None:\n",
    "            angle_diff = min((abs_action_angle - self.previous_direction) % 360, \n",
    "                             (self.previous_direction - abs_action_angle) % 360)\n",
    "            turn_penalty = angle_diff * 0.1\n",
    "        \n",
    "        # 그리드 이동 방향 계산\n",
    "        move_dir = self.angle_to_grid_direction(abs_action_angle)\n",
    "        new_pos = (self.current_pos[0] + move_dir[0], self.current_pos[1] + move_dir[1])\n",
    "        \n",
    "        # 연료 소비 계산\n",
    "        current_fuel = self.calculate_fuel_consumption(abs_action_angle, self.current_pos)\n",
    "        next_fuel = self.calculate_fuel_consumption(abs_action_angle, new_pos)\n",
    "        fuel_reduction = current_fuel - next_fuel\n",
    "        \n",
    "        # 새 위치가 유효한지 확인하고 이동\n",
    "        if (0 <= new_pos[0] < self.n_rows and 0 <= new_pos[1] < self.n_cols and \n",
    "            self.grid[new_pos[0], new_pos[1]] == 0):\n",
    "            self.current_pos = new_pos\n",
    "            self.path.append(self.current_pos)\n",
    "        \n",
    "        # 이전 방향 업데이트\n",
    "        self.previous_direction = abs_action_angle\n",
    "        \n",
    "        # 이전 액션 업데이트\n",
    "        self.prev_action = action\n",
    "        \n",
    "        # 시간 업데이트\n",
    "        self.cumulative_time += self.step_time_minutes\n",
    "        if self.cumulative_time >= 30:\n",
    "            next_time = self.current_time + timedelta(minutes=30)\n",
    "            end_date = datetime(2018, 12, 31, 23, 30)\n",
    "            if next_time <= end_date:\n",
    "                self.current_time = next_time\n",
    "                self.load_tidal_data()\n",
    "                self.map_tidal_to_grid()\n",
    "                self.load_wind_data()\n",
    "                self.map_wind_to_grid()\n",
    "            else:\n",
    "                print(\"Warning: Time exceeds 2018 range. Keeping previous data.\")\n",
    "            self.cumulative_time -= 30\n",
    "        \n",
    "        # 상태, 보상, 종료 여부 계산\n",
    "        state = self._get_state()\n",
    "        current_distance = self.get_distance_to_end()\n",
    "        distance_reward = (self.prev_distance - current_distance) * 2.0\n",
    "        self.prev_distance = current_distance\n",
    "        \n",
    "        goal_reward = 100 if tuple(self.current_pos) == self.end_pos else 0\n",
    "        reward = -current_fuel + fuel_reduction * 1.0 + distance_reward - turn_penalty + goal_reward\n",
    "        \n",
    "        # 종료 조건: 목표 도달 또는 스텝 수 300 초과\n",
    "        done = tuple(self.current_pos) == self.end_pos or self.step_count >= self.max_steps\n",
    "        \n",
    "        return state, reward, done, {}\n",
    "\n",
    "    def _get_state(self):\n",
    "        row, col = self.current_pos\n",
    "        rel_pos, distance, end_angle = self.get_relative_position_and_angle()\n",
    "        \n",
    "        tidal_dir, tidal_speed = 0, 0\n",
    "        if hasattr(self, 'tidal_grid_valid') and self.tidal_grid_valid[row, col]:\n",
    "            tidal_dir = self.tidal_grid_dir[row, col]\n",
    "            tidal_speed = self.tidal_grid_speed[row, col]\n",
    "        \n",
    "        wind_dir, wind_speed = 0, 0\n",
    "        if hasattr(self, 'wind_grid_valid') and self.wind_grid_valid[row, col]:\n",
    "            wind_dir = self.wind_grid_dir[row, col]\n",
    "            wind_speed = self.wind_grid_speed[row, col]\n",
    "        \n",
    "        return np.array([rel_pos[0], rel_pos[1], distance, tidal_dir, tidal_speed, wind_dir, wind_speed])\n",
    "\n",
    "# DQN 에이전트 클래스 정의\n",
    "class DQNAgent:\n",
    "    def __init__(self, state_dim, action_dim):\n",
    "        self.state_dim = state_dim\n",
    "        self.action_dim = action_dim\n",
    "        \n",
    "        self.lr = 0.0001\n",
    "        self.gamma = 0.99\n",
    "        self.batch_size = 64\n",
    "        self.buffer_size = 100000\n",
    "        self.target_update = 1000\n",
    "        self.epsilon_start = 1.0\n",
    "        self.epsilon_end = 0.01\n",
    "        self.epsilon_decay = 10000\n",
    "        self.n_steps = 3\n",
    "        \n",
    "        self.alpha = 0.6\n",
    "        self.beta_start = 0.4\n",
    "        self.beta_end = 1.0\n",
    "        \n",
    "        # 모델을 GPU로 이동\n",
    "        self.policy_net = DuelingDQN(state_dim, action_dim).to(device)\n",
    "        self.target_net = DuelingDQN(state_dim, action_dim).to(device)\n",
    "        self.target_net.load_state_dict(self.policy_net.state_dict())\n",
    "        self.target_net.eval()\n",
    "        self.optimizer = optim.Adam(self.policy_net.parameters(), lr=self.lr)\n",
    "        \n",
    "        self.memory = deque(maxlen=self.buffer_size)\n",
    "        self.step_count = 0\n",
    "        \n",
    "    def select_action(self, state, epsilon):\n",
    "        \"\"\"액션 선택\"\"\"\n",
    "        self.step_count += 1\n",
    "        if random.random() < epsilon:\n",
    "            return random.randrange(self.action_dim)\n",
    "        state = torch.FloatTensor(state).unsqueeze(0).to(device)\n",
    "        with torch.no_grad():\n",
    "            q_values = self.policy_net(state)\n",
    "        return q_values.argmax().item()\n",
    "\n",
    "    def store_experience(self, state, action, reward, next_state, done):\n",
    "        \"\"\"경험 저장\"\"\"\n",
    "        experience = Experience(state, action, reward, next_state, done)\n",
    "        self.memory.append((experience, 1.0))\n",
    "\n",
    "    def sample_batch(self):\n",
    "        \"\"\"배치 샘플링\"\"\"\n",
    "        batch = random.sample(self.memory, min(len(self.memory), self.batch_size))\n",
    "        experiences, priorities = zip(*batch)\n",
    "        return experiences\n",
    "\n",
    "    def compute_loss(self, batch, beta):\n",
    "        \"\"\"손실 계산\"\"\"\n",
    "        states, actions, rewards, next_states, dones = zip(*batch)\n",
    "        states = torch.FloatTensor(states).to(device)\n",
    "        actions = torch.LongTensor(actions).to(device)\n",
    "        rewards = torch.FloatTensor(rewards).to(device)\n",
    "        next_states = torch.FloatTensor(next_states).to(device)\n",
    "        dones = torch.FloatTensor(dones).to(device)\n",
    "        \n",
    "        q_values = self.policy_net(states).gather(1, actions.unsqueeze(1))\n",
    "        next_q_values = self.policy_net(next_states).max(1)[1].unsqueeze(1)\n",
    "        target_next_q_values = self.target_net(next_states).gather(1, next_q_values)\n",
    "        targets = rewards + (1 - dones) * self.gamma * target_next_q_values.squeeze()\n",
    "        \n",
    "        loss = nn.MSELoss()(q_values.squeeze(), targets.detach())\n",
    "        return loss\n",
    "\n",
    "    def update(self):\n",
    "        \"\"\"모델 업데이트\"\"\"\n",
    "        if len(self.memory) < self.batch_size:\n",
    "            return\n",
    "        \n",
    "        beta = self.beta_start + (self.beta_end - self.beta_start) * min(1.0, self.step_count / 50000)\n",
    "        batch = self.sample_batch()\n",
    "        loss = self.compute_loss(batch, beta)\n",
    "        \n",
    "        self.optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        self.optimizer.step()\n",
    "        \n",
    "        if self.step_count % self.target_update == 0:\n",
    "            self.target_net.load_state_dict(self.policy_net.state_dict())\n",
    "\n",
    "# 학습 루프 정의\n",
    "def train_dqn(env, agent, max_episodes=20000):\n",
    "    rewards = []\n",
    "    path_lengths = []\n",
    "    epsilon = agent.epsilon_start\n",
    "    \n",
    "    image_dir = r\"C:\\baramproject\\trained_model\\sibal17\\episode_debug_image\"\n",
    "    data_dir = r\"C:\\baramproject\\trained_model\\sibal17\\episode_debug_data\"\n",
    "    os.makedirs(image_dir, exist_ok=True)\n",
    "    os.makedirs(data_dir, exist_ok=True)\n",
    "    \n",
    "    for episode in tqdm(range(max_episodes), desc=\"Training Episodes\"):\n",
    "        state = env.reset()\n",
    "        total_reward = 0\n",
    "        path_length = 0\n",
    "        done = False\n",
    "        debug_data = []\n",
    "        \n",
    "        while not done:\n",
    "            epsilon = max(agent.epsilon_end, epsilon - (agent.epsilon_start - agent.epsilon_end) / agent.epsilon_decay)\n",
    "            action = agent.select_action(state, epsilon)\n",
    "            next_state, reward, done, _ = env.step(action)\n",
    "            \n",
    "            q_values = agent.policy_net(torch.FloatTensor(state).unsqueeze(0).to(device)).detach().cpu().numpy().flatten()\n",
    "            debug_data.append({\n",
    "                \"step\": path_length,\n",
    "                \"state\": state.tolist(),\n",
    "                \"action\": action,\n",
    "                \"reward\": reward,\n",
    "                \"next_state\": next_state.tolist(),\n",
    "                \"q_values\": q_values.tolist(),\n",
    "                \"epsilon\": epsilon\n",
    "            })\n",
    "            \n",
    "            agent.store_experience(state, action, reward, next_state, done)\n",
    "            agent.update()\n",
    "            \n",
    "            state = next_state\n",
    "            total_reward += reward\n",
    "            path_length += 1\n",
    "        \n",
    "        rewards.append(total_reward)\n",
    "        path_lengths.append(path_length)\n",
    "        \n",
    "        if episode % 100 == 0:\n",
    "            print(f\"Episode {episode}, Total Reward: {total_reward}, Path Length: {path_length}\")\n",
    "            \n",
    "            plt.figure(figsize=(10, 8))\n",
    "            plt.imshow(env.grid, cmap='gray')\n",
    "            path_array = np.array(env.path)\n",
    "            plt.plot(path_array[:, 1], path_array[:, 0], 'r-', label='Path')\n",
    "            plt.plot(env.start_pos[1], env.start_pos[0], 'go', label='Start')\n",
    "            plt.plot(env.end_pos[1], env.end_pos[0], 'bo', label='End')\n",
    "            plt.legend()\n",
    "            plt.title(f\"Episode {episode} Path\")\n",
    "            plt.savefig(os.path.join(image_dir, f\"episode_{episode}.png\"))\n",
    "            plt.close()\n",
    "            \n",
    "            with open(os.path.join(data_dir, f\"episode_{episode}.json\"), 'w') as f:\n",
    "                json.dump(debug_data, f, indent=4)\n",
    "        \n",
    "        if episode % 1000 == 0 and episode > 0:\n",
    "            plt.plot(rewards)\n",
    "            plt.title(\"Total Rewards Over Episodes\")\n",
    "            plt.xlabel(\"Episode\")\n",
    "            plt.ylabel(\"Reward\")\n",
    "            plt.savefig(os.path.join(image_dir, f\"rewards_episode_{episode}.png\"))\n",
    "            plt.close()\n",
    "    \n",
    "    torch.save(agent.policy_net.state_dict(), r\"C:\\baramproject\\trained_model\\sibal17\\navigation_model.pth\")\n",
    "    return rewards, path_lengths\n",
    "\n",
    "# 메인 실행\n",
    "if __name__ == \"__main__\":\n",
    "    env = NavigationEnv()\n",
    "    state_dim = 7\n",
    "    action_dim = len(env.action_space)\n",
    "    agent = DQNAgent(state_dim, action_dim)\n",
    "    \n",
    "    rewards, path_lengths = train_dqn(env, agent)\n",
    "    \n",
    "    plt.plot(rewards)\n",
    "    plt.title(\"Total Rewards Over Episodes\")\n",
    "    plt.xlabel(\"Episode\")\n",
    "    plt.ylabel(\"Reward\")\n",
    "    plt.show()\n",
    "    \n",
    "    plt.plot(path_lengths)\n",
    "    plt.title(\"Path Lengths Over Episodes\")\n",
    "    plt.xlabel(\"Episode\")\n",
    "    plt.ylabel(\"Path Length\")\n",
    "    plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
