In [1]:
import os
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'

In [2]:
import torch
from transformers import MarianMTModel, MarianTokenizer
from datasets import load_dataset
from torch.utils.data import DataLoader
from tqdm import tqdm
from torch.nn.functional import log_softmax, softmax

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
model_name = 'Helsinki-NLP/opus-mt-en-mul'
teacher_model = MarianMTModel.from_pretrained(model_name)
teacher_model.to(device)
teacher_model.eval()

tokenizer = MarianTokenizer.from_pretrained(model_name)



In [4]:
def read_text_file(file_path):
    with open(file_path, "r", encoding="utf-8") as file:
        lines = [line.strip() for line in file.readlines()]
    return lines

# 假设文件路径
uk_file_path = './data/en-uk/NLLB.en-uk.uk'
en_file_path = './data/en-uk/NLLB.en-uk.en'

# 读取文件
uk_sentences = read_text_file(uk_file_path)
en_sentences = read_text_file(en_file_path)

# 检查文件长度是否一致
assert len(uk_sentences) == len(en_sentences), "The number of sentences must be the same in both files."

In [5]:
from torch.utils.data import Dataset, DataLoader

class TranslationDataset(Dataset):
    def __init__(self, src_sentences, tgt_sentences, tokenizer, max_length=512):
        self.tokenizer = tokenizer
        self.src_sentences = [" >>ukr<< " + sent for sent in src_sentences]
        self.tgt_sentences = tgt_sentences
        self.max_length = max_length

    def __len__(self):
        return len(self.src_sentences)

    def __getitem__(self, idx):
        src_text = self.src_sentences[idx]
        tgt_text = self.tgt_sentences[idx]

        # 使用 tokenizer 进行编码
        model_inputs = self.tokenizer(
            text=src_text,
            text_pair=tgt_text,
            add_special_tokens=True,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors="pt"
        )

        # 生成 decoder_input_ids
        decoder_input_ids = self.tokenizer.encode(
            tgt_text, 
            add_special_tokens=True,
            max_length=self.max_length, 
            truncation=True, 
            padding="max_length", 
            return_tensors="pt"
        ).squeeze()

        model_inputs['decoder_input_ids'] = decoder_input_ids
        # 这里我们把 decoder_input_ids 也用作 labels
        model_inputs['labels'] = decoder_input_ids.clone()

        # 调整输入输出格式以去除批次维度
        model_inputs = {key: val.squeeze(0) for key, val in model_inputs.items()}

        return model_inputs

# Create the dataset
train_dataset = TranslationDataset(en_sentences, uk_sentences, tokenizer)

In [6]:
student_config = teacher_model.config
student_model = MarianMTModel(student_config)
student_model.to(device)

MarianMTModel(
  (model): MarianModel(
    (shared): Embedding(64110, 512, padding_idx=64109)
    (encoder): MarianEncoder(
      (embed_tokens): Embedding(64110, 512, padding_idx=64109)
      (embed_positions): MarianSinusoidalPositionalEmbedding(512, 512)
      (layers): ModuleList(
        (0-5): 6 x MarianEncoderLayer(
          (self_attn): MarianAttention(
            (k_proj): Linear(in_features=512, out_features=512, bias=True)
            (v_proj): Linear(in_features=512, out_features=512, bias=True)
            (q_proj): Linear(in_features=512, out_features=512, bias=True)
            (out_proj): Linear(in_features=512, out_features=512, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (activation_fn): SiLU()
          (fc1): Linear(in_features=512, out_features=2048, bias=True)
          (fc2): Linear(in_features=2048, out_features=512, bias=True)
          (final_layer_norm): LayerNorm((512,), eps=1e-05

In [7]:
learning_rate = 0.001
batch_size = 4
num_epochs = 3
temperature = 5
alpha = 0.5

# 创建 DataLoader
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
optimizer = torch.optim.AdamW(student_model.parameters(), lr=learning_rate)

def calculate_loss(student_outputs, teacher_outputs, labels):
    s_logits = student_outputs.logits
    t_logits = teacher_outputs.logits

    vocab_size = s_logits.size(-1)
    ce_logits = s_logits.view(-1, vocab_size)
    ce_labels = labels.view(-1)
    ce_loss = torch.nn.functional.cross_entropy(ce_logits, ce_labels)
    student_log_probs = log_softmax(s_logits.view(-1, vocab_size) / temperature, dim=-1)
    teacher_probs = softmax(t_logits.view(-1, vocab_size) / temperature, dim=-1)

    distill_loss = torch.nn.functional.kl_div(student_log_probs, teacher_probs, reduction="batchmean")
    loss = (1 - alpha) * ce_loss + (alpha * temperature**2 / batch_size**2) * distill_loss
    return loss


In [None]:
for epoch in range(num_epochs):
    student_model.train()
    total_loss = 0

    for batch in tqdm(train_loader, desc=f"Epoch {epoch}"):
        optimizer.zero_grad()

        # 确保所有数据都转移到了适当的设备
        batch = {k: v.to(device) for k, v in batch.items()}
        
        # 教师和学生模型的输出
        teacher_outputs = teacher_model(input_ids=batch['input_ids'], attention_mask=batch.get('attention_mask'), decoder_input_ids=batch['decoder_input_ids'])
        student_outputs = student_model(input_ids=batch['input_ids'], attention_mask=batch.get('attention_mask'), decoder_input_ids=batch['decoder_input_ids'])

        # 计算损失并执行反向传播
        loss = calculate_loss(student_outputs, teacher_outputs, batch['labels'])
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    print(f"Average loss: {total_loss / len(train_loader)}")

student_model.save_pretrained("distilled-opus-mt-translation-model")


In [4]:
def preprocess(text):
    # 对文本进行编码，将其转换为模型可以处理的格式
    model_inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
    return model_inputs

In [5]:
def translate(text):
    # 对文本进行预处理
    encoded_text = preprocess(text)
    
    # 生成翻译输出
    translated_tokens = model.generate(**encoded_text)
    
    # 解码翻译结果
    translated_text = tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)
    return translated_text


In [8]:
text = ">>zho<< This is a sample text for translation."
translated_text = translate(text)
print(translated_text)

['这是翻译的样本']


In [11]:
train_dataset = load_dataset("wmt16", "de-en", split="train")

In [13]:
train_dataset = train_dataset.map(lambda x: {"src_text": x["translation"]["de"], "tgt_text": x["translation"]["en"]})

Map:   0%|          | 0/4548885 [00:00<?, ? examples/s]

Map: 100%|██████████| 4548885/4548885 [05:14<00:00, 14462.40 examples/s]


In [14]:
train_dataset

Dataset({
    features: ['translation', 'src_text', 'tgt_text'],
    num_rows: 4548885
})