In [0]:
!cat /etc/os-release
!nvidia-smi

In [0]:
from torch import nn
from torch.nn import functional as F
import torch
from torchvision import models
import torchvision
  
from PIL import Image

import torch
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader
from torchvision import transforms
import torch.optim as optim

import os
import glob
import random
import math
import numbers




# Get dataset
use_gdrive = False
if (use_gdrive):
  from google.colab import drive
  drive.mount('/content/drive/', True)
  dataset_dir = "/content/drive/My Drive/MIIPS/Deepglobe/road-train-1.v2/"
  project_dir = "/content/drive/My Drive/MIIPS/Deepglobe/road-train-1.v2/"
else:
  project_dir = "/opt/colab/Road_extraction/"
  dataset_dir = project_dir + "dataset/"

  
  
os.chdir(project_dir)

model_save_path = "classifier.pth"


!ls {dataset_dir}

# Setup

In [0]:
# some flags
USE_FULL_SIZE_IMAGES = True


# Define network
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, padding=1, kernel_size=3, stride=1, with_nonlinearity=True):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, padding=padding, kernel_size=kernel_size, stride=stride)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU()
        self.with_nonlinearity = with_nonlinearity

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        if self.with_nonlinearity:
            x = self.relu(x)
        return x  
      
class DecoderBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DecoderBlock, self).__init__()
        self.in_channels = in_channels
        
        # Following Decoder Block in Figure 2
        self.block = nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2),
            ConvBlock(out_channels,out_channels),
            ConvBlock(out_channels,out_channels)
        )



    def forward(self, x):
        return self.block(x)
      
class UNet(nn.Module):

    def __init__(self, n_classes=1, num_filters=32, pretrained=True, is_deconv=False, dropout_p=0.3):
        """
        :param n_classes:
        :param num_filters:
        :param pretrained:
            False - no pre-trained network is used
            True  - encoder is pre-trained with resnet34
        :is_deconv:
            False: bilinear interpolation is used in decoder
            True: deconvolution is used in decoder
        """
        super().__init__()
        self.n_classes = n_classes
        self.pool = nn.MaxPool2d(2, 2)
        self.encoder = torchvision.models.resnet34(pretrained=pretrained)
#         self.relu = nn.ReLU(inplace=True)

        self.conv1 = nn.Sequential(self.encoder.conv1,
                                   self.encoder.bn1,
                                   self.encoder.relu,
                                   self.pool)
        self.conv2 = self.encoder.layer1
        self.conv3 = self.encoder.layer2
        self.conv4 = self.encoder.layer3
        self.conv5 = self.encoder.layer4
        
        self.center = DecoderBlock(512, 256) # self.center is the decoder right after the pool 2x2
        
        self.dec5 = DecoderBlock(768, 256)
        self.dec4 = DecoderBlock(512, 256)
        self.dec3 = DecoderBlock(384, 64)
        self.dec2 = DecoderBlock(128, 128)
        self.dec1 = DecoderBlock(128, 32)
        self.dec0 = ConvBlock(32, 32)
        self.final = nn.Conv2d(32, n_classes, kernel_size=1)
        self.dropout = nn.Dropout(dropout_p)

    def forward(self, x):
        conv1 = self.conv1(x)
        conv2 = self.conv2(conv1)
        conv3 = self.conv3(conv2)
        conv4 = self.conv4(conv3)
        conv5 = self.conv5(conv4)

        center = self.center(self.pool(conv5))

        dec5 = self.dec5(torch.cat([center, conv5], 1))

        dec4 = self.dec4(torch.cat([dec5, conv4], 1))
        dec3 = self.dec3(torch.cat([dec4, conv3], 1))
        dec2 = self.dec2(torch.cat([dec3, conv2], 1))
        dec1 = self.dec1(dec2)
        dec0 = self.dec0(dec1)

        if self.n_classes > 1:
            x_out = F.log_softmax(self.final(dec0), dim=1)
        else:
            x_out = sigmoid(self.final(dec0))

#         return self.dropout(x_out)
        return x_out


# Define loss/acc functions
from torch.nn import Module
from torch import sigmoid

# Matches standard definition of jaccard (and definition on codelab)  
def jaccard(outputs, targets):
  outputs = outputs.round().int()
  targets = targets.round().int()

  intersection = outputs & targets
  union = outputs | targets

  return float(intersection.sum())/float(union.sum())

# From previous implementation
def soft_jaccard(outputs, targets, weight=1):
    eps = 1e-15
    jaccard_target = (targets == 1).float()
    jaccard_output = sigmoid(outputs)

    intersection = (jaccard_output * jaccard_target).sum()
    union = jaccard_output.sum() + jaccard_target.sum()
    return intersection / (union - intersection + eps)

    
# Directly from the paper     
def paper_jaccard(predictions, labels):
  predictions = predictions.reshape(-1)
  labels = labels.reshape(-1)

  summ = (labels*predictions)/(labels + predictions - labels*predictions)
  
  return summ.sum()/len(predictions);


# Directly from the paper
class PaperLoss:
  def __init__(self, alpha = 0.7):
    self.alpha = alpha
    self.bce = nn.BCELoss()
    
  def __call__(self, predictions, labels):
    j = paper_jaccard(predictions, labels)
    loss = self.alpha*self.bce(predictions,labels) - (1 - self.alpha)*math.log(j)
    return loss
    




# Dataset utils    
class SeedableRandomRotation(object):
    """Rotate the image by angle.
       Adapted version of RandomRotation (https://pytorch.org/docs/stable/torchvision/transforms.html#torchvision.transforms.RandomRotation) that allows for a specified seed.
    """

    def __init__(self, degrees, resample=False, expand=False, center=None, seed = random.random()):
        if isinstance(degrees, numbers.Number):
            if degrees < 0:
                raise ValueError("If degrees is a single number, it must be positive.")
            self.degrees = (-degrees, degrees)
        else:
            if len(degrees) != 2:
                raise ValueError("If degrees is a sequence, it must be of len 2.")
            self.degrees = degrees

        self.resample = resample
        self.expand = expand
        self.center = center
        self.random = random.Random()
        self.random.seed(seed)

    @staticmethod
    def get_params(degrees, rand = random):
        """Get parameters for ``rotate`` for a random rotation.

        Returns:
            sequence: params to be passed to ``rotate`` for random rotation.
        """
        angle = rand.uniform(degrees[0], degrees[1])

        return angle

    def __call__(self, img):
        """
            img (PIL Image): Image to be rotated.

        Returns:
            PIL Image: Rotated image.
        """

        angle = self.get_params(self.degrees, rand = self.random)

        return transforms.functional.rotate(img, angle, self.resample, self.expand, self.center)


class SeedableRandomResizedCrop(object):
    """Crop the given PIL Image to random size and aspect ratio.
       Adapted version of RandomResizedCrop (https://pytorch.org/docs/stable/torchvision/transforms.html#torchvision.transforms.RandomResizedCrop) that allows for a specified seed.
    """

    def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolation=Image.BILINEAR, seed = random.random()):
        if isinstance(size, tuple):
            self.size = size
        else:
            self.size = (size, size)
        if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
            warnings.warn("range should be of kind (min, max)")

        self.interpolation = interpolation
        self.scale = scale
        self.ratio = ratio
        self.random = random.Random()
        self.random.seed(seed)

    @staticmethod
    def get_params(img, scale, ratio, rand = random):
        """Get parameters for ``crop`` for a random sized crop.

        Args:
            img (PIL Image): Image to be cropped.
            scale (tuple): range of size of the origin size cropped
            ratio (tuple): range of aspect ratio of the origin aspect ratio cropped

        Returns:
            tuple: params (i, j, h, w) to be passed to ``crop`` for a random
                sized crop.
        """
        area = img.size[0] * img.size[1]

        for attempt in range(10):
            target_area = rand.uniform(*scale) * area
            log_ratio = (math.log(ratio[0]), math.log(ratio[1]))
            aspect_ratio = math.exp(rand.uniform(*log_ratio))

            w = int(round(math.sqrt(target_area * aspect_ratio)))
            h = int(round(math.sqrt(target_area / aspect_ratio)))

            if w <= img.size[0] and h <= img.size[1]:
                i = rand.randint(0, img.size[1] - h)
                j = rand.randint(0, img.size[0] - w)
                return i, j, h, w

        # Fallback to central crop
        in_ratio = img.size[0] / img.size[1]
        if (in_ratio < min(ratio)):
            w = img.size[0]
            h = w / min(ratio)
        elif (in_ratio > max(ratio)):
            h = img.size[1]
            w = h * max(ratio)
        else:  # whole image
            w = img.size[0]
            h = img.size[1]
        i = (img.size[1] - h) // 2
        j = (img.size[0] - w) // 2
        return i, j, h, w

    def __call__(self, img):
        """
        Args:
            img (PIL Image): Image to be cropped and resized.

        Returns:
            PIL Image: Randomly cropped and resized image.
        """
        i, j, h, w = self.get_params(img, self.scale, self.ratio, rand = self.random)
        return transforms.functional.resized_crop(img, i, j, h, w, self.size, self.interpolation)

    
    
class CustomDataset(Dataset):

    def __init__(self, image_paths, target_paths):
      
      assert(len(image_paths) == len(target_paths))  # sloppy sanity check
      
      self.image_paths = image_paths
      self.target_paths = target_paths
      
    def __getitem__(self, index):
        image = Image.open(self.image_paths[index])
        mask = Image.open(self.target_paths[index]).convert('1')
        
        
        randomSeed = random.random()  # Ensure both the image and mask have the same seed so that the same transformations are applied to both
        
        if (USE_FULL_SIZE_IMAGES):
          # Don't include random crop
          # TODO: There's probably a cleaner way to do this
          t_image = transforms.Compose([
              transforms.ColorJitter(brightness=0.2,contrast=0.2,hue=0.02),
  #             SeedableRandomResizedCrop(size = 448, scale = (0.6, 1.4), seed = randomSeed),
              SeedableRandomRotation(degrees = 30, seed = randomSeed),
              transforms.ToTensor()
          ])(image)
          t_mask = transforms.Compose([
  #             SeedableRandomResizedCrop(size = 448, scale = (0.6, 1.4), seed = randomSeed),
              SeedableRandomRotation(degrees = 30, seed = randomSeed),
              transforms.ToTensor()
          ])(mask)
        else:
          t_image = transforms.Compose([
              transforms.ColorJitter(brightness=0.2,contrast=0.2,hue=0.02),
              SeedableRandomResizedCrop(size = 448, scale = (0.6, 1.4), seed = randomSeed),
              SeedableRandomRotation(degrees = 30, seed = randomSeed),
              transforms.ToTensor()
          ])(image)
          t_mask = transforms.Compose([
              SeedableRandomResizedCrop(size = 448, scale = (0.6, 1.4), seed = randomSeed),
              SeedableRandomRotation(degrees = 30, seed = randomSeed),
              transforms.ToTensor()
          ])(mask)

        
        return t_image, t_mask
    
    def __len__(self): 
        return len(self.image_paths)

      
      
# Load dataset      
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

folder_data = glob.glob(dataset_dir + "train/*.jpg")
folder_mask = glob.glob(dataset_dir + "train/*.png")

len_data = len(folder_data)
train_size = 0.7


# Sort because glob returns in a nondeterministic order
folder_data = sorted(folder_data)
folder_mask = sorted(folder_mask)

train_image_paths = folder_data[:int(len_data*train_size)]
test_image_paths = folder_data[int(len_data*train_size):]

train_mask_paths = folder_mask[:int(len_data*train_size)]
test_mask_paths = folder_mask[int(len_data*train_size):]

train_dataset = CustomDataset(train_image_paths, train_mask_paths)
test_dataset = CustomDataset(test_image_paths, test_mask_paths)

print("Size of train dataset :", len(train_dataset))
print("Size of test dataset  :", len(test_dataset))
print("Total number of images:", len(train_dataset) + len(test_dataset))

batch_size = 8
if (USE_FULL_SIZE_IMAGES):
  # Full size images take up a lot of memory; decrease batch size so that it fits
  batch_size = 2


train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size)


# Load model
model = UNet(n_classes=1).to(device)

optimizer = optim.Adam(model.parameters(), lr=1e-04, weight_decay = 1e-04)


criterion = PaperLoss()

# Training

In [0]:
import datetime
print("Started at:", datetime.datetime.now())

for epoch in range(35):  # loop over the dataset multiple times (about 20k batches total)
    model.train()
    running_loss = 0.0
    
    for i, (image, mask) in enumerate(train_loader):
        # get the inputs
        inputs, mask = image.to(torch.cuda.current_device()), mask.to(torch.cuda.current_device())

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(inputs)
        loss = criterion(outputs, mask)
        loss.backward()
        optimizer.step()
        
        # print statistics
        running_loss += loss.item()
        if i % 10 == 9:    # print every 20 mini-batches
          print('[%d, %5d] loss: %.3f, accuracy: %.3f'  % (epoch + 1, i + 1, running_loss/20, jaccard (outputs, mask)))
          running_loss = 0.0
          
        
        if i % 100 == 99:  # save every 100 mini-batches
          torch.save(model.state_dict(), model_save_path)

print('Saving')
torch.save(model.state_dict(), model_save_path)

print('Finished Training')
print("Finished at:", datetime.datetime.now())

# Validation

In [0]:
def get_tta_output(inputs, model):
  '''
  Makes predictions on 4 rotations of the inputs and averages them
  '''
  
  outputs = model(inputs)

  rots = 4

  # Predict non-rotated image
  outputs = model(inputs)

  # Predict the rotated images
  for rot in range(1, rots):

    # Incredibly slow way to do it, defeats the purpose of using gpu
    inputs_rot = torch.stack([
        transforms.ToTensor()(
        transforms.functional.rotate(
            transforms.ToPILImage()(inpt.cpu()),
            90*rot)
        ).to(torch.cuda.current_device())

        for inpt in inputs
    ])

    outputs_rot = model(inputs_rot)

    # Unrotate and add to the running average
    outputs += torch.stack([
        transforms.ToTensor()(
        transforms.functional.rotate(
            transforms.ToPILImage()(otpt.cpu()),
            -90*rot)
        ).to(torch.cuda.current_device())

        for otpt in outputs_rot
    ])


  # Taking an average, so divide by the number of rotations
  outputs /= rots
  
  return outputs

In [0]:
debug_print_every = 1
USE_TTA = False

import matplotlib.pyplot as plt
import numpy as np

acc_fn = jaccard


model = UNet().to(device)
model.load_state_dict(torch.load(model_save_path))

print("Done loading model.")

model.eval()

running_loss = 0.0
running_acc = 0.0
total_acc = 0.0
  
print( datetime.datetime.now(),"Starting validation",)  
with torch.no_grad():
      for i, (image, mask) in enumerate(test_loader):
        inputs, mask = image.to(torch.cuda.current_device()), mask.to(torch.cuda.current_device())

        if USE_TTA:
          outputs = get_tta_output(inputs, model)
        else:
          outputs = model(inputs)
          
        loss = criterion(outputs, mask)
        
        # print statistics
        running_loss += loss.item()
        running_acc += acc_fn(outputs, mask)
        total_acc += acc_fn(outputs, mask)
        
        if i % debug_print_every == (debug_print_every-1):    # print every 20 mini-batches
          print('[%5d] loss: %.3f,\taccuracy: %.3f,\tavg_acc: %.3f'  % (i + 1, running_loss/debug_print_every, running_acc/debug_print_every, total_acc/(i+1)))
          
          running_loss = 0.0
          running_acc = 0.0

print(datetime.datetime.now(), 'Finished validation with average accuracy:', total_acc/(i+1))

In [0]:
# Display validation results
max_iters = 200

import matplotlib.pyplot as plt
import numpy as np


model = UNet().to(device)
model.load_state_dict(torch.load(model_save_path))
model.eval()
print("Done loading.")

print("\tInput\tMask\tprediction\tbinarized") # print headers
with torch.no_grad():
  for i, (image, mask) in enumerate(test_loader):
    
    if i >= max_iters:
      break
    
    inputs, mask = image.to(torch.cuda.current_device()), mask.to(torch.cuda.current_device())
    outputs = model(inputs)
    loss = criterion(outputs, mask)
    
    print('[%d, %5d] loss: %.3f, accuracy: %.3f'  % (-1, i + 1, loss, jaccard (outputs, mask)))

    
    cmap = "hot"
    cmap = "nipy_spectral"
    cmap = "viridis" # default
    
    fig=plt.figure(figsize=(8, 8))
    columns = 4
    rows = 1
    
    fig.add_subplot(rows, columns, 1)
    inputplot = plt.imshow(inputs.cpu()[0][0], cmap=cmap)
    fig.add_subplot(rows, columns, 2)
    maskplot = plt.imshow(mask.cpu()[0][0], cmap=cmap)
    fig.add_subplot(rows, columns, 3)
    outplot = plt.imshow(outputs.cpu()[0][0], cmap=cmap)
    fig.add_subplot(rows, columns, 4)
    outplot = plt.imshow(outputs.round().cpu()[0][0], cmap=cmap)    
    
    plt.show()

# Predict on validation images (for submission)

In [0]:
USE_TTA = False
debug_show_images = False
SAVE_IMAGES = True
save_path = project_dir + "out/"

if USE_TTA:
#   threshold_bias = 0.25
  threshold_bias = 0
else:
  threshold_bias = 0  # bias of 0 means thresholding at 0.5


try:
  os.mkdir(save_path)
except FileExistsError:
  print("Skipped creating " + save_path + "; already exists")

def binarize_and_save_mask():
  pass


import matplotlib.pyplot as plt
import numpy as np

from PIL import Image

import torch
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader
from torchvision import transforms
import torch.optim as optim
import glob

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

image_paths = glob.glob(dataset_dir + "valid/*.jpg")

model = UNet().to(device)
model.load_state_dict(torch.load(model_save_path))


print("Done loading.")

model.eval()

with torch.no_grad():
  for i,path in enumerate(image_paths):
    fileid = os.path.split(path)[1][:-len('_sat.jpg')]
    print(fileid)

    image = Image.open(path)

    inputs = transforms.ToTensor()(image).unsqueeze(0).to(torch.cuda.current_device())
    
    if USE_TTA:
      outputs = get_tta_output(inputs, model)
    else:
      outputs = model(inputs)


    if debug_show_images:
      cmap = "viridis" # default

      fig=plt.figure(figsize=(8, 8))
      columns = 3
      rows = 1

      fig.add_subplot(rows, columns, 1)
      inputplot = plt.imshow(inputs.cpu()[0][0], cmap=cmap)
      fig.add_subplot(rows, columns, 2)
      outplot = plt.imshow(outputs.cpu()[0][0], cmap=cmap)
      
      # binarize
      fig.add_subplot(rows, columns, 3)
      outplot = plt.imshow((outputs + threshold_bias).round().cpu()[0][0], cmap=cmap)

      plt.show()
      
    if SAVE_IMAGES:
      outputs = outputs[0] # select 1st (and only) image from batch
      outputs += threshold_bias  # bias so binarize threshold is changed
      outputs = outputs.expand(3,-1,-1) # turn the grayscale image into 3 color channels
      outputs = outputs.round()  # binarize (threshold of 0.5) # todo: this should be above expand, right? Or does it matter if it's just a view?
      outputs = outputs*255      # scale to 0-255
      out_img = outputs.cpu()  
      
      out_path = save_path + fileid + "_mask.png"
      
      torchvision.utils.save_image(out_img, out_path)
      
    print("{0:.2f} percent complete".format((i+1)/len(image_paths)*100))

      