In [1]:
import torch
import numpy as np
import pandas as pd
from PIL import Image
from tqdm import tqdm
from transformers import AutoImageProcessor, ViTMAEModel
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.naive_bayes import GaussianNB
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
from sklearn.neighbors import KNeighborsClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.neural_network import MLPClassifier
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, balanced_accuracy_score, matthews_corrcoef

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"

model_name = "facebook/vit-mae-base"

image_processor = AutoImageProcessor.from_pretrained(model_name, use_fast=True)

model = ViTMAEModel.from_pretrained(model_name).to(device)


In [3]:
file_path = "/fs/ess/PAS2136/Hawaii-2025/beetles_intake/BeetlePalooza Data/Benchmarking/Beetlepalooza_beetles_image_only.csv"
df = pd.read_csv(file_path, sep="\t")
df = df[['ImageFilePath', 'ScientificName']]
df.head(2)

Unnamed: 0,ImageFilePath,ScientificName
0,/fs/ess/PAS2136/Rayeed/BeetlePalooza/individua...,Chlaenius aestivus
1,/fs/ess/PAS2136/Rayeed/BeetlePalooza/individua...,Chlaenius aestivus


In [4]:
def extract_features(image_path) :
    
    image = Image.open(image_path).convert("RGB")

    inputs = image_processor(images=image, return_tensors="pt").to(device)
    
    with torch.no_grad():
        features = model(**inputs).last_hidden_state.mean(dim=1)
    
    return features.cpu().numpy()

In [5]:
X = np.vstack([extract_features(img) for img in tqdm(df["ImageFilePath"])])

le = LabelEncoder()

y = le.fit_transform(df["ScientificName"])

df_indices = df.index 

X_train, X_test, y_train, y_test, train_idx, test_idx = train_test_split(X, y, df_indices, test_size=0.2, random_state=42)

test_df = df.loc[test_idx].copy()

scaler = StandardScaler()

X_train_scaled = scaler.fit_transform(X_train)

X_test_scaled = scaler.transform(X_test)


100%|██████████| 11399/11399 [08:02<00:00, 23.65it/s]


In [6]:
seed = 99

models = {
    "NaiveBayes": GaussianNB(),
    "LogisticRegression": LogisticRegression(max_iter=100),
    "SVMLinear": SVC(kernel="linear"),
    "SVMPolynomial": SVC(kernel="poly", degree=4),
    "SVMRadialBasis": SVC(kernel="rbf", degree=4),
    "NearestNeighbor": KNeighborsClassifier(n_neighbors=5),
    "RandomForest": RandomForestClassifier(n_estimators=100, random_state=seed),    
    "MLP_Baseline": MLPClassifier(hidden_layer_sizes=(100,), activation='logistic', alpha=0.001, max_iter=300, random_state=seed)
}

predictions = {}

metrics = {}

for name, model in models.items():
    
    model.fit(X_train_scaled, y_train)
    preds = model.predict(X_test_scaled)
    predictions[name] = preds
    
    acc = accuracy_score(y_test, preds)
    prec = precision_score(y_test, preds, average="weighted")
    rec = recall_score(y_test, preds, average="weighted")
    f1 = f1_score(y_test, preds, average="weighted")
    bal_acc = balanced_accuracy_score(y_test, preds)
    mcc = matthews_corrcoef(y_test, preds)
    
    metrics[name] = {"Model": name, "Accuracy": acc, "Precision": prec, "Recall": rec, "F1-Score": f1, "Balanced Acc": bal_acc, "MCC": mcc}
    print(f"{name:<25} | Acc: {acc:.2%} | Prec: {prec:.2%} | Rec: {rec:.2%} | F1: {f1:.2%} | Bal Acc: {bal_acc:.2%} | MCC: {mcc:.4f}")


metrics_df = pd.DataFrame(metrics).T

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


NaiveBayes                | Acc: 40.75% | Prec: 53.96% | Rec: 40.75% | F1: 41.46% | Bal Acc: 36.68% | MCC: 0.3868


STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


LogisticRegression        | Acc: 80.09% | Prec: 78.39% | Rec: 80.09% | F1: 78.87% | Bal Acc: 53.26% | MCC: 0.7867


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


SVMLinear                 | Acc: 82.24% | Prec: 81.23% | Rec: 82.24% | F1: 81.23% | Bal Acc: 58.29% | MCC: 0.8097


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


SVMPolynomial             | Acc: 49.25% | Prec: 56.06% | Rec: 49.25% | F1: 44.15% | Bal Acc: 19.74% | MCC: 0.4576


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


SVMRadialBasis            | Acc: 77.50% | Prec: 72.49% | Rec: 77.50% | F1: 73.27% | Bal Acc: 33.68% | MCC: 0.7582


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


NearestNeighbor           | Acc: 60.48% | Prec: 58.15% | Rec: 60.48% | F1: 57.18% | Bal Acc: 28.05% | MCC: 0.5741


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


RandomForest              | Acc: 54.25% | Prec: 49.86% | Rec: 54.25% | F1: 46.53% | Bal Acc: 15.60% | MCC: 0.5039
MLP_Baseline              | Acc: 83.38% | Prec: 81.78% | Rec: 83.38% | F1: 82.08% | Bal Acc: 58.48% | MCC: 0.8219


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


In [7]:
test_df = test_df.assign(**{f"Pred_{name}": le.inverse_transform(pred) for name, pred in predictions.items()})
test_df.head(2)

Unnamed: 0,ImageFilePath,ScientificName,Pred_NaiveBayes,Pred_LogisticRegression,Pred_SVMLinear,Pred_SVMPolynomial,Pred_SVMRadialBasis,Pred_NearestNeighbor,Pred_RandomForest,Pred_MLP_Baseline
1766,/fs/ess/PAS2136/Rayeed/BeetlePalooza/individua...,Synuchus impunctatus,,Cymindis neglecta,Synuchus impunctatus,Calathus advena,Calathus advena,Cymindis neglecta,Calathus advena,Calathus advena
2629,/fs/ess/PAS2136/Rayeed/BeetlePalooza/individua...,Agonum punctiforme,Discoderus parallelus,Calathus advena,Agonum punctiforme,Calathus advena,Calathus advena,Calathus advena,Calathus advena,Calathus advena


In [8]:
metrics_df

Unnamed: 0,Model,Accuracy,Precision,Recall,F1-Score,Balanced Acc,MCC
NaiveBayes,NaiveBayes,0.407456,0.539551,0.407456,0.414597,0.366765,0.386813
LogisticRegression,LogisticRegression,0.800877,0.78395,0.800877,0.788678,0.532554,0.786682
SVMLinear,SVMLinear,0.822368,0.81231,0.822368,0.812304,0.582895,0.809689
SVMPolynomial,SVMPolynomial,0.492544,0.560568,0.492544,0.441474,0.197424,0.457586
SVMRadialBasis,SVMRadialBasis,0.775,0.724875,0.775,0.732655,0.336838,0.758211
NearestNeighbor,NearestNeighbor,0.604825,0.581489,0.604825,0.571768,0.280451,0.574064
RandomForest,RandomForest,0.542544,0.498555,0.542544,0.46532,0.155977,0.503883
MLP_Baseline,MLP_Baseline,0.833772,0.817816,0.833772,0.820825,0.584762,0.821875


In [9]:
test_df.to_csv("/users/PAS2136/rayees/3. Benchmarking/BeetlePalooza/23.ViTMAE-linear-probing-species.csv", index=False)
metrics_df.to_csv("/users/PAS2136/rayees/3. Benchmarking/BeetlePalooza/23.ViTMAE-linear-probing-species-metrics.csv", index=False)