In [36]:
import tensorflow as tf
import glob
import numpy as np

# --- CONFIGURATION ---
data_path = '/Users/abhicado/Library/CloudStorage/GoogleDrive-agoda.wonders@gmail.com/My Drive/CS6140_Project_Data/*.tfrecord.gz'
files = glob.glob(data_path)

feature_description = {
    'B1': tf.io.VarLenFeature(tf.float32),
    'B2': tf.io.VarLenFeature(tf.float32),
    'B3': tf.io.VarLenFeature(tf.float32),
    'B4': tf.io.VarLenFeature(tf.float32),
    'B5': tf.io.VarLenFeature(tf.float32),
    'B6': tf.io.VarLenFeature(tf.float32),
    'B7': tf.io.VarLenFeature(tf.float32),
    'B8': tf.io.VarLenFeature(tf.float32),
    'B8A': tf.io.VarLenFeature(tf.float32),
    'B9': tf.io.VarLenFeature(tf.float32),
    # 'B10': REMOVED
    'B11': tf.io.VarLenFeature(tf.float32),
    'B12': tf.io.VarLenFeature(tf.float32),
    'classification': tf.io.FixedLenFeature([], tf.float32),
}

def _parse_function(example_proto):
    parsed = tf.io.parse_single_example(example_proto, feature_description)
    for key in parsed:
        if isinstance(parsed[key], tf.SparseTensor):
            parsed[key] = tf.sparse.to_dense(parsed[key])
    return parsed

raw_dataset = tf.data.TFRecordDataset(files[1000], compression_type='GZIP')
parsed_dataset = raw_dataset.map(_parse_function).shuffle(raw_dataset.cardinality())

# Define the order of bands you want in your final tensor
band_names = ['B1', 'B2', 'B3', 'B4', 'B5', 'B6', 'B7', 'B8', 'B8A', 'B9', 'B11', 'B12']

print("--- COMBINING BANDS ---")
for record in parsed_dataset.take(1):
    # 1. Grab all the band tensors
    bands_list = [record[b] for b in band_names]
    
    # 2. Stack them into a single tensor
    # Current shape of each: (16641,)
    # Resulting shape: (16641, 12) -> Pixel-last format
    image_flat = tf.stack(bands_list, axis=-1)
    
    print(f"Flat Stacked Shape: {image_flat.shape}")
    
    # 3. Reshape to (Height, Width, Channels)
    # We calculate side length: sqrt(16641) = 129
    num_pixels = image_flat.shape[0]
    side = int(num_pixels ** 0.5)
    
    image_3d = tf.reshape(image_flat, (side, side, len(band_names)))
    
    print(f"Final 3D Image Shape: {image_3d.shape}")
    print(f"Label: {record['classification'].numpy()}")

--- COMBINING BANDS ---
Flat Stacked Shape: (16641, 12)
Final 3D Image Shape: (129, 129, 12)
Label: 25.0
