## **4. Import the necessary packages:**
numpy, io, glob, tqdm_notebook, confusion_matrix, random, itertools, matplotlib.pyplot, torch, torch.nn,  torch.nn.functional, torch.utils.data, torch.optim, torch.optim.lr_scheduler, torch.nn.init


In [1]:
# Importing 
from skimage import io
from glob import glob
from tqdm import tqdm_notebook as tqdm
from sklearn.metrics import confusion_matrix
import numpy as np
import random
import itertools
import matplotlib.pyplot as plt
import imagecodecs
# %matplotlib inline
# Torch imports
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torch.optim as optim
import torch.optim.lr_scheduler
import torch.nn.init
from torch.autograd import Variable
import os
from IPython.display import clear_output
import tifffile as tiff  

 ## **5. Initialization:**

In [2]:
# Parameters
IN_CHANNELS =  3                          # Number of input channels (e.g. RGB)
MAIN_FOLDER  =    "dataset/"   # Replace with your "/path/to/the/Images/folder/"
BATCH_SIZE =   10            # Number of samples in a mini-batch, example 10
LABELS = ["roads", "buildings", "low veg.", "trees", "cars", "clutter"] # Label names
N_CLASSES = len(LABELS)                   # Number of classes
weights = torch.ones(N_CLASSES)           # Weights for class balancing
DATA_FOLDER = MAIN_FOLDER + 'Images/Image_{}.tif'
LABELS_FOLDER = MAIN_FOLDER + 'Labels/Label_{}.tif'

 ## **6. Functions you may need:**

In [3]:
# Let's define the standard ISPRS color palette
palette = {0 : (255, 255, 255), # Impervious surfaces (white)
           1 : (0, 0, 255),     # Buildings (blue)
           2 : (0, 255, 255),   # Low vegetation (cyan)
           3 : (0, 255, 0),     # Trees (green)
           4 : (255, 255, 0),   # Cars (yellow)
           5 : (255, 0, 0),     # Clutter (red)
           6 : (0, 0, 0)}       # Undefined (black)
invert_palette = {v: k for k, v in palette.items()}
def convert_from_color(arr_3d, palette=invert_palette):
    """ RGB-color encoding to grayscale labels """ '(From 0 to 6)'
    arr_2d = np.zeros((arr_3d.shape[0], arr_3d.shape[1]), dtype=np.uint8)
    for c, i in palette.items():
        m = np.all(arr_3d == np.array(c).reshape(1, 1, 3), axis=2)
        arr_2d[m] = i
    return arr_2d
class Load_dataset(torch.utils.data.Dataset):
    def __init__(self, ids):
        super(Load_dataset, self).__init__()
        # List of files
        self.data_files = [DATA_FOLDER.format(id) for id in ids]
        self.label_files = [LABELS_FOLDER.format(id) for id in ids]
        # Sanity check : raise an error if some files do not exist
        for f in self.data_files + self.label_files:
            if not os.path.isfile(f):
                raise KeyError('{} is not a file !'.format(f))
    def __len__(self):
        return len(self.data_files) # the length of the used data
    
    def __getitem__(self, idx):
#         Pre-processing steps
        #     # Data is normalized in [0, 1]
        self.data = 1/255 * np.asarray(io.imread(self.data_files[idx]).transpose((2,0,1)), dtype='float32')
        self.label = np.asarray(convert_from_color(io.imread(self.label_files[idx])), dtype='int64')
        data_p, label_p = self.data,  self.label
        # Return the torch.Tensor values
        return (torch.from_numpy(data_p),
                torch.from_numpy(label_p))
def CrossEntropy2d(input, target, weight=None, size_average=True):
    """ 2D version of the cross entropy loss """
    dim = input.dim()
    if dim == 2:
        return F.cross_entropy(input, target, weight, size_average)
    elif dim == 4:
        output = input.view(input.size(0), input.size(1), -1)
        output = torch.transpose(output, 1, 2).contiguous()
        output = output.view(-1, output.size(2))
        target = target.view(-1)
        return F.cross_entropy(output, target, weight, size_average)
    else:
        raise ValueError('Expected 2 or 4 dimensions (got {})'.format(dim))
        
def metrics(predictions, gts, label_values=LABELS):
    cm = confusion_matrix(
        gts,
        predictions,
        range(len(label_values)))
    print("Confusion matrix :")
    print(cm)
    print("---")
    # Compute global accuracy
    total = sum(sum(cm))
    accuracy = sum([cm[x][x] for x in range(len(cm))])
    accuracy *= 100 / float(total)
    print("{} pixels processed".format(total))
    print("Total accuracy : {}%".format(accuracy))
    return accuracy

# # **7. Selecting training and testing data**

In [4]:
train_ids =list(range(0, 2000))
test_ids =  list(range(2000,2400))
train_data = Load_dataset(train_ids)
test_data = Load_dataset(test_ids)

## **8. Implement the Unet model**


In [30]:
def conv_block(in_channels, out_channels):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size=3,padding='same'),
        nn.ReLU(),
        nn.Conv2d(out_channels, out_channels, kernel_size=3,padding='same'),
        nn.ReLU()
    )
def crop_tensor(input,target):
    diff =input.size()[2] - target.size()[2] # difference in width
    diff = diff // 2
    return(input[:,:,diff:input.size()[2]-diff,diff:input.size()[2]-diff])
class UNet(nn.Module):
    def __init__(self,IN_CHANNELS=3,OUT_CHANNELS=1):
        super().__init__() 
        self.encode_conv1 = conv_block(IN_CHANNELS, 64)
        self.encode_conv2 = conv_block(64, 128)
        self.encode_conv3 = conv_block(128, 256)
       # self.encode_conv4 = conv_block(256, 512)
       # self.encode_conv5 = conv_block(512, 1024)
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)

       # self.conv_transpose1 = nn.ConvTranspose2d(in_channels=1024,out_channels=512,kernel_size=2,stride=2)
       # self.decode_conv1 = conv_block(1024, 512)
       # self.conv_transpose2 = nn.ConvTranspose2d(in_channels=512,out_channels=256,kernel_size=2,stride=2)
       # self.decode_conv2 = conv_block(512, 256)
        self.conv_transpose3 = nn.ConvTranspose2d(in_channels=256,out_channels=128,kernel_size=2,stride=2)
        self.decode_conv3 = conv_block(256, 128)
        self.conv_transpose4 = nn.ConvTranspose2d(in_channels=128,out_channels=64,kernel_size=2,stride=2)
        self.decode_conv4 = conv_block(128, 64)
        self.out = nn.Conv2d(in_channels=64,out_channels=OUT_CHANNELS,kernel_size=1)
     
    def forward(self, x):
        # Encoder
        x1 = self.encode_conv1(x)
        x2 = self.maxpool(x1)
        x3 = self.encode_conv2(x2)
        x4 = self.maxpool(x3)
        x5 = self.encode_conv3(x4)
        # x6 = self.maxpool(x5)
        # x7 = self.encode_conv4(x6)
       # x8 = self.maxpool(x7)
       # x9 = self.encode_conv5(x8)
       # print(" out : ",out.size())
        print(" x1 : ",x1.size())
        print(" x2 : ",x2.size())
        print(" x3 : ",x3.size())
        print(" x4 : ",x4.size())
        print(" x5 : ",x5.size())
        # print(" x6 : ",x6.size())
        # print(" x7 : ",x7.size())

        
        # Decoder x1,x3,x5,x7 will be used as input 
       # x10 = self.conv_transpose1(x9)
       # x7_cropped = crop_tensor(x7,x10)
       # x7_10 = torch.cat([x10,x7_cropped],1)
       # x11 = self.decode_conv1(x7_10)
       # x12 = self.conv_transpose2(x11) 
        # print(" x12 : ",x12.size())
        # x5_cropped = crop_tensor(x5,x12)
        # print(" x5 : ",x5.size())
        # print(" x5_cropped : ",x5_cropped.size())
        # x5_12 = torch.cat([x12,x5_cropped],dim=1)
        # x13 = self.decode_conv2(x5_12)
        x14 = self.conv_transpose3(x5) #x13
        x3_cropped = crop_tensor(x3,x14)
        x3_14 = torch.cat([x14,x3_cropped],dim=1)
        x15 = self.decode_conv3(x3_14)
        x16 = self.conv_transpose4(x15)
        x1_cropped = crop_tensor(x1,x16)
        x1_16 = torch.cat([x16,x1_cropped],dim=1)
        x17 = self.decode_conv4(x1_16)
        out = self.out(x17)

        # print(" x8 : ",x8.size())
        # print(" x9 : ",x9.size())
        # print(" x10 : ",x10.size())
        # print(" x11 : ",x11.size())
        # print(" x12 : ",x12.size())
        # print(" x13 : ",x13.size())
        print(" x14 : ",x14.size())
        print(" x15 : ",x15.size())
        print(" x16 : ",x16.size())
        print(" x17 : ",x17.size())
        
        return out.squeeze(0)

In [31]:
type(train_data)

__main__.Load_dataset

In [32]:
# Helper function for inline image display
def matplotlib_imshow(img, one_channel=False):
    if one_channel:
        img = img.mean(dim=0)
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    if one_channel:
        plt.imshow(npimg, cmap="Greys")
    else:
        plt.imshow(np.transpose(npimg, (1, 2, 0)))

In [33]:
dataiter = iter(train_data)
images, labels = dataiter.__next__()

In [34]:
model = UNet()

In [35]:
optimizer = optim.Adam(model.parameters(), lr=0.0001)

In [40]:
def train_one_epoch(epoch_index, tb_writer):
    running_loss = 0.
    last_loss = 0.

    # Here, we use enumerate(training_loader) instead of
    # iter(training_loader) so that we can track the batch
    # index and do some intra-epoch reporting
    for i, data in enumerate(train_data):
        # Every data instance is an input + label pair
        inputs, labels = data
        print("inputs : ",inputs.size())
        inputs = inputs.unsqueeze(0)
        print("inputs unsqueezed : ",inputs.size())
        # Zero your gradients for every batch!
        optimizer.zero_grad()

        # Make predictions for this batch
        outputs = model(inputs).squeeze()
        print("outputs : ",outputs.size())
        print("labels : ",labels.size())
        # Compute the loss and its gradients
        loss = CrossEntropy2d(outputs, labels)
        loss.backward()

        # Adjust learning weights
        optimizer.step()

        # Gather data and report
        running_loss += loss.item()
        if i % 1000 == 999:
            last_loss = running_loss / 1000 # loss per batch
            print('  batch {} loss: {}'.format(i + 1, last_loss))
            tb_x = epoch_index * len(train_data) + i + 1
            tb_writer.add_scalar('Loss/train', last_loss, tb_x)
            running_loss = 0.

    return last_loss

In [41]:
train_one_epoch(0, None)

inputs :  torch.Size([3, 300, 300])
inputs unsqueezed :  torch.Size([1, 3, 300, 300])
 x1 :  torch.Size([1, 64, 300, 300])
 x2 :  torch.Size([1, 64, 150, 150])
 x3 :  torch.Size([1, 128, 150, 150])
 x4 :  torch.Size([1, 128, 75, 75])
 x5 :  torch.Size([1, 256, 75, 75])
 x14 :  torch.Size([1, 128, 150, 150])
 x15 :  torch.Size([1, 128, 150, 150])
 x16 :  torch.Size([1, 64, 300, 300])
 x17 :  torch.Size([1, 64, 300, 300])
outputs :  torch.Size([300, 300])
labels :  torch.Size([300, 300])




RuntimeError: Expected floating point type for target with class probabilities, got Long