In [1]:
# ========================= IMPORTS =========================
import tensorflow as tf  # For F1 score metric (optional)
import tensorflow_hub as hub
import tensorflow_datasets as tfds
import numpy as np
import pickle
from collections import Counter
from imblearn.over_sampling import RandomOverSampler
import os


# ========================= DOWNLOAD AND ORGANIZE DATASET =========================

# Define paths
DATA_PATH = '/content/satellite_data'  # Path to save the dataset in Drive

# Download and extract the dataset
!curl -SL https://storage.googleapis.com/wandb_datasets/dw_train_86K_val_10K.zip > dw_data.zip
!unzip dw_data.zip
!rm dw_data.zip

# Move the extracted dataset to Google Drive
!mkdir -p {DATA_PATH}
!mv droughtwatch_data/train {DATA_PATH}/train
!mv droughtwatch_data/val {DATA_PATH}/val
!rm -r droughtwatch_data  # Clean up

# Verify dataset location and structure
import os
print(f"Train files: {len(os.listdir(DATA_PATH + '/train'))}")
print(f"Validation files: {len(os.listdir(DATA_PATH + '/val'))}")


  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100 2050M  100 2050M    0     0  19.3M      0  0:01:46  0:01:46 --:--:-- 18.1M
Archive:  dw_data.zip
   creating: droughtwatch_data/
   creating: droughtwatch_data/val/
  inflating: droughtwatch_data/val/part-r-00090  
  inflating: droughtwatch_data/val/part-r-00061  
  inflating: droughtwatch_data/val/part-r-00052  
  inflating: droughtwatch_data/val/part-r-00043  
  inflating: droughtwatch_data/val/part-r-00040  
  inflating: droughtwatch_data/val/part-r-00042  
  inflating: droughtwatch_data/val/part-r-00067  
  inflating: droughtwatch_data/val/part-r-00026  
  inflating: droughtwatch_data/val/part-r-00046  
  inflating: droughtwatch_data/val/part-r-00023  
  inflating: droughtwatch_data/val/part-r-00083  
  inflating: droughtwatch_data/val/part-r-00011  
  inflating: droughtwatch_data/val/part-r-00058  
  inflating: droughtwat

In [2]:
import os

# Define the expected DATA_PATH and folder name
DATA_PATH = '/content/satellite_data'
EXPECTED_FOLDER_NAME = 'satellite_imagesfolder'
expected_folder_path = os.path.join(DATA_PATH, EXPECTED_FOLDER_NAME)

# Define your actual images folder location
actual_images_folder = '/content/drive/MyDrive/satelliteImages'

# Create the DATA_PATH directory if it doesn't exist
os.makedirs(DATA_PATH, exist_ok=True)

# Create a symbolic link if it doesn't exist already
if not os.path.exists(expected_folder_path):
    os.symlink(actual_images_folder, expected_folder_path)
    print(f"Symbolic link created: {expected_folder_path} -> {actual_images_folder}")
else:
    print(f"Symbolic link or folder already exists at: {expected_folder_path}")


Symbolic link created: /content/satellite_data/satellite_imagesfolder -> /content/drive/MyDrive/satelliteImages


In [4]:

# ========================= CONFIGURATION =========================
BATCH_SIZE = 32
IMG_SIZE = (64, 64)
TARGET_TRAIN_SAMPLES = 3000  # Target number of augmented satellite images
CIFAR_CLASSES_TO_REMOVE = {'cloud', 'forest', 'mountain', 'plain', 'sea'}

# ========================= DATA PREPARATION FUNCTIONS =========================

# Step 1: Augment Satellite Images from folder
def augment_satellite_images(image_folder, target_count):
    datagen = ImageDataGenerator(
        rotation_range=40,
        width_shift_range=0.2,
        height_shift_range=0.2,
        shear_range=0.2,
        zoom_range=0.2,
        horizontal_flip=True,
        fill_mode='nearest'
    )

    img_generator = datagen.flow_from_directory(
        image_folder,
        target_size=IMG_SIZE,
        batch_size=BATCH_SIZE,
        class_mode='binary',
        shuffle=True
    )

    # Check if images are found. If not, raise an error.
    if img_generator.samples == 0:
        raise ValueError(f"No images found in directory: {image_folder}. " +
                         "Ensure the directory exists and is structured with subdirectories per class.")

    # Create augmented dataset
    augmented_images = []
    num_batches = target_count // BATCH_SIZE + 1
    for _ in range(num_batches):
        batch, _ = next(img_generator)  # Use next() for Python 3
        augmented_images.extend(batch)
        if len(augmented_images) >= target_count:
            break

    # Convert list to a numpy array and take only the required target_count images
    augmented_images = np.array(augmented_images[:target_count])
    labels = np.ones(target_count, dtype=np.int64)  # Label 1 for satellite images

    return tf.data.Dataset.from_tensor_slices((augmented_images, labels))

# Step 2: Load EuroSAT dataset
def load_eurosat():
    # The tf.keras.datasets.eurosat may not be available in some TF versions;
    # if so, use an alternative dataset. Adjust accordingly.
    (train_images, train_labels), (_, _) = tf.keras.datasets.eurosat.load_data()
    labels = np.ones(len(train_images), dtype=np.int64)
    return tf.data.Dataset.from_tensor_slices((train_images, labels))

# Step 3: Load existing TFRecord satellite data
def parse_satellite_tfrecord(record):
    feature_description = {
        "B2": tf.io.FixedLenFeature([], tf.string),
        "B3": tf.io.FixedLenFeature([], tf.string),
        "B4": tf.io.FixedLenFeature([], tf.string),
    }
    parsed = tf.io.parse_single_example(record, feature_description)

    # Convert bands to RGB
    red = tf.io.decode_raw(parsed["B4"], tf.uint8)
    green = tf.io.decode_raw(parsed["B3"], tf.uint8)
    blue = tf.io.decode_raw(parsed["B2"], tf.uint8)

    # Assume the image is square; calculate image size from one band
    img_size = tf.cast(tf.sqrt(tf.cast(tf.shape(red)[0], tf.float32)), tf.int32)
    red = tf.reshape(red, (img_size, img_size))
    green = tf.reshape(green, (img_size, img_size))
    blue = tf.reshape(blue, (img_size, img_size))

    image = tf.stack([red, green, blue], axis=-1)
    image = tf.image.resize(image, IMG_SIZE)
    image = tf.cast(image, tf.float32) / 255.0

    return image, tf.constant(1, dtype=tf.int64)

def load_tfrecord_data(pattern):
    files = tf.io.gfile.glob(pattern)
    return tf.data.TFRecordDataset(files).map(parse_satellite_tfrecord)

# Step 4: Prepare CIFAR-100 dataset with class filtering
def load_filtered_cifar():
    (x_train, y_train), (_, _) = tf.keras.datasets.cifar100.load_data(label_mode='fine')

    # Get class names and filter
    class_names = tf.keras.datasets.cifar100.get_label_names()
    remove_indices = [i for i, name in enumerate(class_names) if name in CIFAR_CLASSES_TO_REMOVE]

    # Create mask for valid classes (invert the mask for removal)
    mask = np.isin(y_train, remove_indices, invert=True).flatten()

    x_filtered = x_train[mask]
    y_filtered = np.zeros(len(x_filtered), dtype=np.int64)  # Label 0 for non-satellite images

    return tf.data.Dataset.from_tensor_slices((x_filtered, y_filtered))

# Step 5: Combine all datasets
def create_final_dataset():
    # Load and combine satellite data
    satellite_folder = os.path.join(DATA_PATH, 'satellite_imagesfolder')
    augmented_ds = augment_satellite_images(satellite_folder, TARGET_TRAIN_SAMPLES)

    try:
        eurosat_ds = load_eurosat()
    except Exception as e:
        print("EuroSAT dataset could not be loaded:", e)
        eurosat_ds = tf.data.Dataset.from_tensor_slices(([], []))

    tfrecord_pattern = os.path.join(DATA_PATH, 'train/part*')
    tfrecord_ds = load_tfrecord_data(tfrecord_pattern)

    # Combine all satellite sources
    satellite_ds = augmented_ds.concatenate(eurosat_ds).concatenate(tfrecord_ds)
    satellite_ds = satellite_ds.map(
        lambda img, lbl: (tf.image.resize(img, IMG_SIZE) / 255.0, lbl),
        num_parallel_calls=tf.data.AUTOTUNE
    )

    # Load and prepare CIFAR data
    cifar_ds = load_filtered_cifar().map(
        lambda img, lbl: (tf.image.resize(img, IMG_SIZE) / 255.0, lbl),
        num_parallel_calls=tf.data.AUTOTUNE
    )

    # Balance datasets by taking the same number of samples from both
    satellite_count = sum(1 for _ in satellite_ds)
    cifar_count = sum(1 for _ in cifar_ds)
    min_count = min(satellite_count, cifar_count)

    balanced_ds = satellite_ds.take(min_count).concatenate(cifar_ds.take(min_count))

    # Shuffle and prepare final dataset
    return balanced_ds.shuffle(buffer_size=min_count * 2).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)


In [5]:
# ========================= IMPORTS =========================
import tensorflow as tf  # For F1 score metric (optional)
import tensorflow_hub as hub
import tensorflow_datasets as tfds
import numpy as np
import pickle
import json
from collections import Counter
from imblearn.over_sampling import RandomOverSampler
import os
import shutil
from tensorflow.keras import layers, models
from tensorflow.keras.preprocessing.image import ImageDataGenerator

# ========================= DOWNLOAD AND ORGANIZE DATASET =========================

# Define paths
DATA_PATH = '/content/satellite_data'  # Path to save the dataset in Drive

# Download and extract the dataset (this dataset is separate from your custom images)
!curl -SL https://storage.googleapis.com/wandb_datasets/dw_train_86K_val_10K.zip > dw_data.zip
!unzip dw_data.zip
!rm dw_data.zip

# Move the extracted dataset to Google Drive folder
!mkdir -p {DATA_PATH}
!mv droughtwatch_data/train {DATA_PATH}/train
!mv droughtwatch_data/val {DATA_PATH}/val
!rm -r droughtwatch_data  # Clean up

# Verify dataset location and structure
print(f"Train files: {len(os.listdir(os.path.join(DATA_PATH, 'train')))}")
print(f"Validation files: {len(os.listdir(os.path.join(DATA_PATH, 'val')))}")

# Define the expected folder name for your satellite images
EXPECTED_FOLDER_NAME = 'satellite_imagesfolder'
expected_folder_path = os.path.join(DATA_PATH, EXPECTED_FOLDER_NAME)

# Define your actual images folder location (update this path if needed)
actual_images_folder = '/content/drive/MyDrive/satelliteImages'

# Create the DATA_PATH directory if it doesn't exist
os.makedirs(DATA_PATH, exist_ok=True)

# Create a symbolic link if it doesn't exist already
if not os.path.exists(expected_folder_path):
    os.symlink(actual_images_folder, expected_folder_path)
    print(f"Symbolic link created: {expected_folder_path} -> {actual_images_folder}")
else:
    print(f"Symbolic link or folder already exists at: {expected_folder_path}")

# ----------------------------------------------------------------
# Check if the expected satellite images folder has subdirectories.
# If not, create a default subdirectory ("class1") and move all image files into it.
subdirs = [d for d in os.listdir(expected_folder_path) if os.path.isdir(os.path.join(expected_folder_path, d))]
if len(subdirs) == 0:
    default_class_folder = os.path.join(expected_folder_path, "class1")
    os.makedirs(default_class_folder, exist_ok=True)
    for file in os.listdir(expected_folder_path):
        file_path = os.path.join(expected_folder_path, file)
        if os.path.isfile(file_path) and file.lower().endswith(('.png', '.jpg', '.jpeg')):
            shutil.move(file_path, os.path.join(default_class_folder, file))
    print(f"Moved image files to default subdirectory: {default_class_folder}")
else:
    print("Subdirectories detected; no need to create a dummy class folder.")

# ========================= CONFIGURATION =========================
BATCH_SIZE = 32
IMG_SIZE = (64, 64)
TARGET_TRAIN_SAMPLES = 3000  # Target number of augmented satellite images
CIFAR_CLASSES_TO_REMOVE = {'cloud', 'forest', 'mountain', 'plain', 'sea'}

# ========================= DATA PREPARATION FUNCTIONS =========================

# Step 1: Augment Satellite Images from folder
def augment_satellite_images(image_folder, target_count):
    datagen = ImageDataGenerator(
        rotation_range=40,
        width_shift_range=0.2,
        height_shift_range=0.2,
        shear_range=0.2,
        zoom_range=0.2,
        horizontal_flip=True,
        fill_mode='nearest'
    )

    img_generator = datagen.flow_from_directory(
        image_folder,
        target_size=IMG_SIZE,
        batch_size=BATCH_SIZE,
        class_mode='binary',
        shuffle=True
    )

    # Check if images are found. If not, raise an error.
    if img_generator.samples == 0:
        raise ValueError(f"No images found in directory: {image_folder}. " +
                         "Ensure the directory exists and is structured with subdirectories per class.")

    # Create augmented dataset
    augmented_images = []
    num_batches = target_count // BATCH_SIZE + 1
    for _ in range(num_batches):
        batch, _ = next(img_generator)  # Use next() for Python 3 iterators
        augmented_images.extend(batch)
        if len(augmented_images) >= target_count:
            break

    # Use the actual number of images generated
    actual_count = len(augmented_images)
    if actual_count < target_count:
        print(f"Warning: Only {actual_count} augmented images generated, less than the target of {target_count}.")
    augmented_images = np.array(augmented_images[:actual_count])
    labels = np.ones(actual_count, dtype=np.int64)  # Label 1 for satellite images

    return tf.data.Dataset.from_tensor_slices((augmented_images, labels))

# Step 2: Load EuroSAT dataset
def load_eurosat():
    # The tf.keras.datasets.eurosat may not be available in all TF versions;
    # if not, use an alternative dataset or comment out this section.
    (train_images, train_labels), (_, _) = tf.keras.datasets.eurosat.load_data()
    labels = np.ones(len(train_images), dtype=np.int64)
    return tf.data.Dataset.from_tensor_slices((train_images, labels))

# Step 3: Load existing TFRecord satellite data
def parse_satellite_tfrecord(record):
    feature_description = {
        "B2": tf.io.FixedLenFeature([], tf.string),
        "B3": tf.io.FixedLenFeature([], tf.string),
        "B4": tf.io.FixedLenFeature([], tf.string),
    }
    parsed = tf.io.parse_single_example(record, feature_description)

    # Convert bands to RGB
    red = tf.io.decode_raw(parsed["B4"], tf.uint8)
    green = tf.io.decode_raw(parsed["B3"], tf.uint8)
    blue = tf.io.decode_raw(parsed["B2"], tf.uint8)

    # Assume the image is square; calculate image size from one band
    img_size = tf.cast(tf.sqrt(tf.cast(tf.shape(red)[0], tf.float32)), tf.int32)
    red = tf.reshape(red, (img_size, img_size))
    green = tf.reshape(green, (img_size, img_size))
    blue = tf.reshape(blue, (img_size, img_size))

    image = tf.stack([red, green, blue], axis=-1)
    image = tf.image.resize(image, IMG_SIZE)
    image = tf.cast(image, tf.float32) / 255.0

    return image, tf.constant(1, dtype=tf.int64)

def load_tfrecord_data(pattern):
    files = tf.io.gfile.glob(pattern)
    return tf.data.TFRecordDataset(files).map(parse_satellite_tfrecord)

# Step 4: Prepare CIFAR-100 dataset with class filtering
def load_filtered_cifar():
    (x_train, y_train), (_, _) = tf.keras.datasets.cifar100.load_data(label_mode='fine')

    # Get class names and filter
    class_names = tf.keras.datasets.cifar100.get_label_names()
    remove_indices = [i for i, name in enumerate(class_names) if name in CIFAR_CLASSES_TO_REMOVE]

    # Create mask for valid classes (invert the mask for removal)
    mask = np.isin(y_train, remove_indices, invert=True).flatten()

    x_filtered = x_train[mask]
    y_filtered = np.zeros(len(x_filtered), dtype=np.int64)  # Label 0 for non-satellite images

    return tf.data.Dataset.from_tensor_slices((x_filtered, y_filtered))

# Step 5: Combine all datasets
def create_final_dataset():
    # Load and combine satellite data
    satellite_folder = os.path.join(DATA_PATH, 'satellite_imagesfolder')
    augmented_ds = augment_satellite_images(satellite_folder, TARGET_TRAIN_SAMPLES)

    try:
        eurosat_ds = load_eurosat()
    except Exception as e:
        print("EuroSAT dataset could not be loaded:", e)
        eurosat_ds = tf.data.Dataset.from_tensor_slices(([], []))

    tfrecord_pattern = os.path.join(DATA_PATH, 'train/part*')
    tfrecord_ds = load_tfrecord_data(tfrecord_pattern)

    # Combine all satellite sources
    satellite_ds = augmented_ds.concatenate(eurosat_ds).concatenate(tfrecord_ds)
    satellite_ds = satellite_ds.map(
        lambda img, lbl: (tf.image.resize(img, IMG_SIZE) / 255.0, lbl),
        num_parallel_calls=tf.data.AUTOTUNE
    )

    # Load and prepare CIFAR data
    cifar_ds = load_filtered_cifar().map(
        lambda img, lbl: (tf.image.resize(img, IMG_SIZE) / 255.0, lbl),
        num_parallel_calls=tf.data.AUTOTUNE
    )

    # Balance datasets by taking the same number of samples from both
    satellite_count = sum(1 for _ in satellite_ds)
    cifar_count = sum(1 for _ in cifar_ds)
    min_count = min(satellite_count, cifar_count)

    balanced_ds = satellite_ds.take(min_count).concatenate(cifar_ds.take(min_count))

    # Shuffle and prepare final dataset
    return balanced_ds.shuffle(buffer_size=min_count * 2).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)

# ========================= MODEL DEFINITION =========================

def create_model():
    model = models.Sequential([
        layers.Input(shape=(IMG_SIZE[0], IMG_SIZE[1], 3)),
        layers.Conv2D(32, (3, 3), activation='relu'),
        layers.MaxPooling2D((2, 2)),
        layers.Conv2D(64, (3, 3), activation='relu'),
        layers.MaxPooling2D((2, 2)),
        layers.Conv2D(128, (3, 3), activation='relu'),
        layers.Flatten(),
        layers.Dense(128, activation='relu'),
        layers.Dropout(0.5),
        layers.Dense(1, activation='sigmoid')
    ])

    model.compile(optimizer='adam',
                  loss='binary_crossentropy',
                  metrics=['accuracy'])
    return model

# ========================= TRAINING AND EVALUATION =========================

def train():
    # Create the full dataset
    full_dataset = create_final_dataset()
    dataset_batches = list(full_dataset)  # Materialize dataset to count samples
    dataset_size = len(dataset_batches) * BATCH_SIZE

    if dataset_size == 0:
        raise ValueError("The combined dataset is empty. Please check your data sources.")

    # Calculate sizes for train and validation splits
    val_size = int(0.2 * dataset_size)
    train_size = dataset_size - val_size

    # Determine number of batches for train and validation
    train_batches = train_size // BATCH_SIZE
    val_batches = len(dataset_batches) - train_batches

    train_ds = full_dataset.take(train_batches)
    val_ds = full_dataset.skip(train_batches)

    # Create and train model
    model = create_model()
    history = model.fit(
        train_ds,
        validation_data=val_ds,
        epochs=15,
        callbacks=[tf.keras.callbacks.EarlyStopping(patience=3, restore_best_weights=True)]
    )

    # Save model and training history
    model.save('/content/satellite_classifier.keras')
    with open('/content/training_history.json', 'w') as f:
        json.dump(history.history, f)

    return model, history

# ========================= START TRAINING =========================

trained_model, training_history = train()


  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100 2050M  100 2050M    0     0  22.1M      0  0:01:32  0:01:32 --:--:-- 23.1M
Archive:  dw_data.zip
   creating: droughtwatch_data/
   creating: droughtwatch_data/val/
  inflating: droughtwatch_data/val/part-r-00090  
  inflating: droughtwatch_data/val/part-r-00061  
  inflating: droughtwatch_data/val/part-r-00052  
  inflating: droughtwatch_data/val/part-r-00043  
  inflating: droughtwatch_data/val/part-r-00040  
  inflating: droughtwatch_data/val/part-r-00042  
  inflating: droughtwatch_data/val/part-r-00067  
  inflating: droughtwatch_data/val/part-r-00026  
  inflating: droughtwatch_data/val/part-r-00046  
  inflating: droughtwatch_data/val/part-r-00023  
  inflating: droughtwatch_data/val/part-r-00083  
  inflating: droughtwatch_data/val/part-r-00011  
  inflating: droughtwatch_data/val/part-r-00058  
  inflating: droughtwat

TypeError: Incompatible dataset elements:
  (TensorSpec(shape=(64, 64, 3), dtype=tf.float32, name=None), TensorSpec(shape=(), dtype=tf.int64, name=None)) vs.   (TensorSpec(shape=(), dtype=tf.float32, name=None), TensorSpec(shape=(), dtype=tf.float32, name=None))

In [6]:
import tensorflow as tf
import numpy as np
import json
import os
import shutil
from tensorflow.keras import layers, models
from tensorflow.keras.preprocessing.image import ImageDataGenerator


In [None]:
# ========================= CONFIGURATION =========================
BATCH_SIZE = 32
IMG_SIZE = (64, 64)
TARGET_TRAIN_SAMPLES = 3000  # Target number of augmented satellite images
CIFAR_CLASSES_TO_REMOVE = {'cloud', 'forest', 'mountain', 'plain', 'sea'}

# Global list of CIFAR-100 fine labels (100 classes)
CIFAR100_FINE_LABELS = [
    'apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee', 'beetle', 'bicycle', 'bottle',
    'bowl', 'boy', 'bridge', 'bus', 'butterfly', 'camel', 'can', 'castle', 'caterpillar', 'cattle',
    'chair', 'chimpanzee', 'clock', 'cloud', 'cockroach', 'couch', 'crab', 'crocodile', 'cup', 'dinosaur',
    'dolphin', 'elephant', 'flatfish', 'forest', 'fox', 'girl', 'hamster', 'house', 'kangaroo', 'keyboard',
    'lamp', 'lawn_mower', 'leopard', 'lion', 'lizard', 'lobster', 'man', 'maple_tree', 'motorcycle', 'mountain',
    'mouse', 'mushroom', 'oak_tree', 'orange', 'orchid', 'otter', 'palm_tree', 'pear', 'pickup_truck', 'pine_tree',
    'plain', 'plate', 'poppy', 'porcupine', 'possum', 'rabbit', 'raccoon', 'ray', 'road', 'rocket',
    'rose', 'sea', 'seal', 'shark', 'shrew', 'skunk', 'skyscraper', 'snail', 'snake', 'spider',
    'squirrel', 'streetcar', 'sunflower', 'sweet_pepper', 'table', 'tank', 'telephone', 'television', 'tiger',
    'tractor', 'train', 'trout', 'tulip', 'turtle', 'wardrobe', 'whale', 'willow_tree', 'wolf', 'woman', 'worm'
]

# ========================= DATA PREPARATION FUNCTIONS =========================

# Step 1: Augment Satellite Images from folder
def augment_satellite_images(image_folder, target_count):
    datagen = ImageDataGenerator(
        rotation_range=40,
        width_shift_range=0.2,
        height_shift_range=0.2,
        shear_range=0.2,
        zoom_range=0.2,
        horizontal_flip=True,
        fill_mode='nearest'
    )

    img_generator = datagen.flow_from_directory(
        image_folder,
        target_size=IMG_SIZE,
        batch_size=BATCH_SIZE,
        class_mode='binary',
        shuffle=True
    )

    # Check if images are found. If not, raise an error.
    if img_generator.samples == 0:
        raise ValueError(f"No images found in directory: {image_folder}. " +
                         "Ensure the directory exists and is structured with subdirectories per class.")

    # Create augmented dataset
    augmented_images = []
    num_batches = target_count // BATCH_SIZE + 1
    for _ in range(num_batches):
        batch, _ = next(img_generator)  # Use next() for Python 3 iterators
        augmented_images.extend(batch)
        if len(augmented_images) >= target_count:
            break

    # Use the actual number of images generated (if less than target, print a warning)
    actual_count = len(augmented_images)
    if actual_count < target_count:
        print(f"Warning: Only {actual_count} augmented images generated, less than the target of {target_count}.")
    augmented_images = np.array(augmented_images[:actual_count])
    labels = np.ones(actual_count, dtype=np.int64)  # Label 1 for satellite images

    return tf.data.Dataset.from_tensor_slices((augmented_images, labels))

# Step 2: Load EuroSAT dataset
def load_eurosat():
    # Attempt to load EuroSAT; if not available, this will raise an exception.
    (train_images, train_labels), (_, _) = tf.keras.datasets.eurosat.load_data()
    labels = np.ones(len(train_images), dtype=np.int64)
    return tf.data.Dataset.from_tensor_slices((train_images, labels))

# Step 3: Load existing TFRecord satellite data
def parse_satellite_tfrecord(record):
    feature_description = {
        "B2": tf.io.FixedLenFeature([], tf.string),
        "B3": tf.io.FixedLenFeature([], tf.string),
        "B4": tf.io.FixedLenFeature([], tf.string),
    }
    parsed = tf.io.parse_single_example(record, feature_description)

    # Convert bands to RGB
    red = tf.io.decode_raw(parsed["B4"], tf.uint8)
    green = tf.io.decode_raw(parsed["B3"], tf.uint8)
    blue = tf.io.decode_raw(parsed["B2"], tf.uint8)

    # Assume the image is square; calculate image size from one band
    img_size = tf.cast(tf.sqrt(tf.cast(tf.shape(red)[0], tf.float32)), tf.int32)
    red = tf.reshape(red, (img_size, img_size))
    green = tf.reshape(green, (img_size, img_size))
    blue = tf.reshape(blue, (img_size, img_size))

    image = tf.stack([red, green, blue], axis=-1)
    image = tf.image.resize(image, IMG_SIZE)
    image = tf.cast(image, tf.float32) / 255.0

    return image, tf.constant(1, dtype=tf.int64)

def load_tfrecord_data(pattern):
    files = tf.io.gfile.glob(pattern)
    if not files:
        # Create an empty dataset with the expected element spec:
        empty_images = np.empty((0, IMG_SIZE[0], IMG_SIZE[1], 3), dtype=np.float32)
        empty_labels = np.empty((0,), dtype=np.int64)
        return tf.data.Dataset.from_tensor_slices((empty_images, empty_labels))
    return tf.data.TFRecordDataset(files).map(parse_satellite_tfrecord)

# Step 4: Prepare CIFAR-100 dataset with class filtering
def load_filtered_cifar():
    (x_train, y_train), (_, _) = tf.keras.datasets.cifar100.load_data(label_mode='fine')
    # Determine which CIFAR classes to remove using our global CIFAR100_FINE_LABELS list.
    remove_indices = [i for i, name in enumerate(CIFAR100_FINE_LABELS) if name in CIFAR_CLASSES_TO_REMOVE]
    # Create mask for valid classes (invert the mask for removal)
    mask = np.isin(y_train, remove_indices, invert=True).flatten()
    x_filtered = x_train[mask]
    y_filtered = np.zeros(len(x_filtered), dtype=np.int64)  # Label 0 for non-satellite images
    return tf.data.Dataset.from_tensor_slices((x_filtered, y_filtered))

# Step 5: Combine all datasets
def create_final_dataset():
    # Load and combine satellite data
    satellite_folder = os.path.join(DATA_PATH, 'satellite_imagesfolder')
    augmented_ds = augment_satellite_images(satellite_folder, TARGET_TRAIN_SAMPLES)

    try:
        eurosat_ds = load_eurosat()
    except Exception as e:
        print("EuroSAT dataset could not be loaded:", e)
        empty_images = np.empty((0, IMG_SIZE[0], IMG_SIZE[1], 3), dtype=np.float32)
        empty_labels = np.empty((0,), dtype=np.int64)
        eurosat_ds = tf.data.Dataset.from_tensor_slices((empty_images, empty_labels))

    tfrecord_pattern = os.path.join(DATA_PATH, 'train/part*')
    tfrecord_ds = load_tfrecord_data(tfrecord_pattern)

    # Combine all satellite sources
    satellite_ds = augmented_ds.concatenate(eurosat_ds).concatenate(tfrecord_ds)
    satellite_ds = satellite_ds.map(
        lambda img, lbl: (tf.image.resize(img, IMG_SIZE) / 255.0, lbl),
        num_parallel_calls=tf.data.AUTOTUNE
    )

    # Load and prepare CIFAR data
    cifar_ds = load_filtered_cifar().map(
        lambda img, lbl: (tf.image.resize(img, IMG_SIZE) / 255.0, lbl),
        num_parallel_calls=tf.data.AUTOTUNE
    )

    # Balance datasets by taking the same number of samples from both
    satellite_count = sum(1 for _ in satellite_ds)
    cifar_count = sum(1 for _ in cifar_ds)
    min_count = min(satellite_count, cifar_count)
    print(f"Balancing datasets to {min_count} samples each.")

    balanced_ds = satellite_ds.take(min_count).concatenate(cifar_ds.take(min_count))

    # Shuffle and prepare final dataset
    return balanced_ds.shuffle(buffer_size=min_count * 2).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)

# ========================= MODEL DEFINITION =========================

def create_model():
    model = models.Sequential([
        layers.Input(shape=(IMG_SIZE[0], IMG_SIZE[1], 3)),
        layers.Conv2D(32, (3, 3), activation='relu'),
        layers.MaxPooling2D((2, 2)),
        layers.Conv2D(64, (3, 3), activation='relu'),
        layers.MaxPooling2D((2, 2)),
        layers.Conv2D(128, (3, 3), activation='relu'),
        layers.Flatten(),
        layers.Dense(128, activation='relu'),
        layers.Dropout(0.5),
        layers.Dense(1, activation='sigmoid')
    ])

    model.compile(optimizer='adam',
                  loss='binary_crossentropy',
                  metrics=['accuracy'])
    return model

# ========================= TRAINING AND EVALUATION =========================

def train():
    # Create the full dataset
    full_dataset = create_final_dataset()
    dataset_batches = list(full_dataset)  # Materialize dataset to count batches
    dataset_size = len(dataset_batches) * BATCH_SIZE

    if dataset_size == 0:
        raise ValueError("The combined dataset is empty. Please check your data sources.")

    # Calculate sizes for train and validation splits
    val_size = int(0.2 * dataset_size)
    train_size = dataset_size - val_size

    # Determine number of batches for train and validation
    train_batches = train_size // BATCH_SIZE
    print(f"Training batches: {train_batches}, Total batches: {len(dataset_batches)}")
    train_ds = full_dataset.take(train_batches)
    val_ds = full_dataset.skip(train_batches)

    # Create and train model
    model = create_model()
    history = model.fit(
        train_ds,
        validation_data=val_ds,
        epochs=15,
        callbacks=[tf.keras.callbacks.EarlyStopping(patience=3, restore_best_weights=True)]
    )

    # Save model and training history
    model.save('/content/satellite_classifier.keras')
    with open('/content/training_history.json', 'w') as f:
        json.dump(history.history, f)

    return model, history

# ========================= START TRAINING =========================

trained_model, training_history = train()


Found 205 images belonging to 1 classes.
EuroSAT dataset could not be loaded: module 'keras.api.datasets' has no attribute 'eurosat'
Downloading data from https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz
[1m169001437/169001437[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m13s[0m 0us/step
Balancing datasets to 47500 samples each.
Training batches: 2375, Total batches: 2969
Epoch 1/15
   2373/Unknown [1m61s[0m 6ms/step - accuracy: 0.9722 - loss: 0.1032

