In [1]:
import sys
sys.path.append('../src')

In [2]:
import numpy as np
import torch 
import torch.nn as nn
from torch.autograd import Variable
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
from torchvision import transforms
from sklearn.model_selection import train_test_split
from FishDataset import FishDataset
from model import UNet

  (fname, cnt))
  (fname, cnt))


In [3]:
train_transform = transforms.Compose([
    transforms.Resize(size=(128, 128)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor()
])

test_transform = transforms.Compose([
    transforms.Resize(size=(128, 128)),
    transforms.ToTensor()
])

In [4]:
train_dataset = FishDataset('../data', download=True, transform=train_transform, target_transform=train_transform)

In [5]:
train_indices, test_indices = train_test_split(np.arange(len(train_dataset)), test_size=0.2)

In [6]:
train_loader = DataLoader(
    train_dataset,
    batch_size=32,
    sampler=SubsetRandomSampler(train_indices),
    num_workers=4
)

val_loader = DataLoader(
    FishDataset('../data', transform=test_transform, target_transform=test_transform),
    batch_size=32,
    sampler=SubsetRandomSampler(train_indices),
    num_workers=4
)

In [7]:
def jaccard(outputs, targets):
    outputs = outputs.view(outputs.size(0), -1)
    targets = targets.view(targets.size(0), -1)
    intersection = (outputs * targets).sum(1)
    union = (outputs + targets).sum(1) - intersection
    jac = (intersection + 0.001) / (union + 0.001)
    return jac.mean()

In [8]:
model = UNet()
model.cuda()

UNet(
  (down1): Sequential(
    (0): conv_block(
      (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (batch_norm): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True)
      (leaky_relu): LeakyReLU(0.01)
    )
    (1): conv_block(
      (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (batch_norm): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True)
      (leaky_relu): LeakyReLU(0.01)
    )
  )
  (down2): Sequential(
    (0): conv_block(
      (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (batch_norm): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)
      (leaky_relu): LeakyReLU(0.01)
    )
    (1): conv_block(
      (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (batch_norm): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)
      (leaky_relu): LeakyReLU(0.01)
    )
  )
  (down3): Sequential(
    (0): conv_block(
      (conv): Conv2d(64, 128,

In [9]:
criterion = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters())

In [None]:
hist = {'loss': [], 'val_loss': [], 'val_jaccard': []}
num_epochs = 20
for epoch in range(num_epochs):
    print('Starting epoch {}/{}'.format(epoch+1, num_epochs))
    # train
    model.train()
    running_loss = 0.0
    for batch_idx, (images, masks, _) in enumerate(train_loader):
        images = Variable(images.cuda())
        masks = Variable(masks.cuda())
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.data[0]
        if batch_idx % 10 == 0:
             print('loss: {:.4f}'.format(loss.data[0]))
        
    # evalute
    print('Finished epoch {}, starting evaluation'.format(epoch+1))
    model.eval()
    val_running_loss = 0.0
    val_running_jaccard = 0.0
    for images, masks, _ in val_loader:
        images = Variable(images.cuda())
        masks = Variable(masks.cuda())
        
        outputs = model(images)
        loss = criterion(outputs, masks)
        
        val_running_loss += loss.data[0]
        predicted = outputs.data.round()
        val_running_jaccard += jaccard(predicted, masks.data)
    
    loss = running_loss / len(train_loader)
    val_loss = val_running_loss / len(val_loader)
    val_jaccard = val_running_jaccard / len(val_loader)
    
    hist['loss'].append(loss)
    hist['val_loss'].append(val_loss)
    hist['val_jaccard'].append(val_jaccard)
    
    print('loss: {:.4f}  val_loss: {:.4f} val_jaccard: {:4.4f}\n'.format(loss, val_loss, val_jaccard))


Starting epoch 1/20
loss: 0.7847
loss: 0.6674
loss: 0.6491
loss: 0.6526
loss: 0.6463
loss: 0.6425
loss: 0.6447
loss: 0.6427
loss: 0.6455
loss: 0.6418
loss: 0.6449
loss: 0.6402
loss: 0.6404
loss: 0.6396
loss: 0.6409
loss: 0.6358
loss: 0.6340
loss: 0.6352
loss: 0.6389
loss: 0.6369
loss: 0.6358
loss: 0.6345
loss: 0.6352
loss: 0.6363
loss: 0.6342
loss: 0.6339
loss: 0.6328
loss: 0.6326
loss: 0.6305
loss: 0.6289
loss: 0.6285
loss: 0.6276
loss: 0.6319
loss: 0.6295
loss: 0.6276
loss: 0.6287
loss: 0.6314
loss: 0.6364
loss: 0.6274
loss: 0.6302
loss: 0.6344
loss: 0.6287
loss: 0.6301
loss: 0.6256
loss: 0.6288
loss: 0.6194
loss: 0.6272
loss: 0.6234
loss: 0.6269
loss: 0.6267
loss: 0.6249
loss: 0.6249
loss: 0.6240
loss: 0.6255
loss: 0.6244
loss: 0.6301
loss: 0.6259
loss: 0.6280
loss: 0.6240
loss: 0.6231
loss: 0.6248
loss: 0.6246
loss: 0.6250
loss: 0.6268
loss: 0.6220
loss: 0.6242
loss: 0.6236
loss: 0.6235
loss: 0.6219
Finished epoch 1, starting evaluation
loss: 0.6336  val_loss: 0.6222 val_jaccard: 0