In [1]:
import torch
x = torch.rand(5, 3)
print(x)

tensor([[0.8842, 0.1227, 0.7196],
        [0.0593, 0.5549, 0.7208],
        [0.8850, 0.5103, 0.3473],
        [0.0457, 0.6139, 0.5317],
        [0.9564, 0.2831, 0.4607]])


In [2]:
import random
import numpy as np
import matplotlib.pyplot as plt

In [3]:
# Q-Learning test: 1 agent start on a 2D 30 by 30 grid, The End
# with 4 actions (up, down, left, right)
# and 1 goal (bottom right corner)
# starting from the top left corner (0, 0)
# with a reward of 1 for reaching the goal
# and a reward of -1 for hitting the walls
# and a reward of -0.01 for each step taken
# and a discount factor of 0.9

class TheEnd:
    def __init__(self, size=30, goal=(29, 29), walls=None):
        self.size = size
        self.goal_xy = goal
        self.walls = set(walls) if walls else set()
        self.player_xy = (0, 0)
        self.goal_reached = False

    def reset(self):
        """ゲームをリセットして、スタート地点 (0,0) に戻す。"""
        self.player_xy = (0, 0)
        self.goal_reached = False
        return self.player_xy 

    def is_valid(self, nx, ny):
        """範囲内かつ壁でなければ True を返す。"""
        if nx < 0 or nx >= self.size or ny < 0 or ny >= self.size:
            return False
        if (nx, ny) in self.walls:
            return False
        return True

    def step(self, action):
        """
        action: 0=up, 1=down, 2=left, 3=right
        Returns: (next_state, reward, done, info)
        """
        x, y = self.player_xy
        if action == 0:   # up
            nx, ny = x, y - 1
        elif action == 1: # down
            nx, ny = x, y + 1
        elif action == 2: # left
            nx, ny = x - 1, y
        elif action == 3: # right
            nx, ny = x + 1, y
        else:
            return self.player_xy, 0.0, False, {}

        reward = -0.01 
        
        # 移動先が有効かチェック（壁や外枠に当たる場合は移動させず罰則）
        if not self.is_valid(nx, ny):
            reward = -1.0
            # プレイヤー座標は更新しない（動かない）
            next_state = self.player_xy
        else:
            # 有効なら座標を更新
            self.player_xy = (nx, ny)
            next_state = (nx, ny)
        
        # ゴール判定
        if next_state == self.goal_xy:
            reward = 1.0
            self.goal_reached = True
            done = True
        else:
            done = False

        return next_state, reward, done, {}


def train_q_learning(env, num_episodes=500, alpha=0.1, gamma=0.9, epsilon=0.1):
    # Q テーブルの初期化
    Q = np.zeros((env.size, env.size, 4))
    all_episode_rewards = []
    
    # 追加: 各エピソードの軌跡を保存するリスト
    episode_paths = []

    for episode in range(num_episodes):
        state = env.reset()
        done = False
        total_reward = 0.0
        
        # 今のエピソードで訪れた (x, y) のリスト
        path = []

        while not done:
            path.append(state)  # 現在位置を記録
            
            x, y = state
            # ε-greedy で行動を選択
            if random.random() < epsilon:
                action = random.randint(0, 3)
            else:
                action = np.argmax(Q[x, y, :])
            
            next_state, reward, done, _ = env.step(action)
            nx, ny = next_state
            total_reward += reward

            # Q 値の更新
            current_q = Q[x, y, action]
            max_next_q = np.max(Q[nx, ny, :])
            new_q = current_q + alpha * (reward + gamma * max_next_q - current_q)
            Q[x, y, action] = new_q

            state = next_state

        # エピソード終了時の最終位置も記録しておく
        path.append(state)
        all_episode_rewards.append(total_reward)
        episode_paths.append(path)

    # 学習後、各エピソードの軌跡を画像として保存
    for episode_index, path in enumerate(episode_paths):
        episode_total_reward = all_episode_rewards[episode_index]
        plt.figure()  # 新しい Figure を作成
        # path の x 座標、y 座標をそれぞれ取り出し
        xs = [p[0] for p in path]
        ys = [p[1] for p in path]

        # 壁の描画（散布図）: 壁が多い場合は描画に時間がかかるので注意
        wall_xs = [w[0] for w in env.walls]
        wall_ys = [w[1] for w in env.walls]
        plt.scatter(wall_xs, wall_ys, marker='s', color='black')  # 黒色で壁を描画

        # エージェントの通ったルートを描画
        plt.plot(xs, ys, marker='o', linestyle='-')
        
        # ゴールをわかりやすく描画
        # (デフォルトの色・スタイルに任せるために明示設定はしない)
        gx, gy = env.goal_xy
        plt.scatter(gx, gy, marker='X', s=200)  # ゴール

        # スタート位置を強調表示（(0,0) なら最初の要素 path[0] など）
        sx, sy = path[0]
        plt.scatter(sx, sy, marker='s', s=100)  # スタート

        # 軸の範囲（少し余裕を持たせる）
        plt.xlim(-1, env.size)
        plt.ylim(-1, env.size)

        plt.title(f"Episode {episode_index} Path, Reward: {episode_total_reward}")
        plt.xlabel("X")
        plt.ylabel("Y")
        plt.grid(True)
        plt.xticks(range(env.size))
        plt.yticks(range(env.size))
        plt.savefig(f"./pic/episode_{episode_index}.png")  # ファイル名を指定
        plt.close()  # Figure を閉じる

    return Q, all_episode_rewards


In [5]:
# delete everything in the pic directory
import os
import shutil
if os.path.exists("./pic"):
    shutil.rmtree("./pic")
os.makedirs("./pic", exist_ok=True)

walls_pig = [(9, 3), (10, 3), (21, 3), (9, 4), (10, 4), (11, 4), (21, 4), (8, 5), (9, 5), (10, 5), (11, 5), (12, 5), (19, 5), (20, 5), (21, 5), (22, 5), (8, 6), (9, 6), (10, 6), (11, 6), (12, 6), (19, 6), (20, 6), (21, 6), (22, 6), (8, 7), (9, 7), (10, 7), (11, 7), (12, 7), (13, 7), (14, 7), (15, 7), (16, 7), (17, 7), (18, 7), (19, 7), (20, 7), (21, 7), (22, 7), (8, 8), (9, 8), (10, 8), (18, 8), (19, 8), (20, 8), (21, 8), (22, 8), (7, 9), (8, 9), (21, 9), (22, 9), (23, 9), (5, 10), (6, 10), (7, 10), (23, 10), (24, 10), (4, 11), (5, 11), (6, 11), (7, 11), (17, 11), (18, 11), (19, 11), (20, 11), (21, 11), (22, 11), (23, 11), (24, 11), (25, 11), (4, 12), (5, 12), (7, 12), (17, 12), (19, 12), (21, 12), (23, 12), (25, 12), (4, 13), (17, 13), (18, 13), (19, 13), (20, 13), (21, 13), (22, 13), (23, 13), (25, 13), (3, 14), (4, 14), (17, 14), (25, 14), (2, 15), (3, 15), (24, 15), (25, 15), (2, 16), (3, 16), (20, 16), (22, 16), (24, 16), (2, 17), (3, 17), (24, 17), (2, 18), (3, 18), (23, 18), (2, 19), (4, 19), (5, 19), (22, 19), (23, 19), (2, 20), (5, 20), (6, 20), (7, 20), (20, 20), (21, 20), (22, 20), (2, 21), (7, 21), (8, 21), (9, 21), (10, 21), (11, 21), (12, 21), (13, 21), (14, 21), (15, 21), (16, 21), (17, 21), (18, 21), (19, 21), (20, 21), (21, 21), (22, 21), (23, 21), (2, 22), (19, 22), (21, 22), (22, 22), (23, 22), (2, 23), (19, 23), (21, 23), (22, 23), (23, 23), (2, 24), (22, 24), (23, 24), (2, 25)]
walls_medium_difficulty = [(11, 0), (13, 0), (20, 0), (21, 0), (22, 0), (23, 0), (11, 1), (13, 1), (22, 1), (9, 2), (10, 2), (11, 2), (12, 2), (13, 2), (0, 3), (1, 3), (2, 3), (3, 3), (9, 3), (11, 3), (16, 3), (17, 3), (18, 3), (19, 3), (26, 3), (3, 4), (4, 4), (9, 4), (11, 4), (16, 4), (19, 4), (20, 4), (21, 4), (26, 4), (3, 5), (4, 5), (9, 5), (10, 5), (11, 5), (15, 5), (16, 5), (21, 5), (26, 5), (0, 6), (1, 6), (2, 6), (3, 6), (15, 6), (21, 6), (26, 6), (15, 7), (25, 7), (26, 7), (15, 8), (24, 8), (25, 8), (5, 9), (6, 9), (7, 9), (8, 9), (15, 9), (16, 9), (3, 10), (4, 10), (5, 10), (8, 10), (12, 10), (13, 10), (14, 10), (15, 10), (16, 10), (3, 11), (8, 11), (11, 11), (12, 11), (15, 11), (16, 11), (3, 12), (8, 12), (11, 12), (16, 12), (24, 12), (25, 12), (26, 12), (27, 12), (3, 13), (8, 13), (11, 13), (18, 13), (23, 13), (24, 13), (27, 13), (28, 13), (3, 14), (8, 14), (11, 14), (18, 14), (22, 14), (23, 14), (28, 14), (29, 14), (3, 15), (8, 15), (11, 15), (12, 15), (18, 15), (22, 15), (29, 15), (12, 16), (18, 16), (22, 16), (23, 16), (13, 17), (18, 17), (23, 17), (17, 18), (18, 18), (23, 18), (17, 19), (23, 19), (24, 19), (6, 21), (7, 21), (8, 21), (9, 21), (3, 22), (4, 22), (5, 22), (6, 22), (9, 22), (10, 22), (2, 23), (3, 23), (10, 23), (18, 23), (19, 23), (20, 23), (21, 23), (22, 23), (10, 24), (18, 24), (22, 24), (23, 24), (9, 25), (10, 25), (17, 25), (23, 25), (5, 26), (8, 26), (9, 26), (17, 26), (23, 26), (5, 27), (6, 27), (7, 27), (17, 27), (23, 27), (17, 28), (18, 28), (22, 28), (23, 28), (18, 29), (21, 29), (22, 29)]
lol_malware = [(24, 1), (2, 2), (3, 2), (15, 2), (22, 2), (23, 2), (24, 2), (25, 2), (26, 2), (2, 3), (3, 3), (4, 3), (5, 3), (6, 3), (7, 3), (9, 3), (10, 3), (11, 3), (12, 3), (13, 3), (14, 3), (15, 3), (16, 3), (17, 3), (22, 3), (2, 4), (4, 4), (5, 4), (7, 4), (9, 4), (10, 4), (13, 4), (14, 4), (17, 4), (21, 4), (22, 4), (23, 4), (2, 5), (4, 5), (5, 5), (7, 5), (10, 5), (13, 5), (17, 5), (21, 5), (22, 5), (23, 5), (24, 5), (25, 5), (2, 6), (4, 6), (7, 6), (8, 6), (11, 6), (12, 6), (17, 6), (18, 6), (19, 6), (20, 6), (22, 6), (25, 6), (2, 7), (8, 7), (22, 7), (24, 7), (25, 7), (2, 8), (8, 8), (22, 8), (23, 8), (24, 8), (10, 11), (11, 11), (12, 11), (13, 11), (14, 11), (15, 11), (17, 11), (18, 11), (20, 11), (21, 11), (22, 11), (23, 11), (24, 11), (2, 12), (6, 12), (7, 12), (10, 12), (12, 12), (15, 12), (16, 12), (17, 12), (20, 12), (2, 13), (3, 13), (6, 13), (7, 13), (10, 13), (12, 13), (13, 13), (16, 13), (17, 13), (20, 13), (3, 14), (6, 14), (7, 14), (8, 14), (10, 14), (13, 14), (14, 14), (15, 14), (16, 14), (20, 14), (3, 15), (4, 15), (6, 15), (8, 15), (10, 15), (14, 15), (15, 15), (20, 15), (4, 16), (6, 16), (8, 16), (9, 16), (10, 16), (20, 16), (4, 17), (5, 17), (6, 17), (9, 17), (10, 17), (20, 17), (21, 17), (11, 20), (12, 20), (13, 20), (14, 20), (18, 20), (19, 20), (20, 20), (21, 20), (22, 20), (3, 21), (4, 21), (5, 21), (6, 21), (7, 21), (8, 21), (9, 21), (10, 21), (11, 21), (14, 21), (15, 21), (18, 21), (22, 21), (23, 21), (3, 22), (9, 22), (15, 22), (16, 22), (18, 22), (3, 23), (10, 23), (15, 23), (16, 23), (18, 23), (3, 24), (10, 24), (11, 24), (12, 24), (13, 24), (14, 24), (15, 24), (18, 24), (3, 25), (11, 25), (12, 25), (18, 25), (3, 26), (18, 26), (19, 26), (3, 27), (19, 27)]
発 = [(22, 0), (23, 0), (24, 0), (8, 1), (9, 1), (14, 1), (15, 1), (21, 1), (22, 1), (6, 2), (7, 2), (8, 2), (9, 2), (10, 2), (14, 2), (15, 2), (16, 2), (20, 2), (21, 2), (6, 3), (10, 3), (16, 3), (17, 3), (18, 3), (19, 3), (20, 3), (4, 4), (5, 4), (10, 4), (11, 4), (17, 4), (18, 4), (19, 4), (11, 5), (16, 5), (17, 5), (20, 5), (11, 6), (16, 6), (20, 6), (21, 6), (11, 7), (15, 7), (16, 7), (21, 7), (6, 8), (7, 8), (8, 8), (9, 8), (10, 8), (11, 8), (16, 8), (17, 8), (18, 8), (19, 8), (20, 8), (21, 8), (6, 9), (6, 10), (7, 10), (8, 10), (14, 10), (15, 10), (19, 10), (20, 10), (21, 10), (22, 10), (23, 10), (8, 11), (9, 11), (10, 11), (14, 11), (15, 11), (19, 11), (9, 12), (10, 12), (15, 12), (16, 12), (19, 12), (23, 12), (24, 12), (25, 12), (26, 12), (5, 13), (6, 13), (10, 13), (11, 13), (16, 13), (18, 13), (19, 13), (22, 13), (23, 13), (6, 14), (7, 14), (10, 14), (11, 14), (16, 14), (17, 14), (18, 14), (21, 14), (22, 14), (3, 15), (4, 15), (8, 15), (9, 15), (10, 15), (15, 15), (16, 15), (20, 15), (21, 15), (4, 16), (5, 16), (20, 16), (5, 17), (6, 17), (7, 17), (19, 17), (20, 17), (21, 17), (7, 18), (8, 18), (19, 18), (21, 18), (22, 18), (23, 18), (7, 19), (8, 19), (9, 19), (18, 19), (19, 19), (23, 19), (24, 19), (6, 20), (9, 20), (10, 20), (18, 20), (24, 20), (4, 21), (5, 21), (6, 21), (10, 21), (18, 21), (19, 21), (25, 21), (10, 22), (11, 22), (17, 22), (19, 22), (25, 22), (11, 23), (16, 23), (17, 23), (19, 23), (20, 23), (25, 23), (26, 23), (10, 24), (16, 24), (20, 24), (21, 24), (26, 24), (5, 25), (6, 25), (7, 25), (8, 25), (9, 25), (10, 25), (15, 25), (16, 25), (21, 25), (26, 25), (4, 26), (5, 26), (6, 26), (15, 26), (21, 26)]
env = TheEnd(size=30, goal=(29, 29), walls=発)
learned_Q, rewards_history = train_q_learning(env, num_episodes=901)
print(rewards_history)

[-595.0499999999508, -27.37000000000064, -70.81000000000083, -38.430000000000206, -77.07000000000217, -147.17000000000334, -52.42999999999828, -71.47000000000008, -80.05000000000128, -43.93000000000053, -113.17000000000796, -22.97000000000051, -71.02999999999975, -90.11000000000503, -36.39000000000044, -23.650000000000613, -40.79000000000031, -47.36999999999915, -104.73000000000812, -30.43000000000091, -24.25000000000036, -57.54999999999879, -19.890000000000327, -36.79000000000067, -15.369999999999816, -49.24999999999882, -56.48999999999958, -40.310000000000684, -17.929999999999975, -25.670000000000712, -27.190000000000516, -38.350000000000364, -13.909999999999854, -57.80999999999912, -32.99000000000083, -42.0899999999996, -29.630000000001065, -29.510000000001384, -31.510000000000865, -17.390000000000047, -8.289999999999939, -39.889999999999866, -38.690000000000154, -18.45000000000024, -31.85000000000074, -31.47000000000127, -30.19000000000058, -31.110000000000817, -20.5700000000002, -

In [6]:
import imageio
import os
from glob import glob

# Create a GIF from the images in the ./pic directory
def create_gif_from_images(image_folder, output_gif_path, duration=100):
    images = []
    # Sort filenames numerically by extracting the episode number
    filenames = sorted(
        glob(os.path.join(image_folder, '*.png')),
        key=lambda x: int(os.path.basename(x).split('_')[1].split('.')[0])
    )
    for filename in filenames:
        images.append(imageio.imread(filename))
    imageio.mimsave(output_gif_path, images, duration=duration)

# Create the GIF
create_gif_from_images('./pic', 'q_learning_path.gif', duration=0.3)
# The GIF will be saved as 'q_learning_path.gif' in the current directory.

  images.append(imageio.imread(filename))
