In [1]:
from c2048 import Game, push

In [2]:
import torch
from torch import nn
import torch.optim as optim
import numpy as np

In [3]:
class NN2048(nn.Module):
    def __init__(self, input_size=16, filter1=512, filter2=4096, drop_prob=0.):
        super(NN2048, self).__init__()
        self.conv_a = nn.Conv2d(in_channels=input_size, out_channels=filter1, kernel_size=(2,1), padding=0)
        self.conv_b = nn.Conv2d(in_channels=input_size, out_channels=filter1, kernel_size=(1,2), padding=0)
        self.conv_aa = nn.Conv2d(in_channels=filter1, out_channels=filter2, kernel_size=(2,1), padding=0)
        self.conv_ab = nn.Conv2d(in_channels=filter1, out_channels=filter2, kernel_size=(1,2), padding=0)
        self.conv_ba = nn.Conv2d(in_channels=filter1, out_channels=filter2, kernel_size=(2,1), padding=0)
        self.conv_bb = nn.Conv2d(in_channels=filter1, out_channels=filter2, kernel_size=(1,2), padding=0)
        self.relu = nn.ReLU()
        self.W_aa = nn.Linear(filter2 * 8, 1)
        self.W_ab = nn.Linear(filter2 * 9, 1)
        self.W_ba = nn.Linear(filter2 * 9, 1)
        self.W_bb = nn.Linear(filter2 * 8, 1)

    def flatten(self, x):
        N = x.size()[0]
        return x.view(N, -1)
        
    def forward(self, x):
        x = x.float()
        a = self.relu(self.conv_a(x))
        b = self.relu(self.conv_b(x))
        aa = self.flatten(self.relu(self.conv_aa(a)))
        ab = self.flatten(self.relu(self.conv_ab(a)))
        ba = self.flatten(self.relu(self.conv_ba(b)))
        bb = self.flatten(self.relu(self.conv_bb(b)))
        out = self.W_aa(aa) + self.W_ab(ab) + self.W_ba(ba) + self.W_bb(bb)
        return out

In [4]:
table ={2**i:i for i in range(1,16)}
table[0]=0
def make_input(grid):
    g0 = grid
    r = np.zeros(shape=(16, 4, 4))
    for i in range(4):
        for j in range(4):
            v = g0[i, j]
            r[table[v],i, j]=1
    return r

In [5]:
def add_two(mat):
    indexs=np.argwhere(mat==0)
    index=np.random.randint(0,len(indexs))
    mat[tuple(indexs[index])] = 2
    return mat

In [18]:
def Vchange(grid, v):
    g0 = grid
    g1 = g0[:,::-1,:]
    g2 = g0[:,:,::-1]
    g3 = g2[:,::-1,:]
    r0 = grid.swapaxes(1,2)
    r1 = r0[:,::-1,:]
    r2 = r0[:,:,::-1]
    r3 = r2[:,::-1,:]
    xtrain = np.array([g0,g1,g2,g3,r0,r1,r2,r3])
    ytrain = np.array([v]*8)
    return xtrain, ytrain

def gen_sample_and_learn(model, optimizer, loss_fn, is_train = True):
    model.eval()
    game_len = 0
    game_score = 0
    last_grid1 = np.zeros((4,4),dtype=np.int)
    last_grid1 = add_two(last_grid1)
    last_grid2 = make_input(last_grid1)
    while True:
        grid_array = add_two(last_grid1)
        board_list = []
        for m in range(4):
            g = grid_array.copy()
            s = push(g, m%4)
            if s >= 0:
                board_list.append( (g, m, s) )
        if board_list:
            boards = np.array([make_input(g) for g,m,s in board_list])
            p = model(torch.from_numpy(boards).cuda()).flatten().detach()        
            game_len+=1
            best_move = -1
            best_v = None
            for i, (g,m,s) in enumerate(board_list):
                v = 2*s + p[i].item()
                if best_v is None or v > best_v:
                    best_v = v
                    best_move = m
                    best_score = 2*s
                    best_grid1 = board_list[i][0]
                    best_grid2 = boards[i]
                    
            game_score += best_score
        else:
            best_v = 0
            best_grid1 = None
            best_grid2 = None
            
        if is_train:
            x, y = Vchange(last_grid2, best_v)
            x = torch.from_numpy(x).cuda()
            y = torch.from_numpy(y).unsqueeze(dim=1).cuda().float()
            model.train()
            optimizer.zero_grad()
            pred = model(x)
            loss = loss_fn(pred, y) / 2
            last_loss = loss.item()
            loss.backward()
#             nn.utils.clip_grad_norm_(model.parameters(), 5.0)
            optimizer.step()
            model.eval()
#             if game_len % 50 == 0:
#                 print (game_len, last_loss)
                
        if not board_list:
            break
        last_grid2 = best_grid2
        last_grid1 = best_grid1
        
    return game_len, grid_array.max(), game_score

In [16]:
lr = 1e-3
weight_decay = 1e-5
model = NN2048().cuda()
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay, betas=(0.5, 0.999))
loss_fn=nn.MSELoss()

for j in range(200):
    result = gen_sample_and_learn(model, optimizer, loss_fn)
    print(j, result)
    if result is not None and result[1] >= 4096:
        break

0 (152, 128, 1388)
1 (740, 1024, 11720)
2 (234, 128, 2376)
3 (608, 512, 8212)
4 (580, 512, 7736)
5 (303, 256, 3452)
6 (280, 256, 3196)
7 (668, 512, 9144)
8 (402, 256, 4588)
9 (518, 512, 6808)
10 (846, 1024, 13052)
11 (591, 512, 8204)
12 (532, 512, 7168)
13 (636, 512, 8456)
14 (569, 512, 7656)
15 (581, 512, 7828)
16 (544, 512, 7404)
17 (922, 1024, 14312)
18 (644, 512, 8760)
19 (867, 512, 12364)
20 (954, 1024, 14588)
21 (632, 512, 8424)
22 (857, 1024, 13476)
23 (1426, 2048, 25944)
24 (906, 1024, 13700)
25 (635, 512, 8700)
26 (1017, 1024, 15356)
27 (586, 512, 7976)
28 (408, 256, 4648)
29 (1000, 1024, 15160)
30 (1288, 1024, 20516)
31 (531, 512, 7260)
32 (470, 512, 6504)
33 (1062, 1024, 16668)
34 (1084, 1024, 16904)
35 (580, 512, 7736)
36 (1035, 1024, 16428)
37 (701, 512, 9992)
38 (542, 512, 7368)
39 (1108, 1024, 18016)
40 (482, 512, 6616)
41 (955, 1024, 15100)
42 (524, 512, 7212)
43 (813, 1024, 12672)
44 (1155, 1024, 18172)
45 (880, 1024, 13836)
46 (786, 1024, 12376)
47 (285, 256, 3260)
48

KeyboardInterrupt: 

In [19]:
num_epochs = 50

def test(model):
    epoch = 0
    while epoch != num_epochs:
        epoch += 1
        res = gen_sample_and_learn(model, None, None, False)
        print (epoch, res)

test(model)

1 (1556, 2048, 27764)
2 (1550, 2048, 27744)
3 (1744, 2048, 31704)
4 (1047, 1024, 16540)
5 (1065, 1024, 16752)
6 (1516, 2048, 27200)
7 (676, 512, 9704)
8 (1903, 2048, 34644)
9 (1494, 2048, 26776)
10 (1092, 1024, 17164)
11 (1447, 2048, 26116)
12 (1922, 2048, 34860)
13 (838, 1024, 13048)
14 (1050, 1024, 16556)
15 (1563, 2048, 27836)
16 (1056, 1024, 16600)
17 (1442, 2048, 26076)
18 (1024, 1024, 16332)
19 (543, 512, 7372)
20 (1511, 2048, 27196)
21 (641, 512, 8492)
22 (1045, 1024, 16588)
23 (1107, 1024, 17292)
24 (1055, 1024, 16604)
25 (830, 1024, 12812)
26 (1016, 1024, 16096)
27 (1050, 1024, 16556)
28 (1595, 2048, 28332)
29 (1005, 1024, 16000)
30 (1186, 1024, 18444)
31 (1011, 1024, 16156)
32 (815, 1024, 12684)
33 (1039, 1024, 16460)
34 (1052, 1024, 16580)
35 (1180, 1024, 18376)
36 (2032, 2048, 36504)
37 (1553, 2048, 27736)
38 (935, 1024, 14428)
39 (546, 512, 7388)
40 (1019, 1024, 16236)
41 (1320, 2048, 23816)
42 (613, 512, 8216)
43 (1396, 2048, 24760)
44 (1007, 1024, 16108)
45 (832, 1024, 1