In [1]:
import pandas as pd

In [None]:
# --- Standard Imports ---
import pandas as pd
import numpy as np
import datetime
from sklearn.metrics import r2_score, mean_absolute_error, mean_squared_error
from IPython.display import display
import traceback
import warnings
import os

# --- PyTorch & Forecasting Imports ---
import torch
import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_forecasting import TimeSeriesDataSet, TemporalFusionTransformer
from pytorch_forecasting.data import GroupNormalizer # Example normalizer
from pytorch_forecasting.metrics import MAE, SMAPE, QuantileLoss # Example metrics for model internal use

# Suppress common warnings
warnings.filterwarnings("ignore")
# Set seeds for reproducibility (optional)
# pl.seed_everything(42)

# --- Configuration ---
DATA_PATH_TRAIN = r"E:\elevatetrsest\crop price predictor\Crop_price_Prediction\data\edited_nashik_2002_2023.csv"
DATA_PATH_VALIDATION = r"E:\elevatetrsest\crop price predictor\Crop_price_Prediction\data\edited_nashik_test_2024.csv"

TARGET_COLUMNS = ['avg_min_price', 'avg_max_price', 'avg_modal_price'] # We will process one target at a time
# Define the primary target for this run (TFT often works best focusing on one main target)
# If you want to predict all three, you might need separate models or a multi-target setup (more complex)
PRIMARY_TARGET = 'avg_modal_price' # <--- FOCUS ON ONE TARGET FOR SIMPLICITY FIRST

DATE_COLUMN = 'full_date' # Name for the combined date column
YEAR_COL = 'year'; MONTH_COL = 'month'; DAY_COL = 'date' # Source columns
VALIDATION_YEAR = 2024

# --- User Selections & Encoding Maps ---
SELECTED_STATE_STR = "Maharashtra"
SELECTED_DISTRICT_STR = "Nashik"
SELECTED_COMMODITY_STR = "Wheat"

# Corrected Frequency Encoding Maps provided by user
state_name_encoding_map = {"maharashtra": 6291}
district_name_encoding_map = {"nashik": 6291}
commodity_name_encoding_map = {"wheat": 6291}

# --- TFT Configuration ---
# How many time steps back model should look
# Needs careful tuning based on data frequency and patterns (e.g., ~3-6 months for daily data)
ENCODER_LENGTH = 90
# How many time steps ahead to predict (relevant if predicting sequences, for single step use 1)
# For validation against 2024 data, we predict step-by-step or use the validation length
PREDICTION_LENGTH = 1 # Predict one step ahead for validation simplicity for now
BATCH_SIZE = 64      # Adjust based on GPU memory
NUM_WORKER = 0      # Set > 0 if using multi-processing for dataloaders (can cause issues on Windows)
TRAINER_GPUS = 0     # Set to 1 or more if GPU is available (requires CUDA setup)
LEARNING_RATE = 0.005 # Example learning rate
HIDDEN_SIZE = 32     # Example hidden size for TFT network layers
ATTENTION_HEADS = 4  # Example attention heads
DROPOUT = 0.1        # Example dropout rate
MAX_EPOCHS = 15      # Example max epochs (use EarlyStopping)


# --- Helper Functions (Outlier removal - can be skipped if data is clean) ---
def remove_outliers_iqr(df, columns_to_check):
    # ... (same as before) ...
    df_filtered = df.copy(); initial_rows = len(df_filtered)
    valid_columns = [col for col in columns_to_check if col in df_filtered.columns and pd.api.types.is_numeric_dtype(df_filtered[col])]
    if not valid_columns: return df_filtered
    subset_for_iqr = df_filtered[valid_columns]
    Q1 = subset_for_iqr.quantile(0.25); Q3 = subset_for_iqr.quantile(0.75); IQR = Q3 - Q1
    mask = ~((subset_for_iqr < (Q1 - 1.5 * IQR)) | (subset_for_iqr > (Q3 + 1.5 * IQR))).any(axis=1)
    df_filtered = df_filtered[mask]; rows_removed = initial_rows - len(df_filtered)
    if rows_removed > 0: print(f"Removed {rows_removed} rows via IQR.")
    return df_filtered

# --- Data Loading and Preprocessing Function (Adapted for TFT base) ---
def load_and_preprocess_base_data(path, date_col_name, year_col, month_col, day_col, target_cols, dataset_name="Training"):
    """Loads data, constructs date, basic cleaning - NO TFT specific features yet."""
    print("-" * 30); print(f"Processing {dataset_name} Dataset"); print("-" * 30)
    try:
        print(f"Loading {dataset_name} data from {path}..."); df = pd.read_csv(path); print(f"Loaded {len(df)} rows.")
        # 1. Construct Date
        date_components_cols = [year_col, month_col, day_col]
        if not all(col in df.columns for col in date_components_cols): print(f"Error: Date component cols missing: {[c for c in date_components_cols if c not in df.columns]}"); return None
        for col in date_components_cols: df[col] = pd.to_numeric(df[col], errors='coerce')
        df.dropna(subset=date_components_cols, inplace=True)
        print(f"Constructing '{date_col_name}'...");
        df[date_col_name] = pd.to_datetime({'year': df[year_col], 'month': df[month_col], 'day': df[day_col]}, errors='coerce')
        initial_rows_date = len(df); df.dropna(subset=[date_col_name], inplace=True)
        if initial_rows_date > len(df): print(f"Dropped {initial_rows_date - len(df)} rows due to invalid date components.")
        print(f"{len(df)} rows after date construction.")

        # 2. Ensure Price columns are numeric (only for the ones we might use)
        all_potential_targets = ['avg_min_price', 'avg_max_price', 'avg_modal_price']
        for col in all_potential_targets:
            if col in df.columns: df[col] = pd.to_numeric(df[col], errors='coerce')
            else: print(f"Warning: Price column '{col}' not found.")
        df.dropna(subset=all_potential_targets, how='any', inplace=True) # Drop if any price is missing
        print(f"{len(df)} rows after ensuring price columns numeric.")

        # 3. Drop OTHER unused columns
        cols_to_drop = ['calculationType', 'district_id', 'change',
                        'district_name_enc', 'commodity_name_enc', 'state_name_enc']
        existing_cols_to_drop = [col for col in cols_to_drop if col in df.columns]
        if existing_cols_to_drop: df.drop(columns=existing_cols_to_drop, axis=1, inplace=True)

        # 4. Apply IQR Outlier Removal (Optional)
        # df = remove_outliers_iqr(df, all_potential_targets) # Apply if needed

        # 5. Check required columns (encoded filters + date + targets)
        required_numeric_filter_cols = ['state_name', 'district_name', 'commodity_name']
        required_cols = [date_col_name] + all_potential_targets + required_numeric_filter_cols
        missing_req_cols = [col for col in required_cols if col not in df.columns]
        if missing_req_cols: print(f"Error: Required columns missing: {missing_req_cols}"); print(f"Available: {df.columns.tolist()}"); return None
        for col in required_numeric_filter_cols:
             if not pd.api.types.is_numeric_dtype(df[col]): print(f"Error: Col '{col}' expected numeric but isn't."); return None

        df.sort_values(date_col_name, inplace=True)
        print(f"{dataset_name} base data loaded. {len(df)} rows.")
        return df

    except FileNotFoundError: print(f"Error: {dataset_name} file not found at {path}"); return None
    except Exception as e: print(f"Error loading/preprocessing {dataset_name}: {e}"); traceback.print_exc(); return None


# --- Evaluation Metrics Function ---
def calculate_metrics(y_true, y_pred):
    # ... (same as before) ...
    y_true = np.array(y_true); y_pred = np.array(y_pred)
    if len(y_true) == 0 or len(y_pred) == 0: return np.nan, np.nan, np.nan
    if len(y_true) != len(y_pred):
        min_len = min(len(y_true), len(y_pred)); print(f"Warn: Mismatch metrics. Truncating to {min_len}.")
        if min_len == 0 : return np.nan, np.nan, np.nan
        y_true = y_true[:min_len]; y_pred = y_pred[:min_len]
    try:
        r2 = r2_score(y_true, y_pred); mae = mean_absolute_error(y_true, y_pred); mse = mean_squared_error(y_true, y_pred)
        return r2, mae, mse
    except Exception as e: print(f"Error calculating metrics: {e}"); return np.nan, np.nan, np.nan


# --- Plotting Function for Validation ---
def plot_validation_results(validation_actuals_df, validation_preds_df, target_column, date_col_name, title):
    """Plots actuals vs predictions for validation period."""
    fig = go.Figure(); target_label = target_column.replace("avg_", "").replace("_price", "").capitalize()

    # Ensure dataframes have the required columns and are sorted
    validation_actuals_df = validation_actuals_df.sort_values(by=date_col_name)
    validation_preds_df = validation_preds_df.sort_values(by='time_idx') # preds often keyed by time_idx

    # Add Actual Validation Data trace
    fig.add_trace(go.Scatter(
        x=validation_actuals_df[date_col_name],
        y=validation_actuals_df[target_column],
        mode='lines+markers', name=f'Actual {target_label} ({VALIDATION_YEAR})',
        line=dict(color='blue'), marker=dict(size=4)
    ))

    # Add Predicted Validation Data trace
    # Need to align predictions with actual dates if lengths differ slightly
    # Assuming predictions correspond to the time indices in validation_actuals_df for simplicity here
    # A more robust approach would merge based on time_idx
    if len(validation_actuals_df) == len(validation_preds_df):
         fig.add_trace(go.Scatter(
             x=validation_actuals_df[date_col_name], # Use actual dates for plotting preds
             y=validation_preds_df['prediction'].iloc[:, 0], # Assuming single prediction output
             mode='lines', name=f'Predicted {target_label} ({VALIDATION_YEAR})',
             line=dict(color='red')
         ))
    else:
         print(f"Warning: Length mismatch in plot data. Actuals: {len(validation_actuals_df)}, Preds: {len(validation_preds_df)}")
         # Attempt plot if possible, might be misaligned
         fig.add_trace(go.Scatter(
             x=validation_actuals_df[date_col_name].iloc[:len(validation_preds_df)], # Use actual dates for plotting preds
             y=validation_preds_df['prediction'].iloc[:len(validation_actuals_df), 0], # Assuming single prediction output
             mode='lines', name=f'Predicted {target_label} ({VALIDATION_YEAR} - Potential Mismatch)',
             line=dict(color='orange')
         ))


    fig.update_layout(title=title, xaxis_title=f'Date ({VALIDATION_YEAR})', yaxis_title=f'Price ({target_label})', hovermode="x unified", legend_title_text='Legend')
    return fig


# --- Main Execution Block ---
print("--- Temporal Fusion Transformer Forecasting & Validation ---")
print(f"--- (Nashik/Wheat: 2002-2023 Train, {VALIDATION_YEAR} Validate) ---")

# 1. Load Base Data
df_train_base = load_and_preprocess_base_data(DATA_PATH_TRAIN, DATE_COLUMN, YEAR_COL, MONTH_COL, DAY_COL, TARGET_COLUMNS, "Training (2002-2023)")
df_val_base = load_and_preprocess_base_data(DATA_PATH_VALIDATION, DATE_COLUMN, YEAR_COL, MONTH_COL, DAY_COL, TARGET_COLUMNS, f"Validation ({VALIDATION_YEAR})")

# Proceed only if both datasets loaded
if df_train_base is not None and df_val_base is not None:

    # 2. Get Encoded Values for Filtering
    try:
        selected_state_key=SELECTED_STATE_STR.strip().lower(); selected_district_key=SELECTED_DISTRICT_STR.strip().lower(); selected_commodity_key=SELECTED_COMMODITY_STR.strip().lower()
        encoded_state = state_name_encoding_map.get(selected_state_key); encoded_district = district_name_encoding_map.get(selected_district_key); encoded_commodity = commodity_name_encoding_map.get(selected_commodity_key)
        lookup_failed = False
        if encoded_state is None: print(f"Error: State '{SELECTED_STATE_STR}' missing map."); lookup_failed=True
        if encoded_district is None: print(f"Error: District '{SELECTED_DISTRICT_STR}' missing map."); lookup_failed=True
        if encoded_commodity is None: print(f"Error: Commodity '{SELECTED_COMMODITY_STR}' missing map."); lookup_failed=True
        if lookup_failed: print("Check maps."); df_train_base=df_val_base=None
        else: print(f"\nSelected: {SELECTED_STATE_STR}/{SELECTED_DISTRICT_STR}/{SELECTED_COMMODITY_STR} -> Encoded: St={encoded_state}, Di={encoded_district}, Co={encoded_commodity}")
    except Exception as e: print(f"Error mapping lookup: {e}"); df_train_base = df_val_base = None

# Proceed only if lookup succeeded
if df_train_base is not None and df_val_base is not None:

    # 3. Filtering Data Based on ENCODED Values
    print(f"\nFiltering datasets using encoded values...")
    filter_cols_num = ['state_name', 'district_name', 'commodity_name']

    # Filter Training Data
    if not all(col in df_train_base.columns for col in filter_cols_num): print("Error: Encoded filter cols missing Training."); filtered_df_train = pd.DataFrame()
    else:
        filtered_df_train = df_train_base[(df_train_base['state_name'] == encoded_state) & (df_train_base['district_name'] == encoded_district) & (df_train_base['commodity_name'] == encoded_commodity)].copy()
        filtered_df_train.sort_values(by=DATE_COLUMN, inplace=True)

    # Filter Validation Data
    if not all(col in df_val_base.columns for col in filter_cols_num): print("Error: Encoded filter cols missing Validation."); filtered_df_val = pd.DataFrame()
    else:
        filtered_df_val = df_val_base[(df_val_base['state_name'] == encoded_state) & (df_val_base['district_name'] == encoded_district) & (df_val_base['commodity_name'] == encoded_commodity)].copy()
        filtered_df_val.sort_values(by=DATE_COLUMN, inplace=True)

    if filtered_df_train.empty: print("\nWarning: No training data after filtering.")
    if filtered_df_val.empty: print("\nWarning: No validation data after filtering.")

    # 4. Prepare Data for TFT (Focusing on PRIMARY_TARGET)
    if not filtered_df_train.empty and not filtered_df_val.empty:
        print(f"\nPreparing data for TFT (Target: {PRIMARY_TARGET})...")

        # Combine temporarily for global time_idx calculation
        filtered_df_train["dataset_type"] = "train"
        filtered_df_val["dataset_type"] = "val"
        # Ensure validation data follows training data directly for time_idx continuity
        data_comb = pd.concat([filtered_df_train, filtered_df_val], ignore_index=True)
        data_comb = data_comb.sort_values(DATE_COLUMN)

        # Create time index
        data_comb['time_idx'] = (data_comb[DATE_COLUMN] - data_comb[DATE_COLUMN].min()).dt.days

        # Create group ID (only one series after filtering for Nashik/Wheat)
        data_comb['group_id'] = f"{SELECTED_DISTRICT_STR}_{SELECTED_COMMODITY_STR}_{PRIMARY_TARGET}"

        # Add time features (Known Future Inputs) - ensure they are strings for categorical embedding
        data_comb['month'] = data_comb[DATE_COLUMN].dt.month.astype(str)
        data_comb['day_of_week'] = data_comb[DATE_COLUMN].dt.dayofweek.astype(str)
        data_comb['day_of_year'] = data_comb[DATE_COLUMN].dt.dayofyear.astype(str)
        # Add year as real (continuous) known input
        data_comb['year_real'] = data_comb[DATE_COLUMN].dt.year

        # Convert target to float32 (recommended)
        data_comb[PRIMARY_TARGET] = data_comb[PRIMARY_TARGET].astype(np.float32)

        # Add lags or other features if desired (Time Varying Unknown)
        # Example: data_comb[f'{PRIMARY_TARGET}_lag1'] = data_comb.groupby('group_id')[PRIMARY_TARGET].shift(1)
        # data_comb = data_comb.dropna(subset=[f'{PRIMARY_TARGET}_lag1']) # Drop rows with NaN lags

        # Separate back into train and validation sets
        train_data_tft = data_comb[data_comb["dataset_type"] == "train"].copy()
        val_data_tft = data_comb[data_comb["dataset_type"] == "val"].copy()

        # --- Create TimeSeriesDataSet ---
        print("Creating TimeSeriesDataSet...")
        try:
            # Define Target Normalizer (fitted only on training data)
            target_normalizer = GroupNormalizer(groups=["group_id"], transformation="softplus") # or "standard" or None

            training_cutoff = train_data_tft["time_idx"].max() # Last time index in training data

            training_dataset = TimeSeriesDataSet(
                train_data_tft,
                time_idx="time_idx",
                target=PRIMARY_TARGET,
                group_ids=["group_id"],
                max_encoder_length=ENCODER_LENGTH,
                max_prediction_length=PREDICTION_LENGTH,
                # Static features (only group_id here as others are constant)
                static_categoricals=["group_id"],
                # Time-varying known features
                time_varying_known_categoricals=["month", "day_of_week", "day_of_year"],
                time_varying_known_reals=["year_real", "time_idx"], # time_idx needed as real feature
                # Time-varying unknown features (add lags here if created)
                time_varying_unknown_reals=[PRIMARY_TARGET], # Target needs to be here if no lags/other unknowns
                target_normalizer=target_normalizer,
                add_relative_time_idx=True, # Recommended
                add_target_scales=True,     # Recommended
                add_encoder_length=True,
                allow_missing_timesteps=True    # Recommended
            )

            # Create validation dataset using data from training dataset
            # Crucial: allow_missings=True needed because validation starts right after train
            validation_dataset = TimeSeriesDataSet.from_dataset(
                training_dataset,
                data_comb[data_comb["time_idx"] > training_cutoff - ENCODER_LENGTH], # Need encoder length overlap
                predict=False, # Change to True if only predicting, False if evaluating loss
                stop_randomization=True,
                # allow_missing_timesteps=True # Use if gaps exist, implies predict=True needed? Check docs
            )

            # Create dataloaders
            train_dataloader = training_dataset.to_dataloader(train=True, batch_size=BATCH_SIZE, num_workers=NUM_WORKER)
            val_dataloader = validation_dataset.to_dataloader(train=False, batch_size=BATCH_SIZE * 2, num_workers=NUM_WORKER) # Larger batch size for validation often ok

            print("TimeSeriesDataSet and Dataloaders created.")

            # --- 5. Define Model & Trainer ---
            print("\nDefining TFT model and Trainer...")
            # Define callbacks
            early_stop_callback = EarlyStopping(monitor="val_loss", min_delta=1e-4, patience=5, verbose=True, mode="min")
            lr_monitor = LearningRateMonitor()
            # logger = TensorBoardLogger("tb_logs", name="tft_wheat_price") # Optional: for TensorBoard logging

            trainer = pl.Trainer(
                max_epochs=MAX_EPOCHS,
                gpus=TRAINER_GPUS,
                gradient_clip_val=0.1, # Helps prevent exploding gradients
                limit_train_batches=30,  # Limit batches per epoch for faster iteration/debugging (remove for full run)
                limit_val_batches=10,   # Limit validation batches (remove for full run)
                # fast_dev_run=True, # Uncomment for quick test run (1 batch train/val)
                callbacks=[lr_monitor, early_stop_callback],
                # logger=logger,
                # progress_bar_refresh_rate=30 # Deprecated, use enable_progress_bar=True
                enable_progress_bar=True
            )

            # Define TFT model from dataset parameters
            tft = TemporalFusionTransformer.from_dataset(
                training_dataset,
                learning_rate=LEARNING_RATE,
                hidden_size=HIDDEN_SIZE,
                attention_head_size=ATTENTION_HEADS,
                dropout=DROPOUT,
                hidden_continuous_size=HIDDEN_SIZE // 2, # Example configuration
                output_size=7,  # Number of quantiles to predict (usually 7 for P10, P50, P90 etc.)
                loss=QuantileLoss(), # Use QuantileLoss for probabilistic forecasts
                # reduce_on_plateau_patience=4 # Adjust LR scheduler patience
            )
            print(f"Number of parameters in network: {tft.size()/1e3:.1f}k")

            # --- 6. Train Model ---
            print("\nStarting model training...")
            trainer.fit(
                tft,
                train_dataloaders=train_dataloader,
                val_dataloaders=val_dataloader,
            )
            print("Training finished.")

            # --- 7. Evaluate on Validation Set ---
            print(f"\n--- Evaluating FINAL TFT Model Performance on {VALIDATION_YEAR} Validation Data ---")
            # Load best model checkpoint
            best_model_path = trainer.checkpoint_callback.best_model_path
            print(f"Loading best model from: {best_model_path}")
            best_tft = TemporalFusionTransformer.load_from_checkpoint(best_model_path)

            # Predict on validation data
            # Predict using the dataloader to automatically handle encoder/decoder lengths
            print("Predicting on validation data...")
            predictions_raw = best_tft.predict(val_dataloader, return_index=True, return_decoder_lengths=True)
            # predictions_raw object contains 'prediction', 'index', 'decoder_lengths'

            # Extract actuals corresponding to the predictions
            actuals_raw = torch.cat([y[0] for _, y in iter(val_dataloader)])
            # Align actuals and predictions using the index returned by predict
            # This part can be tricky, ensure alignment based on time_idx and group_id
            # For a single series, alignment might be simpler if lengths match
            actuals_flat = actuals_raw.view(-1) # Flatten the actuals tensor
            # Predictions are often quantile forecasts, take the median (index 3 for 7 quantiles)
            predictions_median = predictions_raw.prediction[:, :, 3].view(-1) # Flatten the median prediction tensor

            # --- Ensure lengths match for metric calculation ---
            min_len_eval = min(len(actuals_flat), len(predictions_median))
            if len(actuals_flat) != len(predictions_median):
                 print(f"Warning: Mismatch evaluation lengths. Actuals: {len(actuals_flat)}, Preds: {len(predictions_median)}. Truncating.")
            actuals_aligned = actuals_flat[:min_len_eval].cpu().numpy() # Move to CPU and convert to numpy
            preds_aligned = predictions_median[:min_len_eval].cpu().numpy()

            # Inverse transform (if normalizer was used)
            # Get the normalizer fitted on the training data
            target_scaler = training_dataset.target_normalizer

            # Check if inverse_transform expects shape (n_samples, 1) or just (n_samples,)
            # Reshape if necessary based on the specific normalizer
            # Example for GroupNormalizer - may need adjustments
            # Create dummy group info matching the structure expected by the normalizer
            inverse_transform_df_actuals = pd.DataFrame({
                PRIMARY_TARGET: actuals_aligned,
                'group_id': [training_dataset.decoded_index.group_ids.iloc[0]] * len(actuals_aligned) # Use the group id
            })
            inverse_transform_df_preds = pd.DataFrame({
                PRIMARY_TARGET: preds_aligned,
                'group_id': [training_dataset.decoded_index.group_ids.iloc[0]] * len(preds_aligned)
            })

            actuals_inv = target_scaler.inverse_transform(inverse_transform_df_actuals)[PRIMARY_TARGET].values
            preds_inv = target_scaler.inverse_transform(inverse_transform_df_preds)[PRIMARY_TARGET].values

            # Calculate metrics on inverse-transformed data
            if len(actuals_inv) > 0 and len(preds_inv) > 0 :
                r2_val, mae_val, mse_val = calculate_metrics(actuals_inv, preds_inv)
                print(f"FINAL Validation R-squared (R2): {r2_val:.4f}")
                print(f"FINAL Validation Mean Absolute Error (MAE): {mae_val:.2f}")
                print(f"FINAL Validation Mean Squared Error (MSE): {mse_val:.2f}")

                # --- 8. Plot Validation Results ---
                print(f"\n--- Plotting FINAL Validation Results for {PRIMARY_TARGET} (Actual vs. Predicted {VALIDATION_YEAR}) ---")
                # Create dataframes for plotting (using inverse transformed values)
                # Need to associate predictions back with dates
                # Use the index returned by predict() which contains time_idx and group_id
                plot_preds_df = predictions_raw.index.copy()
                plot_preds_df['prediction'] = pd.DataFrame(preds_inv.reshape(-1,1)) # Reshape to 2D for consistency if needed
                # Map time_idx back to date using the original validation data
                time_idx_to_date = val_data_tft[[DATE_COLUMN, 'time_idx']].set_index('time_idx')
                plot_preds_df = plot_preds_df.join(time_idx_to_date, on='time_idx')

                # Actuals for plotting (already inverse transformed)
                plot_actuals_df = val_data_tft.iloc[:len(actuals_inv)].copy() # Ensure length matches evaluation
                plot_actuals_df[PRIMARY_TARGET] = actuals_inv

                plot_title_val = f'TFT Validation (Nashik/Wheat): {PRIMARY_TARGET.replace("avg_", "").replace("_price", "").capitalize()} Price (Actual vs. Predicted {VALIDATION_YEAR})'
                fig_val = plot_validation_results(plot_actuals_df, plot_preds_df, PRIMARY_TARGET, DATE_COLUMN, plot_title_val)
                fig_val.show()
            else:
                 print("Skipping metrics and plotting due to empty aligned actuals/predictions.")


        except Exception as e:
            print(f"Error during TFT training or evaluation: {e}")
            traceback.print_exc()

    else:
         print("\nCannot proceed: lack of data after filtering.")
else:
    print("\nFailed during data loading, preprocessing, or mapping lookup.")

print("\nProcess finished.")

--- Temporal Fusion Transformer Forecasting & Validation ---
--- (Nashik/Wheat: 2002-2023 Train, 2024 Validate) ---
------------------------------
Processing Training (2002-2023) Dataset
------------------------------
Loading Training (2002-2023) data from E:\elevatetrsest\crop price predictor\Crop_price_Prediction\data\edited_nashik_2002_2023.csv...
Loaded 6246 rows.
Constructing 'full_date'...
6246 rows after date construction.
6246 rows after ensuring price columns numeric.
Training (2002-2023) base data loaded. 6246 rows.
------------------------------
Processing Validation (2024) Dataset
------------------------------
Loading Validation (2024) data from E:\elevatetrsest\crop price predictor\Crop_price_Prediction\data\edited_nashik_test_2024.csv...
Loaded 278 rows.
Constructing 'full_date'...
278 rows after date construction.
278 rows after ensuring price columns numeric.
Validation (2024) base data loaded. 278 rows.

Selected: Maharashtra/Nashik/Wheat -> Encoded: St=6291, Di=6291,

Traceback (most recent call last):
  File "C:\Users\Shiva\AppData\Local\Temp\ipykernel_22400\2128141116.py", line 274, in <module>
    training_dataset = TimeSeriesDataSet(
  File "C:\Users\Shiva\AppData\Roaming\Python\Python310\site-packages\pytorch_forecasting\data\timeseries.py", line 637, in __init__
    self.index = self._construct_index(data, predict_mode=self.predict_mode)
  File "C:\Users\Shiva\AppData\Roaming\Python\Python310\site-packages\pytorch_forecasting\data\timeseries.py", line 1767, in _construct_index
    assert self.allow_missing_timesteps, msg
AssertionError: Time difference between steps has been idenfied as larger than 1 - set allow_missing_timesteps=True


## 2nd try

In [2]:
# !pip install pytorch-forecasting pytorch-lightning -U # Uncomment if needed

# --- Standard Imports ---
import pandas as pd
import numpy as np
import datetime
from sklearn.metrics import r2_score, mean_absolute_error, mean_squared_error
from IPython.display import display
import traceback
import warnings
import os

# --- PyTorch & Forecasting Imports ---
import torch
import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor
# from pytorch_lightning.loggers import TensorBoardLogger # Optional
from pytorch_forecasting import TimeSeriesDataSet, TemporalFusionTransformer
from pytorch_forecasting.data import GroupNormalizer
from pytorch_forecasting.metrics import MAE, SMAPE, QuantileLoss
# Correct import for TemporalFusionTransformer
from pytorch_forecasting.models import TemporalFusionTransformer

# Suppress common warnings
warnings.filterwarnings("ignore")
# Set seeds for reproducibility (optional)
# pl.seed_everything(42)

# --- Configuration ---
# Data Paths (Ensure these point to your pre-encoded files)
DATA_PATH_TRAIN = r"E:\elevatetrsest\crop price predictor\Crop_price_Prediction\data\edited_nashik_2002_2023.csv"
DATA_PATH_VALIDATION = r"E:\elevatetrsest\crop price predictor\Crop_price_Prediction\data\edited_nashik_test_2024.csv"

# Target Columns & Date Construction Columns
TARGET_COLUMNS = ['avg_min_price', 'avg_max_price', 'avg_modal_price']
PRIMARY_TARGET = 'avg_modal_price' # Focus on one target for TFT simplicity
DATE_COLUMN = 'full_date' # Name for the constructed date column
YEAR_COL = 'year'; MONTH_COL = 'month'; DAY_COL = 'date' # Source columns for date
VALIDATION_YEAR = 2024

# Filter Selections
SELECTED_STATE_STR = "Maharashtra"
SELECTED_DISTRICT_STR = "Nashik"
SELECTED_COMMODITY_STR = "Wheat"

# Frequency Encoding Maps (CRITICAL: CONFIRM THESE ARE CORRECT FOR YOUR FILES)
state_name_encoding_map = {"maharashtra": 6291}
district_name_encoding_map = {"nashik": 6291}
commodity_name_encoding_map = {"wheat": 6291}

# TFT Configuration
ENCODER_LENGTH = 90      # Input sequence length (e.g., 90 days) - NEEDS TUNING
PREDICTION_LENGTH = 1    # Predict 1 step ahead for this validation setup
BATCH_SIZE = 64          # Adjust based on memory
NUM_WORKER = 0           # Use 0 for Windows generally
TRAINER_GPUS = 0         # Set to 1+ if GPU available
LEARNING_RATE = 0.005    # Example, needs tuning
HIDDEN_SIZE = 32         # Example, needs tuning
ATTENTION_HEADS = 4      # Example, needs tuning
DROPOUT = 0.1            # Example, needs tuning
MAX_EPOCHS = 15          # Example, use EarlyStopping

# --- Outlier Removal Function ---
def remove_outliers_iqr(df, columns_to_check):
    """Removes outliers from specified numerical columns using the IQR method."""
    df_filtered = df.copy(); initial_rows = len(df_filtered)
    valid_columns = [col for col in columns_to_check if col in df_filtered.columns and pd.api.types.is_numeric_dtype(df_filtered[col])]
    if not valid_columns: return df_filtered
    subset_for_iqr = df_filtered[valid_columns]
    Q1 = subset_for_iqr.quantile(0.25); Q3 = subset_for_iqr.quantile(0.75); IQR = Q3 - Q1
    mask = ~((subset_for_iqr < (Q1 - 1.5 * IQR)) | (subset_for_iqr > (Q3 + 1.5 * IQR))).any(axis=1)
    df_filtered = df_filtered[mask]; rows_removed = initial_rows - len(df_filtered)
    if rows_removed > 0: print(f"Removed {rows_removed} rows via IQR.")
    return df_filtered

# --- Data Loading and Preprocessing Function ---
def load_and_preprocess_base_data(path, date_col_name, year_col, month_col, day_col, target_cols, dataset_name="Training"):
    """Loads data, constructs date, basic cleaning."""
    print("-" * 30); print(f"Processing {dataset_name} Dataset"); print("-" * 30)
    try:
        print(f"Loading {dataset_name} data from {path}..."); df = pd.read_csv(path); print(f"Loaded {len(df)} rows.")
        # 1. Construct Date
        date_components_cols = [year_col, month_col, day_col]
        if not all(col in df.columns for col in date_components_cols): print(f"Error: Date component cols missing: {[c for c in date_components_cols if c not in df.columns]}"); return None
        for col in date_components_cols: df[col] = pd.to_numeric(df[col], errors='coerce')
        df.dropna(subset=date_components_cols, inplace=True)
        print(f"Constructing '{date_col_name}'...");
        df[date_col_name] = pd.to_datetime({'year': df[year_col], 'month': df[month_col], 'day': df[day_col]}, errors='coerce')
        initial_rows_date = len(df); df.dropna(subset=[date_col_name], inplace=True)
        if initial_rows_date > len(df): print(f"Dropped {initial_rows_date - len(df)} rows due to invalid date components.")
        print(f"{len(df)} rows after date construction.")

        # 2. Ensure Price columns are numeric
        all_potential_targets = ['avg_min_price', 'avg_max_price', 'avg_modal_price']
        for col in all_potential_targets:
            if col in df.columns: df[col] = pd.to_numeric(df[col], errors='coerce')
            else: print(f"Warning: Price column '{col}' not found.")
        df.dropna(subset=all_potential_targets, how='any', inplace=True) # Drop if any price is missing
        print(f"{len(df)} rows after ensuring price columns numeric.")

        # 3. Drop OTHER unused columns
        cols_to_drop = ['calculationType', 'district_id', 'change','district_name_enc', 'commodity_name_enc', 'state_name_enc']
        existing_cols_to_drop = [col for col in cols_to_drop if col in df.columns]
        if existing_cols_to_drop: df.drop(columns=existing_cols_to_drop, axis=1, inplace=True)

        # 4. Apply IQR Outlier Removal (Optional)
        # df = remove_outliers_iqr(df, all_potential_targets)

        # 5. Check required columns (encoded filters + date + targets)
        required_numeric_filter_cols = ['state_name', 'district_name', 'commodity_name']
        required_cols = [date_col_name] + all_potential_targets + required_numeric_filter_cols
        missing_req_cols = [col for col in required_cols if col not in df.columns]
        if missing_req_cols: print(f"Error: Required columns missing: {missing_req_cols}"); print(f"Available: {df.columns.tolist()}"); return None
        for col in required_numeric_filter_cols:
             if not pd.api.types.is_numeric_dtype(df[col]): print(f"Error: Col '{col}' expected numeric but isn't."); return None

        df.sort_values(date_col_name, inplace=True)
        print(f"{dataset_name} base data loaded. {len(df)} rows.")
        return df

    except FileNotFoundError: print(f"Error: {dataset_name} file not found at {path}"); return None
    except Exception as e: print(f"Error loading/preprocessing {dataset_name}: {e}"); traceback.print_exc(); return None


# --- Evaluation Metrics Function ---
def calculate_metrics(y_true, y_pred):
    """Calculates R2, MAE, MSE after handling potential NaNs and length mismatches."""
    y_true = np.array(y_true).flatten(); y_pred = np.array(y_pred).flatten()
    # Remove rows with NaN in either array after alignment
    valid_mask = ~np.isnan(y_true) & ~np.isnan(y_pred)
    y_true = y_true[valid_mask]; y_pred = y_pred[valid_mask]
    if len(y_true) == 0: print("Warning: No valid (non-NaN) points for metric calculation."); return np.nan, np.nan, np.nan
    if len(y_true) != len(y_pred): # Should ideally not happen if preprocessing is correct
        min_len = min(len(y_true), len(y_pred)); print(f"Warn: Mismatch metrics post-NaN. Truncating to {min_len}.")
        if min_len == 0 : return np.nan, np.nan, np.nan
        y_true = y_true[:min_len]; y_pred = y_pred[:min_len]
    try:
        r2 = r2_score(y_true, y_pred); mae = mean_absolute_error(y_true, y_pred); mse = mean_squared_error(y_true, y_pred)
        return r2, mae, mse
    except Exception as e: print(f"Error calculating metrics: {e}"); return np.nan, np.nan, np.nan

# --- Plotting Function for Validation ---
def plot_validation_results(validation_actuals_df, validation_preds_df, target_column, date_col_name, title):
    """Plots actuals vs predictions for validation period."""
    fig = go.Figure(); target_label = target_column.replace("avg_", "").replace("_price", "").capitalize()
    # Ensure dataframes have the required columns and are sorted
    plot_actuals_df = validation_actuals_df.sort_values(by=date_col_name).copy()
    plot_preds_df = validation_preds_df.sort_values(by='time_idx').copy() # preds often keyed by time_idx

    # Add Actual Validation Data trace
    fig.add_trace(go.Scatter(x=plot_actuals_df[date_col_name], y=plot_actuals_df[target_column], mode='lines+markers', name=f'Actual {target_label} ({VALIDATION_YEAR})', line=dict(color='blue'), marker=dict(size=4)))

    # Add Predicted Validation Data trace - Use dates from prediction index if possible
    # Merge predicted values back onto a date axis for plotting
    if date_col_name in plot_preds_df.columns:
        plot_preds_df_dated = plot_preds_df
    else: # If date column is missing in preds, merge it back based on time_idx
         time_idx_to_date = plot_actuals_df[[date_col_name, 'time_idx']].drop_duplicates().set_index('time_idx')
         plot_preds_df_dated = plot_preds_df.join(time_idx_to_date, on='time_idx')

    # Ensure prediction column exists and handle potential multi-output (take median)
    pred_col_name = 'prediction'
    if pred_col_name in plot_preds_df_dated.columns:
         if plot_preds_df_dated[pred_col_name].ndim > 1 and plot_preds_df_dated[pred_col_name].shape[1] > 1:
              # Assuming quantile forecasts, take median (index 3 for 7 quantiles P10,P25,P50,P75,P90)
              pred_values = plot_preds_df_dated[pred_col_name].iloc[:, 3]
              print("Plotting median (quantile 0.5) prediction.")
         else:
              pred_values = plot_preds_df_dated[pred_col_name].iloc[:, 0] # Single prediction output
    else:
         print(f"Warning: Prediction column '{pred_col_name}' not found in prediction results for plotting.")
         pred_values = pd.Series(dtype=float) # Empty series

    fig.add_trace(go.Scatter(
        x=plot_preds_df_dated[date_col_name],
        y=pred_values,
        mode='lines', name=f'Predicted {target_label} ({VALIDATION_YEAR})',
        line=dict(color='red')
    ))

    fig.update_layout(title=title, xaxis_title=f'Date ({VALIDATION_YEAR})', yaxis_title=f'Price ({target_label})', hovermode="x unified", legend_title_text='Legend')
    return fig


# --- Main Execution Block ---
print("--- Temporal Fusion Transformer Forecasting & Validation ---")
print(f"--- (Nashik/Wheat: 2002-2023 Train, {VALIDATION_YEAR} Validate) ---")

# 1. Load Base Data
df_train_base = load_and_preprocess_base_data(DATA_PATH_TRAIN, DATE_COLUMN, YEAR_COL, MONTH_COL, DAY_COL, TARGET_COLUMNS, "Training (2002-2023)")
df_val_base = load_and_preprocess_base_data(DATA_PATH_VALIDATION, DATE_COLUMN, YEAR_COL, MONTH_COL, DAY_COL, TARGET_COLUMNS, f"Validation ({VALIDATION_YEAR})")

# Init flags/variables for later checks
train_dataloader = None
val_dataloader = None
training_dataset = None # To access normalizer later

# Proceed only if both datasets loaded successfully
if df_train_base is not None and df_val_base is not None:

    # 2. Get Encoded Values for Filtering
    try:
        selected_state_key=SELECTED_STATE_STR.strip().lower(); selected_district_key=SELECTED_DISTRICT_STR.strip().lower(); selected_commodity_key=SELECTED_COMMODITY_STR.strip().lower()
        encoded_state = state_name_encoding_map.get(selected_state_key); encoded_district = district_name_encoding_map.get(selected_district_key); encoded_commodity = commodity_name_encoding_map.get(selected_commodity_key)
        lookup_failed = False
        if encoded_state is None: print(f"Error: State '{SELECTED_STATE_STR}' missing map."); lookup_failed=True
        if encoded_district is None: print(f"Error: District '{SELECTED_DISTRICT_STR}' missing map."); lookup_failed=True
        if encoded_commodity is None: print(f"Error: Commodity '{SELECTED_COMMODITY_STR}' missing map."); lookup_failed=True
        if lookup_failed: print("Check maps."); df_train_base=df_val_base=None
        else: print(f"\nSelected: {SELECTED_STATE_STR}/{SELECTED_DISTRICT_STR}/{SELECTED_COMMODITY_STR} -> Encoded: St={encoded_state}, Di={encoded_district}, Co={encoded_commodity}")
    except Exception as e: print(f"Error mapping lookup: {e}"); df_train_base = df_val_base = None

# Proceed only if lookup succeeded
if df_train_base is not None and df_val_base is not None:

    # 3. Filtering Data Based on ENCODED Values
    print(f"\nFiltering datasets using encoded values...")
    filter_cols_num = ['state_name', 'district_name', 'commodity_name']
    if not all(col in df_train_base.columns for col in filter_cols_num): print("Error: Encoded filter cols missing Training."); filtered_df_train = pd.DataFrame()
    else: filtered_df_train = df_train_base[(df_train_base['state_name'] == encoded_state) & (df_train_base['district_name'] == encoded_district) & (df_train_base['commodity_name'] == encoded_commodity)].copy(); filtered_df_train.sort_values(by=DATE_COLUMN, inplace=True)
    if not all(col in df_val_base.columns for col in filter_cols_num): print("Error: Encoded filter cols missing Validation."); filtered_df_val = pd.DataFrame()
    else: filtered_df_val = df_val_base[(df_val_base['state_name'] == encoded_state) & (df_val_base['district_name'] == encoded_district) & (df_val_base['commodity_name'] == encoded_commodity)].copy(); filtered_df_val.sort_values(by=DATE_COLUMN, inplace=True)

    if filtered_df_train.empty: print("\nWarning: No training data after filtering.")
    if filtered_df_val.empty: print("\nWarning: No validation data after filtering.")

    # 4. Prepare Data for TFT (Focusing on PRIMARY_TARGET)
    if not filtered_df_train.empty and not filtered_df_val.empty:
        print(f"\nPreparing data for TFT (Target: {PRIMARY_TARGET})...")
        try:
            # Combine temporarily for global time_idx
            filtered_df_train["dataset_type"] = "train"; filtered_df_val["dataset_type"] = "val"
            data_comb = pd.concat([filtered_df_train, filtered_df_val], ignore_index=True).sort_values(DATE_COLUMN)
            data_comb['time_idx'] = (data_comb[DATE_COLUMN] - data_comb[DATE_COLUMN].min()).dt.days
            data_comb['group_id'] = f"{SELECTED_DISTRICT_STR}_{SELECTED_COMMODITY_STR}_{PRIMARY_TARGET}" # Unique ID for this series/target
            # Add time features (Known Future Inputs)
            data_comb['month'] = data_comb[DATE_COLUMN].dt.month.astype(str)
            data_comb['day_of_week'] = data_comb[DATE_COLUMN].dt.dayofweek.astype(str)
            data_comb['day_of_month'] = data_comb[DATE_COLUMN].dt.day.astype(str) # Added day of month
            # data_comb['day_of_year'] = data_comb[DATE_COLUMN].dt.dayofyear.astype(str) # Can add if useful
            data_comb['week_of_year'] = data_comb[DATE_COLUMN].dt.isocalendar().week.astype(str) # Added week
            # data_comb['year_cat'] = data_comb[DATE_COLUMN].dt.year.astype(str) # Year also as categorical?
            data_comb['year_real'] = data_comb[DATE_COLUMN].dt.year # Year as real
            data_comb[PRIMARY_TARGET] = data_comb[PRIMARY_TARGET].astype(np.float32) # Ensure target is float32

            # --- Feature list definitions ---
            time_varying_known_categoricals = ["month", "day_of_week", "day_of_month", "week_of_year"]
            time_varying_known_reals = ["year_real", "time_idx"]
            # Add lags or other external features here if needed:
            time_varying_unknown_reals = [PRIMARY_TARGET] # If no other unknown features, target is needed here
            static_categoricals = ["group_id"]

            # Separate back into train and validation sets
            train_data_tft = data_comb[data_comb["dataset_type"] == "train"].copy()
            first_val_idx_actual = data_comb[data_comb["dataset_type"] == "val"]['time_idx'].min() # Use combined data to get correct first idx
            start_val_idx_with_overlap = max(0, first_val_idx_actual - ENCODER_LENGTH)
            val_data_tft_with_overlap = data_comb[data_comb["time_idx"] >= start_val_idx_with_overlap].copy()


            # --- Create TimeSeriesDataSet ---
            print("Creating TimeSeriesDataSet...")
            target_normalizer = GroupNormalizer(groups=["group_id"], transformation="softplus", center=True)

            training_dataset = TimeSeriesDataSet(
                train_data_tft,
                time_idx="time_idx", target=PRIMARY_TARGET, group_ids=["group_id"],
                max_encoder_length=ENCODER_LENGTH, max_prediction_length=PREDICTION_LENGTH,
                static_categoricals=static_categoricals,
                time_varying_known_categoricals=time_varying_known_categoricals,
                time_varying_known_reals=time_varying_known_reals,
                time_varying_unknown_reals=time_varying_unknown_reals,
                target_normalizer=target_normalizer,
                add_relative_time_idx=True, add_target_scales=True, add_encoder_length=True,
                allow_missing_timesteps=True # <-- Fix from previous error
            )

            # Create validation dataset using the same parameters
            validation_dataset = TimeSeriesDataSet.from_dataset(
                training_dataset, val_data_tft_with_overlap, # Use the overlap data
                predict=False, # Set to False for calculating validation loss
                stop_randomization=True,
            )

            # Create dataloaders
            train_dataloader = training_dataset.to_dataloader(train=True, batch_size=BATCH_SIZE, num_workers=NUM_WORKER)
            val_dataloader = validation_dataset.to_dataloader(train=False, batch_size=BATCH_SIZE * 2, num_workers=NUM_WORKER)
            print("TimeSeriesDataSet and Dataloaders created.")

        except Exception as e:
            print(f"Error creating TimeSeriesDataSet or Dataloaders: {e}")
            traceback.print_exc(); train_dataloader = val_dataloader = None # Prevent proceeding

    # Proceed only if dataloaders created
    if 'train_dataloader' in locals() and 'val_dataloader' in locals() and train_dataloader is not None and val_dataloader is not None:
        # --- 5. Define Model & Trainer ---
        print("\nDefining TFT model and Trainer...")
        try:
            early_stop_callback = EarlyStopping(monitor="val_loss", min_delta=1e-4, patience=5, verbose=False, mode="min")
            lr_monitor = LearningRateMonitor(logging_interval='epoch')
            # logger = TensorBoardLogger("tb_logs", name="tft_nashik_wheat") # Optional

            trainer = pl.Trainer(
                max_epochs=MAX_EPOCHS, 
                # gpus=TRAINER_GPUS, 
                gradient_clip_val=0.1,
                accelerator="gpu",
                devices=1, # e.g., devices=1 for one GPU
                # gradient_clip_val=0.1,
                # limit_train_batches=30, limit_val_batches=10, # Uncomment for debugging
                # fast_dev_run=True, # Uncomment for quick test run
                callbacks=[lr_monitor, early_stop_callback],
                # logger=logger,
                enable_progress_bar=True, enable_checkpointing=True # Ensure checkpoints are saved
            )

            tft = TemporalFusionTransformer.from_dataset(
                training_dataset, learning_rate=LEARNING_RATE, hidden_size=HIDDEN_SIZE,
                attention_head_size=ATTENTION_HEADS, dropout=DROPOUT, hidden_continuous_size=HIDDEN_SIZE // 2,
                output_size=7, loss=QuantileLoss(), # Use QuantileLoss
                # log_interval=10, # Log less frequently
                # reduce_on_plateau_patience=4
            )
            print(f"TFT Network parameters: {tft.size()/1e3:.1f}k")

            # --- 6. Train Model ---
            print("\nStarting model training...")
            trainer.fit(tft, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader)
            print("Training finished.")

            # --- 7. Evaluate on Validation Set ---
            print(f"\n--- Evaluating FINAL TFT Model Performance on {VALIDATION_YEAR} Validation Data ---")
            best_model_path = trainer.checkpoint_callback.best_model_path
            if best_model_path and os.path.exists(best_model_path):
                 print(f"Loading best model from: {best_model_path}")
                 best_tft = TemporalFusionTransformer.load_from_checkpoint(best_model_path)
            else:
                 print("Warning: Best model checkpoint not found. Using model state from end of training.")
                 best_tft = tft # Use the model from the last epoch if checkpoint missing

            # Predict on validation data (ensure using the correct dataloader)
            print("Predicting on validation data...")
            # Use the validation dataloader created earlier
            raw_predictions = best_tft.predict(val_dataloader, mode="raw", return_index=True)
            # raw_predictions["prediction"] gives quantile forecasts
            # raw_predictions["index"] gives time_idx and group_id

            # Extract actuals corresponding to the predictions
            actuals_df_eval = val_data_tft_with_overlap[lambda x: x.time_idx.isin(raw_predictions.index.time_idx)][[DATE_COLUMN, 'time_idx', 'group_id', PRIMARY_TARGET]].copy()

            # Extract median prediction (quantile 0.5, which is index 3 for 7 quantiles)
            predictions_df_eval = raw_predictions.index.copy()
            predictions_df_eval['prediction_normalized'] = raw_predictions.prediction[:, :, 3] # Index 3 is median

            # Merge actuals and normalized predictions based on time_idx and group_id
            eval_results = pd.merge(
                 actuals_df_eval, predictions_df_eval, on=['time_idx', 'group_id'], how='inner'
            )

            # Inverse transform using the fitted normalizer from training_dataset
            target_scaler = training_dataset.target_normalizer

            # Need to structure data for inverse_transform (often needs group_id)
            actuals_inv = target_scaler.inverse_transform(eval_results[[PRIMARY_TARGET, 'group_id']].rename(columns={PRIMARY_TARGET:'target'}))['target'].values
            preds_inv = target_scaler.inverse_transform(eval_results[['prediction_normalized', 'group_id']].rename(columns={'prediction_normalized':'target'}))['target'].values

            # Calculate metrics
            if len(actuals_inv) > 0 and len(preds_inv) > 0:
                print(f"Evaluating based on {len(actuals_inv)} matched prediction points.")
                r2_val, mae_val, mse_val = calculate_metrics(actuals_inv, preds_inv)
                print(f"FINAL Validation R-squared (R2): {r2_val:.4f}")
                print(f"FINAL Validation Mean Absolute Error (MAE): {mae_val:.2f}")
                print(f"FINAL Validation Mean Squared Error (MSE): {mse_val:.2f}")

                # --- 8. Plot Validation Results ---
                print(f"\n--- Plotting FINAL Validation Results for {PRIMARY_TARGET} (Actual vs. Predicted {VALIDATION_YEAR}) ---")
                # Prepare dataframes for plotting with inverse-transformed values
                plot_actuals_df = eval_results[[DATE_COLUMN, 'time_idx']].copy()
                plot_actuals_df[PRIMARY_TARGET] = actuals_inv
                plot_preds_df = eval_results[[DATE_COLUMN, 'time_idx']].copy()
                plot_preds_df['prediction'] = preds_inv

                plot_title_val = f'TFT Validation (Nashik/Wheat): {PRIMARY_TARGET.replace("avg_", "").replace("_price", "").capitalize()} Price (Actual vs. Predicted {VALIDATION_YEAR})'
                fig_val = plot_validation_results(plot_actuals_df, plot_preds_df, PRIMARY_TARGET, DATE_COLUMN, plot_title_val)
                fig_val.show()
            else: print("Skipping metrics and plotting: No matched actuals/predictions found.")

        except Exception as e:
            print(f"Error during TFT training or evaluation: {e}"); traceback.print_exc()

    else: print("\nCannot proceed: lack of data after filtering or TFT dataset creation failed.")
else: print("\nFailed during data loading, preprocessing, or mapping lookup.")

print("\nProcess finished.")

  from .autonotebook import tqdm as notebook_tqdm
You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.


--- Temporal Fusion Transformer Forecasting & Validation ---
--- (Nashik/Wheat: 2002-2023 Train, 2024 Validate) ---
------------------------------
Processing Training (2002-2023) Dataset
------------------------------
Loading Training (2002-2023) data from E:\elevatetrsest\crop price predictor\Crop_price_Prediction\data\edited_nashik_2002_2023.csv...
Loaded 6246 rows.
Constructing 'full_date'...
6246 rows after date construction.
6246 rows after ensuring price columns numeric.
Training (2002-2023) base data loaded. 6246 rows.
------------------------------
Processing Validation (2024) Dataset
------------------------------
Loading Validation (2024) data from E:\elevatetrsest\crop price predictor\Crop_price_Prediction\data\edited_nashik_test_2024.csv...
Loaded 278 rows.
Constructing 'full_date'...
278 rows after date construction.
278 rows after ensuring price columns numeric.
Validation (2024) base data loaded. 278 rows.

Selected: Maharashtra/Nashik/Wheat -> Encoded: St=6291, Di=6291,

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


TFT Network parameters: 75.9k

Starting model training...
Error during TFT training or evaluation: `model` must be a `LightningModule` or `torch._dynamo.OptimizedModule`, got `TemporalFusionTransformer`

Process finished.


Traceback (most recent call last):
  File "C:\Users\Shiva\AppData\Local\Temp\ipykernel_22884\92919337.py", line 327, in <module>
    trainer.fit(tft, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader)
  File "c:\Users\Shiva\.conda\envs\tft_env\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 554, in fit
    model = _maybe_unwrap_optimized(model)
  File "c:\Users\Shiva\.conda\envs\tft_env\lib\site-packages\pytorch_lightning\utilities\compile.py", line 111, in _maybe_unwrap_optimized
    raise TypeError(
TypeError: `model` must be a `LightningModule` or `torch._dynamo.OptimizedModule`, got `TemporalFusionTransformer`
