In [1]:
import torch
from torch import nn
import torch.optim as optim
import numpy as np
import random

In [2]:
class NN2048(nn.Module):
    def __init__(self, input_size=16, filter1=512, filter2=4096, drop_prob=0.):
        super(NN2048, self).__init__()
        self.conv_a = nn.Conv2d(in_channels=input_size, out_channels=filter1, kernel_size=(2,1), padding=0)
        self.conv_b = nn.Conv2d(in_channels=input_size, out_channels=filter1, kernel_size=(1,2), padding=0)
        self.conv_aa = nn.Conv2d(in_channels=filter1, out_channels=filter2, kernel_size=(2,1), padding=0)
        self.conv_ab = nn.Conv2d(in_channels=filter1, out_channels=filter2, kernel_size=(1,2), padding=0)
        self.conv_ba = nn.Conv2d(in_channels=filter1, out_channels=filter2, kernel_size=(2,1), padding=0)
        self.conv_bb = nn.Conv2d(in_channels=filter1, out_channels=filter2, kernel_size=(1,2), padding=0)
        self.relu = nn.ReLU()
        self.W_aa = nn.Linear(filter2 * 8, 1)
        self.W_ab = nn.Linear(filter2 * 9, 1)
        self.W_ba = nn.Linear(filter2 * 9, 1)
        self.W_bb = nn.Linear(filter2 * 8, 1)

    def flatten(self, x):
        N = x.size()[0]
        return x.view(N, -1)
        
    def forward(self, x):
        x = x.float()
        a = self.relu(self.conv_a(x))
        b = self.relu(self.conv_b(x))
        aa = self.flatten(self.relu(self.conv_aa(a)))
        ab = self.flatten(self.relu(self.conv_ab(a)))
        ba = self.flatten(self.relu(self.conv_ba(b)))
        bb = self.flatten(self.relu(self.conv_bb(b)))
        out = self.W_aa(aa) + self.W_ab(ab) + self.W_ba(ba) + self.W_bb(bb)
        return out

In [3]:
def make_input(grid):
    r = np.zeros(shape=(16, 4, 4))
    for i in range(4):
        for j in range(4):
            r[grid[i, j],i, j]=1
    return r

def add_two(mat):
    indexs=np.argwhere(mat==0)
    index=np.random.randint(0,len(indexs))
    mat[tuple(indexs[index])] = 1
    return mat

In [4]:
singleScore=[0,0,4,16,48,128,320,768,1792,4096,9216,20480,45056,98304,212992,458752,983040]
moveDict=np.load('move.npy')

def move(list):
    return moveDict[list[0],list[1],list[2],list[3],:]

def lookup(x):
    return singleScore[x]

lookup = np.vectorize(lookup)

def getScore(matrix):
    return np.sum(lookup(matrix))

def getMove(grid):
    board_list = []
    for i in range(4):
        newGrid=moveGrid(grid, i)
        if not isSame(grid,newGrid):
            board_list.append((newGrid, i, getScore(newGrid)))
    return board_list
        
def moveGrid(grid,i):
    # new=np.zeros((4,4),dtype=np.int)
    new = None
    if i==0:
        # move up
        grid=np.transpose(grid)
        new = np.stack([move(grid[row,:]) for row in range(4)], axis = 0).astype(int).T
    elif i==1:
        # move left
        new = np.stack([move(grid[row,:]) for row in range(4)], axis = 0).astype(int)
    elif i==2:
        # move down
        grid=np.transpose(grid)
        new = np.stack([np.flip(move(np.flip(grid[row,:]))) for row in range(4)], axis = 0).astype(int).T
    elif i==3:
        # move right
        new = np.stack([np.flip(move(np.flip(grid[row,:]))) for row in range(4)], axis = 0).astype(int)
    return new

def isSame(grid1,grid2):
    return np.all(grid1==grid2)

In [5]:
def Vchange(grid, v):
    g0 = grid
    g1 = g0[:,::-1,:]
    g2 = g0[:,:,::-1]
    g3 = g2[:,::-1,:]
    r0 = grid.swapaxes(1,2)
    r1 = r0[:,::-1,:]
    r2 = r0[:,:,::-1]
    r3 = r2[:,::-1,:]
    xtrain = np.array([g0,g1,g2,g3,r0,r1,r2,r3])
    ytrain = np.array([v]*8)
    return xtrain, ytrain

def gen_sample_and_learn(model, optimizer, loss_fn, is_train = False, explorationProb=0.1):
    model.eval()
    game_len = 0
    game_score = 0
    last_grid1 = np.zeros((4,4),dtype=np.int)
    last_grid1 = add_two(last_grid1)
    last_grid2 = make_input(last_grid1)
    last_loss = 0

    while True:
        grid_array = add_two(last_grid1)
        board_list = getMove(grid_array)
        if board_list:
            boards = np.array([make_input(g) for g,m,s in board_list])
            p = model(torch.from_numpy(boards).cuda()).flatten().detach()        
            game_len += 1
            best_v = None
            for i, (g,m,s) in enumerate(board_list):
                v = (s - game_score) + p[i].item()
                if best_v is None or v > best_v:
                    best_v = v
                    best_score = s
                    best_grid1 = board_list[i][0]
                    best_grid2 = boards[i]
                    
        else:
            best_v = 0
            best_grid1 = None
            best_grid2 = None
            
        if is_train:
            x, y = Vchange(last_grid2, best_v)
            x = torch.from_numpy(x).cuda()
            y = torch.from_numpy(y).unsqueeze(dim=1).cuda().float()
            model.train()
            optimizer.zero_grad()
            pred = model(x)
            loss = loss_fn(pred, y)/2
            last_loss = loss.item()
            loss.backward()
#             nn.utils.clip_grad_norm_(model.parameters(), 10.0) #
            optimizer.step()
            model.eval()
#             if game_len % 50 == 0:
#                 print (game_len, last_loss)
                
        if not board_list:
            break
            
        # gibbs sampling or espilon-greedy
        if is_train and random.random() < explorationProb:
            idx = random.randint(0, len(board_list) - 1)
            game_score = board_list[idx][2]
            last_grid1 = board_list[idx][0]
            last_grid2 = boards[idx]
        else:
            game_score = best_score
            last_grid1 = best_grid1
            last_grid2 = best_grid2
        
    return game_len, 2**grid_array.max(), game_score, last_loss

In [6]:
lr = 1e-3
weight_decay = 1e-6
beta1 = 0.8

model = NN2048().cuda()
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay, betas=(beta1, 0.999))
loss=nn.MSELoss()

In [7]:
num_epochs = 500

def train(model, optimizer, loss):
    epoch = 0
    while epoch != num_epochs:
        epoch += 1
        game_len, max_score, game_score, last_loss = gen_sample_and_learn(model, optimizer, loss, True, 0)
        print ('epoch', epoch, game_len, max_score, game_score, last_loss)
    
train(model, optimizer, loss)

epoch 1 183 128 1708 48358292.0
epoch 2 327 256 3820 9177346.0
epoch 3 357 256 4136 4941937.0
epoch 4 534 512 7276 7169395.0
epoch 5 339 256 3932 3546295.0
epoch 6 303 256 3548 2301029.0
epoch 7 288 256 3276 2189338.5
epoch 8 460 512 6160 2036827.5
epoch 9 284 256 3216 2543565.5
epoch 10 520 512 7152 1699343.25
epoch 11 180 128 1692 1005748.3125
epoch 12 405 256 4728 2928656.25
epoch 13 295 256 3324 1410495.25
epoch 14 228 128 2272 583761.375
epoch 15 496 512 6712 2718500.0
epoch 16 449 512 5912 1774196.0
epoch 17 292 256 3280 1200197.25
epoch 18 319 256 3596 904424.5625
epoch 19 642 512 8404 2231533.0
epoch 20 796 1024 12524 45923924.0
epoch 21 332 256 3884 5655401.5
epoch 22 874 1024 13420 11231558.0
epoch 23 498 512 6556 9027185.0
epoch 24 601 512 8104 6758472.0
epoch 25 514 512 7120 7210343.0
epoch 26 550 512 7420 5027830.0
epoch 27 603 512 8124 7076512.0
epoch 28 880 1024 13548 16244992.0
epoch 29 231 128 2384 6075521.0
epoch 30 664 512 8760 3068349.5
epoch 31 1037 1024 16448 1880

epoch 236 1774 2048 32332 427115840.0
epoch 237 1939 2048 35168 452629760.0
epoch 238 1753 2048 32160 422012512.0
epoch 239 597 512 7932 218271424.0
epoch 240 1024 1024 16276 130852712.0
epoch 241 1544 2048 27676 291618496.0
epoch 242 599 512 8092 127791288.0
epoch 243 1895 2048 34236 269127872.0
epoch 244 1030 1024 16328 166544160.0
epoch 245 549 512 7440 160299648.0
epoch 246 799 1024 12492 87342488.0
epoch 247 1061 1024 16664 70324096.0
epoch 248 1517 2048 27252 320203584.0
epoch 249 789 1024 12392 281971232.0
epoch 250 1035 1024 16428 147069936.0
epoch 251 1544 2048 27676 215336976.0
epoch 252 991 1024 15820 184757152.0
epoch 253 1552 2048 27740 135413408.0
epoch 254 953 1024 15424 114659624.0
epoch 255 1086 1024 17036 105984128.0
epoch 256 1040 1024 16464 113278712.0
epoch 257 1127 1024 18180 69017480.0
epoch 258 1041 1024 16472 71477456.0
epoch 259 1491 2048 26972 147666480.0
epoch 260 1056 1024 16600 146684848.0
epoch 261 1048 1024 16520 90861056.0
epoch 262 558 512 7532 5771225

epoch 455 1792 2048 32684 210591104.0
epoch 456 964 1024 15520 219404128.0
epoch 457 1037 1024 16324 77722848.0
epoch 458 1850 2048 32260 248639008.0
epoch 459 2971 4096 59812 1163210752.0
epoch 460 953 1024 15080 585645440.0
epoch 461 1545 2048 27680 788963904.0
epoch 462 1529 2048 27128 581698560.0
epoch 463 2005 2048 36240 486929440.0
epoch 464 2324 2048 41172 468842688.0
epoch 465 1000 1024 15544 451783200.0
epoch 466 1083 1024 16900 127033456.0
epoch 467 1191 1024 19036 104174256.0
epoch 468 2538 4096 51816 690284544.0
epoch 469 1764 2048 32412 726021120.0
epoch 470 1828 2048 33040 801590400.0
epoch 471 1911 2048 34780 438189568.0
epoch 472 1854 2048 33480 786504064.0
epoch 473 1232 1024 19544 197428800.0
epoch 474 606 512 8140 120423800.0
epoch 475 1119 1024 17436 223253472.0
epoch 476 1590 2048 28124 128936912.0
epoch 477 2055 2048 36768 317257088.0
epoch 478 1700 2048 30200 410689728.0
epoch 479 1575 2048 27996 526655712.0
epoch 480 2035 2048 36652 271640768.0
epoch 481 2006 20

In [8]:
num_epochs = 100

train(model, optimizer, loss)

epoch 1 1175 1024 18820 518705088.0
epoch 2 956 1024 15108 222067440.0
epoch 3 1632 2048 28888 386933568.0
epoch 4 1894 2048 33884 526279808.0
epoch 5 1092 1024 16984 241735440.0
epoch 6 1297 2048 23632 321521472.0
epoch 7 1820 2048 32960 194003456.0
epoch 8 730 512 9916 220834144.0
epoch 9 2052 2048 36864 125004352.0
epoch 10 2027 2048 36588 127902992.0
epoch 11 1374 1024 23408 310899648.0
epoch 12 1860 2048 33524 163150560.0
epoch 13 1949 2048 35272 444601536.0
epoch 14 840 1024 13084 358040160.0
epoch 15 1052 1024 16604 196852928.0
epoch 16 2069 2048 37016 142798448.0
epoch 17 1559 2048 27804 198553888.0
epoch 18 973 1024 15664 319185472.0
epoch 19 919 1024 14744 190993552.0
epoch 20 1545 2048 27680 195034528.0
epoch 21 1502 2048 27068 99031528.0
epoch 22 1507 2048 27116 202415152.0
epoch 23 1304 1024 21668 467251328.0
epoch 24 1058 1024 16680 308212288.0
epoch 25 1923 2048 35072 340256192.0
epoch 26 1755 2048 32172 363763200.0
epoch 27 1047 1024 16512 154744544.0
epoch 28 2976 4096

In [11]:
num_epochs = 50

model.cuda()

def test(model):
    epoch = 0
    while epoch != num_epochs:
        epoch += 1
        game_len, max_score, game_score, last_loss = gen_sample_and_learn(model, None, None, False)
        print ('epoch', epoch, game_len, max_score, game_score, last_loss)

test(model)

epoch 1 1779 2048 32564 0
epoch 2 1800 2048 32740 0
epoch 3 1053 1024 16584 0
epoch 4 2051 2048 36872 0
epoch 5 1043 1024 16484 0
epoch 6 1069 1024 16776 0
epoch 7 2060 2048 36872 0
epoch 8 1989 2048 35676 0
epoch 9 2616 4096 52712 0
epoch 10 1052 1024 16624 0
epoch 11 2071 2048 37028 0
epoch 12 920 1024 14720 0
epoch 13 1231 1024 19788 0
epoch 14 1965 2048 35464 0
epoch 15 1810 2048 32860 0
epoch 16 1864 2048 33448 0
epoch 17 1076 1024 16848 0
epoch 18 3606 4096 72836 0
epoch 19 2636 4096 52936 0
epoch 20 2197 2048 38872 0
epoch 21 2030 2048 36244 0
epoch 22 2045 2048 36744 0
epoch 23 2959 4096 59724 0
epoch 24 2063 2048 36972 0
epoch 25 3087 4096 61548 0
epoch 26 2117 2048 37656 0
epoch 27 952 1024 15048 0
epoch 28 2060 2048 36920 0
epoch 29 1764 2048 32240 0
epoch 30 1175 1024 18844 0
epoch 31 1838 2048 33164 0
epoch 32 2010 2048 36268 0
epoch 33 1010 1024 16028 0
epoch 34 3049 4096 60780 0
epoch 35 1041 1024 16496 0
epoch 36 2766 4096 56632 0
epoch 37 2058 2048 36920 0
epoch 38 898

In [None]:
num_epochs = 50

model.cuda()

def test(model):
    epoch = 0
    while epoch != num_epochs:
        epoch += 1
        game_len, max_score, game_score, last_loss = gen_sample_and_learn(model, None, None, False)
        print ('epoch', epoch, game_len, max_score, game_score, last_loss)

test(model)

epoch 1 1847 2048 33308 0
epoch 2 2026 2048 36412 0
epoch 3 2178 2048 39084 0
epoch 4 976 1024 15468 0
epoch 5 2961 4096 59736 0
epoch 6 1814 2048 32892 0
epoch 7 1826 2048 33004 0
epoch 8 1942 2048 35212 0
epoch 9 2788 4096 56652 0
epoch 10 1971 2048 35868 0
epoch 11 1342 1024 21140 0
epoch 12 1037 1024 16452 0
epoch 13 1568 2048 27920 0
epoch 14 2046 2048 36748 0
epoch 15 1097 1024 17192 0
epoch 16 2090 2048 37244 0
epoch 17 2082 2048 37108 0
epoch 18 1578 2048 28012 0
epoch 19 2079 2048 37164 0
epoch 20 1456 2048 26260 0
epoch 21 1977 2048 35576 0
epoch 22 1014 1024 16088 0
epoch 23 2986 4096 60012 0


In [9]:
import os
experiment_dir = "model"
filename = "model1.pth.tar"
num_epochs = 500

def save_model(state, filename='model.pth.tar'):
    filename = os.path.join(experiment_dir, filename)
    torch.save(state, filename)

save_model({
    'epoch': num_epochs,
    'state_dict': model.cpu().state_dict(),
    'optimizer': optimizer.state_dict(),
}, filename)