Skip to content

Commit

Permalink
Merge pull request #487 from EleutherAI/t5_rpe_scale
Browse files Browse the repository at this point in the history
T5 rpe scale
  • Loading branch information
StellaAthena committed Jan 11, 2022
2 parents 3ad6195 + 6dc7645 commit b754f75
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 5 deletions.
9 changes: 8 additions & 1 deletion megatron/model/gpt2_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

"""GPT-2 model."""

import math
import torch
import torch.nn as nn
from collections import defaultdict
Expand Down Expand Up @@ -157,7 +158,13 @@ def insert_layers(self, layers: Union[nn.Module, nn.ModuleList, nn.Sequential, L
def init_specs(self):
weight_tying = not self.neox_args.no_weight_tying
if self.embedding_type == 'rpe':
rpe_emb = ParallelRelativePositionBias(neox_args=self.neox_args, causal=True,
hidden_size_per_attention_head = mpu.divide(
self.neox_args.hidden_size, self.neox_args.num_attention_heads
)
rpe_scale = math.sqrt(hidden_size_per_attention_head)
rpe_emb = ParallelRelativePositionBias(neox_args=self.neox_args,
scale=rpe_scale,
causal=True,
num_buckets=self.neox_args.rpe_num_buckets,
max_distance=self.neox_args.rpe_max_distance,
heads=self.neox_args.num_attention_heads)
Expand Down
9 changes: 5 additions & 4 deletions megatron/mpu/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
from .random import get_cuda_rng_tracker
from .utils import divide
from .utils import VocabUtility
from einops import rearrange

def _initialize_affine_weight_gpu(weight, init_method,
partition_dim, stride=1):
Expand Down Expand Up @@ -163,14 +162,16 @@ class ParallelRelativePositionBias(torch.nn.Module):
and adapted for megatron's model parallelism
Arguments:
scale: scaling factor for the bias
causal: flag for causal/non-causal language modelling.
num_buckets: number of rp buckets.
max_distance: max distance in sequence dim for each bucket.
heads: number of attention heads (total)
"""

def __init__(self, neox_args, causal=True, num_buckets=32, max_distance=128, heads=8, init_method=init.xavier_normal_):
def __init__(self, neox_args, scale, causal=True, num_buckets=32, max_distance=128, heads=8, init_method=init.xavier_normal_):
super().__init__()
self.scale = scale
self.causal = causal
self.num_buckets = num_buckets
self.max_distance = max_distance
Expand Down Expand Up @@ -251,8 +252,8 @@ def forward(self, q_len, k_len):
rp_bucket = self._rel_pos_bucket_cached
values = F.embedding(rp_bucket, self.weight, self.padding_idx,
self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse)
bias = rearrange(values, 'i j h -> () h i j')
return bias
bias = values.movedim(2,0).unsqueeze(0)
return bias * self.scale


class ColumnParallelLinear(torch.nn.Module):
Expand Down

0 comments on commit b754f75

Please sign in to comment.