In [4]:
import os
import sys
from tqdm import tqdm
import torch
import pandas as pd
import glob
from torch.utils.data import DataLoader
from sklearn.model_selection import GroupShuffleSplit
from sklearn.metrics import mean_absolute_error, r2_score, root_mean_squared_error

base_path = os.path.abspath(os.path.join(os.getcwd(), ".."))  # parent of MetaWeightingModel
sys.path.append(base_path)
from TraitPredictionModel.ModelArchitecture.DenseNetModel import DenseNet121WheatModel
from TraitPredictionModel.ModelArchitecture.EfficientNetV2Model import EfficientNetV2SWheatCountWithConfidence
from TraitPredictionModel.ModelArchitecture.EfficientNetV2MModel import EfficientNetV2MWheatModelWithConfidence
from TraitPredictionModel.ModelArchitecture.RegNetY8GFModel import RegNetY8GFModel
from TraitPredictionModel.ModelArchitecture.EfficientNetV2MAddextrainputModel import EfficientNetV2MConfidenceAddeonextrainput
from TraitPredictionModel.ModelArchitecture.RegNetY8GFAddextrainputModel import RegNetY8GFConfidenceAddoneextrainput
from TraitPredictionModel.dataClass import WheatEarDataset

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
def setDevice():
    if torch.backends.mps.is_available():
        device = "mps"  # Use Apple Metal (Mac M1/M2)
        torch.set_default_tensor_type(torch.FloatTensor)
    elif torch.cuda.is_available():
        device = "cuda"  # Use NVIDIA CUDA (Windows RTX 4060)
    else:
        device = "cpu"  # Default to CPU if no GPU is available
    print(f"Using device: {device}")
    return device

def loadFullData(dataPath):
    df = pd.read_csv(dataPath)
    print(f"Loaded full dataset → Rows: {len(df)}, Columns: {len(df.columns)}")
    return df

def createLoader_full(df, traitName, extra_input_cols=None, batch_size=16, shuffle=False):
    """
    Create a single DataLoader from the full DataFrame.
    """
    dataset = WheatEarDataset(df, label_col=traitName, extra_input_cols=extra_input_cols)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=4)

    print(f"Full DataLoader created → Batches: {len(loader)}")
    return loader

def testModelForPrepareData(model, test_loader, device, traitName, output_csv):
    model.eval()
    preds, stds, targets, datakeys = [], [], [], []

    with torch.no_grad():
        for batch in tqdm(test_loader, desc="Testing"):
            if len(batch) == 4:
                rgb_batch, dsm_batch, label_batch, datakey_batch = batch
            else:
                rgb_batch, dsm_batch, _, label_batch, datakey_batch = batch

            rgb_batch = rgb_batch.to(device)
            dsm_batch = dsm_batch.to(device)

            output = model(rgb_batch, dsm_batch)
            pred_mean = output[:, 0].cpu().numpy()
            pred_std = (torch.exp(0.5 * output[:, 1])).cpu().numpy()
            label_batch = label_batch.squeeze().cpu().numpy()

            preds.extend(pred_mean)
            stds.extend(pred_std)
            targets.extend(label_batch)
            datakeys.extend(datakey_batch)

    r2 = r2_score(targets, preds)
    mae = mean_absolute_error(targets, preds)
    rmse = root_mean_squared_error(targets, preds)

    print(f"\nTest Results:")
    print(f"R² Score : {r2:.4f}")
    print(f"MAE      : {mae:.4f}")
    print(f"RMSE     : {rmse:.4f}")

    df = pd.DataFrame({
        "DataKey": datakeys,
        "true_" + traitName : targets,
        "predicted_" + traitName : preds,
        "predicted_std_" + traitName : stds,
    })
    df.to_csv(output_csv, index=False)
    print(f"Saved predictions to: {output_csv}")

    return df, r2, mae, rmse

def predictModelFullData(dataPath, traitName, model, modelPath):
    '''
    set data, device and test model
    '''
    # get data
    full_df = loadFullData(dataPath)
    full_loader = createLoader_full(full_df, traitName)
    
    # set device
    device = setDevice()

    # load model
    loadedModel = model().to(device)
    if(device == "cuda"):
        loadedModel.load_state_dict(torch.load(modelPath))
    else:
        loadedModel.load_state_dict(torch.load(modelPath, map_location=torch.device("cpu")))
    loadedModel.eval()

    print("traitName: ", traitName)
    print("model: ", model.__name__)

    saveFileName = "predicted_" + traitName + "_" +  model.__name__ + ".csv"

    # Run test
    testModelForPrepareData(loadedModel, full_loader, device, traitName, saveFileName)

In [None]:
predictModelFullData("DataKey_RGB_DSM_LAI_SPAD_Height_totalSeedNum_totalSeedWeightAfterDry_totEarNum_totEarWeight_day_Raw1.csv", "days", DenseNet121WheatModel, "./TraitModel/days_DenseNet121WheatModel_from3.pth")
predictModelFullData("DataKey_RGB_DSM_LAI_SPAD_Height_totalSeedNum_totalSeedWeightAfterDry_totEarNum_totEarWeight_day_Raw1.csv", "Height", DenseNet121WheatModel, "./TraitModel/Height_DenseNet121WheatModel.pth")
predictModelFullData("DataKey_RGB_DSM_LAI_SPAD_Height_totalSeedNum_totalSeedWeightAfterDry_totEarNum_totEarWeight_day_Raw1.csv", "LAI", EfficientNetV2SWheatCountWithConfidence, "./TraitModel/LAI_EfficientNetV2SWheatCountWithConfidence.pth")
predictModelFullData("DataKey_RGB_DSM_LAI_SPAD_Height_totalSeedNum_totalSeedWeightAfterDry_totEarNum_totEarWeight_day_Raw1.csv", "SPAD", EfficientNetV2SWheatCountWithConfidence, "./TraitModel/SPAD_EfficientNetV2SWheatCountWithConfidence.pth")
predictModelFullData("DataKey_RGB_DSM_LAI_SPAD_Height_totalSeedNum_totalSeedWeightAfterDry_totEarNum_totEarWeight_day_Raw1.csv", "totalSeedNum", DenseNet121WheatModel, "./TraitModel/totalSeedNum_DenseNet121WheatModel.pth")
predictModelFullData("DataKey_RGB_DSM_LAI_SPAD_Height_totalSeedNum_totalSeedWeightAfterDry_totEarNum_totEarWeight_day_Raw1.csv", "totalSeedWeightAfterDry", DenseNet121WheatModel, "./TraitModel/totalSeedWeightAfterDry_DenseNet121WheatModel.pth")
predictModelFullData("DataKey_RGB_DSM_LAI_SPAD_Height_totalSeedNum_totalSeedWeightAfterDry_totEarNum_totEarWeight_day_Raw1.csv", "totEarNum", DenseNet121WheatModel, "./TraitModel/totEarNum_DenseNet121WheatModel.pth")
predictModelFullData("DataKey_RGB_DSM_LAI_SPAD_Height_totalSeedNum_totalSeedWeightAfterDry_totEarNum_totEarWeight_day_Raw1.csv", "totEarWeight", DenseNet121WheatModel, "./TraitModel/totEarWeight_DenseNet121WheatModel0.pth")

In [8]:
def merge_trait_predictions(csv_paths, output_path="merged_predictions.csv"):
    """
    Merge multiple prediction CSV files horizontally based on 'DataKey'.
    Assumes each file shares the same order and values for 'DataKey'.

    Args:
        csv_paths (list): List of CSV file paths to merge.
        output_path (str): File path to save the merged result.
    """
    merged_df = None

    for i, path in enumerate(csv_paths):
        df = pd.read_csv(path)
        
        # Drop duplicate 'DataKey' columns except from first file
        if i == 0:
            merged_df = df
        else:
            df = df.drop(columns=["DataKey"])
            merged_df = pd.concat([merged_df, df], axis=1)

        print(f"Loaded: {path} → Shape: {df.shape}")

    print(f"\nFinal merged shape: {merged_df.shape}")
    merged_df.to_csv(output_path, index=False)
    print(f"Saved merged CSV to: {output_path}")
    return merged_df

In [9]:
csv_files = [
    "predicted_days_DenseNet121WheatModel.csv",
    "predicted_Height_DenseNet121WheatModel.csv",
    "predicted_LAI_EfficientNetV2SWheatCountWithConfidence.csv",
    "predicted_SPAD_EfficientNetV2SWheatCountWithConfidence.csv",
    "predicted_totalSeedNum_DenseNet121WheatModel.csv",
    "predicted_totalSeedWeightAfterDry_DenseNet121WheatModel.csv",
    "predicted_totEarNum_DenseNet121WheatModel.csv",
    "predicted_totEarWeight_DenseNet121WheatModel.csv"
]

merged_df = merge_trait_predictions(csv_files, output_path="meta_input_data.csv")


Loaded: predicted_days_DenseNet121WheatModel.csv → Shape: (26780, 4)
Loaded: predicted_Height_DenseNet121WheatModel.csv → Shape: (26780, 3)
Loaded: predicted_LAI_EfficientNetV2SWheatCountWithConfidence.csv → Shape: (26780, 3)
Loaded: predicted_SPAD_EfficientNetV2SWheatCountWithConfidence.csv → Shape: (26780, 3)
Loaded: predicted_totalSeedNum_DenseNet121WheatModel.csv → Shape: (26780, 3)
Loaded: predicted_totalSeedWeightAfterDry_DenseNet121WheatModel.csv → Shape: (26780, 3)
Loaded: predicted_totEarNum_DenseNet121WheatModel.csv → Shape: (26780, 3)
Loaded: predicted_totEarWeight_DenseNet121WheatModel.csv → Shape: (26780, 3)

Final merged shape: (26780, 25)
Saved merged CSV to: meta_input_data.csv
