Based on Tetris implementation in Python, (c) 2010 "Kevin Chabowski" <kevin@kch42.de>

https://gist.github.com/silvasur/565419/d9de6a84e7da000797ac681976442073045c74a4

In [None]:
import torch
from torch.optim import Adam
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import random
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import copy

In [None]:
COLUMNS = 8
ROWS = 16

In [None]:
TETRIS_SHAPES = [
    [[1, 1, 1],
     [0, 1, 0]],

    [[0, 2, 2],
     [2, 2, 0]],

    [[3, 3, 0],
     [0, 3, 3]],

    [[4, 0, 0],
     [4, 4, 4]],

    [[0, 0, 5],
     [5, 5, 5]],

    [[6, 6, 6, 6]],

    [[7, 7],
     [7, 7]]
]

COLORS = {
    1: '#fff200', # yellow
    2: '#ff0000', # red
    3: '#0000ff', # blue
    4: '#aa00ff', # violet
    5: '#15bb40', # green
    6: '#ff9900', # orange
    7: '#00bbb0', # turquoise
    8: '#42d9f4', # light blue (for water)
}

In [None]:
def rotate_clockwise(shape):
    return [
        [shape[y][x] for y in range(len(shape))] for x in range(len(shape[0]) - 1, -1, -1)]


def check_collision(board, shape, offset):
    off_x, off_y = offset
    for cy, row in enumerate(shape):
        for cx, cell in enumerate(row):
            try:
                if cell and board[cy + off_y][cx + off_x]:
                    return True
            except IndexError:
                return True
    return False


def remove_row(board, row):
    del board[row]
    return [[0 for i in range(COLUMNS)]] + board


def join_matrixes(mat1, mat2, mat2_off):
    off_x, off_y = mat2_off
    for cy, row in enumerate(mat2):
        for cx, val in enumerate(row):
            mat1[cy + off_y - 1][cx + off_x] += val
    return mat1


def new_board():
    board = [[0 for x in range(COLUMNS)]
             for y in range(ROWS)]
    board += [[1 for x in range(COLUMNS)]]
    return board

In [None]:
class TetrisState():
    def __init__(self):
        self.board = new_board()
        self.gameover = False
        self.new_stone()
    
    def copy(self):
        return copy.deepcopy(self)
    
    def new_stone(self):
        self.stone = random.choice(TETRIS_SHAPES)
        self.stone_x = int(COLUMNS / 2 - len(self.stone[0]) / 2)
        self.stone_y = 0
        if check_collision(self.board, self.stone, (self.stone_x, self.stone_y)):
            self.gameover = False
  
    def drop(self):
        collision = False
        if not self.gameover:
            self.stone_y += 1
            if check_collision(self.board, self.stone, (self.stone_x, self.stone_y)):
                collision = True
                self.board = join_matrixes(self.board, self.stone, (self.stone_x, self.stone_y))
                self.new_stone()
                while True:
                    for i, row in enumerate(self.board[:-1]):
                        if 0 not in row:
                            self.board = remove_row(
                                self.board, i)
                            break
                    else:
                        break
        return collision

    def drop_until_collision(self):
        while True:
            if self.drop():
                break
    
    def move(self, delta_x):
        if not self.gameover:
            new_x = self.stone_x + delta_x
            if new_x < 0:
                new_x = 0
            if new_x > COLUMNS - len(self.stone[0]):
                new_x = COLUMNS - len(self.stone[0])
            if not check_collision(self.board, self.stone, (new_x, self.stone_y)):
                self.stone_x = new_x
  
    
    def move_left(self):
        self.move(-1)

    def move_right(self):
        self.move(1)
    
    def rotate_stone(self):
        if not self.gameover:
            new_stone = rotate_clockwise(self.stone)
            if not check_collision(self.board,
                                   new_stone,
                                   (self.stone_x, self.stone_y)):
                self.stone = new_stone

    def turn_cw(self):
        new_stone = rotate_clockwise(self.stone)
        if check_collision(self.board, new_stone, (self.stone_x, self.stone_y)):
            return False
        else:
            self.stone = new_stone
            return True

    def turn_ccw(self):
        new_stone = rotate_clockwise(rotate_clockwise(rotate_clockwise(self.stone)))
        if check_collision(self.board, new_stone, (self.stone_x, self.stone_y)):
            return False
        else:
            self.stone = new_stone
            return True
    
    def rewards(self):
        left = self.copy()
        left.move_left()
        left.drop_until_collision()

        right = self.copy()
        right.move_right()
        right.drop_until_collision()

        turn_cw = self.copy()
        turn_cw.turn_cw()
        turn_cw.drop_until_collision()

        turn_ccw = self.copy()
        turn_ccw.turn_ccw()
        turn_ccw.drop_until_collision()
        
        nothing = self.copy()
        nothing.drop_until_collision()

        rewards = {
            'left': left.reward(),
            'right': right.reward(),
            'turn_cw': turn_cw.reward(),
            'turn_ccw':turn_ccw.reward(),
            'nothing': nothing.reward(),
        }
        maximum = max(rewards.values())
        minimum = min(rewards.values())
        val_range = maximum - minimum

        def normalize(value):
            if val_range == 0:
                return 0
            else:
                return ((value - minimum) / val_range)

        return {r[0]: normalize(r[1]) for r in rewards.items()}

    def draw(self):
        fig = plt.figure()
        ax = fig.add_subplot(111, aspect='equal')
        plt.xlim(0, COLUMNS)
        plt.ylim(0, ROWS)
        # draw board
        for row in range(ROWS):
            for column in range(COLUMNS):
                color = self.board[row][column]
                if color != 0:
                     ax.add_patch(patches.Rectangle((column, ROWS - row - 1),
                                  1.0, 1.0, facecolor=COLORS[color], edgecolor='#000000'))
        # draw current stone
        for y, row in enumerate(self.stone):
            for x, color in enumerate(row):
                if color != 0:
                     ax.add_patch(patches.Rectangle((self.stone_x + x, ROWS - (self.stone_y + y) - 1),
                                  1.0, 1.0, facecolor=COLORS[color], edgecolor='#000000'))

        plt.show()

    def fill(self):
         while True:
            change = False
            for row in range(ROWS):
                for column in range(COLUMNS):
                    # block is not blank
                    if self.board[row][column] != 0:
                        continue

                    if row == 0:
                        top = 8
                    else:
                        top = self.board[row - 1][column]
                    if top == 8:
                        self.board[row][column] = 8
                        change = True
                        continue

                    if column == 0:
                        left = 0
                    else:
                        left = self.board[row][column - 1]
                    if left == 8:
                        self.board[row][column] = 8
                        change = True
                        continue

                    if column ==  COLUMNS - 1:
                        right = 0
                    else:
                        right = self.board[row][column + 1]
                    if right == 8:
                        self.board[row][column] = 8
                        change = True
                        continue
            if not change:
                break

    def reward(self):
        copy = self.copy()
        copy.fill()
        reward = 0
        for row in range(ROWS):
            for column in range(COLUMNS):
                if copy.board[row][column] == 8:
                    reward += 1
        return reward


In [None]:
random.seed(1)
state = TetrisState()

while True:
    rewards = list(state.rewards().items())
    random.shuffle(rewards)
    reward, _ = max(rewards, key=lambda r: r[1])
    if reward == 'left':
        state.move_left()
    elif reward == 'right':
        state.move_right()
    elif reward == 'turn_cw':
        state.turn_cw()
    elif reward == 'turn_ccw':
        state.turn_ccw()
    state.drop()
    if state.gameover:
        break
    state.draw()

In [None]:
random.seed(1)
state = TetrisState()
state.drop_until_collision()
state.move(2)
state.drop_until_collision()
state.move(2)
state.draw()
state.rewards()

In [None]:
random.seed(3)
state = TetrisState()
state.move(-2)
state.drop_until_collision()
state.move(1)
for _ in range(14): state.drop()
state.move(-1)
state.drop()
state.move(3)
for _ in range(12): state.drop()
state.draw()
state.rewards()

In [None]:
NUM_CLASSES = 4 # turn left, turn right, move left, move right
LEARING_RATE = 0.001

input_size = COLUMNS * ROWS * 2
HIDDEN_SIZE = 50

class Model2Linear(nn.Module):
    def __init__(self):
        super().__init__()
        self.h1 = nn.Linear(input_size, HIDDEN_SIZE)
        self.h2 = nn.Linear(HIDDEN_SIZE, HIDDEN_SIZE)
        self.h3 = nn.Linear(HIDDEN_SIZE, NUM_CLASSES)

    def forward(self, x):
        x = x.data.view(-1, input_size)
        x = self.h1(x)
        x = F.relu(x)
        x = self.h2(x)
        x = F.relu(x)
        x = self.h3(x)
        x = F.softmax(x, dim=1)
        return x

In [None]:
def evalulate(model):
    model.eval()
    loss = 0.0
    for data, labels in validation_loader:
        predictions_per_class = model(data.cuda())
        _, highest_prediction_class = predictions_per_class.max(1)
        loss += F.nll_loss(predictions_per_class, labels.cuda())
    return loss/len(validation_loader)

def learn():
    model = Model1Linear().cuda()
    optimizer = Adam(params=model.parameters(), lr=LEARING_RATE)

    for epoch in range(1000):
        model.train()
        for data, labels in loader:
            predictions_per_class = model(data.cuda())
            highest_prediction, highest_prediction_class = predictions_per_class.max(1)

            # how good are we? compare output with the target classes
            loss = F.nll_loss(predictions_per_class, labels.cuda())

            model.zero_grad() # ???
            loss.backward() # backpropagate
            optimizer.step()
        
        validation_loss = evalulate(model)
        print(f'Epoch: {epoch}, Loss: {validation_loss.item()}')
        
    return model
%time model = learn()