-
Notifications
You must be signed in to change notification settings - Fork 37
/
07_optimizing_multi_label_auroc_loss_with_densenet121_on_chexpert.py
129 lines (103 loc) · 4.27 KB
/
07_optimizing_multi_label_auroc_loss_with_densenet121_on_chexpert.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
"""
# Optimizing Multi-label AUROC loss on Chest X-Ray Dataset (CheXpert)
Author: Zhuoning Yuan
Reference:
If you find this tutorial helpful in your work, please acknowledge our library and cite the following paper:
@inproceedings{yuan2021large,
title={Large-scale robust deep auc maximization: A new surrogate loss and empirical studies on medical image classification},
author={Yuan, Zhuoning and Yan, Yan and Sonka, Milan and Yang, Tianbao},
booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision},
pages={3040--3049},
year={2021}
}
@misc{libauc2022,
title={LibAUC: A Deep Learning Library for X-Risk Optimization.},
author={Zhuoning Yuan, Zi-Hao Qiu, Gang Li, Dixian Zhu, Zhishuai Guo, Quanqi Hu, Bokun Wang, Qi Qi, Yongjian Zhong, Tianbao Yang},
year={2022}
}
"""
from libauc.losses import AUCM_MultiLabel, CrossEntropyLoss
from libauc.optimizers import PESG, Adam
from libauc.models import densenet121 as DenseNet121
from libauc.datasets import CheXpert
from libauc.metrics import auc_roc_score # for multi-task
import torch
from PIL import Image
import numpy as np
import torchvision.transforms as transforms
from torch.utils.data import Dataset
from sklearn.metrics import roc_auc_score
import torch.nn.functional as F
def set_all_seeds(SEED):
# REPRODUCIBILITY
torch.manual_seed(SEED)
np.random.seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# dataloader
root = './CheXpert/CheXpert-v1.0-small/'
# Index: -1 denotes multi-label mode including 5 diseases
traindSet = CheXpert(csv_path=root+'train.csv', image_root_path=root, use_upsampling=False, use_frontal=True, image_size=224, mode='train', class_index=-1, verbose=False)
testSet = CheXpert(csv_path=root+'valid.csv', image_root_path=root, use_upsampling=False, use_frontal=True, image_size=224, mode='valid', class_index=-1, verbose=False)
trainloader = torch.utils.data.DataLoader(traindSet, batch_size=32, num_workers=2, shuffle=True)
testloader = torch.utils.data.DataLoader(testSet, batch_size=32, num_workers=2, shuffle=False)
# check imbalance ratio for each task
print (traindSet.imratio_list )
# paramaters
SEED = 123
BATCH_SIZE = 32
lr = 0.1
gamma = 500
weight_decay = 1e-5
margin = 1.0
# model
set_all_seeds(SEED)
model = DenseNet121(pretrained=True, last_activation=None, activations='relu', num_classes=5)
model = model.cuda()
# define loss & optimizer
Loss = AUCM_MultiLabel(num_classes=5)
optimizer = PESG(model,
a=Loss.a,
b=Loss.b,
alpha=Loss.alpha,
lr=lr,
gamma=gamma,
margin=margin,
weight_decay=weight_decay, device='cuda')
# training
best_val_auc = 0
for epoch in range(2):
if epoch > 0:
optimizer.update_regularizer(decay_factor=10)
for idx, data in enumerate(trainloader):
train_data, train_labels = data
train_data, train_labels = train_data.cuda(), train_labels.cuda()
y_pred = model(train_data)
y_pred = torch.sigmoid(y_pred)
loss = Loss(y_pred, train_labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# validation
if idx % 400 == 0:
model.eval()
with torch.no_grad():
test_pred = []
test_true = []
for jdx, data in enumerate(testloader):
test_data, test_labels = data
test_data = test_data.cuda()
y_pred = model(test_data)
y_pred = torch.sigmoid(y_pred)
test_pred.append(y_pred.cpu().detach().numpy())
test_true.append(test_labels.numpy())
test_true = np.concatenate(test_true)
test_pred = np.concatenate(test_pred)
val_auc_mean = roc_auc_score(test_true, test_pred)
model.train()
if best_val_auc < val_auc_mean:
best_val_auc = val_auc_mean
torch.save(model.state_dict(), 'aucm_multi_label_pretrained_model.pth')
print ('Epoch=%s, BatchID=%s, Val_AUC=%.4f, Best_Val_AUC=%.4f'%(epoch, idx, val_auc_mean, best_val_auc))
# show auc roc scores for each task
auc_roc_score(test_true, test_pred)