In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt
from datasets.data_module import CloudCoverDataModule
from pathlib import Path
from models.segformer.lightning_module import LightningSegFormer

In [None]:
data_module = CloudCoverDataModule(
    train_X_folder_path=Path("../../data/final/public/train_features/"),
    train_y_folder_path=Path("../../data/final/public/train_labels/"),
    test_X_folder_path=Path("../../data/final/private/test_features/"),
    test_y_folder_path=Path("../../data/final/private/test_labels/"),
    train_batch_size=4,
    val_batch_size=3,
    test_batch_size=3,
    val_size=0.2,
    random_state=42
)

data_module.prepare_data()

data_module.setup(stage="test")

segformer = LightningSegFormer.load_from_checkpoint('segformer_b5-epoch=44-val_loss=0.21.ckpt')


In [None]:
def denormalize_to_rgb(
    x: np.ndarray,
    mean: list = [0.485, 0.456, 0.406, 0.3568],
    std: list = [0.229, 0.224, 0.225, 0.2076],
    min_val: int = 0,
    max_val: int = 1
):
    mean = np.array(mean).reshape(1, -1, 1, 1)
    std = np.array(std).reshape(1, -1, 1, 1)

    x = (x * std) + mean
    
    x = x * (max_val - min_val) + min_val
    
    return np.array(x[:, :3, :, :] * 255, dtype=np.uint8)

In [None]:
def overlay_mask(img, mask, color=[255, 0, 0], alpha=0.5):
    img = img.copy()
    for c in range(3):
        img[:, :, c] = np.where(mask == 1,
                                img[:, :, c] * (1 - alpha) + alpha * color[c],
                                img[:, :, c])
    return img

In [None]:
def display_preds(X, y, y_hat):
    torch.cuda.empty_cache()

    for i in range(X.shape[0]):
        img = X[i].transpose((1, 2, 0))

        mask_pred = y_hat[i].astype(np.uint8)
        mask_gt = y[i].astype(np.uint8)

        overlayed_img_pred = overlay_mask(img, mask_pred, color=[255, 0, 0])
        overlayed_img_gt = overlay_mask(img, mask_gt, color=[0, 255, 0])

        plt.figure(figsize=(15, 5))

        plt.subplot(1, 3, 1)
        plt.imshow(img)
        plt.title("Original Image")
        plt.axis('off')

        plt.subplot(1, 3, 2)
        plt.imshow(overlayed_img_pred)
        plt.title("Prediction Overlay")
        plt.axis('off')

        plt.subplot(1, 3, 3)
        plt.imshow(overlayed_img_gt)
        plt.title("Ground Truth Overlay")
        plt.axis('off')

        plt.show()

In [None]:
MAX_BATCHES = 2

In [None]:
data_loader = data_module.test_dataloader()

for i, (X, y) in enumerate(data_loader):
    if i == MAX_BATCHES:
        break
    print(i)
    X = X.to('cuda')
    y = y.to('cuda')
    
    y_hat = segformer(X)
    y_hat = y_hat.argmax(dim=1)
    
    y_hat =  y_hat.detach().cpu().numpy()
    X = denormalize_to_rgb(X.detach().cpu().numpy())
    y = y.detach().cpu().numpy()
    
    display_preds(X, y, y_hat)