# Student Submission Template - Implement Your RNN Agent

In [None]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from IPython.display import HTML

## Fill in Your Information

In [None]:
STUDENT_INFO = {
    "name": "Your Name",
    "student_id": "Your ID",
    "team_name": "Your Team",
    "description": "Brief description"
}

## Define Your Model

In [None]:
class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        # TODO: Define your network architecture
        self.lstm = nn.LSTM(10, 64, batch_first=True)
        self.fc = nn.Linear(64, 4)
    
    def forward(self, x, hidden=None):
        out, hidden = self.lstm(x, hidden)
        return self.fc(out[:, -1, :]), hidden

## Agent Class (Do Not Modify Class Name)

In [None]:
class StudentAgent(nn.Module):
    def __init__(self):
        super().__init__()
        self.name = STUDENT_INFO.get("team_name") or STUDENT_INFO["name"]
        self.info = STUDENT_INFO
        self.model = MyModel()
        self.hidden = None
    
    def reset(self):
        self.hidden = None
    
    def get_action(self, obs):
        x = torch.from_numpy(obs).float().unsqueeze(0).unsqueeze(0)
        with torch.no_grad():
            if self.hidden is not None:
                out, self.hidden = self.model(x, self.hidden)
            else:
                out = self.model(x)
                if isinstance(out, tuple):
                    out, self.hidden = out
            return torch.argmax(out, dim=-1).item()
    
    def forward(self, x, hidden=None):
        return self.model(x, hidden)

## Training Code

In [None]:
def train(data_path="train_X.npy", labels_path="train_Y.npy", 
          epochs=10, lr=0.001, batch_size=32):
    """Train model"""
    print("Loading data...")
    X = np.load(data_path)
    Y = np.load(labels_path)
    print(f"Loaded {len(X)} samples")
    
    agent = StudentAgent()
    print(f"Model has {sum(p.numel() for p in agent.parameters())} params")
    
    dataset = torch.utils.data.TensorDataset(
        torch.FloatTensor(X), torch.LongTensor(Y)
    )
    loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
    
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(agent.parameters(), lr=lr)
    
    for epoch in range(epochs):
        total_loss = 0
        correct = 0
        for batch_x, batch_y in loader:
            optimizer.zero_grad()
            out = agent.model(batch_x)
            if isinstance(out, tuple):
                out = out[0]
            loss = criterion(out, batch_y)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            correct += (torch.argmax(out, dim=1) == batch_y).sum().item()
        
        acc = 100 * correct / len(X)
        print(f"Epoch {epoch+1}/{epochs}: Loss={total_loss/len(loader):.4f}, Acc={acc:.2f}%")
    
    # Save model
    import os
    os.makedirs("submissions", exist_ok=True)
    save_path = f"submissions/{STUDENT_INFO['name'].replace(' ', '_').lower()}_agent.pth"
    torch.save(agent.state_dict(), save_path)
    print(f"\nSaved to {save_path}")
    return agent

## Main Program - Train Your Model

In [None]:
# Set training parameters
epochs = 10
lr = 0.001
batch_size = 32

# Train the model
agent = train(epochs=epochs, lr=lr, batch_size=batch_size)

## Visualize Tron Game (Similar to tron_env.py)

In [None]:
# Constants (same as tron_env.py)
GRID_SIZE = 20
CELL_SIZE = 20
FPS = 15

EMPTY = 0
P1_HEAD = 2
P1_TRAIL = 3
P2_HEAD = 4
P2_TRAIL = 5

# Colors (same as tron_env.py)
BLACK = (0, 0, 0)
BLUE = (50, 50, 255)
RED = (255, 50, 50)
DARK_BLUE = (0, 0, 150)
DARK_RED = (150, 0, 0)

def simulate_game_visualization(agent1, agent2=None, max_steps=200, fps=15):
    """Simulate game with tron_env.py style visualization"""
    grid = np.zeros((GRID_SIZE, GRID_SIZE), dtype=int)
    
    # Random spawn points with distance
    import random
    while True:
        p1_r, p1_c = random.randint(2, GRID_SIZE-3), random.randint(2, GRID_SIZE-3)
        p2_r, p2_c = random.randint(2, GRID_SIZE-3), random.randint(2, GRID_SIZE-3)
        if abs(p1_r - p2_r) + abs(p1_c - p2_c) > GRID_SIZE / 2:
            break
    
    p1_pos = [p1_r, p1_c]
    p2_pos = [p2_r, p2_c]
    grid[p1_r, p1_c] = P1_HEAD
    grid[p2_r, p2_c] = P2_HEAD
    
    agent1.reset()
    if agent2:
        agent2.reset()
    
    p1_dir = -1
    p2_dir = -1
    
    grid_history = [grid.copy()]
    
    moves = {0: (-1, 0), 1: (1, 0), 2: (0, -1), 3: (0, 1)}
    
    for step in range(max_steps):
        # Create observation (lidar style like tron_env.py)
        obs1 = np.zeros(10, dtype=np.float32)
        obs2 = np.zeros(10, dtype=np.float32)
        
        # 8-direction ray casting
        directions = [(-1, 0), (-1, 1), (0, 1), (1, 1), (1, 0), (1, -1), (0, -1), (-1, -1)]
        
        for i, (dr, dc) in enumerate(directions):
            dist = 0
            r, c = p1_pos
            while True:
                r += dr
                c += dc
                dist += 1
                if r < 0 or r >= GRID_SIZE or c < 0 or c >= GRID_SIZE or grid[r, c] != EMPTY:
                    break
            obs1[i] = min(dist, GRID_SIZE) / GRID_SIZE
        
        obs1[8] = p1_pos[0] / GRID_SIZE
        obs1[9] = p1_pos[1] / GRID_SIZE
        
        # Get actions
        action1 = agent1.get_action(obs1)
        action2 = agent2.get_action(obs2) if agent2 else random.randint(0, 3)
        
        # Handle invalid moves
        if p1_dir == 0 and action1 == 1: action1 = 0
        if p1_dir == 1 and action1 == 0: action1 = 1
        if p1_dir == 2 and action1 == 3: action1 = 2
        if p1_dir == 3 and action1 == 2: action1 = 3
        
        p1_dir = action1
        p2_dir = action2
        
        # Calculate new positions
        dr, dc = moves[action1]
        new_p1 = [p1_pos[0] + dr, p1_pos[1] + dc]
        
        dr, dc = moves[action2]
        new_p2 = [p2_pos[0] + dr, p2_pos[1] + dc]
        
        # Check collisions
        p1_crashed = (new_p1[0] < 0 or new_p1[0] >= GRID_SIZE or 
                     new_p1[1] < 0 or new_p1[1] >= GRID_SIZE or 
                     grid[new_p1[0], new_p1[1]] != EMPTY)
        p2_crashed = (new_p2[0] < 0 or new_p2[0] >= GRID_SIZE or 
                     new_p2[1] < 0 or new_p2[1] >= GRID_SIZE or 
                     grid[new_p2[0], new_p2[1]] != EMPTY)
        
        if p1_crashed or p2_crashed or new_p1 == new_p2:
            break
        
        # Update grid
        grid[p1_pos[0], p1_pos[1]] = P1_TRAIL
        grid[p2_pos[0], p2_pos[1]] = P2_TRAIL
        p1_pos = new_p1
        p2_pos = new_p2
        grid[p1_pos[0], p1_pos[1]] = P1_HEAD
        grid[p2_pos[0], p2_pos[1]] = P2_HEAD
        
        grid_history.append(grid.copy())
    
    # Create visualization
    fig, ax = plt.subplots(figsize=(10, 10))
    ax.set_facecolor('black')
    
    # Draw function
    def draw_grid(frame):
        ax.clear()
        ax.set_facecolor('black')
        ax.set_xlim(0, GRID_SIZE)
        ax.set_ylim(GRID_SIZE, 0)
        ax.set_aspect('equal')
        ax.axis('off')
        
        current_grid = grid_history[frame]
        
        for r in range(GRID_SIZE):
            for c in range(GRID_SIZE):
                cell_val = current_grid[r, c]
                if cell_val == P1_HEAD:
                    color = np.array(BLUE) / 255.0
                elif cell_val == P1_TRAIL:
                    color = np.array(DARK_BLUE) / 255.0
                elif cell_val == P2_HEAD:
                    color = np.array(RED) / 255.0
                elif cell_val == P2_TRAIL:
                    color = np.array(DARK_RED) / 255.0
                else:
                    continue
                
                rect = plt.Rectangle((c, r), 1, 1, facecolor=color, edgecolor='none')
                ax.add_patch(rect)
        
        ax.set_title(f'Blindfolded Tron - Step {frame}/{len(grid_history)-1}', 
                    color='white', fontsize=14)
    
    # Create animation
    anim = FuncAnimation(fig, draw_grid, frames=len(grid_history), 
                        interval=1000//fps, blit=False)
    
    plt.tight_layout()
    plt.close()
    
    return anim, grid_history

# Run visualization
anim, history = simulate_game_visualization(agent, max_steps=200, fps=15)
HTML(anim.to_jshtml())

## Final Game State

In [None]:
# Display the final state of the game
fig, ax = plt.subplots(figsize=(10, 10))
ax.set_facecolor('black')
ax.set_xlim(0, GRID_SIZE)
ax.set_ylim(GRID_SIZE, 0)
ax.set_aspect('equal')
ax.axis('off')

final_grid = history[-1]

for r in range(GRID_SIZE):
    for c in range(GRID_SIZE):
        cell_val = final_grid[r, c]
        if cell_val == P1_HEAD:
            color = np.array(BLUE) / 255.0
        elif cell_val == P1_TRAIL:
            color = np.array(DARK_BLUE) / 255.0
        elif cell_val == P2_HEAD:
            color = np.array(RED) / 255.0
        elif cell_val == P2_TRAIL:
            color = np.array(DARK_RED) / 255.0
        else:
            continue
        
        rect = plt.Rectangle((c, r), 1, 1, facecolor=color, edgecolor='none')
        ax.add_patch(rect)

ax.set_title(f'Final Game State - Step {len(history)-1}', color='white', fontsize=14)
plt.tight_layout()
plt.show()