In [3]:
pip install rasterio numpy pandas scikit-learn joblib matplotlib


Defaulting to user installation because normal site-packages is not writeable
Note: you may need to restart the kernel to use updated packages.


In [None]:
import os
import rasterio
import numpy as np
import pandas as pd
from sklearn.ensemble import RandomForestRegressor
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
import joblib
import matplotlib.pyplot as plt

# Define the folder paths
ndvi_folder = r'C:\Users\FIqbal\Downloads\NDVI and TSDM\NDVI15'
tsdm_folder = r'C:\Users\FIqbal\Downloads\NDVI and TSDM\TSDM15'

# Function to extract date from filename
def extract_date_from_filename(filename):
    base_name = os.path.basename(filename)
    return base_name[4:12]  # Extract YYYYMMDD

# Load NDVI and TSDM data
def load_raster_data(folder_path):
    data = {}
    for file_name in os.listdir(folder_path):
        if file_name.endswith('.tif'):
            date = extract_date_from_filename(file_name)
            file_path = os.path.join(folder_path, file_name)
            with rasterio.open(file_path) as src:
                data[date] = src.read(1).flatten()  # Flatten to a 1D array
    return data

# Load data
print("Loading NDVI data...")
ndvi_data = load_raster_data(ndvi_folder)
print(f"NDVI data loaded: {len(ndvi_data)} files")

print("Loading TSDM data...")
tsdm_data = load_raster_data(tsdm_folder)
print(f"TSDM data loaded: {len(tsdm_data)} files")

# Align data by date
common_dates = set(ndvi_data.keys()) & set(tsdm_data.keys())
print(f"Common dates found: {len(common_dates)}")

# Prepare dataset
records = []
for date in common_dates:
    ndvi_values = ndvi_data[date]
    tsdm_values = tsdm_data[date]
    if len(ndvi_values) == len(tsdm_values):
        for i in range(len(ndvi_values)):
            records.append([ndvi_values[i], tsdm_values[i]])

df = pd.DataFrame(records, columns=['NDVI', 'TSDM'])
print("Data preview:")
print(df.head())

# Check for missing values
print("Checking for missing values...")
print(df.isnull().sum())

# Fill or drop missing values if needed
df = df.dropna()

# Separate features and target
X = df[['NDVI']]
y = df['TSDM']

# Split the data
print("Splitting the data...")
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Train the model
print("Training the Random Forest model...")
model = RandomForestRegressor(n_estimators=100, random_state=42)
model.fit(X_train, y_train)

# Make predictions
print("Making predictions...")
y_pred = model.predict(X_test)

# Evaluate the model
mse = mean_squared_error(y_test, y_pred)
print(f"Mean Squared Error: {mse}")

# Save the model
model_path = r'C:\Users\FIqbal\Downloads\NDVI and TSDM\New folder\RandomForest_TSDM_Model.pkl'
joblib.dump(model, model_path)
print(f"Model saved to {model_path}")

# Prediction function
def predict_tsdm_from_ndvi(ndvi_raster_path):
    with rasterio.open(ndvi_raster_path) as src:
        ndvi_data = src.read(1).flatten()
        ndvi_data = ndvi_data.reshape(-1, 1)  # Reshape for prediction
        predicted_tsdm = model.predict(ndvi_data)
        return predicted_tsdm.reshape(src.height, src.width)

# Example usage
new_ndvi_raster_path = r'C:\Users\FIqbal\Downloads\NDVI and TSDM\NDVI15\New folder\20240527NDVI15.tif'
print(f"Predicting TSDM from {new_ndvi_raster_path}...")
predicted_tsdm = predict_tsdm_from_ndvi(new_ndvi_raster_path)

# Save predicted TSDM raster
output_path = r'C:\Users\FIqbal\Downloads\NDVI and TSDM\New folder\Predicted_TSDM.tif'
print(f"Saving predicted TSDM raster to {output_path}...")
with rasterio.open(new_ndvi_raster_path) as src:
    profile = src.profile
    profile.update(dtype=rasterio.float32, count=1)

with rasterio.open(output_path, 'w', **profile) as dst:
    dst.write(predicted_tsdm.astype(rasterio.float32), 1)

# Plot results
print("Plotting results...")
plt.figure(figsize=(12, 6))

# Plot NDVI raster
plt.subplot(1, 2, 1)
with rasterio.open(new_ndvi_raster_path) as src:
    ndvi_data = src.read(1)
plt.imshow(ndvi_data, cmap='viridis')
plt.title('NDVI Raster')
plt.colorbar()

# Plot predicted TSDM raster
plt.subplot(1, 2, 2)
plt.imshow(predicted_tsdm, cmap='viridis')
plt.title('Predicted TSDM Raster')
plt.colorbar()

plt.show()
