In [1]:
import os
import numpy as np
import pandas as pd
import torch
from concurrent.futures import ProcessPoolExecutor
from tqdm import tqdm
from datetime import datetime
print("OK")

OK


In [2]:
# === Configuration ===
RAW_ROOT = "/scratch/users/dtran/croptype/dataset/TimeSen2Crop"
SAVE_DIR = "./"
N_BANDS = 9
N_MONTHS = 12
BAND_INDICES = list(range(9))  # B1-B9 → B2-B12
FLAG_COL = 'Flag'
TRAIN_TILES = ['32TNT', '32TPT', '32TQT', '33TUM', '33TUN', '33TVM', '33TVN',
               '33TWM', '33TXN', '33UUP', '33UWP', '33UWQ', '33UXP']
VAL_TILE = ['33TWN']
TEST_TILE = ['33UVP']
print("OK")

OK


# Preprocessing dataset

+ Convert to 12 monthly composites using the median of clear pixels.

+ Remove cloudy/snowy/shadowed values.

+ If no clear value exists in a month, that month’s data is set to zero.

Output (data for train, test, and validation sets):

+ Clean data
    
+ Labels for each sample (each sample (pixel) belongs to which class: 0, ..., 15)

In [6]:
# === Helper Functions ===
def extract_monthly_median(data: pd.DataFrame, dates: list) -> np.ndarray:
    reflectance = data.iloc[:, BAND_INDICES].values
    flags = data[FLAG_COL].values
    months = [datetime.strptime(str(d), "%Y%m%d").month for d in dates]
    monthly = np.zeros((N_MONTHS, N_BANDS), dtype=np.float32)

    for m in range(1, 13):
        idx = [i for i, mo in enumerate(months) if mo == m and flags[i] == 0]
        if idx:
            monthly[m - 1] = np.median(reflectance[idx], axis=0)
        else:
            monthly[m - 1] = 0.0
    return monthly

def build_label_map(root: str) -> dict:
    label_map = {}
    for tile in sorted(os.listdir(root)):
        tile_path = os.path.join(root, tile)
        if not os.path.isdir(tile_path): continue
        for cls in sorted(os.listdir(tile_path)):
            if cls.isdigit() and cls not in label_map:
                label_map[cls] = int(cls)
    return label_map

def process_sample(cls_path, file, date_list, label):
    try:
        df = pd.read_csv(os.path.join(cls_path, file))
        if len(df) != len(date_list):
            return None
        result = extract_monthly_median(df, date_list)
        return result, label
    except Exception as e:
        print(f"Error in {cls_path}/{file}: {e}")
        return None

def process_sample_wrapper(args):
    return process_sample(*args)

def process_tile_parallel(tile: str, label_map: dict) -> list:
    tile_data = []
    metadata = []
    tile_path = os.path.join(RAW_ROOT, tile)
    if not os.path.isdir(tile_path): return tile_data

    dates_path = os.path.join(tile_path, "dates.csv")
    if not os.path.exists(dates_path):
        print(f"Missing {dates_path}, skipping...")
        return tile_data
    try:
        date_list = pd.read_csv(dates_path)["acquisition_date"].tolist()
    except Exception as e:
        print(f"Could not read dates.csv in {tile}: {e}")
        return tile_data

    tasks = []
    for cls in sorted(os.listdir(tile_path)):
        cls_path = os.path.join(tile_path, cls)
        if not os.path.isdir(cls_path) or not cls.isdigit():
            continue
        label = label_map[cls]
        for file in sorted(os.listdir(cls_path)):
            if file.endswith(".csv"):
                tasks.append((cls_path, file, date_list, label))
                metadata.append((tile, label, file))

    with ProcessPoolExecutor(max_workers=16) as executor:
        for result in tqdm(executor.map(process_sample_wrapper, tasks),
                           total=len(tasks), desc=f"{tile} samples"):
            if result is not None:
                tile_data.append(result)
    return tile_data, metadata

def process_split(split_name, tiles, label_map):
    all_data, all_labels = [], []
    all_metadata = []
    if len(tiles) > 1:
        # Parallelize over tiles
        with ProcessPoolExecutor(max_workers=16) as executor:
            for tile_data, metadata in tqdm(
                executor.map(process_tile_parallel, tiles, [label_map]*len(tiles)),
                desc=f"Processing {split_name} tiles"
            ):
                for result, label in tile_data:
                    all_data.append(result)
                    all_labels.append(label)
                all_metadata.append(metadata)
    else:
        # Single tile — still use ProcessPoolExecutor but show sample progress
        tile_data, metadata = process_tile_parallel(tiles[0], label_map)
        for result, label in tile_data:
            all_data.append(result)
            all_labels.append(label)
        all_metadata.append(metadata)

    X = torch.tensor(np.array(all_data), dtype=torch.float32)
    y = torch.tensor(np.array(all_labels), dtype=torch.long)
    torch.save(X, os.path.join(SAVE_DIR, f"{split_name}_X.pt"))
    torch.save(y, os.path.join(SAVE_DIR, f"{split_name}_y.pt"))
    print(f"Saved {split_name}: {X.shape[0]} samples to {split_name}_X.pt / {split_name}_y.pt")

    metadata_flat = [item for sublist in all_metadata for item in sublist]  # Flatten the list of lists
    metadata_df = pd.DataFrame(metadata_flat, columns=["tile", "label", "file"])
    metadata_df.to_csv(os.path.join(SAVE_DIR, f"{split_name}_metadata.csv"), index=False)
    print(f"Saved {split_name} metadata to {split_name}_metadata.csv")


# === Run All ===
if __name__ == "__main__":
    label_map = build_label_map(RAW_ROOT)
    process_split("train", TRAIN_TILES, label_map)
    process_split("val", VAL_TILE, label_map)
    process_split("test", TEST_TILE, label_map)

32TNT samples: 100%|███████████████████████████████████████████████████████████████████| 21801/21801 [01:00<00:00, 361.92it/s]
32TQT samples: 100%|███████████████████████████████████████████████████████████████████| 23084/23084 [01:03<00:00, 361.51it/s]
32TPT samples: 100%|███████████████████████████████████████████████████████████████████| 24532/24532 [01:06<00:00, 367.34it/s]
33TUM samples: 100%|███████████████████████████████████████████████████████████████████| 26017/26017 [01:10<00:00, 371.18it/s]
33TUN samples: 100%|███████████████████████████████████████████████████████████████████| 31162/31162 [01:19<00:00, 391.64it/s]
33TVN samples: 100%|███████████████████████████████████████████████████████████████████| 35637/35637 [01:26<00:00, 411.64it/s]
33TVM samples: 100%|███████████████████████████████████████████████████████████████████| 45293/45293 [01:40<00:00, 450.63it/s]
33TWM samples: 100%|███████████████████████████████████████████████████████████████████| 58786/58786 [01:56<00:

Saved train: 822843 samples to train_X.pt / train_y.pt
Saved train metadata to train_metadata1.csv


33TWN samples: 100%|████████████████████████████████████████████████████████████████| 116369/116369 [00:24<00:00, 4802.62it/s]


Saved val: 116369 samples to val_X.pt / val_y.pt
Saved val metadata to val_metadata1.csv


33UVP samples: 100%|████████████████████████████████████████████████████████████████| 133419/133419 [00:27<00:00, 4897.25it/s]


Saved test: 133419 samples to test_X.pt / test_y.pt
Saved test metadata to test_metadata1.csv


# Checking Output

In [10]:
X_train = torch.load("train_X.pt")
y_train = torch.load("train_y.pt")
metadata_df = pd.read_csv("train_metadata.csv")

# Shape of train dataset: [N, n_months, n_bands], there are N samples, each sample is a n_months x n_bands matrix showing the median values of each band across each month acquired
print("Shape:", X_train.shape, y_train.shape, metadata_df.shape)
#Show example of data
print("Sample data:", X_train[0])#Note: months in the output (1/2018, 2/2018, ..., 7/2018, 9/2017, ..., 12/2017) is arranged in different order from dates.csv file
# Show example of labels
print("Labels:", y_train[:10])
print("Example metadata:", metadata_df.iloc[0])

Shape: torch.Size([822843, 12, 9]) torch.Size([822843]) (822843, 3)
Sample data: tensor([[1583.0000, 1662.0000, 1909.0000, 2089.0000, 2420.0000, 2493.0000,
         2745.0000, 1934.0000, 1819.0000],
        [   0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000],
        [   0.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
            0.0000,    0.0000,    0.0000],
        [ 595.5000,  953.0000, 1248.5000, 1818.0000, 2900.0000, 3278.5000,
         3807.5000, 3476.0000, 2169.5000],
        [ 397.5000,  709.5000,  620.5000, 1206.0000, 2818.0000, 3166.5000,
         3530.0000, 2115.5000, 1058.5000],
        [ 382.0000,  821.0000,  437.0000, 1367.0000, 4502.0000, 5357.0000,
         5826.0000, 2827.0000, 1261.0000],
        [ 385.0000,  792.0000,  565.0000, 1394.0000, 3696.0000, 4280.0000,
         4847.0000, 2754.0000, 1400.0000],
        [ 391.5000,  700.5000,  503.5000, 1210.5000, 3109.0000, 3710.5000,
         4359.0000