# Setting

In [31]:
# マリオ関連のimport
from nes_py.wrappers import JoypadSpace
import gym_super_mario_bros

# プロット関連のimport
import matplotlib.pyplot as plt
from matplotlib import animation, rc
from IPython.display import display

# 数値関連のimport
import math
import numpy as np
import numpy.random as rnd

# 警告関連のimport
import warnings

# マルチプロセス関連のimport
from concurrent.futures import ThreadPoolExecutor

In [32]:
# アニメーションのサイズを拡張
plt.rcParams['animation.embed_limit'] = 200

# シード値を設定
rnd.seed(1704034800)

# 警告を非表示
warnings.filterwarnings("ignore", category=UserWarning, module="gym.envs.registration")

# Condition

In [22]:
# ステージ
STAGE = 'SuperMarioBros-1-1-v0'

# アクションパターン
MOVEMENT = [
    ['right', 'B'],
    ['right', 'A', 'B'],
]

In [None]:
# 定数設定
MAX_WORKERS = 16       # 最大プロセス数
MAX_GENERATIONS = 50   # 最大世代数
NUM_MARIOS = 50        # 個体数
LEN_CHROMOSOME = 200   # 染色体の長さ
CROSS_RATE = 0.9       # 交叉率
MUTATION_RATE = 0.1    # 突然変異率
FRAME_INTERVAL = 10    # 行動するフレーム間隔

# Method

In [23]:
def create_generation():
    """初期世代を作成する関数"""
    return rnd.randint(len(MOVEMENT), size=(NUM_MARIOS, LEN_CHROMOSOME))

In [24]:
def cross(parent1, parent2):
    """交叉を行う関数"""
    cross_points = rnd.choice(LEN_CHROMOSOME, 1)
    child1 = np.concatenate([parent1[:cross_points[0]], parent2[cross_points[0]:]])
    child2 = np.concatenate([parent2[:cross_points[0]], parent1[cross_points[0]:]])
    return child1, child2

In [25]:
def mutation(mario):
    """突然変異を行う関数"""
    if rnd.random() < MUTATION_RATE:
        num_mutations = math.ceil(0.05 * LEN_CHROMOSOME)
        mutated_indexes = rnd.choice(LEN_CHROMOSOME, num_mutations, replace=False)
        for index in mutated_indexes:
            mario[index] = rnd.randint(len(MOVEMENT))

    return mario

In [26]:
def sorts(fitnesses, generation):
    """マリオを並び替える関数"""
    return zip(*sorted(zip(fitnesses, generation), key=lambda x: x[0], reverse=True))

In [27]:
def print_fitness(fitnesses, current_generation):
    """適応度を出力する関数"""
    max = fitnesses[0]
    min = fitnesses[NUM_MARIOS - 1]
    avg = int(sum(fitnesses) / NUM_MARIOS)
    print("{:<3}   max: {:<5}   min: {:<5}   avg: {:<5}".format(current_generation, max, min, avg))

In [28]:
def roulette_selection(fitnesses, generation):
    """ルーレット選択を行う関数"""
    selection_rates = fitnesses / np.sum(fitnesses)
    parent_indexes = rnd.choice(NUM_MARIOS, 2, p=selection_rates, replace=False)
    return generation[parent_indexes[0]], generation[parent_indexes[1]]

In [29]:
def evaluate(mario):
    """評価関数"""
    # 環境設定
    env = gym_super_mario_bros.make(STAGE)
    env = JoypadSpace(env, MOVEMENT)
    env.reset()

    # ゲーム本番
    breaker = False
    for action in mario:
        for _ in range(FRAME_INTERVAL):
            observation, reward, done, info = env.step(action)
            if done:
                breaker = True
                break

        if breaker:
            break

    # 評価値計算
    evaluation = info["x_pos"]
    if(info["flag_get"]):
        evaluation += 50000  # ゴール到達は高得点

    env.close()
    return evaluation

# Execution

In [30]:
# 世代ごとの最優秀マリオ
super_marios = []

# GA本番（マルチプロセス）
with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
    # 初期世代
    generation = create_generation()

    for current_generation in range(1, MAX_GENERATIONS + 1):
        # 適応度計算
        fitnesses = []
        evaluations = list(executor.map(evaluate, generation))
        fitnesses.extend(evaluations)

        # 並び替え・表示
        fitnesses, generation = sorts(fitnesses, generation)
        print_fitness(fitnesses, current_generation)
        
        # 最優秀マリオを保存
        super_marios.append(generation[0])
        
        # ゴール到達
        if(fitnesses[0] > 50000):
            break

        # 世代交代
        next_generation = []
        num_elite = math.ceil(NUM_MARIOS * (1 - CROSS_RATE))
        next_generation.extend(generation[:num_elite])
        while len(next_generation) < NUM_MARIOS:
            parent1, parent2 = roulette_selection(fitnesses, generation)
            child1, child2 = cross(parent1, parent2)
            next_generation.extend([mutation(child1), mutation(child2)])

        generation = next_generation[:NUM_MARIOS]

1     max: 1674    min: 299     avg: 742  


KeyboardInterrupt: 

# Result

In [None]:
def mario_animation(mario):
    """マリオのアニメーションを作成する関数"""
    # 環境設定
    env = gym_super_mario_bros.make(STAGE)
    env = JoypadSpace(env, MOVEMENT)
    env.reset()

    # アニメーションの準備
    fig = plt.figure()
    images = []

    # ゲーム本番
    breaker = False
    for action in mario:
        for _ in range(FRAME_INTERVAL):
            observation, reward, done, info = env.step(action)
            image = plt.imshow(env.render(mode='rgb_array'))
            images.append([image])
            if done:
                breaker = True
                break

        if breaker:
            break

    # アニメーションを表示
    anime = animation.ArtistAnimation(fig, images, interval=20, blit=True)
    rc('animation', html='jshtml')
    display(anime)

In [None]:
# ゴール到達のアニメーションを表示
# mario_animation(super_marios[-1])