In [37]:
import numpy as np
import torch
import pygame
import random
from torch import nn
from torch.distributions import Categorical
import torch.nn.functional as F
import copy
from collections import deque
import tqdm

colors = [
    (0, 0, 0),
    (120, 37, 179),
    (100, 179, 179),
    (80, 34, 22),
    (80, 134, 22),
    (180, 34, 22),
    (180, 34, 122),
]


In [38]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [39]:
class Event():
    type = None
    key = None

    def __init__(self, type, key):
        self.type = type
        self.key = key

In [40]:
class Figure:
    x = 0
    y = 0

    figures = [
        [[1, 5, 9, 13], [4, 5, 6, 7]],
        [[4, 5, 9, 10], [2, 6, 5, 9]],
        [[6, 7, 9, 10], [1, 5, 6, 10]],
        [[1, 2, 5, 9], [0, 4, 5, 6], [1, 5, 9, 8], [4, 5, 6, 10]],
        [[1, 2, 6, 10], [5, 6, 7, 9], [2, 6, 10, 11], [3, 5, 6, 7]],
        [[1, 4, 5, 6], [1, 4, 5, 9], [4, 5, 6, 9], [1, 5, 6, 9]],
        [[1, 2, 5, 6]],
    ]

    def __init__(self, x, y):
        self.x = x
        self.y = y
        self.type = random.randint(0, len(self.figures) - 1)
        self.color = 1
        self.rotation = 0

    def image(self):
        return self.figures[self.type][self.rotation]

    def rotate(self):
        self.rotation = (self.rotation + 1) % len(self.figures[self.type])


In [41]:
class Tetris:
    def __init__(self, height, width):
        self.level = 2
        self.score = 0
        self.field = []
        self.height = 0
        self.width = 0
        self.x = 100
        self.y = 60
        self.zoom = 20
        self.figure = None
    
        self.height = height
        self.width = width
        self.field = []
        self.score = 0
        self.done = False
        for i in range(height):
            new_line = []
            for j in range(width):
                new_line.append(0)
            self.field.append(new_line)
            
    def reset(self):
        self.field = []
        self.score = 0
        self.done = False
        for i in range(self.height):
            new_line = []
            for j in range(self.width):
                new_line.append(0)
            self.field.append(new_line)

    def new_figure(self):
        self.figure = Figure(3, 0)

    def intersects(self, figure):
        intersection = False
        for i in range(4):
            for j in range(4):
                if i * 4 + j in figure.image():
                    if i + figure.y > self.height - 1 or \
                            j + figure.x > self.width - 1 or \
                            j + figure.x < 0 or \
                            self.field[i + figure.y][j + figure.x] > 0:
                        intersection = True
        return intersection
    
    def step(self, figure):  
        state = copy.deepcopy(self.field)
        for i in range(4):
            for j in range(4):
                if i * 4 + j in figure.image():
                    state[i + figure.y][j + figure.x] = figure.color
#         print("/////////////////////////")
#         for row in state:
#             print(row)
        return state

    def break_lines(self):
        lines = 0
        for i in range(1, self.height):
            zeros = 0
            for j in range(self.width):
                if self.field[i][j] == 0:
                    zeros += 1
            if zeros == 0:
                lines += 1
                for i1 in range(i, 1, -1):
                    for j in range(self.width):
                        self.field[i1][j] = self.field[i1 - 1][j]
        self.score += lines ** 2

    def go_space(self):
        while not self.intersects(self.figure):
            self.figure.y += 1
        self.figure.y -= 1
        self.freeze()

    def go_down(self):
        self.figure.y += 1
        if self.intersects(self.figure):
            self.figure.y -= 1
            self.freeze()

    def freeze(self):
        for i in range(4):
            for j in range(4):
                if i * 4 + j in self.figure.image():
                    self.field[i + self.figure.y][j + self.figure.x] = self.figure.color
        self.break_lines()
        self.new_figure()
        self.done =  self.intersects(self.figure)

    def go_side(self, dx):
        old_x = self.figure.x
        self.figure.x += dx
        intersects = self.intersects(self.figure)
        if intersects:
            self.figure.x = old_x
        return intersects

    def rotate(self):
        old_rotation = self.figure.rotation
        self.figure.rotate()
        if self.intersects(self.figure):
            self.figure.rotation = old_rotation


In [42]:
class NeuralNet(nn.Module):
    def __init__(self, input_dims, hidden_dims, output_dims):
        super(NeuralNet, self).__init__()

        """ CODE HERE:
                Implement the neural network here
        """
        self.fc1 = nn.Linear(input_dims, hidden_dims)
        self.fc2 = nn.Linear(hidden_dims, hidden_dims)
        self.fc3 = nn.Linear(hidden_dims, output_dims)

    def forward(self, x):
        """ CODE HERE:
                Implement the forward propagation
        """
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)

In [16]:
feature_len = 5
behavior = NeuralNet(feature_len,64,1)
target = NeuralNet(feature_len,64,1)

target.load_state_dict(behavior.state_dict())

# Preffered way
FILE = "model.pth"
torch.save(target, FILE)

In [17]:
def get_state_properties(state):
    tower_h = []
    height_board = len(state)
    width_board  = len(state[0])
    area = height_board*width
    lines = [1]*height_board
    holes = []
    for j in range(width_board):
        tower_h.append(0)
        holes.append(0)
        top_reached = False
        for i in range(height_board):
            if state[i][j] == 0:
                if top_reached:
                    holes[j] += 1
                if lines[i] == 1:
                    lines[i] = 0
            elif state[i][j] != 0 and not top_reached:
                tower_h[j] = height_board - i
                top_reached =True
    bumps = []
    for i in range(width_board):
        bumps.append(0)
        if i != width_board-1:
            bumps[i] += abs(tower_h[i+1]-tower_h[i])
        if i != 0:
            bumps[i] += abs(tower_h[i-1]-tower_h[i])
    tower_h = np.array(tower_h) - sum(lines)
    feature = torch.tensor([sum(bumps), sum(holes), sum(tower_h), max(tower_h), sum(lines)], dtype=torch.float32)
    return feature, sum(lines)*10


def get_bfs_score(state):
    feature, _ = get_state_properties(state)
    return -feature[0].item() - 13*feature[1].item()  - feature[3].item()

def simulate(t):
    fig = Figure(3, 0)
    fig.type = t.figure.type
    fig.color = t.figure.color
    opt = float("-inf")
    opt_rotation, opt_x = 0, fig.x
    if t.intersects(fig):
        return opt_rotation, opt_x
    fig.x = -3
    for i in range(t.width + 3):
        for j in range(len(fig.figures[fig.type])):
            if not t.intersects(fig):
                while not t.intersects(fig):
                    fig.y += 1
                fig.y-=1
                possible_state = t.step(fig)
#                 score = get_bfs_score(possible_state)
                
                feature, _ = get_state_properties(possible_state)
                score = target(feature)
                if score > opt:
                    opt = score
                    opt_rotation = fig.rotation
                    opt_x = fig.x
                fig.y = 0
            fig.rotate()
        fig.x += 1
    return opt_rotation, opt_x


def run_ai(t):
    rotation, x = simulate(t)
    if t.figure.rotation != rotation:
        return [Event(pygame.KEYDOWN, pygame.K_UP)]
    elif t.figure.x < x:
        return [Event(pygame.KEYDOWN, pygame.K_RIGHT)]
    elif t.figure.x > x:
        return [Event(pygame.KEYDOWN, pygame.K_LEFT)]
    else:
        return [Event(pygame.KEYDOWN, pygame.K_SPACE)]
    return []

In [20]:
feature_len = 5
FILE = "model.pth"
# model must be created agin with parameters
behavior = torch.load(FILE)
target = NeuralNet(feature_len,64,1)
target.load_state_dict(behavior.state_dict())

width = 10
height = 20

env = Tetris(height, width)

learning_rate = 5e-4
gamma = 0.99
epsilon = 1
eps_dec = 1e-4
eps_min = 0.001
optimizer = torch.optim.AdamW(behavior.parameters(), lr=learning_rate)
criterion = nn.MSELoss()

replay_memory = deque(maxlen=30000)
scores = deque(maxlen=100)
batch_size = 512

In [21]:
n_episodes = 100000
def simulate_RL(t, isRandom):
    fig = Figure(3, 0)
    reward, best_state, best_feature, done = -300, None, torch.zeros(feature_len), True
    if t.intersects(fig):
        return best_state, best_feature, reward, done
    fig.x = -3
    fig.type = t.figure.type
    fig.color = t.figure.color
    opt = float("-inf")
    states, features, rewards = [], [], []
    for i in range(t.width + 3):
        for j in range(len(fig.figures[fig.type])):
            if not t.intersects(fig):
                while not t.intersects(fig):
                    fig.y += 1
                fig.y-=1
                done = False
                state = t.step(fig)
                feature, r = get_state_properties(state)
                score = behavior(feature)
                if score > opt:
                    opt = score
                    reward = r
                    best_state = state
                    best_feature = feature
                states.append(state)
                features.append(feature)
                rewards.append(r)
                fig.y = 0
            fig.rotate()
        fig.x += 1
    if isRandom and not done:
        index = random.randint(0, len(states) - 1)
        return states[index], features[index], rewards[index], done
    return best_state, best_feature, reward, done

def train(eps):
    pbar = tqdm.trange(n_episodes)
    for t in pbar:
        env.reset()
        feature = torch.zeros(feature_len)
        eps =  eps - eps_dec \
                if eps > eps_min else eps_min
        score = 0
        while True:
            random_action =  random.random() <= eps
            env.new_figure()
            next_state, next_feature, reward, done = simulate_RL(env, random_action)
            score += reward

            replay_memory.append((feature, reward, next_feature, done))
            feature = next_feature
            batch = random.sample(replay_memory, min(len(replay_memory), batch_size))
            feature_batch, reward_batch, next_feature_batch, done_batch = zip(*batch)
            feature_batch = torch.stack(tuple(feat for feat in feature_batch))
            reward_batch = torch.from_numpy(np.array(reward_batch, dtype=np.float32)[:, None])
            next_feature_batch = torch.stack(tuple(feat for feat in next_feature_batch))

            q_values = behavior(feature_batch)
            with torch.no_grad():
                next_prediction_batch = target(next_feature_batch)

            y_batch = torch.cat(
                    tuple(reward if done else reward + 1 + gamma * prediction for reward, done, prediction in
                          zip(reward_batch, done_batch, next_prediction_batch)))[:, None]
            optimizer.zero_grad()
            loss = criterion(q_values, y_batch)
            loss.backward()
            optimizer.step()
            if done:
                scores.append(score)
                break
            env.field = next_state
            env.break_lines()
        if t%100 ==0:
            print(np.mean(scores), eps)
            torch.save(target, FILE)
        if t%30 == 0:
            target.load_state_dict(behavior.state_dict())


In [22]:
train(epsilon)

  0%|                                                                             | 2/100000 [00:00<4:54:39,  5.66it/s]

-300.0 0.99995


  0%|                                                                           | 101/100000 [00:36<9:45:22,  2.84it/s]

-299.5 0.9949500000000006


  0%|▏                                                                         | 201/100000 [01:14<11:31:39,  2.40it/s]

-299.7 0.9899500000000011


  0%|▏                                                                         | 301/100000 [01:55<11:06:00,  2.49it/s]

-299.8 0.9849500000000017


  0%|▎                                                                         | 401/100000 [02:39<13:37:07,  2.03it/s]

-299.5 0.9799500000000022


  1%|▍                                                                          | 501/100000 [03:21<9:21:31,  2.95it/s]

-299.8 0.9749500000000028


  1%|▍                                                                         | 601/100000 [04:04<13:03:04,  2.12it/s]

-299.3 0.9699500000000033


  1%|▌                                                                         | 701/100000 [04:47<11:16:22,  2.45it/s]

-300.0 0.9649500000000039


  1%|▌                                                                         | 801/100000 [05:29<11:15:01,  2.45it/s]

-299.7 0.9599500000000044


  1%|▋                                                                          | 901/100000 [06:10<9:10:16,  3.00it/s]

-299.7 0.954950000000005


  1%|▋                                                                        | 1001/100000 [06:53<10:15:57,  2.68it/s]

-299.7 0.9499500000000055


  1%|▊                                                                        | 1101/100000 [07:36<12:24:58,  2.21it/s]

-299.4 0.9449500000000061


  1%|▉                                                                        | 1201/100000 [08:18<13:18:03,  2.06it/s]

-299.6 0.9399500000000066


  1%|▉                                                                        | 1301/100000 [08:59<12:13:04,  2.24it/s]

-299.5 0.9349500000000072


  1%|█                                                                        | 1401/100000 [09:41<10:12:57,  2.68it/s]

-299.8 0.9299500000000077


  2%|█                                                                        | 1501/100000 [10:24<11:30:39,  2.38it/s]

-299.5 0.9249500000000083


  2%|█▏                                                                       | 1601/100000 [11:06<12:08:11,  2.25it/s]

-299.6 0.9199500000000088


  2%|█▏                                                                       | 1701/100000 [11:49<11:58:17,  2.28it/s]

-299.6 0.9149500000000094


  2%|█▎                                                                       | 1801/100000 [12:34<12:48:58,  2.13it/s]

-299.7 0.9099500000000099


  2%|█▍                                                                       | 1901/100000 [13:17<11:57:29,  2.28it/s]

-299.5 0.9049500000000105


  2%|█▍                                                                       | 2001/100000 [14:00<11:12:56,  2.43it/s]

-299.2 0.899950000000011


  2%|█▌                                                                       | 2101/100000 [14:46<11:55:50,  2.28it/s]

-299.2 0.8949500000000116


  2%|█▌                                                                       | 2201/100000 [15:28<12:08:07,  2.24it/s]

-299.0 0.8899500000000121


  2%|█▋                                                                       | 2301/100000 [16:12<13:39:43,  1.99it/s]

-299.0 0.8849500000000127


  2%|█▊                                                                       | 2401/100000 [17:02<10:04:58,  2.69it/s]

-299.0 0.8799500000000132


  3%|█▊                                                                       | 2501/100000 [17:47<11:36:51,  2.33it/s]

-298.7 0.8749500000000138


  3%|█▉                                                                       | 2601/100000 [18:32<10:23:32,  2.60it/s]

-299.5 0.8699500000000143


  3%|█▉                                                                       | 2701/100000 [19:18<12:16:24,  2.20it/s]

-298.9 0.8649500000000149


  3%|██                                                                       | 2801/100000 [20:06<13:55:52,  1.94it/s]

-298.7 0.8599500000000154


  3%|██                                                                       | 2901/100000 [20:56<13:30:31,  2.00it/s]

-298.9 0.854950000000016


  3%|██▏                                                                      | 3001/100000 [21:41<11:31:03,  2.34it/s]

-298.4 0.8499500000000165


  3%|██▎                                                                      | 3101/100000 [22:28<14:10:37,  1.90it/s]

-298.4 0.8449500000000171


  3%|██▎                                                                      | 3201/100000 [23:15<10:54:42,  2.46it/s]

-299.1 0.8399500000000176


  3%|██▍                                                                      | 3301/100000 [23:59<11:44:31,  2.29it/s]

-299.0 0.8349500000000182


  3%|██▍                                                                      | 3401/100000 [24:45<14:12:50,  1.89it/s]

-298.7 0.8299500000000187


  4%|██▌                                                                      | 3501/100000 [25:34<11:03:40,  2.42it/s]

-297.8 0.8249500000000193


  4%|██▋                                                                      | 3601/100000 [26:24<13:41:30,  1.96it/s]

-298.4 0.8199500000000198


  4%|██▋                                                                      | 3701/100000 [27:14<14:05:57,  1.90it/s]

-299.1 0.8149500000000204


  4%|██▊                                                                      | 3801/100000 [28:04<13:57:52,  1.91it/s]

-298.7 0.8099500000000209


  4%|██▊                                                                      | 3901/100000 [28:56<11:30:32,  2.32it/s]

-298.0 0.8049500000000215


  4%|██▉                                                                      | 4001/100000 [29:49<14:34:43,  1.83it/s]

-297.1 0.799950000000022


  4%|██▉                                                                      | 4101/100000 [30:40<14:37:31,  1.82it/s]

-298.6 0.7949500000000226


  4%|███                                                                      | 4201/100000 [31:30<15:05:25,  1.76it/s]

-297.8 0.7899500000000231


  4%|███▏                                                                     | 4301/100000 [32:18<13:01:02,  2.04it/s]

-298.1 0.7849500000000237


  4%|███▏                                                                     | 4401/100000 [33:08<14:49:02,  1.79it/s]

-295.9 0.7799500000000242


  5%|███▎                                                                     | 4501/100000 [33:59<13:07:43,  2.02it/s]

-297.6 0.7749500000000248


  5%|███▎                                                                     | 4601/100000 [34:50<14:35:06,  1.82it/s]

-297.4 0.7699500000000253


  5%|███▍                                                                     | 4701/100000 [35:44<14:28:48,  1.83it/s]

-297.7 0.7649500000000259


  5%|███▌                                                                     | 4801/100000 [36:36<13:19:52,  1.98it/s]

-297.8 0.7599500000000264


  5%|███▌                                                                     | 4901/100000 [37:28<12:15:18,  2.16it/s]

-297.4 0.754950000000027


  5%|███▋                                                                     | 5001/100000 [38:17<13:02:47,  2.02it/s]

-298.4 0.7499500000000275


  5%|███▋                                                                     | 5101/100000 [39:11<15:35:45,  1.69it/s]

-295.6 0.7449500000000281


  5%|███▊                                                                     | 5201/100000 [39:59<13:50:11,  1.90it/s]

-296.9 0.7399500000000286


  5%|███▊                                                                     | 5301/100000 [40:51<11:56:21,  2.20it/s]

-296.4 0.7349500000000292


  5%|███▉                                                                     | 5401/100000 [41:41<16:56:17,  1.55it/s]

-297.4 0.7299500000000297


  6%|████                                                                     | 5501/100000 [42:33<12:19:37,  2.13it/s]

-296.2 0.7249500000000303


  6%|████                                                                     | 5601/100000 [43:27<14:54:18,  1.76it/s]

-296.9 0.7199500000000308


  6%|████▏                                                                    | 5701/100000 [44:27<15:20:21,  1.71it/s]

-295.9 0.7149500000000314


  6%|████▏                                                                    | 5801/100000 [45:22<15:26:16,  1.69it/s]

-295.5 0.7099500000000319


  6%|████▎                                                                    | 5901/100000 [46:18<14:52:30,  1.76it/s]

-295.4 0.7049500000000325


  6%|████▍                                                                    | 6001/100000 [47:13<13:54:42,  1.88it/s]

-295.7 0.699950000000033


  6%|████▍                                                                    | 6101/100000 [48:09<18:43:58,  1.39it/s]

-295.7 0.6949500000000336


  6%|████▌                                                                    | 6201/100000 [49:06<15:52:11,  1.64it/s]

-295.5 0.6899500000000341


  6%|████▌                                                                    | 6301/100000 [50:01<14:06:02,  1.85it/s]

-293.3 0.6849500000000347


  6%|████▋                                                                    | 6401/100000 [50:58<13:32:29,  1.92it/s]

-295.0 0.6799500000000352


  7%|████▋                                                                    | 6501/100000 [51:56<13:51:44,  1.87it/s]

-293.7 0.6749500000000358


  7%|████▊                                                                    | 6601/100000 [52:53<14:56:16,  1.74it/s]

-295.8 0.6699500000000363


  7%|████▉                                                                    | 6701/100000 [53:52<15:05:48,  1.72it/s]

-294.1 0.6649500000000369


  7%|████▉                                                                    | 6801/100000 [54:50<14:44:18,  1.76it/s]

-295.3 0.6599500000000375


  7%|█████                                                                    | 6901/100000 [55:50<14:50:48,  1.74it/s]

-293.6 0.654950000000038


  7%|█████                                                                    | 7001/100000 [56:47<13:55:17,  1.86it/s]

-296.1 0.6499500000000386


  7%|█████▏                                                                   | 7101/100000 [57:46<15:53:44,  1.62it/s]

-293.6 0.6449500000000391


  7%|█████▎                                                                   | 7201/100000 [58:45<13:02:26,  1.98it/s]

-292.5 0.6399500000000397


  7%|█████▎                                                                   | 7301/100000 [59:46<17:16:31,  1.49it/s]

-292.3 0.6349500000000402


  7%|█████▎                                                                 | 7401/100000 [1:00:47<16:42:13,  1.54it/s]

-291.0 0.6299500000000408


  8%|█████▎                                                                 | 7501/100000 [1:01:45<16:30:07,  1.56it/s]

-292.7 0.6249500000000413


  8%|█████▍                                                                 | 7601/100000 [1:02:44<16:06:43,  1.59it/s]

-293.1 0.6199500000000419


  8%|█████▍                                                                 | 7701/100000 [1:03:46<14:42:30,  1.74it/s]

-290.7 0.6149500000000424


  8%|█████▌                                                                 | 7801/100000 [1:04:48<15:34:09,  1.64it/s]

-291.7 0.609950000000043


  8%|█████▌                                                                 | 7901/100000 [1:05:48<15:50:58,  1.61it/s]

-293.0 0.6049500000000435


  8%|█████▋                                                                 | 8001/100000 [1:06:51<19:54:45,  1.28it/s]

-290.8 0.5999500000000441


  8%|█████▊                                                                 | 8101/100000 [1:07:52<14:57:24,  1.71it/s]

-290.8 0.5949500000000446


  8%|█████▊                                                                 | 8201/100000 [1:08:55<16:01:24,  1.59it/s]

-291.7 0.5899500000000452


  8%|█████▉                                                                 | 8301/100000 [1:09:58<17:51:07,  1.43it/s]

-290.5 0.5849500000000457


  8%|█████▉                                                                 | 8401/100000 [1:11:00<17:31:40,  1.45it/s]

-291.0 0.5799500000000463


  9%|██████                                                                 | 8501/100000 [1:12:02<16:07:31,  1.58it/s]

-291.4 0.5749500000000468


  9%|██████                                                                 | 8601/100000 [1:13:06<16:21:43,  1.55it/s]

-288.8 0.5699500000000474


  9%|██████▏                                                                | 8701/100000 [1:14:10<15:04:55,  1.68it/s]

-290.2 0.5649500000000479


  9%|██████▏                                                                | 8801/100000 [1:15:16<15:06:24,  1.68it/s]

-288.8 0.5599500000000485


  9%|██████▎                                                                | 8901/100000 [1:16:22<19:11:12,  1.32it/s]

-288.4 0.554950000000049


  9%|██████▍                                                                | 9001/100000 [1:17:30<15:31:24,  1.63it/s]

-287.2 0.5499500000000496


  9%|██████▍                                                                | 9101/100000 [1:18:38<15:46:09,  1.60it/s]

-286.0 0.5449500000000501


  9%|██████▌                                                                | 9201/100000 [1:19:42<16:19:16,  1.55it/s]

-289.7 0.5399500000000507


  9%|██████▌                                                                | 9301/100000 [1:20:49<17:57:03,  1.40it/s]

-288.0 0.5349500000000512


  9%|██████▋                                                                | 9401/100000 [1:21:53<17:25:49,  1.44it/s]

-287.8 0.5299500000000518


 10%|██████▋                                                                | 9501/100000 [1:22:59<15:47:07,  1.59it/s]

-288.5 0.5249500000000523


 10%|██████▊                                                                | 9601/100000 [1:24:07<14:05:58,  1.78it/s]

-285.5 0.5199500000000529


 10%|██████▉                                                                | 9701/100000 [1:25:15<17:34:26,  1.43it/s]

-286.8 0.5149500000000534


 10%|██████▉                                                                | 9801/100000 [1:26:25<18:44:03,  1.34it/s]

-284.4 0.509950000000054


 10%|███████                                                                | 9901/100000 [1:27:32<17:22:45,  1.44it/s]

-288.1 0.5049500000000545


 10%|███████                                                               | 10001/100000 [1:28:40<16:38:00,  1.50it/s]

-284.1 0.4999500000000551


 10%|███████                                                               | 10101/100000 [1:29:50<15:47:59,  1.58it/s]

-286.0 0.4949500000000556


 10%|███████▏                                                              | 10201/100000 [1:30:59<15:43:43,  1.59it/s]

-284.5 0.4899500000000562


 10%|███████▏                                                              | 10301/100000 [1:32:09<16:36:33,  1.50it/s]

-282.7 0.4849500000000567


 10%|███████▎                                                              | 10401/100000 [1:33:23<18:25:37,  1.35it/s]

-278.5 0.4799500000000573


 11%|███████▎                                                              | 10501/100000 [1:34:34<18:10:49,  1.37it/s]

-283.4 0.4749500000000578


 11%|███████▍                                                              | 10601/100000 [1:35:45<19:55:33,  1.25it/s]

-284.1 0.4699500000000584


 11%|███████▍                                                              | 10701/100000 [1:36:58<17:11:16,  1.44it/s]

-280.7 0.4649500000000589


 11%|███████▌                                                              | 10801/100000 [1:38:09<17:14:18,  1.44it/s]

-281.2 0.4599500000000595


 11%|███████▋                                                              | 10901/100000 [1:39:23<19:58:50,  1.24it/s]

-279.2 0.45495000000006003


 11%|███████▋                                                              | 11001/100000 [1:40:37<16:39:42,  1.48it/s]

-279.9 0.4499500000000606


 11%|███████▊                                                              | 11101/100000 [1:41:51<19:32:54,  1.26it/s]

-278.6 0.44495000000006113


 11%|███████▊                                                              | 11201/100000 [1:43:05<18:31:34,  1.33it/s]

-279.5 0.4399500000000617


 11%|███████▉                                                              | 11301/100000 [1:44:20<18:46:23,  1.31it/s]

-276.5 0.43495000000006223


 11%|███████▉                                                              | 11401/100000 [1:45:36<18:08:58,  1.36it/s]

-278.4 0.4299500000000628


 12%|████████                                                              | 11501/100000 [1:46:50<19:03:15,  1.29it/s]

-277.6 0.42495000000006333


 12%|████████                                                              | 11601/100000 [1:48:06<21:43:50,  1.13it/s]

-277.2 0.4199500000000639


 12%|████████▏                                                             | 11701/100000 [1:49:26<21:40:16,  1.13it/s]

-272.2 0.41495000000006443


 12%|████████▎                                                             | 11801/100000 [1:50:46<16:27:55,  1.49it/s]

-274.4 0.409950000000065


 12%|████████▎                                                             | 11901/100000 [1:52:04<21:13:14,  1.15it/s]

-271.9 0.40495000000006554


 12%|████████▍                                                             | 12001/100000 [1:53:21<17:15:41,  1.42it/s]

-274.2 0.3999500000000661


 12%|████████▍                                                             | 12101/100000 [1:54:37<17:50:25,  1.37it/s]

-274.0 0.39495000000006664


 12%|████████▌                                                             | 12201/100000 [1:55:56<21:24:26,  1.14it/s]

-269.6 0.3899500000000672


 12%|████████▌                                                             | 12301/100000 [1:57:13<19:32:16,  1.25it/s]

-273.6 0.38495000000006774


 12%|████████▋                                                             | 12401/100000 [1:58:31<16:18:45,  1.49it/s]

-269.9 0.3799500000000683


 13%|████████▊                                                             | 12501/100000 [1:59:55<22:33:32,  1.08it/s]

-260.2 0.37495000000006884


 13%|████████▊                                                             | 12601/100000 [2:01:17<19:56:31,  1.22it/s]

-264.0 0.3699500000000694


 13%|████████▉                                                             | 12701/100000 [2:02:39<22:32:57,  1.08it/s]

-264.7 0.36495000000006994


 13%|████████▉                                                             | 12801/100000 [2:04:01<21:27:27,  1.13it/s]

-265.7 0.3599500000000705


 13%|█████████                                                             | 12901/100000 [2:05:23<17:33:05,  1.38it/s]

-262.4 0.35495000000007104


 13%|█████████                                                             | 13001/100000 [2:06:43<20:24:14,  1.18it/s]

-265.7 0.3499500000000716


 13%|█████████▏                                                            | 13101/100000 [2:08:07<19:41:55,  1.23it/s]

-259.9 0.34495000000007214


 13%|█████████▏                                                            | 13201/100000 [2:09:31<18:59:40,  1.27it/s]

-263.6 0.3399500000000727


 13%|█████████▎                                                            | 13301/100000 [2:10:56<23:08:20,  1.04it/s]

-259.1 0.33495000000007324


 13%|█████████▍                                                            | 13401/100000 [2:12:28<22:29:54,  1.07it/s]

-249.9 0.3299500000000738


 14%|█████████▍                                                            | 13501/100000 [2:13:59<23:33:53,  1.02it/s]

-253.0 0.32495000000007435


 14%|█████████▌                                                            | 13601/100000 [2:15:28<22:31:21,  1.07it/s]

-256.4 0.3199500000000749


 14%|█████████▌                                                            | 13701/100000 [2:16:54<21:22:10,  1.12it/s]

-260.7 0.31495000000007545


 14%|█████████▋                                                            | 13801/100000 [2:18:29<22:49:49,  1.05it/s]

-254.5 0.309950000000076


 14%|█████████▋                                                            | 13901/100000 [2:20:01<21:40:10,  1.10it/s]

-257.4 0.30495000000007655


 14%|█████████▊                                                            | 14001/100000 [2:21:40<23:45:56,  1.01it/s]

-246.2 0.2999500000000771


 14%|█████████▊                                                            | 14101/100000 [2:23:19<26:13:31,  1.10s/it]

-246.5 0.29495000000007765


 14%|█████████▉                                                            | 14201/100000 [2:25:03<25:57:56,  1.09s/it]

-241.7 0.2899500000000782


 14%|██████████                                                            | 14301/100000 [2:26:49<26:56:33,  1.13s/it]

-237.3 0.28495000000007875


 14%|██████████                                                            | 14401/100000 [2:28:38<36:48:29,  1.55s/it]

-234.7 0.2799500000000793


 15%|██████████▏                                                           | 14501/100000 [2:30:25<23:07:49,  1.03it/s]

-237.0 0.27495000000007985


 15%|██████████▏                                                           | 14601/100000 [2:32:15<25:45:01,  1.09s/it]

-237.5 0.2699500000000804


 15%|██████████▎                                                           | 14701/100000 [2:34:00<20:02:56,  1.18it/s]

-232.9 0.26495000000008095


 15%|██████████▎                                                           | 14801/100000 [2:35:42<21:18:33,  1.11it/s]

-235.5 0.2599500000000815


 15%|██████████▍                                                           | 14901/100000 [2:37:28<26:23:09,  1.12s/it]

-230.5 0.25495000000008206


 15%|██████████▌                                                           | 15001/100000 [2:39:13<27:00:50,  1.14s/it]

-229.4 0.2499500000000826


 15%|██████████▌                                                           | 15101/100000 [2:41:02<24:28:18,  1.04s/it]

-225.4 0.24495000000008316


 15%|██████████▋                                                           | 15201/100000 [2:42:42<25:07:37,  1.07s/it]

-239.0 0.2399500000000837


 15%|██████████▋                                                           | 15301/100000 [2:44:27<28:24:19,  1.21s/it]

-229.6 0.23495000000008426


 15%|██████████▊                                                           | 15401/100000 [2:46:18<25:35:09,  1.09s/it]

-223.4 0.2299500000000848


 16%|██████████▊                                                           | 15501/100000 [2:48:09<27:28:56,  1.17s/it]

-221.3 0.22495000000008536


 16%|██████████▉                                                           | 15601/100000 [2:50:06<28:29:19,  1.22s/it]

-210.6 0.2199500000000859


 16%|██████████▉                                                           | 15701/100000 [2:51:56<24:49:33,  1.06s/it]

-220.2 0.21495000000008646


 16%|███████████                                                           | 15801/100000 [2:53:52<28:41:13,  1.23s/it]

-206.2 0.209950000000087


 16%|███████████▏                                                          | 15901/100000 [2:55:52<26:47:28,  1.15s/it]

-200.4 0.20495000000008756


 16%|███████████▏                                                          | 16001/100000 [2:57:56<32:18:20,  1.38s/it]

-197.7 0.1999500000000881


 16%|███████████▎                                                          | 16101/100000 [3:00:06<26:59:00,  1.16s/it]

-181.9 0.19495000000008866


 16%|███████████▎                                                          | 16201/100000 [3:02:10<27:08:33,  1.17s/it]

-194.1 0.18995000000008921


 16%|███████████▍                                                          | 16301/100000 [3:04:18<29:07:47,  1.25s/it]

-190.4 0.18495000000008976


 16%|███████████▍                                                          | 16401/100000 [3:06:24<34:45:08,  1.50s/it]

-190.1 0.17995000000009032


 17%|███████████▌                                                          | 16501/100000 [3:08:29<37:15:40,  1.61s/it]

-194.7 0.17495000000009087


 17%|███████████▌                                                          | 16601/100000 [3:10:42<27:29:49,  1.19s/it]

-179.2 0.16995000000009142


 17%|███████████▋                                                          | 16701/100000 [3:12:59<33:14:57,  1.44s/it]

-170.6 0.16495000000009197


 17%|███████████▊                                                          | 16801/100000 [3:15:05<29:30:52,  1.28s/it]

-192.7 0.15995000000009252


 17%|███████████▊                                                          | 16901/100000 [3:17:36<48:42:08,  2.11s/it]

-147.6 0.15495000000009307


 17%|███████████▉                                                          | 17001/100000 [3:19:58<38:39:06,  1.68s/it]

-165.1 0.14995000000009362


 17%|███████████▉                                                          | 17101/100000 [3:22:20<40:04:21,  1.74s/it]

-165.3 0.14495000000009417


 17%|████████████                                                          | 17201/100000 [3:24:50<32:27:17,  1.41s/it]

-151.6 0.13995000000009472


 17%|████████████                                                          | 17301/100000 [3:27:31<38:38:23,  1.68s/it]

-127.4 0.13495000000009527


 17%|████████████▏                                                         | 17401/100000 [3:30:22<45:51:13,  2.00s/it]

-112.8 0.12995000000009582


 18%|████████████▎                                                         | 17501/100000 [3:33:08<42:59:52,  1.88s/it]

-118.5 0.12495000000009637


 18%|████████████▎                                                         | 17601/100000 [3:36:07<52:12:58,  2.28s/it]

-95.3 0.11995000000009692


 18%|████████████▍                                                         | 17701/100000 [3:39:01<38:32:44,  1.69s/it]

-104.4 0.11495000000009747


 18%|████████████▍                                                         | 17801/100000 [3:42:07<44:59:16,  1.97s/it]

-82.7 0.10995000000009802


 18%|████████████▌                                                         | 17901/100000 [3:45:03<30:09:34,  1.32s/it]

-99.9 0.10495000000009858


 18%|████████████▌                                                         | 18001/100000 [3:48:15<41:42:40,  1.83s/it]

-71.5 0.09995000000009913


 18%|████████████▋                                                         | 18101/100000 [3:51:29<41:26:21,  1.82s/it]

-66.2 0.09495000000009968


 18%|████████████▋                                                         | 18201/100000 [3:54:47<40:56:56,  1.80s/it]

-60.7 0.08995000000010023


 18%|████████████▊                                                         | 18301/100000 [3:58:17<56:43:52,  2.50s/it]

-84.4 0.08495000000010078


 18%|████████████▉                                                         | 18401/100000 [4:02:32<51:02:22,  2.25s/it]

-50.7 0.07995000000010133


 19%|████████████▉                                                         | 18501/100000 [4:06:45<47:43:17,  2.11s/it]

-47.8 0.07495000000010188


 19%|█████████████                                                         | 18601/100000 [4:11:22<57:40:40,  2.55s/it]

-8.3 0.06995000000010243


 19%|█████████████                                                         | 18701/100000 [4:15:49<56:57:28,  2.52s/it]

-19.1 0.06495000000010298


 19%|█████████████▏                                                        | 18801/100000 [4:20:38<65:31:59,  2.91s/it]

14.0 0.05995000000010318


 19%|█████████████▏                                                        | 18901/100000 [4:25:39<56:49:52,  2.52s/it]

34.5 0.054950000000103034


 19%|█████████████▎                                                        | 19001/100000 [4:30:30<63:08:53,  2.81s/it]

14.2 0.04995000000010289


 19%|█████████████▎                                                        | 19101/100000 [4:36:02<86:33:14,  3.85s/it]

82.2 0.04495000000010275


 19%|█████████████▍                                                        | 19201/100000 [4:41:44<88:18:56,  3.93s/it]

97.1 0.039950000000102605


 19%|█████████████▌                                                        | 19301/100000 [4:47:50<77:45:02,  3.47s/it]

142.3 0.03495000000010246


 19%|█████████████▌                                                        | 19401/100000 [4:53:49<71:28:17,  3.19s/it]

112.1 0.02995000000010232


 20%|█████████████▋                                                        | 19501/100000 [4:58:44<59:47:57,  2.67s/it]

97.4 0.024950000000102175


 20%|█████████████▋                                                        | 19601/100000 [5:04:07<81:53:05,  3.67s/it]

174.2 0.019950000000102032


 20%|█████████████▊                                                        | 19701/100000 [5:09:35<69:18:13,  3.11s/it]

169.9 0.014950000000101913


 20%|█████████████▋                                                       | 19801/100000 [5:15:55<133:39:00,  6.00s/it]

238.2 0.009950000000101943


 20%|█████████████▋                                                       | 19901/100000 [5:22:43<173:19:22,  7.79s/it]

297.6 0.0049500000001019735


 20%|█████████████▊                                                       | 20001/100000 [5:30:46<101:17:51,  4.56s/it]

425.4 0.001


 20%|█████████████▊                                                       | 20101/100000 [5:38:05<119:44:51,  5.40s/it]

392.0 0.001


 20%|██████████████▏                                                       | 20201/100000 [5:43:30<51:59:27,  2.35s/it]

167.2 0.001


 20%|██████████████▏                                                       | 20301/100000 [5:49:43<47:53:28,  2.16s/it]

303.3 0.001


 20%|██████████████▏                                                       | 20305/100000 [5:49:52<22:53:12,  1.03s/it]


KeyboardInterrupt: 

In [25]:
pbar = tqdm.trange(1)
for t in pbar:
    env.reset() 
    feature = torch.zeros(feature_len).float()
    score = 0
    while True:
        random_action =  random.random() <= eps_min
        env.new_figure()
        next_state, next_feature, reward, done = simulate_RL(env, random_action)
        score += reward
        replay_memory.append((feature, reward, next_feature, done))
        if done:
            scores.append(score)
            print("///////////Game Over//////////////////")
            break
        env.field = next_state
        env.break_lines()
#         print()
#         for row in env.field:
#             print(row)
#         print(next_feature)
        feature = next_feature
print(score)

100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.31s/it]

///////////Game Over//////////////////
390





In [36]:
t = Tetris(20,10)

while True:
    sentence = input("Input:")
    if sentence =="q":
        break
    if t.figure is None:
        t.new_figure()
    if sentence == "u":
        t.rotate()
    elif sentence == "l":
        t.go_side(-1)
    elif sentence == "r":
        t.go_side(1)
    elif sentence == "s":
        t.go_space()
    print(simulate(t))
    t.go_down()
    print()
    for row in t.field:
        print(row)

Input:q


In [36]:
# Initialize the game engine
pygame.init()

# Define some colors
BLACK = (0, 0, 0)
WHITE = (255, 255, 255)
GRAY = (128, 128, 128)

size = (400, 500)
screen = pygame.display.set_mode(size)

pygame.display.set_caption("Tetris")

# Loop until the user clicks the close button.
done = False
clock = pygame.time.Clock()
fps = 10
game = Tetris(20, 10)
counter = 0

pressing_down = False

while not done:
    if game.figure is None:
        game.new_figure()
    counter += 1
    if counter > 100000:
        counter = 0

    if counter % (fps // game.level // 2) == 0 or pressing_down:
        if not game.done:
            game.go_down()

    for event in list(pygame.event.get()) + run_ai(game):
        if event.type == pygame.QUIT:
            done = True
        if event.type == pygame.KEYDOWN:
            if event.key == pygame.K_UP:
                game.rotate()
            if event.key == pygame.K_DOWN:
                pressing_down = True
            if event.key == pygame.K_LEFT:
                game.go_side(-1)
            if event.key == pygame.K_RIGHT:
                game.go_side(1)
            if event.key == pygame.K_SPACE:
                game.go_space()
            if event.key == pygame.K_ESCAPE:
                game.__init__(20, 10)

    if event.type == pygame.KEYUP:
            if event.key == pygame.K_DOWN:
                pressing_down = False

    screen.fill(WHITE)

    for i in range(game.height):
        for j in range(game.width):
            pygame.draw.rect(screen, GRAY, [game.x + game.zoom * j, game.y + game.zoom * i, game.zoom, game.zoom], 1)
            if game.field[i][j] > 0:
                pygame.draw.rect(screen, colors[game.field[i][j]],
                                 [game.x + game.zoom * j + 1, game.y + game.zoom * i + 1, game.zoom - 2, game.zoom - 1])
    
    if game.figure is not None:
        for i in range(4):
            for j in range(4):
                p = i * 4 + j
                if p in game.figure.image():
                    pygame.draw.rect(screen, colors[game.figure.color],
                                     [game.x + game.zoom * (j + game.figure.x) + 1,
                                      game.y + game.zoom * (i + game.figure.y) + 1,
                                      game.zoom - 2, game.zoom - 2])

    font = pygame.font.SysFont('Calibri', 25, True, False)
    font1 = pygame.font.SysFont('Calibri', 65, True, False)
    text = font.render("Score: " + str(game.score), True, BLACK)
    text_game_over = font1.render("Game Over", True, (255, 125, 0))
    text_game_over1 = font1.render("Press ESC", True, (255, 215, 0))

    screen.blit(text, [0, 0])
    if game.done:
        screen.blit(text_game_over, [20, 200])
        screen.blit(text_game_over1, [25, 265])

    pygame.display.flip()
    clock.tick(fps)

pygame.quit()