-
Notifications
You must be signed in to change notification settings - Fork 1
/
learning.py
138 lines (117 loc) · 5.18 KB
/
learning.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
# -*- coding:utf-8 -*-
import torch
import torch.nn.functional as F
from torch.autograd import Variable
from data.datasets import input_dataset
from models import *
from utils import *
import argparse
import time
import os
parser = argparse.ArgumentParser()
parser.add_argument('--lr', type = float, default = 0.1)
parser.add_argument('--val_ratio', type = float, default = 0.1)
parser.add_argument('--noise_type', type = str, help='clean, aggre, worst, rand1, rand2, rand3, clean100, noisy100', default='clean')
parser.add_argument('--noise_path', type = str, help='path of CIFAR-10_human.pt', default=None)
parser.add_argument('--dataset', type = str, help = ' cifar10 or cifar100', default = 'cifar10')
parser.add_argument('--n_epoch', type=int, default=100)
parser.add_argument('--seed', type=int, default=0) # we will test your code with 5 different seeds. The seeds are generated randomly and fixed for all participants.
parser.add_argument('--print_freq', type=int, default=50)
parser.add_argument('--num_workers', type=int, default=4, help='how many subprocesses to use for data loading')
# Train the Model
def train(epoch, train_loader, model, optimizer):
train_total=0
train_correct=0
model.train()
for i, (images, labels, indexes) in enumerate(train_loader):
batch_size = indexes.shape[0]
images =images.to(args.device)
labels =labels.to(args.device)
# Forward + Backward + Optimize
logits = model(images)
prec, _ = accuracy(logits, labels, topk=(1, 5))
train_total+=1
train_correct+=prec
loss = F.cross_entropy(logits, labels, reduce = True)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if (i+1) % args.print_freq == 0:
print ('Epoch [%d/%d], Iter [%d/%d] Training Accuracy: %.4F, Loss: %.4f'
%(epoch+1, args.n_epoch, i+1, len(train_dataset)//batch_size, prec, loss.data))
train_acc=float(train_correct)/float(train_total)
return train_acc
# Evaluate the Model
def evaluate(loader, teacher, student, save=False, best_acc = 0.0):
teacher.eval()
student.eval()
correct = 0
total = 0
for images, labels, _ in loader:
images = Variable(images).to(args.device)
#logits = model(images)
pred_teacher = teacher(images)
pred_student = student(images)
p_teacher = pred_teacher.softmax(dim=1)
p_student = pred_student.softmax(dim=1)
p_agree = p_teacher * p_student
p_agree = p_agree / p_agree.sum(dim=1, keepdims=True)
p_id_agree = torch.argmax(p_agree, dim=1)
total += labels.size(0)
correct += (p_id_agree.cpu() == labels).sum()
acc = 100*float(correct)/float(total)
# if save:
# if acc > best_acc:
# state = {'state_dict': model.state_dict(),
# 'epoch':epoch,
# 'acc':acc,
# }
# save_path= os.path.join('./', args.noise_type +'best.pth.tar')
# torch.save(state,save_path)
# best_acc = acc
# print(f'model saved to {save_path}!')
return acc
##################################### main code ################################################
args = parser.parse_args()
# Seed
set_global_seeds(args.seed)
args.device = set_device()
time_start = time.time()
# Hyper Parameters
batch_size = 128
learning_rate = args.lr
path = f"./results/{args.dataset}_{args.noise_type}_seed_{args.seed}"
noise_type_map = {'clean':'clean_label', 'worst': 'worse_label', 'aggre': 'aggre_label', 'rand1': 'random_label1', 'rand2': 'random_label2', 'rand3': 'random_label3', 'clean100': 'clean_label', 'noisy100': 'noisy_label'}
args.noise_type = noise_type_map[args.noise_type]
# load dataset
if args.noise_path is None:
if args.dataset == 'cifar10':
args.noise_path = './data/CIFAR-10_human.pt'
elif args.dataset == 'cifar100':
args.noise_path = './data/CIFAR-100_human.pt'
else:
raise NameError(f'Undefined dataset {args.dataset}')
_, _, _, test_dataset, num_classes, num_training_samples = input_dataset(args.dataset,args.noise_type, args.noise_path, is_human = True, val_ratio = args.val_ratio)
# load model
print('building model...')
teacher = PreResNet18(num_classes)
teacher.to(args.device)
student = PreResNet18(num_classes)
student.to(args.device)
print('building model done')
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
batch_size=64,
num_workers=args.num_workers,
shuffle=False)
# we will test the model by the following code
# loading teacher state
state_dict_teacher = torch.load(os.path.join(path, "teacher.pth.tar"), map_location="cpu")
teacher.load_state_dict(state_dict_teacher['state_dict'])
teacher.to(args.device)
# loading student state
state_dict_student = torch.load(os.path.join(path, "student.pth.tar"), map_location="cpu")
student.load_state_dict(state_dict_student['state_dict'])
student.to(args.device)
# calculate test acc
test_acc = evaluate(loader=test_loader, teacher=teacher, student=student, save=False)
print(f'Best test acc selected by val is {test_acc}')