In [1]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import os
import cv2
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch import optim
import albumentations as A
from albumentations.pytorch import ToTensorV2
import pickle
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from itertools import product
import random

import segmentation_models_pytorch as smp

torch.backends.cudnn.benchmark = True
scaler = torch.cuda.amp.GradScaler()



In [2]:
if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

print(f"Using {device}")

Using mps


In [None]:
run_test_cases = False # Set to True if debugging is required
cloud_labels = ["Flower", "Gravel", "Fish", "Sugar"] # All possible labels for the clouds

# Original Image resolutions
in_res_y = 1400
in_res_x = 2100

# New Image resolutions
tile_size = (512, 512)
stride = (512, 512)

# data directories
test_dir = "./test_images"
train_dir = "./train_images"

# Model specifications
model_name = "pretrained" # [custom, pretrained]
num_filters = 32 # Number of filters in first conv layer (only for custom network)

# Training params
num_train_images = 5040
num_test_images = 256

num_workers = 0 # Number of workers for Dataset creation

batch_size = 8
num_epochs = 10
lr = 6e-5
weight_decay = 2e-5

In [4]:
# Create Dataframe
df = pd.read_csv('train.csv')
df[['Image', 'Label']] = df['Image_Label'].str.split('_', expand=True)

# Find images with at least one mask
df_nonempty = df.groupby("Image")["EncodedPixels"].apply(lambda x: x.notna().any()).reset_index()
df_nonempty = df_nonempty[df_nonempty["EncodedPixels"] == True]
valid_images = df_nonempty["Image"].unique().tolist()

image_names = [img for img in sorted(os.listdir(train_dir)) if img in valid_images]

train_images = image_names[:num_train_images]
test_images = image_names[num_train_images:num_train_images+num_test_images]

if run_test_cases:
    print(df[['Image', 'Label', 'EncodedPixels']].head(8))
    print()
    print(df['Image'].unique()[:10])

In [5]:
# Get labels and rle from image name
def get_labels_rle(image_name: str, df) -> list:
    rles = df[df['Image'] == image_name]['EncodedPixels'].to_list()
    labels = df[df['Image'] == image_name]['Label'].to_list()
    return rles, labels

# Debugging
if run_test_cases:

    # Get Files
    test_train_images = os.listdir(train_dir)[:4]
    print(f"Train images: {test_train_images}")

    for image in test_train_images:
        rles, labels = get_labels_rle(f"{image}", df)
        for rle, label in zip(rles, labels):
            print(f"Label: {label} \n rle: {rle} \n")

In [6]:
# Convert rle mask encoding into 2D arrays
def rle_to_array(rle_list: list) -> np.array:

    # Create empty array for
    array = np.zeros(in_res_y * in_res_x)

    # Skip if cloud formation is not on picture
    if not rle_list or pd.isna(rle_list):
        mask = array.reshape((in_res_x, in_res_y), order="A").T
        return mask
    
    rle_array = np.array(list(map(int, rle_list.split())), dtype=int)
    start_pixels = rle_array[::2] - 1 # Offset because pixel 1 is arr position 0
    num_pixels = rle_array[1::2]

    # Create 2D mask
    for start_pixel, num_pixels in zip(start_pixels, num_pixels): # Format is [start_idx_0, num_pixels_0 ...]
        array[start_pixel:start_pixel+num_pixels] = 1.0
    
    # Reshape
    mask = array.reshape((in_res_x, in_res_y), order="A").T # 2D array of [Height, Width]

    return mask

# For debugging
if run_test_cases:

    # Get Files
    test_train_images = os.listdir(train_dir)[:2]
    print(f"Train images: {test_train_images}")

    # Plot files
    for image_name in test_train_images:
        img = cv2.imread(f"{train_dir}/{image_name}")
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

        rles, labels = get_labels_rle(image_name, df)

        # Raw Image
        plt.imshow(img)
        plt.show()

        for rle, label in zip(rles, labels):

            # Masked Image
            mask = rle_to_array(rle)
            print(np.unique(mask))
            print(f"titel: {label}")
            plt.imshow(mask, cmap="grey", vmin=0.0, vmax=1.0)
            #plt.imshow(mask[:, :, None].repeat(3, axis=-1)*img, cmap="grey", vmin=0.0, vmax=1.0)
            plt.show()
        print("---------------------------------")

In [7]:
def dice_coef(preds, target, eps=1e-6):
    # [B, 4, H, W]

    preds = torch.sigmoid(preds)
    overlap = (preds * target).sum((2,3))

    dice = (2. * overlap + eps) / (preds.sum((2,3)) + target.sum((2,3)) + eps)

    return dice.mean()

In [8]:
def dice_loss(preds, target, eps=1e-6):
    preds = preds.float()
    target = target.float()
    return 1 - dice_coef(preds, target, eps=eps)

In [9]:
pos_weight=torch.tensor([4.84217556, 5.62148491, 4.92171107, 4.60515229], dtype=torch.float32, device=device)

def bce_loss(logits, targets, pos_weight=pos_weight):
    """
    Custom BCE with per-class pos_weight pre-applied.
    logits: [B, C, H, W]
    targets: [B, C, H, W]
    pos_weight: [C]
    """
    logits = logits.float()
    targets = targets.float()

    # Broadcast pos_weight to (1, C, 1, 1)
    pos_weight = pos_weight.view(1, -1, 1, 1)

    # Compute sigmoid
    pred = torch.sigmoid(logits)

    # Compute standard BCE loss per element
    loss = -(
        pos_weight * targets * torch.log(pred + 1e-8) +
        (1 - targets) * torch.log(1 - pred + 1e-8)
    )

    return loss.mean()

In [10]:
def loss_fn(preds, target, split=0.5, eps=1e-6):
    return split * dice_loss(preds, target, eps) + (1 - split) * bce_loss(preds, target)

In [11]:
train_transform = A.Compose([
    # --- geometric transforms ---
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.RandomRotate90(p=0.5),
    A.Transpose(p=0.5),

    # --- Random crop + resize ---
    #A.RandomResizedCrop(
    #    height=tile_size[0], width=tile_size[1],
    #    scale=(0.7, 1.0),  # crop 70–100% of image
    #    ratio=(0.9, 1.1),  # aspect ratio jitter
    #    p=0.5
    #),

    # --- photometric transforms ---
    A.RandomBrightnessContrast(brightness_limit=0.05, contrast_limit=0.05, p=0.5),
    A.HueSaturationValue(hue_shift_limit=3, sat_shift_limit=5, val_shift_limit=5, p=0.5),
    
    A.ElasticTransform(alpha=50, sigma=8, alpha_affine=8, p=0.7),
    A.RandomBrightnessContrast(brightness_limit=0.1, contrast_limit=0.1, p=0.7),
    A.GaussNoise(var_limit=(5, 20), p=0.7),

    # --- spatial transforms ---
    A.ShiftScaleRotate(
        shift_limit=0.05,
        scale_limit=0.1,
        rotate_limit=15,
        border_mode=cv2.BORDER_REFLECT,
        p=0.5
    ),

    # --- Cutout / dropout ---
    A.CoarseDropout(
        max_holes=8, max_height=64, max_width=64,
        min_holes=1, min_height=16, min_width=16,
        fill_value=0, mask_fill_value=0, p=0.5
    ),

    # --- normalization ---
    A.Normalize(mean=(0.485, 0.456, 0.406), 
                std=(0.229, 0.224, 0.225)),

    # --- ensure correct shape ---
    A.Resize(height=tile_size[0], width=tile_size[1]),

    # --- to tensor ---
    ToTensorV2(),
])

val_transform = A.Compose([
    # --- normalization ---
    A.Normalize(mean=(0.485, 0.456, 0.406), 
                std=(0.229, 0.224, 0.225)),

    # --- ensure correct shape ---
    A.Resize(height=tile_size[0], width=tile_size[1]),

    # --- to tensor ---
    ToTensorV2(),
])



  A.ElasticTransform(alpha=50, sigma=8, alpha_affine=8, p=0.7),
  A.GaussNoise(var_limit=(5, 20), p=0.7),
  original_init(self, **validated_kwargs)
  A.CoarseDropout(


In [None]:
class ImageDataset(Dataset):
    def __init__(
        self,
        data_frame,
        img_dir,
        image_names,
        tile_size=(512, 512),
        stride=(512, 512),
        transform=None,
        keep_empty_prob=0.1,  # 10% of empty tiles
    ):
        self.img_dir = img_dir
        self.image_names = image_names
        self.tile_size = tile_size
        self.stride = stride
        self.data_frame = data_frame
        self.transform = transform
        self.keep_empty_prob = keep_empty_prob
        self.tiles_index = []

        self._prepare_tile_indices()

    def pad_to_tile_size(self, image, mask):
        """Pad image & mask so height/width become multiples of tile_size."""
        h, w = image.shape[:2]
        th, tw = self.tile_size
        pad_h = (th - h % th) % th
        pad_w = (tw - w % tw) % tw

        image_padded = np.pad(
            image, ((0, pad_h), (0, pad_w), (0, 0)), mode="reflect"
        )
        mask_padded = np.pad(
            mask, ((0, 0), (0, pad_h), (0, pad_w)), mode="constant"
        )
        return image_padded, mask_padded

    def _prepare_tile_indices(self):
        """Precompute tile indices and filter out empty ones (keep 10% of them)."""
        th, tw = self.tile_size
        sh, sw = self.stride
        h, w = in_res_y, in_res_x  # canonical image size (1400x2100)

        y_positions = range(0, h + ((th - h % th) % th), sh)
        x_positions = range(0, w + ((tw - w % tw) % tw), sw)

        for img_idx, img_name in enumerate(self.image_names):
            # --- Load full mask once ---
            rles, labels = get_labels_rle(img_name, self.data_frame)
            mask = np.zeros((len(cloud_labels), h, w), dtype=np.float32)
            for label, rle in zip(labels, rles):
                if rle is not None and label in cloud_labels:
                    mask[cloud_labels.index(label)] = rle_to_array(rle)

            # --- Pad to tile size ---
            _, mask = self.pad_to_tile_size(np.zeros((h, w, 3)), mask)

            # --- Iterate over tiles ---
            for y, x in product(y_positions, x_positions):
                tile = mask[:, y : y + th, x : x + tw]
                if tile.sum() > 50:
                    self.tiles_index.append((img_idx, y, x))
                else:
                    # keep only 10% of empty tiles
                    if random.random() < self.keep_empty_prob:
                        self.tiles_index.append((img_idx, y, x))

    def __len__(self):
        return len(self.tiles_index)

    def __getitem__(self, idx):
        img_idx, y, x = self.tiles_index[idx]
        image_name = self.image_names[img_idx]

        # --- Load image ---
        image = cv2.imread(f"{self.img_dir}/{image_name}")
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).astype(np.float32)

        # --- Load full mask ---
        rles, labels = get_labels_rle(image_name, self.data_frame)
        mask = np.zeros((len(cloud_labels), image.shape[0], image.shape[1]), dtype=np.float32)
        for label, rle in zip(labels, rles):
            if rle is not None and label in cloud_labels:
                single_mask = rle_to_array(rle)
                mask[cloud_labels.index(label)] = single_mask

        # --- Pad too small images ---
        image, mask = self.pad_to_tile_size(image, mask)

        # --- Crop tile ---
        th, tw = self.tile_size
        image_tile = image[y : y + th, x : x + tw]
        mask_tile = mask[:, y : y + th, x : x + tw]

        image_tile = image_tile.astype(np.uint8)

        # --- Transformations ---
        if self.transform:
            transformed = self.transform(
                image=image_tile,
                mask=mask_tile.transpose(1, 2, 0)
            )
            image_tile = transformed["image"]
            mask_tile = transformed["mask"].permute(2, 0, 1)
        else:
            image_tile = torch.from_numpy(image_tile).permute(2, 0, 1).float()
            mask_tile = torch.from_numpy(mask_tile).float()

        return image_tile, mask_tile



In [13]:
class ConvolutionBlock(nn.Module):
    def __init__(self, in_channels, out_channels, conv_kernel_size, padding):
        super().__init__()

        self.conv_block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, conv_kernel_size, padding=padding),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Conv2d(out_channels, out_channels, conv_kernel_size, padding=padding),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
            )

    def forward(self, x):
        output = self.conv_block(x)
        return output # [batch, out_channels, H_out, W_out]


In [14]:
class PoolBlock(nn.Module):
    def __init__(self, downsample):
        super().__init__()

        self.pool = nn.MaxPool2d(downsample)

    def forward(self, x):
        output = self.pool(x)
        return output

In [15]:
class UpSampleBlock(nn.Module):
    def __init__(self, channels, upsample):
        super().__init__()

        self.up_sample_block = nn.Sequential(
            nn.Upsample(scale_factor=upsample, mode='bilinear', align_corners=False),
            nn.Conv2d(channels, channels // upsample, kernel_size=1)
        )

    def forward(self, x):
        upsample_block = self.up_sample_block(x)
        return upsample_block

In [16]:
class Network(nn.Module):
    def __init__(self, filters):
        super().__init__()

        # in_channels, out_channels, conv_kernel_size, padding, pool_kernel_size
        self.encoder_layer_0 = ConvolutionBlock(3, filters, 3, 1)
        self.encoder_layer_1 = ConvolutionBlock(filters, 2*filters, 3, 1)
        self.encoder_layer_2 = ConvolutionBlock(2*filters, 4*filters, 3, 1)

        self.pool_block_0 = PoolBlock(downsample=2)
        self.pool_block_1 = PoolBlock(downsample=2)
        self.pool_block_2 = PoolBlock(downsample=2)

        self.bottle_neck = nn.Conv2d(4*filters, 8*filters, 3, padding=1)

        # in_channels, out_channels, conv_kernel_size, stride
        self.decoder_layer_2 = ConvolutionBlock(8*filters, 4*filters, 3, 1)
        self.decoder_layer_1 = ConvolutionBlock(4*filters, 2*filters, 3, 1)
        self.decoder_layer_0 = ConvolutionBlock(2*filters, filters, 3, 1)

        self.up_sample_block2 = UpSampleBlock(8 * filters, upsample=2)
        self.up_sample_block1 = UpSampleBlock(4 * filters, upsample=2)
        self.up_sample_block0 = UpSampleBlock(2 * filters, upsample=2)

        self.output_layer = nn.Conv2d(filters, 4, 1)

    def forward(self, x): # [Batch, Color, Height, Width]
        
        enc0 = self.encoder_layer_0(x) # [B, num_filters, H, W]
        pool0 = self.pool_block_0(enc0) # [B, num_filters, H / 2, W / 2]

        enc1 = self.encoder_layer_1(pool0) # [B, 2 * num_filters, H / 2, W / 2]
        pool1 = self.pool_block_1(enc1) # [B, 2 * num_filters, H / 4, W / 4]

        enc2 = self.encoder_layer_2(pool1) # [B, 4 * num_filters, H / 4, W / 4]
        pool2 = self.pool_block_2(enc2) # [B, 4 * num_filters, H / 8, W / 8]

        bottle_neck = self.bottle_neck(pool2) # [B, 4 * num_filters, H / 4, W / 4]

        up2 = self.up_sample_block2(bottle_neck) # [B, 4 * num_filters, H / 4, W / 4]
        concat2 = torch.cat([up2, enc2], dim=1) # [B, 8 * num_filters, H / 4, W / 4]
        dec2 = self.decoder_layer_2(concat2) # [B, 4 * num_filters, H / 4, W / 4]

        up1 = self.up_sample_block1(dec2) # [B, 2 * num_filters, H / 2, W / 2]
        concat1 = torch.cat([up1, enc1], dim=1) # [B, 4 * num_filters, H / 2, W / 2]
        dec1 = self.decoder_layer_1(concat1) # [B, 2 * num_filters, H / 2, W / 2]

        up0 = self.up_sample_block0(dec1) # [B, num_filters, H, W]
        concat0 = torch.cat([up0, enc0], dim=1) # [B, 2 * num_filters, H, W]
        dec0 = self.decoder_layer_0(concat0) # [B, num_filters, H, W]

        logits = self.output_layer(dec0) # [B, 4, H, W]

        return logits

In [None]:
# Datasets and DataLoader
train_dataset = ImageDataset(df, train_dir, train_images, tile_size=tile_size, stride=stride, transform=train_transform)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers,
                              pin_memory=True, persistent_workers = True, prefetch_factor = 5)

test_dataset = ImageDataset(df, train_dir, test_images, tile_size=tile_size, stride=stride, transform=val_transform)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers,
                              pin_memory=True, persistent_workers = True, prefetch_factor = 5)

KeyboardInterrupt: 

In [None]:
# Calculate pos weight for BCE
# set num_workers=0 and disable persistent_workers and prefetch_factor before calculating
if False:
    pos_weight = np.zeros(4)
    neg_pix = np.zeros(4)
    pos_pix = np.zeros(4)

    for image, mask in train_dataloader:
        # Move batch tensors to device
        mask = mask.cpu().numpy()

        pos_pix += np.sum(mask==1, axis=(0,2,3))
        neg_pix += np.sum(mask==0, axis=(0,2,3))

    pos_weight = neg_pix / pos_pix
    print(pos_weight)

In [None]:
model = smp.Unet(
    encoder_name="efficientnet-b4",        # backbone
    encoder_weights="imagenet",            # use pretrained ImageNet weights
    in_channels=3,                         # RGB images
    classes=4,                             # 4 cloud types
    decoder_use_batchnorm=True
)

# Inject dropout layers into decoder blocks
for i, block in enumerate(model.decoder.blocks):
    model.decoder.blocks[i].add_module(f"dropout_{i}", nn.Dropout2d(p=0.5))

In [None]:
#choose model
if model_name == "custom":
    model = Network(num_filters).to(device)
elif model_name == "pretrained":
    model = model.to(device)
else:
    print("Choose a valid Model")

In [None]:
optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3, verbose=True)



In [None]:
# Albumentations ImageNet stats
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])

def denormalize(img_tensor):
    img = img_tensor.permute(1, 2, 0).cpu().numpy()
    img = std * img + mean  # undo normalization
    img = np.clip(img, 0, 1)
    return img

if run_test_cases:

    for image, mask in train_dataloader:
        # Move batch tensors to device
        image, mask = image.to(device), mask.to(device)

        # Loop over images in the batch
        for img, msk in zip(image, mask):
            # Convert tensor -> NumPy for display
            
            print(np.sum(img.cpu().numpy()))
            if np.sum(img.cpu().numpy()) == 0:
                print("Error")

            print("image")
            plt.figure()
            plt.imshow(denormalize(img))
            plt.show(block=False)   # prevents overwriting
            plt.pause(0.1)          # forces GUI flush

            for m in msk:
                print("mask")
                plt.imshow(m.cpu().numpy(), cmap="grey", vmin=0.0, vmax=1.0)
                plt.show()


In [None]:
# Training loop
for epoch in range(num_epochs):

    # Training
    model.train()
    train_loss = 0.0

    # Dynamic weigthing of Dice and BCE loss 
    split = 1.0 - (epoch / (num_epochs - 1)) * 0.5
    
    for image, mask in train_dataloader:
        image, mask = image.to(device), mask.to(device)

        with torch.cuda.amp.autocast():
            preds = model(image)
            loss = loss_fn(preds, mask, split)

        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        train_loss += loss.item()
    train_loss /= len(train_dataloader)

    # Validation
    model.eval()
    val_loss = 0.0
    dice = 0.0

    with torch.no_grad():
        for image, mask in test_dataloader:
            image, mask = image.to(device), mask.to(device)

            preds = model(image)
            loss = loss_fn(preds, mask, split)

            val_loss += loss.item()

            #hard_preds = torch.where(preds > 0.5, 1.0, 0.0)
            dice += dice_coef(preds, mask)
        val_loss /= len(test_dataloader)
        scheduler.step(val_loss)
        dice /= len(test_dataloader)

    # Print metrics
    print(f"Epoch: {epoch}")
    print(f"Train loss: {train_loss:.4f}")
    print(f"Val loss: {val_loss:.4f}")
    print(f"Dice coefficient: {dice:.4f}")
    print()



KeyboardInterrupt: 

In [None]:
%matplotlib inline

num_plots = 10
count = 1

save_img = 3

with torch.no_grad():
    for image, mask in test_dataloader:
        image, mask = image.to(device), mask.to(device)

        preds = model(image)
        loss = loss_fn(preds, mask, split)
        
        mask, preds = mask.cpu().numpy(), preds.cpu().numpy()

        print("Image")
        plt.imshow(denormalize(image[0,:,:,:]))
        if save_img == count:
            plt.savefig("cloud_image.png")
        plt.show()
        
        print("Mask")
        plt.imshow(mask[0,0,:,:], vmax=1.0, vmin=0.0, cmap="Greys")
        if save_img == count:
            plt.savefig("cloud_mask.png")
        plt.show()
        
        print("Prediction")
        plt.imshow(preds[0,0,:,:], vmax=1.0, vmin=0.0, cmap="Greys")
        if save_img == count:
            plt.savefig("cloud_pred.png")
        plt.show()
        print("##########################################")

        count += 1
        if count == num_plots:
            break