## 模型准备

In [1]:
model_path = 'D:/model/web/nlp01/ckpt/checkpoint-3070/'

In [2]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

In [3]:
model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path)



In [4]:
model

MarianMTModel(
  (model): MarianModel(
    (shared): Embedding(65001, 512, padding_idx=65000)
    (encoder): MarianEncoder(
      (embed_tokens): Embedding(65001, 512, padding_idx=65000)
      (embed_positions): MarianSinusoidalPositionalEmbedding(512, 512)
      (layers): ModuleList(
        (0-5): 6 x MarianEncoderLayer(
          (self_attn): MarianAttention(
            (k_proj): Linear(in_features=512, out_features=512, bias=True)
            (v_proj): Linear(in_features=512, out_features=512, bias=True)
            (q_proj): Linear(in_features=512, out_features=512, bias=True)
            (out_proj): Linear(in_features=512, out_features=512, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (activation_fn): SiLU()
          (fc1): Linear(in_features=512, out_features=2048, bias=True)
          (fc2): Linear(in_features=2048, out_features=512, bias=True)
          (final_layer_norm): LayerNorm((512,), eps=1e-05

## 模型推理自己的写法

In [5]:
text = '您的胸痛伴有以下任何症状吗?'

In [6]:
input_ids = tokenizer.encode(text, return_tensors = 'pt')

In [7]:
result = model.generate(input_ids)

In [8]:
result[0]

tensor([65000,   147,    37,    53,   189,     4,     3,   414, 37051,    27,
          168, 31542, 10883,    23,     0])

In [9]:
translated_text = tokenizer.decode(result[0], skip_special_tokens=True)

In [10]:
translated_text

'do you have any of the following symptoms with your chest pain?'

## pipeline写法

In [11]:
from transformers import pipeline

In [12]:
zh2en = pipeline('translation_zh_to_en', 
                 model=model, 
                 tokenizer=tokenizer,
                 device = 'cpu')

In [13]:
print(zh2en(text))

[{'translation_text': 'do you have any of the following symptoms with your chest pain?'}]


In [14]:
print(zh2en(text)[0]['translation_text'])

do you have any of the following symptoms with your chest pain?


## 用gradio搭建前端

In [15]:
import gradio as gr

In [16]:
def translate(input):
    return zh2en(input)[0]['translation_text']

In [17]:
demo = gr.Interface(
    fn = translate,
    inputs = [gr.Textbox(label = '输入你的问题', lines=6)],
    outputs = [gr.Textbox(label = '翻译结果', lines=3)],
    title = '使用MarianMT模型微调实现问诊问题中英翻译',
    description = 
    '''
    本模型可以实现中英翻译，在左边的框中输入你想要翻译的句子，右方会显示翻译的结果。\n
    该模型已在`tico-19`上微调，可以解决较多的医学/疫情相关的问题。\n
    微调数据集为`tico-19`，使用的模型checkpoint为`checkpoint-3070`。
    '''
)

In [18]:
demo.launch()

* Running on local URL:  http://127.0.0.1:7860

To create a public link, set `share=True` in `launch()`.


