In [None]:
!pip install gym-tetris
!apt-get install -y xvfb x11-utils
!pip install pyvirtualdisplay==0.2.*

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting gym-tetris
  Downloading gym_tetris-3.0.4-py3-none-any.whl (34 kB)
Collecting nes-py>=8.1.4
  Downloading nes_py-8.2.1.tar.gz (77 kB)
[K     |████████████████████████████████| 77 kB 4.4 MB/s 
Collecting pyglet<=1.5.21,>=1.4.0
  Downloading pyglet-1.5.21-py3-none-any.whl (1.1 MB)
[K     |████████████████████████████████| 1.1 MB 54.2 MB/s 
Building wheels for collected packages: nes-py
  Building wheel for nes-py (setup.py) ... [?25l[?25hdone
  Created wheel for nes-py: filename=nes_py-8.2.1-cp38-cp38-linux_x86_64.whl size=438555 sha256=27632337f07bee8a91f65f4d2326014424a9828591050046cfd812f6cbb559ba
  Stored in directory: /root/.cache/pip/wheels/17/e5/5c/8dfae61b44dbf56c458483aa09accef55a650e0527f6cbd872
Successfully built nes-py
Installing collected packages: pyglet, nes-py, gym-tetris
Successfully installed gym-tetris-3.0.4 nes-py-8.2.1 pyglet-1.5.21
Reading package lists.

In [None]:
#This is setting up the relevant packages
import gym
import math
import random
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from collections import namedtuple, deque
from itertools import count
from PIL import Image
from scipy.ndimage import zoom
import cv2

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as T
from torch.distributions import Categorical

from nes_py.wrappers import JoypadSpace
import gym_tetris
from gym_tetris.actions import MOVEMENT

import time

env = gym_tetris.make('TetrisA-v3').unwrapped
env = JoypadSpace(env, MOVEMENT)

#set up matplotlib
is_ipython = 'inline' in matplotlib.get_backend()
if is_ipython:
    from IPython import display

plt.ion()

# if gpu is to be used
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

  deprecation(
  deprecation(


In [None]:
#Actor
from torch.distributions import Categorical
class Actor(nn.Module):
  def __init__(self, inputSize, outputSize):
    super(Actor, self).__init__()
    self.lin1 = nn.Linear(inputSize, 128)
    self.lin2 = nn.Linear(128,256)
    self.head = nn.Linear(256, outputSize)

  def forward(self, state):
    state = state.flatten()
    output = F.relu(self.lin1(state))
    output = F.relu(self.lin2(output))
    return Categorical(F.softmax(self.head(output), dim=-1))

In [None]:
#Critic
class Critic(nn.Module):
  def __init__(self, inputSize):
    super(Critic, self).__init__()
    self.lin1 = nn.Linear(inputSize, 128)
    self.lin2 = nn.Linear(128, 256)
    self.head = nn.Linear(256, 1)

  def forward(self, state):
    state = state.flatten()
    output = F.relu(self.lin1(state))
    output = F.relu(self.lin2(output))
    return self.head(output)

In [None]:
def stateResize(state):
  state = np.moveaxis(state,2,-3)
  state = state[:,48:208,96:176]
  state = state > 0
  state = state.astype(np.uint8)*255
  state = zoom(state, (1,0.125, 0.125))
  state = state > 220
  state = state.astype(int)*255
  return torch.tensor([state], device=device).type('torch.FloatTensor')

In [None]:
n_actions = env.action_space.n
actor = Actor(600, n_actions).to(device)
critic = Critic(600).to(device)

actorOptimizer = optim.Adam(actor.parameters())
criticOptimizer = optim.Adam(critic.parameters())


def optimize(actorLoss, criticLoss):
  actorOptimizer.zero_grad()
  criticOptimizer.zero_grad()
  actorLoss.backward()
  criticLoss.backward()
  actorOptimizer.step()
  criticOptimizer.step()

  

In [None]:
def compute_returns(next_value, rewards, masks, gamma = 0.99):
  R = next_value
  returns = []
  for step in reversed(range(len(rewards))):
    R = rewards[step] + gamma * R * masks[step]
    returns.insert(0,R)
  return returns

In [None]:
num_episodes = 5
DISCOUNT_FACTOR = 0.99

for iter in range(num_episodes):
  state = env.reset()
  log_probs = []
  values = []
  rewards = []
  masks = []
  entropy = 0
  env.reset()
  score = 0
  I = 1
  state = stateResize(state)

  for i in count():
    #Policy distribution and value of current state
    dist, currValue = actor(state), critic(state)

    action = dist.sample()

    next_state, reward, done, info = env.step(action.item())

    score += reward

    next_state = stateResize(next_state)
    #Value of next state
    nextValue = critic(next_state)

    if done:
      next_state = torch.tensor([0]).float().unsqueeze(0).to(device)

    #Log probability of policy distribution
    log_prob = dist.log_prob(action).unsqueeze(0)

    #Calculating loss
    critic_loss = F.mse_loss(reward+ DISCOUNT_FACTOR * nextValue, currValue)
    critic_loss *= I

    advantage = reward + DISCOUNT_FACTOR * nextValue.item() - currValue.item()
    actor_loss = -log_prob * advantage
    actor_loss *= I
    
    optimize(actor_loss, critic_loss)

    if done:
      break
    state = next_state
    I *= DISCOUNT_FACTOR



In [None]:
torch.save(actor.state_dict(), 'actorNetwork')
torch.save(critic.state_dict(), 'criticNetwork')

In [None]:
actor.load_state_dict(torch.load('/content/actorNetwork', map_location = torch.device('cpu')))
critic.load_state_dict(torch.load('/content/criticNetwork', map_location = torch.device('cpu')))

<All keys matched successfully>

In [None]:
from gym.wrappers.monitoring.video_recorder import VideoRecorder
testVideo = "testVideo.mp4"
env = gym_tetris.make('TetrisA-v3').unwrapped
env = JoypadSpace(env, MOVEMENT)
video = VideoRecorder(env, testVideo)
duration = 0
state = env.reset()
state = stateResize(state)
from base64 import b64encode
def render_mp4(videopath: str) -> str:

  mp4 = open(videopath, 'rb').read()
  base64_encoded_mp4 = b64encode(mp4).decode()
  return f'<video width=400 controls><source src="data:video/mp4;' \
         f'base64,{base64_encoded_mp4}" type="video/mp4"></video>'

  deprecation(
  deprecation(
  logger.deprecation(
  logger.deprecation(
  logger.deprecation(
  logger.deprecation(


In [None]:

while True:
  video.capture_frame()

   # Select and perform an action
  dist, value = actor(state), critic(state)

  action = dist.sample()
  next_state, reward, done, info = env.step(action.item())
  reward = torch.tensor([reward], device=device)
  next_state = stateResize(next_state)

  # Move to the next state
  state = next_state
  if(done):
    break

video.close()


See here for more information: https://www.gymlibrary.ml/content/api/[0m
  deprecation(


In [None]:
from IPython.display import HTML
html = render_mp4(testVideo)
HTML(html)