In [152]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.utils.data import *
import wandb
%matplotlib inline

In [160]:
class model_encoder(nn.Module):
    def __init__(self):
        super(model_encoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 8, 5, stride=2, padding=1),
            nn.ReLU(True),
            nn.BatchNorm2d(8),
            nn.Conv2d(8, 8, 1, stride=1, padding=0),
            nn.ReLU(True),
            nn.BatchNorm2d(8),
            nn.Conv2d(8, 16, 5, stride=1, padding=0),
            nn.ReLU(True),
            nn.BatchNorm2d(16),
            nn.Conv2d(16, 16, 1, stride=1, padding=0),
            nn.ReLU(True),
            nn.BatchNorm2d(16),
            nn.Conv2d(16, 32, 4, stride=1, padding=0),
            nn.ReLU(True),
            nn.BatchNorm2d(32),
            nn.Conv2d(32, 32, 1, stride=1, padding=0),
            nn.ReLU(True),
            nn.BatchNorm2d(32),
        )
 
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(32, 16, 3, stride=2, padding=1),
            nn.ReLU(True),
            nn.BatchNorm2d(16),
            nn.ConvTranspose2d(16, 16, 1, stride=1, padding=0),
            nn.ReLU(True),
            nn.BatchNorm2d(16),
            nn.ConvTranspose2d(16, 8, 5, stride=2, padding=3),
            nn.ReLU(True),
            nn.BatchNorm2d(8),
            nn.ConvTranspose2d(8, 8, 1, stride=1, padding=0),
            nn.ReLU(True),
            nn.BatchNorm2d(8),
            nn.ConvTranspose2d(8, 4, 4, stride=1, padding=0),
            nn.ReLU(True),
            nn.BatchNorm2d(4),
            nn.ConvTranspose2d(4, 4, 1, stride=1, padding=0),
            nn.ReLU(True),
            nn.BatchNorm2d(4),
            nn.ConvTranspose2d(4, 1, 3, stride=1, padding=0),
        )
        
    def forward(self, x):
        x=self.encoder(x)
        x=self.decoder(x)
        return x

In [161]:
batch_size = 128
board=np.load('board.npy')/1000
board=torch.tensor(board, dtype=torch.float32)
board=torch.reshape(board, (board.shape[0],1, 32,32))
board=TensorDataset(board) 
train_dataset, test_dataset = random_split(board, [650000, 40288])
dataset=DataLoader(board,  batch_size=batch_size, shuffle=True, pin_memory=True)

In [162]:
num_epochs = 15
learning_rate = 1e-3
momentum = 0.9
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model=model_encoder().to(device)
optimizer=torch.optim.SGD(model.parameters(),lr=learning_rate, momentum=momentum)
#optimizer=torch.optim.Adam(model.parameters(),lr=learning_rate)
loss_fn=torch.nn.BCEWithLogitsLoss()

In [163]:
wandb.init(project="halite")
wandb.watch(model)
model.train()

W&B Run: https://app.wandb.ai/arb426/halite/runs/txhyzi6c
Call `%%wandb` in the cell containing your training loop to display live results.


model_encoder(
  (encoder): Sequential(
    (0): Conv2d(1, 8, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (1): ReLU(inplace)
    (2): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (4): Conv2d(8, 8, kernel_size=(1, 1), stride=(1, 1))
    (5): ReLU(inplace)
    (6): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): Conv2d(8, 16, kernel_size=(3, 3), stride=(1, 1))
    (8): ReLU(inplace)
    (9): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1))
    (11): ReLU(inplace)
    (12): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (13): Conv2d(16, 32, kernel_size=(2, 2), stride=(1, 1))
    (14): ReLU(inplace)
    (15): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (1

In [None]:
j=0
for i in range(num_epochs):
    print(i+1, "of", num_epochs)
    for data in dataset:
        data=data[0].to(device)
        res=model(data)
        optimizer.zero_grad()
        loss=loss_fn(res,data)
        loss.backward()
        optimizer.step()
        wandb.log({"Loss": loss})
        j+=1

1 of 15
Resuming run: https://app.wandb.ai/arb426/halite/runs/txhyzi6c
2 of 15
3 of 15
4 of 15
5 of 15
6 of 15
7 of 15
8 of 15
9 of 15
10 of 15
11 of 15


In [None]:
fig=plt.figure(figsize=(18, 16), dpi= 80, facecolor='w', edgecolor='k')
for i in range(5):
    j=np.random.randint(0,240000)
    img=model(test_dataset[j][0].unsqueeze(1))
    plt.subplot(5,2,2*i+1)
    plt.imshow(test_dataset[j][0].squeeze(0).numpy())
    plt.subplot(5,2,2*i+2)
    plt.imshow(torch.reshape(img,(32,32)).detach().numpy())
plt.tight_layout()