In [None]:
!pip install pygame



In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
import os
os.environ['SDL_VIDEODRIVER']='dummy'

import math
import random
import pygame
import numpy as np
import pickle
from collections import namedtuple
from itertools import count
from PIL import Image
from pygame.surfarray import array3d, pixels_alpha

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as T

In [None]:
torch.cuda.empty_cache()

In [None]:
class Environment():
    pygame.init()

    width, height = (288, 512)
    screen = pygame.display.set_mode((width, height))

    clock = pygame.time.Clock()
    fps = 30

    # Background
    background = pygame.image.load("/content/drive/MyDrive/Colab_Notebooks/FlappyBird/assets/sprites/background-day.png").convert()

    # Floor
    floor = pygame.image.load("/content/drive/MyDrive/Colab_Notebooks/FlappyBird/assets/sprites/base.png").convert()
    floor_limit = background.get_width() - floor.get_width()

    # Bird
    bird_downflap = pygame.image.load("/content/drive/MyDrive/Colab_Notebooks/FlappyBird/assets/sprites/yellowbird-downflap.png").convert_alpha()
    bird_midflap = pygame.image.load("/content/drive/MyDrive/Colab_Notebooks/FlappyBird/assets/sprites/yellowbird-midflap.png").convert_alpha()
    bird_upflap = pygame.image.load("/content/drive/MyDrive/Colab_Notebooks/FlappyBird/assets/sprites/yellowbird-upflap.png").convert_alpha()

    bird_index = 0
    bird_frames = [bird_downflap, bird_midflap, bird_upflap]
    bird = bird_frames[bird_index]

    bird_hitmask = [pixels_alpha(image).astype(bool) for image in bird_frames]

    init_pos = (
        int(width * 0.2),
        int(height / 2)
    )

    # Pipe
    pipe_surface = pygame.image.load("/content/drive/MyDrive/Colab_Notebooks/FlappyBird/assets/sprites/pipe-green.png").convert_alpha()
    pipes = []

    '''
    pipe_gap = 100
    pipe_min = int(pipe_gap / 4)
    pipe_max = int(height * 0.79 * 0.6 - pipe_gap / 2)
    '''

    pipe_gap = 125

    pipe_min = 1
    pipe_max = 5


    def __init__(self):
        self.floorx, self.floory = (0, self.background.get_height() - self.floor.get_height())
        self.bird_rect = self.bird.get_rect(center=self.init_pos)

        # Game variables
        self.GRAVITY = 1
        self.FLAP_POWER = 9
        self.MAX_DROP_SPEED = 15

        # Velocity on y and x
        self.vel = 0
        self.speed = 4

        # Score
        self.score = 0

        self.tick = 0

        self.pipes.extend(self._generate_pipes(offset=int(0.5 * self.width)))
        self.pipes.extend(self._generate_pipes(offset=int(0.5 * self.pipe_surface.get_width() + self.width)))


    def _generate_pipes(self, offset=0):
        # gap_start = random.randint(self.pipe_min, self.pipe_max)
        gap_start = random.randint(self.pipe_min, self.pipe_max+1)*25 + 50

        top_bottom = gap_start - self.pipe_surface.get_height()
        bottom_top = gap_start + self.pipe_gap

        top_pipe = self.pipe_surface.get_rect(topleft=(self.width + offset, top_bottom))
        bottom_pipe = self.pipe_surface.get_rect(topleft=(self.width + offset, bottom_top))

        return top_pipe, bottom_pipe


    def _is_collided(self):
        # out-of-screen
        if self.bird_rect.top <= - self.bird.get_height() * 0.5 or self.bird_rect.bottom >= self.floory:
            return True

        # mask = self.bird_hitmask[self.bird_index]
        mask = pixels_alpha(self.rotate_bird()).astype(bool)
        rows, columns = mask.shape

        # pipe collison
        for pipe in self.pipes:
            lx, rx = pipe.x, pipe.x + self.pipe_surface.get_width()
            ty, by = pipe.y, pipe.y + self.pipe_surface.get_height()

            for i in range(rows):
                for j in range(columns):
                    posx, posy = self.bird_rect.x + j, self.bird_rect.y + i
                    if mask[i, j] and lx < posx < rx and ty < posy < by:
                        return True

            '''      
            if self.bird_rect.colliderect(pipe):
                return True
            '''

        return False


    def rotate_bird(self):
        return pygame.transform.rotozoom(self.bird, -self.vel * 3, 1)


    def bird_animation(self):
        new_bird = self.bird_frames[self.bird_index]
        new_bird_rect = new_bird.get_rect(center=(100, self.bird_rect.centery))
        return new_bird, new_bird_rect


    def step(self, action):
        pygame.event.pump()

        # reward to stay alive
        reward = 0.1

        # terminal
        terminal = False

        # update tick
        self.tick += 1

        # Velocity updating
        if self.vel < self.MAX_DROP_SPEED:
            self.vel += self.GRAVITY

        if action == 1:
            self.vel = 0
            self.vel -= self.FLAP_POWER

        # bird movement
        self.bird_rect.centery += self.vel

        # floor movement
        self.floorx -= 1
        if self.floorx < self.floor_limit:
            self.floorx = 0

        # pipes' movement
        for pipe in self.pipes:
            pipe.centerx -= self.speed

        # Check whether bird passes the pipe or not
        for pipe in self.pipes:
            if pipe.centerx < self.bird_rect.centerx <= pipe.centerx + self.speed:
                reward = 1
                self.score += 1
                break

        # Update pipes
        if self.pipes[0].x <= -self.pipe_surface.get_width():
            self.pipes.extend(self._generate_pipes())

        # delete top and bottom pipes
        if self.pipes[0].x <= -self.pipe_surface.get_width():
            del self.pipes[0]
            del self.pipes[0]

        if (self.tick + 1) % 15 == 0:
            self.bird_index = (self.bird_index + 1) % 3
            self.bird, self.bird_rect = self.bird_animation()

        if self._is_collided():
            reward = -1
            terminal = True

        # draw
        self.screen.blit(self.background, (0, 0))

        for i, pipe in enumerate(self.pipes):
            if i % 2 == 0:
                flip_pipe = pygame.transform.flip(self.pipe_surface, False, True)
                self.screen.blit(flip_pipe, pipe)
            else:
                self.screen.blit(self.pipe_surface, pipe)

        self.screen.blit(self.floor, (self.floorx, self.floory))

        rotated_bird = self.rotate_bird()
        self.screen.blit(rotated_bird, self.bird_rect)

        # self.screen.blit(self.bird, self.bird_rect)

        pygame.display.update()
        screen = array3d(pygame.display.get_surface())
        self.clock.tick(self.fps)

        return screen, reward, terminal
        
        
    def get_screen(self):
        return array3d(pygame.display.get_surface())


    def reset(self):
        self.pipes.clear()
        self.__init__()
        
    
    def quit(self):
        pygame.quit()


In [None]:
class Environment():
    pygame.init()

    width, height = (288, 512)
    screen = pygame.display.set_mode((width, height))

    clock = pygame.time.Clock()
    fps = 30

    # Background
    background = pygame.image.load("/content/drive/MyDrive/Colab_Notebooks/FlappyBird/assets/sprites/background-day.png").convert()

    # Floor
    floor = pygame.image.load("/content/drive/MyDrive/Colab_Notebooks/FlappyBird/assets/sprites/base.png").convert()
    floor_limit = background.get_width() - floor.get_width()

    # Bird
    bird_downflap = pygame.image.load("/content/drive/MyDrive/Colab_Notebooks/FlappyBird/assets/sprites/yellowbird-downflap.png").convert_alpha()
    bird_midflap = pygame.image.load("/content/drive/MyDrive/Colab_Notebooks/FlappyBird/assets/sprites/yellowbird-midflap.png").convert_alpha()
    bird_upflap = pygame.image.load("/content/drive/MyDrive/Colab_Notebooks/FlappyBird/assets/sprites/yellowbird-upflap.png").convert_alpha()

    bird_index = 0
    bird_frames = [bird_downflap, bird_midflap, bird_upflap]
    bird = bird_frames[bird_index]

    bird_hitmask = [pixels_alpha(image).astype(bool) for image in bird_frames]

    init_pos = (
        int(width * 0.2),
        int(height / 2)
    )

    # Pipe
    pipe_surface = pygame.image.load("/content/drive/MyDrive/Colab_Notebooks/FlappyBird/assets/sprites/pipe-green.png").convert_alpha()
    pipes = []

    pipe_gap = 125
    pipe_min = 1
    pipe_max = 5

    def _init_(self):
        self.floorx, self.floory = (0, self.background.get_height() - self.floor.get_height())
        # self.floorx, self.floory = (0, self.height * 0.79)
        self.bird_rect = self.bird.get_rect(center=self.init_pos)

        # Game variables
        self.GRAVITY = 1
        self.FLAP_POWER = 9
        self.MAX_DROP_SPEED = 15

        # Velocity on y and x
        self.vel = 0
        self.speed = 4

        # Score
        self.score = 0

        self.tick = 0

        self.pipes.extend(self._generate_pipes(offset=(0.5 * self.width)))
        self.pipes.extend(self._generate_pipes(offset=int(0.5 * self.pipe_surface.get_width() + self.width)))


    def _generate_pipes(self, offset=0):
        # gap_start = random.randint(self.pipe_min, self.pipe_max)
        gap_start = random.randint(self.pipe_min, self.pipe_max + 1) * 25 + 50

        top_bottom = gap_start - self.pipe_surface.get_height()
        bottom_top = gap_start + self.pipe_gap

        top_pipe = self.pipe_surface.get_rect(topleft=(self.width + offset, top_bottom))
        bottom_pipe = self.pipe_surface.get_rect(topleft=(self.width + offset, bottom_top))

        return top_pipe, bottom_pipe


    def _is_collided(self):
        # out-of-screen
        if self.bird_rect.top < - self.bird.get_height() * 0.1 or self.bird_rect.bottom >= self.floory:
            return True

        # mask = self.bird_hitmask[self.bird_index]
        mask = pixels_alpha(self.rotate_bird()).astype(bool)
        rows, columns = mask.shape

        # pipe collison
        for pipe in self.pipes:
            lx, rx = pipe.x, pipe.x + self.pipe_surface.get_width()
            ty, by = pipe.y, pipe.y + self.pipe_surface.get_height()

            for i in range(rows):
                for j in range(columns):
                    # posx, posy = self.bird_rect.x + j, self.bird_rect.y + i
                    posx, posy = self.bird_rect.x + i, self.bird_rect.y + j
                    if mask[i, j] and lx < posx < rx and ty < posy < by:
                        return True

        return False


    def rotate_bird(self):
        return pygame.transform.rotozoom(self.bird, -self.vel * 3, 1)


    def bird_animation(self):
        new_bird = self.bird_frames[self.bird_index]
        new_bird_rect = new_bird.get_rect(center=(100, self.bird_rect.centery))
        return new_bird, new_bird_rect


    def step(self, action):
        pygame.event.pump()

        # reward to stay alive
        reward = 0.1

        # terminal
        terminal = False

        self.tick += 1

        # Velocity updating
        if self.vel < self.MAX_DROP_SPEED:
            self.vel += self.GRAVITY

        if action == 1:
            self.vel = 0
            self.vel -= self.FLAP_POWER


        # Check whether bird passes the pipe or not
        for pipe in self.pipes:
            if pipe.centerx < self.bird_rect.centerx <= pipe.centerx + self.speed:
                reward = 1
                self.score += 1
                break

        # bird movement
        self.bird_rect.centery += self.vel

        # floor movement
        self.floorx -= 1
        if self.floorx < self.floor_limit:
            self.floorx = 0

        # pipes' movement
        for pipe in self.pipes:
            pipe.centerx -= self.speed

        # Update pipes
        if self.pipes[0].x <= -self.pipe_surface.get_width():
            self.pipes.extend(self._generate_pipes())

        # delete top and bottom pipes
        if self.pipes[0].x <= -self.pipe_surface.get_width():
            del self.pipes[0]
            del self.pipes[0]

        if (self.tick + 1) % 15 == 0:
            self.bird_index = (self.bird_index + 1) % 3
            self.bird, self.bird_rect = self.bird_animation()

        if self._is_collided():
            reward = -1
            terminal = True

        # draw
        self.screen.blit(self.background, (0, 0))

        for i, pipe in enumerate(self.pipes):
            if i % 2 == 0:
                flip_pipe = pygame.transform.flip(self.pipe_surface, False, True)
                self.screen.blit(flip_pipe, pipe)
            else:
                self.screen.blit(self.pipe_surface, pipe)

        self.screen.blit(self.floor, (self.floorx, self.floory))

        rotated_bird = self.rotate_bird()
        self.screen.blit(rotated_bird, self.bird_rect)

        pygame.display.update()
        screen = pygame.surfarray.array3d(pygame.display.get_surface())
        self.clock.tick(self.fps)

        return screen, reward, terminal


    def get_screen(self):
        return pygame.surfarray.array3d(pygame.display.get_surface())


    def reset(self):
        self.pipes.clear()
        self._init_()
        
    
    def quit(self):
        pygame.quit()

In [None]:
class DQN(nn.Module):
    def __init__(self, n_actions):
        super(DQN, self).__init__()
        # 84x84x4
        self.conv1 = nn.Conv2d(in_channels=4, out_channels=64, kernel_size=8, stride=4)
        self.bn1 = nn.BatchNorm2d(64)
        # 20x20x32
        self.conv2 = nn.Conv2d(in_channels=64, out_channels=32, kernel_size=4, stride=2)
        self.bn2 = nn.BatchNorm2d(32)
        # 9x9x32
        self.conv3 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=1)
        self.bn3 = nn.BatchNorm2d(32)
        # 7x7x16
        self.fc1 = nn.Linear(7 * 7 * 32, 256)
        self.fc2 = nn.Linear(256, n_actions)

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        x = F.relu(self.fc1(x.view(x.size(0), -1)))
        return self.fc2(x)

In [None]:
# create env
env = Environment()

# if gpu is to be used
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


Transition = namedtuple('Transition',
                        ('state', 'action', 'next_state', 'reward'))


class ReplayMemory(object):

    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = []
        self.position = 0

    def push(self, *args):
        """Saves a transition."""
        if len(self.memory) < self.capacity:
            self.memory.append(None)
        self.memory[self.position] = Transition(*args)
        self.position = (self.position + 1) % self.capacity

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)


'\nTransition = namedtuple(\'Transition\',\n                        (\'state\', \'action\', \'next_state\', \'reward\'))\n\n\nclass ReplayMemory(object):\n\n    def __init__(self, capacity):\n        self.capacity = capacity\n        self.memory = []\n        self.position = 0\n\n    def push(self, *args):\n        """Saves a transition."""\n        if len(self.memory) < self.capacity:\n            self.memory.append(None)\n        self.memory[self.position] = Transition(*args)\n        self.position = (self.position + 1) % self.capacity\n\n    def sample(self, batch_size):\n        return random.sample(self.memory, batch_size)\n\n    def __len__(self):\n        return len(self.memory)\n'

In [None]:
def init_weights(m):
    if type(m) == nn.Conv2d or type(m) == nn.Linear:
        torch.nn.init.xavier_normal_(m.weight)
        m.bias.data.fill_(0.01)

In [None]:
# steps_done = 0
# steps_done = 650000
# steps_done = 1465500
# steps_done = 1842549
steps_done = 2258029

def select_action(state):
    global steps_done
    sample = random.random()

    eps_threshold = EPS_END + (EPS_START - EPS_END) * \
        math.exp(-1. * steps_done / EPS_DECAY)

    steps_done += 1
    if sample > eps_threshold:
        with torch.no_grad():
            # t.max(1) will return largest column value of each row.
            # second column on max result is index of where max element was
            # found, so we pick action with the larger expected reward.
            return policy_net(state).max(1)[1].view(1, 1)
    else:
        return torch.tensor([[random.choices([0, 1], [0.5, 0.5])[0]]], device=device, dtype=torch.long)

In [None]:
resize = T.Compose([T.ToPILImage(),
                    T.Grayscale(),
                    T.CenterCrop((288, 288)),
                    T.Resize(84),
                    T.ToTensor()])

# hyperparameters
BATCH_SIZE = 64
GAMMA = 0.99 # discount factor
EPS_START = 1
EPS_END = 0.1
EPS_DECAY = 400000
TARGET_UPDATE = 10

n_actions = 2

# create networks
policy_net = DQN(n_actions).to(device)
target_net = DQN(n_actions).to(device)
target_net.load_state_dict(policy_net.state_dict())

optimizer = torch.optim.Adam(policy_net.parameters(), lr=1e-6)
criterion = nn.MSELoss()

memory = ReplayMemory(10000)

average_per_episode = []
score_per_episode = []
seq_length = []

num_episodes = int(1e5)
for i_episode in range(num_episodes):
    rewards = 0

    # Initialize the environment and state
    env.reset()

    init_screen = env.get_screen()

    init_screen = resize(init_screen)
    # (N, C, H, W): (1, 4, 84, 84)
    state = torch.cat(tuple(init_screen for _ in range(4))).unsqueeze(0).to(device)

    for t in count():
        # epsilon-greedy search (select action)
        action = select_action(state)
        screen, reward, done = env.step(action.item())
        rewards += reward

        screen = resize(screen).to(device)
        # next_state = resize(screen).unsqueeze(0).to(device)
        reward = torch.tensor([reward], device=device)

        next_state = torch.cat((state[0, 1:], screen)).unsqueeze(0)

        # Store the transition in memory
        memory.push(state, action, next_state, reward)

        # Move to the next state
        state = next_state

        if len(memory) >= 2 * BATCH_SIZE:
            for j in range(2):
                transitions = memory.sample(BATCH_SIZE)

                # Transpose the batch (see https://stackoverflow.com/a/19343/3343043 for
                # detailed explanation). This converts batch-array of Transitions
                # to Transition of batch-arrays.
                batch = Transition(*zip(*transitions))

                # Compute a mask of non-final states and concatenate the batch elements
                # (a final state would've been the one after which simulation ended)
                non_final_mask = torch.tensor(tuple(map(lambda s: s != -1,
                                                        batch.reward)), device=device, dtype=torch.bool)
                
                non_final_next_states = torch.cat([s for s, r in zip(batch.next_state, batch.reward)
                                                if r != -1])

                state_batch = torch.cat(batch.state)
                action_batch = torch.cat(batch.action)
                reward_batch = torch.cat(batch.reward)

                # Compute Q(s_t, a) - the model computes Q(s_t), then we select the
                # columns of actions taken. These are the actions which would've been taken
                # for each batch state according to policy_net
                state_action_values = policy_net(state_batch).gather(1, action_batch)

                # Compute V(s_{t+1}) for all next states.
                # Expected values of actions for non_final_next_states are computed based
                # on the "older" target_net; selecting their best reward with max(1)[0].
                # This is merged based on the mask, such that we'll have either the expected
                # state value or 0 in case the state was final.
                next_state_values = torch.zeros(BATCH_SIZE, device=device)
                next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0].detach()

                # Compute the expected Q values
                expected_state_action_values = (next_state_values * GAMMA) + reward_batch

                loss = criterion(state_action_values, expected_state_action_values.unsqueeze(1))

                # Optimize the model
                optimizer.zero_grad()
                loss.backward()

                '''
                for param in policy_net.parameters():
                    param.grad.data.clamp_(-1, 1)
                '''

                optimizer.step()

        if done or t > 2000:
            average_per_episode.append(rewards/t)
            score_per_episode.append(env.score)
            seq_length.append(t)
            break
    
    print('Episode: {}, Score: {}, Reward: {}'.format(i_episode, env.score, rewards))

    # Update the target network, copying all weights and biases in DQN
    if i_episode % TARGET_UPDATE == 0:
        target_net.load_state_dict(policy_net.state_dict())

    if i_episode % 2500 == 0:
        PATH = '/content/drive/MyDrive/Colab Notebooks/FlappyBird/state_dict_model_' + str(i_episode) + '.pt'
        torch.save(policy_net.state_dict(), PATH, _use_new_zipfile_serialization=False)
        

print('Complete')

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Episode: 10010, Score: 0, Reward: 2.2000000000000015
Episode: 10011, Score: 0, Reward: 4.799999999999995
Episode: 10012, Score: 0, Reward: 2.5000000000000018
Episode: 10013, Score: 0, Reward: 2.1000000000000014
Episode: 10014, Score: 0, Reward: 3.0000000000000018
Episode: 10015, Score: 0, Reward: 2.7000000000000024
Episode: 10016, Score: 0, Reward: 5.5999999999999925
Episode: 10017, Score: 0, Reward: 3.1000000000000014
Episode: 10018, Score: 0, Reward: 3.5999999999999996
Episode: 10019, Score: 0, Reward: 2.900000000000002
Episode: 10020, Score: 0, Reward: 2.2000000000000015
Episode: 10021, Score: 0, Reward: 4.299999999999997
Episode: 10022, Score: 0, Reward: 2.800000000000002
Episode: 10023, Score: 0, Reward: 2.4000000000000017
Episode: 10024, Score: 0, Reward: 2.1000000000000014
Episode: 10025, Score: 0, Reward: 2.2000000000000015
Episode: 10026, Score: 0, Reward: 2.900000000000002
Episode: 10027, Score: 0, Reward: 3.200

KeyboardInterrupt: ignored

In [None]:
steps_done

2258545

In [None]:
import pickle

with open('/content/drive/MyDrive/Colab_Notebooks/FlappyBird/avg.txt', 'wb') as fp:
    pickle.dump(average_per_episode, fp)

with open('/content/drive/MyDrive/Colab_Notebooks/FlappyBird/score.txt', 'wb') as fp:
    pickle.dump(score_per_episode, fp)

with open('/content/drive/MyDrive/Colab_Notebooks/FlappyBird/sequence.txt', 'wb') as fp:
    pickle.dump(seq_length, fp)

In [None]:
with open ('/content/drive/MyDrive/Colab Notebooks/FlappyBird/avg.txt', 'rb') as fp:
    average_per_episode = pickle.load(fp)

with open ('/content/drive/MyDrive/Colab Notebooks/FlappyBird/score.txt', 'rb') as fp:
    score_per_episode = pickle.load(fp)

with open ('/content/drive/MyDrive/Colab Notebooks/FlappyBird/sequence.txt', 'rb') as fp:
    seq_length = pickle.load(fp)

In [None]:
resize = T.Compose([T.ToPILImage(),
                    T.Grayscale(),
                    T.CenterCrop((288, 288)),
                    T.Resize(84),
                    T.ToTensor()])

# hyperparameters
BATCH_SIZE = 64
GAMMA = 0.99 # discount factor
EPS_START = 1
EPS_END = 0.05
EPS_DECAY = 400000
TARGET_UPDATE = 10

n_actions = 2

# create networks
policy_net = DQN(n_actions).to(device)
policy_net.load_state_dict(torch.load('/content/drive/MyDrive/Colab_Notebooks/FlappyBird/state_dict_model_27500.pt'))

target_net = DQN(n_actions).to(device)
target_net.load_state_dict(policy_net.state_dict())

optimizer = torch.optim.Adam(policy_net.parameters(), lr=1e-6)
# optimizer = torch.optim.RMSprop(policy_net.parameters(), lr=1e-6, weight_decay=0.9, momentum=0.95)
criterion = nn.MSELoss()

# memory = ReplayMemory(10000)


'''
with open ('/content/drive/MyDrive/Colab_Notebooks/FlappyBird/avg.txt', 'rb') as fp:
    average_per_episode = pickle.load(fp)

with open ('/content/drive/MyDrive/Colab_Notebooks/FlappyBird/score.txt', 'rb') as fp:
    score_per_episode = pickle.load(fp)

with open ('/content/drive/MyDrive/Colab_Notebooks/FlappyBird/sequence.txt', 'rb') as fp:
    seq_length = pickle.load(fp)
'''

num_episodes = int(1e5)
for i_episode in range(30001, num_episodes):
    rewards = 0

    # Initialize the environment and state
    env.reset()

    init_screen = env.get_screen()

    init_screen = resize(init_screen)
    # (N, C, H, W): (1, 4, 84, 84)
    state = torch.cat(tuple(init_screen for _ in range(4))).unsqueeze(0).to(device)
    
    # state = resize(init_screen).unsqueeze(0).to(device)

    for t in count():
        # epsilon-greedy search (select action)
        action = select_action(state)
        screen, reward, done = env.step(action.item())
        rewards += reward

        screen = resize(screen).to(device)
        # next_state = resize(screen).unsqueeze(0).to(device)
        reward = torch.tensor([reward], device=device)

        next_state = torch.cat((state[0, 1:], screen)).unsqueeze(0)

        # Store the transition in memory
        memory.push(state, action, next_state, reward)

        # Move to the next state
        state = next_state

        if len(memory) >= 2 * BATCH_SIZE:
            for j in range(2):
                transitions = memory.sample(BATCH_SIZE)

                # Transpose the batch (see https://stackoverflow.com/a/19343/3343043 for
                # detailed explanation). This converts batch-array of Transitions
                # to Transition of batch-arrays.
                batch = Transition(*zip(*transitions))

                # Compute a mask of non-final states and concatenate the batch elements
                # (a final state would've been the one after which simulation ended)
                non_final_mask = torch.tensor(tuple(map(lambda s: s != -1,
                                                        batch.reward)), device=device, dtype=torch.bool)
                
                non_final_next_states = torch.cat([s for s, r in zip(batch.next_state, batch.reward)
                                                if r != -1])

                state_batch = torch.cat(batch.state)
                action_batch = torch.cat(batch.action)
                reward_batch = torch.cat(batch.reward)

                # Compute Q(s_t, a) - the model computes Q(s_t), then we select the
                # columns of actions taken. These are the actions which would've been taken
                # for each batch state according to policy_net
                state_action_values = policy_net(state_batch).gather(1, action_batch)

                # Compute V(s_{t+1}) for all next states.
                # Expected values of actions for non_final_next_states are computed based
                # on the "older" target_net; selecting their best reward with max(1)[0].
                # This is merged based on the mask, such that we'll have either the expected
                # state value or 0 in case the state was final.
                next_state_values = torch.zeros(BATCH_SIZE, device=device)
                next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0].detach()

                # Compute the expected Q values
                expected_state_action_values = (next_state_values * GAMMA) + reward_batch

                loss = criterion(state_action_values, expected_state_action_values.unsqueeze(1))

                # Optimize the model
                optimizer.zero_grad()
                loss.backward()

                '''
                for param in policy_net.parameters():
                    param.grad.data.clamp_(-1, 1)
                '''

                optimizer.step()

        if done or t > 5000:
            average_per_episode.append(rewards/t)
            score_per_episode.append(env.score)
            seq_length.append(t)
            break
    
    print('Episode: {}, Score: {}, Reward: {}, Steps done: {}'.format(i_episode, env.score, rewards, steps_done))

    # Update the target network, copying all weights and biases in DQN
    if i_episode % TARGET_UPDATE == 0:
        target_net.load_state_dict(policy_net.state_dict())

    if i_episode % 500 == 0:
        
        PATH = '/content/drive/MyDrive/Colab_Notebooks/FlappyBird/state_dict_model_' + str(i_episode) + '.pt'
        torch.save(policy_net.state_dict(), PATH, _use_new_zipfile_serialization=False)
        

print('Complete')

Episode: 27501, Score: 2, Reward: 17.099999999999994, Steps done: 1842713
Episode: 27502, Score: 2, Reward: 16.999999999999993, Steps done: 1842876
Episode: 27503, Score: 0, Reward: 6.799999999999988, Steps done: 1842955
Episode: 27504, Score: 0, Reward: 7.499999999999986, Steps done: 1843041
Episode: 27505, Score: 1, Reward: 11.899999999999974, Steps done: 1843162
Episode: 27506, Score: 2, Reward: 17.099999999999994, Steps done: 1843326
Episode: 27507, Score: 2, Reward: 16.69999999999999, Steps done: 1843486
Episode: 27508, Score: 2, Reward: 17.5, Steps done: 1843654
Episode: 27509, Score: 0, Reward: 5.099999999999994, Steps done: 1843716
Episode: 27510, Score: 2, Reward: 17.299999999999997, Steps done: 1843882
Episode: 27511, Score: 1, Reward: 11.999999999999973, Steps done: 1844004
Episode: 27512, Score: 2, Reward: 17.4, Steps done: 1844171
Episode: 27513, Score: 0, Reward: 6.799999999999988, Steps done: 1844250
Episode: 27514, Score: 0, Reward: 6.899999999999988, Steps done: 184433

KeyboardInterrupt: ignored