In [None]:
!pip install -q einops

## IMPORT LIBS

In [None]:
import os
import numpy as np
import random
import torch
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from torch.amp import autocast
from torch.utils.data import Dataset,DataLoader
import torch.nn as nn
import torchvision.models as models
from tqdm import tqdm
import gc
import pandas as pd
from sklearn import preprocessing
from sklearn.preprocessing import OneHotEncoder
from einops import rearrange
import h5py
from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error
from scipy.stats import pearsonr
from matplotlib import pyplot as plt
import seaborn as sns
from typing import List, Tuple
from datetime import datetime, timedelta
from pprint import pprint
import json
import math

## REPRODUCTIVITY

In [None]:
# Set environment variable
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'

global_seed = 0

random.seed(global_seed)
np.random.seed(global_seed)

torch.manual_seed(global_seed)
torch.use_deterministic_algorithms(True)

In [None]:
def make_config(years: List[int], state: str, state_ansi: str, fips: str, crop_type: str, grow_season: List[int]):

    config = {
        "FIPS": fips,
        "years": years,
        "state": state.upper(),
        "crop_type": crop_type,
        "data": {
            "HRRR": {
                "short_term": []
            },
            "USDA": [],
            "sentinel": []
        }
    }
    
    for year in years:
        # HRRR data
        hrrr_files = [
            f"HRRR/{year}/{state.upper()}/HRRR_{state_ansi}_{state.upper()}_{year}-{month:02d}.csv"
            for month in range(grow_season[0], grow_season[1] + 1)
        ]
        config["data"]["HRRR"]["short_term"].append(hrrr_files)
        
        # USDA data
        if crop_type=="Soybeans":
            config["data"]["USDA"].append(f"USDA/{crop_type}/{year}/USDA_Soybean_County_{year}.csv")
        else:
            config["data"]["USDA"].append(f"USDA/{crop_type}/{year}/USDA_{crop_type}_County_{year}.csv")
        
        # Sentinel data
        quarters = [
            (f"{year}-01-01", f"{year}-03-31"),
            (f"{year}-04-01", f"{year}-06-30"),
            (f"{year}-07-01", f"{year}-09-30"),
            (f"{year}-10-01", f"{year}-12-31")
        ]
        
        sentinel_files = []
        for start, end in quarters:
            quarter_start = datetime.strptime(start, "%Y-%m-%d")
            quarter_end = datetime.strptime(end, "%Y-%m-%d")
            if (grow_season[0] <= quarter_start.month <= grow_season[1]) or \
               (grow_season[0] <= quarter_end.month <= grow_season[1]):
                sentinel_files.append(f"AG/{state.upper()}/{year}/Agriculture_{state_ansi}_{state.upper()}_{start}_{end}.h5")
        
        config["data"]["sentinel"].append(sentinel_files)
    
    return config

# Train
years = list(range(2018,2022))
state = "AL"
state_ansi = "01"
fips = ['01003', '01015', '01019', '01031', '01039', '01045', '01047', '01053', '01061', 
        '01067', '01069', '01077', '01079', '01083', '01089', '01097', '01099', '01117'] 

crop_type = "Cotton"
grow_season = [4, 9]  # April to September

train_config = make_config(years, state, state_ansi, fips, crop_type, grow_season)
with open('train_config.json', 'w') as file:
    json.dump(train_config, file)
print("Train config")
pprint(train_config)

# Test
years = [2022]
test_config = make_config(years,  state, state_ansi, fips, crop_type, grow_season)
with open('test_config.json', 'w') as file:
    json.dump(test_config, file)
print("Test config")
pprint(test_config)

## DATA LOADER

In [None]:
class Sentinel2Imagery(Dataset):
    def __init__(self, base_dir, config_file, transform=None):
        self.transform = transform
        self.base_dir = base_dir
        
        with open(config_file, 'r') as f:
            obj = json.load(f)
        
        self.fips_codes = obj["FIPS"]
        self.years = obj["years"]
        self.file_paths = obj["data"]["sentinel"]
    
    def __len__(self):
        return len(self.fips_codes) * len(self.years)

    def __getitem__(self, index):
        fips_index = index // len(self.years)
        year_index = index % len(self.years)
        
        fips_code = self.fips_codes[fips_index]
        year = self.years[year_index]
        file_paths = self.file_paths[year_index]
        
        temporal_list = []
        for file_path in file_paths:
            with h5py.File(os.path.join(self.base_dir, file_path), 'r') as hf:
                groups = hf[fips_code]
                for d in groups.keys():
                    grids = groups[d]["data"]
                    grids = torch.from_numpy(np.asarray(grids))
                    temporal_list.append(grids)
                hf.close()
        x = torch.stack(temporal_list)
        x = x.to(torch.float32)
        x = rearrange(x, 't g h w c -> t g c h w')
        if self.transform:
            t, g, _, _, _ = x.shape
            x = rearrange(x, 't g c h w -> (t g) c h w')
            x = self.transform(x)
            x = rearrange(x, '(t g) c h w -> t g c h w', t=t, g=g)
        return x, fips_code, year

class HRRRComputedDataset(Dataset):
    def __init__(self, base_dir, config_file, column_names=None):
        self.base_dir = base_dir
        self.day_range = [i + 1 for i in range(28)]
        
        with open(config_file, 'r') as f:
            obj = json.load(f)
        
        self.fips_codes = obj["FIPS"]
        self.years = obj["years"]
        self.short_term_file_path = obj["data"]["HRRR"]["short_term"]
        
        if column_names:
            self.column_names = column_names
        else:
            self.column_names = [
                'Avg Temperature (K)', 'Max Temperature (K)', 'Min Temperature (K)',
                'Precipitation (kg m**-2)', 'Relative Humidity (%)', 'Wind Gust (m s**-1)',
                'Wind Speed (m s**-1)', 'Downward Shortwave Radiation Flux (W m**-2)',
                'Vapor Pressure Deficit (kPa)'
            ]

    def __len__(self):
        return len(self.fips_codes) * len(self.years)

    def __getitem__(self, index):
        fips_index = index // len(self.years)
        year_index = index % len(self.years)
        
        fips_code = self.fips_codes[fips_index]
        year = self.years[year_index]
        short_term_file_paths = self.short_term_file_path[year_index]
        x_short = self.get_short_term_val(fips_code, short_term_file_paths)
        x_short = x_short.to(torch.float32)
        return x_short, fips_code, year

    def get_short_term_val(self, fips_code, file_paths):
        df_list = []
        for file_path in file_paths:
            tmp_df = pd.read_csv(os.path.join(self.base_dir, file_path))
            df_list.append(tmp_df)

        df = pd.concat(df_list, ignore_index=True)
        df["FIPS Code"] = df["FIPS Code"].astype(str).str.zfill(5)
        df = df[(df["FIPS Code"] == fips_code) & (df["Daily/Monthly"] == "Daily")]
        df.columns = df.columns.str.strip()

        group_month = df.groupby(['Month'])

        temporal_list = []
        for month, df_month in group_month:
            group_grid = df_month.groupby(['Grid Index'])

            time_series = []
            for grid, df_grid in group_grid:
                df_grid = df_grid.sort_values(by=['Day'], ascending=[True], na_position='first')
                df_grid = df_grid[df_grid.Day.isin(self.day_range)]
                df_grid = df_grid[self.column_names]
                val = self.signed_log_transform(torch.from_numpy(df_grid.values))
                time_series.append(val)

            temporal_list.append(torch.stack(time_series))

        x_short = torch.stack(temporal_list)
        x_short = rearrange(x_short, 'm g d p -> m d g p')
        return x_short

    def signed_log_transform(self, data):
        epsilon = 1e-9  # small constant to avoid log(0)
        return torch.sign(data) * torch.log10(torch.abs(data) + epsilon)

class USDACropDataset(Dataset):
    def __init__(self, base_dir, config_file, crop_type):
        self.base_dir = base_dir
        self.crop_type = crop_type
        
        with open(config_file, 'r') as f:
            obj = json.load(f)
        
        self.fips_codes = obj["FIPS"]
        self.years = obj["years"]
        self.file_paths = obj["data"]["USDA"]

        if crop_type == "Cotton":
            self.column_names = ['PRODUCTION, MEASURED IN 480 LB BALES', 'YIELD, MEASURED IN LB / ACRE']
        else:
            self.column_names = ['PRODUCTION, MEASURED IN BU', 'YIELD, MEASURED IN BU / ACRE']

        
    def __len__(self):
        return len(self.fips_codes) * len(self.years)
    def get_num_classes(self):
        return len(self.fips_encoder.classes_)
    def __getitem__(self, index):
        fips_index = index // len(self.years)
        year_index = index % len(self.years)
        
        fips_code = self.fips_codes[fips_index]
        year = self.years[year_index]
        file_path = self.file_paths[year_index]
        df = pd.read_csv(os.path.join(self.base_dir, file_path))

        df['state_ansi'] = df['state_ansi'].astype(str).str.zfill(2)
        df['county_ansi'] = df['county_ansi'].astype(str).str.zfill(3)

        df = df[(df["state_ansi"] == fips_code[:2]) & (df["county_ansi"] == fips_code[-3:])]

        df = df[self.column_names]
        x = torch.from_numpy(df.values)
        x = x.to(torch.float32)
        x = torch.log(torch.flatten(x, start_dim=0))
        return x, fips_code, year

## MODAL ARCHITECTURE

In [None]:
class SatelliteEncoder(nn.Module):
    def __init__(self, output_dim=128, max_temporal_len=50):
        super().__init__()
        # Base ResNet feature extractor (without final layers)
        resnet = models.resnet18(pretrained=True)
        self.features = nn.Sequential(*list(resnet.children())[:-2])
        
        # Temporal modeling components
        self.temporal_conv = nn.Sequential(
            nn.Conv3d(512, 256, kernel_size=(3, 3, 3), padding=(1, 1, 1)),
            nn.BatchNorm3d(256),
            nn.LeakyReLU(0.1),
            nn.Conv3d(256, 128, kernel_size=(3, 3, 3), padding=(1, 1, 1)),
            nn.BatchNorm3d(128),
            nn.LeakyReLU(0.1)
        )
        
        # Temporal attention mechanism
        self.temporal_attention = nn.MultiheadAttention(
            embed_dim=128,
            num_heads=4,
            batch_first=True
        )
        
        # Sinusoidal position encoding
        self.register_buffer(
            "pos_encoder", 
            self._get_position_encoding(max_temporal_len, 128)
        )
        
        self.fc = nn.Linear(128, output_dim)
        
        # Layer normalization for attention
        self.norm1 = nn.LayerNorm(128)
        self.norm2 = nn.LayerNorm(128)
        
    def _get_position_encoding(self, max_len, d_model):
        # Create position encoding matrix
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term[:d_model//2])
        
        return pe.unsqueeze(0)  # [1, max_len, d_model]
        
    def forward(self, x):
        # x shape: [batch, t, g, c, h, w]
        batch, t, g, c, h, w = x.shape
        
        # Process each temporal step through ResNet
        x = x.view(batch * t * g, c, h, w)
        x = self.features(x)  # [batch*t*g, 512, h', w']
        
        # Reshape for temporal processing
        _, ch, h_out, w_out = x.shape
        x = x.view(batch * g, t, ch, h_out, w_out)
        x = x.permute(0, 2, 1, 3, 4)  # [batch*g, ch, t, h', w']
        
        # Apply 3D temporal convolutions
        x = self.temporal_conv(x)  # [batch*g, 128, t, h', w']
        
        # Global spatial-temporal pooling
        x = x.mean(dim=(-1, -2))  # [batch*g, 128, t]
        x = x.permute(0, 2, 1)  # [batch*g, t, 128]
        
        # Add positional encoding for the actual sequence length
        pos_enc = self.pos_encoder[:, :t, :]
        x = x + pos_enc
        x = self.norm1(x)
        
        # Use attention masking for variable length
        attn_mask = None
        attn_out, _ = self.temporal_attention(x, x, x, attn_mask=attn_mask)
        x = self.norm2(x + attn_out)  # Residual connection
        
        # Final temporal pooling and feature projection
        x = x.mean(dim=1)  # Average over temporal dimension
        x = self.fc(x)
        
        # Reshape back to [batch, g, output_dim]
        x = x.view(batch, g, -1)
        
        return x


class WeatherEncoder(nn.Module):
    def __init__(self, input_dim=9):
        super().__init__()
        self.lstm = nn.LSTM(input_dim, 64, num_layers=2, batch_first=True)
        self.fc = nn.Linear(64, 32)
    
    def forward(self, x):
        # x shape: [batch, m, d, g, p]
        batch, m, d, g, p = x.shape
        x = rearrange(x, 'b m d g p -> (b g) (m d) p')
        x, _ = self.lstm(x)
        x = self.fc(x[:, -1, :])  # Take the last output
        return rearrange(x, '(b g) c -> b g c', b=batch, g=g)  # Output: [batch, g, 32]

class CropYieldModel(nn.Module):
    def __init__(self, num_fips):
        super().__init__()
        self.satellite_encoder = SatelliteEncoder()
        self.weather_encoder = WeatherEncoder()
        
        # Fusion module for combining satellite and weather features
        self.fusion_layer = nn.Sequential(
            nn.Linear(128 + 32, 64),
            nn.LayerNorm(64),
            nn.LeakyReLU(0.1),
            nn.Dropout(p=0.2),
            nn.Linear(64, 32),
            nn.LayerNorm(32),
            nn.LeakyReLU(0.1),
        )
        
        # Grid feature aggregation using weighted pooling
        self.grid_weights = nn.Sequential(
            nn.Linear(32, 16),
            nn.LeakyReLU(0.1),
            nn.Linear(16, 1),
            nn.Softmax(dim=1)
        )
        
        # Final prediction layers with FIPS embedding
        self.output_layer = nn.Sequential(
            nn.Linear(32 + num_fips, 32),
            nn.LayerNorm(32),
            nn.LeakyReLU(0.1),
            nn.Dropout(0.2),
            nn.Linear(32, 16),
            nn.LeakyReLU(0.1),
            nn.Linear(16, 2)  # 2 outputs: production and yield
        )
        
    def forward(self, satellite, weather, fips):
        # Get encoded features
        sat_features = self.satellite_encoder(satellite)  # [batch, g, 128]
        weather_features = self.weather_encoder(weather)  # [batch, g, 32]
        
        # Combine satellite and weather features for each grid
        combined_features = torch.cat([sat_features, weather_features], dim=-1)  # [batch, g, 160]
        fused_grid_features = self.fusion_layer(combined_features)  # [batch, g, 32]
        
        # Calculate attention weights for each grid
        grid_attention_weights = self.grid_weights(fused_grid_features)  # [batch, g, 1]
        
        # Apply weighted pooling over grids
        county_features = torch.sum(
            fused_grid_features * grid_attention_weights, 
            dim=1
        )  # [batch, 32]
        
        # Concatenate with FIPS embedding and make final prediction
        county_features = torch.cat([county_features, fips], dim=-1)  # [batch, 32 + num_fips]
        output = self.output_layer(county_features)  # [batch, 2]
        
        return output

## TRAIN AND TEST

In [None]:
def train_model(model, train_data, epochs=100, patience=5, save_path="best_model.pth"):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    model = model.to(device)

    train_sentinel_loader, train_hrrr_loader, train_usda_loader = train_data
    
    criterion = nn.HuberLoss()
    optimizer = optim.RMSprop(model.parameters(), lr=0.001, weight_decay=1e-5)
    scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2)
    
    best_train_loss = float('inf')
    counter = 0
    train_losses = []

    for epoch in range(epochs):
        # Training
        model.train()
        train_running_loss = 0.0
        for (sentinel, fips_code, year), (hrrr, _, _), (usda, _, _) in tqdm(zip(train_sentinel_loader, train_hrrr_loader, train_usda_loader), desc=f"Epoch {epoch+1} Training"):
            gc.collect()
            torch.cuda.empty_cache()
            sentinel, hrrr, usda = sentinel.to(device), hrrr.to(device), usda.to(device)
            fips_onehot = torch.tensor(fips_encoder.transform(np.array(fips_code).reshape(-1, 1)), dtype=torch.float32).to(device)      
            optimizer.zero_grad()
            with autocast(device_type=str(device)):
                output = model(sentinel, hrrr, fips_onehot)
                loss = criterion(output,usda)
            loss.backward()
#             torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
#             torch.autograd.set_detect_anomaly(True)
            optimizer.step()
            train_running_loss += loss.item()
        
        train_epoch_loss = train_running_loss / len(train_sentinel_loader)
        train_losses.append(train_epoch_loss)
    
        scheduler.step()
        
        print(f"Epoch {epoch+1} - Train Loss: {train_epoch_loss:.4f}")
        
        if train_epoch_loss < best_train_loss:
            best_train_loss = train_epoch_loss
            torch.save(model.state_dict(), save_path)
            counter = 0
        else:
            counter += 1
            if counter >= patience:
                print("Early stopping")
                break
    return model

def evaluate_model(model, test_data):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    model.eval()
    sentinel_loader, hrrr_loader, usda_loader = test_data
    all_predictions = []
    all_ground_truth = []

    with torch.no_grad():
        for (sentinel, fips_code, year), (hrrr, fips_code, year), (usda, fips_code, year) in zip(sentinel_loader, hrrr_loader, usda_loader):
            sentinel, hrrr, usda = sentinel.to(device), hrrr.to(device), usda.to(device)
            fips_onehot = torch.tensor(fips_encoder.transform(np.array(fips_code).reshape(-1, 1)), dtype=torch.float32).to(device)
    
            output = model(sentinel, hrrr, fips_onehot)
        
            all_predictions.append(output.cpu().numpy())
            all_ground_truth.append(usda.cpu().numpy())
        
    all_predictions = np.concatenate(all_predictions, axis=0)
    all_ground_truth = np.concatenate(all_ground_truth, axis=0)
    print(all_predictions[:,0])
    print(all_ground_truth[:,0])
    print(all_predictions[:,1])
    print(all_ground_truth[:,1])
    results = {}
    for i, metric_name in enumerate(["Production", "Yield"]):
        y_true = torch.from_numpy(all_ground_truth[:, i])
        y_pred = torch.from_numpy(all_predictions[:, i])
        mae = torch.abs(y_true - y_pred).mean()
        mse = ((y_pred - y_true) ** 2).mean()
        rmse = torch.sqrt(mse)
        mape = (torch.abs(y_true - y_pred) / torch.abs(y_true)).mean() * 100
        smape = 100 * (torch.abs(y_true - y_pred) / ((torch.abs(y_true) + torch.abs(y_pred)) / 2)).mean()
        max_error = torch.abs(y_true - y_pred).max()
        corr = torch.corrcoef(torch.stack((y_pred, y_true)))
        metrics = {
            'MAE': round(mae.item(), 2),
            'MSE': round(mse.item(), 2),
            'RMSE': round(rmse.item(), 2),
            'MAPE': round(mape.item(), 2),
            'SMAPE': round(smape.item(), 2),
            'Max Error': round(max_error.item(), 2),
            'Correlation Coefficient': round(corr[0, 1].item(), 2)
        }
        results[metric_name] = metrics
    return results

## RUN MODEL

## Cotton

In [None]:
gc.collect()
torch.cuda.empty_cache()
base_dir = "/kaggle/input/cropnetv2"
train_config = "/kaggle/working/train_config.json"
test_config = "/kaggle/working/test_config.json"

train_sentinel_loader = DataLoader(Sentinel2Imagery(base_dir,train_config),batch_size = 1)
train_hrrr_loader = DataLoader(HRRRComputedDataset(base_dir,train_config),batch_size = 1)
train_usda_loader = DataLoader(USDACropDataset(base_dir,train_config,crop_type),batch_size = 1)
    
test_sentinel_loader = DataLoader(Sentinel2Imagery(base_dir,test_config),batch_size = 1)
test_hrrr_loader = DataLoader(HRRRComputedDataset(base_dir,test_config),batch_size = 1)
test_usda_loader = DataLoader(USDACropDataset(base_dir,test_config,crop_type),batch_size = 1)

with open(train_config, 'r') as f:
    obj = json.load(f)
    fips_codes = sorted(obj["FIPS"])
    
fips_encoder = OneHotEncoder(sparse_output=False)
fips_encoder.fit(np.array(fips).reshape(-1, 1))

model = CropYieldModel(len(fips_codes))

model = train_model(model, (train_sentinel_loader, train_hrrr_loader, train_usda_loader),
                    epochs=20, patience=5, save_path="best_model.pth")

results = evaluate_model(model, (test_sentinel_loader, test_hrrr_loader, test_usda_loader))
print(crop_type)
pprint(results)

## Corn

In [None]:
# Train
years = list(range(2018,2022))
state = "IL"
state_ansi = "17"
fips = ['17007', '17011', '17015', '17017', '17019', '17021', '17025', '17027',  
        '17037', '17049', '17053', '17055', '17057', '17059', '17061', '17063', '17073', 
        '17075', '17077', '17081', '17085', '17089', '17093', '17095', '17101', '17103', 
        '17105', '17107', '17113', '17115', '17117', '17119', '17121', '17123', '17133', 
        '17135', '17139', '17141', '17143', '17147', '17157', '17163', '17167', '17169', 
        '17173', '17175', '17177', '17179', '17189', '17193', '17195', '17201', '17203'] 
crop_type = "Corn"
grow_season = [4, 9]  # April to September

train_config = make_config(years, state, state_ansi, fips, crop_type, grow_season)
with open('train_config.json', 'w') as file:
    json.dump(train_config, file)
print("Train config")
pprint(train_config)

# Test
years = [2022]
test_config = make_config(years,  state, state_ansi, fips, crop_type, grow_season)
with open('test_config.json', 'w') as file:
    json.dump(test_config, file)
print("Test config")
pprint(test_config)

In [None]:
gc.collect()
torch.cuda.empty_cache()
base_dir = "/kaggle/input/cropnetv2"
train_config = "/kaggle/working/train_config.json"
test_config = "/kaggle/working/test_config.json"

train_sentinel_loader = DataLoader(Sentinel2Imagery(base_dir,train_config),batch_size = 1)
train_hrrr_loader = DataLoader(HRRRComputedDataset(base_dir,train_config),batch_size = 1)
train_usda_loader = DataLoader(USDACropDataset(base_dir,train_config,crop_type),batch_size = 1)
    
test_sentinel_loader = DataLoader(Sentinel2Imagery(base_dir,test_config),batch_size = 1)
test_hrrr_loader = DataLoader(HRRRComputedDataset(base_dir,test_config),batch_size = 1)
test_usda_loader = DataLoader(USDACropDataset(base_dir,test_config,crop_type),batch_size = 1)

with open(train_config, 'r') as f:
    obj = json.load(f)
    fips_codes = sorted(obj["FIPS"])
    
fips_encoder = OneHotEncoder(sparse_output=False)
fips_encoder.fit(np.array(fips).reshape(-1, 1))

model = CropYieldModel(len(fips_codes))

model = train_model(model, (train_sentinel_loader, train_hrrr_loader, train_usda_loader),
                    epochs=20, patience=5, save_path="best_model.pth")

results = evaluate_model(model, (test_sentinel_loader, test_hrrr_loader, test_usda_loader))
print(crop_type)
pprint(results)

## Soybeans

In [None]:
# Train
years = list(range(2018,2022))
state = "IL"
state_ansi = "17"
fips = ['17005', '17007', '17009', '17011', '17015', '17019', '17025', '17027', '17037', 
        '17045', '17049', '17053', '17055', '17057', '17059', '17063', '17073', '17075', '17077', 
        '17081', '17089', '17091', '17095', '17101', '17103', '17105', '17113', '17115', '17117', 
        '17119', '17121', '17129', '17133', '17139', '17141', '17143', '17145', '17153', '17157', 
        '17163', '17167', '17173', '17177', '17179', '17189', '17193', '17197', '17201', '17203']

crop_type = "Soybeans"
grow_season = [4, 9]  # April to September

train_config = make_config(years, state, state_ansi, fips, crop_type, grow_season)
with open('train_config.json', 'w') as file:
    json.dump(train_config, file)

# Test
years = [2022]
test_config = make_config(years,  state, state_ansi, fips, crop_type, grow_season)
with open('test_config.json', 'w') as file:
    json.dump(test_config, file)

In [None]:
gc.collect()
torch.cuda.empty_cache()
base_dir = "/kaggle/input/cropnetv2"
train_config = "/kaggle/working/train_config.json"
test_config = "/kaggle/working/test_config.json"

train_sentinel_loader = DataLoader(Sentinel2Imagery(base_dir,train_config),batch_size = 1)
train_hrrr_loader = DataLoader(HRRRComputedDataset(base_dir,train_config),batch_size = 1)
train_usda_loader = DataLoader(USDACropDataset(base_dir,train_config,crop_type),batch_size = 1)
    
test_sentinel_loader = DataLoader(Sentinel2Imagery(base_dir,test_config),batch_size = 1)
test_hrrr_loader = DataLoader(HRRRComputedDataset(base_dir,test_config),batch_size = 1)
test_usda_loader = DataLoader(USDACropDataset(base_dir,test_config,crop_type),batch_size = 1)

with open(train_config, 'r') as f:
    obj = json.load(f)
    fips_codes = sorted(obj["FIPS"])
    
fips_encoder = OneHotEncoder(sparse_output=False)
fips_encoder.fit(np.array(fips).reshape(-1, 1))

model = CropYieldModel(len(fips_codes))

model = train_model(model, (train_sentinel_loader, train_hrrr_loader, train_usda_loader),
                    epochs=20, patience=5, save_path="best_model.pth")

results = evaluate_model(model, (test_sentinel_loader, test_hrrr_loader, test_usda_loader))
print(crop_type)
pprint(results)

## Winter Wheat

In [None]:
def make_config(years: List[int], state: str, state_ansi: str, fips: str, crop_type: str, grow_season: List[int]):
    """
    Creates configuration for winter wheat data collection
    grow_season: List containing [start_month, end_month] of the growing cycle
                 For winter wheat, this spans across year boundary
    """
    config = {
        "FIPS": fips,
        "years": years,
        "state": state.upper(),
        "crop_type": crop_type,
        "data": {
            "HRRR": {
                "short_term": []
            },
            "USDA": [],
            "sentinel": []
        }
    }
   
    for year in years:
        # HRRR data - need to consider months from previous year's fall
        hrrr_files = []
        # Previous year's fall months (planting)
        for month in range(9, 13):  # September to December
            hrrr_files.append(f"HRRR/{year-1}/{state.upper()}/HRRR_{state_ansi}_{state.upper()}_{year-1}-{month:02d}.csv")
        # Current year's winter and spring months (growing and harvest)
        for month in range(1, 7):  # January to July
            hrrr_files.append(f"HRRR/{year}/{state.upper()}/HRRR_{state_ansi}_{state.upper()}_{year}-{month:02d}.csv")
        
        config["data"]["HRRR"]["short_term"].append(hrrr_files)
       
        # USDA data
        config["data"]["USDA"].append(f"USDA/{crop_type}/{year}/USDA_WinterWheat_County_{year}.csv")
       
        # Sentinel data - need to cover previous fall to current summer
        quarters = [
            # Previous year quarters
            (f"{year-1}-10-01", f"{year-1}-12-31"),  # Q4 (planting)
            # Current year quarters
            (f"{year}-01-01", f"{year}-03-31"),      # Q1 (winter growth)
            (f"{year}-04-01", f"{year}-06-30"),      # Q2 (spring growth)
        ]
       
        sentinel_files = []
        for start, end in quarters:
            sentinel_files.append(f"AG/{state.upper()}/{start[:4]}/Agriculture_{state_ansi}_{state.upper()}_{start}_{end}.h5")
       
        config["data"]["sentinel"].append(sentinel_files)
   
    return config

# Train
years = list(range(2018, 2022))
state = "IL"
state_ansi = "17"
fips = ['17011', '17013', '17023', '17025', '17027', '17037', '17047', '17049', '17067', 
        '17083', '17089', '17095', '17119', '17121', '17125', '17133', '17141', '17157', '17159', 
        '17163', '17173', '17177', '17179', '17189', '17201'] 
crop_type = "WinterWheat"
grow_season = [9, 6]  # September to July (spanning across years)

train_config = make_config(years, state, state_ansi, fips, crop_type, grow_season)
with open('train_config.json', 'w') as file:
    json.dump(train_config, file)


# Test
years = [2022]
test_config = make_config(years, state, state_ansi, fips, crop_type, grow_season)
with open('test_config.json', 'w') as file:
    json.dump(test_config, file)

In [None]:
gc.collect()
torch.cuda.empty_cache()
base_dir = "/kaggle/input/cropnetv2"
train_config = "/kaggle/working/train_config.json"
test_config = "/kaggle/working/test_config.json"

train_sentinel_loader = DataLoader(Sentinel2Imagery(base_dir,train_config),batch_size = 1)
train_hrrr_loader = DataLoader(HRRRComputedDataset(base_dir,train_config),batch_size = 1)
train_usda_loader = DataLoader(USDACropDataset(base_dir,train_config,crop_type),batch_size = 1)
    
test_sentinel_loader = DataLoader(Sentinel2Imagery(base_dir,test_config),batch_size = 1)
test_hrrr_loader = DataLoader(HRRRComputedDataset(base_dir,test_config),batch_size = 1)
test_usda_loader = DataLoader(USDACropDataset(base_dir,test_config,crop_type),batch_size = 1)

with open(train_config, 'r') as f:
    obj = json.load(f)
    fips_codes = sorted(obj["FIPS"])
    
fips_encoder = OneHotEncoder(sparse_output=False)
fips_encoder.fit(np.array(fips).reshape(-1, 1))

model = CropYieldModel(len(fips_codes))

model = train_model(model, (train_sentinel_loader, train_hrrr_loader, train_usda_loader),
                    epochs=20, patience=5, save_path="best_model.pth")

results = evaluate_model(model, (test_sentinel_loader, test_hrrr_loader, test_usda_loader))
print(crop_type)
pprint(results)