In [2]:
# !pip install pytorch-forecasting pytorch-lightning -U # Ensure libraries are installed

# --- Standard Imports ---
import pandas as pd
import numpy as np
import datetime
from sklearn.metrics import r2_score, mean_absolute_error, mean_squared_error
from IPython.display import display
import traceback
import warnings
import os

# --- PyTorch & Forecasting Imports ---
import torch
import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor
# from pytorch_lightning.loggers import TensorBoardLogger # Optional
from pytorch_forecasting import TimeSeriesDataSet, DeepAR # <--- Import DeepAR
from pytorch_forecasting.data import GroupNormalizer
# Import a suitable loss function for continuous data
from pytorch_forecasting.metrics import NormalDistributionLoss # <--- Loss for DeepAR

# Suppress common warnings
warnings.filterwarnings("ignore")
# pl.seed_everything(42) # Optional reproducibility

# --- Configuration ---
# Data Paths
DATA_PATH_TRAIN = r"E:\elevatetrsest\crop price predictor\Crop_price_Prediction\data\edited_nashik_2002_2023.csv"
DATA_PATH_VALIDATION = r"E:\elevatetrsest\crop price predictor\Crop_price_Prediction\data\edited_nashik_test_2024.csv"

# Target Columns & Date Construction Columns
TARGET_COLUMNS = ['avg_min_price', 'avg_max_price', 'avg_modal_price']
PRIMARY_TARGET = 'avg_modal_price' # Focus on one target
DATE_COLUMN = 'full_date'; YEAR_COL = 'year'; MONTH_COL = 'month'; DAY_COL = 'date'
VALIDATION_YEAR = 2024

# Filter Selections
SELECTED_STATE_STR = "Maharashtra"; SELECTED_DISTRICT_STR = "Nashik"; SELECTED_COMMODITY_STR = "Wheat"

# Frequency Encoding Maps (CONFIRM THESE ARE CORRECT)
state_name_encoding_map = {"maharashtra": 6291}
district_name_encoding_map = {"nashik": 6291}
commodity_name_encoding_map = {"wheat": 6291}

# --- DeepAR/Training Configuration ---
ENCODER_LENGTH = 90      # Input sequence length (e.g., 90 days) - NEEDS TUNING
PREDICTION_LENGTH = 1    # Predict 1 step ahead for validation
BATCH_SIZE = 64          # Adjust based on memory
NUM_WORKER = 0           # Use 0 for Windows generally
TRAINER_ACCELERATOR="auto" # Use "gpu" if available, else "cpu"
TRAINER_DEVICES="auto"   # Use all available devices of the chosen accelerator type
# --- DeepAR Specific Hyperparameters (Examples - Tune these) ---
HIDDEN_SIZE = 40
RNN_LAYERS = 2
DROPOUT = 0.1
LEARNING_RATE = 0.005    # Example, needs tuning
MAX_EPOCHS = 20          # Increase slightly for RNN? Use EarlyStopping

# --- Helper Functions (remove_outliers_iqr, load_and_preprocess_base_data - NO CHANGES) ---
def remove_outliers_iqr(df, columns_to_check):
    df_filtered = df.copy(); initial_rows = len(df_filtered)
    valid_columns = [col for col in columns_to_check if col in df_filtered.columns and pd.api.types.is_numeric_dtype(df_filtered[col])]
    if not valid_columns: return df_filtered
    subset_for_iqr = df_filtered[valid_columns]
    Q1 = subset_for_iqr.quantile(0.25); Q3 = subset_for_iqr.quantile(0.75); IQR = Q3 - Q1
    mask = ~((subset_for_iqr < (Q1 - 1.5 * IQR)) | (subset_for_iqr > (Q3 + 1.5 * IQR))).any(axis=1)
    df_filtered = df_filtered[mask]; rows_removed = initial_rows - len(df_filtered)
    if rows_removed > 0: print(f"Removed {rows_removed} rows via IQR.")
    return df_filtered

def load_and_preprocess_base_data(path, date_col_name, year_col, month_col, day_col, target_cols, dataset_name="Training"):
    print("-" * 30); print(f"Processing {dataset_name} Dataset"); print("-" * 30)
    try:
        print(f"Loading {dataset_name} data from {path}..."); df = pd.read_csv(path); print(f"Loaded {len(df)} rows.")
        # 1. Construct Date
        date_components_cols = [year_col, month_col, day_col]
        if not all(col in df.columns for col in date_components_cols): print(f"Error: Date component cols missing: {[c for c in date_components_cols if c not in df.columns]}"); return None
        for col in date_components_cols: df[col] = pd.to_numeric(df[col], errors='coerce')
        df.dropna(subset=date_components_cols, inplace=True)
        print(f"Constructing '{date_col_name}'...");
        df[date_col_name] = pd.to_datetime({'year': df[year_col], 'month': df[month_col], 'day': df[day_col]}, errors='coerce')
        initial_rows_date = len(df); df.dropna(subset=[date_col_name], inplace=True)
        if initial_rows_date > len(df): print(f"Dropped {initial_rows_date - len(df)} rows due to invalid date components.")
        print(f"{len(df)} rows after date construction.")
        # 2. Initial dropna()
        initial_rows = len(df); df.dropna(inplace=True)
        if initial_rows > len(df): print(f"{len(df)} rows after initial dropna(). {initial_rows - len(df)} removed.")
        # 3. Ensure Price columns numeric
        all_potential_targets = ['avg_min_price', 'avg_max_price', 'avg_modal_price']
        for col in all_potential_targets:
            if col in df.columns: df[col] = pd.to_numeric(df[col], errors='coerce')
            else: print(f"Warning: Price column '{col}' not found.")
        df.dropna(subset=all_potential_targets, how='any', inplace=True)
        print(f"{len(df)} rows after ensuring price columns numeric.")
        # 4. Drop OTHER unused columns
        cols_to_drop = ['calculationType', 'district_id', 'change','district_name_enc', 'commodity_name_enc', 'state_name_enc']
        existing_cols_to_drop = [col for col in cols_to_drop if col in df.columns]
        if existing_cols_to_drop: df.drop(columns=existing_cols_to_drop, axis=1, inplace=True)
        # 5. Apply IQR Outlier Removal (Optional)
        # df = remove_outliers_iqr(df, all_potential_targets)
        # 6. Check required columns
        required_numeric_filter_cols = ['state_name', 'district_name', 'commodity_name']
        required_cols = [date_col_name] + all_potential_targets + required_numeric_filter_cols
        missing_req_cols = [col for col in required_cols if col not in df.columns]
        if missing_req_cols: print(f"Error: Required columns missing: {missing_req_cols}"); print(f"Available: {df.columns.tolist()}"); return None
        for col in required_numeric_filter_cols:
             if not pd.api.types.is_numeric_dtype(df[col]): print(f"Error: Col '{col}' expected numeric but isn't."); return None
        df.sort_values(date_col_name, inplace=True)
        print(f"{dataset_name} base data loaded. {len(df)} rows.")
        return df
    except FileNotFoundError: print(f"Error: {dataset_name} file not found at {path}"); return None
    except Exception as e: print(f"Error loading/preprocessing {dataset_name}: {e}"); traceback.print_exc(); return None

# --- Evaluation Metrics Function (No changes) ---
def calculate_metrics(y_true, y_pred):
    y_true = np.array(y_true).flatten(); y_pred = np.array(y_pred).flatten()
    valid_mask = ~np.isnan(y_true) & ~np.isnan(y_pred)
    y_true = y_true[valid_mask]; y_pred = y_pred[valid_mask]
    if len(y_true) == 0: print("Warning: No valid points for metric calculation."); return np.nan, np.nan, np.nan
    try:
        r2 = r2_score(y_true, y_pred); mae = mean_absolute_error(y_true, y_pred); mse = mean_squared_error(y_true, y_pred)
        return r2, mae, mse
    except Exception as e: print(f"Error calculating metrics: {e}"); return np.nan, np.nan, np.nan

# --- Plotting Function for Validation (No changes) ---
def plot_validation_results(validation_actuals_df, validation_preds_df, target_column, date_col_name, title):
    """Plots actuals vs predictions for validation period."""
    fig = go.Figure(); target_label = target_column.replace("avg_", "").replace("_price", "").capitalize()
    plot_actuals_df = validation_actuals_df.sort_values(by=date_col_name).copy()
    plot_preds_df = validation_preds_df.sort_values(by='time_idx').copy()

    fig.add_trace(go.Scatter(x=plot_actuals_df[date_col_name], y=plot_actuals_df[target_column], mode='lines+markers', name=f'Actual {target_label} ({VALIDATION_YEAR})', line=dict(color='blue'), marker=dict(size=4)))

    # Merge predicted values back onto a date axis for plotting
    pred_col_name = 'prediction' # Column name holding point forecast
    if date_col_name not in plot_preds_df.columns:
         time_idx_to_date = plot_actuals_df[[date_col_name, 'time_idx']].drop_duplicates().set_index('time_idx')
         plot_preds_df = plot_preds_df.join(time_idx_to_date, on='time_idx')

    if pred_col_name in plot_preds_df.columns and date_col_name in plot_preds_df.columns:
         # Ensure prediction is 1D
         if plot_preds_df[pred_col_name].ndim > 1:
              pred_values = plot_preds_df[pred_col_name].iloc[:, 0] # Assume first col is point forecast
         else:
              pred_values = plot_preds_df[pred_col_name]

         fig.add_trace(go.Scatter(x=plot_preds_df[date_col_name], y=pred_values, mode='lines', name=f'Predicted {target_label} ({VALIDATION_YEAR})', line=dict(color='red')))
    else:
         print(f"Warning: Could not find '{pred_col_name}' or '{date_col_name}' in prediction results for plotting.")

    fig.update_layout(title=title, xaxis_title=f'Date ({VALIDATION_YEAR})', yaxis_title=f'Price ({target_label})', hovermode="x unified", legend_title_text='Legend')
    return fig


# --- Main Execution Block ---
print("--- DeepAR Forecasting & Validation ---") # Changed Title
print(f"--- (Nashik/Wheat: 2002-2023 Train, {VALIDATION_YEAR} Validate) ---")

# 1. Load Base Data
df_train_base = load_and_preprocess_base_data(DATA_PATH_TRAIN, DATE_COLUMN, YEAR_COL, MONTH_COL, DAY_COL, TARGET_COLUMNS, "Training (2002-2023)")
df_val_base = load_and_preprocess_base_data(DATA_PATH_VALIDATION, DATE_COLUMN, YEAR_COL, MONTH_COL, DAY_COL, TARGET_COLUMNS, f"Validation ({VALIDATION_YEAR})")

# Init flags/variables
train_dataloader = None; val_dataloader = None; training_dataset = None

if df_train_base is not None and df_val_base is not None:
    # 2. Get Encoded Values for Filtering
    try:
        selected_state_key=SELECTED_STATE_STR.strip().lower(); selected_district_key=SELECTED_DISTRICT_STR.strip().lower(); selected_commodity_key=SELECTED_COMMODITY_STR.strip().lower()
        encoded_state = state_name_encoding_map.get(selected_state_key); encoded_district = district_name_encoding_map.get(selected_district_key); encoded_commodity = commodity_name_encoding_map.get(selected_commodity_key)
        lookup_failed = False
        if encoded_state is None: print(f"Error: State '{SELECTED_STATE_STR}' missing map."); lookup_failed=True
        if encoded_district is None: print(f"Error: District '{SELECTED_DISTRICT_STR}' missing map."); lookup_failed=True
        if encoded_commodity is None: print(f"Error: Commodity '{SELECTED_COMMODITY_STR}' missing map."); lookup_failed=True
        if lookup_failed: print("Check maps."); df_train_base=df_val_base=None
        else: print(f"\nSelected: {SELECTED_STATE_STR}/{SELECTED_DISTRICT_STR}/{SELECTED_COMMODITY_STR} -> Encoded: St={encoded_state}, Di={encoded_district}, Co={encoded_commodity}")
    except Exception as e: print(f"Error mapping lookup: {e}"); df_train_base = df_val_base = None

if df_train_base is not None and df_val_base is not None:
    # 3. Filtering Data
    print(f"\nFiltering datasets using encoded values...")
    filter_cols_num = ['state_name', 'district_name', 'commodity_name']
    if not all(col in df_train_base.columns for col in filter_cols_num): print("Error: Encoded filter cols missing Training."); filtered_df_train = pd.DataFrame()
    else: filtered_df_train = df_train_base[(df_train_base['state_name'] == encoded_state) & (df_train_base['district_name'] == encoded_district) & (df_train_base['commodity_name'] == encoded_commodity)].copy(); filtered_df_train.sort_values(by=DATE_COLUMN, inplace=True)
    if not all(col in df_val_base.columns for col in filter_cols_num): print("Error: Encoded filter cols missing Validation."); filtered_df_val = pd.DataFrame()
    else: filtered_df_val = df_val_base[(df_val_base['state_name'] == encoded_state) & (df_val_base['district_name'] == encoded_district) & (df_val_base['commodity_name'] == encoded_commodity)].copy(); filtered_df_val.sort_values(by=DATE_COLUMN, inplace=True)

    if filtered_df_train.empty: print("\nWarning: No training data after filtering.")
    if filtered_df_val.empty: print("\nWarning: No validation data after filtering.")

    # 4. Prepare Data for DeepAR (Target: PRIMARY_TARGET)
    if not filtered_df_train.empty and not filtered_df_val.empty:
        print(f"\nPreparing data for DeepAR (Target: {PRIMARY_TARGET})...")
        try:
            # Combine temporarily for global time_idx
            filtered_df_train["dataset_type"] = "train"; filtered_df_val["dataset_type"] = "val"
            data_comb = pd.concat([filtered_df_train, filtered_df_val], ignore_index=True).sort_values(DATE_COLUMN)
            data_comb['time_idx'] = (data_comb[DATE_COLUMN] - data_comb[DATE_COLUMN].min()).dt.days
            data_comb['group_id'] = f"{SELECTED_DISTRICT_STR}_{SELECTED_COMMODITY_STR}_{PRIMARY_TARGET}"
            # Add time features
            data_comb['month'] = data_comb[DATE_COLUMN].dt.month.astype(str)
            data_comb['day_of_week'] = data_comb[DATE_COLUMN].dt.dayofweek.astype(str)
            data_comb['day_of_month'] = data_comb[DATE_COLUMN].dt.day.astype(str)
            data_comb['week_of_year'] = data_comb[DATE_COLUMN].dt.isocalendar().week.astype(str)
            data_comb['year_real'] = data_comb[DATE_COLUMN].dt.year
            data_comb[PRIMARY_TARGET] = data_comb[PRIMARY_TARGET].astype(np.float32)

            # --- Feature list definitions ---
            time_varying_known_categoricals = ["month", "day_of_week", "day_of_month", "week_of_year"]
            time_varying_known_reals = ["year_real", "time_idx"] # time_idx is crucial here
            # DeepAR uses past target values implicitly via RNN structure,
            # but we still need to list the target in one of the variable groups.
            # If ONLY target is used as unknown input, put it in time_varying_unknown_reals
            time_varying_unknown_reals = [PRIMARY_TARGET]
            # If using lags, add lags to time_varying_unknown_reals and target MIGHT go to known reals? Check docs.
            static_categoricals = ["group_id"]

            # Separate back into train and validation sets
            train_data_tft = data_comb[data_comb["dataset_type"] == "train"].copy()
            first_val_idx_actual = data_comb[data_comb["dataset_type"] == "val"]['time_idx'].min()
            start_val_idx_with_overlap = max(0, first_val_idx_actual - ENCODER_LENGTH) # Need history for encoder
            val_data_tft_with_overlap = data_comb[data_comb["time_idx"] >= start_val_idx_with_overlap].copy()

            # --- Create TimeSeriesDataSet ---
            print("Creating TimeSeriesDataSet for DeepAR...")
            target_normalizer = GroupNormalizer(groups=["group_id"], transformation="softplus", center=True)

            training_dataset = TimeSeriesDataSet(
                train_data_tft,
                time_idx="time_idx", target=PRIMARY_TARGET, group_ids=["group_id"],
                max_encoder_length=ENCODER_LENGTH, max_prediction_length=PREDICTION_LENGTH,
                static_categoricals=static_categoricals,
                time_varying_known_categoricals=time_varying_known_categoricals,
                time_varying_known_reals=time_varying_known_reals,
                time_varying_unknown_reals=time_varying_unknown_reals,
                target_normalizer=target_normalizer,
                allow_missing_timesteps=True # Keep this!
            )

            validation_dataset = TimeSeriesDataSet.from_dataset(
                training_dataset, val_data_tft_with_overlap,
                predict=False, stop_randomization=True
            )
            train_dataloader = training_dataset.to_dataloader(train=True, batch_size=BATCH_SIZE, num_workers=NUM_WORKER)
            val_dataloader = validation_dataset.to_dataloader(train=False, batch_size=BATCH_SIZE * 2, num_workers=NUM_WORKER)
            print("TimeSeriesDataSet and Dataloaders created.")

        except Exception as e: print(f"Error creating TimeSeriesDataSet: {e}"); traceback.print_exc(); train_dataloader = val_dataloader = None

    if train_dataloader is not None and val_dataloader is not None:
        # --- 5. Define Model & Trainer ---
        print("\nDefining DeepAR model and Trainer...")
        try:
            early_stop_callback = EarlyStopping(monitor="val_loss", min_delta=1e-4, patience=5, verbose=False, mode="min")
            lr_monitor = LearningRateMonitor(logging_interval='epoch')
            # logger = TensorBoardLogger("tb_logs", name="deepar_nashik_wheat")

            trainer = pl.Trainer(
                max_epochs=MAX_EPOCHS, accelerator=TRAINER_ACCELERATOR, devices=TRAINER_DEVICES,
                gradient_clip_val=0.1,
                # limit_train_batches=30, limit_val_batches=10, # For debugging
                callbacks=[lr_monitor, early_stop_callback],
                # logger=logger,
                enable_progress_bar=True, enable_checkpointing=True
            )

            # --- Instantiate DeepAR Model ---
            deepar = DeepAR.from_dataset(
                training_dataset,
                learning_rate=LEARNING_RATE,
                hidden_size=HIDDEN_SIZE,
                rnn_layers=RNN_LAYERS,
                dropout=DROPOUT,
                loss=NormalDistributionLoss(), # Use appropriate loss for continuous data
                # optimizer="Adam" # Default is Adam
            )
            print(f"DeepAR Network parameters: {deepar.size()/1e3:.1f}k")

            # --- 6. Train Model ---
            print("\nStarting model training...")
            trainer.fit(
                deepar, # Pass the deepar model object
                train_dataloaders=train_dataloader,
                val_dataloaders=val_dataloader,
            )
            print("Training finished.")

            # --- 7. Evaluate on Validation Set ---
            print(f"\n--- Evaluating FINAL DeepAR Model Performance on {VALIDATION_YEAR} Validation Data ---")
            best_model_path = trainer.checkpoint_callback.best_model_path
            if best_model_path and os.path.exists(best_model_path):
                 print(f"Loading best model from: {best_model_path}")
                 best_model = DeepAR.load_from_checkpoint(best_model_path)
            else:
                 print("Warning: Best model checkpoint not found. Using model state from end of training.")
                 best_model = deepar # Use the model from the last epoch

            print("Predicting on validation data...")
            # Use mode="prediction" for point forecast (mean)
            raw_predictions = best_model.predict(val_dataloader, mode="prediction", return_index=True)
            # raw_predictions is typically the tensor of predicted means here
            # index contains time_idx and group_id

            # Align actuals and predictions using the index from predict()
            actuals_df_eval = val_data_tft_with_overlap[lambda x: x.time_idx.isin(raw_predictions.index.time_idx)][[DATE_COLUMN, 'time_idx', 'group_id', PRIMARY_TARGET]].copy()
            predictions_df_eval = raw_predictions.index.copy()
            # Ensure prediction tensor is correctly shaped and assigned
            if isinstance(raw_predictions, torch.Tensor):
                predictions_df_eval['prediction_normalized'] = raw_predictions.cpu().numpy().flatten() # Get point forecast
            else: # Handle potential tuple output or different structure if API changes
                 try: predictions_df_eval['prediction_normalized'] = raw_predictions.prediction.cpu().numpy().flatten()
                 except AttributeError: print("Error: Could not extract predictions from model output."); predictions_df_eval['prediction_normalized'] = np.nan


            eval_results = pd.merge(actuals_df_eval, predictions_df_eval, on=['time_idx', 'group_id'], how='inner')

            # Inverse transform
            target_scaler = training_dataset.target_normalizer
            try:
                 # Structure for inverse transform
                 actuals_inv_input = eval_results[[PRIMARY_TARGET, 'group_id']].rename(columns={PRIMARY_TARGET: 'target'})
                 preds_inv_input = eval_results[['prediction_normalized', 'group_id']].rename(columns={'prediction_normalized': 'target'})

                 actuals_inv = target_scaler.inverse_transform(actuals_inv_input)['target'].values
                 preds_inv = target_scaler.inverse_transform(preds_inv_input)['target'].values
            except Exception as e:
                 print(f"Error during inverse transform: {e}. Cannot calculate metrics accurately.")
                 actuals_inv = np.array([]) # Empty arrays to skip metric calculation
                 preds_inv = np.array([])

            # Calculate metrics
            if len(actuals_inv) > 0 and len(preds_inv) > 0:
                print(f"Evaluating based on {len(actuals_inv)} matched prediction points.")
                r2_val, mae_val, mse_val = calculate_metrics(actuals_inv, preds_inv)
                print(f"FINAL Validation R-squared (R2): {r2_val:.4f}")
                print(f"FINAL Validation Mean Absolute Error (MAE): {mae_val:.2f}")
                print(f"FINAL Validation Mean Squared Error (MSE): {mse_val:.2f}")

                # --- 8. Plot Validation Results ---
                print(f"\n--- Plotting FINAL Validation Results for {PRIMARY_TARGET} (Actual vs. Predicted {VALIDATION_YEAR}) ---")
                # Prepare DFs for plotting (using inverse transformed)
                plot_actuals_df = eval_results[[DATE_COLUMN, 'time_idx']].copy(); plot_actuals_df[PRIMARY_TARGET] = actuals_inv
                plot_preds_df = eval_results[[DATE_COLUMN, 'time_idx']].copy(); plot_preds_df['prediction'] = preds_inv

                plot_title_val = f'DeepAR Validation (Nashik/Wheat): {PRIMARY_TARGET.replace("avg_", "").replace("_price", "").capitalize()} Price (Actual vs. Predicted {VALIDATION_YEAR})'
                fig_val = plot_validation_results(plot_actuals_df, plot_preds_df, PRIMARY_TARGET, DATE_COLUMN, plot_title_val)
                fig_val.show()
            else: print("Skipping metrics and plotting: No matched/valid actuals/predictions found after inverse transform.")

        except Exception as e: print(f"Error during DeepAR training or evaluation: {e}"); traceback.print_exc()

    else: print("\nCannot proceed: lack of data after filtering or TimeSeriesDataSet creation failed.")
else: print("\nFailed during data loading, preprocessing, or mapping lookup.")

print("\nProcess finished.")

  from .autonotebook import tqdm as notebook_tqdm
You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.


--- DeepAR Forecasting & Validation ---
--- (Nashik/Wheat: 2002-2023 Train, 2024 Validate) ---
------------------------------
Processing Training (2002-2023) Dataset
------------------------------
Loading Training (2002-2023) data from E:\elevatetrsest\crop price predictor\Crop_price_Prediction\data\edited_nashik_2002_2023.csv...
Loaded 6246 rows.
Constructing 'full_date'...
6246 rows after date construction.
6246 rows after ensuring price columns numeric.
Training (2002-2023) base data loaded. 6246 rows.
------------------------------
Processing Validation (2024) Dataset
------------------------------
Loading Validation (2024) data from E:\elevatetrsest\crop price predictor\Crop_price_Prediction\data\edited_nashik_test_2024.csv...
Loaded 278 rows.
Constructing 'full_date'...
278 rows after date construction.
278 rows after ensuring price columns numeric.
Validation (2024) base data loaded. 278 rows.

Selected: Maharashtra/Nashik/Wheat -> Encoded: St=6291, Di=6291, Co=6291

Filtering d

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


DeepAR Network parameters: 27.7k

Starting model training...
Error during DeepAR training or evaluation: `model` must be a `LightningModule` or `torch._dynamo.OptimizedModule`, got `DeepAR`

Process finished.


Traceback (most recent call last):
  File "C:\Users\Shiva\AppData\Local\Temp\ipykernel_26256\3624202359.py", line 282, in <module>
    trainer.fit(
  File "c:\Users\Shiva\.conda\envs\nlp_env\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 554, in fit
    model = _maybe_unwrap_optimized(model)
  File "c:\Users\Shiva\.conda\envs\nlp_env\lib\site-packages\pytorch_lightning\utilities\compile.py", line 111, in _maybe_unwrap_optimized
    raise TypeError(
TypeError: `model` must be a `LightningModule` or `torch._dynamo.OptimizedModule`, got `DeepAR`
