<a href="https://colab.research.google.com/github/JericN/rice-disease-classifier/blob/main/multispectral/dataset_upload.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [20]:
! pip --quiet install datasets rasterio

In [21]:
from pathlib import Path
import numpy as np
import rasterio
from datasets import Dataset, DatasetDict
from tqdm import tqdm

In [22]:
import os
from google.colab import drive
import zipfile

# Mount Google Drive
drive.mount('/content/drive')
zip_path = '/content/drive/Shareddrives/CS198-Drones/Multispectral/D2.zip'

# Extract the ZIP file
extract_path = '/content/dataset'
if not os.path.exists(extract_path):
    os.makedirs(extract_path)
    with zipfile.ZipFile(zip_path, 'r') as zip_ref:
        zip_ref.extractall(extract_path)

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [23]:
data_dir = Path("/content/dataset/D2")
splits = ["train", "test", "validation"]

In [24]:
# Function to load .tif images
def load_tif_image(image_path):
    with rasterio.open(image_path) as src:
        return src.read().astype(np.float32)  # Shape: (bands, height, width)

# Function to load label masks
def load_label_mask(label_path):
    with rasterio.open(label_path) as src:
        return src.read(1).astype(np.uint8)  # Shape: (height, width)

In [25]:
def preprocess_image(image):
    """Normalize multispectral image to [0,1] and convert to float32."""
    image = np.array(image, dtype=np.float32)
    image = image / image.max()  # Normalize
    return image

def preprocess_label(label):
    """Ensure label is in int64 format."""
    return np.array(label, dtype=np.int64)

In [26]:
def create_dataset(split):
    images, masks = [], []
    image_dir = data_dir / split
    mask_dir = data_dir / f"{split}_labels"

    # Check if subdirectories exist
    if not image_dir.exists() or not mask_dir.exists():
        raise FileNotFoundError(f"Missing '{image_dir}' or '{mask_dir}'. Please check your dataset structure.")

    # Get all image files
    image_files = sorted(image_dir.iterdir())

    for img_file in tqdm(image_files, desc=f"Processing {split}"):
        # Extract patch number from image filename
        patch_number = img_file.stem.split("_")[-1]  # Extract '3' from 'image_patch_3'
        mask_file = mask_dir / f"label_patch_{patch_number}.tif"  # Expected mask filename

        if mask_file.exists():
            # Load and preprocess images/masks
            image = preprocess_image(load_tif_image(img_file))
            label = preprocess_label(load_label_mask(mask_file))

            images.append(image)
            masks.append(label)
        else:
            print(f"WARNING: No mask found for {img_file.name}")

    if not images:
        print(f"WARNING: No images were loaded for {split}")

    return Dataset.from_dict({
        "image": images,  # Shape: (bands, height, width)
        "label": masks      # Shape: (height, width)
    })


In [27]:
# Create datasets
datasets = {split: create_dataset(split) for split in splits}
dataset_dict = DatasetDict(datasets)

  dataset = DatasetReader(path, driver=driver, sharing=sharing, **kwargs)
Processing train: 100%|██████████| 1299/1299 [00:12<00:00, 107.63it/s]
Processing test: 100%|██████████| 281/281 [00:02<00:00, 131.24it/s]
Processing validation: 100%|██████████| 277/277 [00:02<00:00, 126.80it/s]


In [28]:
dataset_dict

DatasetDict({
    train: Dataset({
        features: ['image', 'label'],
        num_rows: 1299
    })
    test: Dataset({
        features: ['image', 'label'],
        num_rows: 281
    })
    validation: Dataset({
        features: ['image', 'label'],
        num_rows: 277
    })
})

In [29]:
import numpy as np

# Function to print dataset info
def print_dataset_info(dataset_dict):
    for split, dataset in dataset_dict.items():
        print(f"--- {split.upper()} SET ---")
        print(f"Number of samples: {len(dataset)}")

        # Get first sample
        sample = dataset[0]
        image = np.array(sample["image"])  # Convert list to NumPy array
        label = np.array(sample["label"])  # Convert list to NumPy array

        # Print shape and data type
        print(f"Image shape: {image.shape} | dtype: {image.dtype}")
        print(f"Label shape: {label.shape} | dtype: {label.dtype}")

        # Print unique label classes
        unique_classes = np.unique(label)
        print(f"Unique classes in label: {unique_classes}")

        print("-" * 40)

# Print dataset information
print_dataset_info(dataset_dict)


--- TRAIN SET ---
Number of samples: 1299
Image shape: (6, 256, 256) | dtype: float64
Label shape: (256, 256) | dtype: int64
Unique classes in label: [0 1]
----------------------------------------
--- TEST SET ---
Number of samples: 281
Image shape: (6, 256, 256) | dtype: float64
Label shape: (256, 256) | dtype: int64
Unique classes in label: [1 4]
----------------------------------------
--- VALIDATION SET ---
Number of samples: 277
Image shape: (6, 256, 256) | dtype: float64
Label shape: (256, 256) | dtype: int64
Unique classes in label: [4]
----------------------------------------


In [31]:
dataset_dict.push_to_hub("SodaXII/blb-ms-02")

Uploading the dataset shards:   0%|          | 0/6 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Uploading the dataset shards:   0%|          | 0/2 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Uploading the dataset shards:   0%|          | 0/2 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

CommitInfo(commit_url='https://huggingface.co/datasets/SodaXII/blb-ms-02/commit/865fd966c16df44cb613774986bccfae5e49ed95', commit_message='Upload dataset', commit_description='', oid='865fd966c16df44cb613774986bccfae5e49ed95', pr_url=None, repo_url=RepoUrl('https://huggingface.co/datasets/SodaXII/blb-ms-02', endpoint='https://huggingface.co', repo_type='dataset', repo_id='SodaXII/blb-ms-02'), pr_revision=None, pr_num=None)