In [20]:
import torch
import cv2
import numpy as np
import glob
import math
import os
import sys
from tqdm import tqdm
import torch
import torchvision.transforms as transforms
import torch.utils.data as torch_data
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F

sys.path.append(os.path.normpath('../scripts/'))

#### Dataset

In [66]:
class GTiffDataset(torch_data.Dataset):
    def __init__(self, 
                 root_dir, split, tile_size = 256, stride = 256, debug = False, transform=None):
        """
        Args:
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.tile_size = tile_size
        self.stride = stride
        self.root_dir = root_dir
        self.transform = transform
        self.debug = debug
        self.images, self.masks = self.read_dir()
        
    def get_tiles(self, image, mask):
        i_tiles, m_tiles = [], []
        width = image.shape[1] - image.shape[1]%self.stride
        height = image.shape[0] - image.shape[0]%self.stride
             
        for i in range(0, height, self.stride):
            for j in range(0, width, self.stride):
                img_tile = image[
                    i:i+self.tile_size, 
                    j:j+self.tile_size
                ]
                mask_tile = image[
                    i:i+self.tile_size, 
                    j:j+self.tile_size
                ]
                i_tiles.append(img_tile)
                m_tiles.append(mask_tile)
                
                if self.debug:
                    # Debugging the tiles
                    cv2.imwrite("debug_" + str(i) + "_" + str(j) + "_img.png", img_tile)
                    cv2.imwrite("debug_" + str(i) + "_" + str(j) + "_mask.png", mask_tile)
        return i_tiles, m_tiles

    def read_dir(self):
        tiles = [[], []]
        images = sorted(glob.glob(self.root_dir + '/' + '101001000A4E4B00_4326_cropped.png'))
        masks = sorted(glob.glob(self.root_dir + '/' + '101001000A4E4B00_mask_4326.tif'))
        for idx, [i, m] in enumerate(zip(images, masks)):
            print('Reading item # {}/{}'.format(idx+1, len(images)))
            image = cv2.imread(i)
            mask = cv2.imread(m, 0)
            i_tiles, m_tiles = self.get_tiles(image, mask)
            for im, ma in zip(i_tiles, m_tiles):
                tiles[0].append(im)
                tiles[1].append(ma)
        return tiles
    
    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        sample = [self.images[idx], self.masks[idx]]
        return sample

In [71]:
gtiffdataset = GTiffDataset('../../data/pre-processed/dryvalleys/QB02', split='train',stride=128, debug=False)

Reading item # 1/1


In [72]:
train_dataloader = torch_data.DataLoader(gtiffdataset, num_workers=0)

##### LOSS

In [32]:
def focal_loss(output, target, gamma=2, alpha=0.5):
    n, c, h, w = logit.size()
    criterion = nn.CrossEntropyLoss(weight=self.weight, ignore_index=self.ignore_index,
                                    size_average=self.size_average)
    if self.cuda:
        criterion = criterion.cuda()

    logpt = -criterion(logit, target.long())
    pt = torch.exp(logpt)
    if alpha is not None:
        logpt *= alpha
    loss = -((1 - pt) ** gamma) * logpt

    if self.batch_average:
        loss /= n

    return loss

### Model

#### DeepLab V3 with ResNet backbone

In [5]:
from models import deeplab
model = deeplab.DeepLab(output_stride=16)
model.train()

DeepLab(
  (backbone): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace)
        (downsample): Sequential(
          (0): Conv

In [35]:
optimizer = torch.optim.SGD(
    lr=0.001, 
    momentum=0.9,
    weight_decay=5e-4, 
    nesterov=False,
    params=model.get_1x_lr_params()
)

In [36]:
criterion = focal_loss

In [37]:
def train(model, optimizer, criterion, device, dataloader):
    model.train()
    train_loss = 0.0
    tbar = tqdm(dataloader)
    num_samples = len(dataloader)
    print(num_samples)
    for i, sample in enumerate(tbar):
        image, target = sample[0], sample[1]
        image, target = image.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(image)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))

In [73]:
if __name__=="__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    train(model, optimizer, criterion, device, train_dataloader)




  0%|                                                                                                                                        | 0/31476 [00:00<?, ?it/s]

31476


RuntimeError: cuda runtime error (30) : unknown error at ..\aten\src\THC\THCGeneral.cpp:87