In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
"""
This notebook demonstrates how to write out training and validation sets.\
"""

import sys
import os

# Set up imports
project_root = os.path.abspath("..") 
sys.path.append(project_root)  

# Enable efficient use of GPU memory
from config.gpu.gpu_utils import configure_tensorflow_gpu
configure_tensorflow_gpu()

from preprocessing.spark_session import spark  # Reuse the preconfigured SparkSession
import pyspark.sql.functions as F

In [None]:
df = spark.read.parquet("../data/CDL_multiple_scene_ts.parquet")
df = df.withColumn('CDL', F.decode(F.col('CDL'), 'UTF-8'))

df.groupBy('CDL', 'year').count().orderBy('count', ascending=False).show()

In [None]:
df.groupby('year').count().show()

In [None]:
df.groupby('year').mean().show()

In [None]:
import glob

train_files = glob.glob('../data/CDL_multiple_scene_ts.parquet/*/*2021*/*.parquet')   # 2021 â†’ train
val_files   = glob.glob('../data/CDL_multiple_scene_ts.parquet/*/*2020*/*.parquet')   # 2020 â†’ val
test_files  = glob.glob('../data/CDL_unique_scene_ts.parquet/*/*2019*/*.parquet')  # 2019 â†’ test

print("Train files:", len(train_files))
print("Val files:  ", len(val_files))
print("Test files: ", len(test_files))  

In [None]:
import tensorflow as tf 

# Hyperparameters and constants

# Crops we will identify
targeted_cultivated_crops_list = ['Soybeans', 'Rice', 'Corn', 'Cotton']

# Crops we identify as "Cultivated"
other_cultivated_crops_list = [
    'Other Hay/Non Alfalfa', 'Pop or Orn Corn', 'Peanuts', 'Sorghum', 'Oats', 'Peaches',
    'Clover/Wildflowers', 'Pecans', 'Sod/Grass Seed', 'Other Crops', 'Dry Beans', 'Winter Wheat',
    'Alfalfa', 'Potatoes', 'Peas', 'Herbs', 'Rye', 'Cantaloupes', 'Sunflower',
    'Watermelons', 'Sweet Corn', 'Sweet Potatoes'
]

# The label legend
label_legend = ['Uncultivated', 'Cultivated', 'No Crop Growing', 'Soybeans', 'Rice', 'Corn', 'Cotton']

# Define model batch size and time-series bucketing size 
BATCH_SIZE = 1028
DAYS_IN_SERIES = 120
DAYS_PER_BUCKET = 5
MAX_IMAGES_PER_SERIES = (DAYS_IN_SERIES // DAYS_PER_BUCKET) + 1
FRAMES_TO_CHECK = 2
BUCKETING_STRATEGY = "random"
NUM_FEATURES = 16 # 12 bands + 4 indices (or 12 for bands only)

print("ðŸ”¢ MAX_IMAGES_PER_SERIES:", MAX_IMAGES_PER_SERIES)
print("ðŸ“¦ Batch shape: [{} x {}]".format(BATCH_SIZE, MAX_IMAGES_PER_SERIES))

#### Time Series bucketing 

`MAX_IMAGES_PER_SERIES` is calculated based on two parameters: `DAYS_IN_SERIES` and `DAYS_PER_BUCKET`.  
It represents the maximum number of time steps (i.e., satellite images or observations) per pixel over a year.

The formula is:

`MAX_IMAGES_PER_SERIES = (DAYS_IN_SERIES // DAYS_PER_BUCKET) + 1`

*Examples:*
- If `DAYS_IN_SERIES = 120` and `DAYS_PER_BUCKET = 5`, then `MAX_IMAGES_PER_SERIES = 25`
- If `DAYS_IN_SERIES = 100`, then `MAX_IMAGES_PER_SERIES = 21`

The `BATCH_SIZE` parameter defines how many pixels are processed **in parallel** during each training step.  
It refers to different spatial points (locations) in the dataset.  
Each "pixel" here means one location with its own full time series of features.

---

#### What does a single pixelâ€™s time series look like?

Each pixelâ€™s time series is a sequence of feature vectorsâ€”e.g., values for NDVI, red band, NIR band, etc.â€”captured at different time steps:

```python
[
  [0.2, 123],
  [0.3, 118],
  [0.35, 110],
  [0.4, 100],
  [0.45, 95]
]  # shape: (5, 2) = (time_steps, features)
```

In this example:

* There are 5 time steps (observations)

* Each step includes 2 features (e.g., NDVI and surface temperature)

With a size of 1028 pixels and 5 time steps:

```python
batch = np.array([
    [[...], [...], [...], [...], [...]],  # pixel 1
    [[...], [...], [...], [...], [...]],  # pixel 2
    ...
    [[...], [...], [...], [...], [...]]   # pixel 1028
])  # shape: (1028, 5, 2) = (batch_size, time_steps, features)
```

#### Parameter Summary

| Parameter              | Meaning                                                                 |
|------------------------|-------------------------------------------------------------------------|
| `BATCH_SIZE`           | Number of distinct pixels (locations) processed per batch               |
| `MAX_IMAGES_PER_SERIES`| Maximum number of time steps per pixel (e.g., 25 images per year)       |
| `features`             | Number of bands or indices per time step (e.g., NDVI, red, nirâ€¦)        |


#### Estimate mean and stdv to normalize band values

In [None]:
# Normalization
from dataloader import make_from_pandas, filter_double_croppings, parse

train_files_ds = make_from_pandas(train_files)

# Set the normalization flag to False to get the un-normalized data
non_normed_ds = (
    train_files_ds
    .filter(filter_double_croppings)
    .map(lambda x: parse(
        x,
        norm=False,
        means=tf.zeros([NUM_FEATURES], dtype=tf.float32), 
        stds=tf.ones([NUM_FEATURES], dtype=tf.float32),
        label_legend_=label_legend,
        targeted_cultivated_crops_list=targeted_cultivated_crops_list,
        other_cultivated_crops_list=other_cultivated_crops_list,
        days_in_series=DAYS_IN_SERIES,
        days_per_bucket=DAYS_PER_BUCKET,
        max_images_per_series=MAX_IMAGES_PER_SERIES,
        frames_to_check=FRAMES_TO_CHECK,
        bucketing_strategy=BUCKETING_STRATEGY
    ), num_parallel_calls=tf.data.AUTOTUNE)
    .batch(BATCH_SIZE)
)

# Loop through the dataset, saving both the data and associated labels
all_non_normalized_data = []
all_labels = []

for data, label in non_normed_ds:
    all_non_normalized_data.append(data)
    all_labels.append(label)

# Reshape to just get the imagery values - no need to maintain the time-series structure for the following plots
num_features = 18 # 12 bands + 4 indices + 1 SCL + 1 label
all_non_normalized_data = tf.reshape(tf.concat(all_non_normalized_data, axis=0), shape=(-1, num_features))

all_labels = tf.reshape(tf.concat(all_labels, axis=0), shape=(-1, len(label_legend)))

In [None]:
means = tf.math.reduce_mean(tf.ragged.boolean_mask(all_non_normalized_data, mask=(all_non_normalized_data!=0)), axis=0)
stds = tf.math.reduce_std(tf.ragged.boolean_mask(all_non_normalized_data, mask=(all_non_normalized_data!=0)), axis=0)

In [None]:
means[0:NUM_FEATURES] 

In [None]:
stds[0:NUM_FEATURES]

#### Load train and validation data

In [None]:
from dataloader import make_dataset

train_ds, val_ds = make_dataset(
    train_files,
    val_files,
    method="pandas",  # or "tensorflow"
    batch_size=BATCH_SIZE,
    means=means[0:NUM_FEATURES],
    stds=stds[0:NUM_FEATURES],
    label_legend=label_legend,
    targeted_cultivated_crops_list=targeted_cultivated_crops_list,
    other_cultivated_crops_list=other_cultivated_crops_list,
    days_in_series=DAYS_IN_SERIES,
    days_per_bucket=DAYS_PER_BUCKET,
    max_images_per_series=MAX_IMAGES_PER_SERIES,
    frames_to_check=FRAMES_TO_CHECK,
    bucketing_strategy=BUCKETING_STRATEGY,
    augment=False
)

In [None]:
print(train_ds.element_spec)

In [None]:
print(train_ds.element_spec)

In [None]:
train_ds.save(f"../data/train_ds_with_idx_{NUM_FEATURES}f")
val_ds.save(f"../data/val_ds_with_idx_{NUM_EATURES}f")

print("ðŸ’¾ Datasets saved to disk.")

In [None]:
# Example Output: features & one hot labels
X, y = next(iter(train_ds))
print(X)

In [None]:
print(y)

#### Visual checks

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

col_names = ['coastal', 'blue', 'green', 'red', 'rededge1', 'rededge2',
             'rededge3', 'nir', 'nir08', 'nir09', 'swir16', 'swir22',
             'NDVI', 'EVI', 'NWDI', 'NDBI']

# Normalized data
#all_normalized_data = tf.reshape(tf.concat([d[0] for d in train_ds], axis=0), shape=(-1, 12))
all_normalized_data = tf.reshape(tf.concat([d[0] for d in train_ds], axis=0), shape=(-1, 16))
df_norm = pd.DataFrame(all_normalized_data.numpy(), columns=col_names)
df_norm = df_norm.drop_duplicates()  # Ignore padded rows

# Non-normalized data
#df_non_norm = pd.DataFrame(all_non_normalized_data[:, 0:12], columns=col_names)
df_non_norm = pd.DataFrame(all_non_normalized_data[:, 0:16], columns=col_names)
df_non_norm = df_non_norm[df_non_norm != 0].dropna()

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(18, 6))  

sns.violinplot(data=df_norm, ax=ax1)
ax1.set_xticklabels(col_names, rotation=90)
ax1.set_title('Normalized Dataset')
ax1.set_xlabel('Feature')
ax1.set_ylabel('Data Value')
ax1.set_ylim((-2.5, 2.5))  # Optional

sns.violinplot(data=df_non_norm, ax=ax2)
ax2.set_xticklabels(col_names, rotation=90)
ax2.set_title('Non-Normalized Dataset')
ax2.set_xlabel('Feature')
ax2.set_ylabel('Data Value')

plt.tight_layout()
plt.show()


In [None]:
df_non_norm

In [None]:
df_norm

In [None]:
import numpy as np

heights = tf.argmax(all_labels, axis=1).numpy()
plt.bar(label_legend, np.histogram(heights, bins=len(label_legend))[0])
plt.title('Crop Types in Training Set')
plt.xticks(rotation=-45, ha='left')
plt.show()

In [None]:
scl_mapper = {
    0.0: 'No Data',
    1.0: 'Saturated Or Defective',
    2.0: 'Dark Area Pixels',
    3.0: 'Cloud Shadows',
    4.0: 'Vegetation',
    5.0: 'Not Vegetated',
    6.0: 'Water',
    7.0: 'Unclassified',
    8.0: 'Cloud Medium Probability',
    9.0: 'Cloud High Probability',
    10.0: 'Thin Cirrus',
}
fig, axs = plt.subplots(nrows=3, ncols=3, figsize=(15,15))
i = 0
for data, label in non_normed_ds:
    df = pd.DataFrame(data.numpy()[i,:,[-3, -2]].T, columns=['NDVI', 'SCL'])
    df['image in series'] = np.arange(0, df.shape[0], step=1)

    df['SCL Label'] = df.SCL.map(scl_mapper)
    sns.scatterplot(data=df, x='image in series', y='NDVI', hue='SCL Label', ax=axs[i//3, i%3])
    axs[i//3, i%3].set_title(label_legend[tf.argmax(label[i,:]).numpy()])
    
    i += 1
    if i == 9:
        break
    
plt.show()