# Setting


In [None]:
# マリオ関連のimport
from nes_py.wrappers import JoypadSpace
import gym_super_mario_bros

# プロット関連のimport
import matplotlib.pyplot as plt
from matplotlib import animation, rc
from IPython.display import display

# 数値関連のimport
import math
import numpy as np
import numpy.random as rnd

# 警告関連のimport
import warnings

# マルチプロセス関連のimport
from concurrent.futures import ThreadPoolExecutor

In [None]:
# アニメーションのサイズを拡張
plt.rcParams['animation.embed_limit'] = 200

# シード値を設定（Time stamp of 1/1/2024）
rnd.seed(1704034800)

# 警告を非表示
warnings.filterwarnings("ignore", category=UserWarning, module="gym.envs.registration")

# Condition


In [None]:
# ステージ
STAGE = 'SuperMarioBros-1-1-v0'

# アクションパターン
MOVEMENT = [
    ['right', 'B'],
    ['right', 'A', 'B']
]

In [None]:
# 定数設定
MAX_WORKERS = 10            # 最大プロセス数
MAX_GENERATIONS = 50        # 最大世代数
NUM_MARIOS = 50             # 個体数
LEN_CHROMOSOME = 500        # 染色体の長さ
CROSS_RATE = 0.9            # 交叉率
MUTATION_RATE = 0.1         # 突然変異率
MUTATION_POINTS_RATE = 0.1  # 突然変異点率
FRAME_INTERVAL = 4          # 行動するフレーム間隔
STOP_FRAMES = 10            # 停滞と判断するフレーム数

# Method


In [None]:
def create_generation():
    """初期世代を作成する関数"""
    return rnd.randint(len(MOVEMENT), size=(NUM_MARIOS, LEN_CHROMOSOME))

In [None]:
def cross(parent1, parent2):
    """"交叉を行う関数"""
    cross_point = rnd.randint(LEN_CHROMOSOME)
    child1 = np.concatenate([parent1[:cross_point], parent2[cross_point:]])
    child2 = np.concatenate([parent2[:cross_point], parent1[cross_point:]])
    return child1, child2

In [None]:
def mutation(mario):
    """突然変異を行う関数"""
    if rnd.random() < MUTATION_RATE:
        num_mutation_points = math.ceil(MUTATION_POINTS_RATE * LEN_CHROMOSOME)
        mutated_points = rnd.choice(LEN_CHROMOSOME, num_mutation_points, replace=False)
        for mutated_point in mutated_points:
            mario[mutated_point] = rnd.randint(len(MOVEMENT))

    return mario

In [None]:
def sorts(fitnesses, generation, images):
    """並び替える関数"""
    return zip(*sorted(zip(fitnesses, generation, images), key=lambda x: x[0], reverse=True))

In [None]:
def print_fitness(fitnesses, current_generation):
    """適応度を出力する関数"""
    max = fitnesses[0]
    min = fitnesses[NUM_MARIOS - 1]
    avg = int(sum(fitnesses) / NUM_MARIOS)
    print("{:<3}   max: {:<5}   min: {:<5}   avg: {:<5}".format(current_generation, max, min, avg))
    return max, min, avg

In [None]:
def roulette_selection(fitnesses, generation):
    """ルーレット選択を行う関数"""
    selection_rates = fitnesses / np.sum(fitnesses)
    parent_indexes = rnd.choice(NUM_MARIOS, 2, p=selection_rates, replace=False)
    return generation[parent_indexes[0]], generation[parent_indexes[1]]

In [None]:

def evaluate(mario):
    """評価関数"""
    # 環境設定
    env = gym_super_mario_bros.make(STAGE)
    env = JoypadSpace(env, MOVEMENT)
    env.reset()

    # ゲーム本番
    breaker = False
    positions = []
    for action in mario:
        for _ in range(FRAME_INTERVAL):
            observation, reward, done, info = env.step(action)
            if done:
                breaker = True
                break

        if breaker:
            break
        
        # 座標を保存
        positions.append(info["x_pos"])

        # 停滞の場合は終了
        if len(positions) >= STOP_FRAMES and len(set(positions[-STOP_FRAMES:])) == 1:
            break

    # 適応度計算（進んだ距離）
    fitness = info["x_pos"]

    return fitness, env.render(mode='rgb_array'), info["flag_get"]

# Execution


In [None]:
# 世代ごとの最優秀マリオ
super_marios = []

# 世代ごとの適応度(最大・最小・平均)
generations_fitnesses = []

# GA本番（マルチプロセス）
with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
    # 初期世代
    generation = create_generation()

    # 前世代の最大適応度
    previous_max_fitness = 0

    for current_generation in range(1, MAX_GENERATIONS + 1):
        # 評価
        evaluations = list(executor.map(evaluate, generation))
        fitnesses = [evaluation[0] for evaluation in evaluations]
        images = [evaluation[1] for evaluation in evaluations]

        # 並び替え・表示
        fitnesses, generation, images = sorts(fitnesses, generation, images)
        if previous_max_fitness < fitnesses[0]:
            plt.figure(figsize=(2, 2))
            plt.imshow(images[0])
            plt.show()

        generation_fitnesses = print_fitness(fitnesses, current_generation)

        # 保存
        generations_fitnesses.append(generation_fitnesses)
        super_marios.append(generation[0])

        # ゴール到達なら終了
        if any(evaluation[2] for evaluation in evaluations):
            break

        # 世代交代
        num_elite = math.ceil(NUM_MARIOS * (1 - CROSS_RATE))
        next_generation = list(generation[:num_elite])
        while len(next_generation) < NUM_MARIOS:
            parent1, parent2 = roulette_selection(fitnesses, generation)
            child1, child2 = cross(parent1, parent2)
            next_generation.extend([mutation(child1), mutation(child2)])

        previous_max_fitness = fitnesses[0]
        generation = next_generation[:NUM_MARIOS]

# Result


In [None]:
def mario_animation(mario):
    """マリオのアニメーションを作成する関数"""
    # 環境設定
    env = gym_super_mario_bros.make(STAGE)
    env = JoypadSpace(env, MOVEMENT)
    env.reset()

    # コマ数を確認
    def count_frames():
        returner = False
        for count, action in enumerate(mario):
            for _ in range(FRAME_INTERVAL):
                observation, reward, done, info = env.step(action)
                if done:
                    returner = True
                    break

            if returner:
                env.reset()
                return count

    # 初期化関数
    def init():
        pass

    # 描画を更新
    def update(frame):
        for _ in range(FRAME_INTERVAL):
            observation, reward, done, info = env.step(mario[frame])
            image.set_data(env.render(mode='rgb_array'))
            if done:
                env.reset()

    # アニメーションの準備
    fig, ax = plt.subplots()
    image = ax.imshow(env.render(mode='rgb_array'))

    # アニメーション作成（FPS of Super Mario Bros is 20）
    anime = animation.FuncAnimation(fig, update, init_func=init, frames=range(count_frames()), interval=1000 * FRAME_INTERVAL / 20)

    # アニメーションを表示
    rc('animation', html='jshtml')
    display(anime)

In [None]:
# ゴール到達のアニメーション
mario_animation(super_marios[-1])

In [None]:
def plot_fitnesses(fitnesses):
    """世代ごとの適応度をプロットする関数"""
    # 分解
    max, min, avg = zip(*fitnesses)

    # 横軸
    x_values = list(range(1, len(max) + 1))

    # 描画
    plt.plot(x_values, max, label='max')
    plt.plot(x_values, min, label='min')
    plt.plot(x_values, avg, label='avg')

    # 凡例
    plt.legend()

    # ラベル
    plt.xlabel('Generation')
    plt.ylabel('Fitness')

    # 表示
    plt.show()

In [None]:
# 適応度の推移
plot_fitnesses(generations_fitnesses)