Imports

In [1]:
# English: Imports for libraries
# 中文: 导入所需库
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import numpy as np
import os
import time
from google.colab import drive # For Google Drive access
import matplotlib.pyplot as plt # For potential plotting later

Cell: Download, Extract, and Move CIFAR-10-C Data

In [12]:
# --- Configuration ---
# English: Define the target directory in your Google Drive where the .npy files should end up.
# 中文: 定义你 Google Drive 中存放 .npy 文件的目标目录路径。
# MAKE SURE THIS PATH IS EXACTLY WHERE YOU WANT THE DATA!
target_data_path = '/content/drive/MyDrive/ECE 661 Final Project: Knowledge Distillation/A_Jerry_SelfDistillation/Data/'

# English: Define the download URL for CIFAR-10-C .tar file
# 中文: 定义 CIFAR-10-C .tar 文件的下载链接
cifar_c_url = "https://zenodo.org/record/2535967/files/CIFAR-10-C.tar?download=1"
downloaded_tar_path = "/content/CIFAR-10-C.tar"
local_extract_folder = "/content/CIFAR-10-C_extracted" # Temporary local folder for extraction

# --- Step 1: Ensure target directory exists in Google Drive ---
# English: Create the target directory on Google Drive if it doesn't exist.
# 中文: 如果目标目录在 Google Drive 中不存在，则创建它。
import os
print(f"Ensuring target directory exists: {target_data_path}")
os.makedirs(target_data_path, exist_ok=True)

# --- Step 2: Download the .tar file to Colab temporary storage ---
print(f"\nDownloading CIFAR-10-C from {cifar_c_url}...")
# Use !wget command. -O specifies the output file path.
!wget "{cifar_c_url}" -O "{downloaded_tar_path}"
print("Download command executed.")

# --- Step 3: Verify Download ---
print("\nChecking if download was successful...")
if os.path.exists(downloaded_tar_path):
    print(f"File {downloaded_tar_path} downloaded successfully.")
    # --- Step 4: Extract the .tar file locally in Colab ---
    print(f"\nExtracting {downloaded_tar_path} locally...")
    # Create a temporary local directory for extraction
    os.makedirs(local_extract_folder, exist_ok=True)
    # Extract using tar command. xf = extract file, -C = change to directory.
    !tar xf "{downloaded_tar_path}" -C "{local_extract_folder}"
    print("Extraction command executed.")

    # --- Step 5: Identify the actual extracted content folder ---
    # tar usually extracts into a subfolder, typically named 'CIFAR-10-C'
    extracted_content_path = os.path.join(local_extract_folder, "CIFAR-10-C")
    if not os.path.isdir(extracted_content_path):
        # If the subfolder isn't 'CIFAR-10-C', try listing to find it
        possible_dirs = [d for d in os.listdir(local_extract_folder) if os.path.isdir(os.path.join(local_extract_folder, d))]
        if len(possible_dirs) == 1:
            extracted_content_path = os.path.join(local_extract_folder, possible_dirs[0])
            print(f"Found extracted content in: {extracted_content_path}")
        else:
            print(f"Error: Could not uniquely identify extracted content folder inside {local_extract_folder}. Found: {possible_dirs}")
            extracted_content_path = None

    if extracted_content_path and os.path.isdir(extracted_content_path):
        # --- Step 6: Move extracted files (.npy) to Google Drive target path ---
        # Use !mv command. The '*' moves all files/folders inside the source dir.
        # Using quotes around paths handles spaces.
        print(f"\nMoving extracted files from {extracted_content_path} to {target_data_path}...")
        # This step might take a while!
        !mv "{extracted_content_path}"/* "{target_data_path}"
        print("Move command executed.")

        # --- Step 7: Clean up downloaded tar and temporary folder ---
        print("\nCleaning up temporary files...")
        !rm "{downloaded_tar_path}"
        !rm -rf "{local_extract_folder}"
        print("Cleanup complete.")

        # --- Step 8: Verify files in Google Drive ---
        print(f"\nVerifying contents of target directory: {target_data_path}")
        !ls -lh "{target_data_path}" # List files to confirm move
        print("\nPlease check the list above to ensure labels.npy and corruption .npy files are present.")

    else:
        print(f"Error: Failed to find extracted CIFAR-10-C content after extraction.")
        print(f"Please check the contents of {local_extract_folder}")

else:
    print(f"Error: Download failed. File {downloaded_tar_path} not found.")

Ensuring target directory exists: /content/drive/MyDrive/ECE 661 Final Project: Knowledge Distillation/A_Jerry_SelfDistillation/Data/

Downloading CIFAR-10-C from https://zenodo.org/record/2535967/files/CIFAR-10-C.tar?download=1...
--2025-04-22 05:54:10--  https://zenodo.org/record/2535967/files/CIFAR-10-C.tar?download=1
Resolving zenodo.org (zenodo.org)... 188.185.45.92, 188.185.48.194, 188.185.43.25, ...
Connecting to zenodo.org (zenodo.org)|188.185.45.92|:443... connected.
HTTP request sent, awaiting response... 301 MOVED PERMANENTLY
Location: /records/2535967/files/CIFAR-10-C.tar [following]
--2025-04-22 05:54:11--  https://zenodo.org/records/2535967/files/CIFAR-10-C.tar
Reusing existing connection to zenodo.org:443.
HTTP request sent, awaiting response... 200 OK
Length: 2918471680 (2.7G) [application/octet-stream]
Saving to: ‘/content/CIFAR-10-C.tar’


2025-04-22 05:56:19 (21.7 MB/s) - ‘/content/CIFAR-10-C.tar’ saved [2918471680/2918471680]

Download command executed.

Checking if

Mount Drive & Setup Paths

In [15]:
# English: Mount Google Drive and define key directory paths
# 中文: 挂载 Google Drive 并定义关键目录路径

try:
    drive.mount('/content/drive', force_remount=True) # force_remount can help if needed
    print("Google Drive mounted successfully.")

    # --- Define BASE paths based on your previous setup ---
    DRIVE_PROJECT_BASE = '/content/drive/MyDrive/ECE 661 Final Project: Knowledge Distillation/A_Jerry_SelfDistillation'
    MODELS_DIR_PATH = os.path.join(DRIVE_PROJECT_BASE, 'models')
    RESULTS_DIR_PATH = os.path.join(DRIVE_PROJECT_BASE, 'results') # For saving results/plots if needed

    # --- !!! IMPORTANT: UPDATE THIS PATH TO YOUR CIFAR-10-C DATA !!! ---
    # English: Set the path to the folder containing CIFAR-10-C .npy files
    # 中文: 设置包含 CIFAR-10-C .npy 文件的文件夹路径
    CIFAR_C_PATH = '/content/drive/MyDrive/ECE 661 Final Project: Knowledge Distillation/A_Jerry_SelfDistillation/Data/' # Example path, PLEASE UPDATE!
    LABELS_PATH = os.path.join(CIFAR_C_PATH, 'labels.npy')

    # --- Define Model File Paths ---
    BASELINE_PTH = os.path.join(MODELS_DIR_PATH, 'resnet50_baseline_cifar10.pth')
    BEST_SD_PTH = os.path.join(MODELS_DIR_PATH, 'resnet50_self_distill_Best_Config_01.pth') # Your best SD model filename

    print(f"Base Project Path: {DRIVE_PROJECT_BASE}")
    print(f"Models Path: {MODELS_DIR_PATH}")
    print(f"Results Path: {RESULTS_DIR_PATH}")
    print(f"CIFAR-C Path: {CIFAR_C_PATH}")
    print(f"Baseline Model: {BASELINE_PTH}")
    print(f"Best SD Model: {BEST_SD_PTH}")

    # Check if label file exists early
    if not os.path.exists(LABELS_PATH):
        print(f"\n---! WARNING !---")
        print(f"CIFAR-10-C labels file not found at: {LABELS_PATH}")
        print(f"Please ensure the path is correct and you have downloaded the data.")
    else:
        print("CIFAR-10-C labels file found.")

except Exception as e:
    print(f"Error during Drive Mount or Path Setup: {e}")

Mounted at /content/drive
Google Drive mounted successfully.
Base Project Path: /content/drive/MyDrive/ECE 661 Final Project: Knowledge Distillation/A_Jerry_SelfDistillation
Models Path: /content/drive/MyDrive/ECE 661 Final Project: Knowledge Distillation/A_Jerry_SelfDistillation/models
Results Path: /content/drive/MyDrive/ECE 661 Final Project: Knowledge Distillation/A_Jerry_SelfDistillation/results
CIFAR-C Path: /content/drive/MyDrive/ECE 661 Final Project: Knowledge Distillation/A_Jerry_SelfDistillation/Data/
Baseline Model: /content/drive/MyDrive/ECE 661 Final Project: Knowledge Distillation/A_Jerry_SelfDistillation/models/resnet50_baseline_cifar10.pth
Best SD Model: /content/drive/MyDrive/ECE 661 Final Project: Knowledge Distillation/A_Jerry_SelfDistillation/models/resnet50_self_distill_Best_Config_01.pth
CIFAR-10-C labels file found.


In [21]:
# Cell 2.5: 独立验证 labels.npy

import numpy as np
import os

# 本单元格假设 LABELS_PATH 已在之前的单元格 (Cell 2) 中正确定义
# 如果没有，取消注释并在此处定义:
# CIFAR_C_PATH = '/content/drive/MyDrive/Data/CIFAR-10-C/' # 确保路径正确!
# LABELS_PATH = os.path.join(CIFAR_C_PATH, 'labels.npy')

print(f"--- Verifying labels.npy ---")
print(f"Attempting to load: {LABELS_PATH}")

# 检查路径变量是否存在
if 'LABELS_PATH' in locals() and LABELS_PATH:
    if os.path.exists(LABELS_PATH):
        try:
            # 从指定路径加载数组
            labels_array = np.load(LABELS_PATH)
            print(f"Successfully loaded labels.npy")
            # 打印关键信息
            print(f"Reported SHAPE: {labels_array.shape}")
            print(f"Reported DTYPE: {labels_array.dtype}")
            print(f"Number of elements (calculated from shape): {labels_array.size}")
            # 打印前10个元素进行健全性检查
            print(f"First 10 labels: {labels_array[:10]}")

            # 与预期形状进行显式检查
            if labels_array.shape == (10000,):
                 print("\n----> Shape (10000,) 与 CIFAR-10 测试集的预期相符。问题可能出在 Dataset 类代码中。")
            else:
                 print(f"\n---! 警告 !---")
                 print(f"----> Shape {labels_array.shape} 与预期的 (10000,) 不符。")
                 print(f"----> 这确认了被加载的标签文件包含 {labels_array.size} 个元素，而不是 10000 个。")
                 print(f"----> 请重新下载正确的 CIFAR-10-C labels.npy 文件。")

        except Exception as e:
            print(f"Error loading or inspecting labels.npy: {e}")
    else:
        print(f"File not found at the specified path: {LABELS_PATH}")
        print("Please double-check the CIFAR_C_PATH in Cell 2 and ensure labels.npy exists there.")
else:
    print("Error: LABELS_PATH variable not found or is empty. Ensure Cell 2 defining paths was run.")

print(f"--- End Verification ---")

--- Verifying labels.npy ---
Attempting to load: /content/drive/MyDrive/ECE 661 Final Project: Knowledge Distillation/A_Jerry_SelfDistillation/Data/labels.npy
Successfully loaded labels.npy
Reported SHAPE: (50000,)
Reported DTYPE: uint8
Number of elements (calculated from shape): 50000
First 10 labels: [3 8 8 0 6 6 1 6 3 1]

---! 警告 !---
----> Shape (50000,) 与预期的 (10000,) 不符。
----> 这确认了被加载的标签文件包含 50000 个元素，而不是 10000 个。
----> 请重新下载正确的 CIFAR-10-C labels.npy 文件。
--- End Verification ---


Cell 3: Device Setup

In [3]:
# English: Set the computation device (GPU if available, otherwise CPU)
# 中文: 设置计算设备（可用时使用 GPU，否则使用 CPU）
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

Using device: cuda


Cell 4: Model Definitions

In [4]:
# English: Define the necessary model architectures (ResNet50 Baseline and ResNet50_SD)
# 中文: 定义所需的模型架构（ResNet50 Baseline 和 ResNet50_SD）
# IMPORTANT: Ensure these definitions exactly match those used during training!

# --- Standard ResNet50 Definition Logic (for Baseline) ---
def get_baseline_resnet50(num_classes=10):
    model = torchvision.models.resnet50(weights=None)
    num_ftrs = model.fc.in_features
    model.fc = nn.Linear(num_ftrs, num_classes)
    return model

# --- ResNet50_SD Definition (Self-Distillation Version) ---
# Use the definition from your training notebook (copied here for completeness)
# Ensure base_model_struct is handled correctly if needed outside class init
class ResNet50_SD(nn.Module):
    def __init__(self, num_classes=10): # Simplified constructor for loading
        super(ResNet50_SD, self).__init__()
        base_model = torchvision.models.resnet50(weights=None) # Load base structure
        # Use layers from the base ResNet-50
        self.conv1 = base_model.conv1
        self.bn1 = base_model.bn1
        self.relu = base_model.relu
        self.maxpool = base_model.maxpool
        self.layer1 = base_model.layer1
        self.layer2 = base_model.layer2
        self.layer3 = base_model.layer3
        self.layer4 = base_model.layer4
        self.avgpool = base_model.avgpool
        self.fc_final = nn.Linear(base_model.fc.in_features, num_classes)
        # Define Aux layers needed for state_dict loading, even if not used in forward for eval
        self.avgpool_aux2 = nn.AdaptiveAvgPool2d((1, 1))
        self.fc_aux2 = nn.Linear(512, num_classes)
        self.avgpool_aux3 = nn.AdaptiveAvgPool2d((1, 1))
        self.fc_aux3 = nn.Linear(1024, num_classes)

    def forward(self, x):
        # Evaluation forward pass only needs the final output
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        out_final = self.fc_final(x)
        return out_final # Only return final output

print("Model architectures defined (Baseline and Self-Distillation).")

Model architectures defined (Baseline and Self-Distillation).


Cell 5: Helper Functions (Loading, Dataset, Evaluation)

In [26]:
# English: Define helper functions for model loading, data handling, and evaluation
# 中文: 定义用于模型加载、数据处理和评估的辅助函数

# --- Model Loading Function ---
def load_model_robustness(model_type, pth_path, device, num_classes=10):
    print(f"Attempting to load model ({model_type}) from: {pth_path}")
    if model_type == 'baseline':
        model = get_baseline_resnet50(num_classes)
    elif model_type == 'self_distill':
        model = ResNet50_SD(num_classes) # Use simplified constructor
    else:
        print(f"Error: Unknown model_type '{model_type}'")
        return None

    if not os.path.exists(pth_path):
         print(f"Error: Model file not found at {pth_path}.")
         return None

    try:
        # Load state dict; ensure map_location for CPU/GPU compatibility
        model.load_state_dict(torch.load(pth_path, map_location=device))
        model.to(device)
        model.eval() # Set to evaluation mode crucial for testing
        print(f"Model ({model_type}) loaded successfully.")
        return model
    except Exception as e:
        print(f"An error occurred loading model {pth_path}: {e}")
        return None

# --- Test Transform ---
transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

# --- 修改 Cell 5 中的这部分 ---

# --- 替换 Cell 5 中 CIFAR10_C_Dataset 类的【整个】定义 ---

class CIFAR10_C_Dataset(Dataset):
    # 修正后的 __init__ 方法，用于处理 50k 标签/图像文件，按 10k 切片
    def __init__(self, images_npy_path, labels_npy_path, severity, transform=None):
        try:
            self.all_labels = np.load(labels_npy_path) # 加载 (50000,) 形状的标签数组
            self.all_images = np.load(images_npy_path) # 加载 (50000, 32, 32, 3) 形状的图像数组
        except FileNotFoundError as e:
            print(f"Error loading CIFAR-C data in Dataset Init: {e}")
            raise e # 文件找不到就停止

        num_images_per_severity = 10000 # 每个严重级别的固定大小
        expected_total_items = 5 * num_images_per_severity

        # 基本检查 (可选, 但推荐)
        if self.all_images.shape[0] < expected_total_items:
             print(f"    Warning: Image array shape {self.all_images.shape} has fewer images than expected ({expected_total_items}).")
        if self.all_labels.shape[0] < expected_total_items:
             print(f"    Warning: Label array shape {self.all_labels.shape} has fewer labels than expected ({expected_total_items}).")

        start_idx = (severity - 1) * num_images_per_severity
        end_idx = severity * num_images_per_severity

        # 边界检查
        if (start_idx >= self.all_images.shape[0] or end_idx > self.all_images.shape[0] or start_idx < 0 or
            start_idx >= self.all_labels.shape[0] or end_idx > self.all_labels.shape[0]):
             print(f"    Error: Init - Calculated index range [{start_idx}:{end_idx}] is out of bounds for image shape {self.all_images.shape} or label shape {self.all_labels.shape}.")
             self.data = np.array([]) # 出错时创建空数据
             self.labels = np.array([]) # 出错时创建空标签
        else:
             self.data = self.all_images[start_idx:end_idx] # 切片图像
             self.labels = self.all_labels[start_idx:end_idx] # 切片标签

        self.transform = transform
        print(f"Initialized CIFAR10_C_Dataset for severity {severity}: Found {len(self.data)} images/labels.")


    # 调试过的 __len__ 方法
    def __len__(self):
        # DEBUG: Check if this method is called and what self.data is
        # print(f"    DEBUG: __len__ called.") # 如果输出太多可以注释掉
        # print(f"    DEBUG: Type of self.data: {type(self.data)}")
        if hasattr(self.data, 'shape'):
            pass # print(f"    DEBUG: self.data.shape: {self.data.shape}") # 可以注释掉
        try:
            length = len(self.data)
            # print(f"    DEBUG: len(self.data) returned: {length}") # 可以注释掉
            return length
        except Exception as e:
            print(f"    DEBUG: Error calling len(self.data) inside __len__: {e}")
            return 0

    # --- !!! 把这个方法加回去 !!! ---
    # 标准的 __getitem__ 方法
    def __getitem__(self, idx):
        # 在这里进行边界检查也是好习惯
        if idx >= len(self.data):
             raise IndexError(f"Index {idx} out of bounds for dataset with length {len(self.data)}")

        image = self.data[idx]  # 通常来自 npy 的是 HWC uint8 格式
        label = self.labels[idx]

        # 在应用 transform 之前，将 HWC numpy uint8 转换为 PIL Image
        # 这是大多数 torchvision transform（包括 ToTensor）所要求的
        try:
            image = transforms.functional.to_pil_image(image)
        except TypeError as e:
             print(f"Error converting image at index {idx} to PIL Image. Image type: {type(image)}, shape: {image.shape if hasattr(image, 'shape') else 'N/A'}. Error: {e}")
             # 返回虚拟数据或引发错误以避免后续崩溃
             # 返回预期大小的虚拟张量 (例如 3x32x32) 和虚拟标签
             return torch.zeros((3, 32, 32), dtype=torch.float), torch.tensor(-1, dtype=torch.long)


        # 应用转换 (ToTensor, Normalize)
        if self.transform:
            image = self.transform(image)

        # 确保标签是 LongTensor 类型 (CrossEntropyLoss 需要)
        return image, torch.tensor(label, dtype=torch.long)

print("Complete CIFAR10_C_Dataset class defined.")


# --- Evaluation Function ---
def evaluate_accuracy(model, dataloader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images) # Assumes forward returns final output only now
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    if total == 0:
        return 0.0 # Avoid division by zero if dataloader is empty
    accuracy = 100. * correct / total
    return accuracy

print("Helper functions defined.")

Complete CIFAR10_C_Dataset class defined.
Helper functions defined.


In [18]:
print("--- Manual Load Test ---")
try:
    print(f"Testing baseline load from: {BASELINE_PTH}")
    # Ensure get_baseline_resnet50 is defined
    model_b_test = get_baseline_resnet50()
    model_b_test.load_state_dict(torch.load(BASELINE_PTH, map_location=DEVICE))
    print("Manual Baseline Load OK")
    print(f"Testing SD load from: {BEST_SD_PTH}")
    # Ensure ResNet50_SD is defined
    model_sd_test = ResNet50_SD()
    model_sd_test.load_state_dict(torch.load(BEST_SD_PTH, map_location=DEVICE))
    print("Manual SD Load OK")
except Exception as e:
    print(f"Manual load test failed: {e}")
    # Print detailed error
    import traceback
    traceback.print_exc()
print("--- End Manual Load Test ---")

--- Manual Load Test ---
Testing baseline load from: /content/drive/MyDrive/ECE 661 Final Project: Knowledge Distillation/A_Jerry_SelfDistillation/models/resnet50_baseline_cifar10.pth
Manual Baseline Load OK
Testing SD load from: /content/drive/MyDrive/ECE 661 Final Project: Knowledge Distillation/A_Jerry_SelfDistillation/models/resnet50_self_distill_Best_Config_01.pth
Manual SD Load OK
--- End Manual Load Test ---


Cell 6: Configuration & Model Loading

In [19]:
# English: Configure evaluation parameters and load the trained models
# 中文: 配置评估参数并加载训练好的模型

BATCH_SIZE_EVAL = 128 # Batch size for evaluation
NUM_WORKERS = 2       # Number of workers for DataLoader

# --- Load Models ---
print("\n--- Loading Models ---")
# Paths are defined in Cell 2
model_baseline = load_model_robustness('baseline', BASELINE_PTH, DEVICE)
model_sd = load_model_robustness('self_distill', BEST_SD_PTH, DEVICE)

# Check if models loaded successfully before proceeding
if not model_baseline or not model_sd:
    print("\n---! ERROR !---")
    print("One or both models failed to load. Please check paths and definitions.")
    print("Cannot proceed with evaluation.")
    # Stop execution or handle error
    # raise RuntimeError("Model loading failed.")
else:
    print("\nBaseline and Self-Distilled models loaded successfully.")


--- Loading Models ---
Attempting to load model (baseline) from: /content/drive/MyDrive/ECE 661 Final Project: Knowledge Distillation/A_Jerry_SelfDistillation/models/resnet50_baseline_cifar10.pth
Model (baseline) loaded successfully.
Attempting to load model (self_distill) from: /content/drive/MyDrive/ECE 661 Final Project: Knowledge Distillation/A_Jerry_SelfDistillation/models/resnet50_self_distill_Best_Config_01.pth
Model (self_distill) loaded successfully.

Baseline and Self-Distilled models loaded successfully.


Cell 7: Main Evaluation Loop (CIFAR-10-C)

In [27]:
# English: Run the evaluation loop over CIFAR-10-C corruptions and severities
# 中文: 运行遍历 CIFAR-10-C 损坏类型和严重级别的评估循环

# Check if models loaded and label path exists from previous cells
models_loaded = model_baseline is not None and model_sd is not None
labels_exist = os.path.exists(LABELS_PATH)

if models_loaded and labels_exist:
    print("\n--- Starting Robustness Evaluation on CIFAR-10-C ---")

    corruption_types = [
        'gaussian_noise', 'shot_noise', 'impulse_noise', 'defocus_blur',
        'glass_blur', 'motion_blur', 'zoom_blur', 'snow', 'frost', 'fog',
        'brightness', 'contrast', 'elastic_transform', 'pixelate', 'jpeg_compression'
    ]
    severities = [1, 2, 3, 4, 5]
    results = {} # Dictionary to store results

    evaluation_start_time = time.time()

    for corruption in corruption_types:
        print(f"\nEvaluating Corruption Type: {corruption}")
        corruption_file_path = os.path.join(CIFAR_C_PATH, f'{corruption}.npy')

        if not os.path.exists(corruption_file_path):
            print(f"  Warning: {corruption}.npy not found at {corruption_file_path}. Skipping.")
            continue

        results[corruption] = {}
        for severity in severities:
            print(f"  Severity: {severity}")
            try:
                # Create dataset and dataloader
                dataset = CIFAR10_C_Dataset(corruption_file_path, LABELS_PATH, severity, transform=transform_test)
                if len(dataset) == 0: # Check if dataset creation failed silently
                    print(f"    Warning: Dataset empty for severity {severity}. Check data loading logic or file content.")
                    continue
                dataloader = DataLoader(dataset, batch_size=BATCH_SIZE_EVAL, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)

                # Evaluate Baseline
                t_start = time.time()
                baseline_acc = evaluate_accuracy(model_baseline, dataloader, DEVICE)
                t_end = time.time()
                print(f"    Baseline Acc: {baseline_acc:.2f}% (Time: {t_end - t_start:.2f}s)")

                # Evaluate Self-Distilled Model
                t_start = time.time()
                sd_acc = evaluate_accuracy(model_sd, dataloader, DEVICE)
                t_end = time.time()
                print(f"    Self-Distill Acc: {sd_acc:.2f}% (Time: {t_end - t_start:.2f}s)")

                results[corruption][severity] = {'baseline': baseline_acc, 'sd': sd_acc}

            except FileNotFoundError:
                 print(f"    Error: A data file was not found during Dataset initialization. Check paths.")
                 # Break inner loop or continue cautiously
                 results[corruption][severity] = {'baseline': 'Error - Data Missing', 'sd': 'Error - Data Missing'}
            except Exception as e:
                print(f"    Error during evaluation for {corruption} severity {severity}: {e}")
                results[corruption][severity] = {'baseline': 'Error', 'sd': 'Error'}

    evaluation_end_time = time.time()
    print(f"\n--- Evaluation Complete ---")
    print(f"Total Evaluation Time: {(evaluation_end_time - evaluation_start_time)/60:.2f} minutes")

else:
    print("\nEvaluation skipped due to model loading failure or missing label file.")

# The 'results' dictionary now holds all the accuracy data
# Example: Access accuracy -> results['gaussian_noise'][3]['sd']


--- Starting Robustness Evaluation on CIFAR-10-C ---

Evaluating Corruption Type: gaussian_noise
  Severity: 1
Initialized CIFAR10_C_Dataset for severity 1: Found 10000 images/labels.
    Baseline Acc: 83.78% (Time: 3.20s)
    Self-Distill Acc: 85.61% (Time: 2.78s)
  Severity: 2
Initialized CIFAR10_C_Dataset for severity 2: Found 10000 images/labels.
    Baseline Acc: 77.05% (Time: 2.77s)
    Self-Distill Acc: 79.40% (Time: 3.31s)
  Severity: 3
Initialized CIFAR10_C_Dataset for severity 3: Found 10000 images/labels.
    Baseline Acc: 66.91% (Time: 3.36s)
    Self-Distill Acc: 69.84% (Time: 2.81s)
  Severity: 4
Initialized CIFAR10_C_Dataset for severity 4: Found 10000 images/labels.
    Baseline Acc: 61.00% (Time: 2.74s)
    Self-Distill Acc: 64.65% (Time: 2.93s)
  Severity: 5
Initialized CIFAR10_C_Dataset for severity 5: Found 10000 images/labels.
    Baseline Acc: 55.08% (Time: 3.76s)
    Self-Distill Acc: 59.80% (Time: 2.82s)

Evaluating Corruption Type: shot_noise
  Severity: 1
Ini

Cell 8: Results Analysis & Summary

In [28]:
# English: Process and summarize the robustness results
#

if 'results' in locals() and results: # Check if results exist and is not empty
    print("\n--- Evaluation Summary ---")
    all_baseline_accs = []
    all_sd_accs = []
    valid_results_count = 0

    for corruption in results:
        for severity in results[corruption]:
             # Check if results are valid numbers
             baseline_val = results[corruption][severity].get('baseline', None)
             sd_val = results[corruption][severity].get('sd', None)
             if isinstance(baseline_val, (int, float)) and isinstance(sd_val, (int, float)):
                 all_baseline_accs.append(baseline_val)
                 all_sd_accs.append(sd_val)
                 valid_results_count += 1
             else:
                  print(f"Skipping invalid result for {corruption} severity {severity}: Baseline={baseline_val}, SD={sd_val}")


    if valid_results_count > 0:
        avg_baseline_acc_c = np.mean(all_baseline_accs)
        avg_sd_acc_c = np.mean(all_sd_accs)
        print(f"\nAverage Accuracy across {valid_results_count} valid CIFAR-10-C evaluations:")
        print(f"  Average Baseline Accuracy: {avg_baseline_acc_c:.2f}%")
        print(f"  Average Self-Distill Accuracy: {avg_sd_acc_c:.2f}%")
        print(f"  Average Difference (SD - Baseline) on Corruptions: {avg_sd_acc_c - avg_baseline_acc_c:.2f}% points")

        # --- Optional: Calculate Corruption Error metrics ---
        # Example: Clean Accuracy (assuming you have them)
        clean_baseline_acc = 87.68 # Your baseline acc
        clean_sd_acc = 88.29 # Your best SD acc

        # You can calculate relative accuracy drop or mCE if needed for reporting
        # Example: Relative accuracy
        relative_baseline = avg_baseline_acc_c / clean_baseline_acc
        relative_sd = avg_sd_acc_c / clean_sd_acc
        print(f"\nRelative Accuracy (Avg Corrupted / Clean):")
        print(f"  Relative Baseline: {relative_baseline:.4f}")
        print(f"  Relative Self-Distill: {relative_sd:.4f}")
        if relative_sd > relative_baseline:
             print("  Self-Distill maintained accuracy better on average.")
        else:
             print("  Baseline maintained accuracy better or equally on average.")


        # --- Optional: Plotting results ---
        # Example: Plot avg acc per corruption type
        # avg_acc_per_corruption = {}
        # for corruption in results:
        #     baseline_corrupt_accs = [results[corruption][s]['baseline'] for s in results[corruption] if isinstance(results[corruption][s].get('baseline'), (int, float))]
        #     sd_corrupt_accs = [results[corruption][s]['sd'] for s in results[corruption] if isinstance(results[corruption][s].get('sd'), (int, float))]
        #     if baseline_corrupt_accs and sd_corrupt_accs:
        #          avg_acc_per_corruption[corruption] = {'baseline': np.mean(baseline_corrupt_accs), 'sd': np.mean(sd_corrupt_accs)}
        #
        # print("\nAverage Accuracy per Corruption Type:")
        # for corruption, accs in avg_acc_per_corruption.items():
        #      print(f"  {corruption:<20}: Baseline={accs['baseline']:.2f}%, SD={accs['sd']:.2f}%")
        # (Could use matplotlib to create bar plots here)

    else:
        print("\nNo valid results found to summarize.")

else:
    print("\nNo evaluation results available to analyze.")


--- Evaluation Summary ---

Average Accuracy across 75 valid CIFAR-10-C evaluations:
  Average Baseline Accuracy: 73.37%
  Average Self-Distill Accuracy: 74.23%
  Average Difference (SD - Baseline) on Corruptions: 0.86% points

Relative Accuracy (Avg Corrupted / Clean):
  Relative Baseline: 0.8367
  Relative Self-Distill: 0.8407
  Self-Distill maintained accuracy better on average.
