In [1]:
from model import CNN, loss_coteaching, JoCor_loss
from data.cifar import CIFAR10
from config import opt
import torch
import numpy as np
from tqdm import tqdm

In [2]:
torch.manual_seed(1)
torch.cuda.manual_seed(1)

In [3]:
data = CIFAR10()
train_loader = torch.utils.data.DataLoader(dataset=data,batch_size=128,drop_last=True,shuffle=True)

test_data = CIFAR10(train=False)
test_loader = torch.utils.data.DataLoader(dataset=test_data,batch_size=128)

cnn1 = CNN().cuda()
optimizer1 = torch.optim.Adam(cnn1.parameters(), lr= opt.learning_rate)

cnn2 = CNN().cuda()
optimizer2 = torch.optim.Adam(cnn2.parameters(), lr= opt.learning_rate)

Actual noise rate is 0.20076


In [4]:
def set_schedule(opt):

    alpha_plan = [opt.learning_rate] * opt.n_epoch
    beta_plan = [opt.mom1] * opt.n_epoch
    for i in range(opt.epoch_decay_start, opt.n_epoch):
        alpha_plan[i] = float(opt.n_epoch - i) / (opt.n_epoch - opt.epoch_decay_start) * opt.learning_rate
        beta_plan[i] = opt.mom2
    rate_schedule = np.ones(opt.n_epoch)* opt.forget_rate
    rate_schedule[:opt.num_gradual] = np.linspace(0, opt.forget_rate**opt.exponent, opt.num_gradual)
    return alpha_plan, beta_plan, rate_schedule

def adjust_learning_rate(optimizer, epoch, alpha_plan, beta_plan):
    for param_group in optimizer.param_groups:
        param_group['lr']=alpha_plan[epoch]
        param_group['betas']=(beta_plan[epoch], 0.999)

def train(train_loader,epoch, cnn1, cnn2, optimizer1, optimizer2, rate_schedule):
    correct_1 = 0
    correct_2 = 0
    
    loss_t_total = 0
    total_instance = 0 
    
    loss_instance = 0
    for i,(imgs, labels, index) in enumerate(train_loader):
        y1 = cnn1(imgs.cuda())
        y2 = cnn2(imgs.cuda())

        target = labels.long().cuda()
        loss_t, num_remember = JoCor_loss(y1, y2, target, rate_schedule[epoch], 0.85)

        correct_1 += sum(y1.argmax(axis = 1) ==  target)
        correct_2 += sum(y2.argmax(axis = 1) ==  target)
        
        total_instance += imgs.shape[0]
        loss_instance += num_remember
        loss_t_total += loss_t.item()

        optimizer2.zero_grad()
        optimizer1.zero_grad()
        
        loss_t.backward()
        
        optimizer1.step()
        optimizer2.step()
    
    return loss_t_total / loss_instance , correct_1.long().item() / total_instance , correct_2.long().item() / total_instance

def test(test_loader, cnn1, cnn2):
    correct_1 = 0
    correct_2 = 0
    
    total = 0
    for i,(imgs, labels, index) in enumerate(test_loader):
        y1 = cnn1(imgs.cuda())
        y2 = cnn2(imgs.cuda())
        target = labels.long().cuda()
        total += imgs.shape[0]
        correct_1 += sum(y1.argmax(axis = 1) ==  target)
        correct_2 += sum(y2.argmax(axis = 1) ==  target)
        
        
    return correct_1.long().item()/ total, correct_2.long().item()/ total

In [None]:
alpha_plan, beta_plan, rate_schedule = set_schedule(opt) 

for epoch in range(opt.n_epoch):
    cnn1.train()
    adjust_learning_rate(optimizer1, epoch, alpha_plan, beta_plan)
    cnn2.train()
    adjust_learning_rate(optimizer2, epoch, alpha_plan, beta_plan)
    
    loss_1_total, correct_1, correct_2 = train(train_loader,epoch, cnn1, cnn2, optimizer1, optimizer2, rate_schedule)
    
    cnn1.eval()
    cnn2.eval()
    acc1_test, acc2_test = test(test_loader, cnn1, cnn2)
    print('epoch',epoch,'|loss:' '%.4f' % loss_1_total ,'|acc1:''%.3f' % correct_1 , '| acc2:''%.3f' % correct_2, '|acc1_t:''%.3f' % acc1_test , '| acc2_t:''%.3f' % acc2_test)

epoch 0 |loss:0.0049 |acc1:0.272 | acc2:0.274 |acc1_t:0.415 | acc2_t:0.416
epoch 1 |loss:0.0044 |acc1:0.382 | acc2:0.383 |acc1_t:0.502 | acc2_t:0.502
epoch 2 |loss:0.0042 |acc1:0.439 | acc2:0.439 |acc1_t:0.555 | acc2_t:0.554
epoch 3 |loss:0.0039 |acc1:0.478 | acc2:0.478 |acc1_t:0.608 | acc2_t:0.608
epoch 4 |loss:0.0037 |acc1:0.509 | acc2:0.509 |acc1_t:0.630 | acc2_t:0.638
epoch 5 |loss:0.0035 |acc1:0.529 | acc2:0.528 |acc1_t:0.654 | acc2_t:0.657
epoch 6 |loss:0.0034 |acc1:0.547 | acc2:0.549 |acc1_t:0.672 | acc2_t:0.669
epoch 7 |loss:0.0032 |acc1:0.560 | acc2:0.560 |acc1_t:0.685 | acc2_t:0.682
epoch 8 |loss:0.0030 |acc1:0.572 | acc2:0.572 |acc1_t:0.688 | acc2_t:0.699
epoch 9 |loss:0.0029 |acc1:0.581 | acc2:0.580 |acc1_t:0.701 | acc2_t:0.703
epoch 10 |loss:0.0028 |acc1:0.590 | acc2:0.590 |acc1_t:0.713 | acc2_t:0.709
epoch 11 |loss:0.0027 |acc1:0.599 | acc2:0.600 |acc1_t:0.722 | acc2_t:0.723
epoch 12 |loss:0.0026 |acc1:0.607 | acc2:0.608 |acc1_t:0.729 | acc2_t:0.728
epoch 13 |loss:0.0025 