This kernel is inspired by Heng's [discussion](https://www.kaggle.com/c/severstal-steel-defect-detection/discussion/106462#latest-618693), which is about [HRNet-Semantic-Segmentation](https://arxiv.org/abs/1904.04514)
![](https://raw.githubusercontent.com/HRNet/HRNet-Semantic-Segmentation/master/figures/seg-hrnet.png?generation=1565963628491533&amp;alt=media)


In [28]:
from HRNet.c1_decoder import get_decoder
from HRNet.hrnet import get_encoder
from torchvision import transforms
from torch.utils.data import TensorDataset, DataLoader
import torch.optim as optim
import torch.nn as nn
import torch
import numpy as np
import cv2
from PIL import Image
import scipy.io
import matplotlib.pylab as plt
import glob

In [2]:
def transform(image, size):
    resize = transforms.Compose([transforms.Resize(size), transforms.ToTensor()])
    return resize(Image.fromarray(np.uint8(image))).unsqueeze(0)

In [3]:
def show(output):
    final = output[0].detach().numpy()
    classes = np.argmax(final, axis=0)
    print(classes.shape)
    mask = colors[classes]
    plt.imshow(mask)
#     plt.show()

In [4]:
encoder = get_encoder('encoder_epoch_30.pth')

In [5]:
decoder = get_decoder(59, '')

In [8]:
size = (800, 544)

In [11]:
colors = scipy.io.loadmat('color150.mat')['colors']

In [9]:
output_folder = f'data/annotations/pixel-level/'
input_folder = f'data/photos/'

In [18]:
inputs = [transform(Image.open(x), size) for x in sorted(glob.glob(f'{input_folder}*.jpg'))[:2]] # 1004
labels = [transform(scipy.io.loadmat(x)['groundtruth'], size) for x in sorted(glob.glob(f'{output_folder}*.mat'))[:2]]

In [26]:
dataset = TensorDataset(torch.cat(inputs), torch.cat(labels))

In [33]:
class EncoderDecoder(nn.Module):
    def __init__(self, encoder, decoder):
        super(EncoderDecoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
    
    def forward(self, x):
        return self.decoder(self.encoder(x))
    
    def parameters(self):
        return self.decoder.parameters()

In [30]:
def train(model, dataset, epochs=2, batch_size=2):
    dataloader = DataLoader(dataset, batch_size, shuffle=True)
    optimizer = optim.Adam(model.parameters())
    model.train()
    criterion = nn.NLLLoss(ignore_index=-1)
    losses = []
    for e in range(epochs):
        epoch_loss = 0
        count = 0
        for data, target in dataloader:
            prediction = model(data)
            loss = criterion(prediction, target)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            count += 1
            epoch_loss += loss
        losses.append(epoch_loss.item() / count)
        print(f'epoch {e} loss: {epoch_loss.item() / count}')
    return losses

In [34]:
model = EncoderDecoder(encoder, decoder)

In [35]:
train(model, dataset)

KeyboardInterrupt: 