In [None]:
import geopandas as gpd
import leafmap
from shapely.ops import unary_union
from shapely.geometry import Point, mapping, box, shape
import shapely
import json
import os
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from tqdm import tqdm
from einops import rearrange
import sys
sys.path.append("..")

from src.models.datamodule import MineDataModule

In [None]:

os.getcwd()
os.chdir("..")
root = os.path.dirname(os.getcwd())
root = root + "/workspaces/mine-segmentation" # uncomment when running in Lightning Studios
root

In [None]:
TRAIN_CHIP_DIR = "data/processed/chips/npy/512/train/chips/"
TRAIN_LABEL_DIR = "data/processed/chips/npy/512/train/labels/"
VAL_CHIP_DIR = "data/processed/chips/npy/512/val/chips/"
VAL_LABEL_DIR = "data/processed/chips/npy/512/val/labels/"
TEST_CHIP_DIR = "data/processed/chips/npy/512/test/chips/"
TEST_LABEL_DIR = "data/processed/chips/npy/512/test/labels/"

METADATA_PATH = "configs/cnn/cnn_segment_metadata.yaml"
BATCH_SIZE = 1
NUM_WORKERS = 4
PLATFORM = "sentinel-2-l2a"

is_lightning = True
if is_lightning:
    METADATA_PATH = root + "/" + METADATA_PATH
    TRAIN_CHIP_DIR = root +  "/" +TRAIN_CHIP_DIR
    TRAIN_LABEL_DIR = root + "/" + TRAIN_LABEL_DIR
    VAL_CHIP_DIR = root + "/" + VAL_CHIP_DIR
    VAL_LABEL_DIR = root + "/" + VAL_LABEL_DIR
    TEST_CHIP_DIR = root + "/" + TEST_CHIP_DIR
    TEST_LABEL_DIR = root + "/" + TEST_LABEL_DIR

### Check number of chips

In [None]:
# check number of chips
train_chip_files = os.listdir(TRAIN_CHIP_DIR)
val_chip_files = os.listdir(VAL_CHIP_DIR)
test_chip_files = os.listdir(TEST_CHIP_DIR)

len(train_chip_files), len(val_chip_files), len(test_chip_files)

### Plot chips, including data tranformations

In [None]:
import matplotlib.pyplot as plt

datamodule = MineDataModule(
    train_chip_dir=TRAIN_CHIP_DIR,
    train_label_dir=TRAIN_LABEL_DIR,
    val_chip_dir=VAL_CHIP_DIR,
    val_label_dir=VAL_LABEL_DIR,
    test_chip_dir=TEST_CHIP_DIR,
    test_label_dir=TEST_LABEL_DIR,
    metadata_path=METADATA_PATH,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    platform=PLATFORM,
    data_augmentation=True,
)

datamodule.setup(stage="fit")

In [None]:
batch = next(iter(datamodule.train_dataloader()))

batch["pixels"].shape, batch["label"].shape

# squeeze the batch size dimension
pixels = batch["pixels"].squeeze(0)
label = batch["label"].squeeze(0)

# rearrange for plotting
pixels = rearrange(pixels, "c h w -> h w c")
label = rearrange(label, "h w -> h w")

# normalize
pixels = (pixels - pixels.min()) / (pixels.max() - pixels.min())

# plot the image and label side by side
fig, axes = plt.subplots(1, 2, figsize=(10, 5))

# plot the image
axes[0].imshow(pixels)
axes[0].axis("off")
axes[0].set_title("Image", fontsize=12)

# plot the label
axes[1].imshow(label, cmap="viridis")
axes[1].axis("off")
axes[1].set_title("Label", fontsize=12)

plt.tight_layout()
plt.show()

## Plot multiple chips

In [None]:
file_index=10
files = os.listdir(TRAIN_CHIP_DIR)
filename = TRAIN_CHIP_DIR + "/" + files[file_index]

img = np.load(filename)
print(img.shape)
img

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
from skimage.transform import resize

def plot_images_and_masks(root, seed=0):
    """
    Plot randomly selected images (numpy arrays) and their corresponding masks (numpy arrays).

    Parameters:
    - root (str): The root directory path
    - seed (int): The seed value for random number generation.

    Returns:
    None
    """

    chips_dir = TRAIN_CHIP_DIR
    masks_dir = TRAIN_LABEL_DIR

    files = os.listdir(chips_dir)
    fig, axs = plt.subplots(5, 2, figsize=(10, 25))

    # generate 5 random indices in the range of the number of files
    np.random.seed(seed)
    indices = list(np.random.choice(len(files), 5, replace=False))
    print(indices)

    for i, file_index in enumerate(indices):
        filename = os.path.join(chips_dir, files[file_index])
        img = np.load(filename)
        im2display = img.transpose((1, 2, 0))
        im2display = (im2display - im2display.min()) / (im2display.max() - im2display.min())
        im2display = np.clip(im2display, 0, 1)
        
        mask_filename = masks_dir + "/" + files[file_index].replace("_img", "_mask")
        mask = np.load(mask_filename).squeeze()
        
        resized_img = resize(im2display, (im2display.shape[0] // 2, im2display.shape[1] // 2))
        resized_mask = resize(mask, (mask.shape[0] // 2, mask.shape[1] // 2))

        # get date from the filename
        date = files[file_index].split("_")[3][:8]
        date = f"{date[:4]}-{date[4:6]}-{date[6:]}"

        axs[i, 0].imshow(resized_img)
        axs[i, 0].set_title(f"Image from {date}")
        
        axs[i, 1].imshow(resized_mask)
        axs[i, 1].set_title("Mask")
    
    plt.tight_layout()
    plt.show()

plot_images_and_masks(root, seed=5)