Skip to content

Latest commit

 

History

History
23 lines (17 loc) · 700 Bytes

5. transformers使用.md

File metadata and controls

23 lines (17 loc) · 700 Bytes

Model

  • 源码

    • 内部调用model时,输出 = 隐藏层最后一层的输出
    • 将该输出传入输出层(全连接层),将输出拼接到原输出的最前面
    • 如果传入的数据有标签,则计算交叉熵,将结果再插入到output最前面
    image-20201125120115242
  • 使用

# 输出第一个为loss
loss, logits = model(
            b_input_ids,
            token_type_ids=None,
            attention_mask=b_input_mask,
            labels=b_labels)
# 输出第一个为outputs
logits = model(input_ids=input_ids)