In [1]:
import pandas as pd
import numpy as np
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_absolute_error


In [2]:
features = pd.read_csv("../data/processed/feature_dataset.csv")

print(features.shape)
features.head()


(1881, 12)


Unnamed: 0,state,date,total_enrolment,monthly_growth,child_ratio,youth_ratio,adult_ratio,demo_update_pressure,biometric_update_pressure,demo_pressure_ratio,biometric_pressure_ratio,risk_score
0,100000,2025-09-02,3,,0.0,0.0,1.0,,,,,0.0
1,100000,2025-09-03,1,-0.666667,0.0,0.0,1.0,,,,,0.266667
2,100000,2025-09-08,1,0.0,0.0,0.0,1.0,,,,,0.0
3,100000,2025-09-09,1,0.0,0.0,0.0,1.0,,,,,0.0
4,100000,2025-09-11,2,1.0,0.0,0.0,1.0,,,,,0.4


In [3]:
features['date'] = pd.to_datetime(features['date'])

features = features.sort_values(['state', 'date'])

features = features.dropna(subset=['risk_score'])


In [4]:
features['time_index'] = (
    features
    .groupby('state')
    .cumcount()
)


In [5]:
models = {}
predictions = []

for state, df in features.groupby('state'):
    if len(df) < 6:
        continue  # skip very small states
    
    X = df[['time_index']]
    y = df['risk_score']
    
    model = LinearRegression()
    model.fit(X, y)
    
    # predict next time step
    next_time = df['time_index'].max() + 1
    future_risk = model.predict([[next_time]])[0]
    
    predictions.append({
        'state': state,
        'predicted_risk_score': future_risk
    })
    
    models[state] = model




In [6]:
forecast_df = pd.DataFrame(predictions)

forecast_df = forecast_df.sort_values(
    'predicted_risk_score',
    ascending=False
)

forecast_df.head()


Unnamed: 0,state,predicted_risk_score
42,Uttar Pradesh,1.073926
28,Manipur,0.84063
29,Meghalaya,0.772353
6,Bihar,0.754004
5,Assam,0.524385


In [7]:
forecast_df.to_csv(
    "../results/state_risk_forecast.csv",
    index=False
)

print("✅ FORECAST SAVED")


✅ FORECAST SAVED


In [8]:
forecast_df.head(10)


Unnamed: 0,state,predicted_risk_score
42,Uttar Pradesh,1.073926
28,Manipur,0.84063
29,Meghalaya,0.772353
6,Bihar,0.754004
5,Assam,0.524385
24,Ladakh,0.493978
9,Dadra & Nagar Haveli,0.369231
14,Delhi,0.175447
45,West Bengal,0.174216
0,100000,0.062963
