In [4]:
import os
import torch
import pandas as pd
import torch.nn as nn
import torch.optim as optim
from model import WheatEarModel
import torch.nn.functional as F
from dataClass import WheatEarDataset
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split

In [5]:
def loadSplitData(dataPath):
    # Load dataset
    df = pd.read_csv("RGB_DSM_totEarNum.csv")

    # Train-Validation-Test Split (80%-10%-10%)
    train_df, temp_df = train_test_split(df, test_size=0.2, random_state=42)
    val_df, test_df = train_test_split(temp_df, test_size=0.5, random_state=42)

    # ✅ Reset index after splitting
    train_df = train_df.reset_index(drop=True)
    val_df = val_df.reset_index(drop=True)
    test_df = test_df.reset_index(drop=True)

    # Print sizes
    print(f"Train Size: {len(train_df)}, Validation Size: {len(val_df)}, Test Size: {len(test_df)}")

    return train_df, val_df, test_df

def createLoader(train_df, val_df, test_df):
    # Create dataset instances for each split
    train_dataset = WheatEarDataset(train_df)
    val_dataset = WheatEarDataset(val_df)
    test_dataset = WheatEarDataset(test_df)

    # Create DataLoaders
    train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=4)
    val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=4)  # No shuffle for validation
    test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False, num_workers=4)  # No shuffle for testing

    # Check sizes
    print(f"Train Batches: {len(train_loader)}, Validation Batches: {len(val_loader)}, Test Batches: {len(test_loader)}")

    return train_loader, val_loader, test_loader

def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=10, save_path="wheat_ear_model.pth"):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model.to(device)

    best_val_loss = float("inf")  # Track the best validation loss

    for epoch in range(num_epochs):
        model.train()  # Training mode
        train_loss = 0.0

        for batch_idx, (rgb_batch, dsm_batch, label_batch) in enumerate(train_loader):
            rgb_batch, dsm_batch, label_batch = rgb_batch.to(device), dsm_batch.to(device), label_batch.to(device)

            optimizer.zero_grad()
            outputs = model(rgb_batch, dsm_batch)
            loss = criterion(outputs, label_batch)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()

            if batch_idx % 100 == 0:
                print(f"Epoch {epoch+1}/{num_epochs} | Batch {batch_idx}/{len(train_loader)} | Loss: {loss.item():.4f}")

        # Compute validation loss (without gradient updates)
        model.eval()  # Switch to evaluation mode
        val_loss = 0.0
        with torch.no_grad():
            for rgb_batch, dsm_batch, label_batch in val_loader:
                rgb_batch, dsm_batch, label_batch = rgb_batch.to(device), dsm_batch.to(device), label_batch.to(device)
                outputs = model(rgb_batch, dsm_batch)
                loss = criterion(outputs, label_batch)
                val_loss += loss.item()

        train_loss /= len(train_loader)
        val_loss /= len(val_loader)

        print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")

        # ✅ Save model if validation loss improves
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), save_path)
            print(f"✅ Model saved with Val Loss: {best_val_loss:.4f}")

def test_model(model, test_loader):
    model.load_state_dict(torch.load("best_wheat_ear_model.pth"))
    model.eval()
    predictions, actuals = [], []

    with torch.no_grad():
        for rgb_batch, dsm_batch, label_batch in test_loader:
            rgb_batch, dsm_batch = rgb_batch.to("cuda"), dsm_batch.to("cuda")
            outputs = model(rgb_batch, dsm_batch)
            predictions.extend(outputs.cpu().numpy().flatten())
            actuals.extend(label_batch.cpu().numpy().flatten())

    return predictions, actuals



In [6]:
train_df, val_df, test_df = loadSplitData("RGB_DSM_totEarNum.csv")
train_loader, val_loader, test_loader = createLoader(train_df, val_df, test_df)

Train Size: 47840, Validation Size: 5980, Test Size: 5980
Train Batches: 2990, Validation Batches: 374, Test Batches: 374


In [7]:
# Initialize model
model = WheatEarModel().to("cuda" if torch.cuda.is_available() else "cpu")

# Loss function (MSE for regression)
criterion = nn.MSELoss()

# Optimizer (Adam works well for deep learning)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=10, save_path="best_wheat_ear_model.pth")

Epoch 1/10 | Batch 0/2990 | Loss: 126572.0938
Epoch 1/10 | Batch 100/2990 | Loss: 5699.4736
Epoch 1/10 | Batch 200/2990 | Loss: 13809.6064
Epoch 1/10 | Batch 300/2990 | Loss: 5816.7788
Epoch 1/10 | Batch 400/2990 | Loss: 4677.2656
Epoch 1/10 | Batch 500/2990 | Loss: 4009.7156


KeyboardInterrupt: 

In [None]:
# Evaluate on test data
preds, actuals = test_model(model, test_loader)

# Print some predictions vs actual values
for p, a in zip(preds[:10], actuals[:10]):
    print(f"Predicted: {p:.2f}, Actual: {a:.2f}")