In [None]:
import os, random, math, numpy as np, torch, torch.nn as nn, torch.optim as optim
from collections import deque

os.environ.setdefault("SDL_VIDEODRIVER", "dummy")
os.environ.setdefault("SDL_AUDIODRIVER", "dummy")
DEFAULT_CHECKPOINT_DIR = "/kaggle/working"
os.makedirs(DEFAULT_CHECKPOINT_DIR, exist_ok=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

CHECKPOINT_PATH = os.path.join(DEFAULT_CHECKPOINT_DIR, "double_dqn_flappy.pth")

# --- Hyperparameters ---
EPI_NUMS = 30000
LR = 5e-5
BATCH_SIZE = 128
BUFFER_SIZE = 100000
GAMMA = 0.995
TAU = 0.001
EPSILON_DECAY = 80000
EPSILON_START = 1.0
EPSILON_FINAL = 0.01
WARMUP_STEPS = 10000
MAX_STEPS = 6000

# --- Replay Buffer ---
class ReplayBuffer:
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)

    def push(self, state, action, reward, next_state, done):
        self.buffer.append((state, action, reward, next_state, done))

    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        state, action, reward, next_state, done = map(np.array, zip(*batch))
        return (
            torch.FloatTensor(state),
            torch.LongTensor(action),
            torch.FloatTensor(reward),
            torch.FloatTensor(next_state),
            torch.FloatTensor(done),
        )

    def __len__(self):
        return len(self.buffer)

class DQN(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(DQN, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, output_dim)
        )

    def forward(self, x):
        return self.net(x)


class Agent:
    def __init__(self, state_dim=4, n_actions=2):
        self.device = device
        self.n_actions = n_actions
        self.policy_net = DQN(state_dim, n_actions).to(device)
        self.target_net = DQN(state_dim, n_actions).to(device)
        self.target_net.load_state_dict(self.policy_net.state_dict())
        self.optimizer = optim.Adam(self.policy_net.parameters(), lr=LR)
        self.loss_fn = nn.SmoothL1Loss()

    def act(self, state, epsilon):
        if random.random() < epsilon:
            return random.randint(0, self.n_actions - 1)
        state_v = torch.FloatTensor(state).unsqueeze(0).to(device)
        q_vals = self.policy_net(state_v)
        return q_vals.argmax().item()

    def update(self, buffer, batch_size):
        if len(buffer) < batch_size:
            return None

        state, action, reward, next_state, done = buffer.sample(batch_size)
        state, action, reward, next_state, done = (
            state.to(device),
            action.to(device),
            reward.to(device),
            next_state.to(device),
            done.to(device),
        )

        # Q-values hiện tại
        q_values = self.policy_net(state).gather(1, action.unsqueeze(1)).squeeze(1)

        # Double DQN: policy chọn action, target đánh giá
        next_actions = self.policy_net(next_state).argmax(1)
        next_q_values = self.target_net(next_state).gather(1, next_actions.unsqueeze(1)).squeeze(1)
        expected_q = reward + GAMMA * next_q_values * (1 - done)

        # Loss
        loss = self.loss_fn(q_values, expected_q.detach())

        self.optimizer.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(self.policy_net.parameters(), 1.0)
        self.optimizer.step()

        # Soft update
        for target_param, param in zip(self.target_net.parameters(), self.policy_net.parameters()):
            target_param.data.copy_(TAU * param.data + (1.0 - TAU) * target_param.data)

        return loss.item()

    def save(self, path):
        torch.save({
            'policy_state_dict': self.policy_net.state_dict(),
            'target_state_dict': self.target_net.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
        }, path)

    def load(self, path):
        checkpoint = torch.load(path, map_location=device)
        self.policy_net.load_state_dict(checkpoint['policy_state_dict'])
        self.target_net.load_state_dict(checkpoint.get('target_state_dict', checkpoint['policy_state_dict']))
        if 'optimizer_state_dict' in checkpoint:
            self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        print(f" Loaded checkpoint from {path}")


#  Environment ---
class FlappyBirdEnv:
    def __init__(self, difficulty='normal', seed=None):
        if seed is not None:
            random.seed(seed)
            np.random.seed(seed)
        presets = {
            'easy':  {'PIPE_GAP':220,'PIPE_SPACING':300,'SCROLL_SPEED':2},
            'normal':{'PIPE_GAP':180,'PIPE_SPACING':280,'SCROLL_SPEED':3},
            'hard':  {'PIPE_GAP':150,'PIPE_SPACING':260,'SCROLL_SPEED':3}
        }
        p = presets[difficulty]
        self.PIPE_GAP = p['PIPE_GAP']
        self.PIPE_SPACING = p['PIPE_SPACING']
        self.SCROLL_SPEED = p['SCROLL_SPEED']
        self.GRAVITY = 0.5
        self.FLAP_VEL = -9.0
        self.MAX_VEL = 12.0
        self.SCREEN_WIDTH = 288
        self.SCREEN_HEIGHT = 512
        self.GROUND_HEIGHT = 112
        self.pipe_width = 52
        self.INIT_PIPE_OFFSET = 100
        self.MIN_GAP_Y = 50
        self.LIVING_REWARD = 0.1
        self.SCORE_REWARD = 10.0
        self.DEATH_PENALTY = -7
        self.VERTICAL_WEIGHT = 0.3
        self.VELOCITY_WEIGHT = 0.1
        self.CENTER_BONUS_MULT = 5.0
        self.APPROACHING_THRESHOLD = 0.3
        self.APPROACHING_MULTIPLIER = 2.0
        self.reset()

    def reset(self):
        self.bird_x = 80
        self.bird_y = self.SCREEN_HEIGHT//2
        self.bird_vel = 0.0
        self.pipes = []
        start_x = self.SCREEN_WIDTH + self.INIT_PIPE_OFFSET
        for i in range(3):
            gy = random.randint(self.MIN_GAP_Y, self.SCREEN_HEIGHT - self.MIN_GAP_Y - self.PIPE_GAP)
            self.pipes.append([start_x + i*self.PIPE_SPACING, gy])
        self.done = False
        self.score = 0
        self.scored_pipes = set()
        return self._get_state()

    def step(self, action):
        if action == 1:
            self.bird_vel = self.FLAP_VEL
        self.bird_vel += self.GRAVITY
        self.bird_vel = np.clip(self.bird_vel, -self.MAX_VEL, self.MAX_VEL)
        self.bird_y += self.bird_vel
        for p in self.pipes: p[0] -= self.SCROLL_SPEED
        if (self.pipes[0][0] + self.pipe_width) < 0:
            self.pipes.pop(0)
            new_x = self.pipes[-1][0] + self.PIPE_SPACING
            gy = random.randint(self.MIN_GAP_Y, self.SCREEN_HEIGHT - self.MIN_GAP_Y - self.PIPE_GAP)
            self.pipes.append([new_x, gy])
        if self._check_collision():
            self.done = True
            return self._get_state(), self.DEATH_PENALTY, True, {}
        just_scored = False
        for pipe in self.pipes:
            px, gy = pipe
            if (px + self.pipe_width) < self.bird_x and id(pipe) not in self.scored_pipes:
                if gy < self.bird_y < gy + self.PIPE_GAP:
                    self.score += 1
                    just_scored = True
                self.scored_pipes.add(id(pipe))
        dy_norm, vel_norm, dx_norm = self._get_normalized_values()
        if just_scored:
            reward = self.SCORE_REWARD + max(0, (1.0 - abs(dy_norm)) * self.CENTER_BONUS_MULT)
        else:
            reward = self.LIVING_REWARD - abs(dy_norm)*self.VERTICAL_WEIGHT - abs(vel_norm)*self.VELOCITY_WEIGHT
            if dx_norm < self.APPROACHING_THRESHOLD:
                reward *= self.APPROACHING_MULTIPLIER
        return self._get_state(), reward, self.done, {}

    def _check_collision(self):
        GROUND_Y = self.SCREEN_HEIGHT - self.GROUND_HEIGHT
        if self.bird_y <= 0 or self.bird_y >= GROUND_Y: return True
        for px, gy in self.pipes:
            if (self.bird_x + 12) > px and (self.bird_x - 12) < (px + self.pipe_width):
                if self.bird_y - 12 < gy or self.bird_y + 12 > gy + self.PIPE_GAP:
                    return True
        return False

    def _get_next_pipe(self):
        for px, gy in self.pipes:
            if px + self.pipe_width >= self.bird_x:
                return px, gy
        return self.pipes[0]

    def _get_normalized_values(self):
        px, gy = self._get_next_pipe()
        gap_center = gy + self.PIPE_GAP/2
        dy = (gap_center - self.bird_y)/float(self.PIPE_GAP)
        vel = self.bird_vel/self.MAX_VEL
        dx = (px - self.bird_x)/self.SCREEN_WIDTH
        return dy, vel, dx

    def _get_state(self):
        px, gy = self._get_next_pipe()
        gap_center = gy + self.PIPE_GAP/2
        dy = (gap_center - self.bird_y)/float(self.PIPE_GAP)
        vel = self.bird_vel/self.MAX_VEL
        dist = (px - self.bird_x)/self.SCREEN_WIDTH
        gap_y = gy/self.SCREEN_HEIGHT
        return np.array([dy, vel, dist, gap_y], dtype=np.float32)


def train_loop(num_episodes=EPI_NUMS, resume=False, difficulty="normal"):
    env = FlappyBirdEnv(difficulty=difficulty)
    agent = Agent()
    buffer = ReplayBuffer(BUFFER_SIZE)

    if resume and os.path.exists(CHECKPOINT_PATH):
        agent.load(CHECKPOINT_PATH)

    total_steps, losses, scores = 0, [], []

    print(f"Collecting {WARMUP_STEPS} random transitions for warmup...")
    state = env.reset()
    for _ in range(WARMUP_STEPS):
        a = random.randint(0, 1)
        ns, r, d, _ = env.step(a)
        buffer.push(state, a, r, ns, float(d))
        state = env.reset() if d else ns
    print(f"Warmup finished. Replay buffer size = {len(buffer)}")

    for ep in range(1, num_episodes+1):
        s = env.reset()
        done, ep_r = False, 0
        while not done:
            eps = EPSILON_FINAL + (EPSILON_START - EPSILON_FINAL) * math.exp(-total_steps/EPSILON_DECAY)
            a = agent.act(s, eps)
            ns, r, done, _ = env.step(a)
            buffer.push(s, a, r, ns, float(done))
            loss = agent.update(buffer, BATCH_SIZE)
            if loss: losses.append(loss)
            s = ns
            ep_r += r
            total_steps += 1

        scores.append(env.score)
        if ep % 50 == 0:
            print(f"Ep {ep:5d} | Steps {total_steps:7d} | Score {env.score:3d} | "
                  f"EpReward {ep_r:6.2f} | Eps {eps:.3f} | "
                  f"AvgScore50 {np.mean(scores[-50:]):.2f} | AvgLoss100 {np.mean(losses[-100:]):.4f}")
            agent.save(CHECKPOINT_PATH)

    agent.save(CHECKPOINT_PATH)
    print("Training finished. Model saved to", CHECKPOINT_PATH)

# run training
train_loop(num_episodes=EPI_NUMS, difficulty="hard", resume=False)


Using device: cuda
Collecting 10000 random transitions for warmup...
Warmup finished. Replay buffer size = 10000
Ep    50 | Steps    1684 | Score   0 | EpReward -21.54 | Eps 0.979 | AvgScore50 0.00 | AvgLoss100 0.1447
Ep   100 | Steps    3376 | Score   0 | EpReward -10.17 | Eps 0.959 | AvgScore50 0.00 | AvgLoss100 0.1161
Ep   150 | Steps    5062 | Score   0 | EpReward -13.39 | Eps 0.939 | AvgScore50 0.00 | AvgLoss100 0.0693
Ep   200 | Steps    6765 | Score   0 | EpReward -10.21 | Eps 0.920 | AvgScore50 0.00 | AvgLoss100 0.0357
Ep   250 | Steps    8479 | Score   0 | EpReward -12.54 | Eps 0.900 | AvgScore50 0.00 | AvgLoss100 0.0166
Ep   300 | Steps   10183 | Score   0 | EpReward -19.92 | Eps 0.882 | AvgScore50 0.00 | AvgLoss100 0.0158
Ep   350 | Steps   11902 | Score   0 | EpReward -10.10 | Eps 0.863 | AvgScore50 0.00 | AvgLoss100 0.0262
Ep   400 | Steps   13642 | Score   0 | EpReward -11.78 | Eps 0.845 | AvgScore50 0.00 | AvgLoss100 0.0374
Ep   450 | Steps   15338 | Score   0 | EpReward