Skip to content

Commit

Permalink
[Caffe] Fix batch accumulation bug
Browse files Browse the repository at this point in the history
  • Loading branch information
lukeyeager committed Nov 11, 2016
1 parent 4308a53 commit 4063a1e
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 4 deletions.
38 changes: 38 additions & 0 deletions digits/model/images/classification/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import itertools
import json
import math
import os
import shutil
import tempfile
Expand All @@ -16,13 +17,16 @@
from StringIO import StringIO

from bs4 import BeautifulSoup
from google.protobuf import text_format

from digits.config import config_value
import digits.dataset.images.classification.test_views
from digits.frameworks import CaffeFramework
import digits.test_views
from digits import test_utils
import digits.webapp

import caffe_pb2

# May be too short on a slow system
TIMEOUT_DATASET = 45
Expand Down Expand Up @@ -101,6 +105,10 @@ def model_exists(cls, job_id):
def model_status(cls, job_id):
return cls.job_status(job_id, 'models')

@classmethod
def model_info(cls, job_id):
return cls.job_info(job_id, 'models')

@classmethod
def abort_model(cls, job_id):
return cls.abort_job(job_id, job_type='models')
Expand Down Expand Up @@ -1254,3 +1262,33 @@ def test_sweep(self):
assert self.model_wait_completion(job_id) == 'Done', 'create failed'
assert self.delete_model(job_id) == 200, 'delete failed'
assert not self.model_exists(job_id), 'model exists after delete'


@unittest.skipIf(
not CaffeFramework().can_accumulate_gradients(),
'This version of Caffe cannot accumulate gradients')
class TestBatchAccumulationCaffe(BaseViewsTestWithDataset, test_utils.CaffeMixin):
TRAIN_EPOCHS = 1
IMAGE_COUNT = 10 # per class

def test_batch_accumulation_calculations(self):
batch_size = 10
batch_accumulation = 2

job_id = self.create_model(
batch_size=batch_size,
batch_accumulation=batch_accumulation,
)
assert self.model_wait_completion(job_id) == 'Done', 'create failed'
info = self.model_info(job_id)
solver = caffe_pb2.SolverParameter()
with open(os.path.join(info['directory'], info['solver file']), 'r') as infile:
text_format.Merge(infile.read(), solver)
assert solver.iter_size == batch_accumulation, \
'iter_size is %d instead of %d' % (solver.iter_size, batch_accumulation)
max_iter = int(math.ceil(
float(self.TRAIN_EPOCHS * self.IMAGE_COUNT * 3) /
(batch_size * batch_accumulation)
))
assert solver.max_iter == max_iter,\
'max_iter is %d instead of %d' % (solver.max_iter, max_iter)
12 changes: 8 additions & 4 deletions digits/model/tasks/caffe_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,8 +525,10 @@ def save_files_classification(self):
solver.iter_size = self.batch_accumulation

# Epochs -> Iterations
train_iter = int(math.ceil(float(self.dataset.get_entry_count(
constants.TRAIN_DB)) / train_data_layer.data_param.batch_size))
train_iter = int(math.ceil(
float(self.dataset.get_entry_count(constants.TRAIN_DB)) /
(train_data_layer.data_param.batch_size * solver.iter_size)
))
solver.max_iter = train_iter * self.train_epochs
snapshot_interval = self.snapshot_interval * train_iter
if 0 < snapshot_interval <= 1:
Expand Down Expand Up @@ -753,8 +755,10 @@ def save_files_generic(self):
solver.iter_size = self.batch_accumulation

# Epochs -> Iterations
train_iter = int(math.ceil(float(self.dataset.get_entry_count(constants.TRAIN_DB)) /
train_image_data_layer.data_param.batch_size))
train_iter = int(math.ceil(
float(self.dataset.get_entry_count(constants.TRAIN_DB)) /
(train_image_data_layer.data_param.batch_size * solver.iter_size)
))
solver.max_iter = train_iter * self.train_epochs
snapshot_interval = self.snapshot_interval * train_iter
if 0 < snapshot_interval <= 1:
Expand Down

0 comments on commit 4063a1e

Please sign in to comment.