In [1]:
import numpy as np
from torch.utils.data import DataLoader, Dataset
import torch
from torchvision import transforms
import albumentations as A
from albumentations.pytorch import ToTensorV2
import torchvision
from torchvision.models.segmentation import deeplabv3_mobilenet_v3_large
from torch import nn
from torch import optim
import segmentation_models_pytorch as smp

from tqdm import tqdm
import pandas as pd
import os


In [2]:
import numpy as np
import pandas as pd
from torch.utils.data import Dataset
import tifffile

class GeoDataset(Dataset):
    def __init__(self, csv_file, root_dir, transform=None, train=True, header=None):
        self.img_target = pd.read_csv(csv_file, header=header)
        self.root_dir = root_dir
        self.transform = transform
        self.train = train

    def __len__(self):
        return len(self.img_target)

    def __getitem__(self, index):
        img_path = os.path.join(self.root_dir, self.img_target.iloc[index, 0])
        image = tifffile.imread(img_path)

        label_path = os.path.join(self.root_dir, self.img_target.iloc[index, 1])
        y_label = tifffile.imread(label_path)

        image = image[:, :, :3]

        if self.transform:
            auggmented = self.transform(image = np.array(image), mask= np.array(y_label))
            image, mask = auggmented['image'], auggmented['mask']

        return image, mask+1


In [3]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
NUM_CLASSES = 3
TRAIN_VAL_SPLIT = 0.8
HEIGHT = 512
WIDTH  = 512

batch_size = 8
num_epochs = 10

train_transforms = A.Compose([
    A.Resize(height=256, width=256, p=1.0),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.Rotate(limit=30, p=0.5),
    A.Normalize(mean=(0.1759, 0.2107, 0.2502), std=(0.3352, 0.3322, 0.3726)),
    ToTensorV2()
])


val_transforms = A.Compose([
    A.Resize(height=256, width=256, p=1.0),
    A.Normalize(mean=(0.0511, 0.0766, 0.0931), std=(0.0292, 0.0357, 0.0528)),
    ToTensorV2()
])

In [4]:
DEVICE

'cuda'

In [5]:
traing_dataset = GeoDataset("/content/drive/MyDrive/GeoAI/training_index.csv", root_dir="/content/drive/MyDrive/GeoAI/training",
                             transform=train_transforms, train=True)

train_loader = DataLoader(traing_dataset, batch_size=batch_size, shuffle=True)


In [6]:
val_dataset = GeoDataset("/content/drive/MyDrive/GeoAI/validation_index.csv", root_dir="/content/drive/MyDrive/GeoAI/validation",
                             transform=val_transforms, train=False)

val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)

In [7]:

def f1_dice_score(preds, true_mask):
    '''
    https://towardsdatascience.com/metrics-to-evaluate-your-semantic-segmentation-model-6bcb99639aa2
    preds should be (B, 25, H, W)
    true_mask should be (B, H, W)
    '''

    f1_batch = []

    for i in range(len(preds)):
        f1_image = []
        img = preds[i].to(DEVICE)
        mask = true_mask[i].to(DEVICE)

        # Change shape of img from [25, H, W] to [H, W]
        img = torch.argmax(img, dim=0)

        for label in range(3):
            if torch.sum(mask == label) != 0:
                area_of_intersect = torch.sum((img == label) * (mask == label))
                area_of_img = torch.sum(img == label)
                area_of_label = torch.sum(mask == label)
                f1 = 2 * area_of_intersect / (area_of_img + area_of_label)
                f1_image.append(f1)

        f1_batch.append(np.mean([tensor.cpu() for tensor in f1_image]))
    return np.mean(f1_batch)

In [8]:

def accuracy(preds, true_mask):
    '''
    preds should be (B, C, H, W)
    true_mask should be (B, H, W)
    '''
    # Ensure preds and true_mask are on the same device
    preds = preds.to(true_mask.device)

    # Calculate the number of classes (number of channels in preds)
    num_classes = preds.size(1)

    accuracy_batch = []

    for i in range(len(preds)):
        img = torch.argmax(preds[i], dim=0)
        mask = true_mask[i]

        class_accuracy = []

        for c in range(num_classes):
            # Calculate accuracy for each class separately
            class_pred = (img == c)
            class_mask = (mask == c)
            class_correct = (class_pred & class_mask).sum().item()
            class_total = class_mask.sum().item()

            # Avoid division by zero
            if class_total > 0:
                class_accuracy.append(class_correct / class_total)
            else:
                class_accuracy.append(0.0)

        # Average accuracy over classes
        average_accuracy = sum(class_accuracy) / num_classes

        accuracy_batch.append(average_accuracy)

    return torch.mean(torch.tensor(accuracy_batch))


In [9]:
def train():
    min_val_f1 = 0.3

    for epoch in range(STARTING_EPOCH + 1, STARTING_EPOCH + EPOCHS + 1):

        # Train model
        model.train()
        train_losses = []
        train_accuracy = []
        train_f1 = []

        for i, batch in enumerate(train_loader):
            # Extract data, labels
            img_batch, mask_batch = batch  # img [B,3,H,W], mask[B,H,W]

            # Train model
            optimizer.zero_grad()
            # with torch.cuda.amp.autocast():
            output = model(img_batch.to(DEVICE))  # output: [B, 25, H, W]

            pred = output
            pred = pred.to('cpu')

            mask_batch = mask_batch.type(torch.LongTensor).to('cpu')

            loss = criterion(pred, mask_batch)
            loss.backward()

            # torch.nn.utils.clip_grad_norm_(model.parameters(), 6)
            optimizer.step()

            # Add current loss to temporary list (after 1 epoch take avg of all batch losses)
            f1 = f1_dice_score(output, mask_batch)
            acc = accuracy(output, mask_batch)
            train_losses.append(loss.item())
            train_accuracy.append(acc)
            train_f1.append(f1)
            # print(f'Train Epoch: {epoch}, batch: {i} | Batch metrics | loss: {loss.item():.4f}, f1: {f1:.3f}, accuracy: {acc:.3f}')

        # Update global metrics
        print(
            f'TRAIN       Epoch: {epoch} | Epoch metrics | loss: {np.mean(train_losses):.4f}, f1: {np.mean(train_f1):.3f}, accuracy: {np.mean(train_accuracy):.3f}')
        total_train_losses.append(np.mean(train_losses))
        total_train_accuracy.append(np.mean(train_accuracy))
        total_train_f1.append(np.mean(train_f1))

        # Validate model
        model.eval()
        val_losses = []
        val_accuracy = []
        val_f1 = []

        for i, batch in enumerate(val_loader):
            # Extract data, labels
            img_batch, mask_batch = batch
            img_batch = img_batch.to(DEVICE)
            mask_batch = mask_batch.to(DEVICE)

            # Validate model
            with torch.cuda.amp.autocast():



                output = model(img_batch)
                pred = output.to('cuda')
                mask_batch = mask_batch.type(torch.LongTensor).to('cuda')
                loss = criterion(pred, mask_batch)

            # Add current loss to temporary list (after 1 epoch take avg of all batch losses)
            f1 = f1_dice_score(output, mask_batch)
            acc = accuracy(output, mask_batch)
            val_losses.append(loss.item())
            val_accuracy.append(acc)
            val_f1.append(f1)

            # print(f'Val Epoch: {epoch}, batch: {i} | Batch metrics | loss: {loss.item():.4f}, f1: {f1:.3f}, accuracy: {acc:.3f}')

        # Update global metrics
        print(
            f'VALIDATION  Epoch: {epoch} | Epoch metrics | loss: {np.mean(val_losses):.4f}, f1: {np.mean(val_f1):.3f}, accuracy: {np.mean(val_accuracy):.3f}')
        print('---------------------------------------------------------------------------------')
        total_val_losses.append(np.mean(val_losses))
        total_val_accuracy.append(np.mean(val_accuracy))
        total_val_f1.append(np.mean(val_f1))

        # Save the model
        if np.mean(val_f1) > min_val_f1:
            torch.save(model.state_dict(),
                       f'/mm/{epoch}.pt')
            min_val_f1 = np.mean(val_f1)

        # Save the results so far
        temp_df = pd.DataFrame(list(zip(total_train_losses, total_val_losses, total_train_f1, total_val_f1,
                                        total_train_accuracy, total_val_accuracy)),
                               columns=['train_loss', 'val_loss', 'train_f1', 'test_f1', 'train_accuracy',
                                        'test_accuracy'])
        temp_df.to_csv('train_val_measures')



In [10]:
model = smp.PSPNet(
    encoder_name = 'resnet101',
    encoder_weights = 'imagenet',
    classes = 3,
    activation = None, # could be None for logits or 'softmax2d' for multiclass segmentation
).to(DEVICE)

In [11]:
total_train_losses   = []
total_val_losses     = []
total_train_accuracy = []
total_val_accuracy   = []
total_train_f1       = []
total_val_f1         = []


# HYPERPARAMETERS for training run
STARTING_EPOCH = 10
EPOCHS = 20
LR = 0.001
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
criterion = nn.CrossEntropyLoss().to(DEVICE)

train()


TRAIN       Epoch: 11 | Epoch metrics | loss: 0.4168, f1: 0.543, accuracy: 0.446
VALIDATION  Epoch: 11 | Epoch metrics | loss: 25.6683, f1: 0.420, accuracy: 0.333
---------------------------------------------------------------------------------


KeyboardInterrupt: ignored