In [4]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn as nn
from torch.utils.data import *
import wandb
wandb.init(project="halite")
%matplotlib inline

W&B Run: https://app.wandb.ai/arb426/halite/runs/oum405qa
Call `%%wandb` in the cell containing your training loop to display live results.


In [5]:
class model_encoder(nn.Module):
    def __init__(self):
        super(model_encoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 8, 3, stride=2, padding=1),
            nn.MaxPool2d(2, stride=2),
            nn.ReLU(True),
            nn.Conv2d(8, 1, 5, stride=1, padding=1),
            nn.MaxPool2d(2, stride=1),
            nn.ReLU(True),
        )
 
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(1, 8, 2, stride=3, padding=3),
            nn.ReLU(True),
            nn.ConvTranspose2d(8, 4, 2, stride=2, padding=1),
            nn.ReLU(True),
            nn.ConvTranspose2d(4, 1, 6, stride=2, padding=0),
        )
        
    def forward(self, x):
        x=self.encoder(x)
        x=self.decoder(x)
        return x

In [None]:
batch_size = 256
board=np.load('board.npy')/1000
board=torch.tensor(board, dtype=torch.float32)
board=torch.reshape(board, (board.shape[0],1, 32,32))
board=TensorDataset(board)
train_dataset, test_dataset = random_split(board, [450000, 240288])
dataset=DataLoader(board,  batch_size=batch_size, shuffle=True)

In [None]:
num_epochs = 6
learning_rate = 1e-3
model=model_encoder()
optimizer=torch.optim.Adam(model.parameters(),lr=learning_rate)
loss_fn=torch.nn.BCEWithLogitsLoss()
wandb.watch(model)
model.train()

In [None]:
j=0
for i in range(num_epochs):
    for data in dataset:
        res=model(data[0])
        optimizer.zero_grad()
        loss=loss_fn(res,data[0])
        loss.backward()
        optimizer.step()
        wandb.log({"Loss": loss})
        j+=1

In [None]:
fig=plt.figure(figsize=(18, 16), dpi= 80, facecolor='w', edgecolor='k')
for i in range(5):
    j=np.random.randint(0,240000)
    img=model(test_dataset[j][0].unsqueeze(1))
    plt.subplot(5,2,2*i+1)
    plt.imshow(test_dataset[j][0].squeeze(0).numpy())
    plt.subplot(5,2,2*i+2)
    plt.imshow(torch.reshape(img,(32,32)).detach().numpy())
plt.tight_layout()