In [1]:
# import pandas as pd
# import glob
# import os

# # 1. Get a list of all CSV file paths
# path = '/content/drive/MyDrive/Agmarket/Price'
# all_files = glob.glob(os.path.join(path, "*.csv"))

# # 2. Use a list comprehension to read all files
# # We add a column 'Commodity' based on the filename for tracking
# df_list = []
# for filename in all_files:
#     df = pd.read_csv(filename)
#     # Extract filename without extension as the label
#     df['commodity_name'] = os.path.basename(filename).replace('.csv', '')
#     df_list.append(df)

# # 3. Combine everything into one DataFrame
# combined_df = pd.concat(df_list, axis=0, ignore_index=True)

# # 4. Export to a high-performance format
# combined_df.to_csv('combined_agriculture_data.csv', index=False)

In [2]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

# **Data Collection**

In [3]:
data = pd.read_csv('/content/drive/MyDrive/Agmarket/Price/combined_agriculture_data.csv')

In [4]:
df = pd.DataFrame(data)
df

Unnamed: 0,t,cmdty,market_id,market_name,state_id,state_name,district_id,district_name,variety,p_min,p_max,p_modal,commodity_name
0,2023-11-03,Rice,1593,Bhulath,3,Punjab,36,Kapurthala,1009 Kar,700.0,700.0,700.0,PJ_20
1,2023-10-07,Rice,1614,Ahmedgarh,3,Punjab,53,Sangrur,Other,2203.0,2203.0,2203.0,PJ_20
2,2023-10-06,Rice,1614,Ahmedgarh,3,Punjab,53,Sangrur,Other,2203.0,2203.0,2203.0,PJ_20
3,2023-10-05,Rice,1614,Ahmedgarh,3,Punjab,53,Sangrur,Other,2203.0,2203.0,2203.0,PJ_20
4,2023-07-11,Rice,1581,Zira,3,Punjab,43,Firozpur,1009 Kar,100.0,300.0,200.0,PJ_20
...,...,...,...,...,...,...,...,...,...,...,...,...,...
2496799,2016-01-03,Tomato,770,Chickkaballapura,29,Karnataka,582,Chikkaballapura,Tomato,700.0,4500.0,3800.0,price_data - 2026-01-04T141759.571
2496800,2016-01-03,Tomato,120,Kolar,29,Karnataka,581,Kolar,Tomato,1000.0,4667.0,3453.0,price_data - 2026-01-04T141759.571
2496801,2016-01-03,Tomato,1767,Mulabagilu,29,Karnataka,581,Kolar,Tomato,1665.0,4665.0,3165.0,price_data - 2026-01-04T141759.571
2496802,2016-01-03,Tomato,515,Srinivasapur,29,Karnataka,581,Kolar,Tomato,1300.0,4800.0,3050.0,price_data - 2026-01-04T141759.571


In [5]:
df = df.drop(['commodity_name','market_id','state_id','district_id'], axis=1)

In [6]:
df.rename(columns={'t': 'Date','cmdty':'Commodity','market_name':'Market','state_name':'State','district_name':'District','variety':'Variety'}, inplace=True)

In [7]:
df = df.sort_values(by=['Date', 'Commodity', 'State'],ignore_index=True)

# **EDA and pre-processing**

In [8]:
# df.info()
# df.set_index('Date', inplace=True)

In [9]:
print(df['Commodity'].nunique())
print(df['Market'].nunique())
print(df['State'].nunique())

11
2477
32


In [10]:
# outliers
invalid_prices = df[
    (df['p_min'] > df['p_modal']) | (df['p_modal'] > df['p_max'])
]

len(invalid_prices)


4648

In [11]:
# Check if price is -ve (outliers again)
(df[['p_min', 'p_max', 'p_modal']] < 0).sum()

Unnamed: 0,0
p_min,0
p_max,0
p_modal,0


In [12]:
df['p_modal'].describe()

Unnamed: 0,p_modal
count,2496630.0
mean,2675.935
std,3024.149
min,1.0
25%,1400.0
50%,1975.0
75%,2900.0
max,460000.0


In [13]:
# Check if data is balanced
df['Commodity'].value_counts().head(10)

Unnamed: 0_level_0,count
Commodity,Unnamed: 1_level_1
Tomato,543947
Onion,401644
Paddy(Dhan)(Common),373082
Potato,323748
Maize,235419
Apple,192984
Wheat,141928
Rice,124675
Jowar (Sorghum),102086
Ragi (Finger Millet),54032


In [14]:
df

Unnamed: 0,Date,Commodity,Market,State,District,Variety,p_min,p_max,p_modal
0,2015-12-13,Potato,Chandigarh(Grain/Fruit),Chandigarh,Chandigarh,Other,200.0,500.0,350.0
1,2015-12-13,Potato,Keshopur,NCT of Delhi,West,Potato,500.0,900.0,800.0
2,2015-12-13,Potato,Ajmer(F&V),Rajasthan,Ajmer,Other,500.0,700.0,650.0
3,2015-12-13,Potato,Alwar (F&V),Rajasthan,Alwar,Other,300.0,800.0,600.0
4,2015-12-13,Potato,Sriganganagar (F&V),Rajasthan,Ganganagar,Other,600.0,600.0,600.0
...,...,...,...,...,...,...,...,...,...
2496799,2025-10-30,Wheat,Jiaganj,West Bengal,Murshidabad,Sonalika,2600.0,2620.0,2610.0
2496800,2025-10-30,Wheat,Asansol,West Bengal,Barddhaman,Kalyan,2765.0,3050.0,2850.0
2496801,2025-10-30,Wheat,Asansol,West Bengal,Barddhaman,Kalyan,2550.0,2750.0,2600.0
2496802,2025-10-30,Wheat,Durgapur,West Bengal,Barddhaman,Kalyan,2765.0,3050.0,2850.0


In [15]:
df['Commodity'].unique()

array(['Potato', 'Wheat', 'Paddy(Dhan)(Common)', 'Rice', 'Barley (Jau)',
       'Jowar (Sorghum)', 'Maize', 'Onion', 'Ragi (Finger Millet)',
       'Apple', 'Tomato'], dtype=object)

In [16]:
# check unique for state, district, market, variety
df['State'].unique()

array(['Chandigarh', 'NCT of Delhi', 'Rajasthan', 'Tripura',
       'Uttarakhand', 'Jammu & Kashmir', 'Meghalaya', 'Haryana', 'Punjab',
       'Gujarat', 'Jharkhand', 'Karnataka', 'Kerala', 'Madhya Pradesh',
       'Puducherry', 'Assam', 'Bihar', 'Chhattisgarh', 'Odisha',
       'West Bengal', 'Telangana', 'Manipur', 'Maharashtra',
       'Andhra Pradesh', 'Tamil Nadu', 'Uttar Pradesh',
       'Himachal Pradesh', 'Andaman & Nicobar Islands', 'Goa', 'Nagaland',
       'Arunachal Pradesh', 'Mizoram'], dtype=object)

In [17]:
df['District'].nunique()

527

In [18]:
df.isna().sum()

Unnamed: 0,0
Date,0
Commodity,0
Market,0
State,0
District,0
Variety,0
p_min,9970
p_max,7988
p_modal,174


In [19]:
#df = df.reset_index()

In [20]:
df

Unnamed: 0,Date,Commodity,Market,State,District,Variety,p_min,p_max,p_modal
0,2015-12-13,Potato,Chandigarh(Grain/Fruit),Chandigarh,Chandigarh,Other,200.0,500.0,350.0
1,2015-12-13,Potato,Keshopur,NCT of Delhi,West,Potato,500.0,900.0,800.0
2,2015-12-13,Potato,Ajmer(F&V),Rajasthan,Ajmer,Other,500.0,700.0,650.0
3,2015-12-13,Potato,Alwar (F&V),Rajasthan,Alwar,Other,300.0,800.0,600.0
4,2015-12-13,Potato,Sriganganagar (F&V),Rajasthan,Ganganagar,Other,600.0,600.0,600.0
...,...,...,...,...,...,...,...,...,...
2496799,2025-10-30,Wheat,Jiaganj,West Bengal,Murshidabad,Sonalika,2600.0,2620.0,2610.0
2496800,2025-10-30,Wheat,Asansol,West Bengal,Barddhaman,Kalyan,2765.0,3050.0,2850.0
2496801,2025-10-30,Wheat,Asansol,West Bengal,Barddhaman,Kalyan,2550.0,2750.0,2600.0
2496802,2025-10-30,Wheat,Durgapur,West Bengal,Barddhaman,Kalyan,2765.0,3050.0,2850.0


In [21]:
def missing_dates(group):
    full_range = pd.date_range(group['Date'].min(), group['Date'].max())
    return len(full_range) - group['Date'].nunique()

missing_by_group = (
    df.groupby(['Commodity', 'Market'])
      .apply(missing_dates)
      .sort_values(ascending=False)
      .head(10)
)

missing_by_group

  .apply(missing_dates)


Unnamed: 0_level_0,Unnamed: 1_level_0,0
Commodity,Market,Unnamed: 2_level_1
Wheat,Kalagategi,3589
Potato,Mawkyrwat,3572
Rice,Nongpoh (R-Bhoi),3545
Jowar (Sorghum),Manapparai,3544
Maize,Kesinga,3527
Paddy(Dhan)(Common),Pundri,3524
Wheat,Jaspur,3520
Paddy(Dhan)(Common),Kudchi,3519
Rice,Ramnagar,3517
Paddy(Dhan)(Common),Panipat,3504


In [22]:
df['Date'] = pd.to_datetime(df['Date'])

In [23]:
df['year'] = df['Date'].dt.year
df['month'] = df['Date'].dt.month
df['week'] = df['Date'].dt.isocalendar().week
df['dayofweek'] = df['Date'].dt.dayofweek

In [24]:
potato = df[(df['Commodity'] == 'Tomato') & (df['year'] == 2025)]

monthly_avg = (
    potato.groupby('month')['p_modal']
    .mean()
)

monthly_avg

Unnamed: 0_level_0,p_modal
month,Unnamed: 1_level_1
1,2050.607071
2,1944.553432
3,1464.871813
4,1616.676784
5,1673.006991
6,2291.452133
7,3131.431784
8,3931.72715
9,2494.922061
10,2552.252481


In [25]:
potato_market_stats = (
    potato.groupby('State')['p_modal']
    .mean()
    .sort_values(ascending=False)
    .head(10)
)

potato_market_stats

Unnamed: 0_level_0,p_modal
State,Unnamed: 1_level_1
Nagaland,4436.931818
Kerala,3223.187437
Tripura,3146.337918
Jammu & Kashmir,2494.063005
Himachal Pradesh,2478.468642
Tamil Nadu,2286.398699
NCT of Delhi,1828.855842
Chandigarh,1815.625
Assam,1772.559815
Gujarat,1725.435769


In [26]:
df = df.sort_values(['Commodity', 'Market', 'Date'], ignore_index = True)

In [27]:
df['price_change'] = (
    df.groupby(['Commodity', 'Market'])['p_modal']
      .diff()
)

In [28]:
df['pct_change'] = (
    df.groupby(['Commodity', 'Market'])['p_modal']
      .pct_change()
)

  .pct_change()


In [29]:
df

Unnamed: 0,Date,Commodity,Market,State,District,Variety,p_min,p_max,p_modal,year,month,week,dayofweek,price_change,pct_change
0,2022-12-22,Apple,Aarah,Bihar,Bhojpur,Apple,4000.0,6000.0,5000.0,2022,12,51,3,,
1,2022-12-24,Apple,Aarah,Bihar,Bhojpur,Apple,4000.0,6000.0,5000.0,2022,12,51,5,0.0,0.0
2,2022-12-31,Apple,Aarah,Bihar,Bhojpur,Apple,4000.0,6000.0,5000.0,2022,12,52,5,0.0,0.0
3,2023-01-25,Apple,Aarah,Bihar,Bhojpur,Apple,4000.0,6000.0,5000.0,2023,1,4,2,0.0,0.0
4,2023-01-31,Apple,Aarah,Bihar,Bhojpur,Apple,6000.0,8000.0,7000.0,2023,1,5,1,2000.0,0.4
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2496799,2025-04-24,Wheat,kalanwali,Haryana,Sirsa,Other,2425.0,2425.0,2425.0,2025,4,17,3,0.0,0.0
2496800,2025-04-30,Wheat,kalanwali,Haryana,Sirsa,Other,2425.0,2425.0,2425.0,2025,4,18,2,0.0,0.0
2496801,2025-05-03,Wheat,kalanwali,Haryana,Sirsa,Other,2425.0,2425.0,2425.0,2025,5,18,5,0.0,0.0
2496802,2025-05-10,Wheat,kalanwali,Haryana,Sirsa,Other,2425.0,2425.0,2425.0,2025,5,19,5,0.0,0.0


In [30]:
print(df['price_change'].isna().sum())
print(df['pct_change'].isna().sum())

5130
4905


In [31]:
df['price_spike'] = df['pct_change'] > 0.15  # 15% daily spike
df['price_spike'].mean()

np.float64(0.07786994894272839)

In [32]:
final_df_for_analysis = df.copy('')
final_df_for_analysis

Unnamed: 0,Date,Commodity,Market,State,District,Variety,p_min,p_max,p_modal,year,month,week,dayofweek,price_change,pct_change,price_spike
0,2022-12-22,Apple,Aarah,Bihar,Bhojpur,Apple,4000.0,6000.0,5000.0,2022,12,51,3,,,False
1,2022-12-24,Apple,Aarah,Bihar,Bhojpur,Apple,4000.0,6000.0,5000.0,2022,12,51,5,0.0,0.0,False
2,2022-12-31,Apple,Aarah,Bihar,Bhojpur,Apple,4000.0,6000.0,5000.0,2022,12,52,5,0.0,0.0,False
3,2023-01-25,Apple,Aarah,Bihar,Bhojpur,Apple,4000.0,6000.0,5000.0,2023,1,4,2,0.0,0.0,False
4,2023-01-31,Apple,Aarah,Bihar,Bhojpur,Apple,6000.0,8000.0,7000.0,2023,1,5,1,2000.0,0.4,True
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2496799,2025-04-24,Wheat,kalanwali,Haryana,Sirsa,Other,2425.0,2425.0,2425.0,2025,4,17,3,0.0,0.0,False
2496800,2025-04-30,Wheat,kalanwali,Haryana,Sirsa,Other,2425.0,2425.0,2425.0,2025,4,18,2,0.0,0.0,False
2496801,2025-05-03,Wheat,kalanwali,Haryana,Sirsa,Other,2425.0,2425.0,2425.0,2025,5,18,5,0.0,0.0,False
2496802,2025-05-10,Wheat,kalanwali,Haryana,Sirsa,Other,2425.0,2425.0,2425.0,2025,5,19,5,0.0,0.0,False


In [33]:
df['State_District_Market'] = df['State'] + '_' + df['District'] + '_' + df['Market']
df

Unnamed: 0,Date,Commodity,Market,State,District,Variety,p_min,p_max,p_modal,year,month,week,dayofweek,price_change,pct_change,price_spike,State_District_Market
0,2022-12-22,Apple,Aarah,Bihar,Bhojpur,Apple,4000.0,6000.0,5000.0,2022,12,51,3,,,False,Bihar_Bhojpur_Aarah
1,2022-12-24,Apple,Aarah,Bihar,Bhojpur,Apple,4000.0,6000.0,5000.0,2022,12,51,5,0.0,0.0,False,Bihar_Bhojpur_Aarah
2,2022-12-31,Apple,Aarah,Bihar,Bhojpur,Apple,4000.0,6000.0,5000.0,2022,12,52,5,0.0,0.0,False,Bihar_Bhojpur_Aarah
3,2023-01-25,Apple,Aarah,Bihar,Bhojpur,Apple,4000.0,6000.0,5000.0,2023,1,4,2,0.0,0.0,False,Bihar_Bhojpur_Aarah
4,2023-01-31,Apple,Aarah,Bihar,Bhojpur,Apple,6000.0,8000.0,7000.0,2023,1,5,1,2000.0,0.4,True,Bihar_Bhojpur_Aarah
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2496799,2025-04-24,Wheat,kalanwali,Haryana,Sirsa,Other,2425.0,2425.0,2425.0,2025,4,17,3,0.0,0.0,False,Haryana_Sirsa_kalanwali
2496800,2025-04-30,Wheat,kalanwali,Haryana,Sirsa,Other,2425.0,2425.0,2425.0,2025,4,18,2,0.0,0.0,False,Haryana_Sirsa_kalanwali
2496801,2025-05-03,Wheat,kalanwali,Haryana,Sirsa,Other,2425.0,2425.0,2425.0,2025,5,18,5,0.0,0.0,False,Haryana_Sirsa_kalanwali
2496802,2025-05-10,Wheat,kalanwali,Haryana,Sirsa,Other,2425.0,2425.0,2425.0,2025,5,19,5,0.0,0.0,False,Haryana_Sirsa_kalanwali


# **Data Preparation**

In [34]:
deepar_df = df[['Date', 'Commodity', 'State_District_Market', 'p_modal']].copy()
deepar_df['Union'] = df['Commodity'] + "_" + df['State_District_Market']
deepar_df = deepar_df.sort_values(by=['Date', 'Union'], ignore_index=True)
deepar_df = deepar_df.drop('State_District_Market', axis = 1)
deepar_df

Unnamed: 0,Date,Commodity,p_modal,Union
0,2015-12-13,Potato,350.0,Potato_Chandigarh_Chandigarh_Chandigarh(Grain/...
1,2015-12-13,Potato,800.0,Potato_NCT of Delhi_West_Keshopur
2,2015-12-13,Potato,650.0,Potato_Rajasthan_Ajmer_Ajmer(F&V)
3,2015-12-13,Potato,600.0,Potato_Rajasthan_Alwar_Alwar (F&V)
4,2015-12-13,Potato,600.0,Potato_Rajasthan_Ganganagar_Sriganganagar (F&V)
...,...,...,...,...
2496799,2025-10-30,Wheat,2550.0,Wheat_West Bengal_Birbhum_Bolpur
2496800,2025-10-30,Wheat,2550.0,Wheat_West Bengal_Birbhum_Rampurhat
2496801,2025-10-30,Wheat,2550.0,Wheat_West Bengal_Birbhum_Sainthia
2496802,2025-10-30,Wheat,2610.0,Wheat_West Bengal_Murshidabad_Jiaganj


DeepAR requires an integer time index, not datetime.
Because internally
DeepAR uses RNNs and
RNNs operate on ordered sequences, not calendar dates

In [35]:
deepar_df['time_idx'] = (
        deepar_df.groupby('Union')['Date']
          .rank(method='dense')
          .astype(int)
    )

In [36]:
deepar_df.drop(['Commodity'], axis=1, inplace=True)

In [37]:
deepar_df = deepar_df.dropna(subset=["p_modal"])

In [38]:
deepar_test = deepar_df[deepar_df["Date"] >= "2023-01-01"].copy()
deepar_train = deepar_df[deepar_df["Date"] < "2023-01-01"].copy() #same data for train and val

In [39]:
deepar_test.drop(['Date'], axis=1, inplace=True)
deepar_train.drop(['Date'], axis=1, inplace=True)

In [40]:
deepar_train

Unnamed: 0,p_modal,Union,time_idx
0,350.0,Potato_Chandigarh_Chandigarh_Chandigarh(Grain/...,1
1,800.0,Potato_NCT of Delhi_West_Keshopur,1
2,650.0,Potato_Rajasthan_Ajmer_Ajmer(F&V),1
3,600.0,Potato_Rajasthan_Alwar_Alwar (F&V),1
4,600.0,Potato_Rajasthan_Ganganagar_Sriganganagar (F&V),1
...,...,...,...
1554545,2450.0,Wheat_West Bengal_Barddhaman_Asansol,1900
1554546,2360.0,Wheat_West Bengal_Barddhaman_Durgapur,2024
1554547,2150.0,Wheat_West Bengal_Barddhaman_Durgapur,2024
1554548,2080.0,Wheat_West Bengal_Nadia_Karimpur,439


# **DeepAR**

In [41]:
!pip install --no-cache-dir -v lightning pytorch-forecasting

Using pip 24.1.2 from /usr/local/lib/python3.12/dist-packages/pip (python 3.12)
Collecting lightning
  Obtaining dependency information for lightning from https://files.pythonhosted.org/packages/d6/e9/36b340c7ec01dad6f034481e98fc9fc0133307beb05c714c0542af98bbde/lightning-2.6.0-py3-none-any.whl.metadata
  Downloading lightning-2.6.0-py3-none-any.whl.metadata (44 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.9/44.9 kB[0m [31m130.0 MB/s[0m eta [36m0:00:00[0m
[?25h  Link requires a different Python (3.12.12 not in: '>=3.8,<3.11'): https://files.pythonhosted.org/packages/28/9f/0ae16cc808f9b8bf6d4e6df3d32791ff37ef45e46fc72981e9ea5200b866/pytorch_forecasting-0.10.2-py3-none-any.whl (from https://pypi.org/simple/pytorch-forecasting/) (requires-python:>=3.8,<3.11)
  Link requires a different Python (3.12.12 not in: '>=3.8,<3.11'): https://files.pythonhosted.org/packages/5a/e6/a8c342e299abe24ff47df0ad4f809c5dd8057f4775691b6334e6972818c3/pytorch_forecasting-0.10.2.ta

In [42]:
!pip uninstall pytorch-lightning

Found existing installation: pytorch-lightning 2.6.0
Uninstalling pytorch-lightning-2.6.0:
  Would remove:
    /usr/local/lib/python3.12/dist-packages/lightning_fabric/*
    /usr/local/lib/python3.12/dist-packages/pytorch_lightning-2.6.0.dist-info/*
    /usr/local/lib/python3.12/dist-packages/pytorch_lightning/*
Proceed (Y/n)? Y
  Successfully uninstalled pytorch-lightning-2.6.0


In [43]:
!pip show lightning
!pip show pytorch-forecasting

Name: lightning
Version: 2.6.0
Summary: The Deep Learning framework to train, deploy, and ship AI products Lightning fast.
Home-page: https://github.com/Lightning-AI/lightning
Author: Lightning AI et al.
Author-email: developer@lightning.ai
License: Apache-2.0
Location: /usr/local/lib/python3.12/dist-packages
Requires: fsspec, lightning-utilities, packaging, pytorch-lightning, PyYAML, torch, torchmetrics, tqdm, typing-extensions
Required-by: pytorch-forecasting
Name: pytorch-forecasting
Version: 1.6.0
Summary: Forecasting timeseries with PyTorch - dataloaders, normalizers, metrics and models
Home-page: 
Author: Jan Beitner
Author-email: 
License: 
Location: /usr/local/lib/python3.12/dist-packages
Requires: lightning, numpy, pandas, scikit-base, scikit-learn, scipy, torch
Required-by: 


In [44]:
import pandas as pd
import torch
import lightning as pl
from pytorch_forecasting import DeepAR
from pytorch_forecasting.data import (TimeSeriesDataSet,GroupNormalizer)
from pytorch_forecasting.metrics import NormalDistributionLoss

In [45]:
# deepar_train = deepar_df[deepar_df['Date'] < '2022-01-01'].copy()
# deepar_val = deepar_df[(deepar_df['Date'] >= '2022-01-01') & (deepar_df['Date'] < '2023-01-01')].copy()
# deepar_test = deepar_df[deepar_df['Date'] >= '2023-01-01'].copy()

Why shorter encoder is BETTER for agri price data
> Reason 1 — Local dynamics dominate
Short-term price movement is driven by Recent supply, demand  
Not by
Prices from 2 years ago  
So the last 10–14 observations are far more predictive than the last 30–60.

> Reason 2 — Long histories introduce noise

In [46]:
max_prediction_length = 7
max_encoder_length = 14

In [47]:
training_cut_off = deepar_train['time_idx'].max() - max_prediction_length

In [48]:
training = TimeSeriesDataSet(
    deepar_train[deepar_train.time_idx <= training_cut_off],
    time_idx="time_idx",
    target="p_modal",
    group_ids=["Union"],
    max_encoder_length=max_encoder_length,
    max_prediction_length=max_prediction_length,
    time_varying_unknown_reals=["p_modal"],
    target_normalizer=GroupNormalizer(groups=["Union"]), #normalize y
    allow_missing_timesteps=True
)



In [49]:
validation = TimeSeriesDataSet.from_dataset(
    training,  #inherit all imp parameters like 'groupby', 'max_enco','max_pred', etc., from training set
    deepar_train,
    predict=True, #this says that the set is for prediction (therefore validation set)
    stop_randomization=True
    #stop randomization of enc and deco len within specified min,max_enc_len and min,max_pred_len ranges cuz you want a consistent and reproducible benchmark.
    # this is false in training cuz it exposes the model to a wider variety of i/p sequence len and pred horizons during training, making it more robust and less prone to overfitting to specific sequence len.
)



In [50]:
# Data Loader
batch_size = 64
train_loader = training.to_dataloader(
    train=True, # Randomly samples forecast windows and shuffles them (shuffles windows not time order) which improves generalization
    batch_size=batch_size,
    num_workers=2
)

val_loader = validation.to_dataloader(
    train=False, # No randomization. No gradient updates
    batch_size=batch_size,
    num_workers=2
)

In [51]:
#model definition
model = DeepAR.from_dataset(
    training,
    learning_rate=1e-3, #DeepAR uses Adam optimizer. Adam is stable in the range: 1e-4 to 1e-3
    hidden_size=64,
    rnn_layers=2, # layer1 - short term, layer2 - mideium term
    # 3 is overkill, 1 leads to overfit
    dropout=0.1,
    loss=NormalDistributionLoss(),
    # DeepAR models a probability distribution, not quantiles directly — therefore it must be trained with a distribution-based loss.
)

/usr/local/lib/python3.12/dist-packages/lightning/pytorch/utilities/parsing.py:210: Attribute 'loss' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['loss'])`.
/usr/local/lib/python3.12/dist-packages/lightning/pytorch/utilities/parsing.py:210: Attribute 'logging_metrics' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['logging_metrics'])`.


In [None]:
# train
trainer = pl.Trainer(
    max_epochs=20,
    accelerator="auto",
    gradient_clip_val=0.1, # to prevent exploding gradients during training
    enable_model_summary=True
)

trainer.fit(
    model,
    train_dataloaders=train_loader,
    val_dataloaders=val_loader
)

INFO: 💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
INFO:lightning.pytorch.utilities.rank_zero:💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
INFO: GPU available: True (cuda), used: True
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Output()

In [None]:
# test_dataset = TimeSeriesDataSet.from_dataset(
#     training,
#     deepar_test,
#     predict=True,
#     stop_randomization=True
# )

In [None]:
!pip list | grep -E "lightning|forecasting"

In [None]:
import pytorch_forecasting
from pytorch_forecasting import DeepAR

import inspect

print(pytorch_forecasting.__version__)
print(inspect.getmodule(DeepAR))

In [None]:
import lightning.pytorch as pl
from pytorch_forecasting import DeepAR
from pytorch_forecasting.metrics import NormalDistributionLoss

model = DeepAR.from_dataset(
    training,
    learning_rate=1e-3,
    hidden_size=64,
    rnn_layers=2,
    dropout=0.1,
    loss=NormalDistributionLoss(),
)

print(type(model))
print(pl.LightningModule)
print(isinstance(model, pl.LightningModule))