In [6]:
import torch.nn as nn
import torch.nn.functional as F

class CropYieldMLP(nn.Module):
    def __init__(self, input_dim, dropout_rate=0.2):
        super(CropYieldMLP, self).__init__()
        
        self.fc1 = nn.Linear(input_dim, 128)
        self.bn1 = nn.BatchNorm1d(128)
        self.dropout1 = nn.Dropout(p=dropout_rate)
        
        self.fc2 = nn.Linear(128, 64)
        self.bn2 = nn.BatchNorm1d(64)
        self.dropout2 = nn.Dropout(p=dropout_rate)
        
        self.fc3 = nn.Linear(64, 32)
        self.bn3 = nn.BatchNorm1d(32)
        
        self.output = nn.Linear(32, 1)

    def forward(self, x):
        x = self.fc1(x)
        x = self.bn1(x)
        x = F.relu(x)
        x = self.dropout1(x)

        x = self.fc2(x)
        x = self.bn2(x)
        x = F.relu(x)
        x = self.dropout2(x)

        x = self.fc3(x)
        x = self.bn3(x)
        x = F.relu(x)

        x = self.output(x)
        return x


In [7]:
import torch
import joblib
import json
import numpy as np
import pandas as pd

class CropYieldPredictor:
    def __init__(self, model_path, season_encoder_path, state_encoder_path, mean_encoding_path, dataset_path):
        # Load saved encoders and mean dictionary
        self.season_encoder = joblib.load(season_encoder_path)
        self.state_encoder = joblib.load(state_encoder_path)
        
        with open(mean_encoding_path, 'r') as f:
            self.mean_encoding_dict = json.load(f)

        # Load the trained model
        self.model = torch.load(model_path)
        self.model.eval()

        # Load the dataset to get dynamic options
        self.dataset = pd.read_csv(dataset_path)
        self.seasons = self.dataset['Season'].unique().tolist()
        self.states = self.dataset['State_Name'].unique().tolist()
        self.districts = self.dataset['District_Name'].unique().tolist()
        self.crops = self.dataset['Crop'].unique().tolist()

    # Method to display options and get user input
    def get_input(self, options, prompt):
        print(f"\n{prompt}")
        for idx, option in enumerate(options, 1):
            print(f"{idx}. {option}")
        selected_index = int(input(f"Enter the number for {prompt}: "))
        return options[selected_index - 1]

    # Method to get districts for a selected state
    def get_districts_for_state(self, state_input):
        # Filter districts based on the selected state
        state_districts = self.dataset[self.dataset['State_Name'] == state_input]['District_Name'].unique().tolist()
        return state_districts

    # Method to get user inputs
    def get_user_inputs(self):
        # Get user inputs with dynamic options
        season_input = self.get_input(self.seasons, "Season")
        state_input = self.get_input(self.states, "State Name")

        # Get the districts for the selected state
        state_districts = self.get_districts_for_state(state_input)
        district_input = self.get_input(state_districts, "District Name")

        crop_input = self.get_input(self.crops, "Crop Name")
        area_input = float(input("Enter Area (in hectares): "))
        
        return season_input, state_input, district_input, crop_input, area_input

    # Method to preprocess inputs and make predictions
    def predict(self):
        # Get user inputs
        season_input, state_input, district_input, crop_input, area_input = self.get_user_inputs()

        # Apply OneHotEncoding to Season
        season_encoded = self.season_encoder.transform([[season_input]])

        # Apply LabelEncoding to State
        state_encoded = self.state_encoder.transform([state_input]).reshape(-1, 1)

        # Target Mean Encoding for District + Crop
        key = f"{district_input}_{crop_input}"
        district_crop_encoded = np.array([[self.mean_encoding_dict.get(key, 0)]])

        # Log transform the Area
        area_log = np.log1p(area_input).reshape(1, 1)  # Ensure it's a 2D array (1, 1)

        # Combine all features into final input array
        X_input = np.hstack([season_encoded, state_encoded, district_crop_encoded, area_log])

        # Convert to tensor
        X_tensor = torch.tensor(X_input, dtype=torch.float32)

        # Predict
        with torch.no_grad():
            log_pred = self.model(X_tensor).item()

        # Inverse log to get final production prediction
        predicted_production = np.expm1(log_pred)
        
        # Output the result
        print(f"\n🌾 Predicted Production: {predicted_production:.2f} quintals")

# Example of how to use the class
if __name__ == "__main__":
    # Initialize the predictor with the file paths to your model and encoders
    predictor = CropYieldPredictor(
        model_path='my_model_complete.pth',
        season_encoder_path='season_onehot_encoder.pkl',
        state_encoder_path='state_label_encoder.pkl',
        mean_encoding_path='district_crop_mean_prod.json',
        dataset_path='crop_production.csv'  # Path to your dataset
    )

    # Run the prediction
    predictor.predict()


  self.model = torch.load(model_path)



Season
1. Kharif     
2. Whole Year 
3. Autumn     
4. Rabi       
5. Summer     
6. Winter     


Enter the number for Season:  2



State Name
1. Andaman and Nicobar Islands
2. Andhra Pradesh
3. Arunachal Pradesh
4. Assam
5. Bihar
6. Chandigarh
7. Chhattisgarh
8. Dadra and Nagar Haveli
9. Goa
10. Gujarat
11. Haryana
12. Himachal Pradesh
13. Jammu and Kashmir 
14. Jharkhand
15. Karnataka
16. Kerala
17. Madhya Pradesh
18. Maharashtra
19. Manipur
20. Meghalaya
21. Mizoram
22. Nagaland
23. Odisha
24. Puducherry
25. Punjab
26. Rajasthan
27. Sikkim
28. Tamil Nadu
29. Telangana 
30. Tripura
31. Uttar Pradesh
32. Uttarakhand
33. West Bengal


Enter the number for State Name:  1



District Name
1. NICOBARS
2. NORTH AND MIDDLE ANDAMAN
3. SOUTH ANDAMANS


Enter the number for District Name:  1



Crop Name
1. Arecanut
2. Other Kharif pulses
3. Rice
4. Banana
5. Cashewnut
6. Coconut 
7. Dry ginger
8. Sugarcane
9. Sweet potato
10. Tapioca
11. Black pepper
12. Dry chillies
13. other oilseeds
14. Turmeric
15. Maize
16. Moong(Green Gram)
17. Urad
18. Arhar/Tur
19. Groundnut
20. Sunflower
21. Bajra
22. Castor seed
23. Cotton(lint)
24. Horse-gram
25. Jowar
26. Korra
27. Ragi
28. Tobacco
29. Gram
30. Wheat
31. Masoor
32. Sesamum
33. Linseed
34. Safflower
35. Onion
36. other misc. pulses
37. Samai
38. Small millets
39. Coriander
40. Potato
41. Other  Rabi pulses
42. Soyabean
43. Beans & Mutter(Vegetable)
44. Bhindi
45. Brinjal
46. Citrus Fruit
47. Cucumber
48. Grapes
49. Mango
50. Orange
51. other fibres
52. Other Fresh Fruits
53. Other Vegetables
54. Papaya
55. Pome Fruit
56. Tomato
57. Rapeseed &Mustard
58. Mesta
59. Cowpea(Lobia)
60. Lemon
61. Pome Granet
62. Sapota
63. Cabbage
64. Peas  (vegetable)
65. Niger seed
66. Bottle Gourd
67. Sannhamp
68. Varagu
69. Garlic
70. Ginger
71. Oi

Enter the number for Crop Name:  1
Enter Area (in hectares):  1254



🌾 Predicted Production: 1955.32 quintals


