In [25]:
from torch.autograd import Variable


from IPython.display import clear_output

import fedlern.utils as u
from fedlern.train_utils import *
from fedlern.quant_utils import *
from fedlern.models.resnet_v2 import ResNet18
from fedlern.quantize import quantize
from tqdm.notebook import tqdm

In [26]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Hyper-parameters
stats = (0.49139968, 0.48215841, 0.44653091), (0.24703223, 0.24348513, 0.26158784)
batch_size_test = 250
batch_size_train = 128

quantize_nbits = 8
num_epochs = 150
learning_rate = 0.001
eta_rate = 1.05
eta = 1
#global best_acc
best_acc = 0
best_count = 0

In [27]:
train_loader, test_loader = prepare_dataloader_cifar(num_workers=8, 
                                                     train_batch_size=batch_size_train, 
                                                     eval_batch_size=batch_size_test, 
                                                     stats=stats)

    

Files already downloaded and verified


In [28]:
net = ResNet18()
net.to(device)


ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (shortcut): Sequential()
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1

In [29]:
optimizer = get_model_optimizer(net, learning_rate=0.1, weight_decay=5e-4)
criterion = nn.CrossEntropyLoss().cuda()
    
# all_G_kernels = [
#     Variable(kernel.data.clone(), requires_grad=True)
#     for kernel in optimizer.param_groups[1]['params']
# ]
# Copy the parameters
all_G_kernels = [kernel.data.clone().requires_grad_(True)
                for kernel in optimizer.param_groups[1]['params']]

#all_W_kernels = [kernel for kernel in optimizer.param_groups[1]['params']]
all_W_kernels = optimizer.param_groups[1]['params']
kernels = [{'params': all_G_kernels}]
optimizer_quant = optim.SGD(kernels, lr=0)

scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[80,120,160], gamma=0.1)

In [30]:

#for epoch in tqdm(range(num_epochs)):    num_epochs = 200
for epoch in (pbar := tqdm(range(num_epochs))):
    #print('Epoch ID', epoch)
    #----------------------------------------------------------------------
    # Training
    #----------------------------------------------------------------------
    pbar.set_description(f"Training {epoch}",refresh=True)
    correct = 0; total = 0; train_loss = 0
    net.train()
    for batch_idx, (x, target) in enumerate(tqdm(train_loader, leave=False)):
    #for batch_idx, (x, target) in enumerate(train_loader):
        #if batch_idx < 1:
        optimizer.zero_grad()
        x, target = Variable(x.cuda()), Variable(target.cuda())
        all_W_kernels = optimizer.param_groups[1]['params']
        all_G_kernels = optimizer_quant.param_groups[0]['params']
        
        for k_W, k_G in zip(all_W_kernels, all_G_kernels):
            V = k_W.data
            #print(type(V))
            #####Binary Connect#########################
            #k_G.data = quantize_bw(V)
            ############################################
            
            ######Binary Relax##########################
            if epoch<120:
                #k_G.data = (eta*quantize_bw(V)+V)/(1+eta)
                k_G.data = (eta*quantize(V,num_bits=quantize_nbits)+V)/(1+eta)
                
            else:
                k_G.data = quantize(V, num_bits=quantize_nbits)
            #############################################
            
            k_W.data, k_G.data = k_G.data, k_W.data
            
            
        score = net(x)
        loss = criterion(score, target)
        loss.backward()
        
        for k_W, k_G in zip(all_W_kernels, all_G_kernels):
            k_W.data, k_G.data = k_G.data, k_W.data
        
        optimizer.step()
        
        train_loss += loss.data
        _, predicted = torch.max(score.data, 1)
        total += target.size(0)
        correct += predicted.eq(target.data).cpu().sum()
        pbar.write(f"Epoch: {epoch} Training {batch_idx} Loss: {train_loss/(batch_idx+1):.3f} | Acc: {100.*correct/total:.3f}% ({correct}/{total})")
        
    #----------------------------------------------------------------------
    # Testing
    #----------------------------------------------------------------------
    pbar.set_description(f"Testing {epoch}",refresh=True)
    test_loss = 0; correct = 0; total = 0
    net.eval()
    
    for k_W, k_G in zip(all_W_kernels, all_G_kernels):
        k_W.data, k_G.data = k_G.data, k_W.data

    with torch.no_grad():
        for batch_idx, (x, target) in enumerate(tqdm(test_loader, leave=False)):
        #for batch_idx, (x, target) in enumerate(test_loader):
            x, target = Variable(x.cuda()), Variable(target.cuda())
            score= net(x)
            
            loss = criterion(score, target)
            test_loss += loss.data
            _, predicted = torch.max(score.data, 1)
            total += target.size(0)
            correct += predicted.eq(target.data).cpu().sum()
            pbar.write(f"Testing {batch_idx} Loss: {train_loss/(batch_idx+1):.3f} | Acc: {100.*correct/total:.3f}% ({correct}/{total})")

    
    #----------------------------------------------------------------------
    # Save the checkpoint
    #----------------------------------------------------------------------
    '''
    for i in range(len(all_W_kernels)):
        k_W = all_W_kernels[i]
        k_quant = all_G_kernels[i]    
        k_W.data, k_quant.data = k_quant.data, k_W.data
        '''  
    acc = 100.*correct/total
    #if acc > best_acc:
    if correct > best_count:
        # print('Saving model...')
        # state = {
        #     'state': net.state_dict(), #net,
        #     'acc': acc,
        #     'epoch': epoch,
        # }
        
        # torch.save(state, f'./saved_models/resnet_{quantize_nbits}bits_{u.time_stamp()}.pth')
        #net.save_state_dict('resnet20.pt')
        best_acc = acc
        best_count = correct

    for k_W, k_G in zip(all_W_kernels, all_G_kernels):
        k_W.data, k_G.data = k_G.data, k_W.data
    clear_output(wait=True)
    
save_model(net, "saved_models", f'resnet_{quantize_nbits}bits_{u.time_stamp()}.pth')

  0%|          | 0/391 [00:00<?, ?it/s]

Epoch: 149 Training 0 Loss: 0.180 | Acc: 96.094% (123/128)
Epoch: 149 Training 1 Loss: 0.262 | Acc: 92.188% (236/256)
Epoch: 149 Training 2 Loss: 0.285 | Acc: 90.885% (349/384)
Epoch: 149 Training 3 Loss: 0.306 | Acc: 90.430% (463/512)
Epoch: 149 Training 4 Loss: 0.301 | Acc: 90.469% (579/640)
Epoch: 149 Training 5 Loss: 0.295 | Acc: 90.495% (695/768)
Epoch: 149 Training 6 Loss: 0.306 | Acc: 89.955% (806/896)
Epoch: 149 Training 7 Loss: 0.308 | Acc: 89.844% (920/1024)
Epoch: 149 Training 8 Loss: 0.323 | Acc: 89.410% (1030/1152)
Epoch: 149 Training 9 Loss: 0.327 | Acc: 89.297% (1143/1280)
Epoch: 149 Training 10 Loss: 0.324 | Acc: 89.205% (1256/1408)
Epoch: 149 Training 11 Loss: 0.319 | Acc: 89.323% (1372/1536)
Epoch: 149 Training 12 Loss: 0.313 | Acc: 89.483% (1489/1664)
Epoch: 149 Training 13 Loss: 0.321 | Acc: 89.174% (1598/1792)
Epoch: 149 Training 14 Loss: 0.317 | Acc: 89.271% (1714/1920)
Epoch: 149 Training 15 Loss: 0.318 | Acc: 89.307% (1829/2048)
Epoch: 149 Training 16 Loss: 0.31

  0%|          | 0/40 [00:00<?, ?it/s]

Testing 0 Loss: 120.287 | Acc: 83.600% (209/250)
Testing 1 Loss: 60.144 | Acc: 84.200% (421/500)
Testing 2 Loss: 40.096 | Acc: 85.733% (643/750)
Testing 3 Loss: 30.072 | Acc: 84.800% (848/1000)
Testing 4 Loss: 24.057 | Acc: 83.520% (1044/1250)
Testing 5 Loss: 20.048 | Acc: 83.667% (1255/1500)
Testing 6 Loss: 17.184 | Acc: 83.829% (1467/1750)
Testing 7 Loss: 15.036 | Acc: 83.500% (1670/2000)
Testing 8 Loss: 13.365 | Acc: 83.600% (1881/2250)
Testing 9 Loss: 12.029 | Acc: 83.640% (2091/2500)
Testing 10 Loss: 10.935 | Acc: 83.636% (2300/2750)
Testing 11 Loss: 10.024 | Acc: 83.367% (2501/3000)
Testing 12 Loss: 9.253 | Acc: 83.108% (2701/3250)
Testing 13 Loss: 8.592 | Acc: 83.200% (2912/3500)
Testing 14 Loss: 8.019 | Acc: 83.280% (3123/3750)
Testing 15 Loss: 7.518 | Acc: 83.225% (3329/4000)
Testing 16 Loss: 7.076 | Acc: 83.294% (3540/4250)
Testing 17 Loss: 6.683 | Acc: 83.511% (3758/4500)
Testing 18 Loss: 6.331 | Acc: 83.432% (3963/4750)
Testing 19 Loss: 6.014 | Acc: 83.520% (4176/5000)
Test