# mT5-small Fine-tuning

### **The model checkpoint can be downloaded from: [Google Drive](https://drive.google.com/file/d/1uuevTEvrMhLiW4cfijcMkpfz3IrdwgG-/view?usp=sharing)*

## Imports, Device Setting and Weight and Bias Display

In [1]:
! pip install transformers
! pip3 install wandb
! pip install sentencepiece

import wandb
import os
import torch
import re
from torch import cuda, nn, optim
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from transformers import TrainingArguments, Trainer, logging
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence


Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting transformers
  Downloading transformers-4.28.1-py3-none-any.whl (7.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.0/7.0 MB[0m [31m54.3 MB/s[0m eta [36m0:00:00[0m
Collecting huggingface-hub<1.0,>=0.11.0
  Downloading huggingface_hub-0.14.1-py3-none-any.whl (224 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m224.5/224.5 kB[0m [31m27.8 MB/s[0m eta [36m0:00:00[0m
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1
  Downloading tokenizers-0.13.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.8/7.8 MB[0m [31m106.2 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: tokenizers, huggingface-hub, transformers
Successfully installed huggingface-hub-0.14.1 tokenizers-0.13.3 transformers-4.28.1
Looking in indexes: https://pypi.org/simple, https:/

In [2]:
from google.colab import drive
drive.mount('/content/gdrive')
manual_seed = 585
torch.manual_seed(manual_seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

Mounted at /content/gdrive
cuda


In [3]:
wandb.login()
wandb.init(project="Zootopia", entity="qmygrace")


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mqmygrace[0m. Use [1m`wandb login --relogin`[0m to force relogin


## Load the Pre-trained Model

In [4]:
# https://huggingface.co/google/mt5-small

tokenizer = AutoTokenizer.from_pretrained("google/mt5-small")
model = AutoModelForSeq2SeqLM.from_pretrained("google/mt5-small")
model.to(device)

Downloading (…)okenizer_config.json:   0%|          | 0.00/82.0 [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/553 [00:00<?, ?B/s]

Downloading (…)ve/main/spiece.model:   0%|          | 0.00/4.31M [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/99.0 [00:00<?, ?B/s]



Downloading pytorch_model.bin:   0%|          | 0.00/1.20G [00:00<?, ?B/s]

Downloading (…)neration_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

MT5ForConditionalGeneration(
  (shared): Embedding(250112, 512)
  (encoder): MT5Stack(
    (embed_tokens): Embedding(250112, 512)
    (block): ModuleList(
      (0): MT5Block(
        (layer): ModuleList(
          (0): MT5LayerSelfAttention(
            (SelfAttention): MT5Attention(
              (q): Linear(in_features=512, out_features=384, bias=False)
              (k): Linear(in_features=512, out_features=384, bias=False)
              (v): Linear(in_features=512, out_features=384, bias=False)
              (o): Linear(in_features=384, out_features=512, bias=False)
              (relative_attention_bias): Embedding(32, 6)
            )
            (layer_norm): MT5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): MT5LayerFF(
            (DenseReluDense): MT5DenseGatedActDense(
              (wi_0): Linear(in_features=512, out_features=1024, bias=False)
              (wi_1): Linear(in_features=512, out_features=1024, bias=False)
          

## Preprocess data

In [5]:
# path = '../data/'    # change the path as needed
path = '/content/gdrive/My Drive/585data/'

def read_data(file):
    with open (path+file) as t:
        data = t.readlines()
    return data

train_set = read_data('train_data.txt')[:20000]
dev_set = read_data('dev_data.txt')[:3000]
# test_set = read_data('test_data.txt')[:3000]

# type(train_set)
print(train_set[:2], '\n', len(train_set), len(dev_set))

['{"groundTruth": ["发扬光大", "平易近人", "温文尔雅"], "candidates": [["意气风发", "街谈巷议", "人才辈出", "一脉相传", "后继有人", "发扬光大", "腥风血雨"], ["平易近人", "落落大方", "八仙过海", "彬彬有礼", "史无前例", "盛气凌人", "好自为之"], ["不拘小节", "风流潇洒", "无病呻吟", "言谈举止", "壮志凌云", "关门闭户", "温文尔雅"]], "content": "由实力派演员刘威饰演的清华第三任校长蒋南翔，是我国著名的青年运动家和教育家，他跟清华终身校长梅贻琦一样，都是由清华人自己培养出来的校长。历史上的蒋南翔是著名的“一二九”学生救亡运动的领导人之一，他在清华校长之位14年期间，不但很好的继承了清华建校之初的优秀传统与理念，而且更加的#idiom#，他把清华的教师队伍扩大了将近5倍，将清华本科人数破万，为新中国培养了大量的有用人才。在《天行健》中饰演蒋南翔的刘威是观众所熟悉的著名实力派演员，早在1987年刘威就在《关东大侠》中饰演豪爽仗义的关云天一角而获得了金鸡奖最佳男主角的提名，后来更是因在《唐明皇》中精湛的表演而一举夺得金鹰奖最佳男演员奖。此次《天行健》选定刘威来出演正是看中了他#idiom#的表演方式和对人物深入内心的刻画。至此，《天行健》中涉及的三位清华校长的人选都已经曝光，#idiom#的第一任校长赵文?、稳重坚毅的第二任校长孙逊、亲切务实的第三任校长刘威，再加上梁思成、林徽因、朱自清、闻一多等一批“大师”的加盟，相信作为清华百年校庆重点项目之一的《天行健》一定会带领观众重温那段不能抹去的历史。", "realCount": 3}\n', '{"groundTruth": ["肥头大耳"], "candidates": [["超凡入圣", "骨瘦如柴", "青面獠牙", "虎背熊腰", "成人之美", "肥头大耳", "神不守舍"]], "content": "#idiom#的掌柜只穿一件衬衫，坐在柜台里。几个堂倌穿着脏得发黑的白工作服，因为没有顾客，都散坐在桌子旁。这当儿看到这位不寻常的客人，都露出好奇的神色列宁曾批评他理论上的错误，同时认为他“所写的全部哲学，赶紧迎上前来伺候。聂赫留朵夫要了一瓶矿泉水，在离窗较远的地方挨着一张

In [6]:
# preprocess_idx = -1
# def replace(match):
#     global preprocess_idx
#     preprocess_idx += 1
#     return 'extra {}'.format(preprocess_idx)

# text = '由实力派演员刘威饰演的清华第三任校长蒋南翔，是我国著名的青年运动家和教育家，他跟清华终身校长梅贻琦一样，都是由清华人自己培养出来的校长。历史上的蒋南翔是著名的“一二九”学生救亡运动的领导人之一，他在清华校长之位14年期间，不但很好的继承了清华建校之初的优秀传统与理念，而且更加的#idiom#，他把清华的教师队伍扩大了将近5倍，将清华本科人数破万，为新中国培养了大量的有用人才。在《天行健》中饰演蒋南翔的刘威是观众所熟悉的著名实力派演员，早在1987年刘威就在《关东大侠》中饰演豪爽仗义的关云天一角而获得了金鸡奖最佳男主角的提名，后来更是因在《唐明皇》中精湛的表演而一举夺得金鹰奖最佳男演员奖。此次《天行健》选定刘威来出演正是看中了他#idiom#的表演方式和对人物深入内心的刻画。至此，《天行健》中涉及的三位清华校长的人选都已经曝光，#idiom#的第一任校长赵文?、稳重坚毅的第二任校长孙逊、亲切务实的第三任校长刘威，再加上梁思成、林徽因、朱自清、闻一多等一批“大师”的加盟，相信作为清华百年校庆重点项目之一的《天行健》一定会带领观众重温那段不能抹去的历史。'
# re.sub(r'#idiom#', replace, text)

In [7]:
def preprocess(data):
    text_input = []
    idiom_output = []
    for i in range(len(data)):
        data[i] = eval(data[i])
        input_text = data[i]['content']
        ground_truth = data[i]['groundTruth']
        candidates = data[i]['candidates']

        candidate_str = ''
        for candidate in candidates:
            candidate_str += '('+'|'.join(candidate)+')'
        
        preprocess_idx = -1
        def replace(match):
            nonlocal preprocess_idx
            preprocess_idx += 1
            return 'extra{}'.format(preprocess_idx)
        input_text = re.sub(r'#idiom#', replace, input_text)

        instruction = '请从下列括号中分别选择合适的成语填入空缺处：{}'.format(candidate_str)
        # input_text = input_text.replace('#idiom#', '_')
        output_text = ','.join(ground_truth)
        
        text_input.append(instruction+'\n'+input_text)
        idiom_output.append(output_text)
    
    print(text_input[0], idiom_output[0])    
    input_tok = tokenizer.batch_encode_plus(text_input,
                                            add_special_tokens=False, 
                                            return_token_type_ids=False)
    output_tok = tokenizer.batch_encode_plus(idiom_output, 
                                             add_special_tokens=False,
                                             return_token_type_ids=False)
    return input_tok, output_tok

In [8]:
train_input, train_output = preprocess(train_set)
dev_input, dev_output = preprocess(dev_set)
# test_input, test_output = preprocess(test_set)

请从下列括号中分别选择合适的成语填入空缺处：(意气风发|街谈巷议|人才辈出|一脉相传|后继有人|发扬光大|腥风血雨)(平易近人|落落大方|八仙过海|彬彬有礼|史无前例|盛气凌人|好自为之)(不拘小节|风流潇洒|无病呻吟|言谈举止|壮志凌云|关门闭户|温文尔雅)
由实力派演员刘威饰演的清华第三任校长蒋南翔，是我国著名的青年运动家和教育家，他跟清华终身校长梅贻琦一样，都是由清华人自己培养出来的校长。历史上的蒋南翔是著名的“一二九”学生救亡运动的领导人之一，他在清华校长之位14年期间，不但很好的继承了清华建校之初的优秀传统与理念，而且更加的extra0，他把清华的教师队伍扩大了将近5倍，将清华本科人数破万，为新中国培养了大量的有用人才。在《天行健》中饰演蒋南翔的刘威是观众所熟悉的著名实力派演员，早在1987年刘威就在《关东大侠》中饰演豪爽仗义的关云天一角而获得了金鸡奖最佳男主角的提名，后来更是因在《唐明皇》中精湛的表演而一举夺得金鹰奖最佳男演员奖。此次《天行健》选定刘威来出演正是看中了他extra1的表演方式和对人物深入内心的刻画。至此，《天行健》中涉及的三位清华校长的人选都已经曝光，extra2的第一任校长赵文?、稳重坚毅的第二任校长孙逊、亲切务实的第三任校长刘威，再加上梁思成、林徽因、朱自清、闻一多等一批“大师”的加盟，相信作为清华百年校庆重点项目之一的《天行健》一定会带领观众重温那段不能抹去的历史。 发扬光大,平易近人,温文尔雅
请从下列括号中分别选择合适的成语填入空缺处：(深恶痛绝|人人自危|恨入骨髓|不胜枚举|嗤之以鼻|走马看花|不屑一顾)
另据了解，北京一个对垃圾短信extra0的老人，利用该软件总共呼死了近2000个号码。20分钟呼上万号码记者昨天在百度里输入“呼死你软件”，出现了7000多个相关网页，随机登录几个网站，发现软件均需花钱购买，价格从200元至500元不等。 深恶痛绝


In [9]:
print(train_input.keys(), train_output.keys())

dict_keys(['input_ids', 'attention_mask']) dict_keys(['input_ids', 'attention_mask'])


In [10]:
print(train_input['input_ids'][0], '\n', train_input['attention_mask'][0])

[259, 20256, 5229, 2446, 19783, 149527, 3688, 1223, 91486, 29133, 5072, 92396, 493, 4449, 20139, 82635, 4484, 8123, 47193, 13746, 267, 312, 10691, 11755, 8893, 5685, 409, 16628, 30014, 158795, 26896, 409, 47145, 130509, 2371, 409, 1374, 117723, 6497, 19946, 409, 3592, 88623, 64947, 409, 5685, 54029, 4491, 1146, 409, 239243, 8893, 18643, 16189, 4829, 5064, 18272, 8659, 1193, 409, 12265, 12265, 1146, 4222, 409, 7704, 39098, 6994, 4093, 409, 204126, 204126, 1637, 21158, 409, 22695, 5941, 2884, 13733, 409, 19510, 11755, 91708, 1193, 409, 3586, 5081, 2037, 2904, 4829, 1597, 104151, 2144, 16984, 409, 8893, 8041, 239151, 185028, 409, 5941, 14469, 242248, 172615, 409, 9812, 30014, 49425, 32401, 409, 78590, 16706, 91708, 9896, 409, 14428, 8394, 80855, 27841, 409, 17794, 4565, 8216, 24582, 271, 259, 10135, 77006, 19700, 138014, 18538, 18157, 73502, 20936, 493, 9060, 8423, 42258, 12307, 155410, 128057, 4938, 93781, 261, 1543, 44635, 237766, 44283, 61449, 3203, 1107, 9716, 3203, 261, 3763, 18475, 

In [11]:
print(train_output['input_ids'][0], '\n', train_output['attention_mask'][0])

[259, 5685, 54029, 4491, 1146, 261, 5064, 18272, 8659, 1193, 261, 17794, 4565, 8216, 24582] 
 [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]


In [12]:
class IdiomDataset(Dataset):
    def __init__(self, inputs, outputs):
        self.inputs = inputs
        self.outputs = outputs

    def __len__(self):
        return len(self.inputs["input_ids"])

    def __getitem__(self, idx):
        input_ids = self.inputs['input_ids'][idx]
        attention_mask = self.inputs['attention_mask'][idx]

        target_ids = self.outputs['input_ids'][idx]
        target_attention_mask = self.outputs['attention_mask'][idx]
        return {"input_ids": input_ids, "attention_mask":attention_mask, "output_ids":target_ids}


def collate_fn(batch):
    batch_input = [torch.LongTensor(example['input_ids']) for example in batch]
    batch_output = [torch.LongTensor(example['output_ids']) for example in batch]
    batch_mask = [torch.LongTensor(example['attention_mask']) for example in batch]

    padded_batch_input_ids = pad_sequence(batch_input, batch_first=True, padding_value=tokenizer.pad_token_id)
    padded_batch_label = pad_sequence(batch_output, batch_first=True, padding_value=tokenizer.pad_token_id)
    padded_batch_att_mask = pad_sequence(batch_mask, batch_first=True, padding_value=0)

    return {"input_ids": padded_batch_input_ids, "attention_mask": padded_batch_att_mask, "labels": padded_batch_label}

def to_device(data, device):
    new_data = {}
    for k in data:
        # k = k.to(device)
        new_data[k] = data[k].to(device)
    return new_data

In [13]:
train_dataset = IdiomDataset(train_input, train_output)
train_loader = DataLoader(train_dataset, batch_size=8, collate_fn=collate_fn, shuffle=True)

dev_dataset = IdiomDataset(dev_input, dev_output)
dev_loader = DataLoader(dev_dataset, batch_size=8, collate_fn=collate_fn, shuffle=False)


## Training

In [14]:
@torch.no_grad()
def evaluate(model:nn.Module, eval_loader:DataLoader):
    eval_loss = 0.0
    correct = 0
    total = 0
    model.eval()
    print("eval_loader len:", len(eval_loader))
    for batch in eval_loader:
        batch = to_device(batch, device)
        output = model(**batch)
        loss = output.loss
        eval_loss += loss.item()
        pred = output.logits.argmax(-1)
        label = batch["labels"]
        correct += torch.where(label!=0, pred==label, 0).sum().item()
        total += torch.sum(label!=0).item()

    eval_acc = correct / total
    eval_loss = eval_loss / len(eval_loader) 
    print(total, correct)
    return eval_acc, eval_loss

In [15]:
epoches = 5       
optimizer = optim.Adam(model.parameters(), lr=5e-5)
model.to(device)

model.train()
for epoch in range(epoches):
    epoch_loss = 0.0
    log_loss = 0.0
    for idx, batch in enumerate(train_loader):
        model.zero_grad()
        batch = to_device(batch, device)
        loss = model(**batch).loss
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        log_loss += loss.item()

        wandb.log({'batch':idx, 'train_loss': loss.item()})
        wandb.log({'batch':idx, 'accumulated_train_loss_in_this_1k_batches': log_loss})

        if idx % 100 == 0:
            print(f"Train Step: {idx} Loss: {log_loss / 100}")
            log_loss = 0.0
    print(f"Epoch: {epoch+1} Loss is: {epoch_loss}")
    eval_acc, eval_loss = evaluate(model, dev_loader)
    print(f"Epoch {epoch+1} Eval Acc: {eval_acc}; Eval Loss: {eval_loss}")

Train Step: 0 Loss: 0.3624109268188477
Train Step: 100 Loss: 27.885444202423095
Train Step: 200 Loss: 20.266760244369507
Train Step: 300 Loss: 16.646393766403197
Train Step: 400 Loss: 14.34015974998474
Train Step: 500 Loss: 12.613938312530518
Train Step: 600 Loss: 10.975036611557007
Train Step: 700 Loss: 10.338030376434325
Train Step: 800 Loss: 9.289741277694702
Train Step: 900 Loss: 8.410621366500855
Train Step: 1000 Loss: 7.64582968711853
Train Step: 1100 Loss: 6.7045935535430905
Train Step: 1200 Loss: 5.868394136428833
Train Step: 1300 Loss: 5.172049405574799
Train Step: 1400 Loss: 4.590031695365906
Train Step: 1500 Loss: 4.2740762996673585
Train Step: 1600 Loss: 3.9429748368263247
Train Step: 1700 Loss: 3.383105719089508
Train Step: 1800 Loss: 3.169616365432739
Train Step: 1900 Loss: 3.0035462272167206
Train Step: 2000 Loss: 2.489843978881836
Train Step: 2100 Loss: 2.2652867710590363
Train Step: 2200 Loss: 2.05966717004776
Train Step: 2300 Loss: 1.8886240816116333
Train Step: 2400 

In [16]:
torch.save(model.state_dict(), path+"mT5-small_model_5epoches.pt")

In [17]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The model has {count_parameters(model):,} trainable parameters')

The model has 300,176,768 trainable parameters


## Evaluation

In [18]:
@torch.no_grad()
def fill_idiom(model, loader):

    all_preds = []
    all_labels = []
    model.eval()
    for batch in loader:
        batch = to_device(batch, device)
        input_ids = batch["input_ids"]
        attention_mask = batch["attention_mask"]
        labels = batch["labels"]
        outputs = model.generate(input_ids=input_ids, 
                                 attention_mask=attention_mask, 
                                 return_dict_in_generate=True, 
                                 pad_token_id=tokenizer.pad_token_id, 
                                 max_length=512, 
                                 top_k=15)
        truncated_outputs = []

        decode_texts = tokenizer.batch_decode([l[l != 0] for l in outputs['sequences']])
        gold_texts = tokenizer.batch_decode([l[l != 0] for l in labels])
        # print(decode_texts, gold_texts)
        for gold, decode in zip(gold_texts, decode_texts):
            l = set(gold.replace(' ', '').replace('[CLS]', '').split(','))
            p = set(decode.replace(' ', '').replace('[CLS]', '').split(','))
            # print(l, p)
            all_labels.append(l)
            all_preds.append(p)
        # print(decode_texts)
        # print(gold_texts)
        # break
    
    return all_preds, all_labels

def f1_score(sys, gold):
    tp = 0
    total = 0
    pos = 0
    for s, g in zip(sys, gold):
        total += len(g)
        pos += len(s)
        tp += len(g & s)
    precision = tp / pos if pos != 0 else 0
    recall = tp / total if total != 0 else 0
    f1 = (2 * precision * recall) / (precision + recall) if (precision + recall) != 0 else 0
    return precision, recall, f1, tp

In [19]:
sys, gold = fill_idiom(model, dev_loader)
p, r, f1, tp = f1_score(sys, gold)

In [20]:
total = 0
for s, g in zip(sys, gold):
    total += len(g)

In [21]:
print(f"Accurate amount for Validation set is {tp} out of {total}")
print(f"Accuracy for Validation set is {tp/total}")
print(f"F1 score for Validation set is {f1}")

Accurate amount for Validation set is 1589 out of 3668
Accuracy for Validation set is 0.43320610687022904
F1 score for Validation set is 0.43367903930131


In [22]:
sys[:10]

[{'深恶痛绝'},
 {'井井有条'},
 {'跃跃欲试'},
 {'无与伦比'},
 {'一语道破', '不胜其烦', '评头品足'},
 {'千篇一律'},
 {'罪魁祸首'},
 {'聪明才智'},
 {'千载难逢'},
 {'苦中作乐'}]

In [23]:
gold[:10]

[{'深恶痛绝'},
 {'杂乱无章'},
 {'磨刀霍霍'},
 {'独一无二'},
 {'一语道破', '不厌其烦', '品头题足'},
 {'大同小异'},
 {'罪魁祸首'},
 {'聪明才智'},
 {'千载难逢'},
 {'酸甜苦辣'}]

In [24]:
with open ('mT5-small_outputs_final.txt', 'w', encoding='utf-8') as mt5:
    for s in sys:
        line = ','.join(s)
        mt5.write(str(line)+'\n')