In [1]:
from torch.autograd import Variable


from IPython.display import clear_output

import fedlern.utils as u
from fedlern.train_utils import *
from fedlern.quant_utils import *
from fedlern.models.resnet_v2 import ResNet18
from fedlern.quantize import quantize
from tqdm.notebook import tqdm

In [2]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Hyper-parameters
stats = (0.49139968, 0.48215841, 0.44653091), (0.24703223, 0.24348513, 0.26158784)
batch_size_test = 250
batch_size_train = 128

quantize_nbits = 4
num_epochs = 200
learning_rate = 0.001
eta_rate = 1.05
eta = 1
global best_acc
best_acc = 0
best_count = 0

In [3]:
train_loader, test_loader = prepare_dataloader_cifar(num_workers=8, 
                                                     train_batch_size=batch_size_train, 
                                                     eval_batch_size=batch_size_test, 
                                                     stats=stats)

    

Files already downloaded and verified


In [4]:
net = ResNet18()
net.to(device)


ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (shortcut): Sequential()
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1

In [5]:
net, criterion, optimizer = get_model2(net, learning_rate=0.1, weight_decay=5e-4)

    
all_G_kernels = [
    Variable(kernel.data.clone(), requires_grad=True)
    for kernel in optimizer.param_groups[1]['params']
]


all_W_kernels = [kernel for kernel in optimizer.param_groups[1]['params']]
kernels = [{'params': all_G_kernels}]
optimizer_quant = optim.SGD(kernels, lr=0)

scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[80,120,160], gamma=0.1)

In [6]:

#for epoch in tqdm(range(num_epochs)):    num_epochs = 200
for epoch in (pbar := tqdm(range(num_epochs))):
    #print('Epoch ID', epoch)
    #----------------------------------------------------------------------
    # Training
    #----------------------------------------------------------------------
    pbar.set_description(f"Training {epoch}",refresh=True)
    correct = 0; total = 0; train_loss = 0
    net.train()
    for batch_idx, (x, target) in enumerate(tqdm(train_loader, leave=False)):
    #for batch_idx, (x, target) in enumerate(train_loader):
        #if batch_idx < 1:
        optimizer.zero_grad()
        x, target = Variable(x.cuda()), Variable(target.cuda())
        all_W_kernels = optimizer.param_groups[1]['params']
        all_G_kernels = optimizer_quant.param_groups[0]['params']
        
        for i in range(len(all_W_kernels)):
            k_W = all_W_kernels[i]
            k_G = all_G_kernels[i]
            V = k_W.data
            #print(type(V))
            #####Binary Connect#########################
            #k_G.data = quantize_bw(V)
            ############################################
            
            ######Binary Relax##########################
            if epoch<120:
                #k_G.data = (eta*quantize_bw(V)+V)/(1+eta)
                k_G.data = (eta*quantize(V,num_bits=quantize_nbits)+V)/(1+eta)
                
            else:
                k_G.data = quantize(V, num_bits=quantize_nbits)
            #############################################
            
            k_W.data, k_G.data = k_G.data, k_W.data
            
            
        score = net(x)
        loss = criterion(score, target)
        loss.backward()
        
        for i in range(len(all_W_kernels)):
            k_W = all_W_kernels[i]
            k_G = all_G_kernels[i]
            k_W.data, k_G.data = k_G.data, k_W.data
        
        optimizer.step()
        
        train_loss += loss.data
        _, predicted = torch.max(score.data, 1)
        total += target.size(0)
        correct += predicted.eq(target.data).cpu().sum()
        pbar.write(f"Training {batch_idx} Loss: {train_loss/(batch_idx+1):.3f} | Acc: {100.*correct/total:.3f}% ({correct}/{total})")
        
    #----------------------------------------------------------------------
    # Testing
    #----------------------------------------------------------------------
    pbar.set_description(f"Testing {epoch}",refresh=True)
    test_loss = 0; correct = 0; total = 0
    net.eval()
    
    for i in range(len(all_W_kernels)):
        k_W = all_W_kernels[i]
        k_quant = all_G_kernels[i]    
        k_W.data, k_quant.data = k_quant.data, k_W.data
    with torch.no_grad():
        for batch_idx, (x, target) in enumerate(tqdm(test_loader, leave=False)):
        #for batch_idx, (x, target) in enumerate(test_loader):
            x, target = Variable(x.cuda()), Variable(target.cuda())
            score= net(x)
            
            loss = criterion(score, target)
            test_loss += loss.data
            _, predicted = torch.max(score.data, 1)
            total += target.size(0)
            correct += predicted.eq(target.data).cpu().sum()
            pbar.write(f"Testing {batch_idx} Loss: {train_loss/(batch_idx+1):.3f} | Acc: {100.*correct/total:.3f}% ({correct}/{total})")

    
    #----------------------------------------------------------------------
    # Save the checkpoint
    #----------------------------------------------------------------------
    '''
    for i in range(len(all_W_kernels)):
        k_W = all_W_kernels[i]
        k_quant = all_G_kernels[i]    
        k_W.data, k_quant.data = k_quant.data, k_W.data
        '''  
    acc = 100.*correct/total
    #if acc > best_acc:
    if correct > best_count:
        # print('Saving model...')
        # state = {
        #     'state': net.state_dict(), #net,
        #     'acc': acc,
        #     'epoch': epoch,
        # }
        
        # torch.save(state, f'./saved_models/resnet_{quantize_nbits}bits_{u.time_stamp()}.pth')
        #net.save_state_dict('resnet20.pt')
        best_acc = acc
        best_count = correct

    for i in range(len(all_W_kernels)):
        k_W=all_W_kernels[i]
        k_quant=all_W_kernels[i]
        k_W.data, k_quant.data =k_quant.data,k_W.data
    clear_output(wait=True)
    
save_model(net, "saved_models", f'resnet_{quantize_nbits}bits_{u.time_stamp()}.pth')

  0%|          | 0/391 [00:00<?, ?it/s]

Training 0 Loss: 0.374 | Acc: 85.938% (110/128)
Training 1 Loss: 0.380 | Acc: 86.719% (222/256)
Training 2 Loss: 0.375 | Acc: 86.198% (331/384)
Training 3 Loss: 0.368 | Acc: 86.914% (445/512)
Training 4 Loss: 0.377 | Acc: 87.031% (557/640)
Training 5 Loss: 0.366 | Acc: 87.630% (673/768)
Training 6 Loss: 0.358 | Acc: 88.058% (789/896)
Training 7 Loss: 0.354 | Acc: 88.184% (903/1024)
Training 8 Loss: 0.351 | Acc: 88.281% (1017/1152)
Training 9 Loss: 0.343 | Acc: 88.359% (1131/1280)
Training 10 Loss: 0.346 | Acc: 88.068% (1240/1408)
Training 11 Loss: 0.354 | Acc: 87.695% (1347/1536)
Training 12 Loss: 0.354 | Acc: 88.041% (1465/1664)
Training 13 Loss: 0.350 | Acc: 88.114% (1579/1792)
Training 14 Loss: 0.354 | Acc: 88.229% (1694/1920)
Training 15 Loss: 0.350 | Acc: 88.428% (1811/2048)
Training 16 Loss: 0.352 | Acc: 88.097% (1917/2176)
Training 17 Loss: 0.351 | Acc: 88.064% (2029/2304)
Training 18 Loss: 0.361 | Acc: 87.788% (2135/2432)
Training 19 Loss: 0.363 | Acc: 87.734% (2246/2560)
Train

  0%|          | 0/40 [00:00<?, ?it/s]

Testing 0 Loss: 147.301 | Acc: 77.200% (193/250)
Testing 1 Loss: 73.651 | Acc: 77.400% (387/500)
Testing 2 Loss: 49.100 | Acc: 78.933% (592/750)
Testing 3 Loss: 36.825 | Acc: 79.200% (792/1000)
Testing 4 Loss: 29.460 | Acc: 78.320% (979/1250)
Testing 5 Loss: 24.550 | Acc: 78.533% (1178/1500)
Testing 6 Loss: 21.043 | Acc: 79.029% (1383/1750)
Testing 7 Loss: 18.413 | Acc: 79.050% (1581/2000)
Testing 8 Loss: 16.367 | Acc: 78.756% (1772/2250)
Testing 9 Loss: 14.730 | Acc: 78.680% (1967/2500)
Testing 10 Loss: 13.391 | Acc: 78.945% (2171/2750)
Testing 11 Loss: 12.275 | Acc: 79.133% (2374/3000)
Testing 12 Loss: 11.331 | Acc: 79.231% (2575/3250)
Testing 13 Loss: 10.522 | Acc: 79.171% (2771/3500)
Testing 14 Loss: 9.820 | Acc: 79.333% (2975/3750)
Testing 15 Loss: 9.206 | Acc: 79.675% (3187/4000)
Testing 16 Loss: 8.665 | Acc: 79.788% (3391/4250)
Testing 17 Loss: 8.183 | Acc: 79.711% (3587/4500)
Testing 18 Loss: 7.753 | Acc: 79.537% (3778/4750)
Testing 19 Loss: 7.365 | Acc: 79.640% (3982/5000)
Tes