In [3]:
import os
import sys
import torch
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score

# Get current working directory instead of __file__
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..")))
from VisionTransformer.VisionTransformerModel import ViTWheatModel
from simpleCNN.simpleCNNModel import WheatEarModel
from ResNet.ResNetModel import ResNetWheatModel
from EfficientNet.efficientNetModel import EfficientNetWheatModel
from CoAtNet.CoAtNetModel import CoAtNetWheatModel
from dataLoaderFunc import loadSplitData, createLoader
from modelTestFunc import evaluate_model, test_model

In [4]:
train_df, val_df, test_df = loadSplitData("RGB_DSM_totEarNum.csv")
train_loader, val_loader, test_loader = createLoader(train_df, val_df, test_df)

if torch.backends.mps.is_available():
    device = "mps"  # ✅ Use Apple Metal (Mac M1/M2)
    torch.set_default_tensor_type(torch.FloatTensor)
elif torch.cuda.is_available():
    device = "cuda"  # ✅ Use NVIDIA CUDA (Windows RTX 4060)
else:
    device = "cpu"  # ✅ Default to CPU if no GPU is available
print(f"✅ Using device: {device}")

Train Size: 47840, Validation Size: 5980, Test Size: 5980
Train Batches: 2990, Validation Batches: 374, Test Batches: 374
✅ Using device: cuda


In [None]:
simpleCNNModel = WheatEarModel()
simpleCNNModel.load_state_dict(torch.load("Model_Creation/totalEarsModel/simpleCNN/best_wheat_ear_model.pth"))
simpleCNNModel.to(device)
simpleCNNModel.eval()
evaluate_model(simpleCNNModel, test_loader, device, plot_predictions=True)

In [None]:
vitModel = ViTWheatModel()
vitModel.load_state_dict(torch.load("Model_Creation/totalEarsModel/VisionTransformer/vit_wheat_model.pth"))
vitModel.to(device)
vitModel.eval()
evaluate_model(vitModel, test_loader, device, plot_predictions=True)

In [None]:
# # Run test
# preds, actuals = test_model(model, test_loader)

# # Print sample predictions
# for p, a in zip(preds[:10], actuals[:10]):
#     print(f"Predicted: {p:.2f}, Actual: {a:.2f}")
