In [1]:
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torchvision.datasets as datasets
from models.qresnet import ResNet_cifar10, BasicBlock
from preprocess import get_transform
import os
import time
from datetime import datetime
import logging
from utils.log import setup_logging, ResultsLog, save_checkpoint
from utils.meters import AverageMeter, accuracy
from utils.optim import OptimRegime
from utils.misc import torch_dtypes



In [2]:

batch_size = 128
dataset = "cifar10"
input_size = 32
classes = 10
depth = 18
device = "cpu"
seed = 123

In [3]:
def forward(data_loader, model, criterion, epoch=0, training=True, optimizer=None):
    regularizer = getattr(model, 'regularization', None)
    # if args.device_ids and len(args.device_ids) > 1:
    #     model = torch.nn.DataParallel(model, args.device_ids)
        
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    end = time.time()
    for i, (inputs, target) in enumerate(data_loader):
        # measure data loading time
        data_time.update(time.time() - end)
        target = target.to("cpu")
        inputs = inputs.to("cpu", dtype=dtype)

        # compute output
        output = model(inputs)
        loss = criterion(output, target)
        if regularizer is not None:
            loss += regularizer(model)

        if type(output) is list:
            output = output[0]

        # measure accuracy and record loss
        prec1, prec5 = accuracy(output.detach(), target, topk=(1, 5))
        losses.update(float(loss), inputs.size(0))
        top1.update(float(prec1), inputs.size(0))
        top5.update(float(prec5), inputs.size(0))

        if training:
            # optimizer.update(epoch, epoch * len(data_loader) + i)
            # compute gradient and do SGD step
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        print_freq = 10
        if i % print_freq == 0:
            logging.info('{phase} - Epoch: [{0}][{1}/{2}]\t'
                         'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                         'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                         'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                         'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                         'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                             epoch, i, len(data_loader),
                             phase='TRAINING' if training else 'EVALUATING',
                             batch_time=batch_time,
                             data_time=data_time, loss=losses, top1=top1, top5=top5))
    return losses.avg, top1.avg, top5.avg


In [4]:
def train(data_loader, model, criterion, epoch, optimizer):
    # switch to train mode
    model.train()
    return forward(data_loader, model, criterion, epoch,
                   training=True, optimizer=optimizer)


def validate(data_loader, model, criterion, epoch):
    # switch to evaluate mode
    model.eval()
    with torch.no_grad():
        return forward(data_loader, model, criterion, epoch,
                       training=False, optimizer=None)


In [5]:
global best_prec1, dtype
best_prec1 = 0
# args = parser.parse_args()
dtype = torch.float
torch.manual_seed(seed)
time_stamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
# if args.evaluate:
#     args.results_dir = '/tmp'
# if args.save is '':
#     args.save = time_stamp
save_path = os.path.join("./results", time_stamp)
if not os.path.exists(save_path):
    os.makedirs(save_path)

setup_logging(os.path.join(save_path, 'log.txt'))
results_path = os.path.join(save_path, 'results')
results = ResultsLog(
    results_path, title='Training Results - %s' % time_stamp)

logging.info("saving to %s", save_path)
# logging.debug("run arguments: %s", args)

# To use cuda, uncomment below
# if torch.cuda.is_available():
#     torch.cuda.manual_seed_all(seed)
#     torch.cuda.set_device(0)
#     cudnn.benchmark = True


# create model
# logging.info("creating model %s", args.model)
# model = models.__dict__[args.model]
# model_config = {'input_size': args.input_size, 'dataset': args.dataset}

# if args.model_config is not '':
#     model_config = dict(model_config, **literal_eval(args.model_config))

model = ResNet_cifar10(num_classes=classes,block=BasicBlock, depth=depth)
# logging.info("created model with configuration: %s", model_config)

# TODO: implement resume training and only validation features
# if evaluate:
#     if not os.path.isfile(args.evaluate):
#         parser.error('invalid checkpoint: {}'.format(args.evaluate))
#     checkpoint = torch.load(args.evaluate)
#     model.load_state_dict(checkpoint['state_dict'])
#     logging.info("loaded checkpoint '%s' (epoch %s)",
#                     args.evaluate, checkpoint['epoch'])
# elif resume:
#     checkpoint_file = args.resume
#     if os.path.isdir(checkpoint_file):
#         results.load(os.path.join(checkpoint_file, 'results.csv'))
#         checkpoint_file = os.path.join(
#             checkpoint_file, 'model_best.pth.tar')
#     if os.path.isfile(checkpoint_file):
#         logging.info("loading checkpoint '%s'", args.resume)
#         checkpoint = torch.load(checkpoint_file)
#         args.start_epoch = checkpoint['epoch'] - 1
#         best_prec1 = checkpoint['best_prec1']
#         model.load_state_dict(checkpoint['state_dict'])
#         logging.info("loaded checkpoint '%s' (epoch %s)",
#                         checkpoint_file, checkpoint['epoch'])
#     else:
#         logging.error("no checkpoint found at '%s'", args.resume)

num_parameters = sum([l.nelement() for l in model.parameters()])
logging.info("number of parameters: %d", num_parameters)

# Data loading code
default_transform = {
    'train': get_transform(dataset,
                            input_size=input_size, augment=True),
    'eval': get_transform(dataset,
                            input_size=input_size, augment=False)
}
transform = getattr(model, 'input_transform', default_transform)
regime = getattr(model, 'regime')

# define loss function (criterion) and optimizer
criterion = getattr(model, 'criterion', nn.CrossEntropyLoss)()
criterion.to(device, dtype)
model.to(device, dtype)



saving to ./results/2024-02-22_02-44-33
number of parameters: 175274


ResNet_cifar10(
  (conv1): QConv2d(
    3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
    (quantize_input): QuantMeasure()
  )
  (bn1): RangeBN(
    (quantize_input): QuantMeasure()
  )
  (relu): ReLU()
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): QConv2d(
        16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False
        (quantize_input): QuantMeasure()
      )
      (bn1): RangeBN(
        (quantize_input): QuantMeasure()
      )
      (relu): ReLU()
      (conv2): QConv2d(
        16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False
        (quantize_input): QuantMeasure()
      )
      (bn2): RangeBN(
        (quantize_input): QuantMeasure()
      )
      (add_relu): FloatFunctional(
        (activation_post_process): Identity()
      )
    )
    (1): BasicBlock(
      (conv1): QConv2d(
        16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False
        (quantize_input): QuantMeasure()
      )
      (bn

In [6]:

val_data = datasets.CIFAR10(root="./cifar10",
                train=False,
                transform=transform['eval'],
                download=False)
val_loader = torch.utils.data.DataLoader(
                val_data,
                batch_size=batch_size, shuffle=False,
                num_workers=8, pin_memory=True)

# if args.evaluate:
#     validate(val_loader, model, criterion, 0)
#     exit

train_data = datasets.CIFAR10(root="./cifar10",
                train=True,
                transform=transform['train'],
                download=False)
train_loader = torch.utils.data.DataLoader(
                train_data,
                batch_size=batch_size, shuffle=True,
                num_workers=8, pin_memory=True)





In [7]:
# optimizer = OptimRegime(model.parameters(), regime)
optimizer = torch.optim.SGD(model.parameters(),
                            lr=1e-1,
                            weight_decay=1e-4,
                            momentum=0.9)
logging.info('training regime: %s', regime)
epochs = 20
for epoch in range(epochs):
    # train for one epoch
    train_loss, train_prec1, train_prec5 = train(
        train_loader, model, criterion, epoch, optimizer)

    # evaluate on validation set
    val_loss, val_prec1, val_prec5 = validate(
        val_loader, model, criterion, epoch)

    # remember best prec@1 and save checkpoint
    is_best = val_prec1 > best_prec1
    best_prec1 = max(val_prec1, best_prec1)
    save_checkpoint({
        'epoch': epoch + 1,
        'model': "RESNET",
        # 'config': args.model_config,
        'state_dict': model.state_dict(),
        'best_prec1': best_prec1,
        'regime': regime
    }, is_best, path=save_path)
    logging.info('\n Epoch: {0}\t'
                    'Training Loss {train_loss:.4f} \t'
                    'Training Prec@1 {train_prec1:.3f} \t'
                    'Training Prec@5 {train_prec5:.3f} \t'
                    'Validation Loss {val_loss:.4f} \t'
                    'Validation Prec@1 {val_prec1:.3f} \t'
                    'Validation Prec@5 {val_prec5:.3f} \n'
                    .format(epoch + 1, train_loss=train_loss, val_loss=val_loss,
                            train_prec1=train_prec1, val_prec1=val_prec1,
                            train_prec5=train_prec5, val_prec5=val_prec5))

    results.add(epoch=epoch + 1, train_loss=train_loss, val_loss=val_loss,
                train_error1=100 - train_prec1, val_error1=100 - val_prec1,
                train_error5=100 - train_prec5, val_error5=100 - val_prec5)
    results.plot(x='epoch', y=['train_loss', 'val_loss'],
                    legend=['training', 'validation'],
                    title='Loss', ylabel='loss')
    results.plot(x='epoch', y=['train_error1', 'val_error1'],
                    legend=['training', 'validation'],
                    title='Error@1', ylabel='error %')
    results.plot(x='epoch', y=['train_error5', 'val_error5'],
                    legend=['training', 'validation'],
                    title='Error@5', ylabel='error %')
    results.save()

training regime: [{'epoch': 0, 'optimizer': 'SGD', 'lr': 0.1, 'weight_decay': 0.0001, 'momentum': 0.9}, {'epoch': 81, 'lr': 0.01}, {'epoch': 122, 'lr': 0.001, 'weight_decay': 0}, {'epoch': 164, 'lr': 0.0001}]
TRAINING - Epoch: [0][0/391]	Time 1.460 (1.460)	Data 0.249 (0.249)	Loss 2.3503 (2.3503)	Prec@1 8.594 (8.594)	Prec@5 46.875 (46.875)
TRAINING - Epoch: [0][10/391]	Time 0.805 (0.861)	Data 0.001 (0.025)	Loss 2.2445 (2.3216)	Prec@1 22.656 (10.938)	Prec@5 70.312 (61.364)
TRAINING - Epoch: [0][20/391]	Time 0.788 (0.832)	Data 0.002 (0.014)	Loss 2.0746 (2.2564)	Prec@1 21.875 (14.286)	Prec@5 76.562 (66.369)
TRAINING - Epoch: [0][30/391]	Time 0.765 (0.816)	Data 0.002 (0.010)	Loss 2.1097 (2.2098)	Prec@1 25.781 (17.036)	Prec@5 74.219 (69.178)
TRAINING - Epoch: [0][40/391]	Time 0.775 (0.819)	Data 0.002 (0.008)	Loss 2.0262 (2.1816)	Prec@1 20.312 (18.293)	Prec@5 83.594 (71.189)
TRAINING - Epoch: [0][50/391]	Time 0.750 (0.811)	Data 0.002 (0.007)	Loss 2.0238 (2.1520)	Prec@1 26.562 (19.332)	Prec@5 

In [8]:
# Evaluate on the test dataset
# Define transforms for data normalization
# import torch
# import torchvision
# import torchvision.transforms as transforms
# transform = transforms.Compose([
#     transforms.ToTensor(),
#     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
# ])

# # Load the CIFAR10 test dataset (assuming it's already downloaded)
# testset = torchvision.datasets.CIFAR10(root='./cifar10', train=False, download=False, transform=transform)

# # Create data loader
# testloader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False)

model.eval()  # Set model to evaluation mode
correct = 0
total = 0
with torch.no_grad():
    for images, labels in val_loader:
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy = 100 * correct / total
print('Accuracy of the model on the test images: {}%'.format(accuracy))



Accuracy of the model on the test images: 82.99%


In [9]:
from torchinfo import summary
summary(model)

Layer (type:depth-idx)                        Param #
ResNet_cifar10                                --
├─QConv2d: 1-1                                448
│    └─QuantMeasure: 2-1                      --
├─RangeBN: 1-2                                32
│    └─QuantMeasure: 2-2                      --
├─ReLU: 1-3                                   --
├─Sequential: 1-4                             --
│    └─BasicBlock: 2-3                        --
│    │    └─QConv2d: 3-1                      2,304
│    │    └─RangeBN: 3-2                      32
│    │    └─ReLU: 3-3                         --
│    │    └─QConv2d: 3-4                      2,304
│    │    └─RangeBN: 3-5                      32
│    │    └─FloatFunctional: 3-6              --
│    └─BasicBlock: 2-4                        --
│    │    └─QConv2d: 3-7                      2,304
│    │    └─RangeBN: 3-8                      32
│    │    └─ReLU: 3-9                         --
│    │    └─QConv2d: 3-10                     2,304
│ 

In [10]:
from torchsummary import summary
summary(model, input_size=(3, 32, 32))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
      QuantMeasure-1            [-1, 3, 32, 32]               0
           QConv2d-2           [-1, 16, 32, 32]             448
      QuantMeasure-3           [-1, 16, 32, 32]               0
           RangeBN-4           [-1, 16, 32, 32]              32
              ReLU-5           [-1, 16, 32, 32]               0
      QuantMeasure-6           [-1, 16, 32, 32]               0
           QConv2d-7           [-1, 16, 32, 32]           2,304
      QuantMeasure-8           [-1, 16, 32, 32]               0
           RangeBN-9           [-1, 16, 32, 32]              32
             ReLU-10           [-1, 16, 32, 32]               0
     QuantMeasure-11           [-1, 16, 32, 32]               0
          QConv2d-12           [-1, 16, 32, 32]           2,304
     QuantMeasure-13           [-1, 16, 32, 32]               0
          RangeBN-14           [-1, 16,