# Load data

https://github.com/kimiyoung/transformer-xl/blob/44781ed21dbaec88b280f74d9ae2877f52b492a5/pytorch/train.py

In [4]:
import os, sys
import torch
from data_utils import get_lm_corpus

# sys.path.append("../") # go to parent dir

https://github.com/kimiyoung/transformer-xl/blob/44781ed21dbaec88b280f74d9ae2877f52b492a5/pytorch/train.py#L181

In [6]:
class TransformerLMParams:
    def __init__(self, mode='train'):
        # Common parameters
        self.data = '../data/enwik8/'
        self.dataset = 'enwik8'
        self.cuda = True
        
        if mode == 'train':
            # Training specific parameters
            self.n_layer = 12
            self.d_model = 512
            self.n_head = 8
            self.d_head = 64
            self.d_inner = 2048
            self.dropout = 0.1
            self.dropatt = 0.0
            self.optim = 'adam'
            self.lr = 0.00025
            self.warmup_step = 0
            self.max_step = 400000
            self.tgt_len = 512
            self.mem_len = 512
            self.eval_tgt_len = 128
            self.batch_size = 22
            self.multi_gpu = True
            self.gpu0_bsz = 4
        elif mode == 'eval':
            # Evaluation specific parameters
            self.tgt_len = 80
            self.mem_len = 2100
            self.clamp_len = 820
            self.same_length = True
            self.split = 'test'
        else:
            raise ValueError("Mode must be either 'train' or 'eval'")
            
        # Default parameters not specified in bash script
        self.d_embed = -1
        self.init = 'normal'
        self.emb_init = 'normal'
        self.init_range = 0.1
        self.emb_init_range = 0.01
        self.init_std = 0.02
        self.proj_init_std = 0.01
        self.mom = 0.0
        self.scheduler = 'cosine'
        self.decay_rate = 0.5
        self.lr_min = 0.0
        self.clip = 0.25
        self.clip_nonemb = False
        self.batch_chunk = 1
        self.ext_len = 0
        self.not_tied = False
        self.tied = True
        self.seed = 1111
        self.adaptive = False
        self.div_val = 1
        self.pre_lnorm = False
        self.varlen = False
        self.log_interval = 200
        self.eval_interval = 4000
        self.work_dir = 'LM-TFM'
        self.restart = False
        self.restart_dir = ''
        self.debug = False
        self.attn_type = 0
        self.eta_min = 0.0
        self.max_eval_steps = -1
        self.sample_softmax = -1
        self.patience = 0
        self.finetune_v2 = False
        self.finetune_v3 = False
        self.fp16 = False
        self.static_loss_scale = 1
        self.dynamic_loss_scale = False

    def update(self, **kwargs):
        for key, value in kwargs.items():
            if hasattr(self, key):
                setattr(self, key, value)
                # Special case for tied/not_tied
                if key == 'not_tied':
                    self.tied = not value
                elif key == 'tied':
                    self.not_tied = not value
            else:
                raise AttributeError(f"'{self.__class__.__name__}' has no attribute '{key}'")

# Usage examples:
# For training:
args = TransformerLMParams(mode='train')
# train_args.update(lr=0.001)  # Optionally modify any parameters

# For evaluation:
# eval_args = TransformerLMParams(mode='eval')
# eval_args.update(mem_len=2500)  # Optionally modify any parameters

In [8]:
corpus = get_lm_corpus(args.data, args.dataset)
ntokens = len(corpus.vocab)
args.n_token = ntokens

device = "cpu"

eval_batch_size = 10
tr_iter = corpus.get_iterator('train', args.batch_size, args.tgt_len,
    device=device, ext_len=args.ext_len)
va_iter = corpus.get_iterator('valid', eval_batch_size, args.eval_tgt_len,
    device=device, ext_len=args.ext_len)
te_iter = corpus.get_iterator('test', eval_batch_size, args.eval_tgt_len,
    device=device, ext_len=args.ext_len)

# adaptive softmax / embedding
cutoffs, tie_projs = [], [False]
if args.adaptive:
    assert args.dataset in ['wt103', 'lm1b']
    if args.dataset == 'wt103':
        cutoffs = [20000, 40000, 200000]
        tie_projs += [True] * len(cutoffs)
    elif args.dataset == 'lm1b':
        cutoffs = [60000, 100000, 640000]
        tie_projs += [False] * len(cutoffs)

Loading cached dataset...


  corpus = torch.load(fn)


In [11]:
for batch_idx, (data, target, seq_len) in enumerate(tr_iter):
    # data: input tokens with shape [batch_size, tgt_len + ext_len]
    # target: target tokens with shape [batch_size, tgt_len]
    # seq_len: actual sequence length (may be less than tgt_len for last batch)
    
    print(f"Batch {batch_idx}:")
    print(f"  Input data shape: {data.shape}")
    print(f"  Target shape: {target.shape}")
    print(f"  Sequence length: {seq_len}")
    
    if batch_idx > 2:
        break

Batch 0:
  Input data shape: torch.Size([512, 22])
  Target shape: torch.Size([512, 22])
  Sequence length: 512
Batch 1:
  Input data shape: torch.Size([512, 22])
  Target shape: torch.Size([512, 22])
  Sequence length: 512
Batch 2:
  Input data shape: torch.Size([512, 22])
  Target shape: torch.Size([512, 22])
  Sequence length: 512
Batch 3:
  Input data shape: torch.Size([512, 22])
  Target shape: torch.Size([512, 22])
  Sequence length: 512


In [12]:
data

tensor([[  7,   0,  20,  ...,  32,   7,   0],
        [  6,  12,  24,  ...,   0, 130,   2],
        [  3,   3,  15,  ...,   5,  90,  10],
        ...,
        [ 71,   0,   5,  ...,   0,  25,   0],
        [ 63,  15,  19,  ...,   3,   0,  26],
        [  0,  15,   0,  ...,   0,   5,   4]])

In [13]:
target

tensor([[  6,  12,  24,  ...,   0, 130,   2],
        [  3,   3,  15,  ...,   5,  90,  10],
        [ 16,   9,  15,  ...,  19,  69,   1],
        ...,
        [ 63,  15,  19,  ...,   3,   0,  26],
        [  0,  15,   0,  ...,   0,   5,   4],
        [ 19,  17,   2,  ...,   9,   7,   5]])