# Predicting the position of a go board in an image using UNet

This notebook uses the UNet architecture to create a heatmap from an image containing a go board.

In [1]:
import json
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import os
from matplotlib.patches import Polygon
import numpy as np
from PIL import Image
from tqdm import tqdm

from scipy.spatial import distance as dist

from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as T

from skimage.transform import resize

import cv2

from tqdm import tqdm, tqdm_notebook

from skimage.draw import polygon2mask, polygon, polygon_perimeter

from adabelief_pytorch import AdaBelief

import torch
import torch.nn.functional as F
import torch.nn as nn
from unet import UNet

datafolder = "board_masks/upload/"

In [3]:
# define a dataset class for the Dataloaders
class MaskDataset:
    def __init__(self, baseFile , folder, image_size):
        self.basePath = folder
        self.imageSize = image_size
        
        # load the dataset
        with open(baseFile) as json_file:
            self.jsonData = json.load(json_file)
        
        self.images = []
        
        self.masks = []
        
        # save every image and ground truth mask
        for p in tqdm(self.jsonData):
            data = p["data"]
            imagePath = data["image"]
        
            self.images.append(np.array(Image.open(self.basePath + os.path.basename(imagePath))))
            
            y = np.array(p["completions"][0]["result"][0]["value"]["points"])
            mask = polygon2mask((800,800), y*8).astype(bool).T
            
            self.masks.append(mask)
    
    def __len__(self):
        return len(self.jsonData)
    
    def __getitem__(self, index):
        
        # a simple image transformer wich resizes the images
        transform = T.Compose([
            T.ToPILImage(),
            T.Resize(self.imageSize),
            T.ToTensor(),
        ])
        
        img = self.images[index]
        X = transform(img)
        
        mask = resize(self.masks[index], (self.imageSize,self.imageSize), order = 0,preserve_range=True)
        
        return X, torch.LongTensor(mask)
    

In [6]:
# our dataset
d = MaskDataset("result.json",datafolder, 128)

100%|██████████| 802/802 [00:53<00:00, 15.10it/s]


In [7]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = UNet(in_channels=3, n_classes=2, wf=5, depth=4, padding=True, up_mode='upsample').to(device)
optim = AdaBelief(model.parameters(), lr=1e-3, eps=1e-8, betas=(0.9,0.999), weight_decouple = True, rectify = False)

[31mPlease check your arguments if you have upgraded adabelief-pytorch from version 0.0.5.
[31mModifications to default arguments:
[31m                           eps  weight_decouple    rectify
-----------------------  -----  -----------------  ---------
adabelief-pytorch=0.0.5  1e-08  False              False
>=0.1.0 (Current 0.2.0)  1e-16  True               True
[34mSGD better than Adam (e.g. CNN for Image Classification)    Adam better than SGD (e.g. Transformer, GAN)
----------------------------------------------------------  ----------------------------------------------
Recommended eps = 1e-8                                      Recommended eps = 1e-16
[34mFor a complete table of recommended hyperparameters, see
[34mhttps://github.com/juntang-zhuang/Adabelief-Optimizer
[32mYou can disable the log message by setting "print_change_log = False", though it is recommended to keep as a reminder.
[0m
Weight decoupling enabled in AdaBelief


In [8]:
batch_size = 16
train_dl = DataLoader(d, batch_size, shuffle=True, num_workers=8,)

In [9]:
epochs = 100
batches = len(train_dl)

for epoch in range(epochs):
    total_loss = 0
    progress = tqdm_notebook(enumerate(train_dl), desc="Loss: ", total=batches)
    
    model.train()
    
    for i, (X, target) in progress:
        X = X.to(device)  # [N, 3, H, W]
        target = target.to(device)  # [N, H, W] with class indices (0, 1)
        
        outputs = model(X) # [N, 1, H, W]
        
        loss = F.cross_entropy(outputs, target)
        optim.zero_grad()
        loss.backward()
        optim.step()
        
        current_loss = loss.item()
        total_loss += current_loss
        
        progress.set_description(f"Epoch: {epoch} | Loss: {(total_loss/(i+1))}")
        
    torch.cuda.empty_cache() 
    val_losses = 0

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  progress = tqdm_notebook(enumerate(train_dl), desc="Loss: ", total=batches)


Loss:   0%|          | 0/51 [00:00<?, ?it/s]



Loss:   0%|          | 0/51 [00:00<?, ?it/s]

Loss:   0%|          | 0/51 [00:00<?, ?it/s]

Loss:   0%|          | 0/51 [00:00<?, ?it/s]

Loss:   0%|          | 0/51 [00:00<?, ?it/s]

Loss:   0%|          | 0/51 [00:00<?, ?it/s]

Loss:   0%|          | 0/51 [00:00<?, ?it/s]

Loss:   0%|          | 0/51 [00:00<?, ?it/s]

Loss:   0%|          | 0/51 [00:00<?, ?it/s]

Loss:   0%|          | 0/51 [00:00<?, ?it/s]

Loss:   0%|          | 0/51 [00:00<?, ?it/s]

Loss:   0%|          | 0/51 [00:00<?, ?it/s]

Loss:   0%|          | 0/51 [00:00<?, ?it/s]

Loss:   0%|          | 0/51 [00:00<?, ?it/s]

Loss:   0%|          | 0/51 [00:00<?, ?it/s]

Loss:   0%|          | 0/51 [00:00<?, ?it/s]

Loss:   0%|          | 0/51 [00:00<?, ?it/s]

Loss:   0%|          | 0/51 [00:00<?, ?it/s]

Loss:   0%|          | 0/51 [00:00<?, ?it/s]

Loss:   0%|          | 0/51 [00:00<?, ?it/s]

Loss:   0%|          | 0/51 [00:00<?, ?it/s]

Loss:   0%|          | 0/51 [00:00<?, ?it/s]

Loss:   0%|          | 0/51 [00:00<?, ?it/s]

Loss:   0%|          | 0/51 [00:00<?, ?it/s]

Loss:   0%|          | 0/51 [00:00<?, ?it/s]

Loss:   0%|          | 0/51 [00:00<?, ?it/s]

Loss:   0%|          | 0/51 [00:00<?, ?it/s]

Loss:   0%|          | 0/51 [00:00<?, ?it/s]

Loss:   0%|          | 0/51 [00:00<?, ?it/s]

Loss:   0%|          | 0/51 [00:00<?, ?it/s]

Loss:   0%|          | 0/51 [00:00<?, ?it/s]

Loss:   0%|          | 0/51 [00:00<?, ?it/s]

Loss:   0%|          | 0/51 [00:00<?, ?it/s]

Loss:   0%|          | 0/51 [00:00<?, ?it/s]

Loss:   0%|          | 0/51 [00:00<?, ?it/s]

Loss:   0%|          | 0/51 [00:00<?, ?it/s]

Loss:   0%|          | 0/51 [00:00<?, ?it/s]

Loss:   0%|          | 0/51 [00:00<?, ?it/s]

Loss:   0%|          | 0/51 [00:00<?, ?it/s]

Loss:   0%|          | 0/51 [00:00<?, ?it/s]

Loss:   0%|          | 0/51 [00:00<?, ?it/s]

Loss:   0%|          | 0/51 [00:00<?, ?it/s]

Loss:   0%|          | 0/51 [00:00<?, ?it/s]

Loss:   0%|          | 0/51 [00:00<?, ?it/s]

Loss:   0%|          | 0/51 [00:00<?, ?it/s]

Loss:   0%|          | 0/51 [00:00<?, ?it/s]

Loss:   0%|          | 0/51 [00:00<?, ?it/s]

Loss:   0%|          | 0/51 [00:00<?, ?it/s]

Loss:   0%|          | 0/51 [00:00<?, ?it/s]

Loss:   0%|          | 0/51 [00:00<?, ?it/s]

Loss:   0%|          | 0/51 [00:00<?, ?it/s]

Loss:   0%|          | 0/51 [00:00<?, ?it/s]

Loss:   0%|          | 0/51 [00:00<?, ?it/s]

Loss:   0%|          | 0/51 [00:00<?, ?it/s]

Loss:   0%|          | 0/51 [00:00<?, ?it/s]

Loss:   0%|          | 0/51 [00:00<?, ?it/s]

Loss:   0%|          | 0/51 [00:00<?, ?it/s]

Loss:   0%|          | 0/51 [00:00<?, ?it/s]

Loss:   0%|          | 0/51 [00:00<?, ?it/s]

Loss:   0%|          | 0/51 [00:00<?, ?it/s]

Loss:   0%|          | 0/51 [00:00<?, ?it/s]

Loss:   0%|          | 0/51 [00:00<?, ?it/s]

Loss:   0%|          | 0/51 [00:00<?, ?it/s]

Loss:   0%|          | 0/51 [00:00<?, ?it/s]

Loss:   0%|          | 0/51 [00:00<?, ?it/s]

Loss:   0%|          | 0/51 [00:00<?, ?it/s]

Loss:   0%|          | 0/51 [00:00<?, ?it/s]

Loss:   0%|          | 0/51 [00:00<?, ?it/s]

Loss:   0%|          | 0/51 [00:00<?, ?it/s]

Loss:   0%|          | 0/51 [00:00<?, ?it/s]

Loss:   0%|          | 0/51 [00:00<?, ?it/s]

Loss:   0%|          | 0/51 [00:00<?, ?it/s]

Loss:   0%|          | 0/51 [00:00<?, ?it/s]

Loss:   0%|          | 0/51 [00:00<?, ?it/s]

Loss:   0%|          | 0/51 [00:00<?, ?it/s]

Loss:   0%|          | 0/51 [00:00<?, ?it/s]

Loss:   0%|          | 0/51 [00:00<?, ?it/s]

Loss:   0%|          | 0/51 [00:00<?, ?it/s]

Loss:   0%|          | 0/51 [00:00<?, ?it/s]

Loss:   0%|          | 0/51 [00:00<?, ?it/s]

Loss:   0%|          | 0/51 [00:00<?, ?it/s]

Loss:   0%|          | 0/51 [00:00<?, ?it/s]

Loss:   0%|          | 0/51 [00:00<?, ?it/s]

Loss:   0%|          | 0/51 [00:00<?, ?it/s]

Loss:   0%|          | 0/51 [00:00<?, ?it/s]

Loss:   0%|          | 0/51 [00:00<?, ?it/s]

Loss:   0%|          | 0/51 [00:00<?, ?it/s]

Loss:   0%|          | 0/51 [00:00<?, ?it/s]

Loss:   0%|          | 0/51 [00:00<?, ?it/s]

Loss:   0%|          | 0/51 [00:00<?, ?it/s]

Loss:   0%|          | 0/51 [00:00<?, ?it/s]

Loss:   0%|          | 0/51 [00:00<?, ?it/s]

Loss:   0%|          | 0/51 [00:00<?, ?it/s]

Loss:   0%|          | 0/51 [00:00<?, ?it/s]

Loss:   0%|          | 0/51 [00:00<?, ?it/s]

Loss:   0%|          | 0/51 [00:00<?, ?it/s]

Loss:   0%|          | 0/51 [00:00<?, ?it/s]

Loss:   0%|          | 0/51 [00:00<?, ?it/s]

Loss:   0%|          | 0/51 [00:00<?, ?it/s]

Loss:   0%|          | 0/51 [00:00<?, ?it/s]

In [19]:
torch.save(model.state_dict(), 'checkpoint1.pth')