Skip to content

Commit

Permalink
Phase1 merge: vit optimizations + dataset enhancements + scaled_softm…
Browse files Browse the repository at this point in the history
…ax kernel
  • Loading branch information
kvareddy committed Jan 12, 2022
1 parent 9a8b89a commit 7a77abd
Show file tree
Hide file tree
Showing 19 changed files with 1,012 additions and 132 deletions.
13 changes: 11 additions & 2 deletions megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -835,11 +835,20 @@ def _add_vit_args(parser):

group.add_argument('--num-classes', type=int, default=1000,
help='num of classes in vision classificaiton task')
group.add_argument('--img-dim', type=int, default=224,
help='Image size for vision classification task')
group.add_argument('--img-h', type=int, default=224,
help='Image height for vision classification task')
group.add_argument('--img-w', type=int, default=224,
help='Image height for vision classification task')
group.add_argument('--num-channels', type=int, default=3,
help='Number of channels in input image data')
group.add_argument('--patch-dim', type=int, default=16,
help='patch dimension used in vit')
group.add_argument('--classes-fraction', type=float, default=1.0,
help='training with fraction of classes.')
group.add_argument('--data-per-class-fraction', type=float, default=1.0,
help='training with fraction of data per class.')
group.add_argument('--no-data-sharding', action='store_false',
help='Disable data sharding.',
dest='data_sharding')

return parser
68 changes: 55 additions & 13 deletions megatron/data/data_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@
"""Dataloaders."""


import torch
import random
import torch
import numpy as np
from torch.utils.data import Dataset
from megatron import get_args
from megatron import mpu

Expand All @@ -39,11 +41,13 @@ def build_pretraining_data_loader(dataset, consumed_samples):
data_parallel_size=mpu.get_data_parallel_world_size())
elif args.dataloader_type == 'cyclic':
batch_sampler = MegatronPretrainingRandomSampler(
dataset,
total_samples=len(dataset),
consumed_samples=consumed_samples,
micro_batch_size=args.micro_batch_size,
data_parallel_rank=mpu.get_data_parallel_rank(),
data_parallel_size=mpu.get_data_parallel_world_size())
data_parallel_size=mpu.get_data_parallel_world_size(),
data_sharding=args.data_sharding)
else:
raise Exception('{} dataloader type is not supported.'.format(
args.dataloader_type))
Expand Down Expand Up @@ -103,16 +107,40 @@ def __iter__(self):
yield batch[start_idx:end_idx]


class RandomSeedDataset(Dataset):

def __init__(self, dataset):
args = get_args()
self.base_seed = args.seed
self.curr_seed = args.seed
self.dataset = dataset

def __len__(self):
return len(self.dataset)

def set_epoch(self, epoch):
self.curr_seed = self.base_seed + epoch

def __getitem__(self, idx):
seed = idx + self.curr_seed
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
return self.dataset[idx]


class MegatronPretrainingRandomSampler:

def __init__(self, total_samples, consumed_samples, micro_batch_size,
data_parallel_rank, data_parallel_size):
def __init__(self, dataset, total_samples, consumed_samples, micro_batch_size,
data_parallel_rank, data_parallel_size, data_sharding):
# Keep a copy of input params for later use.
self.dataset = dataset
self.total_samples = total_samples
self.consumed_samples = consumed_samples
self.micro_batch_size = micro_batch_size
self.data_parallel_rank = data_parallel_rank
self.data_parallel_size = data_parallel_size
self.data_sharding = data_sharding
self.micro_batch_times_data_parallel_size = \
self.micro_batch_size * data_parallel_size
self.last_batch_size = \
Expand All @@ -136,16 +164,30 @@ def __iter__(self):
current_epoch_samples = self.consumed_samples % active_total_samples
assert current_epoch_samples % self.micro_batch_times_data_parallel_size == 0

if isinstance(dataset, RandomSeedDataset):
self.dataset.set_epoch(self.epoch)

# data sharding and random sampling
bucket_size = (self.total_samples // self.micro_batch_times_data_parallel_size) \
* self.micro_batch_size
bucket_offset = current_epoch_samples // self.data_parallel_size
start_idx = self.data_parallel_rank * bucket_size

g = torch.Generator()
g.manual_seed(self.epoch)
random_idx = torch.randperm(bucket_size, generator=g).tolist()
idx_range = [start_idx + x for x in random_idx[bucket_offset:]]
if self.data_sharding:
bucket_size = (self.total_samples // self.micro_batch_times_data_parallel_size) \
* self.micro_batch_size
bucket_offset = current_epoch_samples // self.data_parallel_size
start_idx = self.data_parallel_rank * bucket_size

g = torch.Generator()
g.manual_seed(self.epoch)
random_idx = torch.randperm(bucket_size, generator=g).tolist()
idx_range = [start_idx + x for x in random_idx[bucket_offset:]]
else:
full_bucket_size = (self.total_samples // self.micro_batch_size) \
* self.micro_batch_size
full_bucket_offset = current_epoch_samples
g = torch.Generator()
g.manual_seed(self.epoch)
idx_range_total = \
torch.randperm(full_bucket_size, generator=g).tolist()
idx_range_active = idx_range_total[full_bucket_offset:]
idx_range = idx_range_active[self.data_parallel_rank::self.data_parallel_size]

batch = []
# Last batch if not complete will be dropped.
Expand Down
Loading

0 comments on commit 7a77abd

Please sign in to comment.