<a href="https://colab.research.google.com/github/Shambhavi410/BCS-Winter-Project-STACK-O-MATIC/blob/main/BCS_Assignment4.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install gym
!pip install Box2D
!pip install tetris-gymnasium
import gymnasium as gym
from tetris_gymnasium.envs import Tetris
import matplotlib.pyplot as plt
import numpy as np
from collections import deque
import random
import torch
from torch import nn
import torch.nn.functional as F

In [None]:
from IPython.display import Video
from gym.wrappers import RecordVideo
class DQN(nn.Module):
    def __init__(self,in_states,h1_nodes,out_actions):
        super().__init__()
        self.fc1=nn.Linear(in_states,h1_nodes)
        self.out=nn.Linear(h1_nodes,out_actions)
    def forward(self,x):
        x=F.relu(self.fc1(x))
        x=self.out(x)
        return x
class ReplayMemory():
    def __init__(self,maxlen):
        self.memory=deque([],maxlen=maxlen)
    def append(self,transition):
        self.memory.append(transition)
    def sample(self,sample_size):
        return random.sample(self.memory,sample_size)
    def __len__(self):
        return len(self.memory)
class TetrisDQL():
    learning_rate_a=0.001
    discount_factor_g=0.90
    network_sync_rate=10
    replay_memory_size=1000
    mini_batch_size=32
    loss_fn=nn.MSELoss()
    optimizer=None
    ACTIONS=['L','R','D','RC','RAC','HARD_DROP','SWAP','NO_Operation']

    def preprocess_state(self,state):
        state_vector=[]
        for key in ['active_tetromino_mask','board','holder','queue']:
            state_vector.append(state[key].flatten())
        return np.concatenate(state_vector)

    def train(self,episodes,render=False):
        env=gym.make("tetris_gymnasium/Tetris")


        print(f"Observation Space: {env.observation_space}")


        num_states=sum([space.shape[0] * space.shape[1] for space in env.observation_space.values()])

        num_actions=env.action_space.n
        epsilon=1
        epsilon_min=0.001
        epsilon_decay=0.99
        memory=ReplayMemory(self.replay_memory_size)
        policy_dqn=DQN(in_states=num_states,h1_nodes=64,out_actions=num_actions)
        target_dqn=DQN(in_states=num_states,h1_nodes=64,out_actions=num_actions)
        target_dqn.load_state_dict(policy_dqn.state_dict())
        self.optimizer=torch.optim.Adam(policy_dqn.parameters(),lr=self.learning_rate_a)
        rewards_per_episode=np.zeros(episodes)
        epsilon_history=[]
        score=0

        for i in range(episodes):
            state=env.reset()
            if isinstance(state,tuple):
                state=state[0]
            state=self.preprocess_state(state)
            terminated=False
            episode_reward=0
            print(f'Episode {i+1}/{episodes},Epsilon: {epsilon:.4f}')
            while not terminated:
                if random.random() < epsilon:
                    action=env.action_space.sample()
                else:
                    with torch.no_grad():
                        action=policy_dqn(torch.FloatTensor(state).unsqueeze(0)).argmax().item()


                new_state,reward,terminated,truncated,info=env.step(action)

                if isinstance(new_state,tuple):
                    new_state=new_state[0]
                new_state=self.preprocess_state(new_state)
                memory.append((state,action,new_state,reward,terminated))
                state=new_state
                episode_reward += reward
                score += 1

            print(f'Episode {i+1} Reward: {episode_reward}')
            rewards_per_episode[i]=episode_reward
            if len(memory) > self.mini_batch_size:
                mini_batch=memory.sample(self.mini_batch_size)
                self.optimize(mini_batch,policy_dqn,target_dqn)
                epsilon=max(epsilon * epsilon_decay,epsilon_min)
                epsilon_history.append(epsilon)
                if score > self.network_sync_rate:
                    target_dqn.load_state_dict(policy_dqn.state_dict())
                    score=0

        env.close()
        torch.save(policy_dqn.state_dict(),'tetris_dqn.pt')
        plt.subplot(121)
        plt.plot(rewards_per_episode,label='Episode Reward')
        plt.xlabel('Episode')
        plt.ylabel('Reward')
        plt.title('Reward vs Episode')
        plt.legend()
        plt.subplot(122)
        plt.plot(epsilon_history,label='Epsilon')
        plt.xlabel('Episode')
        plt.ylabel('Epsilon')
        plt.title('Epsilon vs Episode')
        plt.legend()
        plt.show()

    def optimize(self,mini_batch,policy_dqn,target_dqn):
        num_states=policy_dqn.fc1.in_features
        num_actions=policy_dqn.out.out_features
        current_q_list=[]
        target_q_list=[]
        for state,action,new_state,reward,terminated in mini_batch:
            if terminated:
                target=torch.FloatTensor([reward])
            else:
                with torch.no_grad():
                    target=torch.FloatTensor(
                        [reward + self.discount_factor_g * target_dqn(torch.FloatTensor(new_state).unsqueeze(0)).max().item()]
                    )
            current_q=policy_dqn(torch.FloatTensor(state).unsqueeze(0))
            target_q=current_q.clone()
            target_q[0][action]=target
            current_q_list.append(current_q)
            target_q_list.append(target_q)
        loss=self.loss_fn(torch.cat(current_q_list),torch.cat(target_q_list))
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

    def test(self,episodes):
      env=gym.make("tetris_gymnasium/Tetris",render_mode='rgb_array')
      env=RecordVideo(env,video_folder="videos",episode_trigger=lambda e: True,name_prefix="test")
      num_states=sum([space.shape[0] * space.shape[1] for space in env.observation_space.values()])
      num_actions=env.action_space.n
      policy_dqn=DQN(in_states=num_states,h1_nodes=64,out_actions=num_actions)
      policy_dqn.load_state_dict(torch.load("tetris_dqn.pt"))
      policy_dqn.eval()
      for i in range(episodes):
          state=env.reset()
          if isinstance(state,tuple):
              state=state[0]
          state=self.preprocess_state(state)
          terminated=False
          episode_reward=0
          print(f'Test Episode {i+1}/{episodes}')
          while not terminated:
              with torch.no_grad():
                  action=policy_dqn(torch.FloatTensor(state).unsqueeze(0)).argmax().item()
              new_state,reward,terminated,info=env.step(action)
              if isinstance(new_state,tuple):
                  new_state=new_state[0]
              state=self.preprocess_state(new_state)
              episode_reward += reward
          print(f'Test Episode {i+1} Reward: {episode_reward}')

      env.close()

if __name__ == '__main__':
    tetris=TetrisDQL()
    tetris.train(20000)
    tetris.test(10)
