In [13]:
import pandas as pd
import numpy as np

import rasterio
from skimage.transform import resize
from skimage.transform import rotate
import os

import torch
from torch.utils.data import Dataset, DataLoader

import torch.nn as nn
import torch.nn.functional as F

from sklearn.model_selection import KFold
from sklearn.preprocessing import MinMaxScaler
from tqdm import tqdm
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import TimeSeriesSplit
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score

from sklearn.model_selection import train_test_split

from datetime import timedelta
from skimage.draw import polygon
import matplotlib.pyplot as plt

from shapely.geometry import Polygon

from utils import process_yield_data
from pathlib import Path

import matplotlib.pyplot as plt

#### Import Yield Data

In [2]:
YIELD_DATA_PATH = Path("./combined_yield_data.csv")
yield_data_weekly = process_yield_data(YIELD_DATA_PATH)

            Volume (Pounds)  Cumulative Volumne (Pounds)  Pounds/Acre
Date                                                                 
2012-01-02          23400.0                      23400.0          2.0
2012-01-03          26064.0                      49464.0          3.0
2012-01-04          32382.0                      81846.0          3.0
2012-01-05          69804.0                     151650.0          7.0
2012-01-06          18000.0                     169650.0          2.0

Number of Yield Data Points:  3970

Column Names: Index(['Volume (Pounds)', 'Cumulative Volumne (Pounds)', 'Pounds/Acre'], dtype='object')
Number of Yield Data Points: 2879
Yield data with time features:
            Volume (Pounds)  Cumulative Volumne (Pounds)  Pounds/Acre  \
Date                                                                    
2012-03-04         525753.0                    1785843.0    18.333333   
2012-03-11        2949534.0                    4735377.0    51.666667   
2012-03-18   

#### Define the Model

In [3]:
target_shape = (512, 512)
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using {device} device")

Using mps device


### Old Model

In [4]:
class CNNFeatureExtractor(nn.Module):
    def __init__(self):
        super(CNNFeatureExtractor, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.conv3 = nn.Conv2d(64, 128, 3, padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        self.conv4 = nn.Conv2d(128, 256, 3, padding=1)
        self.bn4 = nn.BatchNorm2d(256)
        self.pool = nn.MaxPool2d(2, 2)
        self.dropout = nn.Dropout(0.5)
        self.flattened_size = self._get_conv_output((1, *target_shape))
        self.fc1 = nn.Linear(self.flattened_size, 512)

    def _get_conv_output(self, shape):
        x = torch.rand(1, *shape)
        x = self.pool(F.relu(self.bn1(self.conv1(x))))
        x = self.pool(F.relu(self.bn2(self.conv2(x))))
        x = self.pool(F.relu(self.bn3(self.conv3(x))))
        x = self.pool(F.relu(self.bn4(self.conv4(x))))
        n_size = x.view(1, -1).size(1)
        return n_size

    def forward(self, x):
        x = self.pool(F.relu(self.bn1(self.conv1(x))))
        x = self.pool(F.relu(self.bn2(self.conv2(x))))
        x = self.pool(F.relu(self.bn3(self.conv3(x))))
        x = self.pool(F.relu(self.bn4(self.conv4(x))))
        x = self.dropout(x)
        x = x.view(-1, self.flattened_size)
        x = F.relu(self.fc1(x))
        return x
    
class HybridModel(nn.Module):
    def __init__(self, cnn_feature_extractor, lstm_hidden_size=64, lstm_layers=1):
        super(HybridModel, self).__init__()
        self.cnn = cnn_feature_extractor
        self.lstm = nn.LSTM(input_size=512, hidden_size=lstm_hidden_size, num_layers=lstm_layers, batch_first=True)
        self.fc1 = nn.Linear(lstm_hidden_size + 4, 64)
        self.fc2 = nn.Linear(64, target_shape[0] * target_shape[1])  # Predict a value per pixel
        self.target_shape = target_shape

    def forward(self, x, time_features):
        batch_size, time_steps, C, H, W = x.size()
        c_in = x.view(batch_size * time_steps, C, H, W)
        c_out = self.cnn(c_in)
        r_in = c_out.view(batch_size, time_steps, -1)
        r_out, (h_n, c_n) = self.lstm(r_in)
        r_out = r_out[:, -1, :]
        x = torch.cat((r_out, time_features), dim=1)  # Concatenate LSTM output with time features
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        x = x.view(batch_size, *self.target_shape)  # Reshape to the target shape
        return x

#### Initialize Function

In [5]:
def weights_init(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
        nn.init.kaiming_normal_(m.weight)
        if m.bias is not None:
            nn.init.constant_(m.bias, 0)

# # Instantiate model with weight decay regularization
# cnn_feature_extractor = CNNFeatureExtractor()
# model = HybridModel(cnn_feature_extractor)
# model.apply(weights_init)
# model.to(device)

batch_size = 16
epochs = 50

# criterion = nn.MSELoss()
# optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0.01)
# scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3, verbose=True)

### Functions for prediction

In [6]:
from inference_utils import (
    preprocess_image,
    compute_mean_std,
    load_evi_data_and_prepare_features,
    find_closest_date,
    find_closest_date_in_df,
    mask_evi_data,
    predict,
    predict_weekly_yield,
    augment_image,
    prepare_dataset,
    train_and_evaluate,
    sync_evi_yield_data,
    CustomDataset,
    load_evi_data,
    find_common_date_range
)


In [7]:
# Load EVI data and prepare time features
evi_data_dir = "./landsat_evi_monterey_masked"
train_loader, val_loader, mean, std = prepare_dataset(evi_data_dir, yield_data_weekly, target_shape, augment=True)

# Define directory and other variables
evi_data_dir = "./landsat_evi_monterey_masked"

# Load EVI data
evi_data_dict = {}
for file in os.listdir(evi_data_dir):
    if file.endswith('.tiff'):
        date_str = os.path.basename(file).split('_')[3]
        date = pd.to_datetime(date_str, format='%Y%m%d')
        evi_data = load_evi_data(os.path.join(evi_data_dir, file))
        evi_data_dict[date] = evi_data

# Compute mean and std metrics
mean, std = compute_mean_std(evi_data_dict, target_shape)

# Apply preprocessing to all images in the dictionary
for date, image in evi_data_dict.items():
    try:
        evi_data_dict[date] = preprocess_image(image, target_shape, mean, std)
    except AssertionError as e:
        print(f"Error processing image for date {date}: {e}")

# Determine common date range between EVI and yield data
start_date, end_date = find_common_date_range(evi_data_dict, yield_data_weekly)

# Filter yield_data_weekly to only include dates within the common range
yield_data_weekly_filtered = yield_data_weekly[(yield_data_weekly.index >= start_date) & (yield_data_weekly.index <= end_date)]

# Sync EVI and yield data
evi_data_dict_filtered, evi_reference_filtered = sync_evi_yield_data(evi_data_dict, yield_data_weekly_filtered)

# Initialize CustomDataset with filtered data
dataset = CustomDataset(evi_data_dict_filtered, evi_reference_filtered, yield_data_weekly_filtered)

Processed file 1/84 in 2.302029s
Processed file 2/84 in 1.793254s
Processed file 3/84 in 1.843476s
Processed file 4/84 in 1.610249s
Processed file 5/84 in 1.831314s
Processed file 6/84 in 1.950453s
Processed file 7/84 in 1.788520s
Processed file 8/84 in 1.666233s
Processed file 9/84 in 2.275727s
Processed file 10/84 in 1.803741s
Processed file 11/84 in 2.236039s
Processed file 12/84 in 1.965004s
Processed file 13/84 in 2.226741s
Processed file 14/84 in 2.216345s
Processed file 15/84 in 1.714607s
Processed file 16/84 in 1.891356s
Processed file 17/84 in 1.998704s
Processed file 19/84 in 2.017514s
Processed file 20/84 in 1.656819s
Processed file 21/84 in 2.118526s
Processed file 22/84 in 2.206108s
Processed file 23/84 in 2.093217s
Processed file 24/84 in 1.883898s
Processed file 25/84 in 2.072297s
Processed file 26/84 in 2.112418s
Processed file 27/84 in 2.314410s
Processed file 28/84 in 1.838153s
Processed file 29/84 in 2.059869s
Processed file 30/84 in 2.147930s
Processed file 31/84 in

### Model Evaluation (Cross Validation)

In [14]:
# Initialize TimeSeriesSplit for cross-validation
tscv = TimeSeriesSplit(n_splits=5)

mse_scores = []
rmse_scores = []
mae_scores = []
r2_scores = []

# Lists to store losses for plotting
all_train_losses = []
all_val_losses = []

epochs = 50
patience = 5 
early_stop = False

for fold, (train_index, val_index) in enumerate(tscv.split(yield_data_weekly_filtered)):
    print(f"Fold {fold + 1}")

    # Check if indices are within bounds
    if max(train_index) >= len(dataset) or max(val_index) >= len(dataset):
        print(f"Error: Indices out of range for fold {fold + 1}")
        continue 

    # Subsets for the current fold
    fold_train_subset = torch.utils.data.Subset(dataset, train_index)
    fold_val_subset = torch.utils.data.Subset(dataset, val_index)

    # DataLoaders for the current fold
    fold_train_loader = DataLoader(fold_train_subset, batch_size=batch_size, shuffle=True)
    fold_val_loader = DataLoader(fold_val_subset, batch_size=batch_size, shuffle=False)

    # Instantiate a new model for each fold
    model = HybridModel(CNNFeatureExtractor())
    model.apply(weights_init)
    model.to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0.0001)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
    criterion = nn.MSELoss()

    # Initialize variables for early stopping
    best_val_loss = float('inf')
    epochs_without_improvement = 0

    # Lists to store the losses for this fold
    train_losses = []
    val_losses = []

    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        for inputs, labels, time_features, timestamp in tqdm(fold_train_loader):
            inputs, labels, time_features = inputs.to(device), labels.to(device), time_features.to(device)
            optimizer.zero_grad()
            outputs = model(inputs, time_features)
            labels = labels / (512 * 512)
            labels = labels.unsqueeze(1).unsqueeze(2).expand(-1, target_shape[0], target_shape[1])
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        epoch_loss = running_loss / len(fold_train_loader)
        train_losses.append(epoch_loss)
        print(f'Epoch {epoch + 1}, Loss: {epoch_loss}')

        # Evaluate on validation set
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for inputs, labels, time_features, timestamp in fold_val_loader:
                inputs, labels, time_features = inputs.to(device), labels.to(device), time_features.to(device)
                outputs = model(inputs, time_features)
                labels = labels.unsqueeze(1).unsqueeze(2).expand(-1, target_shape[0], target_shape[1])
                loss = criterion(outputs, labels)
                val_loss += loss.item()
        val_loss /= len(fold_val_loader)
        val_losses.append(val_loss)
        print(f'Validation Loss: {val_loss}')

        scheduler.step(val_loss)

        # Check for early stopping
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            epochs_without_improvement = 0
        else:
            epochs_without_improvement += 1

        if epochs_without_improvement >= patience:
            print(f"Early stopping triggered after {epoch + 1} epochs.")
            early_stop = True
            break

    if early_stop:
        break

    # Store the train and val losses for visualization
    all_train_losses.append(train_losses)
    all_val_losses.append(val_losses)

    # Model evaluation on the validation set
    model.eval()
    with torch.no_grad():
        outputs_val = []
        labels_val = []
        for evi_batch, label_batch, time_features_batch, timestamp in fold_val_loader:
            evi_batch, label_batch, time_features_batch = evi_batch.to(device), label_batch.to(device), time_features_batch.to(device)
            outputs_batch = model(evi_batch, time_features_batch) # lbs/pixel
            outputs_val.extend(outputs_batch.cpu().numpy().flatten())
            label_batch = label_batch.unsqueeze(1).unsqueeze(2).expand(-1, target_shape[0], target_shape[1])
            labels_val.extend(label_batch.cpu().numpy().flatten())

    outputs_val = np.array(outputs_val)
    labels_val = np.array(labels_val)

    # Calculate val metrics
    mse = mean_squared_error(labels_val, outputs_val)
    rmse = np.sqrt(mse)
    mae = mean_absolute_error(labels_val, outputs_val)
    r2 = r2_score(labels_val, outputs_val)

    mse_scores.append(mse)
    rmse_scores.append(rmse)
    mae_scores.append(mae)
    r2_scores.append(r2)

# Print results
print(f"Average MSE: {np.mean(mse_scores)}")
print(f"Average RMSE: {np.mean(rmse_scores)}")
print(f"Average MAE: {np.mean(mae_scores)}")
print(f"Average R-squared: {np.mean(r2_scores)}")

Fold 1


100%|██████████| 6/6 [00:13<00:00,  2.27s/it]


Epoch 1, Loss: 0.13158051172892252




Validation Loss: 0.16321770598491034


100%|██████████| 6/6 [00:13<00:00,  2.23s/it]


Epoch 2, Loss: 0.07084010665615399




Validation Loss: 0.1429088662068049


100%|██████████| 6/6 [00:13<00:00,  2.22s/it]


Epoch 3, Loss: 0.052040555203954376




Validation Loss: 0.12773683667182922


100%|██████████| 6/6 [00:13<00:00,  2.22s/it]


Epoch 4, Loss: 0.038451588402191796




Validation Loss: 0.11675967338184516


100%|██████████| 6/6 [00:13<00:00,  2.20s/it]


Epoch 5, Loss: 0.02824142078558604




Validation Loss: 0.10856030198434989


100%|██████████| 6/6 [00:13<00:00,  2.25s/it]


Epoch 6, Loss: 0.021477386665840942




Validation Loss: 0.10323445840428273


100%|██████████| 6/6 [00:14<00:00,  2.43s/it]


Epoch 7, Loss: 0.016581515781581402




Validation Loss: 0.0993532349045078


100%|██████████| 6/6 [00:13<00:00,  2.25s/it]


Epoch 8, Loss: 0.012992694197843472




Validation Loss: 0.09651622238258521


100%|██████████| 6/6 [00:13<00:00,  2.29s/it]


Epoch 9, Loss: 0.010404997350027164




Validation Loss: 0.0978635367937386


100%|██████████| 6/6 [00:14<00:00,  2.36s/it]


Epoch 10, Loss: 0.010395593009889126




Validation Loss: 0.09480611172815163


100%|██████████| 6/6 [00:13<00:00,  2.28s/it]


Epoch 11, Loss: 0.007463878253474832




Validation Loss: 0.09228268653775255


100%|██████████| 6/6 [00:13<00:00,  2.27s/it]


Epoch 12, Loss: 0.005378499239062269




Validation Loss: 0.09061341173946857


100%|██████████| 6/6 [00:12<00:00,  2.10s/it]


Epoch 13, Loss: 0.004042394498052697




Validation Loss: 0.08958852849900723


100%|██████████| 6/6 [00:12<00:00,  2.08s/it]


Epoch 14, Loss: 0.0031883688255523643




Validation Loss: 0.08889540715608746


100%|██████████| 6/6 [00:12<00:00,  2.12s/it]


Epoch 15, Loss: 0.0026071711520974836




Validation Loss: 0.08838166509910177


100%|██████████| 6/6 [00:12<00:00,  2.13s/it]


Epoch 16, Loss: 0.0021624588019525013




Validation Loss: 0.08798838690078507


100%|██████████| 6/6 [00:12<00:00,  2.14s/it]


Epoch 17, Loss: 0.0018065229329901438




Validation Loss: 0.08767508591214816


100%|██████████| 6/6 [00:12<00:00,  2.15s/it]


Epoch 18, Loss: 0.0016679828986525536




Validation Loss: 0.08741820623011638


100%|██████████| 6/6 [00:12<00:00,  2.09s/it]


Epoch 19, Loss: 0.0012853331475829084




Validation Loss: 0.08721349681339537


100%|██████████| 6/6 [00:13<00:00,  2.26s/it]


Epoch 20, Loss: 0.0010972718203750749




Validation Loss: 0.087047203463347


100%|██████████| 6/6 [00:12<00:00,  2.11s/it]


Epoch 21, Loss: 0.0009504805590646962




Validation Loss: 0.08690991757127146


100%|██████████| 6/6 [00:12<00:00,  2.16s/it]


Epoch 22, Loss: 0.0008184946297357479




Validation Loss: 0.08679347159340978


100%|██████████| 6/6 [00:12<00:00,  2.14s/it]


Epoch 23, Loss: 0.0007052996758526812




Validation Loss: 0.08669634178901713


100%|██████████| 6/6 [00:12<00:00,  2.16s/it]


Epoch 24, Loss: 0.0006232252150463561




Validation Loss: 0.08661424889578484


100%|██████████| 6/6 [00:12<00:00,  2.13s/it]


Epoch 25, Loss: 0.0005414131883298978




Validation Loss: 0.0865439226909075


100%|██████████| 6/6 [00:12<00:00,  2.14s/it]


Epoch 26, Loss: 0.00047982525332675624




Validation Loss: 0.08648393190621088


100%|██████████| 6/6 [00:12<00:00,  2.12s/it]


Epoch 27, Loss: 0.0004248476228288685




Validation Loss: 0.0864324920985382


100%|██████████| 6/6 [00:12<00:00,  2.13s/it]


Epoch 28, Loss: 0.00038168372217720997




Validation Loss: 0.08638806700279626


100%|██████████| 6/6 [00:13<00:00,  2.26s/it]


Epoch 29, Loss: 0.00033424857732219




Validation Loss: 0.08634893244015984


100%|██████████| 6/6 [00:12<00:00,  2.14s/it]


Epoch 30, Loss: 0.0002988776929366092




Validation Loss: 0.08631522860378027


100%|██████████| 6/6 [00:13<00:00,  2.25s/it]


Epoch 31, Loss: 0.00026395749106692773




Validation Loss: 0.08628575611995377


100%|██████████| 6/6 [00:13<00:00,  2.17s/it]


Epoch 32, Loss: 0.00023716818638301143




Validation Loss: 0.08626000831524532


100%|██████████| 6/6 [00:13<00:00,  2.22s/it]


Epoch 33, Loss: 0.0002144788644121339




Validation Loss: 0.08623751021999244


100%|██████████| 6/6 [00:13<00:00,  2.17s/it]


Epoch 34, Loss: 0.00019304103382940715




Validation Loss: 0.08621729913769134


100%|██████████| 6/6 [00:12<00:00,  2.16s/it]


Epoch 35, Loss: 0.00017479293213303512




Validation Loss: 0.08619943986316987


100%|██████████| 6/6 [00:13<00:00,  2.25s/it]


Epoch 36, Loss: 0.00015697164053563029




Validation Loss: 0.08618347605564243


100%|██████████| 6/6 [00:13<00:00,  2.22s/it]


Epoch 37, Loss: 0.00014430515633042282




Validation Loss: 0.08616946938127512


100%|██████████| 6/6 [00:13<00:00,  2.19s/it]


Epoch 38, Loss: 0.0001565299171488732




Validation Loss: 0.08651337333139963


100%|██████████| 6/6 [00:13<00:00,  2.22s/it]


Epoch 39, Loss: 0.0003447851462018055




Validation Loss: 0.08616314571069476


100%|██████████| 6/6 [00:13<00:00,  2.29s/it]


Epoch 40, Loss: 0.00010973241296596825




Validation Loss: 0.08613543809042312


100%|██████████| 6/6 [00:13<00:00,  2.20s/it]


Epoch 41, Loss: 9.973511259886436e-05




Validation Loss: 0.08612605276721297


100%|██████████| 6/6 [00:12<00:00,  2.13s/it]


Epoch 42, Loss: 9.122056629469928e-05




Validation Loss: 0.08611763372391579


100%|██████████| 6/6 [00:12<00:00,  2.12s/it]


Epoch 43, Loss: 8.30828539619688e-05




Validation Loss: 0.08611023823201928


100%|██████████| 6/6 [00:13<00:00,  2.19s/it]


Epoch 44, Loss: 7.676991905706625e-05




Validation Loss: 0.08610350921784023


100%|██████████| 6/6 [00:13<00:00,  2.18s/it]


Epoch 45, Loss: 6.97420885747609e-05




Validation Loss: 0.08609737941636315


100%|██████████| 6/6 [00:13<00:00,  2.23s/it]


Epoch 46, Loss: 6.498872911227711e-05




Validation Loss: 0.0860919697151985


100%|██████████| 6/6 [00:13<00:00,  2.24s/it]


Epoch 47, Loss: 5.987598585003676e-05




Validation Loss: 0.08608677495794836


100%|██████████| 6/6 [00:13<00:00,  2.24s/it]


Epoch 48, Loss: 5.492689221379502e-05




Validation Loss: 0.08608208717487287


100%|██████████| 6/6 [00:12<00:00,  2.04s/it]


Epoch 49, Loss: 5.1291989317784704e-05




Validation Loss: 0.08607789286543266


100%|██████████| 6/6 [00:12<00:00,  2.10s/it]


Epoch 50, Loss: 4.673300403131483e-05
Validation Loss: 0.08607392197639759




Fold 2


100%|██████████| 12/12 [00:22<00:00,  1.89s/it]


Epoch 1, Loss: 0.13110416010022163




Validation Loss: 0.19226237386465073


100%|██████████| 12/12 [00:22<00:00,  1.89s/it]


Epoch 2, Loss: 0.06429157747576635




Validation Loss: 0.16041402394572893


100%|██████████| 12/12 [00:22<00:00,  1.88s/it]


Epoch 3, Loss: 0.03824029009168347




Validation Loss: 0.14184824097901583


100%|██████████| 12/12 [00:22<00:00,  1.86s/it]


Epoch 4, Loss: 0.023742025562872488




Validation Loss: 0.13129415176808834


100%|██████████| 12/12 [00:22<00:00,  1.87s/it]


Epoch 5, Loss: 0.01542381476610899




Validation Loss: 0.12497924206157525


100%|██████████| 12/12 [00:22<00:00,  1.88s/it]


Epoch 6, Loss: 0.010413038621967038




Validation Loss: 0.12099024777611096


100%|██████████| 12/12 [00:22<00:00,  1.87s/it]


Epoch 7, Loss: 0.007183604485665758




Validation Loss: 0.11834725271910429


100%|██████████| 12/12 [00:22<00:00,  1.90s/it]


Epoch 8, Loss: 0.0050651030614972115




Validation Loss: 0.11658169468864799


100%|██████████| 12/12 [00:22<00:00,  1.88s/it]


Epoch 9, Loss: 0.0035785577298762896




Validation Loss: 0.1154710774620374


100%|██████████| 12/12 [00:22<00:00,  1.87s/it]


Epoch 10, Loss: 0.0026578357986484966




Validation Loss: 0.1145189261296764


100%|██████████| 12/12 [00:22<00:00,  1.87s/it]


Epoch 11, Loss: 0.0019407479849178344




Validation Loss: 0.11390805205640693


100%|██████████| 12/12 [00:22<00:00,  1.88s/it]


Epoch 12, Loss: 0.0013915518939029425




Validation Loss: 0.11348237857843439


100%|██████████| 12/12 [00:23<00:00,  1.92s/it]


Epoch 13, Loss: 0.0010647122350443776




Validation Loss: 0.1131818318584313


100%|██████████| 12/12 [00:23<00:00,  1.92s/it]


Epoch 14, Loss: 0.000792925760227566




Validation Loss: 0.112963329651393


100%|██████████| 12/12 [00:22<00:00,  1.91s/it]


Epoch 15, Loss: 0.0006105986831244081




Validation Loss: 0.11280394167018433


100%|██████████| 12/12 [00:23<00:00,  1.92s/it]


Epoch 16, Loss: 0.0004795308268512599




Validation Loss: 0.11268590631273885


100%|██████████| 12/12 [00:22<00:00,  1.91s/it]


Epoch 17, Loss: 0.00037737156526418403




Validation Loss: 0.11259639111813158


100%|██████████| 12/12 [00:22<00:00,  1.89s/it]


Epoch 18, Loss: 0.0003020859900667953




Validation Loss: 0.11252814391627908


100%|██████████| 12/12 [00:23<00:00,  1.92s/it]


Epoch 19, Loss: 0.0002461913488029192




Validation Loss: 0.11247576161986217


100%|██████████| 12/12 [00:23<00:00,  1.92s/it]


Epoch 20, Loss: 0.00020246292600252977




Validation Loss: 0.11243433208437636


100%|██████████| 12/12 [00:22<00:00,  1.87s/it]


Epoch 21, Loss: 0.00016627890242186064




Validation Loss: 0.11240167582097153


100%|██████████| 12/12 [00:23<00:00,  1.93s/it]


Epoch 22, Loss: 0.00013675196896656416




Validation Loss: 0.11237526531719293


100%|██████████| 12/12 [00:22<00:00,  1.91s/it]


Epoch 23, Loss: 0.00011454065073242721




Validation Loss: 0.1123541688430123


100%|██████████| 12/12 [00:23<00:00,  1.93s/it]


Epoch 24, Loss: 9.704066845491373e-05




Validation Loss: 0.11233706562779844


100%|██████████| 12/12 [00:22<00:00,  1.89s/it]


Epoch 25, Loss: 8.17766861776666e-05




Validation Loss: 0.11232276912778616


100%|██████████| 12/12 [00:22<00:00,  1.89s/it]


Epoch 26, Loss: 7.071778418321628e-05




Validation Loss: 0.11231102288002148


100%|██████████| 12/12 [00:22<00:00,  1.91s/it]


Epoch 27, Loss: 6.003060540630637e-05




Validation Loss: 0.11259902778935309


100%|██████████| 12/12 [00:22<00:00,  1.87s/it]


Epoch 28, Loss: 8.186452062849033e-05




Validation Loss: 0.11229311713638405


100%|██████████| 12/12 [00:23<00:00,  1.92s/it]


Epoch 29, Loss: 4.490940227697138e-05




Validation Loss: 0.11228611015637095


100%|██████████| 12/12 [00:22<00:00,  1.90s/it]


Epoch 30, Loss: 3.837972932766812e-05




Validation Loss: 0.11227934609632939


100%|██████████| 12/12 [00:22<00:00,  1.89s/it]


Epoch 31, Loss: 3.3903884589866116e-05




Validation Loss: 0.11227303719109234


100%|██████████| 12/12 [00:22<00:00,  1.91s/it]


Epoch 32, Loss: 2.9419090727363557e-05




Validation Loss: 0.11226847421494313


100%|██████████| 12/12 [00:22<00:00,  1.87s/it]


Epoch 33, Loss: 2.5421717509743758e-05




Validation Loss: 0.11226468890284498


100%|██████████| 12/12 [00:22<00:00,  1.89s/it]


Epoch 34, Loss: 2.1972770961535087e-05




Validation Loss: 0.11226145531206082


100%|██████████| 12/12 [00:22<00:00,  1.90s/it]


Epoch 35, Loss: 1.9174537707537336e-05




Validation Loss: 0.11225865665862027


100%|██████████| 12/12 [00:23<00:00,  1.92s/it]


Epoch 36, Loss: 1.7082805091680104e-05




Validation Loss: 0.11225633654976264


100%|██████████| 12/12 [00:22<00:00,  1.90s/it]


Epoch 37, Loss: 1.5034790067147696e-05




Validation Loss: 0.11225430557775933


100%|██████████| 12/12 [00:23<00:00,  1.92s/it]


Epoch 38, Loss: 1.3205548232993655e-05




Validation Loss: 0.11225255928972426


100%|██████████| 12/12 [00:22<00:00,  1.91s/it]


Epoch 39, Loss: 1.1673809467538376e-05




Validation Loss: 0.11225102407236894


100%|██████████| 12/12 [00:23<00:00,  1.92s/it]


Epoch 40, Loss: 1.0461560729405997e-05




Validation Loss: 0.11224971535072352


100%|██████████| 12/12 [00:22<00:00,  1.91s/it]


Epoch 41, Loss: 9.081969134664783e-06




Validation Loss: 0.11224854811249922


100%|██████████| 12/12 [00:22<00:00,  1.91s/it]


Epoch 42, Loss: 8.1459690666937e-06




Validation Loss: 0.11224754327364887


100%|██████████| 12/12 [00:22<00:00,  1.90s/it]


Epoch 43, Loss: 7.390377542530284e-06




Validation Loss: 0.11224665634411697


100%|██████████| 12/12 [00:22<00:00,  1.91s/it]


Epoch 44, Loss: 6.459842931387054e-06




Validation Loss: 0.11224592873865429


100%|██████████| 12/12 [00:22<00:00,  1.90s/it]


Epoch 45, Loss: 5.76154468490131e-06




Validation Loss: 0.11224516715932016


100%|██████████| 12/12 [00:23<00:00,  1.93s/it]


Epoch 46, Loss: 5.119920084932043e-06




Validation Loss: 0.11224453965163168


100%|██████████| 12/12 [00:22<00:00,  1.90s/it]


Epoch 47, Loss: 4.582196197588928e-06




Validation Loss: 0.11224402050720528


100%|██████████| 12/12 [00:22<00:00,  1.89s/it]


Epoch 48, Loss: 4.184225531389529e-06




Validation Loss: 0.11224356074429427


100%|██████████| 12/12 [00:22<00:00,  1.91s/it]


Epoch 49, Loss: 3.7307388917421727e-06




Validation Loss: 0.11224313436347681


100%|██████████| 12/12 [00:22<00:00,  1.89s/it]


Epoch 50, Loss: 3.2423543151101817e-06
Validation Loss: 0.11224271159153432




Fold 3


100%|██████████| 18/18 [00:35<00:00,  1.96s/it]


Epoch 1, Loss: 0.06430631731119421




Validation Loss: 0.1823284768809875


100%|██████████| 18/18 [00:35<00:00,  1.98s/it]


Epoch 2, Loss: 0.022410357267492347




Validation Loss: 0.16764985180149475


100%|██████████| 18/18 [00:34<00:00,  1.94s/it]


Epoch 3, Loss: 0.012163541900614897




Validation Loss: 0.16069566582640013


100%|██████████| 18/18 [00:34<00:00,  1.94s/it]


Epoch 4, Loss: 0.007038390134564704




Validation Loss: 0.15701438890149197


 33%|███▎      | 6/18 [00:12<00:24,  2.04s/it]

### Plot Loss Across Folds

In [None]:
for fold, (train_losses, val_losses) in enumerate(zip(all_train_losses, all_val_losses)):
    plt.plot(train_losses, label=f'Train Loss Fold {fold + 1}')
    plt.plot(val_losses, label=f'Val Loss Fold {fold + 1}', linestyle='--')

plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Training and Validation Loss Across Folds')
plt.legend()
plt.grid(True)
plt.show()

# Train on full dataset

In [8]:
# Instantiate a new model for each fold
model = HybridModel(CNNFeatureExtractor())
model.apply(weights_init)
model.to(device)

# Set up the optimizer, scheduler, and loss function
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0.0001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
criterion = nn.MSELoss()

# Train and evaluate the model
val_loss = train_and_evaluate(model, train_loader, val_loader, optimizer, scheduler, criterion, epochs, device)

torch.save(model.state_dict(), "./trained-full-dataset-yield-density-no-leakage.pt")


# of samples - Training   - 510
# of samples - Validation - 128


  evi_sequence = torch.tensor(evi_sequence, dtype=torch.float32).unsqueeze(1)
100%|██████████| 128/128 [01:03<00:00,  2.01it/s]


Epoch 1, Loss: 0.02303801325973609




Validation Loss: 0.13716325513087213


100%|██████████| 128/128 [01:00<00:00,  2.11it/s]


Epoch 2, Loss: 0.00021912637054555262




Validation Loss: 0.13683857419528067


100%|██████████| 128/128 [01:00<00:00,  2.11it/s]


Epoch 3, Loss: 4.714069107225605e-05




Validation Loss: 0.13679931915248744


100%|██████████| 128/128 [01:00<00:00,  2.11it/s]


Epoch 4, Loss: 1.2139776693148585e-05




Validation Loss: 0.13678972469642758


100%|██████████| 128/128 [01:00<00:00,  2.11it/s]


Epoch 5, Loss: 2.449644118566064e-06




Validation Loss: 0.13678752846317366


100%|██████████| 128/128 [01:00<00:00,  2.11it/s]


Epoch 6, Loss: 7.677901634311471e-07




Validation Loss: 0.13678681483725086


100%|██████████| 128/128 [01:00<00:00,  2.10it/s]


Epoch 7, Loss: 3.032403498042413e-07




Validation Loss: 0.13678655843250453


100%|██████████| 128/128 [01:01<00:00,  2.10it/s]


Epoch 8, Loss: 1.6920563471189976e-07




Validation Loss: 0.13678646262269467


100%|██████████| 128/128 [01:00<00:00,  2.10it/s]


Epoch 9, Loss: 1.0260573640036297e-07




Validation Loss: 0.13678642455488443


100%|██████████| 128/128 [01:00<00:00,  2.11it/s]


Epoch 10, Loss: 5.5787742407435725e-08




Validation Loss: 0.1367864032217767


100%|██████████| 128/128 [01:00<00:00,  2.11it/s]


Epoch 11, Loss: 2.9276606201833477e-08




Validation Loss: 0.13678638232522644


100%|██████████| 128/128 [01:00<00:00,  2.11it/s]


Epoch 12, Loss: 1.2652823795644697e-08




Validation Loss: 0.1367863641353324


100%|██████████| 128/128 [01:00<00:00,  2.11it/s]


Epoch 13, Loss: 4.771049414917239e-09




Validation Loss: 0.13678636751137674


100%|██████████| 128/128 [01:00<00:00,  2.12it/s]


Epoch 14, Loss: 1.5382479281230571e-09




Validation Loss: 0.13678635086398572


100%|██████████| 128/128 [01:00<00:00,  2.12it/s]


Epoch 15, Loss: 4.786846412416397e-10




Validation Loss: 0.13678636238910258


100%|██████████| 128/128 [01:00<00:00,  2.11it/s]


Epoch 16, Loss: 1.3792824101279723e-10




Validation Loss: 0.1367863685300108


100%|██████████| 128/128 [01:00<00:00,  2.12it/s]


Epoch 17, Loss: 3.7930990881324296e-11




Validation Loss: 0.1367863482737448


100%|██████████| 128/128 [01:00<00:00,  2.11it/s]


Epoch 18, Loss: 1.030097532846283e-11




Validation Loss: 0.13678636791883036


100%|██████████| 128/128 [01:00<00:00,  2.11it/s]


Epoch 19, Loss: 3.4419171131595274e-12




Validation Loss: 0.1367863736813888


100%|██████████| 128/128 [01:00<00:00,  2.11it/s]


Epoch 20, Loss: 1.862709280804205e-12




Validation Loss: 0.13678635843098164


100%|██████████| 128/128 [01:00<00:00,  2.12it/s]


Epoch 21, Loss: 1.5365051628925621e-12




Validation Loss: 0.1367863556370139


100%|██████████| 128/128 [01:00<00:00,  2.12it/s]


Epoch 22, Loss: 1.483005749043814e-12
Validation Loss: 0.1367863569757901
Early stopping!




In [None]:
# Loss Curve
# Assuming you have these lists from your training process
train_losses = []  # Fill this with training loss for each epoch
val_losses = []    # Fill this with validation loss for each epoch

# Plotting the loss curves
plt.figure(figsize=(10, 5))
plt.plot(range(1, len(train_losses) + 1), train_losses, label='Training Loss')
plt.plot(range(1, len(val_losses) + 1), val_losses, label='Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Training and Validation Loss over Epochs')
plt.legend()
plt.grid(True)
plt.show()

# Predicted vs. Actual Values
predicted_values = []  # Fill this with predicted values
actual_values = []     # Fill this with actual values

# Plotting predicted vs actual values
plt.figure(figsize=(10, 5))
plt.scatter(actual_values, predicted_values, alpha=0.5)
plt.xlabel('Actual Values')
plt.ylabel('Predicted Values')
plt.title('Predicted vs. Actual Values')
plt.grid(True)
plt.plot([min(actual_values), max(actual_values)], [min(actual_values), max(actual_values)], 'r')  # Line y=x
plt.show()

# Residuals Plot
residuals = np.array(actual_values) - np.array(predicted_values)

# Plotting the residuals
plt.figure(figsize=(10, 5))
plt.scatter(range(len(residuals)), residuals, alpha=0.5)
plt.axhline(y=0, color='r', linestyle='-')
plt.xlabel('Index')
plt.ylabel('Residual')
plt.title('Residuals of Predictions')
plt.grid(True)
plt.show()

# Inference

In [9]:
import joblib

# load in model from file
# inf_model_weights = torch.load("trained-full-dataset.pt", weights_only=True)
inf_model_weights = torch.load("trained-full-dataset-yield-density-no-leakage.pt", weights_only=True)
inf_model = HybridModel(CNNFeatureExtractor())
inf_model.load_state_dict(inf_model_weights)
inf_model.to(device)
inf_model.eval()

scaler = joblib.load("yield_scaler.save")

In [10]:

# inf_output = inf_model(evi_val, time_features_val)

# print(f"{evi_val.shape = }")
# print(f"{time_features_val.shape = }")
# print(f"{inf_output.shape = }")

In [11]:
yield_data_weekly.iloc[0].name

Timestamp('2012-03-04 00:00:00')

In [12]:

evi_data_dir = "./landsat_evi_monterey_masked"
dataset_loader, _, mean, std = prepare_dataset(evi_data_dir, yield_data_weekly, target_shape, augment=True, full=True)

Processed file 1/83 in 4.125449s
Processed file 2/83 in 4.246559s
Processed file 3/83 in 4.898093s
Processed file 4/83 in 4.672359s
Processed file 5/83 in 5.460909s
Processed file 6/83 in 4.038937s
Processed file 7/83 in 3.490030s
Processed file 8/83 in 4.127463s
Processed file 9/83 in 3.800272s
Processed file 10/83 in 3.523895s
Processed file 11/83 in 3.932964s
Processed file 12/83 in 3.734574s
Processed file 13/83 in 4.148887s
Processed file 14/83 in 4.548136s
Processed file 15/83 in 3.748950s
Processed file 16/83 in 3.109021s
Processed file 17/83 in 4.015578s
Processed file 18/83 in 4.144617s
Processed file 19/83 in 3.917441s
Processed file 20/83 in 4.372727s
Processed file 21/83 in 4.537723s
Processed file 22/83 in 3.373776s
Processed file 23/83 in 3.701019s
Processed file 24/83 in 3.277262s
Processed file 25/83 in 3.982489s
Processed file 26/83 in 4.063423s
Processed file 27/83 in 3.969857s
Processed file 28/83 in 3.823087s
Processed file 29/83 in 4.417579s
Processed file 30/83 in

In [13]:
timestamps = torch.Tensor()
yield_labels = torch.Tensor()
predictions = torch.Tensor()

for idx, (inputs, labels, time_features, timestamp) in enumerate(dataset_loader):
    print(f"Running inference... {idx/len(dataset_loader)*100:.2f}%", end='\r')
    inputs, labels, time_features = inputs.to(device), labels.to(device), time_features.to(device)
    outputs = inf_model(inputs, time_features)
    summed_outputs = outputs.sum(dim=(1,2))

    if idx >0:
        break
    timestamps = torch.cat((timestamps, timestamp))
    yield_labels = torch.cat((yield_labels, labels.to("cpu")))
    predictions = torch.cat((predictions, summed_outputs.to("cpu")))

    # loss = criterion(outputs, labels)
    # val_loss += loss.item()

# val_loss /= len(val_loader)
# print(f'Validation Loss: {val_loss}')

Running inference... 0.62%

In [14]:
yield_labels.reshape(-1,1)

tensor([0.4385, 0.8434, 0.4935, 0.0000])

In [19]:
scaler.inverse_transform(yield_labels.reshape(-1, 1))

array([[20429036.3072927 ],
       [39287413.27969694],
       [22990931.39602876],
       [       0.        ]])

In [22]:
scaler.inverse_transform(predictions.detach().numpy().reshape(-1,1))

array([[1023663.94],
       [1020795.2 ],
       [1033442.56],
       [ 977205.5 ]], dtype=float32)

In [23]:
yield_labels

tensor([0.4385, 0.8434, 0.4935, 0.0000])

In [24]:
yield_data_weekly

Unnamed: 0_level_0,Volume (Pounds),Cumulative Volumne (Pounds),Pounds/Acre,month_sin,month_cos,day_of_year_sin,day_of_year_cos
Date,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
2012-03-04,0.011286,1785843.0,18.333333,1.000000e+00,6.123234e-17,0.891981,0.452072
2012-03-11,0.063317,4735377.0,51.666667,1.000000e+00,6.123234e-17,0.939856,0.341571
2012-03-18,0.102446,9507645.0,83.500000,1.000000e+00,6.123234e-17,0.974100,0.226116
2012-03-25,0.067456,12649959.0,55.000000,1.000000e+00,6.123234e-17,0.994218,0.107381
2012-04-01,0.134627,18921357.0,93.857143,8.660254e-01,-5.000000e-01,0.999917,-0.012910
...,...,...,...,...,...,...,...
2024-05-12,0.767907,682790517.0,305.285714,5.000000e-01,-8.660254e-01,0.752667,-0.658402
2024-05-19,0.787426,682790517.0,365.166667,5.000000e-01,-8.660254e-01,0.668064,-0.744104
2024-05-26,0.827681,682790517.0,329.285714,5.000000e-01,-8.660254e-01,0.573772,-0.819015
2024-06-02,0.796377,682790517.0,316.571429,1.224647e-16,-1.000000e+00,0.471160,-0.882048


In [None]:
timestamps, yield_labels, predictions

In [None]:
yield_data_weekly

In [None]:
out_df = pd.DataFrame(data={"timestamp":timestamps.to_numpy(), "prediction":predictions.to_numpy(), "truth":yield_labels.to_numpy()})
out_df.to_csv("out.csv")