In [None]:
!pip install opendatasets
!pip install pandas
!pip install albumentations

In [1]:
import os
import albumentations as alb
#import tensorflow as tf
import datetime
import torch
import numpy as np
import opendatasets as od

from model import UNet
from utils import plot, get_data_loaders, evaluate, get_dice_score
#from torchsummary import summary
from torch.utils.tensorboard import SummaryWriter
from albumentations.pytorch.transforms import ToTensorV2
from torchmetrics.classification import Dice

  warn(


In [None]:
# Download dataset from kaggle
od.download(
    "https://www.kaggle.com/datasets/balraj98/deepglobe-land-cover-classification-dataset")

In [None]:
# model.py
import torch
import numpy as np

from torch import nn

# U-Net

# Two convolution block. Performs two consecutive convolutions
class TwoConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding='same'):
        super().__init__()

        self.module_list = nn.ModuleList([])
        
        #Using Henriks convultion layering or the one introduced in the Unet paper?
        self.module_list.append(nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding))
        self.module_list.append(nn.ReLU())

        self.module_list.append(nn.Conv2d(out_channels, out_channels, kernel_size, stride, padding))
        self.module_list.append(nn.ReLU())

    def forward(self, x):
        y = x
        for module in self.module_list:
            y = module(y)
        return y

# UNet encoder block. Performs two convolutions and max pooling.
class ConvPool(TwoConv):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding='same'):
        super().__init__(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
        self.max = nn.MaxPool2d(2, 2)

    def forward(self, x):
        c = super().forward(x)
        p = self.max(c)
        return c, p

# UNet decoder block. Performs upsampling, concatenation of the two inputs and two convolutions.
class UpConv(TwoConv):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding='same'):
        super().__init__(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
        # We may use different upsampling method here.
        self.upsampling = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)

    def forward(self, x, skip):
        u = self.upsampling(x)
        u = torch.cat([u, skip], 1)
        c = super().forward(u)
        return c, u


class UNet(nn.Module):
    def __init__(self, in_channels, min, max, num_classes):
        super().__init__()
        self.enc_layers = nn.ModuleList([])
        self.dec_layers = nn.ModuleList([])
        self.enc_final = None
        self.dec_final = None
        self.softmax = None

        # When go down the encoder/up the decoder the number of filter doubles/halves
        # respectively. For that we will generate the powers of two.
        # List of powers of 2 [min, 2*min, 4*min, ..., max]
        channels = []
        power = min
        for i in range(int(np.log2(max // min))):
            channels.append(power)
            power = power*2

        # Construct list of blocks for the encoder
        self.enc_layers.append(ConvPool(in_channels, min))
        for i in range(len(channels)-1):
            enc_layer = ConvPool(channels[i], channels[i+1])
            self.enc_layers.append(enc_layer)

        # Construct list of blocks for the encoder
        for i in range(len(channels)-1):
            dec_layer = UpConv(channels[i+1], channels[i])
            self.dec_layers.insert(0, dec_layer)
        self.dec_layers.insert(0, UpConv(max, channels[-1]))

        # Set up final convolutions for the encoder and decoder
        self.enc_final = TwoConv(channels[len(channels)-1], max, 3, 1, 'same')
        self.dec_final = nn.Conv2d(min, num_classes, 1, 1)
        self.softmax = nn.Softmax(0)

    def forward(self, x):
        # Collect the values for skip connections to the decoder
        skip_connections = []
        p = x
        # Encoder
        for layer in self.enc_layers:
            c, p = layer(p)
            skip_connections.append(c)

        # Bottleneck
        c =  self.enc_final(p)

        # Decoder
        for layer in self.dec_layers:
            skip = skip_connections.pop()
            c, u = layer(c, skip) # if we do not need c we can use _ instead
        c = self.dec_final(c)

        return self.softmax(c)

In [None]:
# dataset.py
#Custom dataset for SatelliteSet
import cv2
import pandas as pd
import os
import torch
import numpy as np
import matplotlib.pyplot as plt

from torch.utils.data import Dataset


# Custom dataset class to load deep globe dataset
class SatelliteSet(Dataset):
    def __init__(self,
                 # mpandas dataframe loaded with a meta data csv containing image file names
                 meta_data, 
                 # Class dictionary
                 class_dict,
                 # directory where the data is stored
                 data_dir, 
                 # albumentations transform
                 transform=None):
        self.meta_data = meta_data
        self.data_dir = data_dir
        self.transform = transform
        self.class_dict = class_dict

    # number of samples in dataset
    def __len__(self):
        return len(self.meta_data)

    # load and return sample at index idx
    def __getitem__(self, idx):
        # Read image
        img_path = os.path.join(self.data_dir, self.meta_data.iloc[idx]['sat_image_path'])
        image = cv2.imread(img_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        # Read target mask
        mask_path = os.path.join(self.data_dir, self.meta_data.iloc[idx]['mask_path'])
        mask = cv2.imread(mask_path)
        mask = cv2.cvtColor(mask, cv2.COLOR_BGR2RGB)
        if self.transform:
            # In the transform from albumentation we pass both the image and the mask together to make sure
            # they undergo the same transformation, e.g. this ensure both have the same random crop
            transformed = self.transform(image = image, mask = mask)
            image = transformed['image'].to(torch.float32)
            mask = transformed['mask']
            mask = torch.tensor(np.apply_along_axis(lambda k: self.class_dict[tuple(k)], 2, mask))
            mask_onehot = torch.zeros((len(self.class_dict), mask.shape[0], mask.shape[1]))
            for w,h in mask.nonzero(as_tuple=False):
                mask_onehot[mask[w,h], w, h] = 1
            mask = mask_onehot.to(torch.float32)

        #image.require_grad = True

        return image, mask
    


In [None]:
# utils.py
import matplotlib.pyplot as plt
import numpy as np
import os
import torch
import pandas as pd

from PIL import Image
from dataset import SatelliteSet
from torch.utils.data import SubsetRandomSampler
from torchmetrics import JaccardIndex
from torchmetrics.classification import Dice

def plot(sample, data_dir):

    plt.figure(figsize=(5,4))
    ax = plt.subplot(2,2,1)
    plt.imshow(np.asarray(Image.open(os.path.join(data_dir, sample['sat_image_path'].iloc[0]))))
    plt.gray()
    ax.get_yaxis().set_visible(False)
    ax.get_xaxis().set_visible(False)

    ax = plt.subplot(2,2,2)
    plt.imshow(np.asarray(Image.open(os.path.join(data_dir, sample['mask_path'].iloc[0]))))
    plt.gray()
    ax.get_yaxis().set_visible(False)
    ax.get_xaxis().set_visible(False)

    plt.show()

# TODO: make something like a dictionary for parametere to pass to data loader
def get_data_loaders(data_dir, transform, shuffle_dataset, test_split, random_seed, batch_size):

    # Load metadata csv
    metadata = pd.read_csv(os.path.join(data_dir, 'subsample.csv'))

    # We need to filter for row where 'split' is 'train' because samples where
    # 'split' is 'valid' or 'test' have no target mask
    metadata = metadata[metadata['split'] == 'train']

    class_dict = pd.read_csv(os.path.join(data_dir, 'class_dict.csv'))

    classes = {}
    c = 0
    for i in class_dict.index:
        classes[tuple(class_dict.iloc[i,1:])] = c
        c += 1

    # We need to filter for row where 'split' is 'train' because samples where
    # 'split' is 'valid' or 'test' have no target mask

    dataset = SatelliteSet(meta_data=metadata, class_dict=classes, data_dir=data_dir, transform=transform)

    # Creating data indices for training and validation splits:
    dataset_size = len(dataset)
    indices = list(range(dataset_size))
    split = int(np.floor(test_split * dataset_size))
    #TODO:Implement validation

    if shuffle_dataset:
        np.random.seed(random_seed)
        np.random.shuffle(indices)
    train_indices, test_indices = indices[split:], indices[:split]

    # Creating PT data samplers and loaders:
    train_sampler = SubsetRandomSampler(train_indices)
    test_sampler = SubsetRandomSampler(test_indices)

    train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, 
                                            sampler=train_sampler)
    test_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                                sampler=test_sampler)
    return train_loader, test_loader

def evaluate(model, writer, dataloader, epoch):
    iou_score = 0
    iou = JaccardIndex('multiclass', num_classes=7)
    dice_score = 0
    dice = Dice(num_classes=7)
    for x, y in dataloader:
        pred = model(x)
        pred = torch.argmax(pred, 1)
        y = torch.argmax(y, 1)
        # Calculate IOU
        iou_score += iou(pred, y)
        # Calculate DICE
        dice_score += dice(pred, y)

    writer.add_scalar('DICE Score', dice_score / len(dataloader), epoch)
    writer.add_scalar('IOU Score', iou_score / len(dataloader), epoch)

def get_iou_score(pred, y):
    iou_score = 0
    intersection = np.logical_and(y, pred)
    union = np.logical_or(y, pred)
    iou_score += np.sum(intersection) / np.sum(union)
    return iou_score

def get_dice_score(pred, y):
    dice_score = 0
    intersection = np.logical_and(y, pred)
    dice_score += np.sum(intersection) / (y.size() + pred.size())
    return dice_score

In [2]:
# Define constants

# directories
data_dir = 'deepglobe-land-cover-classification-dataset' # change to directory containing the data
trained_models = 'trained_models'
train_dir = 'train'
log_dir = 'runs'

In [3]:
# Training configuration (hyperparameters)

## Data
test_split = .2 #20% for test split
#valdation_split = .2 #20% for validation split
random_seed = np.random.seed()
shuffle_dataset = True

transform = alb.Compose([
    alb.RandomCrop(width=256, height=256),
    alb.HorizontalFlip(p=0.5),
    ToTensorV2()
    ],
    # we want the mask and the image to have the same augmentation (especially when we crop)
    # this way we pass the image and the mask simultaneously to the pipeline
    additional_targets={'image': 'image', 'mask': 'mask'}
    )

## model architecture
in_channels = 3
min_channels = 16
max_channels = 128
num_classes = 7

## Training
learning_rate = 0.1
batch_size = 10
epochs = 3

In [4]:
# Training

# setup training enviroment

#Labels


# init data loader/generator
train_dataloader, test_dataloader = get_data_loaders(data_dir, transform, shuffle_dataset, test_split, random_seed, batch_size)

# init model, optimizer and loss function
model = UNet(in_channels, min_channels, max_channels, num_classes)
opt = torch.optim.SGD(model.parameters(), learning_rate)
loss_func = torch.nn.CrossEntropyLoss() #Dice(num_classes=num_classes)

# Set up summary writer for tensorboard
current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
writer = SummaryWriter(os.path.join(log_dir, current_time))

# start training
for epoch in range(epochs):
    print(f"Epoch: {epoch}")
    for x, y in train_dataloader:
        opt.zero_grad() 
        pred = model(x)
        # TODO: check the output of the model. is it one hot encoded, rgb or else ?
        #pred = torch.argmax(pred, 1)
        pred = torch.softmax(pred, 1)
        #pred.requires_grad = True
        loss = loss_func(pred, y)
        loss.backward()
        opt.step()
        
    # TODO: this is only the loss from the last batch. we want the accumulated loss or something else.
    writer.add_scalar('Loss', loss, epoch)
    
    # Evaluate model
    model.eval()
    evaluate(model, writer, test_dataloader, epoch)
    model.train()

# Save model after training
torch.save(model.state_dict(), os.path.join(trained_models, current_time))

Epoch: 0
Batch
Batch
Epoch: 1
Batch
Batch
Epoch: 2
Batch
Batch


RuntimeError: Parent directory trained_models does not exist.

In [None]:
# Model evaluation

model = UNet(0, 0, 0, 0)
model.load_state_dict(torch.load(PATH))
model.eval()