In [1]:
import numpy as np
import pygame
import random
from collections import deque
import pickle
import matplotlib.pyplot as plt
from matplotlib.figure import Figure

class SnakeGame:
    def __init__(self, width=400, height=400, grid_size=20):
        self.width = width
        self.height = height
        self.grid_size = grid_size
        self.reset()

        pygame.init()
        self.screen = pygame.display.set_mode((width, height + 200))
        pygame.display.set_caption('Snake RL')
        self.clock = pygame.time.Clock()
        self.font = pygame.font.Font(None, 36)

    def reset(self):
        self.snake = [(self.width//(2*self.grid_size))*self.grid_size,
                      (self.height//(2*self.grid_size))*self.grid_size]
        self.snake_body = []
        self.food = self._place_food()
        self.direction = random.choice(['UP', 'DOWN', 'LEFT', 'RIGHT'])
        self.score = 0
        self.game_over = False
        self.steps_without_food = 0
        self.max_steps_without_food = 200  # Increased to give more exploration time
        self.current_direction_steps = 0  # Track steps in same direction
        self.max_direction_steps = 20  # Maximum steps in same direction
        return self._get_state()

    def _place_food(self):
        while True:
            food = [random.randint(1, (self.width-2*self.grid_size)//self.grid_size)*self.grid_size,
                   random.randint(1, (self.height-2*self.grid_size)//self.grid_size)*self.grid_size]
            if food not in self.snake_body and food != self.snake:
                return food

    def _get_state(self):
        head_x, head_y = self.snake
        food_x, food_y = self.food

        danger_straight = False
        danger_right = False
        danger_left = False

        # Current direction of movement
        current_dir_vector = {
            'UP': [0, -self.grid_size],
            'DOWN': [0, self.grid_size],
            'LEFT': [-self.grid_size, 0],
            'RIGHT': [self.grid_size, 0]
        }[self.direction]

        # Check direction and update dangers
        if self.direction == 'UP':
            danger_straight = (head_y - self.grid_size < 0) or ([head_x, head_y - self.grid_size] in self.snake_body)
            danger_right = (head_x + self.grid_size >= self.width) or ([head_x + self.grid_size, head_y] in self.snake_body)
            danger_left = (head_x - self.grid_size < 0) or ([head_x - self.grid_size, head_y] in self.snake_body)
        elif self.direction == 'DOWN':
            danger_straight = (head_y + self.grid_size >= self.height) or ([head_x, head_y + self.grid_size] in self.snake_body)
            danger_right = (head_x - self.grid_size < 0) or ([head_x - self.grid_size, head_y] in self.snake_body)
            danger_left = (head_x + self.grid_size >= self.width) or ([head_x + self.grid_size, head_y] in self.snake_body)
        elif self.direction == 'LEFT':
            danger_straight = (head_x - self.grid_size < 0) or ([head_x - self.grid_size, head_y] in self.snake_body)
            danger_right = (head_y - self.grid_size < 0) or ([head_x, head_y - self.grid_size] in self.snake_body)
            danger_left = (head_y + self.grid_size >= self.height) or ([head_x, head_y + self.grid_size] in self.snake_body)
        elif self.direction == 'RIGHT':
            danger_straight = (head_x + self.grid_size >= self.width) or ([head_x + self.grid_size, head_y] in self.snake_body)
            danger_right = (head_y + self.grid_size >= self.height) or ([head_x, head_y + self.grid_size] in self.snake_body)
            danger_left = (head_y - self.grid_size < 0) or ([head_x, head_y - self.grid_size] in self.snake_body)

        state = [
            danger_straight,
            danger_right,
            danger_left,
            self.direction == 'LEFT',
            self.direction == 'RIGHT',
            self.direction == 'UP',
            self.direction == 'DOWN',
            food_x < head_x,
            food_x > head_x,
            food_y < head_y,
            food_y > head_y,
            self.current_direction_steps >= self.max_direction_steps  # New state component
        ]
        return np.array(state, dtype=int)

    def step(self, action):
        self.steps_without_food += 1
        reward = 0  # Initialize reward

        # Get previous direction
        prev_direction = self.direction
        
        clock_wise = ['RIGHT', 'DOWN', 'LEFT', 'UP']
        idx = clock_wise.index(self.direction)

        if action == 0:  # Continue straight
            new_direction = clock_wise[idx]
        elif action == 1:  # Turn right
            new_direction = clock_wise[(idx + 1) % 4]
        else:  # Turn left
            new_direction = clock_wise[(idx - 1) % 4]

        # Update direction steps counter
        if new_direction == prev_direction:
            self.current_direction_steps += 1
        else:
            self.current_direction_steps = 0

        self.direction = new_direction

        # Calculate initial position and distance
        prev_dist = abs(self.snake[0] - self.food[0]) + abs(self.snake[1] - self.food[1])

        # Move snake
        x, y = self.snake
        if self.direction == 'UP':
            y -= self.grid_size
        elif self.direction == 'DOWN':
            y += self.grid_size
        elif self.direction == 'LEFT':
            x -= self.grid_size
        elif self.direction == 'RIGHT':
            x += self.grid_size

        self.snake_body.insert(0, list(self.snake))

        # Calculate new distance to food
        new_dist = abs(x - self.food[0]) + abs(y - self.food[1])

        # Reward for moving closer/further from food
        if new_dist < prev_dist:
            reward += 0.5
        else:
            reward -= 0.2

        # Penalty for moving in same direction too long
        if self.current_direction_steps >= self.max_direction_steps:
            reward -= 0.5

        # Check if food is eaten
        if x == self.food[0] and y == self.food[1]:
            self.score += 1
            reward += 10
            self.food = self._place_food()
            self.steps_without_food = 0
            self.current_direction_steps = 0  # Reset direction steps on food
        else:
            self.snake_body.pop()
            
            # Small penalty for not eating food
            reward -= 0.05
            
            # Penalty for being near walls
            if x <= self.grid_size or x >= self.width - self.grid_size or \
               y <= self.grid_size or y >= self.height - self.grid_size:
                reward -= 0.1

        # Check for game over conditions
        if (x < 0 or x >= self.width or
            y < 0 or y >= self.height or
            [x, y] in self.snake_body or
            self.steps_without_food >= self.max_steps_without_food):
            self.game_over = True
            reward = -10
            return self._get_state(), reward, True

        self.snake = [x, y]
        return self._get_state(), reward, False

    def render(self, stats=None):
        self.screen.fill((0, 0, 0))
        pygame.draw.rect(self.screen, (50, 50, 50),
                        pygame.Rect(0, 0, self.width, self.height))

        # Draw snake head
        pygame.draw.rect(self.screen, (0, 255, 0),
                        pygame.Rect(self.snake[0], self.snake[1],
                                  self.grid_size-2, self.grid_size-2))

        # Draw snake body
        for segment in self.snake_body:
            pygame.draw.rect(self.screen, (0, 200, 0),
                           pygame.Rect(segment[0], segment[1],
                                     self.grid_size-2, self.grid_size-2))

        # Draw food
        pygame.draw.rect(self.screen, (255, 0, 0),
                        pygame.Rect(self.food[0], self.food[1],
                                  self.grid_size-2, self.grid_size-2))

        # Draw stats
        if stats:
            episode, epsilon, avg_score = stats
            text_color = (255, 255, 255)
            texts = [
                f'Episode: {episode}',
                f'Score: {self.score}',
                f'Epsilon: {epsilon:.2f}',
                f'Avg Score: {avg_score:.2f}',
                f'Steps: {self.current_direction_steps}'  # Added steps counter
            ]
            for i, text in enumerate(texts):
                text_surface = self.font.render(text, True, text_color)
                self.screen.blit(text_surface, (10, self.height + 10 + i * 40))

        pygame.display.flip()
        self.clock.tick(10)

class QLearningAgent:
    def __init__(self, state_size, action_size):
        self.state_size = state_size
        self.action_size = action_size
        self.q_table = {}
        self.epsilon = 1.0
        self.epsilon_min = 0.1  # Increased minimum exploration
        self.epsilon_decay = 0.9995  # Much slower decay
        self.alpha = 0.2  # Learning rate
        self.gamma = 0.9  # Discount factor
        self.memory = deque(maxlen=10000)  # Increased memory size
        self.batch_size = 64
        self.min_random_moves = 3  # Minimum random moves before considering Q-values

    def get_state_key(self, state):
        return tuple(state)

    def get_action(self, state):
        # Force minimum random moves
        if random.random() < 0.2:  # 20% chance of making minimum random moves
            return random.randint(0, self.action_size - 1)

        state_key = self.get_state_key(state)
        if state_key not in self.q_table:
            self.q_table[state_key] = np.zeros(self.action_size)

        # Epsilon-greedy with temperature-based exploration
        if random.random() < self.epsilon:
            if random.random() < 0.7:  # 70% of exploration will be weighted random
                temp = 2.0
                probs = np.exp(self.q_table[state_key] / temp)
                probs = probs / np.sum(probs)
                return np.random.choice(self.action_size, p=probs)
            return random.randint(0, self.action_size - 1)
        
        # Add random noise to Q-values to break ties
        q_values = self.q_table[state_key] + np.random.uniform(0, 0.1, self.action_size)
        return np.argmax(q_values)

    def remember(self, state, action, reward, next_state):
        self.memory.append((state, action, reward, next_state))

    def learn(self, state, action, reward, next_state):
        state_key = self.get_state_key(state)
        next_state_key = self.get_state_key(next_state)

        if state_key not in self.q_table:
            self.q_table[state_key] = np.zeros(self.action_size)
        if next_state_key not in self.q_table:
            self.q_table[next_state_key] = np.zeros(self.action_size)

        old_value = self.q_table[state_key][action]
        next_max = np.max(self.q_table[next_state_key])
        new_value = (1 - self.alpha) * old_value + self.alpha * (reward + self.gamma * next_max)
        self.q_table[state_key][action] = new_value

    def replay(self):
        if len(self.memory) < self.batch_size:
            return
            
        batch = random.sample(self.memory, self.batch_size)
        for state, action, reward, next_state in batch:
            self.learn(state, action, reward, next_state)

        # Slower epsilon decay
        if self.epsilon > self.epsilon_min:
            self.epsilon *= self.epsilon_decay

class Plotter:
    def __init__(self):
        plt.style.use('dark_background')
        self.scores = []
        self.avg_scores = []
        self.epsilons = []
        self.running_avg_size = 100

    def update(self, score, epsilon):
        self.scores.append(score)
        self.epsilons.append(epsilon)
        avg_score = np.mean(self.scores[-self.running_avg_size:]) if len(self.scores) > 0 else 0
        self.avg_scores.append(avg_score)
        return avg_score

    def save_plots(self, episode):
        if episode % 100 == 0:
            # Plot scores
            plt.figure(figsize=(10, 5))
            plt.plot(self.scores, label='Score', alpha=0.4)
            plt.plot(self.avg_scores, label='Average Score', linewidth=2)
            plt.xlabel('Episode')
            plt.ylabel('Score')
            plt.title('Training Progress')
            plt.legend()
            plt.savefig('training_progress.png')
            plt.close()

            # Plot epsilon decay
            plt.figure(figsize=(10, 5))
            plt.plot(self.epsilons, label='Epsilon', linewidth=2)
            plt.xlabel('Episode')
            plt.ylabel('Epsilon')
            plt.title('Exploration Rate Decay')
            plt.legend()
            plt.savefig('epsilon_decay.png')
            plt.close()

def train():
    env = SnakeGame()
    agent = QLearningAgent(state_size=12, action_size=3)  # Updated state size to 12 due to new direction step state
    plotter = Plotter()
    episodes = 1000
    min_score_progress = -1
    consecutive_no_progress = 0
    max_no_progress = 50  # Episodes without progress before forcing exploration

    try:
        for episode in range(episodes):
            state = env.reset()
            total_reward = 0
            steps_in_episode = 0

            # Force exploration if stuck
            if consecutive_no_progress >= max_no_progress:
                agent.epsilon = min(1.0, agent.epsilon * 2)
                consecutive_no_progress = 0
                print(f"Forcing exploration! Epsilon increased to {agent.epsilon}")

            while not env.game_over:
                action = agent.get_action(state)
                next_state, reward, done = env.step(action)
                
                # Store experience and learn from batch
                agent.remember(state, action, reward, next_state)
                agent.replay()
                
                state = next_state
                total_reward += reward
                steps_in_episode += 1

                # Force episode end if stuck
                if steps_in_episode > 1000:
                    break

                # Calculate average score for display
                avg_score = plotter.update(env.score, agent.epsilon)

                # Render game with stats
                env.render(stats=(episode, agent.epsilon, avg_score))

                for event in pygame.event.get():
                    if event.type == pygame.QUIT:
                        raise KeyboardInterrupt

            # Update progress tracking
            if env.score > min_score_progress:
                min_score_progress = env.score
                consecutive_no_progress = 0
            else:
                consecutive_no_progress += 1

            # Save plots periodically
            plotter.save_plots(episode)

            if episode % 10 == 0:
                print(f'Episode: {episode}, Score: {env.score}, Min Progress: {min_score_progress}, '
                      f'Epsilon: {agent.epsilon:.2f}, No Progress: {consecutive_no_progress}')

    except KeyboardInterrupt:
        print("\nTraining interrupted by user")
    finally:
        # Save the trained agent
        print("\nSaving trained agent...")
        with open('snake_agent.pkl', 'wb') as f:
            pickle.dump(agent, f)
        print("Agent saved successfully!")
        pygame.quit()

def play_trained_agent():
    """
    Load and watch a trained agent play
    """
    try:
        with open('snake_agent.pkl', 'rb') as f:
            agent = pickle.load(f)
    except FileNotFoundError:
        print("No trained agent found! Please train the agent first.")
        return

    env = SnakeGame()
    state = env.reset()
    
    try:
        while True:
            action = agent.get_action(state)
            state, _, done = env.step(action)
            env.render(stats=(0, 0, env.score))  # Just show the score
            
            if done:
                print(f"Game Over! Score: {env.score}")
                state = env.reset()
                
            for event in pygame.event.get():
                if event.type == pygame.QUIT:
                    raise KeyboardInterrupt
                
    except KeyboardInterrupt:
        print("\nPlayback stopped by user")
    finally:
        pygame.quit()

if __name__ == "__main__":
    mode = input("Enter 'train' to train new agent or 'play' to watch trained agent: ").lower()
    if mode == 'train':
        train()
    elif mode == 'play':
        play_trained_agent()
    else:
        print("Invalid mode! Please enter either 'train' or 'play'")

pygame 2.6.1 (SDL 2.28.4, Python 3.12.4)
Hello from the pygame community. https://www.pygame.org/contribute.html
Episode: 0, Score: 1, Min Progress: 1, Epsilon: 1.00, No Progress: 0
Episode: 10, Score: 7, Min Progress: 7, Epsilon: 0.61, No Progress: 0
Episode: 20, Score: 4, Min Progress: 7, Epsilon: 0.40, No Progress: 10
Episode: 30, Score: 5, Min Progress: 7, Epsilon: 0.27, No Progress: 20
Episode: 40, Score: 3, Min Progress: 12, Epsilon: 0.15, No Progress: 4
Episode: 50, Score: 8, Min Progress: 12, Epsilon: 0.10, No Progress: 14
Episode: 60, Score: 7, Min Progress: 12, Epsilon: 0.10, No Progress: 24
Episode: 70, Score: 7, Min Progress: 12, Epsilon: 0.10, No Progress: 34
Episode: 80, Score: 0, Min Progress: 12, Epsilon: 0.10, No Progress: 44
Forcing exploration! Epsilon increased to 0.19990187928652475
Episode: 90, Score: 8, Min Progress: 12, Epsilon: 0.16, No Progress: 4
Episode: 100, Score: 4, Min Progress: 12, Epsilon: 0.10, No Progress: 14
Episode: 110, Score: 6, Min Progress: 12,

In [2]:
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
import argparse

# Text Preprocessing Function
def preprocess_text(text):
    """Cleans and preprocesses the input text."""
    text = text.strip().replace("\n", " ")
    return text

# Extractive Summarization Function
def extractive_summary(text):
    """Generates an extractive summary using Hugging Face's pipeline."""
    summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
    summary = summarizer(text, max_length=130, min_length=30, do_sample=False)
    return summary[0]['summary_text']

# Abstractive Summarization Function
def abstractive_summary(text):
    """Generates an abstractive summary using the T5 model."""
    model_name = "t5-small"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

    inputs = tokenizer.encode("summarize: " + text, return_tensors="pt", max_length=512, truncation=True)
    outputs = model.generate(inputs, max_length=150, min_length=40, length_penalty=2.0, num_beams=4, early_stopping=True)
    summary = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return summary

# CLI Tool
def main():
    parser = argparse.ArgumentParser(description="AI-Based Text Summarization Tool")
    parser.add_argument("--text", type=str, required=True, help="Input text to summarize")
    parser.add_argument("--method", type=str, choices=["extractive", "abstractive"], required=True, help="Summarization method")
    
    args = parser.parse_args()
    text = preprocess_text(args.text)

    if args.method == "extractive":
        summary = extractive_summary(text)
    elif args.method == "abstractive":
        summary = abstractive_summary(text)

    print("\nOriginal Text:\n", text)
    print("\nGenerated Summary:\n", summary)

if __name__ == "__main__":
    main()

  from .autonotebook import tqdm as notebook_tqdm
usage: ipykernel_launcher.py [-h] --text TEXT --method
                             {extractive,abstractive}
ipykernel_launcher.py: error: the following arguments are required: --text, --method


SystemExit: 2

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)
