In [None]:
from google.colab import drive
drive.mount('/content/drive')

#!pip install rioxarray
import numpy as np
import random
import os
import rasterio
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors

data = np.load("/content/drive/MyDrive/NEW_MODIS_Combined/new_data.npz")

X = data["X"]
y = data["y"]

In [None]:
from keras.layers import Multiply, GlobalAveragePooling2D, GlobalMaxPooling2D, Reshape, Dense, Conv2D, Add, Activation, Concatenate
import tensorflow as tf
from tensorflow.keras.layers import Input, Conv2D, BatchNormalization, Activation, UpSampling2D, Concatenate, Layer
from tensorflow.keras.applications import ResNet50V2
from tensorflow.keras.models import Model

def channel_attention(input_tensor, ratio=8):
    avg_pool = GlobalAveragePooling2D()(input_tensor)
    max_pool = GlobalMaxPooling2D()(input_tensor)

    avg_pool = Reshape((1, 1, avg_pool.shape[-1]))(avg_pool)
    max_pool = Reshape((1, 1, max_pool.shape[-1]))(max_pool)
    avg_pool = Dense(input_tensor.shape[-1] // ratio)(avg_pool)
    avg_pool = Activation('relu')(avg_pool)
    avg_pool = Dense(input_tensor.shape[-1])(avg_pool)

    max_pool = Dense(input_tensor.shape[-1] // ratio)(max_pool)
    max_pool = Activation('relu')(max_pool)
    max_pool = Dense(input_tensor.shape[-1])(max_pool)

    channel_attention = Add()([avg_pool, max_pool])
    channel_attention = Activation('sigmoid')(channel_attention)

    return Multiply()([input_tensor, channel_attention])

class SpatialAttention(Layer):
    def __init__(self, kernel_size=7, **kwargs):
        super(SpatialAttention, self).__init__(**kwargs)
        self.kernel_size = kernel_size
        self.conv = Conv2D(1, (self.kernel_size, self.kernel_size), padding="same", activation="sigmoid")

    def call(self, input_tensor):
        avg_pool = tf.reduce_mean(input_tensor, axis=-1, keepdims=True)
        max_pool = tf.reduce_max(input_tensor, axis=-1, keepdims=True)
        concat = Concatenate(axis=-1)([avg_pool, max_pool])
        spatial_attention = self.conv(concat)
        return Multiply()([input_tensor, spatial_attention])

def cbam_block(input_tensor, ratio=8):
    channel_attention_map = channel_attention(input_tensor, ratio)

    output_tensor = SpatialAttention()(channel_attention_map)

    return output_tensor

def conv_block(input, num_filters):
    x = Conv2D(num_filters, 3, padding="same")(input)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)
    x = Conv2D(num_filters, 3, padding="same")(x)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)
    return x

class ResizeLayer(Layer):
    def call(self, input):
        input_tensor, target_tensor = input
        target_height = tf.shape(target_tensor)[1]
        target_width = tf.shape(target_tensor)[2]
        return tf.image.resize(input_tensor, [target_height, target_width])
def decoder_block(input_tensor, skip_tensor, num_filters):
    x = ResizeLayer()([input_tensor, skip_tensor])
    x = Concatenate()([x, skip_tensor])
    x = conv_block(x, num_filters)

    x = cbam_block(x)

    return x

def build_resnetv2_unet(input_shape=(128, 128, 3)):
    inputs = Input(input_shape)

    base_model = ResNet50V2(include_top=False, weights="imagenet", input_tensor=inputs)

    s1 = base_model.get_layer("conv1_conv").output       # 64x64
    s2 = base_model.get_layer("conv2_block3_out").output # 32x32
    s3 = base_model.get_layer("conv3_block4_out").output # 16x16
    s4 = base_model.get_layer("conv4_block6_out").output # 8x8
    b1 = base_model.get_layer("conv5_block3_out").output # 4x4 (bridge)

    d1 = decoder_block(b1, s4, 512)  # 4→8
    d2 = decoder_block(d1, s3, 256)  # 8→16
    d3 = decoder_block(d2, s2, 128)  # 16→32
    d4 = decoder_block(d3, s1, 64)   # 32→64

    x = UpSampling2D((2, 2))(d4)     # 64→128
    x = Conv2D(32, 3, padding="same", activation="relu")(x)

    outputs = Conv2D(4, 1, padding="same", activation="softmax")(x)

    model = Model(inputs, outputs, name="ResNetV2_U-Net_with_CBAM")
    return model


In [None]:
from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau
from keras.losses import SparseCategoricalCrossentropy
import tensorflow.keras.backend as K

def focal_loss(alpha=0.25, gamma=2.0):
    def loss(y_true, y_pred):
        y_true = tf.cast(y_true, tf.int32)
        y_true = tf.one_hot(y_true, depth=4)  # Convert to one-hot
        y_pred = K.clip(y_pred, K.epsilon(), 1. - K.epsilon())

        # Calculate focal loss
        cross_entropy = -y_true * K.log(y_pred)
        weight = alpha * K.pow(1. - y_pred, gamma)
        focal_loss = weight * cross_entropy

        return K.mean(focal_loss)
    return loss

In [None]:
def dice_coefficient(y_true, y_pred, smooth=1e-6):
    y_true = K.cast(y_true, 'int32')
    y_true_one_hot = K.one_hot(y_true, num_classes=4)  # Convert to one-hot [batch, H, W, 4]
    y_pred = K.cast(y_pred, 'float32')
    intersection = K.sum(y_true_one_hot * y_pred, axis=[1, 2])
    union = K.sum(y_true_one_hot, axis=[1, 2]) + K.sum(y_pred, axis=[1, 2])
    dice = K.mean((2. * intersection + smooth) / (union + smooth), axis=0)
    return K.mean(dice)

In [None]:
input_shape = X.shape[1:]
model = build_resnetv2_unet(input_shape)

Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/resnet/resnet50v2_weights_tf_dim_ordering_tf_kernels_notop.h5
[1m94668760/94668760[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 0us/step


In [None]:
from tensorflow.keras.callbacks import ReduceLROnPlateau, EarlyStopping
callbacks = [
    ReduceLROnPlateau(monitor='val_loss', patience=2, factor=0.5, verbose=1),
    EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)
]

In [None]:
from keras.metrics import MeanIoU
class MeanIoUMetric(tf.keras.metrics.MeanIoU):
    def __init__(self, num_classes, name='mean_iou', **kwargs):
        super().__init__(num_classes=num_classes, name=name, **kwargs)

    def update_state(self, y_true, y_pred, sample_weight=None):
        y_pred = tf.argmax(y_pred, axis=-1)
        # No need to squeeze y_true if it's already shape (batch, H, W)
        return super().update_state(y_true, y_pred, sample_weight)

model.compile(
    optimizer='adam',
    loss=focal_loss(alpha=0.75, gamma=3.0),
    metrics=[ "accuracy", dice_coefficient, MeanIoUMetric(num_classes=4)]
)

# Train the model
model.fit(X, y, epochs=5, callbacks=callbacks)
