## 1. Import Libraries

In [1]:
!pip install stable-baselines3[extra] pygame


Collecting stable-baselines3[extra]
  Downloading stable_baselines3-2.6.0-py3-none-any.whl.metadata (4.8 kB)
Collecting gymnasium<1.2.0,>=0.29.1 (from stable-baselines3[extra])
  Downloading gymnasium-1.1.1-py3-none-any.whl.metadata (9.4 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch<3.0,>=2.3->stable-baselines3[extra])
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch<3.0,>=2.3->stable-baselines3[extra])
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch<3.0,>=2.3->stable-baselines3[extra])
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch<3.0,>=2.3->stable-baselines3[extra])
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)


# Intro to Reinforcement Learning

CartPole simulation을 위한 예시


In [8]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation

# -------------------------------
# CartPole 모델 (더 촘촘한 시뮬레이션: dt = 0.002)
# -------------------------------
class CartPole:
    def __init__(self):
        self.g = 9.8
        self.mc = 0.01
        self.mp = 0.1
        self.l = 0.5  # half pole length
        self.dt = 0.00001  # 더 촘촘한 시간 간격
        self.total_mass = self.mc + self.mp
        self.pml = self.mp * self.l
        self.reset()

    def reset(self):
        self.state = np.array([0.0, 0.0, np.deg2rad(10), 0.0])
        return self.state

    def step(self, force=0.0):
        x, x_dot, theta, theta_dot = self.state
        costheta = np.cos(theta)
        sintheta = np.sin(theta)
        temp = (force + self.pml * theta_dot**2 * sintheta) / self.total_mass
        thetaacc = (self.g * sintheta - costheta * temp) / (
            self.l * (4.0/3.0 - self.mp * costheta**2 / self.total_mass)
        )
        xacc = temp - self.pml * thetaacc * costheta / self.total_mass

        x += self.dt * x_dot
        x_dot += self.dt * xacc
        theta += self.dt * theta_dot
        theta_dot += self.dt * thetaacc

        self.state = np.array([x, x_dot, theta, theta_dot])
        return self.state

# -------------------------------
# 시뮬레이션 및 애니메이션
# -------------------------------
def simulate(cartpole, sim_time=6.0, frame_interval=0.02):
    dt = cartpole.dt
    total_steps = int(sim_time / dt)
    frame_steps = int(frame_interval / dt)
    sampled_states = []

    for step in range(total_steps):
        cartpole.step(0.0)
        if step % frame_steps == 0:
            sampled_states.append(cartpole.state.copy())

    return np.array(sampled_states)

def animate_cartpole(states):
    fig, ax = plt.subplots(figsize=(8, 4))
    ax.set_xlim(-2.5, 2.5)
    ax.set_ylim(-0.5, 1.5)
    ax.set_aspect('equal')
    ax.set_title("CartPole Pendulum Exercise")

    cart_width = 0.4
    cart_height = 0.2
    wheel_radius = 0.05
    pole_length = 1.0  # 2 * l

    cart = plt.Rectangle((0, 0), cart_width, cart_height, color='black')
    pole_line, = ax.plot([], [], lw=4, color='blue')
    wheel_left = plt.Circle((0, 0), wheel_radius, color='gray')
    wheel_right = plt.Circle((0, 0), wheel_radius, color='gray')

    ax.add_patch(cart)
    ax.add_patch(wheel_left)
    ax.add_patch(wheel_right)

    def init():
        cart.set_xy((-cart_width/2, 0))
        wheel_left.center = (-cart_width/4, 0)
        wheel_right.center = (cart_width/4, 0)
        pole_line.set_data([], [])
        return cart, wheel_left, wheel_right, pole_line

    def update(i):
        x = states[i, 0]
        theta = states[i, 2]

        cart.set_xy((x - cart_width/2, 0))
        wheel_left.center = (x - cart_width/4, 0)
        wheel_right.center = (x + cart_width/4, 0)

        pole_x = [x, x + pole_length * np.sin(theta)]
        pole_y = [cart_height, cart_height + pole_length * np.cos(theta)]
        pole_line.set_data(pole_x, pole_y)

        return cart, wheel_left, wheel_right, pole_line

    ani = animation.FuncAnimation(
        fig, update, frames=len(states),
        init_func=init, blit=True, interval=20  # 실제 애니메이션 간격은 여전히 0.02초
    )
    plt.close()
    return ani

# -------------------------------
# 실행
# -------------------------------
cp = CartPole()
states = simulate(cp, sim_time=6.0, frame_interval=0.02)
ani = animate_cartpole(states)

from IPython.display import HTML
HTML(ani.to_jshtml())
