In [9]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset, random_split
from PIL import Image
import os
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
import collections
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
import seaborn as sns
import re
import datetime as dt
import random

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

class GalaxyDataset(Dataset):
    def __init__(self, root_dir, labeled=True):
        self.root_dir = root_dir
        self.labeled = labeled
        self.image_files = [f for f in os.listdir(root_dir) if f.endswith(('.png', '.jpg', '.jpeg'))]
        
    def extract_label_from_filename(self, filename):
        match = re.search(r'_(\d+)$', os.path.splitext(filename)[0])
        if match:
            return int(match.group(1))
        return -1  
        
    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_name = self.image_files[idx]
        img_path = os.path.join(self.root_dir, img_name)

        image = Image.open(img_path).convert("RGB")
        image = image.resize((128, 128))

        image_array = np.array(image) / 255.0

        image_tensor = torch.FloatTensor(image_array).permute(2, 0, 1)
        
        if self.labeled:

            label = self.extract_label_from_filename(img_name)
            return image_tensor, label, img_name
        else:
            return image_tensor, img_name

class GalaxyCNN(nn.Module):
    def __init__(self, num_classes):
        super(GalaxyCNN, self).__init__()
        
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),

            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),

            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),

            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        
        self.classifier = nn.Sequential(
            nn.Dropout(0.3),
            nn.Linear(512 * 8 * 8, 2048),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3),
            nn.Linear(2048, 512),
            nn.ReLU(inplace=True),
            nn.Linear(512, num_classes)
        )
    
    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

def load_model(model_path, num_classes):
    model = GalaxyCNN(num_classes=num_classes)
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.to(device)
    model.eval()
    return model

def make_predictions(model, dataloader, labeled=True):
    predictions = []
    filenames = []
    true_labels = []
    
    with torch.no_grad():
        for data in tqdm(dataloader, desc="Making Predictions"):
            if labeled:
                inputs, labels, img_names = data
                true_labels.extend(labels.numpy())
                filenames.extend(img_names)
            else:
                inputs, img_names = data
                filenames.extend(img_names)
            
            inputs = inputs.to(device)
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            
            batch_preds = preds.cpu().numpy()
            predictions.extend(batch_preds)
    
    if labeled:
        return predictions, true_labels, filenames
    else:
        return predictions, filenames

def print_validation_accuracy(true_labels, predictions):
    accuracy = accuracy_score(true_labels, predictions)
    correct = sum(1 for t, p in zip(true_labels, predictions) if t == p)
    total = len(true_labels)

    print("\n" + "="*50)
    print(f"VALIDATION ACCURACY: {accuracy * 100:.2f}%")
    print(f"Correct predictions: {correct}/{total}")
    print("="*50)
    
    return accuracy

def visualize_predictions(predictions, num_classes, title="Prediction Distribution"):
    counter = collections.Counter(predictions)

    classes = list(range(num_classes))
    counts = [counter.get(cls, 0) for cls in classes]

    plt.figure(figsize=(12, 6))
    plt.bar(classes, counts)
    plt.xlabel('Class')
    plt.ylabel('Count')
    plt.title(title)
    plt.xticks(classes)
    plt.savefig(f'{title.lower().replace(" ", "_")}.png')
    plt.close()
    print(f"Distribution saved to '{title.lower().replace(' ', '_')}.png'")
    print(f"\n{title} Summary:")
    for cls in classes:
        print(f"Class {cls}: {counter.get(cls, 0)} images ({counter.get(cls, 0)/len(predictions)*100:.2f}%)")

def evaluate_model(true_labels, predictions, num_classes, filenames=None):
    accuracy = accuracy_score(true_labels, predictions)
    print(f"\nModel Accuracy: {accuracy * 100:.2f}%")

    print("\nClassification Report:")
    report = classification_report(true_labels, predictions, labels=range(num_classes), zero_division=0)
    print(report)

    cm = confusion_matrix(true_labels, predictions, labels=range(num_classes))

    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=range(num_classes), yticklabels=range(num_classes))
    plt.xlabel('Predicted Class')
    plt.ylabel('True Class')
    plt.title('Confusion Matrix')
    plt.savefig('confusion_matrix.png')
    plt.close()
    print("Confusion matrix saved to 'confusion_matrix_final.png'")

    print("\nPer-Class Accuracy:")
    for i in range(num_classes):
        class_correct = cm[i, i]
        class_total = np.sum(cm[i, :])
        if class_total > 0:
            class_accuracy = class_correct / class_total
            print(f"Class {i}: {class_accuracy * 100:.2f}% ({class_correct}/{class_total})")
        else:
            print(f"Class {i}: N/A (0 samples)")

    if filenames is not None:
        incorrect_predictions = []
        for idx, (true, pred, fname) in enumerate(zip(true_labels, predictions, filenames)):
            if true != pred:
                incorrect_predictions.append(f"{fname}: Predicted {pred}, True {true}")
        
        if incorrect_predictions:
            with open("incorrect_predictions_final.txt", "w") as f:
                for line in incorrect_predictions:
                    f.write(f"{line}\n")
            print(f"\nSaved {len(incorrect_predictions)} incorrect predictions to 'incorrect_predictions.txt'")

# New function for test cases
def create_test_cases_section(model, dataset, num_examples=10, save_dir="test_cases"):
    """
    Creates a test cases section with example images, predictions, and metrics.
    
    Args:
        model: Trained model
        dataset: Dataset to sample from
        num_examples: Number of examples to include
        save_dir: Directory to save test case images and results
    """
    print(f"\n{'='*60}")
    print(f"GENERATING TEST CASES AND RESULTS SECTION")
    print(f"{'='*60}")
    
    # Create directory if it doesn't exist
    os.makedirs(save_dir, exist_ok=True)
    
    # Sample random indices
    indices = random.sample(range(len(dataset)), min(num_examples, len(dataset)))
    
    # Initialize lists to store results
    test_results = []
    
    # Custom colormap for heatmap visualization
    colors = [(0.6, 0.1, 0.1), (1.0, 1.0, 0.2), (0.1, 0.5, 0.1)]  # Red -> Yellow -> Green
    cmap = LinearSegmentedColormap.from_list("custom_cmap", colors, N=100)
    
    # Create the test report file
    with open(os.path.join(save_dir, "test_report.md"), "w", encoding="utf-8") as report_file:
        report_file.write("# Galaxy Classification Test Cases and Results\n\n")
        report_file.write(f"Date: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n")
        report_file.write("## Individual Test Cases\n\n")
        
        for i, idx in enumerate(indices):
            print(f"Processing test case {i+1}/{len(indices)}...")
            
            # Get the image and true label
            image_tensor, true_label, filename = dataset[idx]
            
            # Make prediction
            model.eval()
            with torch.no_grad():
                input_tensor = image_tensor.unsqueeze(0).to(device)
                output = model(input_tensor)
                probabilities = torch.nn.functional.softmax(output, dim=1)[0]
                predicted_label = torch.argmax(output, dim=1).item()
            
            # Calculate accuracy for this example (0 or 1)
            accuracy = 1.0 if predicted_label == true_label else 0.0
            
            # Store results
            test_results.append({
                "index": idx,
                "filename": filename,
                "true_label": true_label,
                "predicted_label": predicted_label,
                "accuracy": accuracy,
                "confidence": probabilities[predicted_label].item(),
                "all_probabilities": probabilities.cpu().numpy()
            })
            
            # Plot and save the image with prediction info
            plt.figure(figsize=(12, 10))
            
            # Original image
            plt.subplot(2, 1, 1)
            img = image_tensor.permute(1, 2, 0).numpy()
            plt.imshow(img)
            plt.title(f"Galaxy Image: {filename}")
            plt.axis('off')
            
            # Probability distribution
            plt.subplot(2, 1, 2)
            probs = probabilities.cpu().numpy()
            bar_colors = [cmap(0.2) if i != predicted_label else cmap(0.8) for i in range(len(probs))]
            bars = plt.bar(range(len(probs)), probs, color=bar_colors)
            
            # Highlight true label
            if true_label < len(probs):
                bars[true_label].set_color('blue')
                bars[true_label].set_edgecolor('black')
                bars[true_label].set_linewidth(2)
            
            plt.xticks(range(len(probs)))
            plt.xlabel("Galaxy Class")
            plt.ylabel("Probability")
            plt.title(f"Class Probabilities (Predicted: {predicted_label}, True: {true_label})")
            
            # Save the figure
            test_case_img_path = os.path.join(save_dir, f"test_case_{i+1}.png") 
            plt.savefig(test_case_img_path)
            plt.close()
            
            # Write to report - using ASCII instead of Unicode symbols
            report_file.write(f"### Test Case {i+1}: {filename}\n\n")
            report_file.write(f"![Test Case {i+1}]({os.path.basename(test_case_img_path)})\n\n")
            report_file.write(f"- **Filename:** {filename}\n")
            report_file.write(f"- **True Class:** {true_label}\n")
            report_file.write(f"- **Predicted Class:** {predicted_label}\n")
            # Using ASCII symbols instead of Unicode
            report_file.write(f"- **Correct:** {'YES' if accuracy == 1.0 else 'NO'}\n")
            report_file.write(f"- **Confidence:** {probabilities[predicted_label].item()*100:.2f}%\n\n")
            report_file.write("| Class | Probability |\n")
            report_file.write("|-------|-------------|\n")
            for class_idx, prob in enumerate(probabilities.cpu().numpy()):
                report_file.write(f"| {class_idx} | {prob*100:.2f}% |\n")
            report_file.write("\n---\n\n")
        
        # Calculate overall metrics
        correct_count = sum(1 for result in test_results if result["accuracy"] == 1.0)
        overall_accuracy = correct_count / len(test_results) * 100
        avg_confidence = sum(result["confidence"] for result in test_results) / len(test_results) * 100
        
        report_file.write("## Summary of Test Results\n\n")
        report_file.write(f"- **Total Test Cases:** {len(test_results)}\n")
        report_file.write(f"- **Correctly Classified:** {correct_count}\n")
        report_file.write(f"- **Test Accuracy:** {overall_accuracy:.2f}%\n")
        report_file.write(f"- **Average Confidence:** {avg_confidence:.2f}%\n\n")
        
        # Create a summary table
        report_file.write("## Test Cases Overview\n\n")
        report_file.write("| # | Image | True Class | Predicted | Correct | Confidence |\n")
        report_file.write("|---|-------|------------|-----------|---------|------------|\n")
        
        for i, result in enumerate(test_results):
            img_path = f"test_case_{i+1}.png"
            # Using plain text instead of Unicode symbols
            correct_symbol = "YES" if result["accuracy"] == 1.0 else "NO"
            report_file.write(f"| {i+1} | [{result['filename']}]({img_path}) | {result['true_label']} | {result['predicted_label']} | {correct_symbol} | {result['confidence']*100:.2f}% |\n")
    
    # Save detailed results to CSV
    try:
        import pandas as pd
        
        # Create detailed dataframe
        df_detailed = pd.DataFrame(test_results)
        
        # Expand probabilities into separate columns
        prob_cols = pd.DataFrame([{f'prob_class_{i}': probs[i] for i in range(len(probs))} 
                                for probs in df_detailed['all_probabilities']])
        
        # Remove the original probabilities column and merge with expanded probabilities
        df_detailed = df_detailed.drop('all_probabilities', axis=1)
        df_detailed = pd.concat([df_detailed, prob_cols], axis=1)
        
        # Save to CSV
        df_detailed.to_csv(os.path.join(save_dir, "test_results_detailed.csv"), index=False)
        print(f"Detailed test results saved to {os.path.join(save_dir, 'test_results_detailed.csv')}")
        
    except ImportError:
        print("pandas not installed, skipping detailed CSV export")
    
    print(f"\nTest cases and report generated successfully in '{save_dir}' directory")
    print(f"Overall test accuracy: {overall_accuracy:.2f}%")
    
    return test_results

if __name__ == "__main__":
    data_dir = "C:\\Users\\Aseem\\Desktop\\BE Project\\Decals_data_images" 
    if not os.path.exists(data_dir):
        raise FileNotFoundError(f"Directory {data_dir} does not exist!")

    model_path = "best_model.pth"  
    num_classes = 10  
    validation_split = 0.2  
    use_validation = True  

    print("Setting up dataset with labels extracted from filenames")
    full_dataset = GalaxyDataset(root_dir=data_dir, labeled=True)

    for i in range(min(10, len(full_dataset))):
        _, label, filename = full_dataset[i]
        print(f"Sample {i}: Filename = {filename}, Extracted Label = {label}")
    
    if use_validation:
        dataset_size = len(full_dataset)
        val_size = int(validation_split * dataset_size)
        train_size = dataset_size - val_size
        
        train_dataset, val_dataset = random_split(
            full_dataset, [train_size, val_size], 
            generator=torch.Generator().manual_seed(42) 
        )
        
        print(f"Dataset split into {train_size} training and {val_size} validation samples")

        train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
        val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

        try:
            model = load_model(model_path, num_classes)
            print(f"Model loaded successfully from {model_path}")
        except Exception as e:
            print(f"Error loading model: {e}")
            exit(1)

        print("\nEvaluating model on training set...")
        train_predictions, train_true_labels, train_filenames = make_predictions(model, train_loader, labeled=True)
        print_validation_accuracy(train_true_labels, train_predictions) 
        evaluate_model(train_true_labels, train_predictions, num_classes, train_filenames)
        
        print("\nEvaluating model on validation set...")
        val_predictions, val_true_labels, val_filenames = make_predictions(model, val_loader, labeled=True)

        val_accuracy = print_validation_accuracy(val_true_labels, val_predictions)

        with open("validation_accuracy.txt", "w") as f:
            f.write(f"Validation Accuracy: {val_accuracy * 100:.2f}%\n")
            f.write(f"Correct: {sum(1 for t, p in zip(val_true_labels, val_predictions) if t == p)}\n")
            f.write(f"Total: {len(val_true_labels)}\n")
            f.write(f"Date: {dt.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")

        evaluate_model(val_true_labels, val_predictions, num_classes, val_filenames)

        visualize_predictions(train_predictions, num_classes, "Training Set Predictions")
        visualize_predictions(train_true_labels, num_classes, "Training Set True Labels")
        visualize_predictions(val_predictions, num_classes, "Validation Set Predictions")
        visualize_predictions(val_true_labels, num_classes, "Validation Set True Labels")
        
        # Generate test cases from validation set
        print("\nGenerating test cases from validation set...")
        # Create a custom dataset from validation indices for better control
        val_indices = val_dataset.indices
        test_cases_dataset = torch.utils.data.Subset(full_dataset, val_indices)
        
        # Generate test cases (adjust num_examples as needed)
        test_results = create_test_cases_section(model, test_cases_dataset, num_examples=15, 
                                                save_dir="test_cases_report")

        all_predictions = train_predictions + val_predictions
        all_true_labels = train_true_labels + val_true_labels
        all_filenames = train_filenames + val_filenames
        
        try:
            import pandas as pd
            results_df = pd.DataFrame({
                "filename": all_filenames,
                "true_class": all_true_labels,
                "predicted_class": all_predictions,
                "correct": [p == t for p, t in zip(all_predictions, all_true_labels)]
            })
            results_df.to_csv("all_predictions_with_accuracy_final.csv", index=False)
            print(f"All predictions saved to 'all_predictions_with_accuracy.csv'")
        except ImportError:
            print("pandas not installed, skipping CSV export")
        
    else:
        dataloader = DataLoader(full_dataset, batch_size=32, shuffle=False)
        
        print(f"Found {len(full_dataset)} images for classification")
        try:
            model = load_model(model_path, num_classes)
            print(f"Model loaded successfully from {model_path}")
        except Exception as e:
            print(f"Error loading model: {e}")
            exit(1)

        predictions, true_labels, filenames = make_predictions(model, dataloader, labeled=True)

        print_validation_accuracy(true_labels, predictions)
           
        evaluate_model(true_labels, predictions, num_classes, filenames)
        visualize_predictions(predictions, num_classes, "Model Predictions")
        visualize_predictions(true_labels, num_classes, "True Labels")
        
        # Generate test cases from the full dataset
        print("\nGenerating test cases...")
        test_results = create_test_cases_section(model, full_dataset, num_examples=15, 
                                                save_dir="test_cases_report")

        try:
            import pandas as pd
            results_df = pd.DataFrame({
                "filename": filenames,
                "true_class": true_labels,
                "predicted_class": predictions,
                "correct": [p == t for p, t in zip(predictions, true_labels)]
            })
            results_df.to_csv("predictions_with_accuracy.csv", index=False)
            print(f"Predictions saved to 'predictions_with_accuracy.csv'")
        except ImportError:
            print("pandas not installed, skipping CSV export")

import datetime

Using device: cuda
Setting up dataset with labels extracted from filenames
Sample 0: Filename = image_0_0.png, Extracted Label = 0
Sample 1: Filename = image_10000_6.png, Extracted Label = 6
Sample 2: Filename = image_10001_6.png, Extracted Label = 6
Sample 3: Filename = image_10002_6.png, Extracted Label = 6
Sample 4: Filename = image_10003_6.png, Extracted Label = 6
Sample 5: Filename = image_10004_6.png, Extracted Label = 6
Sample 6: Filename = image_10005_6.png, Extracted Label = 6
Sample 7: Filename = image_10006_6.png, Extracted Label = 6
Sample 8: Filename = image_10007_6.png, Extracted Label = 6
Sample 9: Filename = image_10008_6.png, Extracted Label = 6
Dataset split into 14189 training and 3547 validation samples
Model loaded successfully from best_model.pth

Evaluating model on training set...


Making Predictions: 100%|██████████| 444/444 [01:26<00:00,  5.15it/s]



VALIDATION ACCURACY: 90.77%
Correct predictions: 12879/14189

Model Accuracy: 90.77%

Classification Report:
              precision    recall  f1-score   support

           0       0.80      0.70      0.75       888
           1       0.95      0.97      0.96      1515
           2       0.94      0.97      0.95      2102
           3       0.93      0.97      0.95      1613
           4       0.76      1.00      0.87       263
           5       0.86      0.94      0.89      1624
           6       0.86      0.86      0.86      1484
           7       0.89      0.74      0.81      2069
           8       0.95      0.99      0.97      1134
           9       0.97      0.97      0.97      1497

    accuracy                           0.91     14189
   macro avg       0.89      0.91      0.90     14189
weighted avg       0.91      0.91      0.91     14189

Confusion matrix saved to 'confusion_matrix_final.png'

Per-Class Accuracy:
Class 0: 70.50% (626/888)
Class 1: 96.96% (1469/1515)
C

Making Predictions: 100%|██████████| 111/111 [00:22<00:00,  4.85it/s]



VALIDATION ACCURACY: 90.19%
Correct predictions: 3199/3547

Model Accuracy: 90.19%

Classification Report:
              precision    recall  f1-score   support

           0       0.75      0.69      0.72       193
           1       0.94      0.96      0.95       338
           2       0.94      0.97      0.95       543
           3       0.94      0.97      0.96       414
           4       0.77      1.00      0.87        71
           5       0.87      0.90      0.88       419
           6       0.82      0.90      0.86       345
           7       0.89      0.73      0.80       559
           8       0.95      0.98      0.97       289
           9       0.96      0.97      0.96       376

    accuracy                           0.90      3547
   macro avg       0.88      0.91      0.89      3547
weighted avg       0.90      0.90      0.90      3547

Confusion matrix saved to 'confusion_matrix_final.png'

Per-Class Accuracy:
Class 0: 68.91% (133/193)
Class 1: 95.86% (324/338)
Class