In [45]:
import rasterio
from rasterio.windows import Window
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import torch
from torch import optim
import torch.nn.functional as F
import torch.nn as nn
import numpy as np
from unet_parts import *
from datetime import datetime
from tqdm import tqdm
import segmentation_models_pytorch as smp
from torch.utils.tensorboard import SummaryWriter
import matplotlib.pyplot as plt
import pandas as pd
from utils import metrics
import matplotlib.patches as mpatches
from torchmetrics import MetricCollection, Accuracy, Precision, Recall, IoU

In [2]:
name = datetime.now().strftime("%d-%m-%Hh%Mm") + '-unet_ortho_simple'
print(name)

01-06-20h18m-unet_ortho_simple


In [36]:
legend = pd.read_excel('../dataset_ICCV/legend.xlsx', usecols=[0,1,7])

In [37]:
class Formatter(object):
    def __init__(self, im):
        self.im = im

    def __call__(self, x, y):
        z = self.im.get_array()[int(y), int(x)]
        return f"{legend.iloc[z]['Name']}"

In [3]:
raster_dirs = {
    'orthophoto': '/CatLCNet/ORTHO/ORTHO_CAT_2018_UTM_WGS84_31N_1m.tif',
    'orthophotoIR': '/CatLCNet/ORTHO/ORTHOIR_CAT_2018.tif',
    'landcover': '/CatLCNet/LC/LC_2018_UTM_WGS84_31N_1m.tif',
}

In [4]:
no_borders = np.load('../dataset_ICCV/no_borders_960.npy', allow_pickle=True)
np.random.seed(1)
np.random.shuffle(no_borders)
no_borders_train = no_borders[:(int)(0.6*len(no_borders))]
no_borders_val = no_borders[(int)(0.6*len(no_borders)):(int)(0.8*len(no_borders))]
no_borders_test = no_borders[(int)(0.8*len(no_borders)):]

In [5]:
class Normalize(object):
    """Normalize images."""

    def __call__(self, sample):
        tile, orthophoto, orthophotoIR, landcover = sample['tile'], sample['orthophoto'], sample['orthophotoIR'], sample['landcover']
                
        stats_orthophoto = np.load('../dataset_ICCV/stats_orthophotoRGB.npy',allow_pickle=True)
        stats_orthophotoIR = np.load('../dataset_ICCV/stats_orthophotoIR.npy',allow_pickle=True)
        
        normalize_orthophoto = transforms.Normalize(*stats_orthophoto)
        normalize_orthophotoIR = transforms.Normalize(*stats_orthophotoIR)
           
        orthophoto = normalize_orthophoto(orthophoto)
        orthophotoIR = normalize_orthophotoIR(orthophotoIR)
        
        
        return {'tile': tile,
                'orthophoto': orthophoto,
                'orthophotoIR': orthophotoIR,
                'landcover': landcover}


class ToTensor(object):
    """Convert ndarrays in sample to Tensors."""

    def __call__(self, sample):
        tile, orthophoto, orthophotoIR, landcover = sample['tile'], sample['orthophoto'], sample['orthophotoIR'], sample['landcover']
        
        orthophoto = orthophoto.astype(np.float32)
        orthophotoIR = orthophotoIR.astype(np.float32)
        
        return {'tile': tile, 'orthophoto': torch.from_numpy(orthophoto), 'orthophotoIR': torch.from_numpy(orthophotoIR), 'landcover': torch.from_numpy(landcover-1)}

In [6]:
class CatalanDataset(Dataset):
    def __init__(self, tiles_list, raster_dirs, shape=320, transform=None): #960, to train 320
        """
        Args:
            tiles_list (): Path to the csv file with annotations.
            raster_dirs (): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.tiles_list = tiles_list
        self.raster_dirs = raster_dirs
        self.transform = transform
        self.shape = shape

    def __len__(self):
        return len(self.tiles_list)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
            
            
        window_1m = Window(*self.tiles_list[idx], self.shape, self.shape)
        
        with rasterio.open(self.raster_dirs['orthophoto']) as src:
            raster_orthophoto = src.read(window=window_1m) # (3, shape, shape) uint8
            
        with rasterio.open(self.raster_dirs['orthophotoIR']) as src:
            raster_orthophotoIR = src.read(window=window_1m) # (3, shape, shape) uint8
        
        with rasterio.open(self.raster_dirs['landcover']) as src:
            raster_landcover = src.read(window=window_1m) # (1,shape,shape) uint8
            
        sample = {'tile': idx, 'orthophoto': raster_orthophoto, 'orthophotoIR': raster_orthophotoIR, 'landcover': raster_landcover}
        

        if self.transform:
            sample = self.transform(sample)

        return sample

In [7]:
CatLC_ortho_train = CatalanDataset(no_borders_train, raster_dirs, transform=transforms.Compose([ToTensor(), Normalize()]))
CatLC_ortho_val = CatalanDataset(no_borders_val, raster_dirs, transform=transforms.Compose([ToTensor(), Normalize()]))
CatLC_ortho_test = CatalanDataset(no_borders_test, raster_dirs, transform=transforms.Compose([ToTensor(), Normalize()]))

In [87]:
train_loader = DataLoader(CatLC_ortho_train, batch_size=16, shuffle=True, num_workers=20)
val_loader = DataLoader(CatLC_ortho_val, batch_size=16, shuffle=True, num_workers=10)
test_loader = DataLoader(CatLC_ortho_test, batch_size=1, shuffle=True, num_workers=20)

In [9]:
class UNet(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=True):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        factor = 2 if bilinear else 1
        self.down4 = Down(512, 1024 // factor)
        self.up1 = Up(1024, 512 // factor, bilinear)
        self.up2 = Up(512, 256 // factor, bilinear)
        self.up3 = Up(256, 128 // factor, bilinear)
        self.up4 = Up(128, 64, bilinear)
        self.outc = OutConv(64, n_classes)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits

In [10]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [11]:
model = UNet(n_channels=4, n_classes=41, bilinear=True)

model = model.to(device);

criterion = nn.CrossEntropyLoss()

optimizer = torch.optim.Adam(params=model.parameters(), lr=0.0001)

In [12]:
def save_model(name, epoch, model, optimizer, train_losses, val_losses):
    PATH = f'saved_models/{name}'
    torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'train_losses': train_losses,
            'val_losses': val_losses,
            }, PATH)

In [13]:
train_losses = []
val_losses = []
patience = 0

In [32]:
with open('../dataset_ICCV/cmap_mcsc.pickle', 'rb') as f:
    cmap_mcsc = np.load(f, allow_pickle=True)
with open('../dataset_ICCV/norm_mcsc.pickle', 'rb') as f:
    norm_mcsc = np.load(f, allow_pickle=True)

In [41]:
def plot_figure():
    
    %matplotlib inline

    data = next(iter(test_loader))

    tile = no_borders_test[data['tile']]
    fig, axs = plt.subplots(1,3, figsize=(14,5))

    window = Window(*tile, 320, 320)
    with rasterio.open(raster_dirs['orthophoto']) as src:
        pic = src.read(window=window)
    axs[0].imshow(np.moveaxis(pic,0,-1))
    axs[0].set_title('Orthophoto')

    with rasterio.open(raster_dirs['landcover']) as src:
        pic_landcover = src.read(window=window) - 1
    plot_landcover = axs[1].imshow(pic_landcover[0], cmap=cmap_mcsc, vmin=0, vmax=41, interpolation='none')
    axs[1].set_title('landcover')
    axs[1].format_coord = Formatter(plot_landcover)

    input_data_highres = torch.cat((data['orthophoto'], data['orthophotoIR']), 1)

    prediction = model(input_data_highres.to(device))
    prediction = torch.max(prediction, 1)[1][0].cpu()
    labels = data['landcover'].squeeze().long().cpu()

    plot_prediction = axs[2].imshow(prediction, cmap=cmap_mcsc, vmin=0, vmax=41, interpolation='none')
    axs[2].set_title(f'prediction. acc: {torch.sum(prediction==labels)/labels.nelement():.2f}, iou: {metrics.mean_iou(prediction, labels, 41):.2f}')
    #axs[2].set_title(f'prediction, accuracy: {100*np.sum(torch.max(prediction, 1)[1].cpu() == pic_landcover[0])/pic_landcover.size:.2f}%')
    axs[2].format_coord = Formatter(plot_prediction)

    classes_in_images = np.unique([prediction.numpy(),labels.numpy()])
    patches =[mpatches.Patch(color=cmap_mcsc.colors[i],label=legend.iloc[i]['Name']+f' ({i+1})') for i in classes_in_images]
    fig.legend(handles=patches, loc='lower center', ncol = (len(patches) if len(patches) < 6 else 5))

    fig.suptitle(f'Tile {tile}')
    
    return fig

In [None]:
for epoch in range(0, 80):
    loss_per_batch = torch.tensor((), dtype=torch.float32).new_empty(
        (train_loader.__len__())
    )
    with tqdm(train_loader) as t:
        for batch_idx, data in enumerate(t):

            model.train()
            # data
            labels, inputs_highres = data['landcover'], torch.cat((data['orthophoto'], data['orthophotoIR']), 1)
            
            labels = torch.squeeze(labels.long(),1)

            labels = labels.to(device)
            inputs_highres = inputs_highres.to(device)

            # forward
            optimizer.zero_grad()

            outputs = model(inputs_highres)
            loss = criterion(outputs, labels)
            loss = loss.mean()
            
            loss_per_batch[batch_idx] = loss.item()
            
            loss.backward()
            optimizer.step()
            
            accuracy = 100 * torch.sum(torch.max(outputs, 1)[1] == labels)/labels.nelement()
            
            t.set_postfix({"batch_idx": batch_idx,
                            "Loss": loss.item(),
                            "Accuracy (%)": accuracy.item(),
                            "Learning rate": optimizer.param_groups[0]["lr"],})
            
        loss_per_epoch = torch.mean(loss_per_batch)
        t.set_postfix({"Total loss": loss_per_epoch.item()})
        train_losses.append(loss_per_epoch.item())        
            
    loss_per_batch = torch.tensor((), dtype=torch.float32).new_empty(
        (val_loader.__len__())
    )            
    with tqdm(val_loader) as t:
        for batch_idx, data in enumerate(t):

            model.eval()
            # data
            labels, inputs_highres = data['landcover'], torch.cat((data['orthophoto'], data['orthophotoIR']), 1)
            
            labels = torch.squeeze(labels.long(),1)

            labels = labels.to(device)
            inputs_highres = inputs_highres.to(device)

            # forward
            with torch.no_grad():

                outputs = model(inputs_highres)

                loss = criterion(outputs, labels)
                loss = loss.mean()

                loss_per_batch[batch_idx] = loss.item()

                accuracy = 100 * torch.sum(torch.max(outputs, 1)[1] == labels)/labels.nelement()

                t.set_postfix({"batch_idx": batch_idx,
                                "Loss": loss.item(),
                                "Accuracy (%)": accuracy.item(),
                                "Learning rate": optimizer.param_groups[0]["lr"],})

        loss_per_epoch = torch.mean(loss_per_batch).item()
        t.set_postfix({"Total loss": loss_per_epoch})
        
        if any([loss_per_epoch > i for i in val_losses]) and epoch != 0:
            patience += 1
        else:
            patience = 0
            save_model(name, epoch, model, optimizer, train_losses, val_losses)
        
        val_losses.append(loss_per_epoch)
        
    writer = SummaryWriter('logs/'+ name)
    writer.add_figure('figure', plot_figure(), global_step=epoch);
    writer.add_scalars('Loss', {'Train': train_losses[epoch], 'Validation': val_losses[epoch]}, global_step=epoch)
    writer.close()
    
    if patience >= 20:
        break

 19%|█▉        | 247/1277 [03:44<15:29,  1.11it/s, batch_idx=246, Loss=1.54, Accuracy (%)=58.3, Learning rate=0.0001]

In [124]:
metric_collection = MetricCollection([
    Accuracy(num_classes=41, average=None, mdmc_average='global'),
    Precision(num_classes=41, average=None, mdmc_average='global'),
    Recall(num_classes=41, average=None, mdmc_average='global'),
    IoU(num_classes=41, reduction='none')
])
metric_collection.reset()

In [125]:
model.eval()
with tqdm(test_loader) as t:
    for batch_idx, data in enumerate(t):
        with torch.no_grad():
            labels, inputs_highres = data['landcover'], torch.cat((data['orthophoto'], data['orthophotoIR']), 1)            
            labels = torch.squeeze(labels.long(),1)
            labels = labels.to(device)
            
            inputs_highres = inputs_highres.to(device)
            
            outputs = model(inputs_highres)
            preds = torch.max(outputs, 1)[1]
            
            metrics = metric_collection(preds.cpu(), labels.cpu())

100%|██████████| 6808/6808 [28:11<00:00,  4.03it/s]


In [132]:
metric_collection.compute()

{'Accuracy': tensor([0.9035, 0.0379, 0.8146, 0.6660, 0.7038, 0.0163, 0.8498, 0.6661, 0.6803,
         0.6029, 0.2578, 0.1351, 0.3726, 0.5039, 0.2712, 0.1665, 0.0804, 0.6883,
         0.0000, 0.1388, 0.5506, 0.7757, 0.5711, 0.4678, 0.5905, 0.4989, 0.7360,
         0.1561, 0.4084, 0.0018, 0.7972, 0.1523, 0.0358, 0.3572, 0.0000, 0.9351,
         0.6625, 0.4846, 0.4849, 0.5029, 0.0000]),
 'Precision': tensor([0.8722, 0.1804, 0.8153, 0.7494, 0.6602, 0.0710, 0.8099, 0.6643, 0.6549,
         0.6962, 0.2738, 0.1722, 0.2095, 0.5642, 0.4393, 0.3711, 0.6436, 0.5912,
         0.0000, 0.3899, 0.5287, 0.5764, 0.4476, 0.3582, 0.6151, 0.3962, 0.6351,
         0.3559, 0.5271, 0.1378, 0.6431, 0.3164, 0.3867, 0.6734, 0.0000, 0.6082,
         0.6349, 0.6908, 0.5271, 0.7013, 0.0000]),
 'Recall': tensor([0.9035, 0.0379, 0.8146, 0.6660, 0.7038, 0.0163, 0.8498, 0.6661, 0.6803,
         0.6029, 0.2578, 0.1351, 0.3726, 0.5039, 0.2712, 0.1665, 0.0804, 0.6883,
         0.0000, 0.1388, 0.5506, 0.7757, 0.5711, 0.46