In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="1"

In [2]:
import sys
sys.path.append('../src')

In [3]:
import numpy as np
import torch 
import torch.nn as nn
import os
from torch.autograd import Variable
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
from torchvision import transforms
from sklearn.model_selection import train_test_split
from FishDataset import FishDataset

In [4]:
%load_ext autoreload
%autoreload 2
from model import UNet

In [5]:
train_transform = transforms.Compose([
    transforms.Resize(size=(128, 128)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor()
])

test_transform = transforms.Compose([
    transforms.Resize(size=(128, 128)),
    transforms.ToTensor()
])

In [6]:
train_dataset = FishDataset('../data', download=True, transform=train_transform, target_transform=train_transform)

In [7]:
train_indices, test_indices = train_test_split(np.arange(len(train_dataset)), test_size=0.2, random_state=42)

In [8]:
train_loader = DataLoader(
    train_dataset,
    batch_size=32,
    sampler=SubsetRandomSampler(train_indices),
    num_workers=4
)

val_loader = DataLoader(
    FishDataset('../data', transform=test_transform, target_transform=test_transform),
    batch_size=32,
    sampler=SubsetRandomSampler(train_indices),
    num_workers=4
)

In [9]:
def jaccard(outputs, targets):
    outputs = outputs.view(outputs.size(0), -1)
    targets = targets.view(targets.size(0), -1)
    intersection = (outputs * targets).sum(1)
    union = (outputs + targets).sum(1) - intersection
    jac = (intersection + 0.001) / (union + 0.001)
    return jac.mean()

In [10]:
model = UNet()
model.cuda()

UNet(
  (down1): Sequential(
    (0): conv_block(
      (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (batch_norm): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True)
      (leaky_relu): LeakyReLU(0.01)
    )
    (1): conv_block(
      (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (batch_norm): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True)
      (leaky_relu): LeakyReLU(0.01)
    )
  )
  (down2): Sequential(
    (0): conv_block(
      (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (batch_norm): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)
      (leaky_relu): LeakyReLU(0.01)
    )
    (1): conv_block(
      (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (batch_norm): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)
      (leaky_relu): LeakyReLU(0.01)
    )
  )
  (down3): Sequential(
    (0): conv_block(
      (conv): Conv2d(64, 128,

In [11]:
criterion = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters())

In [12]:
model_folder = os.path.abspath('../models')
if not os.path.exists(model_folder):
    os.mkdir(model_folder)
model_path = os.path.join(model_folder, 'unet.pt')

In [13]:
hist = {'loss': [], 'jaccard': [], 'val_loss': [], 'val_jaccard': []}
num_epochs = 5
display_steps = 50
best_jaccard = 0
for epoch in range(num_epochs):
    print('Starting epoch {}/{}'.format(epoch+1, num_epochs))
    # train
    model.train()
    running_loss = 0.0
    running_jaccard = 0.0
    for batch_idx, (images, masks, _) in enumerate(train_loader):
        images = Variable(images.cuda())
        masks = Variable(masks.cuda())
        
        optimizer.zero_grad()
        outputs = model(images)
        predicted = outputs.round()
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()
        
        jac = jaccard(outputs.round(), masks)
        running_jaccard += jac.data[0]
        running_loss += loss.data[0]
        
        if batch_idx % display_steps == 0:
            print('    ', end='')
            print('batch {:>3}/{:>3} loss: {:.4f}, jaccard {:.4f}\r'\
                  .format(batch_idx+1, len(train_loader),
                          loss.data[0], jac.data[0]))

        
    # evalute
    print('Finished epoch {}, starting evaluation'.format(epoch+1))
    model.eval()
    val_running_loss = 0.0
    val_running_jaccard = 0.0
    for images, masks, _ in val_loader:
        images = Variable(images.cuda())
        masks = Variable(masks.cuda())
        
        outputs = model(images)
        loss = criterion(outputs, masks)
        
        val_running_loss += loss.data[0]
        jac = jaccard(outputs.round(), masks)
        val_running_jaccard += jac.data[0]

    train_loss = running_loss / len(train_loader)
    train_jaccard = running_jaccard / len(train_loader)
    val_loss = val_running_loss / len(val_loader)
    val_jaccard = val_running_jaccard / len(val_loader)
    
    hist['loss'].append(train_loss)
    hist['jaccard'].append(train_jaccard)
    hist['val_loss'].append(val_loss)
    hist['val_jaccard'].append(val_jaccard)
    
    if val_jaccard > best_jaccard:
        torch.save(model, model_path)
    print('    ', end='')
    print('loss: {:.4f}  jaccard: {:.4f} \
           val_loss: {:.4f} val_jaccard: {:4.4f}\n'\
           .format(train_loss, train_jaccard, val_loss, val_jaccard))

Starting epoch 1/5
    batch   1/685 loss: 0.7597, jaccard 0.1310
    batch  51/685 loss: 0.6546, jaccard 0.6962
    batch 101/685 loss: 0.6477, jaccard 0.7860
    batch 151/685 loss: 0.6474, jaccard 0.7801
    batch 201/685 loss: 0.6405, jaccard 0.8275
    batch 251/685 loss: 0.6364, jaccard 0.8467
    batch 301/685 loss: 0.6396, jaccard 0.8081
    batch 351/685 loss: 0.6308, jaccard 0.8697
    batch 401/685 loss: 0.6329, jaccard 0.8403
    batch 451/685 loss: 0.6293, jaccard 0.8602
    batch 501/685 loss: 0.6279, jaccard 0.8722
    batch 551/685 loss: 0.6297, jaccard 0.8639
    batch 601/685 loss: 0.6294, jaccard 0.8621
    batch 651/685 loss: 0.6205, jaccard 0.8920
Finished epoch 1, starting evaluation
    loss: 0.6381  jaccard: 0.8154            val_loss: 0.6274 val_jaccard: 0.8766

Starting epoch 2/5
    batch   1/685 loss: 0.6228, jaccard 0.8836
    batch  51/685 loss: 0.6248, jaccard 0.8651
    batch 101/685 loss: 0.6259, jaccard 0.8638
    batch 151/685 loss: 0.6236, jaccard 0.