In [1]:
import numpy as np
import os
import random
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, TensorDataset

import random
import pickle

import matplotlib.pyplot as plt
%matplotlib inline

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [1]:
PATH = ""

In [0]:
param_join_reward = 1
param_inv_move_reward = -8
param_ohe_state = True
param_input_shape = 4*4*14 if param_ohe_state else 4*4
param_n_actions = 4
param_batch_size = 64

In [0]:
# Направление хода -> оси
directions_dict = {0: np.array((0,1)),
                   1: np.array((1,0)),
                   2: np.array((0,-1)),
                   3: np.array((-1,0)),}

# Направление в one-hot
def dir_to_ohe(direction_int):
    ohe = np.zeros(4)
    ohe[direction_int] = 1
    return ohe

# Направление one-hot в int
def ohe_to_dir(direction_ohe):
    return np.argmax(direction_ohe)

# Значение тайла в one-hot
def tile_to_ohe(n):    
    ohe = np.zeros(14)
    if n>0:
        n = int(np.log2(n))    
        ohe[n-1] = 1
    return ohe

# Состояние в one-hot для нейросети
def state_to_ohe(state):
    ohe = np.zeros((4,4,14))
    for i in range(4):
        for j in range(4):
            ohe[i,j] = tile_to_ohe(state[i,j])
    return ohe.transpose((2,0,1))

def position_to_tuple(position):
    return tuple([tuple(x) for x in position])

# Игровой движок
class Game_Core_2048:
    def __init__(self):
        self.reset()

    def reset(self, random_start=0., max_number=128):
        # Стандартное начало игры или со случайного состояния
        self.gameboard = np.zeros((4, 4)).astype(int)
        if np.random.rand() >= random_start:
            self.place_random_number()
            self.place_random_number()
        else:
            self.make_random_state(max_number)
        #self.gameover = False
        self.moves_dict = self.find_all_moves()
        self.gameover = self.check_gameover(self.moves_dict)
        self.score = 0
        self.inv_move_count = 0

    def random_free_cell(self):
        # Возвращает случайную пустую ячейку
        free_cells = np.argwhere(self.gameboard == 0)
        return free_cells[np.random.randint(len(free_cells))]

    def place_random_number(self):
        # Добавление случайного тайла (по принципу как в оригинале)
        n = 2 if np.random.rand() < 0.9 else 4
        cell = self.random_free_cell()
        self.gameboard[cell[0], cell[1]] = n

    def place_number(self, position, state):
        # Ставит значение в позицию (для MCTS). 
        state = state.copy()
        free_cells = np.argwhere(state == 0)
        number = 2
        if position >= len(free_cells):
            number = 4
            position -= len(free_cells)
        cell = free_cells[position]
        state[cell[0], cell[1]] = number
        moves_dict = self.find_all_moves(state)
        state_gameover = self.check_gameover(moves_dict)
        return state, state_gameover

    def make_random_state(self, max_number):
        # Генерирует случайную стартовую позицию.
        idx = np.where(self.gameboard==0)
        powers = np.random.randint(np.log2(max_number), size=(len(idx[0])))+1
        mask = np.random.randint(2, size=(len(idx[0]))).astype(np.bool)
        self.gameboard[idx] = 2**powers*mask

    def find_moves(self, gameboard, direction, already_summed_mask, offset):
        # Возвращает допустимые ходы в заданном направлении для заданного ряда/колонки.
        # Пытался избежать вложенных циклов, как в оригинальной версии.
        keep0 = []  # Хранение ходов - ось 0
        keep1 = []  # Хранение ходов - ось 1
        axis = np.argwhere(direction != 0)[0, 0]  # Ось направления
        del_axis = int(not axis)  # Ось удаления найденных ходов

        # Выбор параметров для горизонтального или вертикального направления.
        if axis:
            shape = (4, 1)
            inds = np.indices(shape)
            dirs = np.tile(direction[axis], 4).reshape(shape)
            mask = np.ones(shape) # Маска для несдвигаемых тайлов.
        else:
            shape = (1, 4)
            inds = np.indices(shape)
            dirs = np.tile(direction[axis], 4).reshape(shape)
            mask = np.ones(shape) # Маска для несдвигаемых тайлов.
        
        check_inds = inds.copy()  # Индексы для проверки доступного хода вдоль оси
        check_inds[axis] = check_inds[axis] + offset  # Начало со смещения
        mask = (mask * (gameboard[check_inds[0], check_inds[1]] != 0)).astype(bool)  # Нули не двигаем
        keep0, keep1 = check_inds[0][mask == 0], check_inds[1][mask == 0]  # Сохранение перед удалением
        if axis:
            cur_values = gameboard[:, offset].reshape(shape)  # Сохранение текущих значений в ячейках
        else:
            cur_values = gameboard[offset, :].reshape(shape)  # Сохранение текущих значений в ячейках
        cur_values_mask = np.ones(shape).astype(bool)
        # Удаление из поиска того, что дальше не двигается.
        check_inds = np.delete(check_inds, np.where(mask.flatten() == 0), axis=del_axis + 1)
        dirs = np.delete(dirs, np.where(mask.flatten() == 0), axis=del_axis)
        cur_values = np.delete(cur_values, np.where(mask.flatten() == 0), axis=del_axis)
        cur_values_mask = np.delete(cur_values_mask, np.where(mask.flatten() == 0), axis=del_axis)
        mask = np.delete(mask, np.where(mask.flatten() == 0), axis=del_axis)

        # Пока что-то можно двигать дальше
        while check_inds.size > 0:
            cur_values[cur_values_mask] = gameboard[check_inds[0], check_inds[1]][cur_values_mask] # Сохранение текущих значений в ячейках
            prev_inds = check_inds.copy()
            check_inds[axis] = check_inds[axis] + dirs * mask # Смещение индексов для проверки на 1 по направлению.
            mask = ((-direction[axis] * check_inds[axis]) >= 1 * (direction[axis] > 0)).astype(
                bool)  # Проверка на границы поля
            asm_mask = already_summed_mask[check_inds[0], check_inds[1]] == 0  # Проверка на уже суммированное значение (всегда первое по направлению)
            mask = mask * asm_mask
            new_values = gameboard[check_inds[0], check_inds[1]]  # Проверка на равенство
            cur_values_mask = new_values != 0
            mask_eq = mask * (cur_values == new_values)
            mask_free = mask * (gameboard[check_inds[0], check_inds[1]] == 0)  # Проверка на свободную ячейку
            mask = mask_eq + mask_free
            keep0 = np.r_[keep0, prev_inds[0][mask == 0]] # Сохранение завершенных сдвигов перед удалением из поиска
            keep1 = np.r_[keep1, prev_inds[1][mask == 0]]
            # Удаление из поиска того, что дальше не двигается.
            check_inds = np.delete(check_inds, np.where(mask.flatten() == 0), axis=del_axis + 1)
            dirs = np.delete(dirs, np.where(mask.flatten() == 0), axis=del_axis)
            cur_values = np.delete(cur_values, np.where(mask.flatten() == 0), axis=del_axis)
            cur_values_mask = np.delete(cur_values_mask, np.where(mask.flatten() == 0), axis=del_axis)
            mask = np.delete(mask, np.where(mask.flatten() == 0), axis=del_axis)

        # Создание возвращаемых массивов.
        moves = np.c_[keep0.reshape(-1, 1), keep1.reshape(-1, 1)]
        if axis:
            moves = moves[moves[:, 0].argsort(axis=0)]
            mask = moves[:, 1] != offset
        else:
            moves = moves[moves[:, 1].argsort(axis=0)]
            mask = moves[:, 0] != offset

        return moves, mask

    def get_offset(self, offset, direction):
        # Возвращает смещение для направления
        axis = np.argwhere(direction != 0)[0, 0]
        offset = -direction[axis] * offset
        if direction[axis] > 0:
            offset = offset - 1
        return offset, axis

    def find_all_moves(self, state=None):
        # Поиск всех возможных ходов для состояния.
        moves_dict = {}
        # Проверка направлений
        for d in range(4):
            direction = directions_dict[d]
            if state is None:
                temp_gb = self.gameboard.copy()
            else:
                temp_gb = state.copy()
            already_summed_mask = np.zeros_like(temp_gb)
            dir_dict = {}
            moves_available = False
            # Проверка смещений
            for offset in range(3):
                offset, axis = self.get_offset(offset + 1, direction)
                moves, mask = self.find_moves(temp_gb, direction, already_summed_mask, offset)
                dir_dict[offset] = (moves, mask)
                if not moves_available:
                    moves_available = np.any(mask)
                already_summed_mask[moves[mask, 0], moves[mask, 1]] = temp_gb[moves[mask, 0], moves[mask, 1]] != 0
                if axis:
                    temp_gb[moves[mask, 0], moves[mask, 1]] = temp_gb[moves[mask, 0], moves[mask, 1]] + temp_gb[
                        mask, offset]
                    temp_gb[mask, offset] = 0
                else:
                    temp_gb[moves[mask, 0], moves[mask, 1]] = temp_gb[moves[mask, 0], moves[mask, 1]] + temp_gb[
                        offset, mask]
                    temp_gb[offset, mask] = 0
            moves_dict[d] = (moves_available, dir_dict)
        # Возвращает словарь {Направление : (допустимость хода в направлении, сдвигаемые ячейки)}
        return moves_dict

    def check_gameover(self, moves_dict):
        # Проверка на конец игры
        return not np.any([vm[0] for vm in moves_dict.values()])

    def move(self, direction):
        # Движение в заданном направлении.
        if not self.gameover:
            movescore = 0
            reward = 0
            terminal = False
            if self.moves_dict[direction][0]:
                # Если направление валидно
                self.inv_move_count = 0
                invalid_move = False
                self.undo = self.gameboard.copy()
                all_moves = self.moves_dict[direction][1]
                axis = np.argwhere(directions_dict[direction] != 0)[0, 0]
                for offset, (moves, mask) in all_moves.items():
                    # Перемещение тайлов для каждого из трех смещений
                    if axis:                        
                        score_delta = self.gameboard[moves[mask, 0], moves[mask, 1]].sum()*2
                        movescore += score_delta  # Score
                        #if score_delta:
                            #reward += param_join_reward * (self.gameboard[moves[mask, 0], moves[mask, 1]]!=0).sum()  
                        # Сдвиг или суммирование на новой позиции                      
                        self.gameboard[moves[mask, 0], moves[mask, 1]] = self.gameboard[
                                                                             moves[mask, 0], moves[mask, 1]] + \
                                                                         self.gameboard[mask, offset]
                        # Обнуление старой позиции
                        self.gameboard[mask, offset] = 0
                    else:                        
                        score_delta = self.gameboard[moves[mask, 0], moves[mask, 1]].sum()*2
                        movescore += score_delta
                        #if score_delta:
                            #reward += param_join_reward * (self.gameboard[moves[mask, 0], moves[mask, 1]]!=0).sum()
                        self.gameboard[moves[mask, 0], moves[mask, 1]] = self.gameboard[
                                                                             moves[mask, 0], moves[mask, 1]] + \
                                                                         self.gameboard[offset, mask]
                        self.gameboard[offset, mask] = 0
                # Подсчет очков
                self.score += movescore
                reward = movescore                
                # Добавление случайного тайла
                self.place_random_number()
                # Проверка на возможность продолжения игры
                self.moves_dict = self.find_all_moves()
                self.gameover = self.check_gameover(self.moves_dict)
                if self.gameover:
                    terminal = True                    
            else:
                # Ход был недопустимым. Подсчет количества для invalid_move_tolerance                
                self.inv_move_count += 1
                invalid_move = True
                reward = param_inv_move_reward
            
            return reward, terminal, invalid_move
        else:
            # Игра уже закончена            
            return 0, True, False

    def restore(self):
        # Undo
        self.gameboard = self.undo.copy()
        self.moves_dict = self.find_all_moves()

    def move_from_state(self, state, direction):
        # Движение в заданном направлении из искусственного стейта для MCTS. (Копипэйст move(self) из-за недостатка времени)
        state = state.copy()
        moves_dict = self.find_all_moves(state)
        state_gameover = self.check_gameover(moves_dict)        
        if not state_gameover:
            movescore = 0
            reward = 0
            terminal = False
            if moves_dict[direction][0]:
                inv_move_count = 0
                invalid_move = False                
                all_moves = moves_dict[direction][1]
                axis = np.argwhere(directions_dict[direction] != 0)[0, 0]
                for offset, (moves, mask) in all_moves.items():                    
                    if axis:                        
                        score_delta = state[moves[mask, 0], moves[mask, 1]].sum()*2
                        movescore += score_delta  # Score
                        #if score_delta:
                            #reward += param_join_reward * (self.gameboard[moves[mask, 0], moves[mask, 1]]!=0).sum()                        
                        state[moves[mask, 0], moves[mask, 1]] = state[moves[mask, 0], moves[mask, 1]] + \
                                                                         state[mask, offset]
                        state[mask, offset] = 0
                    else:                        
                        score_delta = state[moves[mask, 0], moves[mask, 1]].sum()*2
                        movescore += score_delta
                        #if score_delta:
                            #reward += param_join_reward * (self.gameboard[moves[mask, 0], moves[mask, 1]]!=0).sum()
                        state[moves[mask, 0], moves[mask, 1]] = state[moves[mask, 0], moves[mask, 1]] + \
                                                                         state[offset, mask]
                        state[offset, mask] = 0
                #self.score += movescore
                reward = movescore                
                #self.place_random_number()
                moves_dict = self.find_all_moves(state)
                state_gameover = self.check_gameover(moves_dict)
                if state_gameover:
                    terminal = True                    
            else:                
                #inv_move_count += 1
                invalid_move = True
                reward = param_inv_move_reward
            
            return reward, state, terminal, invalid_move
        else:            
            return 0, True, False

In [0]:
"""
Environment. Изначально создавался под DQN (action -> reward, state, terminal) и GUI от pygame.
Пришлось вносить срочные изменения при переходе на метод AlphaGo.
Сейчас частично используются методы отсюда, частично из GameCore (в MCTS).
ToDo : Привести в нормальный вид.
"""
class Env2048():
    def __init__(self, gui=True, inv_move_tolerance=0):
        self.game_core = Game_Core_2048()
        #self.player = player
        self.gui = None
        if gui:
            self.gui = Game2048(self.game_core)
        self.inv_move_tolerance = inv_move_tolerance
        self.inv_move_count = 0
        #if self.player:
         #   self.player.start()

#    def draw_game(self):
#        self.game.surface.fill(colors.AZURE3)

        #self.handle_events()
#        self.game.update()
#        self.game.draw()
#        pygame.display.update()
#        self.game.clock.tick(self.game.frame_rate)

    def reset(self, random_start=0., max_number=128):
        self.game_core.reset(random_start, max_number)
        self.inv_move_count = 0

    def get_state(self):
        state = self.game_core.gameboard.copy()              
        return state

    def act(self, action_ohe, ohe_state):
        direction = ohe_to_dir(action_ohe)
        state = self.get_state()
        reward, terminal, invalid_move = self.game_core.move(direction)
        new_state = self.get_state()
        if not invalid_move:
            self.inv_move_count = 0
            if self.gui:
                self.gui.move()
        else:
            self.inv_move_count += 1            
            if self.inv_move_tolerance and self.inv_move_count >= self.inv_move_tolerance:
                terminal = True
        return state, action_ohe, reward, new_state, terminal

    def act_from_state(self, state, direction):        
        reward, new_state, terminal, invalid_move = self.game_core.move_from_state(direction)        
        if invalid_move:
            terminal = True            
        return reward, new_state, terminal

In [0]:
# Node для MCTS
class TreeNode():
    def __init__(self, game_core, state, reward, player, net, parent=None, parent_a=-1, terminal=False):
        self.game_core = game_core
        self.player = player
        self.state = state
        self.movescore = reward
        self.parent = parent
        self.parent_a = parent_a
        
        if self.player:
            # Получение p и v от сети
            with torch.no_grad():
                self.P, self.v, _ = net(torch.Tensor(state_to_ohe(state)).unsqueeze(0).to(device))
                self.P = self.P.detach().cpu().numpy().reshape(-1)
                self.v = self.v.detach().cpu().item()
            if parent is None:
                self.P = self.P*0.75 + np.random.dirichlet((0.03,0.03,0.03,0.03))*0.25
            self.N = np.zeros(4)
            self.Q = np.zeros(4)
        else:
            # Распределение вероятностей появления случайных тайлов для p, если ход не игрока. 
            freecells = np.argwhere(self.state == 0)            
            self.P = np.array([0.9]*len(freecells) + [0.1]*len(freecells)) / len(freecells)
            self.v = 0.
            self.N = np.zeros(len(freecells)*2)
            self.Q = np.zeros(len(freecells)*2)
        self.V = 0.
        self.c_puct = 1
        self.number_of_visits = 0
        self.terminal = terminal        
        self.children = {}
        self.has_children = False        

    def get_pi(self, tau=1.):
        # Возвращает policy
        pi = self.N**(1./tau) / (self.N**(1./tau)).sum()
        return pi

    def find_node(self, action, state):
        # Возвращает нод с состоянием
        if not self.player:
            return None
        if self.has_children:
            if action in self.children.keys():                
                node = self.children[action]
                if node.has_children:
                    for k, v in node.children.items():
                        if np.all(state == v.state):
                            return v                
        return None

    def make_root(self):        
        # Делает нод корнем
        self.movescore = 0
        self.parent = None
        self.parent_a = -1
        self.P = self.P*0.75 + np.random.dirichlet((0.03,0.03,0.03,0.03))*0.25

    def select_child(self):
        # Выбирает куда идти по UCB
        if self.terminal:            
            return self, True

        if self.player:            
            U = self.Q + self.c_puct*self.P*(np.sqrt(self.N.sum())/(1+self.N))            
            action = np.argmax(U)                  
        else:
            action = np.random.choice(np.arange(len(self.P)), p = self.P)
        
        self.N[action] += 1
        if self.has_children:
            if action in self.children.keys():                
                return self.children[action], False
            else:                
                return self.expand(action), True
        else:            
            return self.expand(action), True

    def expand(self, action):
        # Создание нового нода
        if self.player:
            reward, new_state, terminal, invalid_move = self.game_core.move_from_state(self.state, action)
            if invalid_move:
                terminal = True            
            self.children[action] = TreeNode(self.game_core, new_state, self.movescore+reward, not self.player, net, self, action, terminal)
        else:
            new_state, terminal = self.game_core.place_number(action, self.state)            
            self.children[action] = TreeNode(self.game_core, new_state, self.movescore, not self.player, net, self, action, terminal)
        self.has_children = True        
        return self.children[action]

    def play_to_leaf(self):
        # Проход от корня до создания нода (итерация MCTS)        
        node = self
        search_finished = False
        while not search_finished:
            node, search_finished = node.select_child()            
        return node

    def backup(self, leaf):
        # Обновление параметров
        leaf.number_of_visits += 1        
        value = leaf.v
        movescore = leaf.movescore
        parent_a = leaf.parent_a
        node = leaf.parent
        while node is not None:            
            node.number_of_visits +=1            
            if node.player:                
                node.V += value                
                node.Q[parent_a] = ((node.N[parent_a] - 1) * node.Q[parent_a] + value) / node.N[parent_a]
            parent_a = node.parent_a
            node = node.parent

In [0]:
def tree_search(root, state, net, number):
    # Поиск по дереву из состояния заданное число раз    
    for i in range(number):
        leaf = root.play_to_leaf() # Проход до нового листа (Select, Expand)        
        leaf.backup(leaf) # Backup
    return root

In [0]:
def self_play(net, dataset_size=4096, num_MCTS=999, random_start=0., max_number=128):
    # Self-play по алгоритму Alpha Zero для создания датасета для обучения сети.
    dataset = []    
    env = Env2048(gui=False, inv_move_tolerance=1)

    total_score = 0
    total_num_moves = 0
    total_inv_moves = 0
    mean_score = 0.
    mean_num_moves = 0.
    total_max_num_reached = 0
    total_num_moves = 0

    start_time = time.time()
    i = 0
    while len(dataset) < dataset_size:
        i += 1
        print("Simulation {}, dataset length {}".format(i, len(dataset)))
        env.reset(random_start, max_number)
        terminal = False     
        tau = 1. # Температура
        num_moves = 0    
        dataset_sim = []
        dataset_sim_double = []
        state = env.get_state()
        root = TreeNode(Game_Core_2048(), state, 0, True, net)
        print("Max number started: ", np.max(state))
        while not terminal:        
            state = env.get_state()            
            pi = tree_search(root, state, net, num_MCTS).get_pi(tau) # Создание policy
            direction_n = np.random.choice(np.arange(len(pi)), p=pi)
            direction = dir_to_ohe(direction_n)
            cur_score = env.game_core.score
            # Заполнение датасета с аугментацией
            dataset_sim.append([state_to_ohe(state), pi, cur_score]) 
            dataset_sim_double.append([state_to_ohe(state*2), pi, cur_score*2]) 
            h_flip_pi = pi.copy()
            h_flip_pi[0], h_flip_pi[2] = h_flip_pi[2], h_flip_pi[0] 
            dataset_sim.append([state_to_ohe(np.fliplr(state)), h_flip_pi, cur_score]) 
            dataset_sim_double.append([state_to_ohe(np.fliplr(state*2)), h_flip_pi, cur_score*2]) 
            v_flip_pi = pi.copy()
            v_flip_pi[1], v_flip_pi[3] = v_flip_pi[3], v_flip_pi[1] 
            dataset_sim.append([state_to_ohe(np.flipud(state)), v_flip_pi, cur_score]) 
            dataset_sim_double.append([state_to_ohe(np.flipud(state*2)), v_flip_pi, cur_score*2]) 
            rot90_state = np.rot90(state)
            rot90_pi = np.roll(pi, -1)
            dataset_sim.append([state_to_ohe(rot90_state), rot90_pi, cur_score]) 
            dataset_sim_double.append([state_to_ohe(rot90_state*2), rot90_pi, cur_score*2]) 
            h_flip_pi = rot90_pi.copy()
            h_flip_pi[0], h_flip_pi[2] = h_flip_pi[2], h_flip_pi[0] 
            dataset_sim.append([state_to_ohe(np.fliplr(rot90_state)), h_flip_pi, cur_score]) 
            dataset_sim_double.append([state_to_ohe(np.fliplr(rot90_state*2)), h_flip_pi, cur_score*2]) 
            v_flip_pi = rot90_pi.copy()
            v_flip_pi[1], v_flip_pi[3] = v_flip_pi[3], v_flip_pi[1] 
            dataset_sim.append([state_to_ohe(np.flipud(rot90_state)), v_flip_pi, cur_score]) 
            dataset_sim_double.append([state_to_ohe(np.flipud(rot90_state*2)), v_flip_pi, cur_score*2]) 
            dataset_sim.append([state_to_ohe(np.rot90(state, 2)), np.roll(pi, -2), cur_score]) 
            dataset_sim_double.append([state_to_ohe(np.rot90(state*2, 2)), np.roll(pi, -2), cur_score*2]) 
            dataset_sim.append([state_to_ohe(np.rot90(state, 3)), np.roll(pi, -3), cur_score]) 
            dataset_sim_double.append([state_to_ohe(np.rot90(state*2, 3)), np.roll(pi, -3), cur_score*2]) 

            _, _, reward, new_state, terminal = env.act(direction, ohe_state=False)      
            new_root = root.find_node(direction_n, new_state) # Reuse дерева
            if new_root:                
                root = new_root
                root.make_root()
            else:                
                root = TreeNode(Game_Core_2048(), new_state, 0, True, net)            
            num_moves += 1
            if num_moves > 30:
                tau = 0.1
            if reward == param_inv_move_reward:
                total_inv_moves += 1                

        cur_score = env.game_core.score
        for data in dataset_sim:
            data[2] = np.log1p(cur_score - data[2])
        for data in dataset_sim_double:
            data[2] = np.log1p(cur_score*2 - data[2])
        dataset += dataset_sim
        dataset += dataset_sim_double
        with open(os.path.join(PATH, "dataset.pickle"), "wb") as f:
            pickle.dump(dataset, f)
        total_score += cur_score
        total_num_moves += num_moves
        total_max_num_reached += np.max(new_state)
        print("Max number reached: {}, moves made: {}, score: {}".format(np.max(new_state), num_moves, cur_score))
        print("Time: ", (time.time() - start_time)/60.)

    mean_score = total_score/i
    mean_num_moves = total_num_moves/i
    mean_max_num_reached = total_max_num_reached/i
    mean_inv_moves = total_inv_moves/total_num_moves
        
    return dataset, (mean_score, mean_num_moves, mean_max_num_reached, mean_inv_moves)

In [0]:
# Сеть по образцу из статьи.
class ConvBlock(nn.Module):
    def __init__(self):
        super(ConvBlock, self).__init__()
        self.action_size = param_n_actions
        self.conv1 = nn.Conv2d(14, 128, 3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(128)

    def forward(self, s):        
        s = F.relu(self.bn1(self.conv1(s)))
        return s

class ResBlock(nn.Module):
    def __init__(self, inplanes=128, planes=128, stride=1, downsample=None):
        super(ResBlock, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

    def forward(self, x):
        residual = x
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += residual
        out = F.relu(out)
        return out

class ConvBlockWider(nn.Module):
    def __init__(self):
        super(ConvBlockWider, self).__init__()        
        self.conv1 = nn.Conv2d(14, 256, 3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(256)

    def forward(self, s):        
        s = F.relu(self.bn1(self.conv1(s)))
        return s

class ResBlockWider(nn.Module):
    def __init__(self, inplanes=256, planes=256, stride=1, downsample=None):
        super(ResBlockWider, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

    def forward(self, x):
        residual = x
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += residual
        out = F.relu(out)
        return out

class OutBlock(nn.Module):
    def __init__(self):
        super(OutBlock, self).__init__()
        self.conv = nn.Conv2d(256, 14, kernel_size=1) 
        self.bn = nn.BatchNorm2d(14)
        self.fc1 = nn.Linear(14*4*4, 32)
        self.fc2 = nn.Linear(32, 1)
        
        self.conv1 = nn.Conv2d(256, 32, kernel_size=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.logsoftmax = nn.LogSoftmax(dim=-1)
        self.fc = nn.Linear(32*4*4, param_n_actions)
    
    def forward(self,s):
        v = F.relu(self.bn(self.conv(s))) 
        v = v.view(-1, 14*4*4)  
        v = F.relu(self.fc1(v))
        v = F.relu(self.fc2(v))
        
        p = F.relu(self.bn1(self.conv1(s))) 
        p = p.view(-1, 32*4*4)
        p_logits = self.fc(p)        
        p_probas = self.logsoftmax(p_logits).exp()        
        return p_probas, v, p_logits

class Alpha2048net(nn.Module):
    def __init__(self, device="cpu"):
        super(Alpha2048net, self).__init__()
        #self.conv = ConvBlock()
        self.conv = ConvBlockWider()
        #for block in range(19):
        for block in range(3):
            #setattr(self, "res_%i" % block, ResBlock())
            setattr(self, "res_%i" % block, ResBlockWider())
        self.outblock = OutBlock()
        self.to(device)
    
    def forward(self,s):
        s = self.conv(s)
        #for block in range(19):
        for block in range(3):
            s = getattr(self, "res_%i" % block)(s)
        s = self.outblock(s)
        return s

In [10]:
net = Alpha2048net(device=device)
net_optim = optim.Adam(net.parameters(), lr=1e-6, weight_decay=1e-5)
net.eval()
print("")




In [11]:
net.load_state_dict(torch.load(os.path.join(PATH, "net_wider3_weights_dataset_18-wider3-32768.pth"), map_location=device))

IncompatibleKeys(missing_keys=[], unexpected_keys=[])

In [0]:
# Генерация датасета из игр.
#start_time = time.time()
dataset, stats = self_play(net, dataset_size=65536, num_MCTS=512, random_start=0.75, max_number=256)

with open(os.path.join(PATH, "dataset.pickle"), "wb") as f:
    pickle.dump(dataset, f)

print("Mean score {:.3f}, mean number of moves {:.3f}, mean max number {:.3f}, invalid moves ratio {:.3f}".format(*stats))
#print("Time: ", (time.time() - start_time)/60.)

Simulation 1, dataset length 0
Max number started:  256
Max number reached: 256, moves made: 63, score: 428
Time:  6.480619430541992
Simulation 2, dataset length 1008
Max number started:  128
Max number reached: 256, moves made: 48, score: 592
Time:  12.481745676199596
Simulation 3, dataset length 1776
Max number started:  128
Max number reached: 128, moves made: 24, score: 132
Time:  15.46567670504252
Simulation 4, dataset length 2160
Max number started:  256
Max number reached: 256, moves made: 41, score: 248
Time:  20.26848164399465
Simulation 5, dataset length 2816
Max number started:  256
Max number reached: 512, moves made: 42, score: 760
Time:  26.14280281464259
Simulation 6, dataset length 3488
Max number started:  4
Max number reached: 128, moves made: 130, score: 1280
Time:  46.839362347126006
Simulation 7, dataset length 5568
Max number started:  256
Max number reached: 256, moves made: 70, score: 528
Time:  57.048868318398796
Simulation 8, dataset length 6688
Max number sta

In [0]:
print("")

In [0]:
def train(epoch, net, optim, loader, log):
    net.train()
    mean_loss = 0.
    for data, label in loader:
        optim.zero_grad()
        data = data.to(device)
        label_pi = label[:,:-1].to(device)        
        label_z = label[:,-1].to(device).view(-1, 1)
        p, v, p_logits = net(data)
                
        loss_p = torch.sum((-label_pi * (1e-8 + p).log()), 1) # Policy loss        
        loss_v =  ((label_z - v)**2).view(-1) # Value loss        
        loss = (loss_p + loss_v).mean()
        
        loss.backward()
        optim.step()

        mean_loss += loss.item()        
    
    mean_loss /= len(loader)
    log["train"].append(mean_loss)
    if epoch % 10 == 0:
        print("Epoch {} loss: {:.3f}".format(epoch, mean_loss))

In [0]:
def validate(epoch, net, loader, log):
    net.eval()
    mean_loss = 0.
    for data, label in loader:        
        data = data.to(device)
        label_pi = label[:,:-1].to(device)        
        label_z = label[:,-1].to(device).view(-1, 1)
        p, v, p_logits = net(data)        
        
        loss_p = torch.sum((-label_pi * (1e-8 + p).log()), 1)        
        loss_v =  ((label_z - v)**2).view(-1)
        loss = (loss_p + loss_v).mean()
        mean_loss += loss.item()        
    
    mean_loss /= len(loader)
    log["val"].append(mean_loss)
    if epoch % 10 == 0:
        print("Epoch {} validation loss: {:.3f}".format(epoch, mean_loss))

In [0]:
# Создание pytorch loaders
#np.random.shuffle(dataset)
#split_idx = int(len(dataset)*.75)
dataset_train = dataset#[:split_idx]
#dataset_test = dataset[split_idx:]
dataset_train_T = list(zip(*dataset_train))
#dataset_test_T = list(zip(*dataset_test))

data_state_train = torch.Tensor(dataset_train_T[0])
data_pi_train = torch.Tensor(dataset_train_T[1])
data_z_train = torch.Tensor(dataset_train_T[2]).view(-1,1)
labels_train = torch.cat((data_pi_train, data_z_train), dim=1)
#data_state_test = torch.Tensor(dataset_test_T[0])
#data_pi_test = torch.Tensor(dataset_test_T[1])
#data_z_test = torch.Tensor(dataset_test_T[2]).view(-1,1)
#labels_test = torch.cat((data_pi_test, data_z_test), dim=1)

tensor_dataset_train = TensorDataset(data_state_train, labels_train)
loader_train = DataLoader(tensor_dataset_train, batch_size=param_batch_size, shuffle=True)

#tensor_dataset_test = TensorDataset(data_state_test, labels_test)
#loader_test = DataLoader(tensor_dataset_test, batch_size=param_batch_size, shuffle=False)

In [16]:
losses_log  = {"train" : [], "val" : []}
for epoch in range(31):
    train(epoch, net, net_optim, loader_train, losses_log)
    #validate(epoch, net, loader_test, losses_log)
torch.save(net.state_dict(), os.path.join(PATH, "net_wider3_weights_dataset_18-wider3-32768.pth"))
net.eval()
print("")

Epoch 0 loss: 3.521
Epoch 10 loss: 1.651
Epoch 20 loss: 0.990
Epoch 30 loss: 0.853



In [14]:
with open(os.path.join(PATH, "dataset_18-wider3-32768.pickle"), "rb") as f:
    dataset = pickle.load(f)
print(len(dataset))

33984


In [0]:
net.load_state_dict(torch.load(os.path.join(PATH, "net_weights_dataset_5-2048.pth"), map_location=device))

IncompatibleKeys(missing_keys=[], unexpected_keys=[])