In [1]:
import torch.nn as nn
import torch
import torchvision.transforms.functional as func
import torch.optim as optim
from torch.amp import GradScaler, autocast

class UNetBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3)
        self.relu = nn.ReLU(inplace=True)
    
    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.relu(x)
        return x

class UNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer1 = UNetBlock(3,64)
        self.layer2 = UNetBlock(64,128)
        
        self.layer3 = UNetBlock(128, 256)
        self.up_3_to_4 = nn.ConvTranspose2d(256, 128, 2)
        
        self.layer4 = UNetBlock(256, 128)
        self.up_4_to_5 = nn.ConvTranspose2d(128, 64, 2)

        self.layer5 = UNetBlock(128, 64)
        self.output = nn.Conv2d(64,1,1)
        
        self.pool = nn.MaxPool2d(2)
    
    def forward(self, input):
        enc1 = self.layer1(input)
        enc2 = self.layer2(self.pool(enc1))
        
        bottleneck = self.layer3(self.pool(enc2))
        
        bottleneck_up = self.up_3_to_4(bottleneck)
        enc2_cropped = func.center_crop(enc2, [bottleneck_up.shape[2], bottleneck_up.shape[3]])
        dec1 = self.layer4(torch.cat([enc2_cropped, bottleneck_up], dim=1))
        
        dec1_up = self.up_4_to_5(dec1)
        enc1_cropped = func.center_crop(enc1, [dec1_up.shape[2], dec1_up.shape[3]])
        dec2 = self.layer5(torch.cat([enc1_cropped, dec1_up], dim=1))

        output = self.output(dec2)

        return output

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def dice_loss(pred, target, smooth=1.):
    pred = torch.sigmoid(pred)
    pred_flat = pred.view(-1)
    target_flat = target.view(-1)
    intersection = (pred_flat * target_flat).sum()
    return 1 - ((2. * intersection + smooth) / (pred_flat.sum() + target_flat.sum() + smooth))

In [4]:
from torch.utils.data import Dataset, DataLoader
import os
from sklearn.model_selection import train_test_split
from torchvision import transforms
import tifffile as tif
import random

imageTransform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
maskTransform = transforms.Compose([transforms.ToTensor()])

random.seed(42) # used for reproducibility 
batch_size = 8

class LandslideDataset(Dataset):
    def __init__(self, img_list, mask_list):
        self.img_list = img_list
        self.mask_list = mask_list
    
    def __len__(self):
        return len(self.img_list)
    
    def __getitem__(self, index):
        img_dir = self.img_list[index]
        mask_dir = self.mask_list[index]

        img = tif.imread(img_dir)
        mask = tif.imread(mask_dir)

        img = imageTransform(img)
        mask = (maskTransform(mask) > 0).float()

        return img, mask
    
def getMetrics(TP, TN, FP, FN):
    precision = TP / (TP + FP + 1e-8)
    recall = TP / (TP + FN + 1e-8)
    f1 = 2 * precision * recall / (precision + recall + 1e-8)
        
    iou_1  = TP / (TP + FP + FN + 1e-8)
    iou_0 = TN / (TN + FP + FN + 1e-8)
    miou = (iou_0 + iou_1) / 2 
        
    oa = (TP + TN)/(TP + TN + FP + FN + 1e-8)
    
    return precision, recall, f1, iou_1, miou, oa

path = "../data/"
regions = [f for f in os.listdir(path) if os.path.isdir(os.path.join(path, f))]

subdataset_to_region = {
    "Hokkaido Iburi-Tobu": "Hokkaido Iburi-Tobu",
    "Jiuzhai valley (UAV-0.2m)": "Jiuzhai Valley",
    "Jiuzhai valley (UAV-0.5m)": "Jiuzhai Valley",
    "Lombok": "Lombok",
    "Longxi River (SAT)": "Longxi River",
    "Longxi River (UAV)": "Longxi River",
    "Mengdong Township": "Mengdong Township",
    "Moxi town (UAV-0.2m)": "Luding",
    "Moxi town (UAV-1m)": "Luding",
    "Moxitaidi (SAT)": "Luding",
    "Moxitaidi (UAV-0.6m)": "Luding",
    "Moxitaidi (UAV-1m)": "Luding",
    "palu": "Palu",
    "Tiburon Peninsula (planet)": "Tiburon Peninsula",
    "Tiburon Peninsula (Sentinel)": "Tiburon Peninsula",
    "Wenchuan": "Wenchuan"
}

regions_dict = {
    "Hokkaido Iburi-Tobu": [],
    "Jiuzhai Valley": [],
    "Lombok": [],
    "Longxi River": [],
    "Mengdong Township": [],
    "Luding": [],
    "Palu": [],
    "Tiburon Peninsula": [],
    "Wenchuan": []
}

for region in regions:
    if(region != "study areas shp"):
        dataset_dir = "../data/" + region
        image_dir = os.path.join(dataset_dir, "img")
        img_list = os.listdir(image_dir)
        
        all_images = sorted(os.path.join(image_dir, f) for f in img_list)
        
        regions_dict[subdataset_to_region[region]].extend(all_images)

output = "region,precision,recall,f1,iou,miou,oa"

for region in regions_dict:
    if len(regions_dict[region]) > 1000:
        regions_dict[region] = random.sample(regions_dict[region], 1000)
    
    image_paths = regions_dict[region]
    mask_paths = [f.replace("img", "mask") for f in image_paths]
    
    train_img, temp_img, train_mask, temp_mask = train_test_split(
        image_paths, mask_paths, test_size=.3, random_state=42
    )
    
    test_img, val_img, test_mask, val_mask = train_test_split(
        temp_img, temp_mask, test_size=.5, random_state=42
    )

    train_dataset = LandslideDataset(train_img, train_mask)
    val_dataset = LandslideDataset(val_img, val_mask)
    test_dataset = LandslideDataset(test_img, test_mask)
    
    trainLoader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
    valLoader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0)
    testLoader = DataLoader(test_dataset, batch_size=batch_size,shuffle=False,num_workers=0)
    
    epochs = 40
    
    best_iou = 0.0
    patience = 10
    counter = 0
        
    model_path = "../models/intra-region/" + region + ".pth"
    
    unet = UNet()
    unet.to(device)

    pos_weight = torch.tensor([4.5]).to(device)
    criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
    optimizer = optim.Adam(unet.parameters(), lr=1e-4, weight_decay=1e-5)

    scaler = GradScaler()
    
    print(f'region, epoch, train_loss, val_loss, precision, recall, f1, iou, miou, oa')
    
    for epoch in range(epochs):
        unet.train()
        running_loss = 0.0
        train_num = 0

        for i, data in enumerate(trainLoader, 0):
            image, mask = data
            
            image = image.to(device)
            mask = mask.to(device)

            optimizer.zero_grad()
            
            with autocast(device_type="cuda"):
                outputs = unet(image)
                outputs = nn.functional.interpolate(outputs, size=mask.shape[2:], mode="bilinear", align_corners=False)
                bce = criterion(outputs, mask)
                dice = dice_loss(outputs, mask)
                loss = bce + dice
            
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            running_loss += loss.item()
            train_num += 1
        
        unet.eval()
        val_loss = 0.0
        val_num = 0
        
        TP, FP, FN, TN = 0,0,0,0
        
        with torch.no_grad():
            for data in valLoader:
                image, mask = data

                image = image.to(device)
                mask = mask.to(device)
            
                outputs = unet(image)
                outputs = nn.functional.interpolate(outputs, size=mask.shape[2:], mode="bilinear", align_corners=False)
                loss = criterion(outputs, mask)
                val_loss += loss.item()

                preds = torch.sigmoid(outputs) > .6
                
                preds_flat = preds.view(-1)
                mask_flat = mask.view(-1)
                
                TP += ((preds_flat == 1) & (mask_flat == 1)).sum().item()
                FP += ((preds_flat == 1) & (mask_flat == 0)).sum().item()
                FN += ((preds_flat == 0) & (mask_flat == 1)).sum().item()
                TN += ((preds_flat == 0) & (mask_flat == 0)).sum().item()
                
                val_num += 1
        
        precision, recall, f1, iou, miou, oa = getMetrics(TP, TN, FP, FN)
        
        print(f'{region}, {epoch+1}, {running_loss / train_num :.3f}, {val_loss / val_num :.3f}, {precision:.4f}, {recall:.4f}, {f1:.4f}, {iou:.4f}, {miou:.4f}, {oa:.4f}')

        if iou > best_iou:
            best_iou = iou
            counter = 0
            
            torch.save(unet.state_dict(), model_path)
        elif iou != 0.0:
            counter += 1
            if counter >= patience:
                break
    
    unet.load_state_dict(torch.load(model_path))
    
    TP, FP, FN, TN = 0,0,0,0
    
    for data in testLoader:
        image, mask = data

        image = image.to(device)
        mask = mask.to(device)
            
        outputs = unet(image)
        outputs = nn.functional.interpolate(outputs, size=mask.shape[2:], mode="bilinear", align_corners=False)

        preds = torch.sigmoid(outputs) > .6
                
        preds_flat = preds.view(-1)
        mask_flat = mask.view(-1)
                
        TP += ((preds_flat == 1) & (mask_flat == 1)).sum().item()
        FP += ((preds_flat == 1) & (mask_flat == 0)).sum().item()
        FN += ((preds_flat == 0) & (mask_flat == 1)).sum().item()
        TN += ((preds_flat == 0) & (mask_flat == 0)).sum().item()
        
    precision, recall, f1, iou, miou, oa = getMetrics(TP, TN, FP, FN)
    output += f'\n{region}, {precision:.4f}, {recall:.4f}, {f1:.4f}, {iou:.4f}, {miou:.4f}, {oa:.4f}'

with open("../results/intra-region/metrics.csv", "w") as f:
    f.write(output)

region, epoch, train_loss, val_loss, precision, recall, f1, iou, miou, oa
Hokkaido Iburi-Tobu, 1, 1.717, 0.836, 0.0000, 0.0000, 0.0000, 0.0000, 0.4500, 0.9000
Hokkaido Iburi-Tobu, 2, 1.337, 0.513, 0.4105, 0.7115, 0.5206, 0.3519, 0.6054, 0.8690
Hokkaido Iburi-Tobu, 3, 1.140, 0.523, 0.3940, 0.7658, 0.5203, 0.3517, 0.5994, 0.8588
Hokkaido Iburi-Tobu, 4, 1.113, 0.490, 0.4211, 0.7145, 0.5299, 0.3605, 0.6120, 0.8732
Hokkaido Iburi-Tobu, 5, 1.101, 0.485, 0.4097, 0.7440, 0.5284, 0.3591, 0.6078, 0.8672
Hokkaido Iburi-Tobu, 6, 1.092, 0.478, 0.4205, 0.7308, 0.5338, 0.3641, 0.6132, 0.8724
Hokkaido Iburi-Tobu, 7, 1.075, 0.478, 0.4481, 0.6821, 0.5409, 0.3707, 0.6232, 0.8842
Hokkaido Iburi-Tobu, 8, 1.068, 0.472, 0.4148, 0.7503, 0.5343, 0.3645, 0.6116, 0.8692
Hokkaido Iburi-Tobu, 9, 1.065, 0.504, 0.5027, 0.5653, 0.5321, 0.3625, 0.6286, 0.9006
Hokkaido Iburi-Tobu, 10, 1.065, 0.474, 0.4517, 0.6818, 0.5434, 0.3731, 0.6251, 0.8854
Hokkaido Iburi-Tobu, 11, 1.055, 0.467, 0.4148, 0.7573, 0.5361, 0.3662, 0.61