-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
56447bc
commit 856ec31
Showing
4 changed files
with
540 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
epoch hours top1Accuracy | ||
1 0.00148373 52.61 | ||
2 0.00228371 62.45 | ||
3 0.00308774 77.08 | ||
4 0.00388793 66.63 | ||
5 0.00468890 74.39 | ||
6 0.00549573 83.39 | ||
7 0.00629734 81.37 | ||
8 0.00710241 81.29 | ||
9 0.00790268 84.33 | ||
10 0.00870396 82.36 | ||
11 0.00950842 87.64 | ||
12 0.01030985 84.10 | ||
13 0.01111024 88.50 | ||
14 0.01191096 87.62 | ||
15 0.01271271 84.25 | ||
16 0.01351768 88.41 | ||
17 0.01431947 86.36 | ||
18 0.01512118 89.79 | ||
19 0.01592666 90.87 | ||
20 0.01672992 91.01 | ||
21 0.01753578 91.84 | ||
22 0.01833951 92.59 | ||
23 0.01914275 93.64 | ||
24 0.01994991 94.10 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,149 @@ | ||
import numpy as np | ||
import torch | ||
from torch import nn | ||
import torchvision | ||
from utils import build_graph, cat, to_numpy | ||
|
||
torch.backends.cudnn.benchmark = True | ||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | ||
|
||
@cat.register(torch.Tensor) | ||
def _(*xs): | ||
return torch.cat(xs) | ||
|
||
@to_numpy.register(torch.Tensor) | ||
def _(x): | ||
return x.detach().cpu().numpy() | ||
|
||
def warmup_cudnn(model, batch_size): | ||
#run forward and backward pass of the model on a batch of random inputs | ||
#to allow benchmarking of cudnn kernels | ||
batch = { | ||
'input': torch.Tensor(np.random.rand(batch_size,3,32,32)).cuda().half(), | ||
'target': torch.LongTensor(np.random.randint(0,10,batch_size)).cuda() | ||
} | ||
model.train(True) | ||
o = model(batch) | ||
o['loss'].sum().backward() | ||
model.zero_grad() | ||
torch.cuda.synchronize() | ||
|
||
|
||
##################### | ||
## dataset | ||
##################### | ||
|
||
def cifar10(root): | ||
train_set = torchvision.datasets.CIFAR10(root=root, train=True, download=True) | ||
test_set = torchvision.datasets.CIFAR10(root=root, train=False, download=True) | ||
return { | ||
'train': {'data': train_set.train_data, 'labels': train_set.train_labels}, | ||
'test': {'data': test_set.test_data, 'labels': test_set.test_labels} | ||
} | ||
|
||
##################### | ||
## data loading | ||
##################### | ||
|
||
class Batches(): | ||
def __init__(self, dataset, batch_size, shuffle, set_random_choices=False, num_workers=0, drop_last=False): | ||
self.dataset = dataset | ||
self.batch_size = batch_size | ||
self.set_random_choices = set_random_choices | ||
self.dataloader = torch.utils.data.DataLoader( | ||
dataset, batch_size=batch_size, num_workers=num_workers, pin_memory=True, shuffle=shuffle, drop_last=drop_last | ||
) | ||
|
||
def __iter__(self): | ||
if self.set_random_choices: | ||
self.dataset.set_random_choices() | ||
return ({'input': x.to(device).half(), 'target': y.to(device).long()} for (x,y) in self.dataloader) | ||
|
||
def __len__(self): | ||
return len(self.dataloader) | ||
|
||
##################### | ||
## torch stuff | ||
##################### | ||
|
||
class Identity(nn.Module): | ||
def forward(self, x): return x | ||
|
||
class Mul(nn.Module): | ||
def __init__(self, weight): | ||
super().__init__() | ||
self.weight = weight | ||
def __call__(self, x): | ||
return x*self.weight | ||
|
||
class Flatten(nn.Module): | ||
def forward(self, x): return x.view(x.size(0), x.size(1)) | ||
|
||
class Add(nn.Module): | ||
def forward(self, x, y): return x + y | ||
|
||
class Concat(nn.Module): | ||
def forward(self, *xs): return torch.cat(xs, 1) | ||
|
||
class Correct(nn.Module): | ||
def forward(self, classifier, target): | ||
return classifier.max(dim = 1)[1] == target | ||
|
||
def batch_norm(num_channels, bn_bias_init=None, bn_bias_freeze=False, bn_weight_init=None, bn_weight_freeze=False): | ||
m = nn.BatchNorm2d(num_channels) | ||
if bn_bias_init is not None: | ||
m.bias.data.fill_(bn_bias_init) | ||
if bn_bias_freeze: | ||
m.bias.requires_grad = False | ||
if bn_weight_init is not None: | ||
m.weight.data.fill_(bn_weight_init) | ||
if bn_weight_freeze: | ||
m.weight.requires_grad = False | ||
|
||
return m | ||
|
||
|
||
|
||
class Network(nn.Module): | ||
def __init__(self, net): | ||
self.graph = build_graph(net) | ||
super().__init__() | ||
for n, (v, _) in self.graph.items(): | ||
setattr(self, n, v) | ||
|
||
def forward(self, inputs): | ||
self.cache = dict(inputs) | ||
for n, (_, i) in self.graph.items(): | ||
self.cache[n] = getattr(self, n)(*[self.cache[x] for x in i]) | ||
return self.cache | ||
|
||
def half(self): | ||
for module in self.children(): | ||
if type(module) is not nn.BatchNorm2d: | ||
module.half() | ||
return self | ||
|
||
trainable_params = lambda model:filter(lambda p: p.requires_grad, model.parameters()) | ||
|
||
class TorchOptimiser(): | ||
def __init__(self, weights, optimizer, step_number=0, **opt_params): | ||
self.weights = weights | ||
self.step_number = step_number | ||
self.opt_params = opt_params | ||
self._opt = optimizer(weights, **self.param_values()) | ||
|
||
def param_values(self): | ||
return {k: v(self.step_number) if callable(v) else v for k,v in self.opt_params.items()} | ||
|
||
def step(self): | ||
self.step_number += 1 | ||
self._opt.param_groups[0].update(**self.param_values()) | ||
self._opt.step() | ||
|
||
def __repr__(self): | ||
return repr(self._opt) | ||
|
||
def SGD(weights, lr=0, momentum=0, weight_decay=0, dampening=0, nesterov=False): | ||
return TorchOptimiser(weights, torch.optim.SGD, lr=lr, momentum=momentum, | ||
weight_decay=weight_decay, dampening=dampening, | ||
nesterov=nesterov) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
from utils import * | ||
from torch_backend import * | ||
|
||
#Network definition | ||
def conv_bn(c_in, c_out, bn_weight_init=1.0, **kw): | ||
return { | ||
'conv': nn.Conv2d(c_in, c_out, kernel_size=3, stride=1, padding=1, bias=False), | ||
'bn': batch_norm(c_out, bn_weight_init=bn_weight_init, **kw), | ||
'relu': nn.ReLU(True) | ||
} | ||
|
||
def residual(c, **kw): | ||
return { | ||
'in': Identity(), | ||
'res1': conv_bn(c, c, **kw), | ||
'res2': conv_bn(c, c, **kw), | ||
'add': (Add(), [rel_path('in'), rel_path('res2', 'relu')]), | ||
} | ||
|
||
def basic_net(channels, weight, pool, **kw): | ||
return { | ||
'prep': conv_bn(3, channels['prep'], **kw), | ||
'layer1': dict(conv_bn(channels['prep'], channels['layer1'], **kw), pool=pool), | ||
'layer2': dict(conv_bn(channels['layer1'], channels['layer2'], **kw), pool=pool), | ||
'layer3': dict(conv_bn(channels['layer2'], channels['layer3'], **kw), pool=pool), | ||
'pool': nn.MaxPool2d(8), | ||
'flatten': Flatten(), | ||
'linear': nn.Linear(channels['layer3'], 10, bias=False), | ||
'classifier': Mul(weight), | ||
} | ||
|
||
def net(channels=None, weight=0.2, pool=nn.MaxPool2d(2), extra_layers=(), res_layers=('layer1', 'layer2'), **kw): | ||
channels = channels or {'prep': 64, 'layer1': 128, 'layer2': 256, 'layer3': 256, } | ||
n = basic_net(channels, weight, pool, **kw) | ||
for layer in res_layers: | ||
n[layer]['residual'] = residual(channels[layer], **kw) | ||
for layer in extra_layers: | ||
n[layer]['extra'] = conv_bn(channels[layer], channels[layer], **kw) | ||
return n | ||
|
||
losses = { | ||
'loss': (nn.CrossEntropyLoss(reduce=False), [('classifier',), ('target',)]), | ||
'correct': (Correct(), [('classifier',), ('target',)]), | ||
} | ||
|
||
class TSVLogger(): | ||
def __init__(self): | ||
self.log = ['epoch\thours\ttop1Accuracy'] | ||
def append(self, output): | ||
epoch, hours, acc = output['epoch'], output['total time']/3600, output['test acc']*100 | ||
self.log.append(f'{epoch}\t{hours:.8f}\t{acc:.2f}') | ||
def __str__(self): | ||
return '\n'.join(self.log) | ||
|
||
def main(): | ||
DATA_DIR = './data' | ||
|
||
print('Downloading datasets') | ||
dataset = cifar10(DATA_DIR) | ||
|
||
epochs = 24 | ||
lr_schedule = PiecewiseLinear([0, 5, epochs], [0, 0.4, 0.001]) | ||
batch_size = 512 | ||
train_transforms = [Crop(32, 32), FlipLR(), Cutout(8, 8)] | ||
|
||
model = Network(union(net(), losses)).to(device).half() | ||
|
||
print('Warming up cudnn on random inputs') | ||
for size in [batch_size, len(dataset['test']['labels']) % batch_size]: | ||
warmup_cudnn(model, size) | ||
|
||
print('Starting timer') | ||
timer = Timer() | ||
|
||
print('Preprocessing training data') | ||
train_set = list(zip(transpose(normalise(pad(dataset['train']['data'], 4))), dataset['train']['labels'])) | ||
print(f'Finished in {timer():.2} seconds') | ||
print('Preprocessing test data') | ||
test_set = list(zip(transpose(normalise(dataset['test']['data'])), dataset['test']['labels'])) | ||
print(f'Finished in {timer():.2} seconds') | ||
|
||
TSV = TSVLogger() | ||
|
||
train_batches = Batches(Transform(train_set, train_transforms), batch_size, shuffle=True, set_random_choices=True, drop_last=True) | ||
test_batches = Batches(test_set, batch_size, shuffle=False, drop_last=False) | ||
lr = lambda step: lr_schedule(step/len(train_batches))/batch_size | ||
opt = SGD(trainable_params(model), lr=lr, momentum=0.9, weight_decay=5e-4*batch_size, nesterov=True) | ||
|
||
train(model, opt, train_batches, test_batches, epochs, loggers=(TableLogger(), TSV), timer=timer, test_time_in_total=False) | ||
|
||
with open('logs_training.tsv','w') as f: | ||
f.write(str(TSV)) | ||
|
||
main() |
Oops, something went wrong.