# MixResNet18 Model Tranining Script
### This Notebook is used to train ResNet18 model.

#### Import

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.nn import CrossEntropyLoss
from torchvision.transforms import transforms
from PIL import Image
import pandas as pd
import numpy as np
import os
import random
from time import time
import json


#### 1. Global Settings


In [2]:
# Hyperparameters
BATCH_SIZE = 128
EPOCH = 500
NUM_CLASSES = 10
RANDOM_SEED = 42
EARLY_STOPPING_PATIENCE = 30
EARLY_STOPPING_DELTA = 0.0001
SHOW_TRAINING_PROCESS = False #Whether to display the training process

# Path settings
ROOT_DIR = "/path/to/your/data/folder" 
TRAIN_CSV_PATH = os.path.join(ROOT_DIR, "meta/training_info.csv") # There are two columns, the first is Path and the second is label
VAL_CSV_PATH = os.path.join(ROOT_DIR, "meta/val_info.csv") # There are two columns, the first is Path and the second is label

# Path for saving the best model and the results CSV
RESULTS_DIR = "/path/to/your/results/folder" ### Paths where training set files and validation set files and their images are stored
MODEL_SAVE_PATH = os.path.join(RESULTS_DIR, "MixResNet18_best_model.pth") # model file
STATS_SAVE_PATH = os.path.join(RESULTS_DIR, "MixResNet18_normalization_stats.json") # mean and std
CSV_SAVE_PATH = os.path.join(RESULTS_DIR, "MixResNet18_training_log.csv") ## Training logs with the loss and accuracy of the training and test sets in each epoch
os.makedirs(RESULTS_DIR, exist_ok=True) # Ensure the results directory exists


# Device setup
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Set random seeds for reproducibility
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(RANDOM_SEED)

# Label and class mapping
CUSTOM_LABEL_MAPPING =  {'UniStable': 0, 'UniUnstable': 1,'LinAgeHomo': 2, 'LinAgeHet': 3,'NonAgeHet': 4,
                        'Outlier': 5,'TriAge': 6,'TriNonAge': 7,'BiAge':8, 'BiNonAge': 9}

CLASSES = tuple(CUSTOM_LABEL_MAPPING.keys())

Using device: cuda:0


### 2.Dataset Definition

In [3]:
class CpGImageDataset(Dataset):
    """
    Loads image paths and labels from a CSV file and creates a dataset.
    """
    def __init__(self, root_dir, path_csv, label_mapping, transform=None):
        self.root_dir = root_dir
        self.path_csv = path_csv
        self.transform = transform
        self.label_mapping = label_mapping
        self.img_info = [] # [(path, label), ... , ]
        self._get_img_info()

    def __getitem__(self, index):
        path_img, label = self.img_info[index]
        img = Image.open(path_img).convert('RGB')
        if self.transform is not None:
            img = self.transform(img)
        return img, label

    def __len__(self):
        if len(self.img_info) == 0:
            raise Exception(f"\ndata_dir:{self.root_dir} is an empty dir! Please check your path to images!")
        return len(self.img_info)

    def _get_img_info(self):
        """
        Reads the CSV file and parses image paths and labels.
        """
        column_names = ["Path","Label"]
        df = pd.read_csv(self.path_csv,
                header=None,
                index_col=False,
                names=column_names,
                sep=","
                )
        df.reset_index(inplace=True)

        for idx in range(len(df)):
            path_img = os.path.join(self.root_dir, df.loc[idx, "Path"])
            label_str = df.at[idx,"Label"]
            if label_str in self.label_mapping:
                label_int = self.label_mapping[label_str]
            else:
                raise ValueError(f"Unknown label on line {idx}: {label_str}")
            self.img_info.append((path_img, label_int))

### 3. Calculate Mean and Std

In [4]:
def get_stat(dataset):
    """
    Computes the mean and standard deviation per channel for the dataset.
    :param dataset: A PyTorch Dataset object.
    :return: (mean, std)
    """
    print('Computing mean and variance for training data...')
    loader = DataLoader(dataset, batch_size=1, shuffle=False, pin_memory=True, num_workers=4)
    
    mean = torch.zeros(3)
    std = torch.zeros(3)
    for x, _ in loader:
        for d in range(3):
            mean[d] += x[:, d, :, :].mean()
            std[d] += x[:, d, :, :].std()
    mean.div_(len(dataset))
    std.div_(len(dataset))
    return list(mean.numpy()), list(std.numpy())

# Step 3.1: Create a temporary dataset with only ToTensor transformation to calculate stats
temp_transform = transforms.Compose([transforms.ToTensor()])
stat_dataset = CpGImageDataset(ROOT_DIR, TRAIN_CSV_PATH, CUSTOM_LABEL_MAPPING, transform=temp_transform)
# Step 3.2: Calculate the mean and std
mean, std = get_stat(stat_dataset)
print(f"Calculated Mean: {mean}")
print(f"Calculated Std: {std}")

print("Saving normalization stats...")
norm_stats = {
    'mean': [float(x) for x in mean],
    'std': [float(x) for x in std]
}
with open(STATS_SAVE_PATH, 'w') as f:
    json.dump(norm_stats, f, indent=4)
#print(f"Stats saved to: {STATS_SAVE_PATH}")
print(f"Stats saved to: ~/model/MixResNet18_normalization_stats.json")

Computing mean and variance for training data...
Calculated Mean: [np.float32(0.8675819), np.float32(0.99111927), np.float32(0.8675819)]
Calculated Std: [np.float32(0.3242125), np.float32(0.03585337), np.float32(0.3242125)]
Saving normalization stats...
Stats saved to: ~/model/MixResNet18_normalization_stats.json


### 4. Data Preprocessing and Loading

In [5]:
# Define data transformations using the calculated mean and std
data_transform = {
    "train": transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ]),
    "val": transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ])
}

# Create the final training and validation datasets
train_data = CpGImageDataset(ROOT_DIR, TRAIN_CSV_PATH, CUSTOM_LABEL_MAPPING, transform=data_transform["train"])
val_data = CpGImageDataset(ROOT_DIR, VAL_CSV_PATH, CUSTOM_LABEL_MAPPING, transform=data_transform["val"])

# Create DataLoaders
train_dataloader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
val_dataloader = DataLoader(val_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)

print(f"Number of training samples: {len(train_data)}")
print(f"Number of validation samples: {len(val_data)}")

Number of training samples: 18000
Number of validation samples: 2000


 
### 5. Model Definition


In [6]:

def ResNet18(num_classes):
    """
    Loads the ResNet18 model and adapts its fully connected layer
    to the specified number of classes.
    """
    res18 = models.resnet18(weights=None) # Train from scratch, not using pre-trained weights
    num_ftrs = res18.fc.in_features
    res18.fc = nn.Linear(num_ftrs, num_classes)
    return res18

# Instantiate the model and move it to the specified device
model = ResNet18(num_classes=NUM_CLASSES).to(device)

# Define optimizer with weight decay
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4)
scheduler = ReduceLROnPlateau(optimizer, mode="min", factor=0.1, patience=3)
loss_function = CrossEntropyLoss()


### 6. Evaluation Function and Early Stopping Class


In [7]:


def evaluate(model, data_loader, loss_fn, device):
    """Evaluates the model on a given dataset."""
    model.eval()  # Set model to evaluation mode
    total_loss = 0.0
    correct_predictions = 0
    total_samples = 0
    
    with torch.no_grad():  # No need to track gradients for evaluation
        for images, labels in data_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = loss_fn(outputs, labels)
            total_loss += loss.item()
            
            prediction = outputs.argmax(dim=1)
            correct_predictions += torch.eq(prediction, labels).sum().item()
            total_samples += labels.size(0)

    avg_loss = total_loss / len(data_loader)

    accuracy = correct_predictions / total_samples
    return accuracy, avg_loss

import torch
import numpy as np

class EarlyStopping:
    """Early stops the training if validation accuracy doesn't improve after a given patience."""
    def __init__(self, save_path, patience=7, verbose=False, delta=0):
        """
        Args:
            save_path (str):  Path for saving the best model.。
            patience (int): How long to wait after last time validation accuracy improved.
            verbose (bool):  If True, prints a message for each validation accuracy improvement.
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
        """
        self.save_path = save_path
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_acc_max = float('-inf')
        self.delta = delta

    def __call__(self, val_acc, model):
        score = val_acc

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_acc, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_acc, model)
            self.counter = 0

    def save_checkpoint(self, val_acc, model):
        '''When the accuracy of the validation set improves, save the model'''
        if self.verbose:
            print(f'Validation accuracy increased ({self.val_acc_max:.6f} --> {val_acc:.6f}).  Saving model ...')
        torch.save(model.state_dict(), self.save_path)
        self.val_acc_max = val_acc




### 7. Training and Validation Loop

In [8]:
# Lists to store metrics for plotting and analysis
all_epochs_train_loss = []
all_epochs_train_acc = []
all_epochs_val_loss = []
all_epochs_val_acc = []
epoch_times = []

# Instantiate EarlyStopping
early_stopping = EarlyStopping(
    save_path=MODEL_SAVE_PATH, 
    patience=EARLY_STOPPING_PATIENCE, 
    verbose=SHOW_TRAINING_PROCESS:, 
    delta=EARLY_STOPPING_DELTA
)

start_time = time()
print("Starting training...")

for epoch in range(EPOCH):
    model.train()  # Set model to training mode
    running_loss = 0.0
    train_correct = 0
    train_total = 0
    epoch_start_time = time()

    if SHOW_TRAINING_PROCESS:
        print(f"--- Epoch {epoch+1}/{EPOCH} ---")
        print(f"Current Learning Rate: {optimizer.param_groups[0]['lr']:.6f}")

    for step, (images, labels) in enumerate(train_dataloader):
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = loss_function(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        prediction = outputs.argmax(dim=1)
        train_correct += torch.eq(prediction, labels).sum().item()
        train_total += labels.size(0)

        if SHOW_TRAINING_PROCESS:
            # Progress bar
            rate = (step + 1) / len(train_dataloader)
            a = "*" * int(rate * 50)
            b = "." * int((1 - rate) * 50)
            print(f"\rTraining: {int(rate*100):>3d}%[{a}>{b}] Loss: {loss.item():.4f}", end="")

    epoch_train_loss = running_loss / len(train_dataloader)
    epoch_train_acc = train_correct / train_total
    all_epochs_train_loss.append(epoch_train_loss)
    all_epochs_train_acc.append(epoch_train_acc)

    # Validation phase
    val_acc, val_loss = evaluate(model, val_dataloader, loss_function, device)
    all_epochs_val_acc.append(val_acc)
    all_epochs_val_loss.append(val_loss)
    
    scheduler.step(val_acc) # Adjust learning rate based on validation loss
    
    epoch_end_time = time()
    epoch_duration = epoch_end_time - epoch_start_time
    epoch_times.append(epoch_duration)

    if SHOW_TRAINING_PROCESS:
        print(f"\nEpoch {epoch+1} Summary | Time: {epoch_duration:.2f}s")
        print(f"  Train Loss: {epoch_train_loss:.4f}, Train Acc: {epoch_train_acc:.4f}")
        print(f"  Valid Loss: {val_loss:.4f},   Valid Acc: {val_acc:.4f}")

    # Early stopping check
    early_stopping(val_acc, model)
    if early_stopping.early_stop:
        print("Early stopping triggered")
        break

print("\n--- Training Finished! ---")
total_duration = time() - start_time
print(f"Total training time: {total_duration/60:.2f} minutes")
print(f"Best model saved at {MODEL_SAVE_PATH}" )

Starting training...
Early stopping triggered

--- Training Finished! ---
Total training time: 35.18 minutes
Best model saved at ~/model/Best model/MixResNet18_best_model.pth



### 8. Save Results to CSV


In [9]:
# Create a DataFrame for the epochs that actually ran
completed_epochs = len(all_epochs_train_loss)
df = pd.DataFrame({
    'epoch': range(1, completed_epochs + 1),
    'train_loss': all_epochs_train_loss,
    'train_acc': all_epochs_train_acc,
    'val_loss': all_epochs_val_loss,
    'val_acc': all_epochs_val_acc,
    'epoch_duration_s': epoch_times
})

df.to_csv(CSV_SAVE_PATH, index=False)
print(f"\nTraining log saved to {CSV_SAVE_PATH}")


Training log saved to ~/model/log/MixResNet18_training_log.csv
