In [1]:
import os
import sys
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split

# Get current working directory instead of __file__
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..")))
from simpleCNNModel import WheatEarModel
from dataLoaderFunc import loadSplitData, createLoader

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=10, save_path="wheat_ear_model.pth"):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model.to(device)

    best_val_loss = float("inf")  # Track the best validation loss

    for epoch in range(num_epochs):
        model.train()  # Training mode
        train_loss = 0.0

        for batch_idx, (rgb_batch, dsm_batch, label_batch) in enumerate(train_loader):
            rgb_batch, dsm_batch, label_batch = rgb_batch.to(device), dsm_batch.to(device), label_batch.to(device)

            optimizer.zero_grad()
            outputs = model(rgb_batch, dsm_batch)
            loss = criterion(outputs, label_batch)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()

            if batch_idx % 100 == 0:
                print(f"Epoch {epoch+1}/{num_epochs} | Batch {batch_idx}/{len(train_loader)} | Loss: {loss.item():.4f}")

        # Compute validation loss (without gradient updates)
        model.eval()  # Switch to evaluation mode
        val_loss = 0.0
        with torch.no_grad():
            for rgb_batch, dsm_batch, label_batch in val_loader:
                rgb_batch, dsm_batch, label_batch = rgb_batch.to(device), dsm_batch.to(device), label_batch.to(device)
                outputs = model(rgb_batch, dsm_batch)
                loss = criterion(outputs, label_batch)
                val_loss += loss.item()

        train_loss /= len(train_loader)
        val_loss /= len(val_loader)

        print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")

        # ✅ Save model if validation loss improves
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), save_path)
            print(f"✅ Model saved with Val Loss: {best_val_loss:.4f}")

def test_model(model, test_loader):
    model.load_state_dict(torch.load("best_wheat_ear_model.pth"))
    model.eval()
    predictions, actuals = [], []

    with torch.no_grad():
        for rgb_batch, dsm_batch, label_batch in test_loader:
            rgb_batch, dsm_batch = rgb_batch.to("cuda"), dsm_batch.to("cuda")
            outputs = model(rgb_batch, dsm_batch)
            predictions.extend(outputs.cpu().numpy().flatten())
            actuals.extend(label_batch.cpu().numpy().flatten())

    return predictions, actuals


In [3]:
train_df, val_df, test_df = loadSplitData("RGB_DSM_totEarNum.csv")
train_loader, val_loader, test_loader = createLoader(train_df, val_df, test_df)

Train Size: 47840, Validation Size: 5980, Test Size: 5980
Train Batches: 2990, Validation Batches: 374, Test Batches: 374


In [4]:
# ✅ Universal device selection
if torch.backends.mps.is_available():
    device = "mps"
    torch.set_default_tensor_type(torch.FloatTensor)
elif torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"
print(f"✅ Using device: {device}")
# Initialize model
model = WheatEarModel().to(device)

# Loss function (MSE for regression)
criterion = nn.MSELoss()

# Optimizer (Adam works well for deep learning)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

  _C._set_default_tensor_type(t)


✅ Using device: mps


In [5]:
train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=10, save_path="best_wheat_ear_model.pth")

Epoch 1/10 | Batch 0/2990 | Loss: 101077.9844
Epoch 1/10 | Batch 100/2990 | Loss: 7128.2104
Epoch 1/10 | Batch 200/2990 | Loss: 10878.1709
Epoch 1/10 | Batch 300/2990 | Loss: 11422.9902
Epoch 1/10 | Batch 400/2990 | Loss: 9385.7588
Epoch 1/10 | Batch 500/2990 | Loss: 8159.3750
Epoch 1/10 | Batch 600/2990 | Loss: 9748.4355
Epoch 1/10 | Batch 700/2990 | Loss: 7014.3569
Epoch 1/10 | Batch 800/2990 | Loss: 11091.7275
Epoch 1/10 | Batch 900/2990 | Loss: 3227.2292
Epoch 1/10 | Batch 1000/2990 | Loss: 8621.7129
Epoch 1/10 | Batch 1100/2990 | Loss: 7272.8960
Epoch 1/10 | Batch 1200/2990 | Loss: 9088.9443
Epoch 1/10 | Batch 1300/2990 | Loss: 7125.2275
Epoch 1/10 | Batch 1400/2990 | Loss: 10909.8594
Epoch 1/10 | Batch 1500/2990 | Loss: 3400.9009
Epoch 1/10 | Batch 1600/2990 | Loss: 3269.1514
Epoch 1/10 | Batch 1700/2990 | Loss: 4018.1504
Epoch 1/10 | Batch 1800/2990 | Loss: 6750.3384
Epoch 1/10 | Batch 1900/2990 | Loss: 6059.2178
Epoch 1/10 | Batch 2000/2990 | Loss: 4347.7017
Epoch 1/10 | Batch 

: 

In [None]:
# Evaluate on test data
preds, actuals = test_model(model, test_loader)

# Print some predictions vs actual values
for p, a in zip(preds[:10], actuals[:10]):
    print(f"Predicted: {p:.2f}, Actual: {a:.2f}")