In [1]:
from torch.autograd import Variable


from IPython.display import clear_output

import fedlern.utils as u
from fedlern.train_utils import *
from fedlern.quant_utils import *
from fedlern.models.resnet_v2 import ResNet18

In [2]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Hyper-parameters
stats = (0.49139968, 0.48215841, 0.44653091), (0.24703223, 0.24348513, 0.26158784)
batch_size_test = 250
batch_size_train = 128

quantize_nbits = 8
num_epochs = 200
learning_rate = 0.001
eta_rate = 1.05
eta = 1
#global best_acc
best_acc = 0
best_count = 0

In [3]:
train_loader, test_loader = prepare_dataloader_cifar(num_workers=8, 
                                                     train_batch_size=batch_size_train, 
                                                     eval_batch_size=batch_size_test, 
                                                     stats=stats)

    

Files already downloaded and verified


In [4]:
net = ResNet18()
net.to(device)


ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (shortcut): Sequential()
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1

In [5]:
net, criterion, optimizer = get_model2(net, learning_rate=0.1, weight_decay=5e-4)

    
all_G_kernels = [
    Variable(kernel.data.clone(), requires_grad=True)
    for kernel in optimizer.param_groups[1]['params']
]


all_W_kernels = [kernel for kernel in optimizer.param_groups[1]['params']]
kernels = [{'params': all_G_kernels}]
optimizer_quant = optim.SGD(kernels, lr=0)

scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[80,120,160], gamma=0.1)

In [6]:

#for epoch in tqdm(range(num_epochs)):    num_epochs = 200
for epoch in (pbar := tqdm(range(num_epochs))):
    #print('Epoch ID', epoch)
    #----------------------------------------------------------------------
    # Training
    #----------------------------------------------------------------------
    pbar.set_description(f"Training {epoch}",refresh=True)
    correct = 0; total = 0; train_loss = 0
    net.train()
    for batch_idx, (x, target) in enumerate(tqdm(train_loader, leave=False)):
    #for batch_idx, (x, target) in enumerate(train_loader):
        #if batch_idx < 1:
        optimizer.zero_grad()
        x, target = Variable(x.cuda()), Variable(target.cuda())
        all_W_kernels = optimizer.param_groups[1]['params']
        all_G_kernels = optimizer_quant.param_groups[0]['params']
        
        for i in range(len(all_W_kernels)):
            k_W = all_W_kernels[i]
            k_G = all_G_kernels[i]
            V = k_W.data
            #print(type(V))
            #####Binary Connect#########################
            #k_G.data = quantize_bw(V)
            ############################################
            
            ######Binary Relax##########################
            if epoch<120:
                #k_G.data = (eta*quantize_bw(V)+V)/(1+eta)
                k_G.data = (eta*quantize(V,num_bits=quantize_nbits)+V)/(1+eta)
                
            else:
                k_G.data = quantize(V, num_bits=quantize_nbits)
            #############################################
            
            k_W.data, k_G.data = k_G.data, k_W.data
            
            
        score = net(x)
        loss = criterion(score, target)
        loss.backward()
        
        for i in range(len(all_W_kernels)):
            k_W = all_W_kernels[i]
            k_G = all_G_kernels[i]
            k_W.data, k_G.data = k_G.data, k_W.data
        
        optimizer.step()
        
        train_loss += loss.data
        _, predicted = torch.max(score.data, 1)
        total += target.size(0)
        correct += predicted.eq(target.data).cpu().sum()
        pbar.write(f"Training {batch_idx} Loss: {train_loss/(batch_idx+1):.3f} | Acc: {100.*correct/total:.3f}% ({correct}/{total})")
        
    #----------------------------------------------------------------------
    # Testing
    #----------------------------------------------------------------------
    pbar.set_description(f"Testing {epoch}",refresh=True)
    test_loss = 0; correct = 0; total = 0
    net.eval()
    
    for i in range(len(all_W_kernels)):
        k_W = all_W_kernels[i]
        k_quant = all_G_kernels[i]    
        k_W.data, k_quant.data = k_quant.data, k_W.data
    with torch.no_grad():
        for batch_idx, (x, target) in enumerate(tqdm(test_loader, leave=False)):
        #for batch_idx, (x, target) in enumerate(test_loader):
            x, target = Variable(x.cuda()), Variable(target.cuda())
            score= net(x)
            
            loss = criterion(score, target)
            test_loss += loss.data
            _, predicted = torch.max(score.data, 1)
            total += target.size(0)
            correct += predicted.eq(target.data).cpu().sum()
            pbar.write(f"Testing {batch_idx} Loss: {train_loss/(batch_idx+1):.3f} | Acc: {100.*correct/total:.3f}% ({correct}/{total})")

    
    #----------------------------------------------------------------------
    # Save the checkpoint
    #----------------------------------------------------------------------
    '''
    for i in range(len(all_W_kernels)):
        k_W = all_W_kernels[i]
        k_quant = all_G_kernels[i]    
        k_W.data, k_quant.data = k_quant.data, k_W.data
        '''  
    acc = 100.*correct/total
    #if acc > best_acc:
    if correct > best_count:
        # print('Saving model...')
        # state = {
        #     'state': net.state_dict(), #net,
        #     'acc': acc,
        #     'epoch': epoch,
        # }
        
        # torch.save(state, f'./saved_models/resnet_{quantize_nbits}bits_{u.time_stamp()}.pth')
        #net.save_state_dict('resnet20.pt')
        best_acc = acc
        best_count = correct

    for i in range(len(all_W_kernels)):
        k_W=all_W_kernels[i]
        k_quant=all_W_kernels[i]
        k_W.data, k_quant.data =k_quant.data,k_W.data
    clear_output(wait=True)
    
save_model(net, "saved_models", f'resnet_{quantize_nbits}bits_{u.time_stamp()}.pt')

  0%|          | 0/391 [00:00<?, ?it/s]

Training 0 Loss: 0.367 | Acc: 85.938% (110/128)
Training 1 Loss: 0.300 | Acc: 89.062% (228/256)
Training 2 Loss: 0.300 | Acc: 88.542% (340/384)
Training 3 Loss: 0.294 | Acc: 89.453% (458/512)
Training 4 Loss: 0.281 | Acc: 90.156% (577/640)
Training 5 Loss: 0.269 | Acc: 91.016% (699/768)
Training 6 Loss: 0.281 | Acc: 90.513% (811/896)
Training 7 Loss: 0.273 | Acc: 90.820% (930/1024)
Training 8 Loss: 0.281 | Acc: 90.625% (1044/1152)
Training 9 Loss: 0.269 | Acc: 91.172% (1167/1280)
Training 10 Loss: 0.279 | Acc: 90.767% (1278/1408)
Training 11 Loss: 0.273 | Acc: 91.016% (1398/1536)
Training 12 Loss: 0.276 | Acc: 90.805% (1511/1664)
Training 13 Loss: 0.274 | Acc: 90.848% (1628/1792)
Training 14 Loss: 0.278 | Acc: 90.781% (1743/1920)
Training 15 Loss: 0.284 | Acc: 90.576% (1855/2048)
Training 16 Loss: 0.281 | Acc: 90.717% (1974/2176)
Training 17 Loss: 0.275 | Acc: 90.929% (2095/2304)
Training 18 Loss: 0.279 | Acc: 90.666% (2205/2432)
Training 19 Loss: 0.280 | Acc: 90.664% (2321/2560)
Train

  0%|          | 0/40 [00:00<?, ?it/s]

Testing 0 Loss: 118.244 | Acc: 86.000% (215/250)
Testing 1 Loss: 59.122 | Acc: 85.200% (426/500)
Testing 2 Loss: 39.415 | Acc: 84.933% (637/750)
Testing 3 Loss: 29.561 | Acc: 84.100% (841/1000)
Testing 4 Loss: 23.649 | Acc: 83.280% (1041/1250)
Testing 5 Loss: 19.707 | Acc: 83.667% (1255/1500)
Testing 6 Loss: 16.892 | Acc: 83.600% (1463/1750)
Testing 7 Loss: 14.780 | Acc: 83.600% (1672/2000)
Testing 8 Loss: 13.138 | Acc: 83.733% (1884/2250)
Testing 9 Loss: 11.824 | Acc: 83.280% (2082/2500)
Testing 10 Loss: 10.749 | Acc: 83.382% (2293/2750)
Testing 11 Loss: 9.854 | Acc: 83.367% (2501/3000)
Testing 12 Loss: 9.096 | Acc: 83.508% (2714/3250)
Testing 13 Loss: 8.446 | Acc: 83.457% (2921/3500)
Testing 14 Loss: 7.883 | Acc: 83.547% (3133/3750)
Testing 15 Loss: 7.390 | Acc: 83.675% (3347/4000)
Testing 16 Loss: 6.956 | Acc: 83.788% (3561/4250)
Testing 17 Loss: 6.569 | Acc: 83.889% (3775/4500)
Testing 18 Loss: 6.223 | Acc: 83.874% (3984/4750)
Testing 19 Loss: 5.912 | Acc: 83.720% (4186/5000)
Testi