# Setup: Importing Necessary Libraries
#### Assuming libraries like torch, torchvision, torchmetrics, numpy and rasterio are installed in the environment

In [78]:
import matplotlib.pyplot as plt
import pandas as pd
from PIL import Image, ImageEnhance
import patchify
from patchify import unpatchify 
import math
import cv2
import torch
from skimage.io import imread        # needs installation 
import numpy as np
import random 
import os
import copy
from glob import glob
import matplotlib.pyplot as plt
import rasterio                      
from rasterio.plot import show
from rasterio.merge import merge
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torchvision
from torch import nn
from torchvision.io.image import read_image
from torchvision.models.segmentation import deeplabv3_resnet50
import torchmetrics
from torchvision import transforms

In [10]:
# define the patch size and steps of the training image
ps=64 
s=int(ps/4) 

## Data: Import Training Image

In [79]:
img_path="Image_Training.tif"  # Path to Image
with rasterio.open(img_path, 'r') as ds:
    arr = ds.read() 
arr1 = np.swapaxes(arr, 1, 0)
arr2_image = np.swapaxes(arr1, 1, 2)
patch2 = patchify.patchify(arr2_image ,(ps,ps,3), step=s)
patch_X = np.reshape(patch2, (patch2.shape[0]*patch2.shape[1], 3, ps,ps))
print(f'Image shape: {patch_X.shape}')

## Data: Import Labels

In [82]:
target_path="Label_Training.tif"# Path to Label
with rasterio.open(target_path, 'r') as ds:
    arr = ds.read().squeeze() 
    arr=arr.astype('uint8')
    labels = list(np.unique(arr))
    labels_ = np.zeros((len(labels), arr.shape[0], arr.shape[1]), np.uint8)
    for i in range(len(labels)):
        x = np.where(arr==i, 1, 0)
        labels_[i,:,:] = x
arr2_label = labels_.reshape(651, 1169, 8)
patch2 = patchify.patchify(arr2_label ,(ps,ps,8), step=s) 
patch_Y = np.reshape(patch2,(patch2.shape[0]*patch2.shape[1],8, ps,ps))
print(patch_Y.shape)

## Data: Array partitioning for train test sample chips
###### custom function is required to partition array or samples in pytorch environemnet

In [13]:
def trainTesitsplit(x, y, ratio=0.1): # follows systematic random sampling
    x = torch.from_numpy(x.astype(float))
    y = torch.from_numpy(y.astype(float))
    
    N = x.shape[0]
    n = int(ratio*N)
    step = int(N/n)
    full = list(range(x.shape[0]))
    v_ind = list(range(0, N, step))
    t_ind = list(set(full).symmetric_difference(set(v_ind)))
    
    vv_x = torch.index_select(x, 0, torch.tensor(v_ind, dtype=torch.long))
    tt_x = torch.index_select(x, 0, torch.tensor(t_ind, dtype=torch.long))
    
    vv_y = torch.index_select(y, 0, torch.tensor(v_ind, dtype=torch.long))
    tt_y = torch.index_select(y, 0, torch.tensor(t_ind, dtype=torch.long))
    
    return tt_x, tt_y, vv_x, vv_y 

# Data: Prepare pytorch data loader
for smooth training, data flow and memory usage, pytorch has dataset and dataloader class which is designed to be customized as our data structure and training procedure, example below is thecustomdataset class amended to take data from image and labels of which are torch tensors

In [14]:
class CustomImageDataset(Dataset): 
    def __init__(self, arr_, lbl_, tensify=False):
        self.tensify = tensify
        self.arr_ = arr_
        self.lbl_ = lbl_
        self.tarsform = transforms.Compose([transforms.ToTensor(),
                                            transforms.Lambda(lambda img: img.float())
                                           ])
    def __len__(self):
        return len(self.arr_)

    def __getitem__(self, idx):
        
        image = self.arr_[idx]
        label = self.lbl_[idx]
        if not self.tensify:
            return image, label
        else:
            return self.tarsform(image), self.tarsform(label)

# Model: Prepare FCN_ResNet Model
###### see https://pytorch.org/vision/0.11/models.html for further notes

In [15]:
def define_model(bakbone_depth=50, all_weight=None, num_class=8, phase='train'):
    if bakbone_depth == 50:
        model = deeplabv3_resnet50(num_classes=num_class, aux_loss=False)
    elif bakbone_depth == 101:
        model = deeplabv3_resnet101(num_classes=num_class, aux_loss=False)
    else:
        raise ValueError('Depth for resnet not known')
    if phase == 'train':
        model.train()
    else:
        model.eval()
    return model

##### Specify model phyper arameters

In [16]:

bakbone_depth=101
all_weight=None 
num_class=8 # including bacground
phase='train'
bach_size = 20
epochs = 50
lr_rate = 0.0001
device = device =  torch.device("cuda" if torch.cuda.is_available() else "cpu")

# define and load model using parameters

model = define_model(bakbone_depth=bakbone_depth,
                     all_weight=all_weight,
                     num_class=num_class,
                     phase=phase)

model = model.to(device)  # put the model either on CPU or GPU depending on availability

##### Torch training process is not the same as keras where tensorflow is built on top of as a wrap module, in pytorch we shoud define everything explicitly or use a pytorch wraper pytorch lightening as alternative wich can be found here https://lightning.ai/pages/open-source/

In [30]:
optimizer = torch.optim.SGD(model.parameters(), lr_rate)  # optimizer 
activation = torch.nn.Softmax(dim=1).to(device)
acc_fn = torchmetrics.Accuracy(task='multiclass', num_classes=8, average='weighted', mdmc_reduce='samplewise').to(device)
ls_fn = nn.CrossEntropyLoss().to(device)
# define data set and loader for training
x_train, y_train, x_vaid, y_valid = trainTesitsplit(patch_X, patch_Y,ratio=0.1)  # sample array
train_dataset = CustomImageDataset(arr_ = x_train, lbl_ = y_train) # put in custom dataset
valid_dataset = CustomImageDataset(arr_ = x_vaid, lbl_ = y_valid)  # put in custom datset
train_loader = DataLoader(dataset=train_dataset, batch_size=bach_size, drop_last=True, shuffle=True)
valid_loader = DataLoader(dataset = valid_dataset, batch_size=bach_size, drop_last=True, shuffle=True)

##### Categorical cross entrophy loss as alternative loss function as there is no direct implementation in pytorch alternative solution is using torch CrossEntropy function on logits and undertake softmax during inerence 

In [20]:
def loss_fn(ref, pred): 
    x = torch.mean(-torch.sum(ref*torch.log(pred), dim=(1,2)))
    return x

### Training 

In [34]:
save_dir = 'weight'
if not os.path.exists(save_dir):
    os.makedirs(save_dir, exist_ok=True)
    
VL = {'loss':[], 'acc':[]}
TA = {'loss':[], 'acc':[]}

best_weight = None

for epoch in range(1, epochs+1):
    epoch_loss = 0
    epoch_acc = 0
    control = 0
    for j, (X, Y) in enumerate(train_loader):
        optimizer.zero_grad()
        X = X.float().to(device)
        Y = Y.to(device) 
        logits = model(X)
        softs = activation(logits['out'])
        loss = ls_fn(logits['out'], Y)
        # loss = loss_fn(Y, softs)
        acc = acc_fn(torch.argmax(softs, dim=1), torch.argmax(Y, dim=1))
        loss.backward()
        optimizer.step()
        with torch.no_grad():
            epoch_loss+=loss.item()
            epoch_acc+=acc.item()
            print(f'step: {j}, step_loss: {loss.item()}, step_acc: {acc.item()}', end = '\r', flush=True)
        control+=1

        with torch.no_grad():
            epoch_loss = epoch_loss / control
            TA['loss'].append(epoch_loss)
            TA['acc'].append(epoch_acc/control)
            print(f'+++ {epoch}: step: {j}, train_loss: {epoch_loss}, train_acc:{epoch_acc/control} +++',  end = '\r', flush=True)

    # validate the model
    val_loss = 0
    val_acc = 0
    control = 0

    with torch.no_grad():
        for image, target in valid_loader:
            image, target = image.float().to(device), target.to(device) # target.float().to(self.device)
            output = model(image)
            softs = activation(output['out'])
            # vloss = loss_fn(softs, target)
            vloss = ls_fn(output['out'], target)
            vacc = acc_fn(torch.argmax(softs,dim=1), torch.argmax(target, dim=1))
            val_loss += vloss.item()
            val_acc+=vacc.item()
            control+=1
        val_loss = val_loss/control
        val_acc = val_acc/control
        VL['loss'].append(val_loss)
        VL['acc'].append(val_acc)
        print(f'+++ {epoch}: step: {j}, valid_loss: {val_loss}, valid_acc: {val_acc}')

        if epoch == 1:
            best_valid = VL["loss"][-1]
            best_weight = copy.deepcopy(model.state_dict())
        else:
            if VL["loss"][-1]<=best_valid:
                best_valid = VL["loss"][-1]
                best_weight = copy.deepcopy(model.state_dict())
            else:
                pass

name = save_dir + '/checkpoint.pth'   # self.weigth_path
torch.save(best_weight, name)

plt.plot(TA['loss'], label='training loss')
plt.plot(VL['loss'], label = 'validation loss')
plt.ylabel('Loss')
plt.xlabel('epochs')
plt.legend()
plt.show()

plt.plot(TA['acc'], label='Taining accuracy')
plt.plot(VL['acc'], label = 'Validation accuracy')
plt.ylabel('Accuarcy')
plt.xlabel('epochs')
plt.legend()
plt.show()

# Evaluation

## Import Evaluation Image

In [45]:
test_img="Image_Prediction.tif" 
with rasterio.open(test_img, 'r') as ds:
    arr_ti = ds.read() 
shape = arr_ti.squeeze().shape
arr2_ti = arr_ti.reshape(shape[1], shape[2], shape[0])

patchsize = ps
nbands = 3
new_row, new_col = 207, 173 
arr2_ti = arr2_ti[:new_row, :new_col, :]
print(arr2_ti.shape)

patch1_ti = patchify.patchify(arr2_ti, (patchsize,patchsize,nbands), step=patchsize) #1952 #3264
num_patch_row = int(arr2_ti.shape[0]/patchsize)
num_patch_col = int(arr2_ti.shape[1]/patchsize)
num_total = num_patch_row * num_patch_col
test_img_patch = np.reshape(patch1_ti, (num_total, nbands, patchsize, patchsize))
test_img_patch = torch.from_numpy(test_img_patch.astype(float))
print(test_img_patch.shape)

(207, 173, 3)
torch.Size([6, 3, 64, 64])


## Import Evaluation Labels

In [46]:
test_label="Label_Prediction_new_classes.tif" # Path to evaluation labels
with rasterio.open(test_label, 'r') as ds:
    arr_tl = ds.read() 
    arr_tl=arr_tl.astype('uint8').squeeze()
    
    labels = list(np.unique(arr_tl))
    labels_ = np.zeros((len(labels), arr_tl.shape[0], arr_tl.shape[1]), np.uint8)
    for i in range(len(labels)):
        x = np.where(arr_tl==i, 1, 0)
        labels_[i,:,:] = x
        
lb_shape = labels_.shape
arr_tl = labels_.reshape(lb_shape[1], lb_shape[2], lb_shape[0])

patchsize = ps
nbands = 8
new_row, new_col = 207, 173
arr_tl = arr_tl[:new_row, :new_col, :]
print(arr_tl.shape)

patch1_tl = patchify.patchify(arr_tl, (patchsize,patchsize,nbands), step=patchsize) 
num_patch_row = int(arr_tl.shape[0]/patchsize)
num_patch_col = int(arr_tl.shape[1]/patchsize)
num_total = num_patch_row * num_patch_col
test_label_patch = np.reshape(patch1_tl, (num_total, nbands, patchsize, patchsize))
test_label_patch = torch.from_numpy(test_label_patch.astype(float))
# print(test_label_patch.shape)

(207, 173, 8)
torch.Size([6, 8, 64, 64])


## Import Trained Model 

In [47]:
weights_name = save_dir + '/checkpoint.pth'

model_pred = define_model(bakbone_depth=bakbone_depth,
                     all_weight=all_weight,
                     num_class=num_class,
                     phase='test')
model_pred = model.to(device)
model_pred.load_state_dict(torch.load(weights_name))   # learned weightes loaded using model builtin function "load_state_dict" 

<All keys matched successfully>

In [51]:
#performance of model on normal prediciton data
test_img_patch = test_img_patch.to(device)
test_label_patch = test_label_patch.to(device)
logs = model_pred(test_img_patch.float())
softs = torch.softmax(logs['out'], dim=1)
accuracy = acc_fn(torch.argmax(softs, dim=1),torch.argmax(test_label_patch, axis=1))
# print(accuracy)

tensor(0.0304, device='cuda:0')


## Reconstruct Predicted Patches Into A Complete Testing Area [Optitional]

In [53]:
# reconstruct predicted patches into a complete testing area
def reshape_prediction_by_unpatchify(prediction, patchsize, nclass, lab_array):
    num_row = int(lab_array.shape[0]/patchsize)
    num_col = int(lab_array.shape[1]/patchsize)
    prediction_reshape = prediction.reshape((num_row, num_col, 1, patchsize, patchsize, nclass))
    target_shape = (num_row*patchsize, num_col*patchsize, nclass)
    prediction_reshape_unpatch = unpatchify(prediction_reshape, target_shape)
    # please note that unpatchify is not giving perfect spatial tiles, it gives blocky images.
    # the usage here is just to demonstrate the workflow and we do not recommend for final product version
    return prediction_reshape_unpatch

In [67]:
prediction = softs.permute(0, 2, 3, 1).cpu().detach().numpy()
nclass = 8
lab_array = arr_tl
prediction_label_complete = reshape_prediction_by_unpatchify(prediction, patchsize, nclass, lab_array)
print(arr_tl.shape, prediction_label_complete.shape, arr2_ti.shape)

(207, 173, 8) (192, 128, 8) (207, 173, 3)


## Visualize Prediction Results [Optional]

In [71]:
fig, ax = plt.subplots(1,3, figsize=(10, 15), sharex=True, sharey=True)
ax[0].imshow(arr2_ti)
ax[0].set_title('Testing Image')
ax[1].imshow(np.argmax(arr_tl, axis=-1))
ax[0].set_title('Testing Label')
ax[2].imshow(np.argmax(prediction_label_complete, axis=-1))
ax[0].set_title('Prediction on test image')
plt.show()

## Save Prediction Results [Optional]

In [73]:
img_pred=np.argmax(prediction_label_complete, axis=-1)
save= 'Prediction.tif'
cv2.imwrite(save, img_pred)
print('Prediction and Testing labels saved:'+' '+save)

Prediction and Testing labels saved: Prediction.tif
