In [28]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import GridSearchCV, cross_validate, train_test_split
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report, r2_score, mean_absolute_error
from sklearn.preprocessing import OneHotEncoder, LabelEncoder, MinMaxScaler, StandardScaler
from sklearn.tree import DecisionTreeRegressor, DecisionTreeClassifier
from sklearn.impute import SimpleImputer

In [2]:
def data_loader(path: str) -> pd.DataFrame:
  data = pd.read_csv(path)
  print(f"{path} Loaded.")

  return data

data = data_loader('/content/crop_yield.csv')

/content/crop_yield.csv Loaded.


In [3]:
data

Unnamed: 0,Crop,Crop_Year,Season,State,Area,Production,Annual_Rainfall,Fertilizer,Pesticide,Yield
0,Arecanut,1997,Whole Year,Assam,73814.0,56708,2051.4,7024878.38,22882.34,0.796087
1,Arhar/Tur,1997,Kharif,Assam,6637.0,4685,2051.4,631643.29,2057.47,0.710435
2,Castor seed,1997,Kharif,Assam,796.0,22,2051.4,75755.32,246.76,0.238333
3,Coconut,1997,Whole Year,Assam,19656.0,126905000,2051.4,1870661.52,6093.36,5238.051739
4,Cotton(lint),1997,Kharif,Assam,1739.0,794,2051.4,165500.63,539.09,0.420909
...,...,...,...,...,...,...,...,...,...,...
19684,Small millets,1998,Kharif,Nagaland,4000.0,2000,1498.0,395200.00,1160.00,0.500000
19685,Wheat,1998,Rabi,Nagaland,1000.0,3000,1498.0,98800.00,290.00,3.000000
19686,Maize,1997,Kharif,Jammu and Kashmir,310883.0,440900,1356.2,29586735.11,96373.73,1.285000
19687,Rice,1997,Kharif,Jammu and Kashmir,275746.0,5488,1356.2,26242746.82,85481.26,0.016667


In [4]:
data.columns = data.columns.str.strip().str.lower()

In [5]:
data

Unnamed: 0,crop,crop_year,season,state,area,production,annual_rainfall,fertilizer,pesticide,yield
0,Arecanut,1997,Whole Year,Assam,73814.0,56708,2051.4,7024878.38,22882.34,0.796087
1,Arhar/Tur,1997,Kharif,Assam,6637.0,4685,2051.4,631643.29,2057.47,0.710435
2,Castor seed,1997,Kharif,Assam,796.0,22,2051.4,75755.32,246.76,0.238333
3,Coconut,1997,Whole Year,Assam,19656.0,126905000,2051.4,1870661.52,6093.36,5238.051739
4,Cotton(lint),1997,Kharif,Assam,1739.0,794,2051.4,165500.63,539.09,0.420909
...,...,...,...,...,...,...,...,...,...,...
19684,Small millets,1998,Kharif,Nagaland,4000.0,2000,1498.0,395200.00,1160.00,0.500000
19685,Wheat,1998,Rabi,Nagaland,1000.0,3000,1498.0,98800.00,290.00,3.000000
19686,Maize,1997,Kharif,Jammu and Kashmir,310883.0,440900,1356.2,29586735.11,96373.73,1.285000
19687,Rice,1997,Kharif,Jammu and Kashmir,275746.0,5488,1356.2,26242746.82,85481.26,0.016667


In [6]:
data.isna().any()
data.isnull().sum() / len(data) * 100
data.duplicated().sum()
data.describe().T.round(2)

Unnamed: 0,count,mean,std,min,25%,50%,75%,max
crop_year,19689.0,2009.13,6.5,1997.0,2004.0,2010.0,2015.0,2020.0
area,19689.0,179926.57,732828.7,0.5,1390.0,9317.0,75112.0,50808100.0
production,19689.0,16435941.27,263056800.0,0.0,1393.0,13804.0,122718.0,6326000000.0
annual_rainfall,19689.0,1437.76,816.91,301.3,940.7,1247.6,1643.7,6552.7
fertilizer,19689.0,24103312.45,94946000.0,54.17,188014.62,1234957.44,10003847.2,4835407000.0
pesticide,19689.0,48848.35,213287.4,0.09,356.7,2421.9,20041.7,15750510.0
yield,19689.0,79.95,878.31,0.0,0.6,1.03,2.39,21105.0


In [7]:
data

Unnamed: 0,crop,crop_year,season,state,area,production,annual_rainfall,fertilizer,pesticide,yield
0,Arecanut,1997,Whole Year,Assam,73814.0,56708,2051.4,7024878.38,22882.34,0.796087
1,Arhar/Tur,1997,Kharif,Assam,6637.0,4685,2051.4,631643.29,2057.47,0.710435
2,Castor seed,1997,Kharif,Assam,796.0,22,2051.4,75755.32,246.76,0.238333
3,Coconut,1997,Whole Year,Assam,19656.0,126905000,2051.4,1870661.52,6093.36,5238.051739
4,Cotton(lint),1997,Kharif,Assam,1739.0,794,2051.4,165500.63,539.09,0.420909
...,...,...,...,...,...,...,...,...,...,...
19684,Small millets,1998,Kharif,Nagaland,4000.0,2000,1498.0,395200.00,1160.00,0.500000
19685,Wheat,1998,Rabi,Nagaland,1000.0,3000,1498.0,98800.00,290.00,3.000000
19686,Maize,1997,Kharif,Jammu and Kashmir,310883.0,440900,1356.2,29586735.11,96373.73,1.285000
19687,Rice,1997,Kharif,Jammu and Kashmir,275746.0,5488,1356.2,26242746.82,85481.26,0.016667


In [8]:
def feature(data: str) -> pd.DataFrame:
  data['crop_age'] = 2025 - data['crop_year']
  data['area_log'] = np.log(data['area'])
  data['fertilizer_pre_pesticide'] = data['pesticide'] /  data['fertilizer']
feature(data)

In [9]:
data

Unnamed: 0,crop,crop_year,season,state,area,production,annual_rainfall,fertilizer,pesticide,yield,crop_age,area_log,fertilizer_pre_pesticide
0,Arecanut,1997,Whole Year,Assam,73814.0,56708,2051.4,7024878.38,22882.34,0.796087,28,11.209304,0.003257
1,Arhar/Tur,1997,Kharif,Assam,6637.0,4685,2051.4,631643.29,2057.47,0.710435,28,8.800415,0.003257
2,Castor seed,1997,Kharif,Assam,796.0,22,2051.4,75755.32,246.76,0.238333,28,6.679599,0.003257
3,Coconut,1997,Whole Year,Assam,19656.0,126905000,2051.4,1870661.52,6093.36,5238.051739,28,9.886138,0.003257
4,Cotton(lint),1997,Kharif,Assam,1739.0,794,2051.4,165500.63,539.09,0.420909,28,7.461066,0.003257
...,...,...,...,...,...,...,...,...,...,...,...,...,...
19684,Small millets,1998,Kharif,Nagaland,4000.0,2000,1498.0,395200.00,1160.00,0.500000,27,8.294050,0.002935
19685,Wheat,1998,Rabi,Nagaland,1000.0,3000,1498.0,98800.00,290.00,3.000000,27,6.907755,0.002935
19686,Maize,1997,Kharif,Jammu and Kashmir,310883.0,440900,1356.2,29586735.11,96373.73,1.285000,28,12.647172,0.003257
19687,Rice,1997,Kharif,Jammu and Kashmir,275746.0,5488,1356.2,26242746.82,85481.26,0.016667,28,12.527235,0.003257


In [10]:
data['crop'].unique()

array(['Arecanut', 'Arhar/Tur', 'Castor seed', 'Coconut ', 'Cotton(lint)',
       'Dry chillies', 'Gram', 'Jute', 'Linseed', 'Maize', 'Mesta',
       'Niger seed', 'Onion', 'Other  Rabi pulses', 'Potato',
       'Rapeseed &Mustard', 'Rice', 'Sesamum', 'Small millets',
       'Sugarcane', 'Sweet potato', 'Tapioca', 'Tobacco', 'Turmeric',
       'Wheat', 'Bajra', 'Black pepper', 'Cardamom', 'Coriander',
       'Garlic', 'Ginger', 'Groundnut', 'Horse-gram', 'Jowar', 'Ragi',
       'Cashewnut', 'Banana', 'Soyabean', 'Barley', 'Khesari', 'Masoor',
       'Moong(Green Gram)', 'Other Kharif pulses', 'Safflower',
       'Sannhamp', 'Sunflower', 'Urad', 'Peas & beans (Pulses)',
       'other oilseeds', 'Other Cereals', 'Cowpea(Lobia)',
       'Oilseeds total', 'Guar seed', 'Other Summer Pulses', 'Moth'],
      dtype=object)

In [11]:
data

Unnamed: 0,crop,crop_year,season,state,area,production,annual_rainfall,fertilizer,pesticide,yield,crop_age,area_log,fertilizer_pre_pesticide
0,Arecanut,1997,Whole Year,Assam,73814.0,56708,2051.4,7024878.38,22882.34,0.796087,28,11.209304,0.003257
1,Arhar/Tur,1997,Kharif,Assam,6637.0,4685,2051.4,631643.29,2057.47,0.710435,28,8.800415,0.003257
2,Castor seed,1997,Kharif,Assam,796.0,22,2051.4,75755.32,246.76,0.238333,28,6.679599,0.003257
3,Coconut,1997,Whole Year,Assam,19656.0,126905000,2051.4,1870661.52,6093.36,5238.051739,28,9.886138,0.003257
4,Cotton(lint),1997,Kharif,Assam,1739.0,794,2051.4,165500.63,539.09,0.420909,28,7.461066,0.003257
...,...,...,...,...,...,...,...,...,...,...,...,...,...
19684,Small millets,1998,Kharif,Nagaland,4000.0,2000,1498.0,395200.00,1160.00,0.500000,27,8.294050,0.002935
19685,Wheat,1998,Rabi,Nagaland,1000.0,3000,1498.0,98800.00,290.00,3.000000,27,6.907755,0.002935
19686,Maize,1997,Kharif,Jammu and Kashmir,310883.0,440900,1356.2,29586735.11,96373.73,1.285000,28,12.647172,0.003257
19687,Rice,1997,Kharif,Jammu and Kashmir,275746.0,5488,1356.2,26242746.82,85481.26,0.016667,28,12.527235,0.003257


In [19]:
data.groupby('state').agg(
    avg_yield = ('yield', 'median')
).reset_index().sort_values(
    by='avg_yield', ascending=False
).reset_index(
    drop=True
).unstack(
    level=2
)

Unnamed: 0,Unnamed: 1,0
state,0,Delhi
state,1,Goa
state,2,Puducherry
state,3,Kerala
state,4,Arunachal Pradesh
state,5,Meghalaya
state,6,Tamil Nadu
state,7,Mizoram
state,8,Telangana
state,9,Gujarat


In [20]:
data

Unnamed: 0,crop,crop_year,season,state,area,production,annual_rainfall,fertilizer,pesticide,yield,crop_age,area_log,fertilizer_pre_pesticide
0,Arecanut,1997,Whole Year,Assam,73814.0,56708,2051.4,7024878.38,22882.34,0.796087,28,11.209304,0.003257
1,Arhar/Tur,1997,Kharif,Assam,6637.0,4685,2051.4,631643.29,2057.47,0.710435,28,8.800415,0.003257
2,Castor seed,1997,Kharif,Assam,796.0,22,2051.4,75755.32,246.76,0.238333,28,6.679599,0.003257
3,Coconut,1997,Whole Year,Assam,19656.0,126905000,2051.4,1870661.52,6093.36,5238.051739,28,9.886138,0.003257
4,Cotton(lint),1997,Kharif,Assam,1739.0,794,2051.4,165500.63,539.09,0.420909,28,7.461066,0.003257
...,...,...,...,...,...,...,...,...,...,...,...,...,...
19684,Small millets,1998,Kharif,Nagaland,4000.0,2000,1498.0,395200.00,1160.00,0.500000,27,8.294050,0.002935
19685,Wheat,1998,Rabi,Nagaland,1000.0,3000,1498.0,98800.00,290.00,3.000000,27,6.907755,0.002935
19686,Maize,1997,Kharif,Jammu and Kashmir,310883.0,440900,1356.2,29586735.11,96373.73,1.285000,28,12.647172,0.003257
19687,Rice,1997,Kharif,Jammu and Kashmir,275746.0,5488,1356.2,26242746.82,85481.26,0.016667,28,12.527235,0.003257


In [22]:
data.columns


to_drop = ['crop', 'season', 'state']

data.drop(
    columns = to_drop, inplace=True
)

In [23]:
data

Unnamed: 0,crop_year,area,production,annual_rainfall,fertilizer,pesticide,yield,crop_age,area_log,fertilizer_pre_pesticide
0,1997,73814.0,56708,2051.4,7024878.38,22882.34,0.796087,28,11.209304,0.003257
1,1997,6637.0,4685,2051.4,631643.29,2057.47,0.710435,28,8.800415,0.003257
2,1997,796.0,22,2051.4,75755.32,246.76,0.238333,28,6.679599,0.003257
3,1997,19656.0,126905000,2051.4,1870661.52,6093.36,5238.051739,28,9.886138,0.003257
4,1997,1739.0,794,2051.4,165500.63,539.09,0.420909,28,7.461066,0.003257
...,...,...,...,...,...,...,...,...,...,...
19684,1998,4000.0,2000,1498.0,395200.00,1160.00,0.500000,27,8.294050,0.002935
19685,1998,1000.0,3000,1498.0,98800.00,290.00,3.000000,27,6.907755,0.002935
19686,1997,310883.0,440900,1356.2,29586735.11,96373.73,1.285000,28,12.647172,0.003257
19687,1997,275746.0,5488,1356.2,26242746.82,85481.26,0.016667,28,12.527235,0.003257


In [24]:
X = data.drop('yield', axis=1)
y = data['yield']

In [26]:
scaler = StandardScaler()

num_cols = X.select_dtypes(
    include=['int64', 'float64']
).columns.tolist()

X[num_cols] = scaler.fit_transform(X[num_cols])

In [27]:
X

Unnamed: 0,crop_year,area,production,annual_rainfall,fertilizer,pesticide,crop_age,area_log,fertilizer_pre_pesticide
0,-1.866375,-0.144802,-0.062267,0.751197,-0.179880,-0.121745,1.866375,0.721464,2.195722
1,-1.866375,-0.236473,-0.062464,0.751197,-0.247217,-0.219385,1.866375,-0.085373,2.195722
2,-1.866375,-0.244443,-0.062482,0.751197,-0.253072,-0.227875,1.866375,-0.795722,2.195722
3,-1.866375,-0.218707,0.419954,0.751197,-0.234167,-0.200462,1.866375,0.278281,2.195722
4,-1.866375,-0.243157,-0.062479,0.751197,-0.252127,-0.226504,1.866375,-0.533977,2.195722
...,...,...,...,...,...,...,...,...,...
19684,-1.712480,-0.240071,-0.062475,0.073749,-0.249707,-0.223593,1.712480,-0.254976,1.611729
19685,-1.712480,-0.244165,-0.062471,0.073749,-0.252829,-0.227672,1.712480,-0.719303,1.611729
19686,-1.866375,0.178704,-0.060806,-0.099836,0.057755,0.222829,1.866375,1.203065,2.195722
19687,-1.866375,0.130756,-0.062461,-0.099836,0.022534,0.171758,1.866375,1.162894,2.195722


In [30]:
X_train, X_test, y_train, y_test = train_test_split(
    X, y,
    test_size=0.2, random_state=42
)

In [37]:
tree_clf = DecisionTreeRegressor(random_state=42,
                                 max_depth=5, min_samples_split=2)
tree_clf.fit(
    X_train, y_train
)

In [35]:
param_grid = {
    'max_depth' : [1, 2, 5, 10],
    'min_samples_split' : [2, 5]
}

clf = GridSearchCV(
    tree_clf,
    param_grid=param_grid,
    cv=10
)

clf.fit(X, y)

In [40]:
print(f"Best params {clf.best_params_}")
print(f"Best score {clf.best_score_}")

Best params {'max_depth': 5, 'min_samples_split': 2}
Best score 0.9290307233692523


In [39]:
tree_clf.score(X, y)

0.9575166129403582

In [43]:
cv = cross_validate(tree_clf, X, y, cv=10)

print(cv)

{'fit_time': array([0.1021769 , 0.10637522, 0.24153423, 0.21555638, 0.10033107,
       0.10744286, 0.09907579, 0.11952424, 0.10765982, 0.13783813]), 'score_time': array([0.00299883, 0.00337958, 0.01608777, 0.00345802, 0.00274539,
       0.00283408, 0.00293231, 0.00326633, 0.00311351, 0.00372052]), 'test_score': array([0.81760909, 0.89371262, 0.96659155, 0.78098674, 0.99106647,
       0.99765205, 0.69871629, 0.97064328, 0.9931132 , 0.82838065])}
