In [1]:
%pylab inline

import os
import numpy as np

from PIL import Image
from skimage.morphology import remove_small_objects, remove_small_holes

import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import DataLoader
from torchvision import transforms

from pair_transforms import *
from FaceDataset import FaceDataset

from models.linknet import UNetResnet

from loss import dice_loss
from utils import rle_encode, write_results

Populating the interactive namespace from numpy and matplotlib


In [2]:
path_images = list(filter(lambda x: x.endswith('.jpg'), os.listdir('data/train/')))
print(len(path_images))
train_images, val_images = path_images[:1460], path_images[1460:1491]

1491


In [4]:
color_transforms = transforms.Compose([
    transforms.RandomGrayscale(p=0.01)
    ])

pair_transforms = Compose([
    RandomCrop(240),
    RandomHorizontalFlip(),
    ToTensor()
    ])

tensor_transforms = transforms.Compose([
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
    ])

In [5]:
train_dataset = FaceDataset(
    images_dir='data/train/',
    images_name=train_images,
    target_dir='data/train_mask/',
    color_transforms=color_transforms,
    pair_transforms=pair_transforms,
    tensor_transforms=tensor_transforms)

val_dataset = FaceDataset(
    images_dir='data/train/',
    images_name=val_images,
    target_dir='data/train_mask/',
    pair_transforms=Compose([ ToTensor() ]),
    tensor_transforms=tensor_transforms)

train_data_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)
val_data_loader = DataLoader(val_dataset, batch_size=2)

1460 images
31 images


In [10]:
device = torch.device('cpu')
if torch.cuda.is_available():
    device = torch.device('cuda')

net = UNetResnet().to(device)

In [63]:
# Check sizes
for batch in train_data_loader:
    break
    
out = net.forward(batch['img'].cuda())

print(batch['img'].shape)
print(out.shape)

torch.cuda.empty_cache()

torch.Size([2, 3, 240, 240])
torch.Size([2, 1, 240, 240])


## Train

In [67]:
lr = 1e-4

criterion = dice_loss
val_criterion = dice_loss
optimizer = torch.optim.Adam(net.parameters(), lr=lr)

from torch.optim.lr_scheduler import ReduceLROnPlateau

scheduler = ReduceLROnPlateau(optimizer, 'min', factor=0.1, patience=10)

In [None]:
def train(verbose=True):
    net.train()
    running_loss = 0.0
    train_loss = 0.0
    for i, batch in enumerate(train_data_loader):
        optimizer.zero_grad()
        
        batch['img'], batch['mask'] = batch['img'].to(device), batch['mask'].to(device)
        output = net(batch['img'])
        loss = criterion(output, batch['mask'])
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        train_loss += loss.item()
        
        if i % 100 == 99:
            if verbose:
                print(' [{} - {}],\ttrain loss: {:.5}'.format(epoch+1, i+1, running_loss/100))
            else:
                print('|', end='', flush=True)
            running_loss = 0.0
    train_loss /= i
    print('\n [{}], \ttrain loss: {:.5}'.format(epoch+1, train_loss))
    return train_loss

def validate():
    net.eval()
    val_loss = 0.0
    for i, batch in enumerate(val_data_loader):
        batch['img'], batch['mask'] = batch['img'].to(device), batch['mask'].to(device)
        with torch.no_grad():
            output = net(batch['img'])
        val_loss += val_criterion(output, batch['mask']).detach().item()
    val_loss /= len(val_dataset)
    print(' [{}], \tval loss: {:.5}\n'.format(epoch+1, val_loss))
    return val_loss

In [68]:
import time

num_epoch = 1
history = []

for epoch in range(num_epoch):
    train_loss = train()
    val_loss = validate()
    history.append((train_loss, val_loss))
    #scheduler.step(val_loss)

 [1 - 100],	train loss: 0.028007
 [1 - 200],	train loss: 0.03228
 [1 - 300],	train loss: 0.033948
 [1 - 400],	train loss: 0.031249
 [1 - 500],	train loss: 0.029961
 [1 - 600],	train loss: 0.031217
 [1 - 700],	train loss: 0.033077

 [1], 	train loss: 0.031503
 [1], 	val loss: 0.014317



In [54]:
#torch.save(net.state_dict(), 'models/linknet34')
#net.load_state_dict(torch.load('models/linknet34_best'))

## Test

In [64]:
tensor_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
    ])

In [65]:
path_images = list(filter(lambda x: x.endswith('.jpg'), os.listdir('data/test/')))

test_data_loader = DataLoader(
    FaceDataset('data/test', path_images, tensor_transforms=tensor_transforms), batch_size=1)

2177 images


In [67]:
path_images = [int(x.split('.')[0]) for x in path_images]

write_results(net, test_data_loader, path_images)

In [53]:
import matplotlib.pyplot as plt

def save_pseudo_labels(dir_name='data/pseudo', threshold=0.25):
    for i,batch in enumerate(test_data_loader):
        batch['img'] = batch['img'].to(device)
        with torch.no_grad():
            output = net.forward(batch['img'])

        img = output[0].detach().cpu().numpy()
        post_img = remove_small_holes(remove_small_objects(img > threshold))
        im = Image.fromarray(post_img[0].astype('uint8')*255)
        im.save("{}/{}.png".format(dir_name, path_images[i]))  
    print("Processed {} imgs".format(i+1))

2176
