In [9]:
import torch
from torch import nn
import torch.optim as optim
import numpy as np
import random

In [10]:
class NN2048(nn.Module):
    def __init__(self, input_size=16, filter1=256, filter2=2048, drop_prob=0.):
        super(NN2048, self).__init__()

        self.conv_1 = nn.Conv2d(in_channels=input_size, out_channels=input_size, kernel_size=(1,1), padding=0)
        self.conv_a = nn.Conv2d(in_channels=input_size, out_channels=filter1, kernel_size=(2,1), padding=0)
        self.conv_a3 = nn.Conv2d(in_channels=input_size, out_channels=filter1, kernel_size=(3,1), padding=0)
        self.conv_a4 = nn.Conv2d(in_channels=input_size, out_channels=filter1, kernel_size=(4,1), padding=0)
        self.conv_b = nn.Conv2d(in_channels=input_size, out_channels=filter1, kernel_size=(1,2), padding=0)
        self.conv_b3 = nn.Conv2d(in_channels=input_size, out_channels=filter1, kernel_size=(1,3), padding=0)
        self.conv_b4 = nn.Conv2d(in_channels=input_size, out_channels=filter1, kernel_size=(1,4), padding=0)
        self.conv_c = nn.Conv2d(in_channels=input_size, out_channels=filter1, kernel_size=(2,2), padding=0)
        
        self.conv_aa = nn.Conv2d(in_channels=filter1, out_channels=filter2, kernel_size=(2,1), padding=0)
        self.conv_ab = nn.Conv2d(in_channels=filter1, out_channels=filter2, kernel_size=(1,2), padding=0)
        self.conv_ba = nn.Conv2d(in_channels=filter1, out_channels=filter2, kernel_size=(2,1), padding=0)
        self.conv_bb = nn.Conv2d(in_channels=filter1, out_channels=filter2, kernel_size=(1,2), padding=0)
        
        self.conv_ab3 = nn.Conv2d(in_channels=filter1, out_channels=filter2, kernel_size=(1,3), padding=0)
        self.conv_ba3 = nn.Conv2d(in_channels=filter1, out_channels=filter2, kernel_size=(3,1), padding=0)
        self.conv_ab4 = nn.Conv2d(in_channels=filter1, out_channels=filter2, kernel_size=(1,4), padding=0)
        self.conv_ba4 = nn.Conv2d(in_channels=filter1, out_channels=filter2, kernel_size=(4,1), padding=0)
        self.conv_c2 = nn.Conv2d(in_channels=filter1, out_channels=filter2, kernel_size=(2,2), padding=0)
        
        self.relu = nn.ReLU()
        self.W_aa = nn.Linear(filter2 * 8, 1)
        self.W_ab = nn.Linear(filter2 * 9, 1)
        self.W_ba = nn.Linear(filter2 * 9, 1)
        self.W_bb = nn.Linear(filter2 * 8, 1)
        self.W_ab3 = nn.Linear(filter2 * 4, 1)
        self.W_ba3 = nn.Linear(filter2 * 4, 1)
        self.W_ab4 = nn.Linear(filter2 * 1, 1)
        self.W_ba4 = nn.Linear(filter2 * 1, 1)
        self.W_c = nn.Linear(filter2 * 4, 1)
        self.W_1 = nn.Linear(input_size * 16, 1)

    def flatten(self, x):
        N = x.size()[0]
        return x.view(N, -1)
        
    def forward(self, x):
        x = x.float()

        x1 = self.relu(self.conv_1(x))
        a = self.relu(self.conv_a(x))
        b = self.relu(self.conv_b(x))
        c = self.relu(self.conv_c(x))
        a3 = self.relu(self.conv_a3(x))
        b3 = self.relu(self.conv_b3(x))
        a4 = self.relu(self.conv_a4(x))
        b4 = self.relu(self.conv_b4(x))
        
        aa = self.flatten(self.relu(self.conv_aa(a)))
        ab = self.flatten(self.relu(self.conv_ab(a)))
        ba = self.flatten(self.relu(self.conv_ba(b)))
        bb = self.flatten(self.relu(self.conv_bb(b)))
        
        ab3 = self.flatten(self.relu(self.conv_ab3(a3)))
        ba3 = self.flatten(self.relu(self.conv_ba3(b3)))
        ab4 = self.flatten(self.relu(self.conv_ab4(a4)))
        ba4 = self.flatten(self.relu(self.conv_ba4(b4)))
        c2 = self.flatten(self.relu(self.conv_c2(c)))
        
        out = self.W_aa(aa) + self.W_ab(ab) + self.W_ba(ba) + self.W_bb(bb) + \
              self.W_ab4(ab4) + self.W_ba4(ba4) + self.W_c(c2) + \
              self.W_ab3(ab3) + self.W_ba3(ba3) + self.W_1(self.flatten(x1))
        
        return out

In [11]:
def make_input(grid):
    r = np.zeros(shape=(16, 4, 4))
    for i in range(4):
        for j in range(4):
            r[grid[i, j],i, j]=1
    return r

def add_two(mat):
    indexs=np.argwhere(mat==0)
    index=np.random.randint(0,len(indexs))
    mat[tuple(indexs[index])] = 1
    return mat

In [12]:
singleScore=[0,0,4,16,48,128,320,768,1792,4096,9216,20480,45056,98304,212992,458752,983040]
moveDict=np.load('move.npy')

def move(list):
    return moveDict[list[0],list[1],list[2],list[3],:]

def lookup(x):
    return singleScore[x]

lookup = np.vectorize(lookup)

def getScore(matrix):
    return np.sum(lookup(matrix))

def getMove(grid):
    board_list = []
    for i in range(4):
        newGrid=moveGrid(grid, i)
        if not isSame(grid,newGrid):
            board_list.append((newGrid, i, getScore(newGrid)))
    return board_list
        
def moveGrid(grid,i):
    # new=np.zeros((4,4),dtype=np.int)
    new = None
    if i==0:
        # move up
        grid=np.transpose(grid)
        new = np.stack([move(grid[row,:]) for row in range(4)], axis = 0).astype(int).T
    elif i==1:
        # move left
        new = np.stack([move(grid[row,:]) for row in range(4)], axis = 0).astype(int)
    elif i==2:
        # move down
        grid=np.transpose(grid)
        new = np.stack([np.flip(move(np.flip(grid[row,:]))) for row in range(4)], axis = 0).astype(int).T
    elif i==3:
        # move right
        new = np.stack([np.flip(move(np.flip(grid[row,:]))) for row in range(4)], axis = 0).astype(int)
    return new

def isSame(grid1,grid2):
    return np.all(grid1==grid2)

In [13]:
def Vchange(grid, v):
    g0 = grid
    g1 = g0[:,::-1,:]
    g2 = g0[:,:,::-1]
    g3 = g2[:,::-1,:]
    r0 = grid.swapaxes(1,2)
    r1 = r0[:,::-1,:]
    r2 = r0[:,:,::-1]
    r3 = r2[:,::-1,:]
    xtrain = np.array([g0,g1,g2,g3,r0,r1,r2,r3])
    ytrain = np.array([v]*8)
    return xtrain, ytrain

def gen_sample_and_learn(model, optimizer, loss_fn, is_train = False, explorationProb=0.1):
    model.eval()
    game_len = 0
    game_score = 0
    last_grid1 = np.zeros((4,4),dtype=np.int)
    last_grid1 = add_two(last_grid1)
    last_grid2 = make_input(last_grid1)
    last_loss = 0

    while True:
        grid_array = add_two(last_grid1)
        board_list = getMove(grid_array)
        if board_list:
            boards = np.array([make_input(g) for g,m,s in board_list])
            p = model(torch.from_numpy(boards).cuda()).flatten().detach()        
            game_len += 1
            best_v = None
            for i, (g,m,s) in enumerate(board_list):
                v = (s - game_score) + p[i].item()
                if best_v is None or v > best_v:
                    best_v = v
                    best_score = s
                    best_grid1 = board_list[i][0]
                    best_grid2 = boards[i]
                    
        else:
            best_v = 0
            best_grid1 = None
            best_grid2 = None
            
        if is_train:
            x, y = Vchange(last_grid2, best_v)
            x = torch.from_numpy(x).cuda()
            y = torch.from_numpy(y).unsqueeze(dim=1).cuda().float()
            model.train()
            optimizer.zero_grad()
            pred = model(x)
            loss = loss_fn(pred, y)
            last_loss = loss.item()
            loss.backward()
#             nn.utils.clip_grad_norm_(model.parameters(), 10.0) #
            optimizer.step()
            model.eval()
#             if game_len % 50 == 0:
#                 print (game_len, last_loss)
                
        if not board_list:
            break
            
        # gibbs sampling or espilon-greedy
        if is_train and random.random() < explorationProb:
            idx = random.randint(0, len(board_list) - 1)
            game_score = board_list[idx][2]
            last_grid1 = board_list[idx][0]
            last_grid2 = boards[idx]
        else:
            game_score = best_score
            last_grid1 = best_grid1
            last_grid2 = best_grid2
        
    return game_len, 2**grid_array.max(), game_score, last_loss

In [14]:
lr = 1e-4
weight_decay = 0
beta1 = 0.9

model = NN2048().cuda()
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay, betas=(beta1, 0.999))
loss=nn.MSELoss()

In [15]:
import os
experiment_dir = "model"

def save_model(state, filename='model.pth.tar'):
    filename = os.path.join(experiment_dir, filename)
    torch.save(state, filename)

In [18]:
def load_model(model, optimizer, checkpoint_path):
    ckpt_dict = torch.load(checkpoint_path, map_location="cuda:0")

    model.load_state_dict(ckpt_dict['state_dict'])
    optimizer.load_state_dict(ckpt_dict['optimizer'])
    epoch = ckpt_dict['epoch']
    return model, optimizer, epoch

model, optimizer, epoch = load_model(model, optimizer, "model/model5_1000.pth.tar")

In [None]:
num_epochs = 5000

def train(model, optimizer, loss, epoch = 0):
    while epoch != num_epochs:
        epoch += 1
        game_len, max_score, game_score, last_loss = gen_sample_and_learn(model, optimizer, loss, True, 0)
        print ('epoch', epoch, game_len, max_score, game_score, last_loss)
        if epoch % 500 == 0:
            filename = "model5_"+str(epoch)+".pth.tar"
            save_model({
                'epoch': epoch,
                'state_dict': model.cpu().state_dict(),
                'optimizer': optimizer.state_dict(),
            }, filename)
            model.cuda()
    
train(model, optimizer, loss, epoch)

epoch 1001 2039 2048 36700 310006592.0
epoch 1002 1008 1024 16140 401893888.0
epoch 1003 1060 1024 16800 397371648.0
epoch 1004 1031 1024 16396 320965536.0
epoch 1005 1802 2048 32808 389014208.0
epoch 1006 295 256 3324 425184256.0
epoch 1007 1084 1024 16904 220893408.0
epoch 1008 989 1024 15784 248499344.0
epoch 1009 2014 2048 36312 307389120.0
epoch 1010 884 1024 14364 585993856.0
epoch 1011 1538 2048 27548 371538048.0
epoch 1012 997 1024 15888 384254464.0
epoch 1013 661 1024 10620 484602688.0
epoch 1014 1586 2048 28104 385710848.0
epoch 1015 1082 1024 16892 276155168.0
epoch 1016 529 512 7248 320600896.0
epoch 1017 656 512 9548 271918720.0
epoch 1018 1046 1024 16536 192178720.0
epoch 1019 974 1024 15420 187487744.0
epoch 1020 1006 1024 15996 162793696.0
epoch 1021 799 1024 12492 178876448.0
epoch 1022 1287 2048 23532 248003040.0
epoch 1023 1956 2048 35320 239804240.0
epoch 1024 787 1024 12380 302293568.0
epoch 1025 1775 2048 32508 232799648.0
epoch 1026 913 1024 14672 331952000.0
epo

epoch 1214 1031 1024 16408 391033344.0
epoch 1215 2013 2048 36320 201528416.0
epoch 1216 1388 2048 25452 550840128.0
epoch 1217 1562 2048 27820 332196672.0
epoch 1218 563 512 7596 361207072.0
epoch 1219 2118 2048 37660 268234496.0
epoch 1220 1547 2048 27692 641771968.0
epoch 1221 801 1024 12504 528675104.0
epoch 1222 1126 1024 17752 364775168.0
epoch 1223 1029 1024 16368 393620224.0
epoch 1224 1810 2048 32860 448614208.0
epoch 1225 2064 2048 36976 376931072.0
epoch 1226 2043 2048 36600 580251968.0
epoch 1227 1473 2048 26768 1102102528.0
epoch 1228 660 1024 10588 1203939840.0
epoch 1229 922 1024 14732 815448384.0
epoch 1230 801 1024 12504 556027008.0
epoch 1231 685 1024 10780 483256288.0
epoch 1232 994 1024 15836 412011008.0
epoch 1233 792 1024 12416 360946720.0
epoch 1234 608 512 8208 260160064.0
epoch 1235 1047 1024 16540 204618048.0
epoch 1236 1865 2048 34320 203995120.0
epoch 1237 1285 1024 20496 178052992.0
epoch 1238 1412 1024 23280 215223968.0
epoch 1239 2032 2048 36624 416279552

epoch 1427 923 1024 14764 438397376.0
epoch 1428 1399 2048 25532 384430400.0
epoch 1429 2132 2048 37784 278893760.0
epoch 1430 1554 2048 27740 282001248.0
epoch 1431 2072 2048 37032 342753152.0
epoch 1432 1697 2048 31192 898620480.0
epoch 1433 1047 1024 16540 524571200.0
epoch 1434 803 1024 12572 621606784.0
epoch 1435 2453 4096 50536 1024663552.0
epoch 1436 1521 2048 27312 856999872.0
epoch 1437 565 512 7608 632240768.0
epoch 1438 1019 1024 16296 548486080.0
epoch 1439 1545 2048 27680 339766912.0
epoch 1440 917 1024 14704 461117248.0
epoch 1441 1068 1024 16768 297950336.0
epoch 1442 1040 1024 16464 303690112.0
epoch 1443 536 512 7296 261733312.0
epoch 1444 791 1024 12412 224112160.0
epoch 1445 967 1024 15548 182565728.0
epoch 1446 2105 2048 37496 227780960.0
epoch 1447 2012 2048 36336 402857696.0
epoch 1448 534 512 7288 732466624.0
epoch 1449 1942 2048 35188 347329344.0
epoch 1450 1002 1024 16096 600671168.0
epoch 1451 1026 1024 16344 394362176.0
epoch 1452 1050 1024 16556 257761120.0

epoch 1641 2839 4096 57476 1060505216.0
epoch 1642 523 512 7208 1369079552.0
epoch 1643 2040 2048 36764 724894272.0
epoch 1644 1556 2048 27760 924420544.0
epoch 1645 1554 2048 27740 812566784.0
epoch 1646 1047 1024 16540 717926784.0
epoch 1647 944 1024 14996 632179584.0
epoch 1648 1100 1024 17216 347927136.0
epoch 1649 565 512 7632 505528896.0
epoch 1650 816 1024 12812 457647616.0
epoch 1651 1938 2048 35164 301622784.0
epoch 1652 1553 2048 27736 407984032.0
epoch 1653 1805 2048 32824 444080448.0
epoch 1654 1551 2048 27724 431290816.0
epoch 1655 1990 2048 36112 704439808.0
epoch 1656 1000 1024 16060 775456256.0
epoch 1657 2064 2048 36952 475596992.0
epoch 1658 1053 1024 16584 573323264.0
epoch 1659 1080 1024 16872 386105216.0
epoch 1660 948 1024 15388 603064576.0
epoch 1661 917 1024 14696 445270560.0
epoch 1662 1494 2048 26988 329118400.0
epoch 1663 1098 1024 17448 397216640.0
epoch 1664 2034 2048 36636 396213504.0
epoch 1665 802 1024 12520 563635072.0
epoch 1666 1013 1024 16044 2833731

epoch 1854 1905 2048 34240 487751680.0
epoch 1855 1559 2048 27776 519362080.0
epoch 1856 788 1024 12384 761447424.0
epoch 1857 1061 1024 16636 283917184.0
epoch 1858 3086 4096 61516 557348480.0
epoch 1859 1055 1024 16652 805925568.0
epoch 1860 915 1024 14684 1133786368.0
epoch 1861 1551 2048 27724 789477504.0
epoch 1862 1042 1024 16476 546595200.0
epoch 1863 1017 1024 16224 600361024.0
epoch 1864 2048 2048 36768 364136000.0
epoch 1865 1050 1024 16556 653403072.0
epoch 1866 1554 2048 27740 479430656.0
epoch 1867 797 1024 12472 709231040.0
epoch 1868 1903 2048 34812 541927424.0
epoch 1869 1067 1024 16888 648307968.0
epoch 1870 2512 4096 51532 1321445120.0
epoch 1871 1797 2048 32780 1180251904.0
epoch 1872 3185 4096 62672 900399872.0
epoch 1873 1040 1024 16464 1036380352.0
epoch 1874 1122 1024 17300 804358016.0
epoch 1875 2003 2048 36192 598949376.0
epoch 1876 1014 1024 16184 683442688.0
epoch 1877 842 1024 12972 550186048.0
epoch 1878 1053 1024 16584 429079104.0
epoch 1879 2824 4096 5737

epoch 2066 1468 2048 26732 830572544.0
epoch 2067 2062 2048 36940 413580672.0
epoch 2068 1809 2048 32856 688780544.0
epoch 2069 1647 2048 28824 460294048.0
epoch 2070 2072 2048 37044 506343360.0
epoch 2071 2587 4096 52412 1095607680.0
epoch 2072 1048 1024 16520 888434304.0
epoch 2073 659 1024 10580 1757195136.0
epoch 2074 513 512 7104 1237565440.0
epoch 2075 1056 1024 16624 367720192.0
epoch 2076 1050 1024 16556 287634688.0
epoch 2077 1636 2048 29088 294055520.0
epoch 2078 1130 1024 17896 577604864.0
epoch 2079 525 512 7228 753080576.0
epoch 2080 1558 2048 27800 354246720.0
epoch 2081 1238 2048 22672 751226496.0
epoch 2082 1595 2048 28332 332335104.0
epoch 2083 1801 2048 32800 463906304.0
epoch 2084 1011 1024 16156 428206464.0
epoch 2085 1567 2048 27868 564916352.0
epoch 2086 1824 2048 32984 486612096.0
epoch 2087 2092 2048 37256 332076992.0
epoch 2088 1036 1024 16444 702443520.0
epoch 2089 1050 1024 16556 429315968.0
epoch 2090 1420 2048 25900 431304896.0
epoch 2091 1025 1024 16280 27

epoch 2278 1545 2048 27692 1923118336.0
epoch 2279 1023 1024 16268 1207117312.0
epoch 2280 781 1024 12336 1419668096.0
epoch 2281 570 512 7660 1073726016.0
epoch 2282 1722 2048 31812 640670528.0
epoch 2283 1552 2048 27728 703743232.0
epoch 2284 980 1024 15712 956365568.0
epoch 2285 1502 2048 27080 760775872.0
epoch 2286 1049 1024 16552 615428672.0
epoch 2287 2054 2048 36888 362532096.0
epoch 2288 2586 4096 52408 1476497792.0
epoch 2289 1824 2048 33044 1013766784.0
epoch 2290 1538 2048 27548 1083137024.0
epoch 2291 1776 2048 32524 1138760704.0
epoch 2292 1939 2048 35208 1079930624.0
epoch 2293 1818 2048 32952 890526464.0
epoch 2294 1304 2048 23708 1633521536.0
epoch 2295 2046 2048 36748 380222144.0
epoch 2296 986 1024 15784 934744768.0
epoch 2297 1072 1024 16800 394256544.0
epoch 2298 828 1024 12924 865213184.0
epoch 2299 927 1024 14796 660849664.0
epoch 2300 1784 2048 32576 335559648.0
epoch 2301 793 1024 12448 723431232.0
epoch 2302 1040 1024 16468 473125184.0
epoch 2303 2049 2048 367

epoch 2490 1801 2048 32800 210098160.0
epoch 2491 1035 1024 16428 225989152.0
epoch 2492 2129 2048 37980 17317284.0
epoch 2493 2009 2048 36264 324788064.0
epoch 2494 1053 1024 16584 356051776.0
epoch 2495 1998 2048 36168 256966752.0
epoch 2496 2354 2048 44488 520637664.0
epoch 2497 1742 2048 31816 962785472.0
epoch 2498 1073 1024 16808 565411200.0
epoch 2499 944 1024 14936 591101696.0
epoch 2500 1673 2048 30920 521168448.0
epoch 2501 3031 4096 60888 1079701248.0
epoch 2502 1049 1024 16552 983321280.0
epoch 2503 2010 2048 36268 320253248.0
epoch 2504 1038 1024 16444 700114560.0
epoch 2505 1066 1024 16868 373713088.0
epoch 2506 1579 2048 28028 338285248.0
epoch 2507 1784 2048 32576 239288352.0
epoch 2508 1520 2048 27392 479728384.0
epoch 2509 2010 2048 36304 478946688.0
epoch 2510 1494 2048 27000 696098112.0
epoch 2511 616 512 8288 790620032.0
epoch 2512 1952 2048 35288 188116976.0
epoch 2513 2031 2048 36620 345696512.0
epoch 2514 3079 4096 61468 981590400.0
epoch 2515 1048 1024 16544 12

epoch 2702 558 512 7500 510468320.0
epoch 2703 1432 2048 26012 328406080.0
epoch 2704 991 1024 15820 174313120.0
epoch 2705 1290 2048 23568 250481344.0
epoch 2706 504 512 6996 349504064.0
epoch 2707 1505 2048 27152 136964672.0
epoch 2708 1068 1024 16712 25875704.0
epoch 2709 1994 2048 36136 97234184.0
epoch 2710 1236 2048 22876 555071424.0
epoch 2711 1078 1024 16860 224209200.0
epoch 2712 1037 1024 16448 246680800.0
epoch 2713 1010 1024 16056 185450944.0
epoch 2714 2040 2048 36704 74445016.0
epoch 2715 1061 1024 16664 218351968.0
epoch 2716 2950 4096 58988 449110336.0
epoch 2717 1048 1024 16520 378854400.0
epoch 2718 3089 4096 61560 506492352.0
epoch 2719 3077 4096 61460 1548637696.0
epoch 2720 1551 2048 27724 2035671168.0
epoch 2721 937 1024 14944 1647855616.0
epoch 2722 941 1024 14920 1244270080.0
epoch 2723 1104 1024 17240 508902944.0
epoch 2724 787 1024 12380 836253824.0
epoch 2725 1812 2048 32884 421231168.0
epoch 2726 1019 1024 16236 481555072.0
epoch 2727 412 512 5532 876387584.

epoch 2914 1038 1024 16456 418845184.0
epoch 2915 1069 1024 16776 124796800.0
epoch 2916 1781 2048 32552 106340320.0
epoch 2917 3613 4096 72908 318203264.0
epoch 2918 1825 2048 33000 648736640.0
epoch 2919 2562 2048 46028 235295648.0
epoch 2920 943 1024 14988 1347218944.0
epoch 2921 653 512 8664 1384495744.0
epoch 2922 1529 2048 27488 706663488.0
epoch 2923 1041 1024 16472 434590528.0
epoch 2924 1532 2048 27516 302349184.0
epoch 2925 1547 1024 25480 148097280.0
epoch 2926 1874 2048 33884 94491184.0
epoch 2927 1762 2048 32220 169634480.0
epoch 2928 2701 2048 51520 146254144.0
epoch 2929 1544 2048 27676 954542336.0
epoch 2930 1032 1024 16412 778075712.0
epoch 2931 1559 2048 27776 184617856.0
epoch 2932 2960 4096 59732 812474368.0
epoch 2933 1093 1024 17048 1320427264.0
epoch 2934 1569 2048 27868 681828800.0
epoch 2935 966 1024 15592 1459043328.0
epoch 2936 1744 2048 32076 609577600.0
epoch 2937 1050 1024 16556 680709760.0
epoch 2938 1481 2048 26908 814256192.0
epoch 2939 1575 2048 27996 

In [9]:
num_epochs = 100

def test(model):
    epoch = 0
    while epoch != num_epochs:
        epoch += 1
        game_len, max_score, game_score, last_loss = gen_sample_and_learn(model, None, None, False)
        print ('epoch', epoch, game_len, max_score, game_score, last_loss)

test(model)

epoch 1 2026 2048 36412 0
epoch 2 2040 2048 36680 0
epoch 3 1060 1024 16660 0
epoch 4 801 1024 12488 0
epoch 5 727 1024 11628 0
epoch 6 1042 1024 16476 0
epoch 7 796 1024 12476 0
epoch 8 925 1024 14280 0
epoch 9 1552 2048 27728 0
epoch 10 1801 2048 32800 0
epoch 11 1547 2048 27692 0
epoch 12 1054 1024 16588 0
epoch 13 1038 1024 16456 0
epoch 14 1031 1024 16408 0
epoch 15 1864 2048 33568 0
epoch 16 1809 2048 32856 0
epoch 17 1822 2048 32972 0
epoch 18 1553 2048 27736 0
epoch 19 1823 2048 32980 0
epoch 20 569 512 7628 0
epoch 21 1617 2048 29264 0
epoch 22 2048 2048 36760 0
epoch 23 1042 1024 16476 0
epoch 24 1044 1024 16488 0
epoch 25 1553 2048 27760 0
epoch 26 1004 1024 15984 0
epoch 27 1972 2048 35884 0
epoch 28 1041 1024 16472 0
epoch 29 2080 2048 37264 0
epoch 30 1559 2048 27804 0
epoch 31 1488 2048 26956 0
epoch 32 1813 2048 32888 0
epoch 33 1914 2048 34892 0
epoch 34 1544 2048 27648 0
epoch 35 1058 1024 16620 0
epoch 36 975 1024 15676 0
epoch 37 1054 1024 16588 0
epoch 38 980 1024 

In [10]:
# import os
# experiment_dir = "model"
# filename = "model6.pth.tar"
# num_epochs = 1000

# def save_model(state, filename='model.pth.tar'):
#     filename = os.path.join(experiment_dir, filename)
#     torch.save(state, filename)

# save_model({
#     'epoch': num_epochs,
#     'state_dict': model.cpu().state_dict(),
#     'optimizer': optimizer.state_dict(),
# }, filename)