In [None]:
! pip install torchvision
! pip install torchfcn

In [None]:
import os
root = "."
!ls .
os.chdir(root)
print(os.getcwd())

In [None]:
import pandas as pd
import os.path as osp
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import shutil
import collections
import PIL
import torchfcn

from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from collections import Counter
from torchvision import transforms

In [None]:
transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.RandomCrop(60),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor()
    ])

In [None]:
# create dataset class
class Bags(Dataset):
  """Bags dataset"""
  class_names = np.array([
        'background',
        'bag'
    ])
  mean_bgr = np.array([104.00698793, 116.66876762, 122.67891434])

  def __init__(self, root, split='train', transform=False):
      self.root = root
      self.split = split
      self._transform = transform

      dataset_dir = osp.join(self.root, 'bags_data')
      self.files = collections.defaultdict(list)
      for split in ['train', 'val']:
          imgsets_file = osp.join(
              dataset_dir, 'imagesets/%s.txt' % split)
          for did in open(imgsets_file):
              did = did.strip()
              img_file = osp.join(dataset_dir, 'JPEGImages/%s.jpg' % did)
              lbl_file = osp.join(
                  dataset_dir, 'segmentation_mask/%s.png' % did)
              self.files[split].append({
                  'img': img_file,
                  'lbl': lbl_file,
              })

  def __len__(self):
      return len(self.files[self.split])

  def __getitem__(self, index):
      data_file = self.files[self.split][index]
      # load image
      img_file = data_file['img']
      img = PIL.Image.open(img_file).convert('RGB')
      img = img.resize((224, 224))
      if img.getbands()[0] == 'L':
        img = img.convert('RGB')
      img = np.array(img, dtype=np.uint8)
      # load label
      lbl_file = data_file['lbl']
      lbl = PIL.Image.open(lbl_file)
      lbl = lbl.resize((224, 224))
      lbl = np.array(lbl, dtype=np.int32)
      lbl[lbl == 255] = -1
      if self._transform:
        return self.transform(img, lbl)
      else:
        return img, lbl

  def transform(self, img, lbl):
      img = img.astype(np.float64)
      img -= self.mean_bgr
      img = img.transpose(2, 0, 1)
      img = torch.from_numpy(img).float()
      lbl = torch.from_numpy(lbl).long()
      return img, lbl

  def untransform(self, img, lbl):
      img = img.numpy()
      img = img.transpose(1, 2, 0)
      img += self.mean_bgr
      img = img.astype(np.uint8)
      img = img[:, :, ::-1]
      lbl = lbl.numpy()
      return img, lbl

In [None]:
# create data sets and dataloaders
train_set = Bags(root, "train", True)
test_set = Bags(root, "val", True)

#use cuda
cuda = torch.cuda.is_available()
kwargs = {'num_workers': 4, 'pin_memory': True} if cuda else {}
train_loader = DataLoader(train_set, batch_size = 32, shuffle=True, **kwargs)
val_loader = DataLoader(test_set, batch_size = 32, shuffle=False, **kwargs)

In [None]:
def get_parameters(model, bias=False):
    import torch.nn as nn
    modules_skipped = (
        nn.ReLU,
        nn.MaxPool2d,
        nn.Dropout2d,
        nn.Sequential,
        torchfcn.models.FCN32s,
        torchfcn.models.FCN16s,
        torchfcn.models.FCN8s,
    )
    for m in model.modules():
        if isinstance(m, nn.Conv2d):
            if bias:
                yield m.bias
            else:
                yield m.weight
        elif isinstance(m, nn.ConvTranspose2d):
            # weight is frozen because it is just a bilinear upsampling
            if bias:
                assert m.bias is None
        elif isinstance(m, modules_skipped):
            continue
        else:
            raise ValueError('Unexpected module: %s' % str(m))

# learning rate
lr = 1e-3
criterion = torch.nn.CrossEntropyLoss()

#checkpoint path file
resume =  None

# Create model
model = torchfcn.models.FCN32s(n_class=2)
start_epoch = 0
start_iteration = 0

# use a pretrained .pth file if provided.
if resume:
    checkpoint = torch.load(resume)
    model.load_state_dict(checkpoint['model_state_dict'])
    start_epoch = checkpoint['epoch']
    start_iteration = checkpoint['iteration']
else:
    vgg16 = torchfcn.models.VGG16(pretrained=True)
    model.copy_params_from_vgg16(vgg16)

# use cuda option
if cuda:
    model = model.cuda()

# Initialize optimizer
optimizer = torch.optim.SGD(
    [
        {'params': get_parameters(model, bias=False)},
        {'params': get_parameters(model, bias=True),
         'lr': lr * 2, 'weight_decay': 0},
    ],
    lr=lr)

if resume:
    optimizer.load_state_dict(checkpoint['optim_state_dict'])

# create a trainer object
trainer = torchfcn.Trainer(
    cuda=cuda,
    model=model,
    optimizer=optimizer,
    train_loader=train_loader,
    val_loader=val_loader,
    out=".",
    max_iter=1000,
    interval_validate=400,
)
trainer.epoch = start_epoch
trainer.iteration = start_iteration
trainer.train()

In [None]:
import tqdm
from torch.autograd import Variable

def validate():
    visualizations = []
    label_trues, label_preds = [], []
    for batch_idx, (data, target) in tqdm.tqdm(enumerate(val_loader),
                                               total=len(val_loader),
                                               ncols=80, leave=False):
        if torch.cuda.is_available():
            data, target = data.cuda(), target.cuda()
        data, target = Variable(data, volatile=True), Variable(target)
        score = model(data)

        imgs = data.data.cpu()
        lbl_pred = score.data.max(1)[1].cpu().numpy()[:, :, :]
        lbl_true = target.data.cpu()
        for img, lt, lp in zip(imgs, lbl_true, lbl_pred):
            img, lt = val_loader.dataset.untransform(img, lt)
            label_trues.append(lt)
            label_preds.append(lp)
            if len(visualizations) < 9:
                viz = fcn.utils.visualize_segmentation(
                    lbl_pred=lp, lbl_true=lt, img=img, n_class=n_class,
                    label_names=val_loader.dataset.class_names)
                visualizations.append(viz)
    metrics = torchfcn.utils.label_accuracy_score(
        label_trues, label_preds, n_class=n_class)
    metrics = np.array(metrics)
    metrics *= 100
    print('''\Accuracy: {0}, Accuracy Class: {1}, Mean IU: {2}, FWAV Accuracy: {3}'''.format(*metrics))
    viz = fcn.utils.get_tile_image(visualizations)
    skimage.io.imsave('viz_evaluate.png', viz)

validate()