In [18]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.utils.data import *
import wandb
%matplotlib inline

In [19]:
class model_encoder(nn.Module):
    def __init__(self):
        super(model_encoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 8, 3, stride=2, padding=1),
            nn.MaxPool2d(2, stride=2),
            nn.ELU(True),
            nn.Conv2d(8, 1, 5, stride=1, padding=1),
            nn.MaxPool2d(2, stride=1),
            nn.ELU(True),
        )
 
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(1, 8, 2, stride=3, padding=3),
            nn.ELU(True),
            nn.ConvTranspose2d(8, 4, 2, stride=2, padding=1),
            nn.ELU(True),
            nn.ConvTranspose2d(4, 1, 6, stride=2, padding=0),
        )
        
    def forward(self, x):
        x=self.encoder(x)
        x=self.decoder(x)
        return x

In [20]:
batch_size = 256
board=np.load('board.npy')/1000
board=torch.tensor(board, dtype=torch.float32)
board=torch.reshape(board, (board.shape[0],1, 32,32))
board=TensorDataset(board)    

In [21]:
num_epochs = 50
learning_rate = 5e-4
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model=model_encoder().to(device)
optimizer=torch.optim.Adam(model.parameters(),lr=learning_rate)
loss_fn=torch.nn.BCEWithLogitsLoss()

In [22]:
wandb.init(project="halite")
wandb.watch(model)
model.train()

W&B Run: https://app.wandb.ai/arb426/halite/runs/9f3ri4k3
Call `%%wandb` in the cell containing your training loop to display live results.


model_encoder(
  (encoder): Sequential(
    (0): Conv2d(1, 8, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (2): ELU(alpha=True)
    (3): Conv2d(8, 1, kernel_size=(5, 5), stride=(1, 1), padding=(1, 1))
    (4): MaxPool2d(kernel_size=2, stride=1, padding=0, dilation=1, ceil_mode=False)
    (5): ELU(alpha=True)
  )
  (decoder): Sequential(
    (0): ConvTranspose2d(1, 8, kernel_size=(2, 2), stride=(3, 3), padding=(3, 3))
    (1): ELU(alpha=True)
    (2): ConvTranspose2d(8, 4, kernel_size=(2, 2), stride=(2, 2), padding=(1, 1))
    (3): ELU(alpha=True)
    (4): ConvTranspose2d(4, 1, kernel_size=(6, 6), stride=(2, 2))
  )
)

In [None]:
j=0
for i in range(num_epochs):
    for data in dataset:
        data=data[0].to(device)
        res=model(data)
        optimizer.zero_grad()
        loss=loss_fn(res,data)
        loss.backward()
        optimizer.step()
        wandb.log({"Loss": loss})
        j+=1

Resuming run: https://app.wandb.ai/arb426/halite/runs/9f3ri4k3


In [None]:
fig=plt.figure(figsize=(18, 16), dpi= 80, facecolor='w', edgecolor='k')
for i in range(5):
    j=np.random.randint(0,240000)
    img=model(test_dataset[j][0].unsqueeze(1))
    plt.subplot(5,2,2*i+1)
    plt.imshow(test_dataset[j][0].squeeze(0).numpy())
    plt.subplot(5,2,2*i+2)
    plt.imshow(torch.reshape(img,(32,32)).detach().numpy())
plt.tight_layout()