-
Notifications
You must be signed in to change notification settings - Fork 28
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
37 changed files
with
3,130 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,6 @@ | ||
data | ||
checkpoint | ||
|
||
# Byte-compiled / optimized / DLL files | ||
__pycache__/ | ||
*.py[cod] | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,31 @@ | ||
# shiftnet-cifar | ||
shiftnet on CIFAR task | ||
# ShiftResNet | ||
|
||
Train ResNet with shift operations on CIFAR10 using PyTorch. This uses the [original resnet codebase](https://github.com/kuangliu/pytorch-cifar.git) written by Kuang Liu. In this codebase, we replace 3x3 convolutional layers with a conv-shift-conv--a 1x1 convolutional layer, a set of shift operations, and a second 1x1 convolutional layer. From Liu, this repository boasts: | ||
|
||
- Built-in data loading and augmentation, very nice! | ||
- Training is fast, maybe even a little bit faster. | ||
- Very memory efficient! | ||
|
||
## Accuracy | ||
|
||
Below, we run experiments using ResNet101, varying expansion used for all conv-shift-conv layers in the neural network. | ||
|
||
| Expansion | Acc | | ||
|-----------|-----| | ||
| 1 | | | ||
| 2 | | | ||
| 3 | | | ||
| 4 | | | ||
| 5 | | | ||
| 6 | | | ||
| 7 | | | ||
| 8 | | | ||
| 9 | | | ||
|
||
## Learning rate adjustment | ||
I manually change the `lr` during training: | ||
- `0.1` for epoch `[0,150)` | ||
- `0.01` for epoch `[150,250)` | ||
- `0.001` for epoch `[250,350)` | ||
|
||
Resume the training with `python main.py --resume --lr=0.01` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,155 @@ | ||
'''Train CIFAR10 with PyTorch.''' | ||
|
||
import torch | ||
import torch.nn as nn | ||
import torch.optim as optim | ||
import torch.nn.functional as F | ||
import torch.backends.cudnn as cudnn | ||
|
||
import torchvision | ||
import torchvision.transforms as transforms | ||
|
||
import os | ||
import argparse | ||
|
||
from models import * | ||
from utils import progress_bar | ||
from torch.autograd import Variable | ||
|
||
|
||
parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training') | ||
parser.add_argument('--lr', default=0.1, type=float, help='learning rate') | ||
parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint') | ||
parser.add_argument('--batch_size', '-b', default=128, type=int, help='batch size') | ||
args = parser.parse_args() | ||
|
||
use_cuda = torch.cuda.is_available() | ||
best_acc = 0 # best test accuracy | ||
start_epoch = 0 # start from epoch 0 or last checkpoint epoch | ||
|
||
# Data | ||
print('==> Preparing data..') | ||
transform_train = transforms.Compose([ | ||
transforms.RandomCrop(32, padding=4), | ||
transforms.RandomHorizontalFlip(), | ||
transforms.ToTensor(), | ||
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), | ||
]) | ||
|
||
transform_test = transforms.Compose([ | ||
transforms.ToTensor(), | ||
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), | ||
]) | ||
|
||
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train) | ||
trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=2) | ||
|
||
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test) | ||
testloader = torch.utils.data.DataLoader(testset, batch_size=1000, shuffle=False, num_workers=2) | ||
|
||
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') | ||
|
||
# Model | ||
if args.resume: | ||
# Load checkpoint. | ||
print('==> Resuming from checkpoint..') | ||
assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!' | ||
checkpoint = torch.load('./checkpoint/ckpt.t18') | ||
net = checkpoint['net'] | ||
best_acc = checkpoint['acc'] | ||
start_epoch = checkpoint['epoch'] | ||
else: | ||
print('==> Building model..') | ||
# net = VGG('VGG19') | ||
net = ResNet20() | ||
# net = PreActResNet18() | ||
# net = GoogLeNet() | ||
# net = DenseNet121() | ||
# net = ResNeXt29_2x64d() | ||
# net = MobileNet() | ||
# net = DPN92() | ||
# net = ShuffleNetG2() | ||
# net = SENet18() | ||
|
||
if use_cuda: | ||
net.cuda() | ||
net = torch.nn.DataParallel( | ||
net, device_ids=range(torch.cuda.device_count())) | ||
cudnn.benchmark = True | ||
|
||
criterion = nn.CrossEntropyLoss() | ||
|
||
def adjust_learning_rate(epoch, lr): | ||
if epoch < 80: | ||
return lr | ||
elif epoch < 120: | ||
return lr/10 | ||
else: | ||
return lr/100 | ||
|
||
# Training | ||
def train(epoch): | ||
lr = adjust_learning_rate(epoch, args.lr) | ||
optimizer = optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4) | ||
print('\nEpoch: %d' % epoch) | ||
net.train() | ||
train_loss = 0 | ||
correct = 0 | ||
total = 0 | ||
for batch_idx, (inputs, targets) in enumerate(trainloader): | ||
if use_cuda: | ||
inputs, targets = inputs.cuda(), targets.cuda() | ||
optimizer.zero_grad() | ||
inputs, targets = Variable(inputs), Variable(targets) | ||
outputs = net(inputs) | ||
loss = criterion(outputs, targets) | ||
loss.backward() | ||
optimizer.step() | ||
|
||
train_loss += loss.data[0] | ||
_, predicted = torch.max(outputs.data, 1) | ||
total += targets.size(0) | ||
correct += predicted.eq(targets.data).cpu().sum() | ||
|
||
progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' | ||
% (train_loss/(batch_idx+1), 100.*correct/total, correct, total)) | ||
|
||
def test(epoch): | ||
global best_acc | ||
net.eval() | ||
test_loss = 0 | ||
correct = 0 | ||
total = 0 | ||
for batch_idx, (inputs, targets) in enumerate(testloader): | ||
if use_cuda: | ||
inputs, targets = inputs.cuda(), targets.cuda() | ||
inputs, targets = Variable(inputs, volatile=True), Variable(targets) | ||
outputs = net(inputs) | ||
loss = criterion(outputs, targets) | ||
|
||
test_loss += loss.data[0] | ||
_, predicted = torch.max(outputs.data, 1) | ||
total += targets.size(0) | ||
correct += predicted.eq(targets.data).cpu().sum() | ||
|
||
progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' | ||
% (test_loss/(batch_idx+1), 100.*correct/total, correct, total)) | ||
|
||
# Save checkpoint. | ||
acc = 100.*correct/total | ||
if acc > best_acc: | ||
print('Saving..') | ||
state = { | ||
'net': net.module if use_cuda else net, | ||
'acc': acc, | ||
'epoch': epoch, | ||
} | ||
if not os.path.isdir('checkpoint'): | ||
os.mkdir('checkpoint') | ||
torch.save(state, './checkpoint/resnet20.t7') | ||
best_acc = acc | ||
|
||
|
||
for epoch in range(start_epoch, 160): | ||
train(epoch) | ||
test(epoch) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,155 @@ | ||
'''Train CIFAR10 with PyTorch.''' | ||
|
||
import torch | ||
import torch.nn as nn | ||
import torch.optim as optim | ||
import torch.nn.functional as F | ||
import torch.backends.cudnn as cudnn | ||
|
||
import torchvision | ||
import torchvision.transforms as transforms | ||
|
||
import os | ||
import argparse | ||
|
||
from models import * | ||
from utils import progress_bar | ||
from torch.autograd import Variable | ||
|
||
|
||
parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training') | ||
parser.add_argument('--lr', default=0.1, type=float, help='learning rate') | ||
parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint') | ||
parser.add_argument('--batch_size', '-b', default=128, type=int, help='batch size') | ||
args = parser.parse_args() | ||
|
||
use_cuda = torch.cuda.is_available() | ||
best_acc = 0 # best test accuracy | ||
start_epoch = 0 # start from epoch 0 or last checkpoint epoch | ||
|
||
# Data | ||
print('==> Preparing data..') | ||
transform_train = transforms.Compose([ | ||
transforms.RandomCrop(32, padding=4), | ||
transforms.RandomHorizontalFlip(), | ||
transforms.ToTensor(), | ||
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), | ||
]) | ||
|
||
transform_test = transforms.Compose([ | ||
transforms.ToTensor(), | ||
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), | ||
]) | ||
|
||
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train) | ||
trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=2) | ||
|
||
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test) | ||
testloader = torch.utils.data.DataLoader(testset, batch_size=1000, shuffle=False, num_workers=2) | ||
|
||
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') | ||
|
||
# Model | ||
if args.resume: | ||
# Load checkpoint. | ||
print('==> Resuming from checkpoint..') | ||
assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!' | ||
checkpoint = torch.load('./checkpoint/ckpt.t18') | ||
net = checkpoint['net'] | ||
best_acc = checkpoint['acc'] | ||
start_epoch = checkpoint['epoch'] | ||
else: | ||
print('==> Building model..') | ||
# net = VGG('VGG19') | ||
net = ResNet44() | ||
# net = PreActResNet18() | ||
# net = GoogLeNet() | ||
# net = DenseNet121() | ||
# net = ResNeXt29_2x64d() | ||
# net = MobileNet() | ||
# net = DPN92() | ||
# net = ShuffleNetG2() | ||
# net = SENet18() | ||
|
||
if use_cuda: | ||
net.cuda() | ||
net = torch.nn.DataParallel( | ||
net, device_ids=range(torch.cuda.device_count())) | ||
cudnn.benchmark = True | ||
|
||
criterion = nn.CrossEntropyLoss() | ||
|
||
def adjust_learning_rate(epoch, lr): | ||
if epoch < 80: | ||
return lr | ||
elif epoch < 120: | ||
return lr/10 | ||
else: | ||
return lr/100 | ||
|
||
# Training | ||
def train(epoch): | ||
lr = adjust_learning_rate(epoch, args.lr) | ||
optimizer = optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4) | ||
print('\nEpoch: %d' % epoch) | ||
net.train() | ||
train_loss = 0 | ||
correct = 0 | ||
total = 0 | ||
for batch_idx, (inputs, targets) in enumerate(trainloader): | ||
if use_cuda: | ||
inputs, targets = inputs.cuda(), targets.cuda() | ||
optimizer.zero_grad() | ||
inputs, targets = Variable(inputs), Variable(targets) | ||
outputs = net(inputs) | ||
loss = criterion(outputs, targets) | ||
loss.backward() | ||
optimizer.step() | ||
|
||
train_loss += loss.data[0] | ||
_, predicted = torch.max(outputs.data, 1) | ||
total += targets.size(0) | ||
correct += predicted.eq(targets.data).cpu().sum() | ||
|
||
progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' | ||
% (train_loss/(batch_idx+1), 100.*correct/total, correct, total)) | ||
|
||
def test(epoch): | ||
global best_acc | ||
net.eval() | ||
test_loss = 0 | ||
correct = 0 | ||
total = 0 | ||
for batch_idx, (inputs, targets) in enumerate(testloader): | ||
if use_cuda: | ||
inputs, targets = inputs.cuda(), targets.cuda() | ||
inputs, targets = Variable(inputs, volatile=True), Variable(targets) | ||
outputs = net(inputs) | ||
loss = criterion(outputs, targets) | ||
|
||
test_loss += loss.data[0] | ||
_, predicted = torch.max(outputs.data, 1) | ||
total += targets.size(0) | ||
correct += predicted.eq(targets.data).cpu().sum() | ||
|
||
progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' | ||
% (test_loss/(batch_idx+1), 100.*correct/total, correct, total)) | ||
|
||
# Save checkpoint. | ||
acc = 100.*correct/total | ||
if acc > best_acc: | ||
print('Saving..') | ||
state = { | ||
'net': net.module if use_cuda else net, | ||
'acc': acc, | ||
'epoch': epoch, | ||
} | ||
if not os.path.isdir('checkpoint'): | ||
os.mkdir('checkpoint') | ||
torch.save(state, './checkpoint/resnet44.t7') | ||
best_acc = acc | ||
|
||
|
||
for epoch in range(start_epoch, 160): | ||
train(epoch) | ||
test(epoch) |
Oops, something went wrong.