In [1]:
#Tetris 14*10

In [21]:
import os
from collections import deque
from random import random,randint,sample
import numpy as np
import cv2
import torch
import torch.nn as nn


pieces = [[[1, 1],[1, 1]],[[0, 2, 0],[2, 2, 2]],[[0, 3, 3],[3, 3, 0]],
[[4, 4, 0],[0, 4, 4]],[[5, 5, 5, 5]],[[0, 0, 6],[6, 6, 6]],[[7, 0, 0],[7, 7, 7]]]

board_width,board_height = 10,14

def index_of(piece):
    return int(np.sum(piece)/4) - 1


def rotate(input_piece,n_times):
    i = [x[:] for x in input_piece]
    
    for _ in range(n_times):
        rows = len(i)
        cols = len(i[0])
        r = [[0]*rows for _ in range(cols)]
        
        for rs in range(rows):
            for cs in range(cols):
                #for y coordinate,downward is positive
                r[cs][rows-rs-1] = i[rs][cs]
                # Or r[cs][rs] = i[rows-rs-1][cs]
        i = r
    return i


# Collide with board boundary.
def collision(p,x,y):    
    end_x = x -1 + len(p[0])
    end_y = y -1 + len(p)
    Colli = False    
    if x < 0 or y < 0 or end_x >= board_width or end_y >= board_height:
        Colli = True
    return Colli


class Tetris:
    def __init__(self,width = 10,height = 14):
        self.width = width
        self.height = height
        self.area_size = self.width * self.height
        self.total = int(self.area_size / 4)
        
        self.reset()
    def reset(self):
        #self.piece_x = 0
        #self.piece_y = 0         
        self.score = 0
        self.remained = [5,5,5,5,5,5,5]
        self.current_step = 0        
        self.gameover = False           
        self.stps = [[[0 for _ in range(self.width)] for _ in range(self.height)] for _ in range(self.total + 1)]
        self.current = self.stps[0]
        return self.state_as_nn_input(self.current)

    
    #Most Important Function,Both class Tetris and DQN will call it.
    def board_add_pieces(self,adding_piece,x,y,real = False):
        a = adding_piece
        conflict = False
        #conflict means Conflict with other pieces
        
        if real:
            idx = index_of(a)
            self.remained[idx] = self.remained[idx] - 1
            if any(e < 0 for e in self.remained):
                print("Error in board_add_pieces function!")
            b = self.current
        else:
            b = [data[:] for data in self.current]            


        for r in range(len(a)):
            for c in range(len(a[r])):
                row = y + r
                col = x + c
                if a[r][c]:
                    if np.count_nonzero(b[row][col]):
                        conflict = True
                    else:
                        b[row][col] = a[r][c]
        return b,conflict    
        

    #Made a choice from all possible choices
    def step(self,idx,r,x,y): 
        p = pieces[idx]
        p = rotate(p,r)
        current_state,conflit = self.board_add_pieces(p,x,y,real = True)
        self.current_step += 1
        if not conflit:
            self.stps[self.current_step] = current_state
            self.score += self.current_step     
        else:
            print("Error in step function")
            
            
    #Get all of states as inputs of DQN ...       
    def next_steps(self):
        choices = {}
        for i,e in enumerate(self.remained):
            if e > 0:
                if i == 0:
                    rot = 1
                elif i == 2 or i == 3 or i == 4:
                    rot = 2
                else:
                    rot = 4
                
                p = pieces[i]
                valid_xs = self.width - len(p[0])
                valid_ys = self.height - len(p)
                
                for r in range(rot):
                    for x in range(valid_xs + 1):
                        for y in range(valid_ys + 1):
                            Colli = collision(p,x,y)
                            #print(i,r,x,y,o)
                            if not Colli:
                                state,conflict = self.board_add_pieces(p,x,y)
                                if not conflict:
                                    choices[(i,r,x,y)] = self.state_as_nn_input(state)
                    p = rotate(p,n_times = 1)
                    valid_xs = board_width - len(p[0])
                    valid_ys = board_height - len(p)
        if choices == {}:
            self.gameover = True
        return choices
    

    def state_as_nn_input(self,state):
        capture_row = torch.FloatTensor(self.row_count(state))/10#
        capture_col = torch.FloatTensor(self.col_count(state))/14#
        number_of_remained = torch.FloatTensor(self.remained)/5#
        return torch.cat((capture_row,capture_col,number_of_remained),0)   
    
    
    def col_count(self,s):
        h = []
        for i in range(self.width):
            k = 0
            for j in range(len(self.current)):
                if np.count_nonzero(s[j][i]):
                    k += 1
            h.append(k)
        return h
            
        
    def row_count(self,s):
        h = [0] * self.height
        for i,e in enumerate(s):
            h[i] = np.count_nonzero(e)
        return h    
    
    
#The input of Fully Connected Layer is what state_as_nn_input Fn returned!
class DQN(nn.Module):
    def __init__(self):
        super(DQN,self).__init__()
        #8 is number of rows,7 is number of cols,7 is number of which remained
        self.l1 = nn.Sequential(nn.Linear(14+10+7,512), nn.ReLU(inplace = True))
        self.l2 = nn.Sequential(nn.Linear(512,128), nn.ReLU(inplace = True))
        self.l3 = nn.Sequential(nn.Linear(128,1))
        self._create_weights()
        
        
    def _create_weights(self):
        for w in self.modules():
            if isinstance(w, nn.Linear):
                nn.init.xavier_uniform_(w.weight)
                nn.init.constant_(w.bias, 0)
                
                
    def forward(self, t):
        t = self.l1(t)
        t = self.l2(t)
        t = self.l3(t)
        return t
    
    
############################################################################    
powerful = torch.cuda.is_available()
############################################################################    
    
#Suppose we are all powerful :)    
def trainDQN(train_model = DQN()):
    model = train_model
    model.cuda()#
    env = Tetris()    
    state = env.reset()
    state = state.cuda()#
    
    replay_memory_size = 30000
    batch_size = 1024
    num_epochs = 1000
    learning_rate = 0.05
    epsilon = 0.9
    torch.cuda.manual_seed(123)
    epoch = 0    
    replay_memory = deque(maxlen = replay_memory_size)
    optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate)
    criterion = nn.MSELoss()

    while epoch < num_epochs:    
        epsilon = epsilon * 0.999    
        choices = env.next_steps()
        if bool(choices):
            actions, states = zip(*choices.items())
            states = torch.stack(states)
            states = states.cuda()#
            model.eval()
            with torch.no_grad():
                scores = model(states)[:,0]
            model.train() 
            if random() < epsilon:
                #randint contain the biggest number!
                index = randint(0,len(choices) - 1)
            else:
                index = torch.argmax(scores).item()        
                
            next_state = states[index, :]
            next_action = actions[index]
            idx,r,x,y = next_action
            env.step(idx,r,x,y)
            next_state = next_state.cuda()#
            
            replay_memory.append([state,env.current_step,next_state,env.gameover])
            
            state = next_state
            continue
        else:#Could not put one single tetris piece on the board any more!
            final_score = env.score
            state = env.reset() 
            state = state.cuda()
            
        if len(replay_memory) < replay_memory_size / 20:
            continue
            
        
        #First continue means game not over yet,Second means not play enough times
        #If played enough times,then:send the batch data of step reward(env.current) and choiced state
        #as input of NN,output is Q value approching to step reward
        epoch += 1
        batch = sample(replay_memory, min(len(replay_memory), batch_size))        
        
        state_batch, reward_batch, next_state_batch, done_batch = zip(*batch)        
        
        state_batch = torch.stack(tuple(state for state in state_batch))
        reward_batch = torch.from_numpy(np.array(reward_batch,dtype = np.float32)[:,None])
        next_state_batch = torch.stack(tuple(state for state in next_state_batch))    
        
        state_batch = state_batch.cuda()#
        reward_batch = reward_batch.cuda()#
        next_state_batch = next_state_batch.cuda()#
        
        Q_values = model(state_batch)
        model.eval()
        with torch.no_grad():
            ###############################################
            next_prediction_batch = model(next_state_batch)
            ###############################################
        model.train()#y_batch is actually loss expression
        #####################################################################################
        y_batch = torch.cat(
            tuple(reward if done else reward + 0.9 *prediction for reward, done, prediction in 
                 zip(reward_batch, done_batch, next_prediction_batch)))[:, None]

        optimizer.zero_grad()
        loss = criterion(Q_values, y_batch)
        loss.backward()
        optimizer.step()

        #if epoch > 0 and epoch%10 == 0:
        print("Epoch: {}/{} ".format(
        epoch,
        num_epochs))
        print(final_score)
        torch.save(model.state_dict(),"/home/fuxi/tetris/checkpoints_14*10.pth")        

In [22]:
trainDQN()

Epoch: 1/1000 
465
Epoch: 2/1000 
378
Epoch: 3/1000 
378
Epoch: 4/1000 
351
Epoch: 5/1000 
406
Epoch: 6/1000 
435
Epoch: 7/1000 
406
Epoch: 8/1000 
465
Epoch: 9/1000 
496
Epoch: 10/1000 
406
Epoch: 11/1000 
435
Epoch: 12/1000 
435
Epoch: 13/1000 
378
Epoch: 14/1000 
435
Epoch: 15/1000 
378
Epoch: 16/1000 
351
Epoch: 17/1000 
325
Epoch: 18/1000 
378
Epoch: 19/1000 
325
Epoch: 20/1000 
351
Epoch: 21/1000 
351
Epoch: 22/1000 
351
Epoch: 23/1000 
276
Epoch: 24/1000 
300
Epoch: 25/1000 
325
Epoch: 26/1000 
300
Epoch: 27/1000 
351
Epoch: 28/1000 
378
Epoch: 29/1000 
351
Epoch: 30/1000 
325
Epoch: 31/1000 
325
Epoch: 32/1000 
351
Epoch: 33/1000 
378
Epoch: 34/1000 
378
Epoch: 35/1000 
351
Epoch: 36/1000 
351
Epoch: 37/1000 
406
Epoch: 38/1000 
351
Epoch: 39/1000 
378
Epoch: 40/1000 
351
Epoch: 41/1000 
435
Epoch: 42/1000 
406


KeyboardInterrupt: 

In [23]:
def test(m):
    if torch.cuda.is_available():
        torch.cuda.manual_seed(123)
    else:
        torch.manual_seed(123)
    model = m
    model.eval()
    
    env = Tetris()
    env.reset
    model.cuda()
    while True:
        next_steps = env.next_steps()
        if bool(next_steps):
            next_actions, next_states = zip(*next_steps.items())
            next_states = torch.stack(next_states)
            if torch.cuda.is_available():
                next_states = next_states.cuda()
            predictions = model(next_states)[:, 0]
            index = torch.argmax(predictions).item()
            action = next_actions[index]
            idx,r,x,y = action
            _ = env.step(idx,r,x,y)
            print(env.current)
            print(env.score)
        else:
            env.gameover = True
        done = env.gameover
        if done:

            break
            
t = DQN()
t.load_state_dict(torch.load("/home/fuxi/tetris/checkpoints_14*10.pth"))

test(t)

[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 5, 5, 5, 5, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
1
[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 5, 5, 5, 5, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 5, 5, 5, 5, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
3
[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [

[[1, 1, 5, 5, 5, 5, 6, 6, 2, 0], [1, 1, 7, 7, 7, 0, 0, 6, 2, 2], [6, 0, 1, 1, 7, 3, 3, 6, 2, 4], [6, 0, 1, 1, 3, 3, 1, 1, 4, 4], [6, 6, 5, 5, 5, 5, 1, 1, 4, 0], [0, 3, 3, 0, 0, 0, 0, 0, 0, 0], [3, 3, 5, 5, 5, 5, 0, 4, 4, 0], [7, 7, 5, 5, 5, 5, 6, 6, 4, 4], [7, 6, 6, 6, 7, 7, 7, 6, 1, 1], [7, 6, 5, 5, 5, 5, 7, 6, 1, 1], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [2, 0, 1, 1, 7, 7, 7, 7, 0, 0], [2, 2, 1, 1, 7, 7, 7, 7, 6, 0], [2, 0, 0, 0, 0, 0, 6, 6, 6, 0]]
351
[[1, 1, 5, 5, 5, 5, 6, 6, 2, 0], [1, 1, 7, 7, 7, 0, 0, 6, 2, 2], [6, 0, 1, 1, 7, 3, 3, 6, 2, 4], [6, 0, 1, 1, 3, 3, 1, 1, 4, 4], [6, 6, 5, 5, 5, 5, 1, 1, 4, 0], [0, 3, 3, 0, 0, 0, 0, 0, 0, 0], [3, 3, 5, 5, 5, 5, 0, 4, 4, 0], [7, 7, 5, 5, 5, 5, 6, 6, 4, 4], [7, 6, 6, 6, 7, 7, 7, 6, 1, 1], [7, 6, 5, 5, 5, 5, 7, 6, 1, 1], [0, 0, 0, 0, 0, 0, 0, 0, 3, 0], [2, 0, 1, 1, 7, 7, 7, 7, 3, 3], [2, 2, 1, 1, 7, 7, 7, 7, 6, 3], [2, 0, 0, 0, 0, 0, 6, 6, 6, 0]]
378
[[1, 1, 5, 5, 5, 5, 6, 6, 2, 0], [1, 1, 7, 7, 7, 0, 0, 6, 2, 2], [6, 0, 1, 1, 7, 3, 3, 6, 2, 4