In [None]:
# ==================== PRECISION, RECALL, F1-SCORE ANALYSIS ====================
import torch
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from transformers import ViTImageProcessor, ViTForImageClassification
from sklearn.metrics import precision_recall_fscore_support, accuracy_score
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ==================== SETUP ====================
processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")

def vit_transformers(image):
    inputs = processor(images=image, return_tensors="pt")
    return inputs['pixel_values'].squeeze(0)

test_dataset = ImageFolder(
    root=r"C:\Users\Ahmed Pasha\OneDrive\Desktop\garbage\data\test",
    transform=vit_transformers
)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
class_names = test_dataset.classes

# Load model
model = ViTForImageClassification.from_pretrained(
    "google/vit-base-patch16-224",
    num_labels=len(class_names),
    ignore_mismatched_sizes=True
)
model.load_state_dict(torch.load(r"C:\Users\Ahmed Pasha\OneDrive\Desktop\garbage\model\best_vit_model.pth", map_location=device))
model.to(device)
model.eval()

# Get predictions
def get_predictions(model, loader):
    all_predictions = []
    all_labels = []
    
    with torch.no_grad():
        for images, labels in loader:
            images = images.to(device)
            outputs = model(pixel_values=images)
            predictions = torch.argmax(outputs.logits, dim=1)
            all_predictions.extend(predictions.cpu().numpy())
            all_labels.extend(labels.numpy())
    
    return np.array(all_predictions), np.array(all_labels)

predictions, true_labels = get_predictions(model, test_loader)

# ==================== CALCULATE METRICS ====================
precision, recall, f1, support = precision_recall_fscore_support(
    true_labels, 
    predictions, 
    average=None,
    labels=range(len(class_names))
)

# Overall metrics
precision_macro = np.mean(precision)
recall_macro = np.mean(recall)
f1_macro = np.mean(f1)
accuracy = accuracy_score(true_labels, predictions)

print("="*80)
print("OVERALL METRICS")
print("="*80)
print(f"Accuracy:  {accuracy*100:.2f}%")
print(f"Precision: {precision_macro*100:.2f}%")
print(f"Recall:    {recall_macro*100:.2f}%")
print(f"F1-Score:  {f1_macro*100:.2f}%")

# ==================== METRICS TABLE ====================
metrics_df = pd.DataFrame({
    'Class': class_names,
    'Precision': precision,
    'Recall': recall,
    'F1-Score': f1,
    'Support': support
})

print("\n" + "="*80)
print("PER-CLASS METRICS")
print("="*80)
print(metrics_df.to_string(index=False))

# Save to CSV
metrics_df.to_csv('precision_recall_f1.csv', index=False)
print("\n✅ Metrics saved to 'precision_recall_f1.csv'")

# ==================== VISUALIZATION 1: Grouped Bar Chart ====================
x = np.arange(len(class_names))
width = 0.25

fig, ax = plt.subplots(figsize=(14, 6))
bars1 = ax.bar(x - width, precision, width, label='Precision', color='#2E86AB')
bars2 = ax.bar(x, recall, width, label='Recall', color='#A23B72')
bars3 = ax.bar(x + width, f1, width, label='F1-Score', color='#F18F01')

ax.set_xlabel('Class', fontsize=12)
ax.set_ylabel('Score', fontsize=12)
ax.set_title('Precision, Recall, and F1-Score by Class', fontsize=14, fontweight='bold')
ax.set_xticks(x)
ax.set_xticklabels(class_names, rotation=45, ha='right')
ax.legend()
ax.set_ylim([0, 1.1])
ax.grid(axis='y', alpha=0.3)

plt.tight_layout()
plt.savefig('precision_recall_f1_bars.png', dpi=300, bbox_inches='tight')
plt.show()
print("✅ Bar chart saved to 'precision_recall_f1_bars.png'")

# ==================== VISUALIZATION 2: Heatmap ====================
metrics_matrix = np.array([precision, recall, f1])

fig, ax = plt.subplots(figsize=(12, 5))
sns.heatmap(
    metrics_matrix,
    annot=True,
    fmt='.3f',
    cmap='YlGnBu',
    xticklabels=class_names,
    yticklabels=['Precision', 'Recall', 'F1-Score'],
    cbar_kws={'label': 'Score'}
)
plt.title('Metrics Heatmap', fontsize=14, fontweight='bold')
plt.xticks(rotation=45, ha='right')
plt.tight_layout()
plt.savefig('metrics_heatmap.png', dpi=300, bbox_inches='tight')
plt.show()
print("✅ Heatmap saved to 'metrics_heatmap.png'")

# ==================== VISUALIZATION 3: Radar Chart ====================
from math import pi

fig, ax = plt.subplots(figsize=(10, 10), subplot_kw=dict(projection='polar'))

# Number of variables
categories = class_names
N = len(categories)

# Angle for each axis
angles = [n / float(N) * 2 * pi for n in range(N)]
angles += angles[:1]

# Plot for each metric
precision_plot = list(precision) + [precision[0]]
recall_plot = list(recall) + [recall[0]]
f1_plot = list(f1) + [f1[0]]

ax.plot(angles, precision_plot, 'o-', linewidth=2, label='Precision', color='#2E86AB')
ax.fill(angles, precision_plot, alpha=0.15, color='#2E86AB')

ax.plot(angles, recall_plot, 'o-', linewidth=2, label='Recall', color='#A23B72')
ax.fill(angles, recall_plot, alpha=0.15, color='#A23B72')

ax.plot(angles, f1_plot, 'o-', linewidth=2, label='F1-Score', color='#F18F01')
ax.fill(angles, f1_plot, alpha=0.15, color='#F18F01')

ax.set_xticks(angles[:-1])
ax.set_xticklabels(categories)
ax.set_ylim(0, 1)
ax.set_title('Metrics Radar Chart', fontsize=14, fontweight='bold', pad=20)
ax.legend(loc='upper right', bbox_to_anchor=(1.3, 1.0))
ax.grid(True)

plt.tight_layout()
plt.savefig('metrics_radar.png', dpi=300, bbox_inches='tight')
plt.show()
print("✅ Radar chart saved to 'metrics_radar.png'")

# ==================== SUPPORT DISTRIBUTION ====================
plt.figure(figsize=(10, 6))
plt.bar(class_names, support, color='coral', edgecolor='black')
plt.xlabel('Class', fontsize=12)
plt.ylabel('Number of Samples', fontsize=12)
plt.title('Test Set Distribution', fontsize=14, fontweight='bold')
plt.xticks(rotation=45, ha='right')
for i, v in enumerate(support):
    plt.text(i, v + max(support)*0.01, str(v), ha='center', va='bottom')
plt.tight_layout()
plt.savefig('test_distribution.png', dpi=300, bbox_inches='tight')
plt.show()
print("✅ Distribution plot saved to 'test_distribution.png'")