In [5]:
# 训练
VOCAB_DIR="generated/vocab"
VOCAB_FILE="songci.subword"


from train import training_loop
from preprocessing.preprocessing import generate_data
from model import ReformerLM
import trax
import numpy as np
from trax import layers as tl

In [6]:
train_stream, eval_stream, vocab_size = generate_data()

print("vocab size:", vocab_size)

# the stream generators will yield (input, target, weights). let's just grab the input for inspection
inp, _, _ = next(train_stream)

# print the shape. format is (batch size, token length)
print("input shape: ", inp.shape)

# detokenize the first element
print(trax.data.detokenize(inp[0], vocab_dir=VOCAB_DIR, vocab_file=VOCAB_FILE))


vocab size: 5943
input shape:  (8, 256)
题：摸鱼儿 作：张炎 文：又孤吟、灞桥深雪，千山绝尽飞鸟。梅花也著东风笑，一夜瘦添多少。春悄悄。正断梦愁诗，忘却池塘草。前村路杳。看野水流冰，舟闲渡口，何必见安道。慵登眺。脉脉霏霏未了。寒威犹自清峭。终须几日开晴去，无奈此时怀抱。空暗恼。料酒兴歌情，未肯随人老。惜花起早。拚醉□忘归，接B565更好，一笑任倾倒。


In [7]:
temp_model = ReformerLM('train')
print(str(temp_model))

# free memory
del temp_model 

Serial[
  Serial[
    Serial[
      ShiftRight(1)
    ]
    Embedding_train_512
    Dropout
    PositionalEncoding
    Dup_out2
    ReversibleSerial_in2_out2[
      ReversibleHalfResidual_in2_out2[
        Serial[
          LayerNorm
        ]
        SelfAttention
      ]
      ReversibleSwap_in2_out2
      ReversibleHalfResidual_in2_out2[
        Serial[
          LayerNorm
          Dense_2048
          Dropout
          Serial[
            FastGelu
          ]
          Dense_512
          Dropout
        ]
      ]
      ReversibleSwap_in2_out2
      ReversibleHalfResidual_in2_out2[
        Serial[
          LayerNorm
        ]
        SelfAttention
      ]
      ReversibleSwap_in2_out2
      ReversibleHalfResidual_in2_out2[
        Serial[
          LayerNorm
          Dense_2048
          Dropout
          Serial[
            FastGelu
          ]
          Dense_512
          Dropout
        ]
      ]
      ReversibleSwap_in2_out2
    ]
    Concatenate_in2
    LayerNorm
    Dropo

## 从头训练
直接从头开始训练

In [None]:
loop = training_loop()
loop.run(10000)

In [11]:
shape11 = trax.shapes.ShapeDtype((1, 1), dtype=np.int32)

def attention(*args, **kwargs):
    kwargs['predict_mem_len'] = 120  # max length for predictions
    kwargs['predict_drop_len'] = 120  # never drop old stuff
    return tl.SelfAttention(*args, **kwargs)

model = ReformerLM(
    vocab_size=vocab_size,
    n_layers=6,
    mode='predict',
    attention_type=attention,
)


# initialize from file
model.init_from_file('model/model.pkl.gz',
                     weights_only=True, input_signature=shape11)

# save the starting state
STARTING_STATE = model.state

In [12]:
def tokenize(sentence, vocab_file, vocab_dir):
    return list(trax.data.tokenize(iter([sentence]), vocab_file=vocab_file, vocab_dir=vocab_dir))[0]

def detokenize(tokens, vocab_file, vocab_dir):
    return trax.data.detokenize(tokens, vocab_file=vocab_file, vocab_dir=vocab_dir)

In [13]:
def ReformerLM_output_gen(ReformerLM, start_sentence, vocab_file, vocab_dir, temperature):
    """
    Args:
        ReformerLM:  the Reformer language model you just trained
        start_sentence (string): starting sentence of the conversation
        vocab_file (string): vocabulary filename
        vocab_dir (string): directory of the vocabulary file
        temperature (float): parameter for sampling ranging from 0.0 to 1.0.
            0.0: same as argmax, always pick the most probable token
            1.0: sampling from the distribution (can sometimes say random things)

    Returns:
        generator: yields the next symbol generated by the model
    """
    
    ### START CODE HERE (REPLACE INSTANCES OF 'None' WITH YOUR CODE) ###
    
    # Create input tokens using the the tokenize function
    input_tokens = tokenize(start_sentence, vocab_file, vocab_dir)
    
    # Add batch dimension to array. Convert from (n,) to (x, n) where 
    # x is the batch size. Default is 1. (hint: you can use np.expand_dims() with axis=0)
    input_tokens_with_batch = np.expand_dims(input_tokens, axis=0)
    
    # call the autoregressive_sample_stream function from trax
    output_gen = trax.supervised.decoding.autoregressive_sample_stream( 
        # model
        ReformerLM,
        # inputs will be the tokens with batch dimension
        inputs=input_tokens_with_batch,
        # temperature
        temperature=temperature
    )
    
    ### END CODE HERE ###
    
    return output_gen

In [14]:
def generate_dialogue(ReformerLM, model_state, start_sentence, vocab_file, vocab_dir, max_len, temperature):
    """
    Args:
        ReformerLM:  the Reformer language model you just trained
        model_state (np.array): initial state of the model before decoding
        start_sentence (string): starting sentence of the conversation
        vocab_file (string): vocabulary filename
        vocab_dir (string): directory of the vocabulary file
        max_len (int): maximum number of tokens to generate 
        temperature (float): parameter for sampling ranging from 0.0 to 1.0.
            0.0: same as argmax, always pick the most probable token
            1.0: sampling from the distribution (can sometimes say random things)

    Returns:
        generator: yields the next symbol generated by the model
    """  
    
    # define the delimiters we used during training
    delimiter_1 = '作：' 
    delimiter_2 = '文：'
    
    # initialize detokenized output
    sentence = ''
    
    # token counter
    counter = 0
    
    # output tokens. we insert a ': ' for formatting
    result = [tokenize(': ', vocab_file=vocab_file, vocab_dir=vocab_dir)]
    
    # reset the model state when starting a new dialogue
    ReformerLM.state = model_state
    
    # calls the output generator implemented earlier
    output = ReformerLM_output_gen(ReformerLM, start_sentence, vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR, temperature=temperature)
    
    # print the starting sentence
    print(start_sentence.split(delimiter_2)[0].strip())
    
    # loop below yields the next tokens until max_len is reached. the if-elif is just for prettifying the output.
    for o in output:
        
        result.append(o)
        
        sentence = detokenize(np.concatenate(result, axis=0), vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR)
        
        if sentence.endswith(delimiter_1):
            sentence = sentence.split(delimiter_1)[0]
            print(f'{delimiter_2}{sentence}')
            sentence = ''
            result.clear()
        
        elif sentence.endswith(delimiter_2):
            sentence = sentence.split(delimiter_2)[0]
            print(f'{delimiter_1}{sentence}')
            sentence = ''
            result.clear()

        counter += 1
        
        if counter > max_len:
            break    

In [None]:
sample_sentence = ' 题：'
generate_dialogue(ReformerLM=model, model_state=STARTING_STATE, start_sentence=sample_sentence, vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR, max_len=120, temperature=0.2)