Skip to content
Permalink
Browse files

Add files via upload

  • Loading branch information...
BAIDU-USA-GAIT-LEOPARD committed May 7, 2019
1 parent 56447bc commit 856ec31bd48cc0338ce3d31dd5299cd58b1ceaaf
Showing with 540 additions and 0 deletions.
  1. +25 −0 logs_training.tsv
  2. +149 −0 torch_backend.py
  3. +94 −0 training.py
  4. +272 −0 utils.py
@@ -0,0 +1,25 @@
epoch hours top1Accuracy
1 0.00148373 52.61
2 0.00228371 62.45
3 0.00308774 77.08
4 0.00388793 66.63
5 0.00468890 74.39
6 0.00549573 83.39
7 0.00629734 81.37
8 0.00710241 81.29
9 0.00790268 84.33
10 0.00870396 82.36
11 0.00950842 87.64
12 0.01030985 84.10
13 0.01111024 88.50
14 0.01191096 87.62
15 0.01271271 84.25
16 0.01351768 88.41
17 0.01431947 86.36
18 0.01512118 89.79
19 0.01592666 90.87
20 0.01672992 91.01
21 0.01753578 91.84
22 0.01833951 92.59
23 0.01914275 93.64
24 0.01994991 94.10
@@ -0,0 +1,149 @@
import numpy as np
import torch
from torch import nn
import torchvision
from utils import build_graph, cat, to_numpy

torch.backends.cudnn.benchmark = True
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

@cat.register(torch.Tensor)
def _(*xs):
return torch.cat(xs)

@to_numpy.register(torch.Tensor)
def _(x):
return x.detach().cpu().numpy()

def warmup_cudnn(model, batch_size):
#run forward and backward pass of the model on a batch of random inputs
#to allow benchmarking of cudnn kernels
batch = {
'input': torch.Tensor(np.random.rand(batch_size,3,32,32)).cuda().half(),
'target': torch.LongTensor(np.random.randint(0,10,batch_size)).cuda()
}
model.train(True)
o = model(batch)
o['loss'].sum().backward()
model.zero_grad()
torch.cuda.synchronize()


#####################
## dataset
#####################

def cifar10(root):
train_set = torchvision.datasets.CIFAR10(root=root, train=True, download=True)
test_set = torchvision.datasets.CIFAR10(root=root, train=False, download=True)
return {
'train': {'data': train_set.train_data, 'labels': train_set.train_labels},
'test': {'data': test_set.test_data, 'labels': test_set.test_labels}
}

#####################
## data loading
#####################

class Batches():
def __init__(self, dataset, batch_size, shuffle, set_random_choices=False, num_workers=0, drop_last=False):
self.dataset = dataset
self.batch_size = batch_size
self.set_random_choices = set_random_choices
self.dataloader = torch.utils.data.DataLoader(
dataset, batch_size=batch_size, num_workers=num_workers, pin_memory=True, shuffle=shuffle, drop_last=drop_last
)

def __iter__(self):
if self.set_random_choices:
self.dataset.set_random_choices()
return ({'input': x.to(device).half(), 'target': y.to(device).long()} for (x,y) in self.dataloader)

def __len__(self):
return len(self.dataloader)

#####################
## torch stuff
#####################

class Identity(nn.Module):
def forward(self, x): return x

class Mul(nn.Module):
def __init__(self, weight):
super().__init__()
self.weight = weight
def __call__(self, x):
return x*self.weight

class Flatten(nn.Module):
def forward(self, x): return x.view(x.size(0), x.size(1))

class Add(nn.Module):
def forward(self, x, y): return x + y

class Concat(nn.Module):
def forward(self, *xs): return torch.cat(xs, 1)

class Correct(nn.Module):
def forward(self, classifier, target):
return classifier.max(dim = 1)[1] == target

def batch_norm(num_channels, bn_bias_init=None, bn_bias_freeze=False, bn_weight_init=None, bn_weight_freeze=False):
m = nn.BatchNorm2d(num_channels)
if bn_bias_init is not None:
m.bias.data.fill_(bn_bias_init)
if bn_bias_freeze:
m.bias.requires_grad = False
if bn_weight_init is not None:
m.weight.data.fill_(bn_weight_init)
if bn_weight_freeze:
m.weight.requires_grad = False

return m



class Network(nn.Module):
def __init__(self, net):
self.graph = build_graph(net)
super().__init__()
for n, (v, _) in self.graph.items():
setattr(self, n, v)

def forward(self, inputs):
self.cache = dict(inputs)
for n, (_, i) in self.graph.items():
self.cache[n] = getattr(self, n)(*[self.cache[x] for x in i])
return self.cache

def half(self):
for module in self.children():
if type(module) is not nn.BatchNorm2d:
module.half()
return self

trainable_params = lambda model:filter(lambda p: p.requires_grad, model.parameters())

class TorchOptimiser():
def __init__(self, weights, optimizer, step_number=0, **opt_params):
self.weights = weights
self.step_number = step_number
self.opt_params = opt_params
self._opt = optimizer(weights, **self.param_values())

def param_values(self):
return {k: v(self.step_number) if callable(v) else v for k,v in self.opt_params.items()}

def step(self):
self.step_number += 1
self._opt.param_groups[0].update(**self.param_values())
self._opt.step()

def __repr__(self):
return repr(self._opt)

def SGD(weights, lr=0, momentum=0, weight_decay=0, dampening=0, nesterov=False):
return TorchOptimiser(weights, torch.optim.SGD, lr=lr, momentum=momentum,
weight_decay=weight_decay, dampening=dampening,
nesterov=nesterov)
@@ -0,0 +1,94 @@
from utils import *
from torch_backend import *

#Network definition
def conv_bn(c_in, c_out, bn_weight_init=1.0, **kw):
return {
'conv': nn.Conv2d(c_in, c_out, kernel_size=3, stride=1, padding=1, bias=False),
'bn': batch_norm(c_out, bn_weight_init=bn_weight_init, **kw),
'relu': nn.ReLU(True)
}

def residual(c, **kw):
return {
'in': Identity(),
'res1': conv_bn(c, c, **kw),
'res2': conv_bn(c, c, **kw),
'add': (Add(), [rel_path('in'), rel_path('res2', 'relu')]),
}

def basic_net(channels, weight, pool, **kw):
return {
'prep': conv_bn(3, channels['prep'], **kw),
'layer1': dict(conv_bn(channels['prep'], channels['layer1'], **kw), pool=pool),
'layer2': dict(conv_bn(channels['layer1'], channels['layer2'], **kw), pool=pool),
'layer3': dict(conv_bn(channels['layer2'], channels['layer3'], **kw), pool=pool),
'pool': nn.MaxPool2d(8),
'flatten': Flatten(),
'linear': nn.Linear(channels['layer3'], 10, bias=False),
'classifier': Mul(weight),
}

def net(channels=None, weight=0.2, pool=nn.MaxPool2d(2), extra_layers=(), res_layers=('layer1', 'layer2'), **kw):
channels = channels or {'prep': 64, 'layer1': 128, 'layer2': 256, 'layer3': 256, }
n = basic_net(channels, weight, pool, **kw)
for layer in res_layers:
n[layer]['residual'] = residual(channels[layer], **kw)
for layer in extra_layers:
n[layer]['extra'] = conv_bn(channels[layer], channels[layer], **kw)
return n

losses = {
'loss': (nn.CrossEntropyLoss(reduce=False), [('classifier',), ('target',)]),
'correct': (Correct(), [('classifier',), ('target',)]),
}

class TSVLogger():
def __init__(self):
self.log = ['epoch\thours\ttop1Accuracy']
def append(self, output):
epoch, hours, acc = output['epoch'], output['total time']/3600, output['test acc']*100
self.log.append(f'{epoch}\t{hours:.8f}\t{acc:.2f}')
def __str__(self):
return '\n'.join(self.log)

def main():
DATA_DIR = './data'

print('Downloading datasets')
dataset = cifar10(DATA_DIR)

epochs = 24
lr_schedule = PiecewiseLinear([0, 5, epochs], [0, 0.4, 0.001])
batch_size = 512
train_transforms = [Crop(32, 32), FlipLR(), Cutout(8, 8)]

model = Network(union(net(), losses)).to(device).half()

print('Warming up cudnn on random inputs')
for size in [batch_size, len(dataset['test']['labels']) % batch_size]:
warmup_cudnn(model, size)

print('Starting timer')
timer = Timer()

print('Preprocessing training data')
train_set = list(zip(transpose(normalise(pad(dataset['train']['data'], 4))), dataset['train']['labels']))
print(f'Finished in {timer():.2} seconds')
print('Preprocessing test data')
test_set = list(zip(transpose(normalise(dataset['test']['data'])), dataset['test']['labels']))
print(f'Finished in {timer():.2} seconds')

TSV = TSVLogger()

train_batches = Batches(Transform(train_set, train_transforms), batch_size, shuffle=True, set_random_choices=True, drop_last=True)
test_batches = Batches(test_set, batch_size, shuffle=False, drop_last=False)
lr = lambda step: lr_schedule(step/len(train_batches))/batch_size
opt = SGD(trainable_params(model), lr=lr, momentum=0.9, weight_decay=5e-4*batch_size, nesterov=True)

train(model, opt, train_batches, test_batches, epochs, loggers=(TableLogger(), TSV), timer=timer, test_time_in_total=False)

with open('logs_training.tsv','w') as f:
f.write(str(TSV))

main()

0 comments on commit 856ec31

Please sign in to comment.
You can’t perform that action at this time.