In [1]:
import numpy as np
import cupy as cp
def cross_entropy_loss(logits, y_true, model, reg_lambda):
    m = y_true.shape[0]
    
    # 数值稳定的Softmax
    shifted_logits = logits - cp.max(logits, axis=1, keepdims=True)  # 防止指数溢出
    exp_logits = cp.exp(shifted_logits)
    probs = exp_logits / cp.sum(exp_logits, axis=1, keepdims=True)
    
    # 计算交叉熵损失
    correct_log_probs = -shifted_logits[cp.arange(m), y_true] + cp.log(cp.sum(exp_logits, axis=1))
    loss = cp.mean(correct_log_probs)
    
    # L2正则化
    reg_loss = 0.5 * reg_lambda * sum(cp.sum(layer.W**2) for layer in model.layers if isinstance(layer, Linear))
    total_loss = loss + reg_loss
    
    # 梯度计算
    grad = probs.copy()
    grad[cp.arange(m), y_true] -= 1
    grad /= m  # 平均梯度
    
    return total_loss, grad