In [1]:
from torch.autograd import Variable


from IPython.display import clear_output

import fedlern.utils as u
from fedlern.train_utils import *
from fedlern.quant_utils import *
from fedlern.models.resnet_v2 import ResNet18
from fedlern.quantize import quantize
from tqdm.notebook import tqdm

In [2]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Hyper-parameters
stats = (0.49139968, 0.48215841, 0.44653091), (0.24703223, 0.24348513, 0.26158784)
batch_size_test = 64 #250
batch_size_train = 64 #128

quantize_nbits = 16
num_epochs = 150
learning_rate = 0.0001
eta_rate = 1.05
eta = 1
#global best_acc
best_acc = 0
best_count = 0

In [3]:
train_loader, test_loader = prepare_dataloader_cifar(num_workers=8, 
                                                     train_batch_size=batch_size_train, 
                                                     eval_batch_size=batch_size_test, 
                                                     stats=stats)

    

Files already downloaded and verified


In [4]:
net = ResNet18()
net.to(device)


ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (layer1): Sequential(
    (0): BasicBlock(
      (skip_add): FloatFunctional(
        (activation_post_process): Identity()
      )
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (shortcut): Sequential()
    )
    (1): BasicBlock(
      (skip_add): FloatFunctional(
        (activation_post_process): Identity()
      )
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (

In [5]:
optimizer = get_model_optimizer(net, learning_rate=0.1, weight_decay=5e-4)
criterion = nn.CrossEntropyLoss().cuda()
    
# all_G_kernels = [
#     Variable(kernel.data.clone(), requires_grad=True)
#     for kernel in optimizer.param_groups[1]['params']
# ]
# Copy the parameters
all_G_kernels = [kernel.data.clone().requires_grad_(True)
                for kernel in optimizer.param_groups[1]['params']]

#all_W_kernels = [kernel for kernel in optimizer.param_groups[1]['params']]
all_W_kernels = optimizer.param_groups[1]['params']
kernels = [{'params': all_G_kernels}]
optimizer_quant = optim.SGD(kernels, lr=0)

scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[80,120,160], gamma=0.1)

In [6]:

#for epoch in tqdm(range(num_epochs)):    num_epochs = 200
for epoch in (pbar := tqdm(range(num_epochs))):
    #print('Epoch ID', epoch)
    #----------------------------------------------------------------------
    # Training
    #----------------------------------------------------------------------
    pbar.set_description(f"Training {epoch}",refresh=True)
    correct = 0; total = 0; train_loss = 0
    net.train()
    for batch_idx, (x, target) in enumerate(tqdm(train_loader, leave=False)):
    #for batch_idx, (x, target) in enumerate(train_loader):
        #if batch_idx < 1:
        optimizer.zero_grad()
        x, target = Variable(x.cuda()), Variable(target.cuda())
        all_W_kernels = optimizer.param_groups[1]['params']
        all_G_kernels = optimizer_quant.param_groups[0]['params']
        
        for k_W, k_G in zip(all_W_kernels, all_G_kernels):
            V = k_W.data
            #print(type(V))
            #####Binary Connect#########################
            #k_G.data = quantize_bw(V)
            ############################################
            
            ######Binary Relax##########################
            if epoch<120:
                #k_G.data = (eta*quantize_bw(V)+V)/(1+eta)
                k_G.data = (eta*quantize(V,num_bits=quantize_nbits)+V)/(1+eta)
                
            else:
                k_G.data = quantize(V, num_bits=quantize_nbits)
            #############################################
            
            k_W.data, k_G.data = k_G.data, k_W.data
            
            
        score = net(x)
        loss = criterion(score, target)
        loss.backward()
        
        for k_W, k_G in zip(all_W_kernels, all_G_kernels):
            k_W.data, k_G.data = k_G.data, k_W.data
        
        optimizer.step()
        
        train_loss += loss.data
        _, predicted = torch.max(score.data, 1)
        total += target.size(0)
        correct += predicted.eq(target.data).cpu().sum()
        pbar.write(f"Epoch: {epoch} Training {batch_idx} Loss: {train_loss/(batch_idx+1):.3f} | Acc: {100.*correct/total:.3f}% ({correct}/{total})")
        
    #----------------------------------------------------------------------
    # Testing
    #----------------------------------------------------------------------
    pbar.set_description(f"Testing {epoch}",refresh=True)
    test_loss = 0; correct = 0; total = 0
    net.eval()
    
    for k_W, k_G in zip(all_W_kernels, all_G_kernels):
        k_W.data, k_G.data = k_G.data, k_W.data

    with torch.no_grad():
        for batch_idx, (x, target) in enumerate(tqdm(test_loader, leave=False)):
            x, target = Variable(x.cuda()), Variable(target.cuda())
            score= net(x)
            
            loss = criterion(score, target)
            test_loss += loss.data
            _, predicted = torch.max(score.data, 1)
            total += target.size(0)
            correct += predicted.eq(target.data).cpu().sum()
            pbar.write(f"Testing {batch_idx} Loss: {train_loss/(batch_idx+1):.3f} | Acc: {100.*correct/total:.3f}% ({correct}/{total})")

    
    #----------------------------------------------------------------------
    # Save the checkpoint
    #----------------------------------------------------------------------
    '''
    for i in range(len(all_W_kernels)):
        k_W = all_W_kernels[i]
        k_quant = all_G_kernels[i]    
        k_W.data, k_quant.data = k_quant.data, k_W.data
        '''  
    acc = 100.*correct/total
    #if acc > best_acc:
    if correct > best_count:
        # print('Saving model...')
        # state = {
        #     'state': net.state_dict(), #net,
        #     'acc': acc,
        #     'epoch': epoch,
        # }
        
        # torch.save(state, f'./saved_models/resnet_{quantize_nbits}bits_{u.time_stamp()}.pth')
        #net.save_state_dict('resnet20.pt')
        best_acc = acc
        best_count = correct

    for k_W, k_G in zip(all_W_kernels, all_G_kernels):
        k_W.data, k_G.data = k_G.data, k_W.data
    clear_output(wait=True)
    
save_model(net, "saved_models", f'resnet_{quantize_nbits}bits_{u.time_stamp()}.pth')

  0%|          | 0/391 [00:00<?, ?it/s]

Epoch: 149 Training 0 Loss: 0.328 | Acc: 88.281% (113/128)
Epoch: 149 Training 1 Loss: 0.307 | Acc: 90.625% (232/256)
Epoch: 149 Training 2 Loss: 0.277 | Acc: 90.885% (349/384)
Epoch: 149 Training 3 Loss: 0.293 | Acc: 90.039% (461/512)
Epoch: 149 Training 4 Loss: 0.295 | Acc: 89.688% (574/640)
Epoch: 149 Training 5 Loss: 0.292 | Acc: 89.714% (689/768)
Epoch: 149 Training 6 Loss: 0.295 | Acc: 89.732% (804/896)
Epoch: 149 Training 7 Loss: 0.284 | Acc: 90.137% (923/1024)
Epoch: 149 Training 8 Loss: 0.283 | Acc: 90.104% (1038/1152)
Epoch: 149 Training 9 Loss: 0.291 | Acc: 90.078% (1153/1280)
Epoch: 149 Training 10 Loss: 0.298 | Acc: 89.986% (1267/1408)
Epoch: 149 Training 11 Loss: 0.292 | Acc: 90.299% (1387/1536)
Epoch: 149 Training 12 Loss: 0.298 | Acc: 90.144% (1500/1664)
Epoch: 149 Training 13 Loss: 0.293 | Acc: 90.234% (1617/1792)
Epoch: 149 Training 14 Loss: 0.290 | Acc: 90.208% (1732/1920)
Epoch: 149 Training 15 Loss: 0.285 | Acc: 90.332% (1850/2048)
Epoch: 149 Training 16 Loss: 0.28

  0%|          | 0/40 [00:00<?, ?it/s]

Testing 0 Loss: 121.079 | Acc: 79.200% (198/250)
Testing 1 Loss: 60.539 | Acc: 79.200% (396/500)
Testing 2 Loss: 40.360 | Acc: 79.600% (597/750)
Testing 3 Loss: 30.270 | Acc: 79.800% (798/1000)
Testing 4 Loss: 24.216 | Acc: 78.800% (985/1250)
Testing 5 Loss: 20.180 | Acc: 79.267% (1189/1500)
Testing 6 Loss: 17.297 | Acc: 79.257% (1387/1750)
Testing 7 Loss: 15.135 | Acc: 79.350% (1587/2000)
Testing 8 Loss: 13.453 | Acc: 79.600% (1791/2250)
Testing 9 Loss: 12.108 | Acc: 79.720% (1993/2500)
Testing 10 Loss: 11.007 | Acc: 79.527% (2187/2750)
Testing 11 Loss: 10.090 | Acc: 79.467% (2384/3000)
Testing 12 Loss: 9.314 | Acc: 79.231% (2575/3250)
Testing 13 Loss: 8.648 | Acc: 79.029% (2766/3500)
Testing 14 Loss: 8.072 | Acc: 79.147% (2968/3750)
Testing 15 Loss: 7.567 | Acc: 79.200% (3168/4000)
Testing 16 Loss: 7.122 | Acc: 79.459% (3377/4250)
Testing 17 Loss: 6.727 | Acc: 79.444% (3575/4500)
Testing 18 Loss: 6.373 | Acc: 79.368% (3770/4750)
Testing 19 Loss: 6.054 | Acc: 79.460% (3973/5000)
Testi