### Set GPU clocks

In [7]:
from core import *
from torch_backend import *

### Network definition

In [8]:
def conv_bn(c_in, c_out):
    return {
        'conv': nn.Conv2d(c_in, c_out, kernel_size=3, stride=1, padding=1, bias=False), 
        'bn': BatchNorm(c_out), 
        'relu': nn.ReLU(True)
    }

def residual(c):
    return {
        'in': Identity(),
        'res1': conv_bn(c, c),
        'res2': conv_bn(c, c),
        'add': (Add(), ['in', 'res2/relu']),
    }

def net(channels=None, weight=0.125, pool=nn.MaxPool2d(2), extra_layers=(), res_layers=('layer1', 'layer3')):
    channels = channels or {'prep': 64, 'layer1': 128, 'layer2': 256, 'layer3': 512}
    n = {
        'input': (None, []),
        'prep': conv_bn(3, channels['prep']),
        'layer1': dict(conv_bn(channels['prep'], channels['layer1']), pool=pool),
        'layer2': dict(conv_bn(channels['layer1'], channels['layer2']), pool=pool),
        'layer3': dict(conv_bn(channels['layer2'], channels['layer3']), pool=pool),
        'pool': nn.MaxPool2d(4),
        'flatten': Flatten(),
        'linear': nn.Linear(channels['layer3'], 10, bias=False),
        'logits': Mul(weight),
    }
    for layer in res_layers:
        n[layer]['residual'] = residual(channels[layer])
    for layer in extra_layers:
        n[layer]['extra'] = conv_bn(channels[layer], channels[layer])       
    return n

### Download and preprocess data

In [9]:
DATA_DIR = './data'
dataset = cifar10(root=DATA_DIR)
timer = Timer()
print('Preprocessing training data')
transforms = [
    partial(normalise, mean=np.array(cifar10_mean, dtype=np.float32), std=np.array(cifar10_std, dtype=np.float32)),
    partial(transpose, source='NHWC', target='NCHW'), 
]
train_set = list(zip(*preprocess(dataset['train'], [partial(pad, border=4)] + transforms).values()))
print(f'Finished in {timer():.2} seconds')
print('Preprocessing test data')
valid_set = list(zip(*preprocess(dataset['valid'], transforms).values()))
print(f'Finished in {timer():.2} seconds')

Preprocessing training data
Finished in 2.4 seconds
Preprocessing test data
Finished in 0.046 seconds


### Network visualisation

In [10]:
# colors = ColorMap()
# draw = lambda graph: DotGraph({p: ({'fillcolor': colors[type(v)], 'tooltip': repr(v)}, inputs) for p, (v, inputs) in graph.items() if v is not None})

# draw(build_graph(net()))

### Training

NB: on the first run, the first epoch will be slower as initialisation and Cudnn benchmarking take place.

In [11]:
def save_run(dirnames_csv, filename_csv, trace, train_accs, valid_accs):
    if not os.path.exists(os.path.join(*dirnames_csv)):
        os.makedirs(os.path.join(*dirnames_csv))
    with open(filename_csv, 'w') as file:
        for index, sample in enumerate(trace):
            if index == 0:
                string = f'epoch,duration,train acc,valid acc,'
                for domain in sample.energy.keys():
                    string += f'{domain} energy,'
                for domain in sample.energy.keys():
                    string += f'{domain} power,'
                string = string[:-1] + '\n'
                file.write(string)
            string = f'{sample.tag},{sample.duration},{train_accs[index]},{valid_accs[index]},'
            for domain, energy in sample.energy.items():
                if 'nvidia_gpu' in domain:
                    energy *= 1000
                if energy < 0:
                    if 'package' in domain:
                        energy += 262143328850
                    elif 'dram' in domain:
                        energy += 65712999613
                energy /= 1e6
                string += f'{energy},'
            for domain, energy in sample.energy.items():
                if 'nvidia_gpu' in domain:
                    energy *= 1000
                if energy < 0:
                    if 'package' in domain:
                        energy += 262143328850
                    elif 'dram' in domain:
                        energy += 65712999613 
                energy /= 1e6
                string += f'{energy/sample.duration},'
            string = string[:-1] + '\n'
            file.write(string)

In [12]:
# import time
import os
import datetime

from pyJoules.device import DeviceFactory
from pyJoules.energy_meter import EnergyMeter

epochs = 1
N_runs = 1

DIRNAMES_CSV = ['csv', datetime.datetime.now().strftime(f'cifar10-fast-%Y-%m-%d-%H-%M-%S')]

lr_schedule = PiecewiseLinear([0, 5, epochs], [0, 0.4, 0])
batch_size = 512
train_transforms = [Crop(32, 32), FlipLR(), Cutout(8, 8)]

train_batches = DataLoader(Transform(train_set, train_transforms), batch_size, shuffle=True, set_random_choices=True, drop_last=True)
valid_batches = DataLoader(valid_set, batch_size, shuffle=False, drop_last=False)
lr = lambda step: lr_schedule(step/len(train_batches))/batch_size

summaries = []
for i in range(N_runs):
    print(f'Starting run {i} at {localtime()}')
    devices = DeviceFactory.create_devices()
    meter = EnergyMeter(devices)
    
    
    start_time = time.time()
    model = Network(net()).to(device)
    if device != cpu:
        model = model.half()
    opts = [SGD(trainable_params(model).values(), {'lr': lr, 'weight_decay': Const(5e-4*batch_size), 'momentum': Const(0.9)})]
    logs, state = Table(), {MODEL: model, LOSS: x_ent_loss, OPTS: opts}
    train_accs = []
    valid_accs = []
    for epoch in range(epochs):
        if epoch == 0:
            meter.start(tag=1)
        else:
            meter.record(tag=epoch+1)
            
        print(f'Epoch {epoch}/{epochs-1}')
        synch = None if device == cpu else torch.cuda.synchronize
        stats = train_epoch(state, Timer(synch), train_batches, valid_batches)
        #logs.append(union({'epoch': epoch+1}, stats))
        
        train_accs.append(stats['train']['acc'])
        valid_accs.append(stats['valid']['acc'])
        
    meter.stop()
    
    print(f'Duration: {time.time() - start_time}')
    
    trace = meter.get_trace()
    
    FILENAME_CSV = os.path.join(
        *DIRNAMES_CSV,
        f'run-{i:02d}.csv')
    
    save_run(DIRNAMES_CSV, FILENAME_CSV, trace, train_accs, valid_accs)

Starting run 0 at 2021-11-24 10:32:42
Epoch 0/0


KeyboardInterrupt: 