In [1]:
import torch
from collections import Counter
from scipy.special import softmax
import numpy as np
from sklearn.datasets import make_classification
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix

# 手工构造样本不均衡的数据集

In [2]:
weights=[0.05, 0.1, 0.1, 0.25, 0.5]

X, y = make_classification(n_samples=20000,        # 样本个数
                           n_features=20,          # 特征个数
                           n_informative=5,        # 有效特征个数
                           n_redundant=2,          # 冗余特征个数（有效特征的随机组合）
                           n_repeated=0,           # 重复特征个数（有效特征和冗余特征的随机组合）
                           n_classes=5,            # 样本类别
                           n_clusters_per_class=2, # 簇的个数
                           weights=weights,
                           class_sep=1.5,
#                            class_sep=0.3,
                           random_state=0)

x_train, x_test, y_train, y_test = train_test_split(X,y,random_state = 33,test_size = 0.2)
x_train, x_valid, y_train, y_valid = train_test_split(x_train,y_train,random_state = 33,test_size = 0.25)

print(f"training set: {Counter(y_train)}")
print(f"testing set: {Counter(y_valid)}")
print(f"valid set: {Counter(y_test)}")

training set: Counter({4: 6038, 3: 3004, 2: 1179, 1: 1164, 0: 615})
testing set: Counter({4: 1973, 3: 951, 1: 444, 2: 430, 0: 202})
valid set: Counter({4: 1934, 3: 1048, 2: 415, 1: 411, 0: 192})


# 尝试一：不考虑样本不均衡问题，直接训练

In [82]:
# 训练过程中不做 balance
model = LogisticRegression(class_weight=None, solver='lbfgs').fit(x_train,y_train)
y_train_pred = model.predict(x_train)
print(classification_report(y_pred=y_train_pred, y_true=y_train))
print(confusion_matrix(y_pred=y_train_pred, y_true=y_train))

y_test_pred = model.predict(x_test)
print(classification_report(y_pred=y_test_pred, y_true=y_test))
print(confusion_matrix(y_pred=y_test_pred, y_true=y_test))



              precision    recall  f1-score   support

           0       0.72      0.21      0.33       615
           1       0.55      0.30      0.39      1164
           2       0.84      0.81      0.83      1179
           3       0.59      0.50      0.54      3004
           4       0.72      0.89      0.80      6038

    accuracy                           0.69     12000
   macro avg       0.69      0.54      0.58     12000
weighted avg       0.68      0.69      0.67     12000

[[ 129   40  122   79  245]
 [   2  347    3  401  411]
 [  24    0  953   26  176]
 [   1  183   32 1513 1275]
 [  22   56   18  561 5381]]
              precision    recall  f1-score   support

           0       0.76      0.23      0.36       192
           1       0.59      0.30      0.40       411
           2       0.88      0.80      0.84       415
           3       0.60      0.51      0.55      1048
           4       0.70      0.89      0.78      1934

    accuracy                           0.69 

# 尝试二: 在训练过程中做 balance

In [4]:
# 训练过程中做 balance
model = LogisticRegression(class_weight='balanced', solver='lbfgs').fit(x_train,y_train)
y_train_pred = model.predict(x_train)
print(classification_report(y_pred=y_train_pred, y_true=y_train))
print(confusion_matrix(y_pred=y_train_pred, y_true=y_train))

y_test_pred = model.predict(x_test)
print(classification_report(y_pred=y_test_pred, y_true=y_test))
print(confusion_matrix(y_pred=y_test_pred, y_true=y_test))



              precision    recall  f1-score   support

           0       0.45      0.56      0.50       615
           1       0.40      0.66      0.50      1164
           2       0.68      0.82      0.75      1179
           3       0.54      0.53      0.54      3004
           4       0.85      0.69      0.76      6038

    accuracy                           0.66     12000
   macro avg       0.58      0.65      0.61     12000
weighted avg       0.69      0.66      0.67     12000

[[ 347   50  190   16   12]
 [   3  770    4  287  100]
 [ 146    5  971   16   41]
 [  17  651  161 1590  585]
 [ 262  457   96 1027 4196]]
              precision    recall  f1-score   support

           0       0.46      0.56      0.50       192
           1       0.43      0.66      0.52       411
           2       0.69      0.82      0.75       415
           3       0.56      0.55      0.55      1048
           4       0.84      0.71      0.77      1934

    accuracy                           0.66 

# 方案三: 阈值移动（threshold-moving）

即：训练阶段不做 balance, 在预测阶段 后处理调整了 threshold

参考：[分类任务中数据类别不平衡问题](https://blog.csdn.net/kuaizi_sophia/article/details/84894363)

In [5]:
# 训练阶段不做 balance, 在预测阶段 后处理调整了 threshold

model = LogisticRegression(class_weight=None, solver='lbfgs').fit(x_train,y_train)
pred = model.predict_proba(x_train)
pred = pred / weights
y_train_pred = np.argmax(pred, axis=1)

print(classification_report(y_pred=y_train_pred, y_true=y_train))
print(confusion_matrix(y_pred=y_train_pred, y_true=y_train))

pred = model.predict_proba(x_test)
pred = pred / weights
y_test_pred = np.argmax(pred, axis=1)
print(classification_report(y_pred=y_test_pred, y_true=y_test))
print(confusion_matrix(y_pred=y_test_pred, y_true=y_test))



              precision    recall  f1-score   support

           0       0.33      0.65      0.44       615
           1       0.38      0.72      0.50      1164
           2       0.67      0.80      0.73      1179
           3       0.53      0.52      0.53      3004
           4       0.88      0.62      0.72      6038

    accuracy                           0.62     12000
   macro avg       0.56      0.66      0.58     12000
weighted avg       0.70      0.62      0.64     12000

[[ 401   47  145    9   13]
 [   4  836   24  233   67]
 [ 199    6  941    7   26]
 [  54  806  172 1566  406]
 [ 569  506  123 1120 3720]]
              precision    recall  f1-score   support

           0       0.32      0.63      0.42       192
           1       0.41      0.71      0.52       411
           2       0.67      0.80      0.73       415
           3       0.55      0.54      0.54      1048
           4       0.86      0.62      0.72      1934

    accuracy                           0.62 

# 方案三：温度后调整

## 使用启发式的方法做温度调整

使用启发式的方法做温度调整时，损失函数并不直接与优化目标（预测样本类别占比和真实样本类别占比相近）一致，最终的结果比启发式方法的结果要差

In [6]:
# 训练过程中不做 balance，在预测阶段，为每个类别调整 temperature

x_valid_new = x_valid
y_valid_new = y_valid
x_test_new = x_test
y_test_new = y_test

model = LogisticRegression(class_weight=None, solver='lbfgs').fit(x_train,y_train)

raw_valid_probs = model.predict_proba(x_valid_new)
raw_test_probs = model.predict_proba(x_test_new)

temperature = [1.0] * 5
y_valid_counter = Counter(y_valid_new)
target_distrib = [y_valid_counter[key]/len(y_valid_new) for key in sorted(Counter(y_valid_new).keys())]


def auto_tuning_temperature_heuristic(model, x_val, y_val, max_diff=0.005, max_iter=100):
    ''' 调整温度，使得 model 的预测输出分布 和 y_val 中分布接近
    @params model: 模型，需要有方法 predict_proba
    '''
    label_counter = Counter(y_val)
    temperature = [1.0 for _ in range(len(label_counter))]
    y_prob = model.predict_proba(x_val)
    y_true_distrib = [label_counter[key]/len(y_valid) for key in sorted(label_counter.keys())]
    
    learning_rate = 1e-1
    learning_rate = 0.9
    for t in range(100):
        y_prob_tuned = softmax(np.log(y_prob) / temperature, axis=1)
        y_pred_tuned = np.argmax(y_prob_tuned, axis=1)
        
        tuned_label_counter = Counter(y_pred_tuned)
        y_pred_tuned_distrib = [tuned_label_counter[key]/len(y_valid) 
                                for key in sorted(tuned_label_counter.keys())]
        diff = np.array(y_true_distrib) - np.array(y_pred_tuned_distrib)
        if t % 10 == 9:
            print(f"========= iteration: {t} diff: {np.max(np.abs(diff))}, temperature: {temperature}")
        if max(abs(diff)) < max_diff:
            break
        # temperature = temperature * (np.exp(diff))
        temperature += learning_rate * (temperature * (np.exp(diff)) - temperature)

    return temperature

temperature = auto_tuning_temperature_heuristic(model, x_valid_new, y_valid_new)

# use best temperature when test
tuned_test_probs = softmax(np.log(raw_test_probs)/temperature, axis=1)
tuned_test_preds = np.argmax(tuned_test_probs, axis=1)
print(classification_report(y_pred=tuned_test_preds, y_true=y_test_new))
print(confusion_matrix(y_pred=tuned_test_preds, y_true=y_test_new))



              precision    recall  f1-score   support

           0       0.57      0.48      0.52       192
           1       0.46      0.52      0.49       411
           2       0.80      0.81      0.80       415
           3       0.59      0.55      0.57      1048
           4       0.79      0.81      0.80      1934

    accuracy                           0.69      4000
   macro avg       0.64      0.63      0.64      4000
weighted avg       0.69      0.69      0.69      4000

[[  92   20   49    6   25]
 [   0  215    1  114   81]
 [  43    3  336    8   25]
 [   0  161   27  576  284]
 [  25   67    9  273 1560]]


## 基于梯度的温度后调整策略

In [7]:
def auto_tuning_temperature(model, x_val, y_val):
    ''' 调整温度，使得 model 的预测输出分布 和 y_val 中分布接近
    @params model: 模型，需要有方法 predict_proba
    '''
    label_counter = Counter(y_val)
    temperature = torch.tensor([1.0] * len(label_counter), requires_grad=True)
    y_prob = torch.from_numpy(model.predict_proba(x_val))
    y_true_distrib = torch.tensor([label_counter[key]/len(y_valid) for key in sorted(label_counter.keys())])
    
    learning_rate = 1e-1
    for t in range(10000):
        y_prob_tuned = torch.softmax(torch.log(y_prob) / temperature, axis=1)
        y_pred_distrib = y_prob_tuned.sum(axis=0) / y_prob_tuned.sum()
        loss = torch.sum(torch.square(y_pred_distrib - y_true_distrib))
        if t % 1000 == 999:
            print(f"========= iteration: {t} ========= {torch.max(torch.abs(y_pred_distrib - y_true_distrib))}")
        loss.backward()
        with torch.no_grad():
            temperature -= learning_rate * temperature.grad
            # Manually zero the gradients after updating weights
            temperature.grad.zero_()

    return temperature.detach().numpy()


In [8]:
auto_tuned_temperature = auto_tuning_temperature(model, x_valid, y_valid)


def evaluate_with_temperature(model, x_test, y_test, temperature):
    y_prob = model.predict_proba(x_test)
    y_prob_tuned = softmax(np.log(y_prob) / temperature, axis=1)
    y_pred_tuned = np.argmax(y_prob_tuned, axis=1)

    true_label_counter = Counter(y_test)
    y_true_distrib = [true_label_counter[key]/len(y_test) for key in sorted(true_label_counter.keys())]
    pred_label_counter = Counter(y_pred_tuned)
    y_tune_distrib = [pred_label_counter[key]/len(y_test) for key in sorted(pred_label_counter.keys())]

    print("diff:", np.max(np.abs(np.array(y_true_distrib) - np.array(y_tune_distrib))))
    print(classification_report(y_pred=y_pred_tuned, y_true=y_test))
    print(confusion_matrix(y_pred=y_pred_tuned, y_true=y_test))

base_temperature = [1.0] * len(Counter(y_valid))
evaluate_with_temperature(model, x_valid, y_valid, auto_tuned_temperature)
evaluate_with_temperature(model, x_valid, y_valid, base_temperature)

base_temperature = [1.0] * len(Counter(y_test))
evaluate_with_temperature(model, x_test, y_test, auto_tuned_temperature)
evaluate_with_temperature(model, x_test, y_test, base_temperature)

diff: 0.10674999999999996
              precision    recall  f1-score   support

           0       0.79      0.26      0.39       202
           1       0.61      0.36      0.45       444
           2       0.79      0.83      0.81       430
           3       0.57      0.50      0.53       951
           4       0.72      0.88      0.79      1973

    accuracy                           0.69      4000
   macro avg       0.70      0.57      0.60      4000
weighted avg       0.69      0.69      0.67      4000

[[  53   11   56   13   69]
 [   0  160    1  135  148]
 [   5    0  358    9   58]
 [   1   64   22  471  393]
 [   8   26   14  193 1732]]
diff: 0.12925000000000003
              precision    recall  f1-score   support

           0       0.78      0.24      0.37       202
           1       0.64      0.28      0.39       444
           2       0.87      0.79      0.83       430
           3       0.56      0.50      0.53       951
           4       0.70      0.89      0.78    