In [15]:
import gymnasium as gym
import flappy_bird_gymnasium
import torch
from PIL import Image
import numpy as np
import os
import time
import cv2  # For visualization

class DQN(torch.nn.Module):
    def __init__(self, input_channels=1, n_actions=2):
        super(DQN, self).__init__()
        self.conv_layers = torch.nn.Sequential(
            torch.nn.Conv2d(input_channels, 32, kernel_size=8, stride=4),
            torch.nn.ReLU(),
            torch.nn.Conv2d(32, 64, kernel_size=4, stride=2),
            torch.nn.ReLU(),
            torch.nn.Conv2d(64, 64, kernel_size=3, stride=1),
            torch.nn.ReLU()
        )
        
        self.fc_layers = torch.nn.Sequential(
            torch.nn.Linear(3136, 512),
            torch.nn.ReLU(),
            torch.nn.Linear(512, n_actions)
        )
    
    def forward(self, x):
        x = self.conv_layers(x)
        x = x.view(x.size(0), -1)
        return self.fc_layers(x)

def preprocess_frame(frame):
    """Convert and preprocess a single frame"""
    if frame is None:
        raise ValueError("Received None instead of a frame")
    
    # Convert to PIL Image and preprocess
    image = Image.fromarray(frame)
    gray = image.convert("L")
    resized = gray.resize((84, 84))
    
    # Convert to tensor and normalize
    tensor = torch.FloatTensor(np.array(resized)) / 255.0
    return tensor.unsqueeze(0).unsqueeze(0)  # Add batch and channel dimensions

def visualize_agent(model_path, num_episodes=20):
    """
    Load a trained model and visualize it playing Flappy Bird
    
    Args:
        model_path (str): Path to the saved model file
        num_episodes (int): Number of episodes to play
    """
    try:
        # Set up device
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"Using device: {device}")

        # Load model
        if not os.path.exists(model_path):
            raise FileNotFoundError(f"Model not found at {model_path}")
        
        # Initialize model and load weights
        model = DQN().to(device)
        checkpoint = torch.load(model_path, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        model.eval()
        
        print(f"Loaded model from episode {checkpoint['episode']} with avg reward: {checkpoint['reward']:.2f}")
        
        # Create environment with rgb_array render mode
        env = gym.make("FlappyBird-v0", render_mode="rgb_array")
        scores = []
        
        # Create window for visualization
        cv2.namedWindow('Flappy Bird', cv2.WINDOW_NORMAL)
        cv2.resizeWindow('Flappy Bird', 400, 600)
        
        for episode in range(num_episodes):
            state, _ = env.reset()
            frame = env.render()
            state = preprocess_frame(frame)
            episode_reward = 0
            done = False
            
            while not done:
                # Display the frame
                cv2.imshow('Flappy Bird', cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
                if cv2.waitKey(20) & 0xFF == ord('q'):  # Press 'q' to quit
                    return
                
                # Get action from model
                with torch.no_grad():
                    state = state.to(device)
                    q_values = model(state)
                    action = q_values.max(1)[1].item()
                
                # Take action
                next_state, reward, done, truncated, _ = env.step(action)
                
                # Get next state
                frame = env.render()
                next_state = preprocess_frame(frame)
                
                episode_reward += reward
                state = next_state
                
            scores.append(episode_reward)
            print(f"Episode {episode + 1}/{num_episodes} - Score: {episode_reward:.2f}")
        
        # Print statistics
        print("\nGame Statistics:")
        print(f"Average Score: {np.mean(scores):.2f}")
        print(f"Best Score: {max(scores):.2f}")
        print(f"Worst Score: {min(scores):.2f}")
        
    except Exception as e:
        print(f"Error during visualization: {str(e)}")
        import traceback
        traceback.print_exc()
    
    finally:
        env.close()
        cv2.destroyAllWindows()

if __name__ == "__main__":
    # Example usage
    model_path = "downloaded_models/dqn_best.pth"  # Change this to your model path
    visualize_agent(model_path)

Using device: cuda
Loaded model from episode 1618 with avg reward: 12.90


  checkpoint = torch.load(model_path, map_location=device)


Episode 1/20 - Score: 4.50
Episode 2/20 - Score: 1.50
Episode 3/20 - Score: -0.50
Episode 4/20 - Score: -4.80
Episode 5/20 - Score: 1.40
Episode 6/20 - Score: 1.50
Episode 7/20 - Score: 1.40
Episode 8/20 - Score: 1.40
Episode 9/20 - Score: 1.40
Episode 10/20 - Score: 0.90
Episode 11/20 - Score: 1.40
Episode 12/20 - Score: 3.50
Episode 13/20 - Score: -1.20


KeyboardInterrupt: 