Skip to content

Commit

Permalink
data state restoration, first pass
Browse files Browse the repository at this point in the history
  • Loading branch information
Mickus Timothee committed May 20, 2024
1 parent c6995b7 commit 441f13c
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 11 deletions.
18 changes: 13 additions & 5 deletions mammoth/inputters/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def __init__(self, dataset, score_fn):
self._it = iter(self.dataset)
self._prev = next(self._it)
self._score = self.score_fn(self._prev)
self._current_line_idx = None

def peek_at_score(self):
return self._score
Expand Down Expand Up @@ -174,7 +175,7 @@ def __iter__(self):
for _ in range(self.batch_size):
_, example = self._sie.next()
minibatch.append(example)
yield self.collate_fn(minibatch)
yield self.collate_fn(accum, self._sie._current_line_idx)


class DynamicDatasetIter(object):
Expand Down Expand Up @@ -212,6 +213,7 @@ def __init__(
batch_size_multiple,
max_look_ahead_sentences=2048,
lookahead_minibatches=4,
line_idx_restore=None,
):
self.task_queue_manager = task_queue_manager
self.opts = opts
Expand All @@ -225,11 +227,10 @@ def __init__(
self.batch_size = batch_size
self.batch_size_multiple = batch_size_multiple
self.device = 'cpu'
self.max_look_ahead_sentences = max_look_ahead_sentences
self.lookahead_minibatches = lookahead_minibatches
self.line_idx_restore = dict() if line_idx_restore is None else line_idx_restore

@classmethod
def from_opts(cls, task_queue_manager, transforms_cls, vocabs_dict, opts, is_train):
def from_opts(cls, task_queue_manager, transforms_cls, vocabs_dict, opts, is_train, line_idx_restore):
"""Initilize `DynamicDatasetIter` with options parsed from `opts`."""
batch_size = opts.batch_size if is_train else opts.valid_batch_size
if opts.batch_size_multiple is not None:
Expand All @@ -248,6 +249,7 @@ def from_opts(cls, task_queue_manager, transforms_cls, vocabs_dict, opts, is_tra
batch_size_multiple,
max_look_ahead_sentences=opts.max_look_ahead_sentences,
lookahead_minibatches=opts.lookahead_minibatches,
line_idx_restore=line_idx_restore,
)

def _init_datasets(self):
Expand All @@ -272,7 +274,12 @@ def _init_datasets(self):
# is defined
if self.is_train or self.opts.tasks[task.corpus_id].get('path_valid_src', None) is not None:
corpus = get_corpus(
self.opts, task, src_vocab, tgt_vocab, is_train=self.is_train
self.opts,
task,
src_vocab,
tgt_vocab,
is_train=self.is_train,
line_idx_restore=self.line_idx_restore.get(task.corpus_id, None),
).to(device)

# iterator over minibatches
Expand Down Expand Up @@ -311,6 +318,7 @@ def __iter__(self):
batch = next(ordered_iter)
if batch_task_sample.training_step == 0 and self.opts.verbose:
# De-numericalize a few sentences for debugging
# FIXME should be debug, not warn
logger.warning(
f'src shape: {batch.src[0].shape} tgt shape: {batch.tgt.shape} '
f'batch size: {batch.batch_size}'
Expand Down
33 changes: 29 additions & 4 deletions mammoth/inputters/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class Batch():
tgt: torch.Tensor
labels: torch.Tensor
batch_size: int
line_idx: int

def to(self, device):
self.src = (self.src[0].to(device), self.src[1].to(device))
Expand All @@ -40,13 +41,16 @@ def read_examples_from_files(
):
"""Helper function to read examples"""

line_idx_generator = itertools.count()

def _make_example_dict(packed):
"""Helper function to convert lines to dicts"""
src_str, tgt_str = packed
return {
'src': tokenize_fn(src_str, side='src'),
'tgt': tokenize_fn(tgt_str, side='tgt') if tgt_str is not None else None,
# 'align': None,
'line_idx': next(line_idx_generator)
}

if src_path.endswith('.gz'):
Expand Down Expand Up @@ -90,6 +94,7 @@ def __init__(
is_train=False,
task=None,
max_length=None,
line_idx_restore=None,
):
self.src_file = src_file
self.tgt_file = tgt_file
Expand All @@ -104,6 +109,7 @@ def __init__(
self.is_train = is_train
self.corpus_id = task.corpus_id
self.max_length = max_length
self._line_idx_restore = line_idx_restore

# FIXME: most likely redundant with mammoth.transforms.tokenize
def _tokenize(self, string, side='src'):
Expand Down Expand Up @@ -149,6 +155,17 @@ def _cast(example_dict):
if v is not None
}

# ensure we only restore the first time the corpus is restored
if self._line_idx_restore is not None:
if self.stride is not None:
# sanity check
assert self._line_idx_restore % self.stride == 0, \
'Stride is inconsistent with data restoration index'
offset = self._line_idx_restore
self._line_idx_restore = None
else:
offset = self.offset

examples = read_examples_from_files(
self.src_file,
self.tgt_file,
Expand All @@ -162,13 +179,13 @@ def _cast(example_dict):
if self.transforms is not None else lambda x: x
),
stride=self.stride,
offset=self.offset,
offset=offset,
)
examples = map(_cast, examples)
yield from examples

# FIXME: some RNN archs require sorting src's by length
def collate_fn(self, examples):
def collate_fn(self, examples, line_idx):
has_tgt = 'tgt' in examples[0].keys()
src_padidx = self.vocabs['src'][DefaultTokens.PAD]
tgt_padidx = self.vocabs['tgt'][DefaultTokens.PAD]
Expand All @@ -186,11 +203,18 @@ def collate_fn(self, examples):
else:
tgt = None
labels = None
batch = Batch(src, tgt, labels, len(examples))
batch = Batch(src, tgt, labels, len(examples), line_idx)
return batch


def get_corpus(opts, task, src_vocab: Vocab, tgt_vocab: Vocab, is_train: bool = False):
def get_corpus(
opts,
task,
src_vocab: Vocab,
tgt_vocab: Vocab,
is_train: bool = False,
line_idx_restore: int = None,
):
"""build an iterable Dataset object"""
# get transform classes to infer special tokens
# FIXME ensure TQM properly initializes transform with global if necessary
Expand Down Expand Up @@ -223,6 +247,7 @@ def get_corpus(opts, task, src_vocab: Vocab, tgt_vocab: Vocab, is_train: bool =
is_train=is_train,
task=task,
max_length=max_length,
line_idx_restore=line_idx_restore,
)
return dataset

Expand Down
13 changes: 11 additions & 2 deletions mammoth/models/model_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from mammoth.utils.logging import logger

import torch
import torch.distributed
import torch.nn as nn

from mammoth.utils.module_splitter import explode_model
Expand Down Expand Up @@ -65,7 +66,7 @@ def __init__(
self.device_context = device_context
self.all_gpus = all_gpus

def save(self, step, moving_average=None):
def save(self, step, data_state, moving_average=None):
"""Main entry point for model saver
It wraps the `_save` method with checks and apply `keep_checkpoint`
Expand Down Expand Up @@ -125,7 +126,8 @@ def _rm_checkpoint(self, name):
class ModelSaver(ModelSaverBase):
"""Simple model saver to filesystem"""

def _save(self, step, model, device_context):
# FIXME does not match the base class signature
def _save(self, step, model, device_context, data_state):
real_model = model.module if isinstance(model, nn.DataParallel) else model

model_state_dict = real_model.state_dict()
Expand Down Expand Up @@ -163,6 +165,13 @@ def _save(self, step, model, device_context):
torch.save(module, checkpoint_path)
tmp_checkpoint_paths.append(checkpoint_path)

# In a distributed context, aggregate all data states for corpus restoration
if device_context.is_distributed():
data_states = [None for _ in range(device_context.world_size)]
torch.distributed.all_gather_object(data_states, data_state)
data_state = {k: v for state in data_states for k, v in state.items()}

model_frame['data_state'] = data_state
if device_context.is_master():
# TODO: not sure how to deal with model_state_dict, fields, model_opts and optim.state_dict() in a multi-gpu
# setting. Is it OK to save only from master?
Expand Down
6 changes: 6 additions & 0 deletions mammoth/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,8 @@ def __init__(

self.task_queue_manager = task_queue_manager

self._data_state = {}

for i in range(len(self.accum_count_l)):
assert self.accum_count_l[i] > 0

Expand Down Expand Up @@ -421,6 +423,10 @@ def _gradient_accumulation(
f'Received {metadata},\n expected {expected_metadata}'
)
seen_comm_batches.add(comm_batch)

# update data state
self._data_state[metadata.corpus_id] = batch.line_idx

if self.norm_method == "tokens":
num_tokens = (
batch.labels[1:, :, 0].ne(self.train_loss_md[f'trainloss{metadata.tgt_lang}'].padding_idx).sum()
Expand Down

0 comments on commit 441f13c

Please sign in to comment.