### Deep learning model training.


In [1]:
import time
import torch
import random
from model import unet
import torch.nn as nn
from glob import glob
import geopandas as gpd
from torchsummary import summary
from utils.imgShow import imsShow
from utils.dataloader import TraSet, ValSet
from utils.acc_metric import oa_binary, miou_binary


### dataset loading

In [2]:
dir_scene = 'data/dset/scene/'
dir_dem = 'data/dset/dem/' 
dir_truth = 'data/dset/truth/'  

In [3]:
### secene and truth pairwise data
## traset
ids_tra_gdf = gpd.read_file('data/dset/dset_tra.gpkg')
ids_tra = ids_tra_gdf['id_scene'].tolist()
paths_scene_tra = [dir_scene+id+'_nor.tif' for id in ids_tra]
paths_dem_tra = [dir_dem+id+'_dem_nor.tif' for id in ids_tra]
paths_truth_tra = [dir_truth+id+'.tif' for id in ids_tra] 
## valset
paths_patch_valset = sorted(glob('data/dset/valset/*'))
print(f'train scenes: {len(paths_scene_tra)}, vali patch: {len(paths_patch_valset)}')



train scenes: 48, vali patch: 1959


In [4]:
## Create dataset instances
tra_data = TraSet(paths_scene=paths_scene_tra, 
                   paths_truth=paths_truth_tra, 
                   paths_dem=paths_dem_tra,
                   path_size=(256, 256))
val_data = ValSet(paths_valset=paths_patch_valset)
tra_loader = torch.utils.data.DataLoader(tra_data, batch_size=4, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_data, batch_size=4)


In [5]:
## check data loading time
time_start = time.time()
for x_batch, y_batch in tra_loader:
  print(f"Batch processing time: {time.time() - time_start:.1f}")
time_start = time.time()


Batch processing time: 5.0
Batch processing time: 6.7
Batch processing time: 8.3
Batch processing time: 13.2
Batch processing time: 14.8
Batch processing time: 16.5
Batch processing time: 21.3
Batch processing time: 26.5
Batch processing time: 28.2
Batch processing time: 29.8
Batch processing time: 31.4
Batch processing time: 33.0


#### Model training

In [6]:
### check model
model = unet(num_bands=7)
summary(model, input_size=(7,256,256), device='cpu')


Layer (type:depth-idx)                   Param #
├─Upsample: 1-1                          --
├─Sequential: 1-2                        --
|    └─Conv2d: 2-1                       1,024
|    └─BatchNorm2d: 2-2                  32
|    └─ReLU: 2-3                         --
├─Sequential: 1-3                        --
|    └─Conv2d: 2-4                       4,640
|    └─BatchNorm2d: 2-5                  64
|    └─ReLU: 2-6                         --
├─Sequential: 1-4                        --
|    └─Conv2d: 2-7                       18,496
|    └─BatchNorm2d: 2-8                  128
|    └─ReLU: 2-9                         --
├─Sequential: 1-5                        --
|    └─Conv2d: 2-10                      73,856
|    └─BatchNorm2d: 2-11                 256
|    └─ReLU: 2-12                        --
├─Sequential: 1-6                        --
|    └─Conv2d: 2-13                      110,656
|    └─BatchNorm2d: 2-14                 128
|    └─ReLU: 2-15                        --
├─Seq

Layer (type:depth-idx)                   Param #
├─Upsample: 1-1                          --
├─Sequential: 1-2                        --
|    └─Conv2d: 2-1                       1,024
|    └─BatchNorm2d: 2-2                  32
|    └─ReLU: 2-3                         --
├─Sequential: 1-3                        --
|    └─Conv2d: 2-4                       4,640
|    └─BatchNorm2d: 2-5                  64
|    └─ReLU: 2-6                         --
├─Sequential: 1-4                        --
|    └─Conv2d: 2-7                       18,496
|    └─BatchNorm2d: 2-8                  128
|    └─ReLU: 2-9                         --
├─Sequential: 1-5                        --
|    └─Conv2d: 2-10                      73,856
|    └─BatchNorm2d: 2-11                 256
|    └─ReLU: 2-12                        --
├─Sequential: 1-6                        --
|    └─Conv2d: 2-13                      110,656
|    └─BatchNorm2d: 2-14                 128
|    └─ReLU: 2-15                        --
├─Seq

In [7]:
### create loss and optimizer
loss_bce = nn.BCELoss()     
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)


In [8]:
'''------train step------'''
def train_step(model, loss_fn, optimizer, x, y):
    optimizer.zero_grad()
    pred = model(x)
    loss = loss_fn(pred, y.float())
    loss.backward()
    optimizer.step()
    miou = miou_binary(pred=pred, truth=y)
    oa = oa_binary(pred=pred, truth=y)
    return loss, miou, oa

'''------validation step------'''
def val_step(model, loss_fn, x, y):
    model.eval()
    with torch.no_grad():
        pred = model(x.float())
        loss = loss_fn(pred, y.float())
    miou = miou_binary(pred=pred, truth=y)
    oa = oa_binary(pred=pred, truth=y)
    return loss, miou, oa

'''------train loops------'''
def train_loops(model, loss_fn, optimizer, tra_loader, 
                                    val_loader, epoches, device):
    model = model.to(device)
    size_tra_loader = len(tra_loader)
    size_val_loader = len(val_loader)
    for epoch in range(epoches):
        start = time.time()
        tra_loss, val_loss = 0, 0
        tra_miou, val_miou = 0, 0
        tra_oa, val_oa = 0, 0
        '''-----train the model-----'''
        for x_batch, y_batch in tra_loader:
            x_batch, y_batch = x_batch.to(device), y_batch.to(device)
            loss, miou, oa = train_step(model=model, loss_fn=loss_fn, 
                                    optimizer=optimizer, x=x_batch, y=y_batch)
            tra_loss += loss.item()
            tra_miou += miou.item()
            tra_oa += oa.item()
        '''-----validation the model-----'''
        for x_batch, y_batch in val_loader:
            x_batch, y_batch = x_batch.to(device), y_batch.to(device)
            loss, miou, oa = val_step(model=model, loss_fn=loss_fn, 
                                                    x=x_batch, y=y_batch)
            val_loss += loss.item()
            val_miou += miou.item()
            val_oa += oa.item()
        ## Accuracy
        tra_loss = tra_loss/size_tra_loader
        val_loss = val_loss/size_val_loader
        tra_miou = tra_miou/size_tra_loader
        val_miou = val_miou/size_val_loader
        tra_oa = tra_oa/size_tra_loader
        val_oa = val_oa/size_val_loader
        print(f'Ep{epoch+1}: tra-> Loss:{tra_loss:.3f},Oa:{tra_oa:.2f},Miou:{tra_miou:.2f}, '
              f'val-> Loss:{val_loss:.2f},Oa:{val_oa:.2f},Miou:{val_miou:.2f},time:{time.time()-start:.0f}s')
        ## show the result
        if (epoch+1)%10 == 0:
            model.eval()
            sam_index = random.randrange(len(val_data))
            patch, truth = val_data[sam_index]
            patch, truth = torch.unsqueeze(patch.float(), 0).to(device), truth.to(device)
            pred = model(patch)
            ## convert to numpy and plot
            patch = patch[0].to('cpu').detach().numpy().transpose(1,2,0)
            pred = pred[0].to('cpu').detach().numpy()
            truth = truth.to('cpu').detach().numpy()
            imsShow([patch, truth, pred], 
                    img_name_list=['input_patch', 'truth', 'prediction'] , figsize=(10,4))


In [None]:
# device = torch.device('cpu') 
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_loops(model=model, 
            loss_fn=loss_bce, 
            optimizer=optimizer,
            tra_loader=tra_loader, 
            val_loader=val_loader, 
            epoches=20,
            device=device)


Ep1: tra-> Loss:0.565,Oa:0.51,Miou:0.35, val-> Loss:0.33,Oa:0.70,Miou:0.54,time:38s
Ep2: tra-> Loss:0.310,Oa:0.93,Miou:0.80, val-> Loss:0.34,Oa:0.76,Miou:0.65,time:38s
Ep3: tra-> Loss:0.310,Oa:0.90,Miou:0.80, val-> Loss:0.18,Oa:0.80,Miou:0.72,time:38s


In [None]:
# # model saving
# path_save = 'model/trained/unet.pth'
# torch.save(model.state_dict(), path_save)   # save weights of the trained model 
# model.load_state_dict(torch.load(path_save, weights_only=True))  # load the weights of the trained model
