# ResNet18 Model Prediction Script

This Notebook is used to load a trained ResNet18 model or MixResNet18 model and run predictions on a specified test set.

### Import Libraries

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import transforms
from PIL import Image
import pandas as pd
import numpy as np
import os
import json
import random

### Configuration

In [2]:
# --- 1. Paths ---
# Directory where your saved model and stats are located

MODEL_DIR = "/path/to/your/data/folder" 
MODEL_PATH = os.path.join(MODEL_DIR, "model/ResNet18_best_model.pth") # model file
STATS_PATH = os.path.join(MODEL_DIR, "model/ResNet18_normalization_stat.json") # mean and std of training set

# Path to the data you want to predict on
ROOT_DIR = "/path/to/your/data/folder" 
TEST_CSV_PATH = os.path.join(ROOT_DIR, "data/test_info.csv") # There are two columns, the first is Path and the second is label

# Path to save the final prediction results
RESULTS_DIR  = "/path/to/your/data/folder" 
PREDICTION_SAVE_PATH = os.path.join(RESULTS_DIR, "predictions_example.csv")

# --- 2. Model & Data Parameters ---
NUM_CLASSES = 10
BATCH_SIZE = 64 # You can use a larger batch size for inference

# --- 3. Device Setup ---
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Set random seeds for reproducibility
RANDOM_SEED = 42
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(RANDOM_SEED)



Using device: cuda:0


### Helper Functions & Classes

In [3]:
def ResNet18(num_classes):
    """Initializes the ResNet18 model structure."""
    res18 = models.resnet18(weights=None)
    num_ftrs = res18.fc.in_features
    res18.fc = nn.Linear(num_ftrs, num_classes)
    return res18

def load_model_and_transform(model_path, stats_path, num_classes, device):
    """Loads the model and the normalization transform from separate files."""
    with open(stats_path, 'r') as f:
        norm_stats = json.load(f)
    mean_loaded = norm_stats['mean']
    std_loaded = norm_stats['std']
    
    print("Loaded normalization stats:")
    print(f" - Mean: {mean_loaded}")
    print(f" - Std: {std_loaded}")
    
    model = ResNet18(num_classes).to(device)
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.eval() # Set model to evaluation mode
    
    inference_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=mean_loaded, std=std_loaded)
    ])
    
    return model, inference_transform

class PredictionDataset(Dataset):
    """Custom Dataset for loading data for prediction."""
    def __init__(self, root_dir, csv_path, transform=None):
        self.root_dir = root_dir
        self.df = pd.read_csv(csv_path, header=None, names=["Path", "Label"])
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, index):
        row = self.df.iloc[index]
        img_path_relative = row["Path"]
        label_str = row["Label"] # We'll keep the string label for the final CSV
        
        img_path_full = os.path.join(self.root_dir, img_path_relative)
        image = Image.open(img_path_full).convert("RGB")
        
        if self.transform:
            image = self.transform(image)
            
        return image, label_str, img_path_full

### Load the Model and Data

In [4]:
# 1. Load the trained model and the corresponding data transform
print("--- Loading Model and Normalization Stats ---")
model, inference_transform = load_model_and_transform(MODEL_PATH, STATS_PATH, NUM_CLASSES, device)

# 2. Create the test set Dataset and DataLoader
print("\n--- Loading Test Data ---")
test_dataset = PredictionDataset(
    root_dir=ROOT_DIR,
    csv_path=TEST_CSV_PATH,
    transform=inference_transform
)
test_dataloader = DataLoader(
    dataset=test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False, # No need to shuffle for prediction
    num_workers=4
)
print(f"Found {len(test_dataset)} samples in the test set.")


--- Loading Model and Normalization Stats ---
Loaded normalization stats:
 - Mean: [0.8576046824455261, 0.9898070096969604, 0.8576046824455261]
 - Std: [0.3346951901912689, 0.0391424335539341, 0.3346951901912689]

--- Loading Test Data ---
Found 10 samples in the test set.


  model.load_state_dict(torch.load(model_path, map_location=device))


### Run Prediction 

In [5]:
def run_prediction(model, dataloader, device):
    """Runs the prediction loop and returns results in a list of dictionaries."""
    results_list = []
    model.eval() # Ensure model is in evaluation mode
    
    print("Prediction loop started... This may take a while depending on the test set size.")
    
    with torch.no_grad():
        # The tqdm wrapper has been removed from the loop below.
        for images, labels_str, paths in dataloader:
            images = images.to(device)
            outputs = model(images)
            probabilities = F.softmax(outputs, dim=1)
            _, preds_indices = torch.max(outputs, 1)
            
            probabilities_cpu = probabilities.cpu().numpy()
            preds_indices_cpu = preds_indices.cpu().numpy()
            
            for i in range(len(paths)):
                result = {
                    'ImagePath': paths[i],
                    'ActualLabel': labels_str[i],
                    'PredictedLabelIndex': preds_indices_cpu[i],
                }
                for j in range(NUM_CLASSES):
                    result[f'Prob_Class_{j}'] = probabilities_cpu[i, j]
                results_list.append(result)
                
    print("Prediction loop finished.")
    return results_list

print("\n--- Running Prediction ---")
results_list = run_prediction(model, test_dataloader, device)
results_df = pd.DataFrame(results_list)

# Display the first few rows of the DataFrame directly in Jupyter for a quick preview
print("\nPrediction results preview:")
results_df.head()



--- Running Prediction ---
Prediction loop started... This may take a while depending on the test set size.
Prediction loop finished.

Prediction results preview:


Unnamed: 0,ImagePath,ActualLabel,PredictedLabelIndex,Prob_Class_0,Prob_Class_1,Prob_Class_2,Prob_Class_3,Prob_Class_4,Prob_Class_5,Prob_Class_6,Prob_Class_7,Prob_Class_8,Prob_Class_9
0,/mnt/local-disk/data/duzhaozhen/category_DNAm_...,BiNonAge,9,2.716695e-07,0.0005336623,0.002718097,6.724039e-05,0.00152881,1.038607e-05,0.1013383,0.3451014,0.014574,0.5341275
1,/mnt/local-disk/data/duzhaozhen/category_DNAm_...,TriAge,6,3.156298e-09,4.425022e-16,2.128968e-05,1.984568e-06,9.408601e-16,1.280801e-12,0.9972817,1.270269e-05,0.002682,3.198303e-11
2,/mnt/local-disk/data/duzhaozhen/category_DNAm_...,LinAgeHet,3,1.024525e-12,5.767382e-15,2.490101e-11,0.9999981,6.501111e-11,3.216262e-09,6.019651e-10,3.652668e-15,2e-06,1.500561e-16
3,/mnt/local-disk/data/duzhaozhen/category_DNAm_...,TriNonAge,7,1.154786e-08,0.001813379,0.0001826217,9.41744e-08,1.037695e-06,5.719502e-07,0.07913074,0.900133,0.000286,0.01845288
4,/mnt/local-disk/data/duzhaozhen/category_DNAm_...,NonAgeHet,4,1.498795e-15,8.011916e-13,1.892851e-12,0.0002820744,0.9992661,2.665998e-13,7.39058e-11,4.916971e-09,1.9e-05,0.0004326642


### Post-process and Analyze Results  

In [6]:
# Map the predicted numerical index back to the class name
label_mapping = {
    'UniStable': 0, 'UniUnstable': 1,
    'LinAgeHomo': 2, 'LinAgeHet': 3, 'NonAgeHet': 4, 'Outlier': 5,
    'TriAge': 6, 'TriNonAge': 7, 'BiAge': 8, 'BiNonAge': 9
}

# Create a reverse mapping from index to name
idx_to_class = {v: k for k, v in label_mapping.items()}
results_df['PredictedLabelName'] = results_df['PredictedLabelIndex'].map(idx_to_class)

# Calculate and print the final accuracy
correct_predictions = (results_df['ActualLabel'] == results_df['PredictedLabelName']).sum()
total_samples = len(results_df)
accuracy = correct_predictions / total_samples

print(f"\n--- Prediction Analysis ---")
print(f"Test Set Accuracy: {correct_predictions}/{total_samples} ({accuracy:.4f})")


--- Prediction Analysis ---
Test Set Accuracy: 10/10 (1.0000)


### Save Final Results

In [7]:
# Reorder the columns for a cleaner final output
cols_to_show = ['ImagePath', 'ActualLabel', 'PredictedLabelName', 'PredictedLabelIndex'] + [f'Prob_Class_{j}' for j in range(NUM_CLASSES)]
final_df = results_df[cols_to_show]

# rename
rename_mapping = {f'Prob_Class_{index}': name for index, name in idx_to_class.items()}
final_df = final_df.rename(columns=rename_mapping)

# Save the results
final_df.to_csv(PREDICTION_SAVE_PATH, index=False)
print(f"\nComplete results saved to: {PREDICTION_SAVE_PATH}")

final_df


Complete results saved to: predictions_example.csv


Unnamed: 0,ImagePath,ActualLabel,PredictedLabelName,PredictedLabelIndex,UniStable,UniUnstable,LinAgeHomo,LinAgeHet,NonAgeHet,Outlier,TriAge,TriNonAge,BiAge,BiNonAge
0,/mnt/local-disk/data/duzhaozhen/category_DNAm_...,BiNonAge,BiNonAge,9,2.716695e-07,0.0005336623,0.002718097,6.724039e-05,0.00152881,1.038607e-05,0.1013383,0.3451014,0.0145743,0.5341275
1,/mnt/local-disk/data/duzhaozhen/category_DNAm_...,TriAge,TriAge,6,3.156298e-09,4.425022e-16,2.128968e-05,1.984568e-06,9.408601e-16,1.280801e-12,0.9972817,1.270269e-05,0.00268225,3.198303e-11
2,/mnt/local-disk/data/duzhaozhen/category_DNAm_...,LinAgeHet,LinAgeHet,3,1.024525e-12,5.767382e-15,2.490101e-11,0.9999981,6.501111e-11,3.216262e-09,6.019651e-10,3.652668e-15,1.890021e-06,1.500561e-16
3,/mnt/local-disk/data/duzhaozhen/category_DNAm_...,TriNonAge,TriNonAge,7,1.154786e-08,0.001813379,0.0001826217,9.41744e-08,1.037695e-06,5.719502e-07,0.07913074,0.900133,0.0002857295,0.01845288
4,/mnt/local-disk/data/duzhaozhen/category_DNAm_...,NonAgeHet,NonAgeHet,4,1.498795e-15,8.011916e-13,1.892851e-12,0.0002820744,0.9992661,2.665998e-13,7.39058e-11,4.916971e-09,1.913379e-05,0.0004326642
5,/mnt/local-disk/data/duzhaozhen/category_DNAm_...,BiAge,BiAge,8,2.864129e-09,5.599494e-13,0.0001206779,0.01749925,1.206675e-10,8.809237e-10,0.01848551,4.003854e-06,0.9638905,8.20778e-09
6,/mnt/local-disk/data/duzhaozhen/category_DNAm_...,UniUnstable,UniUnstable,1,1.813829e-13,0.999984,4.494148e-06,1.90538e-07,6.820175e-07,1.098433e-06,2.904327e-07,1.446644e-06,3.979683e-07,7.347324e-06
7,/mnt/local-disk/data/duzhaozhen/category_DNAm_...,LinAgeHomo,LinAgeHomo,2,2.362736e-08,3.211895e-07,0.9994349,3.45058e-05,1.026451e-11,3.393887e-09,5.199942e-06,5.479616e-09,0.000525015,8.630402e-09
8,/mnt/local-disk/data/duzhaozhen/category_DNAm_...,UniStable,UniStable,0,0.9999963,9.983666e-09,9.573859e-09,2.557223e-06,2.165029e-07,8.719268e-07,3.019701e-08,1.820173e-09,3.262653e-08,1.09657e-09
9,/mnt/local-disk/data/duzhaozhen/category_DNAm_...,Outlier,Outlier,5,8.758417e-19,1.060473e-16,6.974279e-10,1.359096e-11,3.626419e-12,1.0,4.086044e-12,6.452494e-13,5.823053e-13,1.748979e-15
