In [None]:
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import (
    Conv2D,
    MaxPooling2D,
    Flatten,
    Dense,
    Dropout,
    BatchNormalization,
)
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import Adam
from sklearn.model_selection import KFold
from tensorflow.keras.preprocessing.image import (
    img_to_array,
    load_img,
    ImageDataGenerator,
)
from PIL import Image
import os
import glob
import matplotlib.pyplot as plt

In [None]:
# Function to convert .tif to .png
def convert_tif_to_png(folder_path):
    for img_path in glob.glob(folder_path + "/*.tif"):
        img = Image.open(img_path)
        if img.mode != "RGB":
            img = img.convert("RGB")
        png_path = img_path.replace(".tif", ".png")
        img.save(png_path, "PNG")
        os.remove(img_path)

In [None]:
# Function to preprocess the images
def preprocess_images(image_paths, target_size=(224, 224)):
    images = []
    labels = []
    for label, path in image_paths.items():
        img = load_img(path, target_size=target_size)
        img_array = img_to_array(img)
        img_array = img_array / 255.0  # Normalize the images
        images.append(img_array)
        labels.append(int(label.split("_")[0]))
    return np.array(images), np.array(labels)

In [None]:
# Define a simple CNN model
def create_model(input_shape):
    model = Sequential(
        [
            Conv2D(32, (3, 3), activation="relu", input_shape=input_shape),
            MaxPooling2D((2, 2)),
            Conv2D(64, (3, 3), activation="relu"),
            MaxPooling2D((2, 2)),
            Conv2D(128, (3, 3), activation="relu"),
            MaxPooling2D((2, 2)),
            Flatten(),
            Dense(256, activation="relu"),
            Dropout(0.5),
            BatchNormalization(),
            Dense(1, activation="linear"),  # For regression output
        ]
    )
    model.compile(optimizer="adam", loss="mean_squared_error")
    return model

In [None]:
# Function to process and tile images with overlap
def process_and_tile_images_with_overlap(image_path, tile_size=(224, 224), overlap=0.4):
    img = Image.open(image_path)
    img_width, img_height = img.size
    tiles = []
    stride = int(tile_size[0] * (1 - overlap))

    for x in range(0, img_width - tile_size[0], stride):
        for y in range(0, img_height - tile_size[1], stride):
            box = (x, y, x + tile_size[0], y + tile_size[1])
            tile = img.crop(box)
            tile_array = img_to_array(tile)
            tile_array = tile_array / 255.0  # Normalize the tiles
            tiles.append(tile_array)

    return np.array(tiles)

In [None]:
# Function to predict and sum trees with post-processing and thresholding
def predict_and_sum_trees(
    model, image_path, tile_size=(224, 224), overlap_threshold=0.4, count_threshold=0.5
):
    tiles = process_and_tile_images_with_overlap(
        image_path, tile_size=tile_size, overlap=overlap_threshold
    )
    tile_predictions = model.predict(tiles)
    # Apply a threshold to the predictions to count only confident predictions
    tile_predictions = (tile_predictions > count_threshold).astype(int)
    # Adjust the post-processing logic if necessary
    adjusted_count = np.sum(tile_predictions)
    return adjusted_count

In [None]:
def visualize_model(model, layer_id, filter_id, image):
    layer = model.layers[layer_id]
    layer_output = layer.output
    submodel = tf.keras.models.Model(inputs=model.inputs, outputs=layer_output)

    feature_map = submodel.predict(image[np.newaxis, ...])
    if layer_output.shape[-1] > 1:  # Only visualize if the layer has filters
        filter_activation = feature_map[0, :, :, filter_id]
        plt.matshow(filter_activation, cmap="viridis")
        plt.title(f"Layer {layer.name} Filter {filter_id}")
        plt.show()

In [None]:
# Define a directory to save/load augmented images
augmented_images_dir = "augmented_images"

In [None]:
# Function to save augmented images
def save_augmented_images(images, labels, directory):
    if not os.path.exists(directory):
        os.makedirs(directory)
    for i, (image, label) in enumerate(zip(images, labels)):
        filename = f"augmented_image_{label}_{i}.png"
        image_path = os.path.join(directory, filename)
        image = (image * 255).astype(np.uint8)  # Convert back to 0-255 range
        img = Image.fromarray(image)
        img.save(image_path)

In [None]:
# Function to load augmented images
def load_augmented_images(directory):
    images = []
    labels = []
    for image_path in glob.glob(directory + "/*.png"):
        label = int(os.path.basename(image_path).split("_")[2])
        img = load_img(image_path)
        img_array = img_to_array(img)
        img_array = img_array / 255.0  # Normalize the images
        images.append(img_array)
        labels.append(label)
    return np.array(images), np.array(labels)

In [None]:
# Assuming image_paths contains paths for all images and their labels
image_paths = {
    "13_trees": "13_trees.png",
    "15_trees": "15_trees.png",
    "19_trees": "19_trees.png",
    "20_trees": "20_trees.png",
    "21_trees": "21_trees.png",
    "1_trees": "example.png",
    "22_trees": "22_trees.png",
    "34_trees": "34_trees.png",
    "35_trees": "35_trees.png",
    "36_trees": "36_trees.png",
    "41_trees": "41_trees.png",
}

In [None]:
# Preprocess the images
images, labels = preprocess_images(image_paths)

In [None]:
# Define data augmentation configuration
data_generator = ImageDataGenerator(
    rotation_range=20,
    width_shift_range=0.1,
    height_shift_range=0.1,
    shear_range=0.1,
    zoom_range=0.2,
    horizontal_flip=True,
    vertical_flip=True,
    fill_mode="nearest",
)

In [None]:
# Check if augmented images exist, if not, create and save them
if (
    not os.path.exists(augmented_images_dir)
    or len(os.listdir(augmented_images_dir)) == 0
):
    # Augment data to create a larger dataset
    augmented_images, augmented_labels = [], []
    for i in range(len(images)):
        img, label = images[i], labels[i]
        img = img[np.newaxis, ...]
        num_augmented = 0
        for batch in data_generator.flow(img, batch_size=1):
            augmented_image = batch[0]
            augmented_images.append(augmented_image)
            augmented_labels.append(label)
            num_augmented += 1
            if num_augmented == (200 // len(images)):
                break
    augmented_images = np.array(augmented_images)
    augmented_labels = np.array(augmented_labels)
    save_augmented_images(augmented_images, augmented_labels, augmented_images_dir)
else:
    # Load augmented images
    augmented_images, augmented_labels = load_augmented_images(augmented_images_dir)

In [None]:
# Define K-fold cross validation
kf = KFold(n_splits=5, shuffle=True, random_state=42)

In [None]:
# Perform K-fold cross validation
fold_no = 1
for train_index, test_index in kf.split(images):
    train_images, test_images = images[train_index], images[test_index]
    train_labels, test_labels = labels[train_index], labels[test_index]

    # Create a new model for each fold
    model = create_model(input_shape=(224, 224, 3))
    print(f"Training fold {fold_no}...")

    # Train the model
    history = model.fit(
        train_images,
        train_labels,
        batch_size=5,
        epochs=100,
        validation_data=(test_images, test_labels),
    )

    # Increase fold number
    fold_no += 1

In [None]:
# After K-fold cross validation, you can train a final model on all available data
final_model = create_model(input_shape=(224, 224, 3))
final_model.fit(augmented_images, augmented_labels, batch_size=5, epochs=100)

In [None]:
# Save the final model
final_model.save("final_tree_counting_model.h5")

In [None]:
# Convert .tif to .png before prediction if necessary
unlabeled_folder_path = "unlabeled"
convert_tif_to_png(unlabeled_folder_path)

In [None]:
overlap_threshold = 0.05

In [None]:
# Path to the specific unlabeled image you want to count trees in
unlabeled_image_name = "PalmTreePlantation_transparent_mosaic_group1.png"
unlabeled_image_path = os.path.join(unlabeled_folder_path, unlabeled_image_name)

In [None]:
# Predict and count trees in the unlabeled image
tree_count = predict_and_sum_trees(
    final_model,
    unlabeled_image_path,
    tile_size=(224, 224),
    overlap_threshold=overlap_threshold,
)
print(f"Image: {unlabeled_image_name}, Predicted Tree Count: {tree_count}")

In [None]:
example_index = 0  # Index of the image to visualize
visualize_model(
    final_model, layer_id=0, filter_id=0, image=augmented_images[example_index]
)