### Use correct version of Transformers

In [None]:
! pip uninstall -y transformers
! pip install --quiet transformers==4.48.3

### Import Libraries and Mount Drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import os
import pandas as pd
from datasets import load_dataset
from transformers import AutoModelForImageClassification, ViTHybridForImageClassification

### Initialize Test Config

In [None]:
# Set the root directory
root_dir = "/content/drive/Shareddrives/CS198-Drones/"

# Set the dataset path and output file path
dataset_path = 'cvmil/rice-leaf-disease-augmented-v4'
output_path = '[TEST] Models/model_and_dataset_info.xlsx'

# Set the number of labels for the models
num_labels = 8

# List of models to analyze
models = [
  "facebook/convnextv2-base-1k-224",
  "google/vit-hybrid-base-bit-384",
  "google/vit-base-patch16-224",
  "microsoft/swin-base-patch4-window7-224",
  "facebook/deit-base-patch16-224",
  "facebook/dinov2-base",
  "timm/vit_small_patch16_224.augreg_in21k",
  "microsoft/swin-tiny-patch4-window7-224",
  "facebook/deit-small-patch16-224",
  "facebook/convnextv2-tiny-1k-224",
  "apple/mobilevit-small",
  "timm/mobilevitv2_150.cvnets_in22k_ft_in1k",
  "google/efficientnet-b2",
  "timm/efficientvit_b1.r224_in1k",
  "timm/efficientvit_m4.r224_in1k",
  "timm/efficientformerv2_s2.snap_dist_in1k",
  "timm/efficientformer_l1.snap_dist_in1k"
]

### Get Model Metadata

In [11]:
def get_model_info(model_names, num_labels, cache_dir="./hf_models"):
    model_data = []

    for model_name in model_names:
        print(f"\n🔍 Processing Model: {model_name}")
        try:
            ModelClass = ViTHybridForImageClassification if "hybrid" in model_name else AutoModelForImageClassification
            model = ModelClass.from_pretrained(model_name, num_labels, cache_dir, ignore_mismatched_sizes=True)

            # Get model parameter count
            total_params = sum(p.numel() for p in model.parameters())

            # Get model size in MB
            weight_file = os.path.join(cache_dir, model_name.replace("/", "__"), "pytorch_model.bin")
            model_size_mb = os.path.getsize(weight_file) / (1024 * 1024) if os.path.exists(weight_file) else None

            # Add model info to the list
            model_data.append({
                "Model": model_name,
                "Total Parameters": total_params,
                "Model Size (MB)": model_size_mb
            })

        except Exception as e:
            print(f"❌ Error processing {model_name}: {e}")

    df_models = pd.DataFrame(model_data)
    return df_models

### Get Dataset Metadata

In [None]:
def get_dataset_info(dataset_path):
    print(f"\n🔍 Processing Dataset: {dataset_path}")
    dataset = load_dataset(dataset_path)

    # Get label names from the dataset metadata
    label_names = dataset["train"].features["label"].names  # Assuming all splits have the same labels

    class_data = []

    for split in dataset.keys():
        df = dataset[split].to_pandas()
        class_counts = df["label"].value_counts().reset_index()
        class_counts.columns = ["Class Index", "Count"]
        class_counts["Class"] = class_counts["Class Index"].map(lambda x: label_names[x])  # Map index to label
        class_counts["Split"] = split
        class_counts = class_counts[["Class", "Class Index", "Count", "Split"]]  # Reorder columns
        class_data.append(class_counts)

    # Merge all splits into one DataFrame
    df_dataset = pd.concat(class_data, ignore_index=True)
    return df_dataset

### Runner Function

In [None]:
# Create the output directory
output_file = os.path.join(root_dir, output_path)
os.makedirs(os.path.dirname(output_file), exist_ok=True)

# Get model and dataset info
df_models = get_model_info(models, num_labels)
df_dataset = get_dataset_info(dataset_path)

# Save both tables to the same Excel file
with pd.ExcelWriter(output_file, engine='openpyxl') as writer:
    df_models.to_excel(writer, sheet_name="Model Info", index=False)
    df_dataset.to_excel(writer, sheet_name="Dataset Info", index=False)

print(f"\n✅ Results saved to {output_file}")