In [1]:
import json
import os
from torch.utils.data import Dataset, DataLoader
import torch
from tqdm import tqdm
import random
import re

In [2]:
from transformers import (
    AdamW,
    T5ForConditionalGeneration,
    T5Tokenizer,
    get_linear_schedule_with_warmup
)

In [3]:
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

# 数据预处理

In [4]:
tokenizer = T5Tokenizer.from_pretrained('./t5-base/')

In [5]:
# with open("./wiki80/wiki80_train.txt",'r',encoding='utf-8') as f:
#         lines = f.readlines()
#         for line in lines:
#             data = json.loads(line)
#             h_entity_replace = randomLetter()
#             t_entity_replace = randomLetter()
#             plus_text = data['token'][0:data['h']['pos'][0]] + [h_entity_replace] +  data['token'][data['h']['pos'][1]:data['t']['pos'][0]] + [t_entity_replace] + data['token'][data['h']['pos'][1]:] 
#             plus_text = " ".join(plus_text)
#             plus_text = "extract relation: " + plus_text.lower() + " </s>"
#             rel_text = h_entity_replace.lower() + " - " +  data['relation'].lower() + ' - ' + t_entity_replace.lower() + " </s>"
#             print(plus_text)
#             print(rel_text)
#             break

In [6]:
def extract_data(filepath,tokenizer):
    origin_texts = []
    rel_texts = []
    max_input_len = 0
    max_output_len = 0
    not_pre = 0
    pre = 0
    with open(filepath,'r',encoding='utf-8') as f:
        lines = f.readlines()
        for line in lines:
            line = line.strip().split("\t")
            lecture1 = line[0]
            lecture2 = line[2]
            text = "judge prerequisite: " + lecture1 + " " + lecture2 + " </s>"
            if line[4] == "-":
                rel = "is not prerequisite"
                not_pre += 1
            elif line[4] == "1-":
                rel = "is prerequisite"
                pre +=1
            elif line[4] == "-1":
                rel = "is prerequisite"
                lecture1,lecture2 = lecture2,lecture1
                pre +=1
            else:
                print(line)
                print("!")
            rel_text = lecture1 + " - " +  rel + ' - ' + lecture2 + " </s>"
            origin_texts.append(text)
            rel_texts.append(rel_text)
            tokenized_inp = tokenizer.encode_plus(text, return_tensors="pt")
            tokenized_output = tokenizer.encode_plus(rel_text,return_tensors="pt")
            input_ids  = tokenized_inp["input_ids"]
            output_ids  = tokenized_output["input_ids"]
            max_input_len = max(max_input_len, input_ids.shape[1])
            max_output_len = max(max_output_len, output_ids.shape[1])
            
            # 数据增强
            # 用两个随机字母替换两个实体
#             h_entity_replace = randomLetter()
#             t_entity_replace = randomLetter()
#             first_entity_pos = []
#             second_entity_pos = []
#             if data['h']['pos'][0] < data['t']['pos'][0]:
#                 first_entity_pos = data['h']['pos']
#                 second_entity_pos = data['t']['pos']
#             else:
#                 first_entity_pos = data['t']['pos']
#                 second_entity_pos = data['h']['pos']
#             plus_text = data['token'][0:first_entity_pos[0]] + [h_entity_replace] +  data['token'][first_entity_pos[1]:second_entity_pos[0]] + [t_entity_replace] + data['token'][second_entity_pos[1]:] 
#             plus_text = " ".join(plus_text)
#             plus_text = "extract relation: " + plus_text.lower() + " </s>"
#             rel_text = h_entity_replace.lower() + " - " +  data['relation'].lower() + ' - ' + t_entity_replace.lower() + " </s>"
#             print(plus_text)
#             print(rel_text)
#             origin_texts.append(plus_text)
#             rel_texts.append(rel_text)
#     print(max_input_len)
#     print(max_output_len)
    print(pre)
    print(not_pre)
    return origin_texts,rel_texts

In [7]:
def extract_val_data(filepath,tokenizer):
    origin_texts = []
    rel_texts = []
    max_input_len = 0
    max_output_len = 0
    with open(filepath,'r',encoding='utf-8') as f:
        lines = f.readlines()
        rel = ""
        for line in lines:
            line = line.strip().split("\t")
            lecture1 = line[0]
            lecture2 = line[2]
            text = "judge prerequisite: " + lecture1 + " " + lecture2 + " </s>"
            if line[4] == "-":
                rel = "is not prerequisite"
            elif line[4] == "-1" or line[4] == "1-":
                rel = "is prerequisite"
            else:
                print(line)
                print("!")
            rel_text = lecture1 + " - " +  rel + ' - ' + lecture2
            origin_texts.append(text)
            rel_texts.append(rel_text)
    return origin_texts,rel_texts

In [8]:
class Extract_Dataset(Dataset):
    def __init__(self, filepaths, tokenizer,max_input_len,max_output_len):
        self.origin_texts, self.rel_texts = [],[]
        for filepath in filepaths:
            o,r = extract_data(filepath,tokenizer)
            self.origin_texts.extend(o)
            self.rel_texts.extend(r)
        self.max_input_len = max_input_len
        self.max_output_len = max_output_len
        
    def __len__(self):
        return len(self.origin_texts)
    
    def __getitem__(self, index):
        tokenized_input = tokenizer.encode_plus(self.origin_texts[index], max_length=self.max_input_len, pad_to_max_length=True, return_tensors="pt")
        tokenized_output = tokenizer.encode_plus(self.rel_texts[index], max_length=self.max_output_len, pad_to_max_length=True, return_tensors="pt")
        
        input_ids  = tokenized_input["input_ids"].squeeze()
        attention_mask = tokenized_input["attention_mask"].squeeze()

        output_ids = tokenized_output["input_ids"].squeeze()
        decoder_attention_mask=  tokenized_output["attention_mask"].squeeze()
        
        data = {
            'input_ids':input_ids,
            'attention_mask':attention_mask,
            'output_ids':output_ids,
            'decoder_attention_mask':decoder_attention_mask
        }
        return data

# 训练

In [9]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
t5_model = T5ForConditionalGeneration.from_pretrained('./t5-base/').to(DEVICE)

In [10]:
# optimizer
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
    {
        "params": [p for n, p in t5_model.named_parameters() if not any(nd in n for nd in no_decay)],
        "weight_decay": 0.0,
    },
    {
        "params": [p for n, p in t5_model.named_parameters() if any(nd in n for nd in no_decay)],
        "weight_decay": 0.0,
    },
]
optimizer = AdamW(optimizer_grouped_parameters, lr=3e-4)

In [11]:
def train(model, device, train_loader, optimizer):   # 训练模型
    model.train()
    with tqdm(total=len(train_loader)) as bar:
        for idx,data in enumerate(train_loader):
            input_ids, attention_mask, output_ids, decoder_attention_mask = data['input_ids'].to(device), data['attention_mask'].to(device), data['output_ids'].to(device), data['decoder_attention_mask'].to(device)
            output = model(input_ids=input_ids, labels=output_ids,decoder_attention_mask=decoder_attention_mask,attention_mask=attention_mask)
            loss = output[0]
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            bar.set_postfix(loss=loss.item())
            bar.update(1)

In [12]:
def val(model, tokenizer,device, filepath):
    origin_texts, rel_texts = extract_val_data(filepath,tokenizer)
    model.eval()
    TP = 0
    FP = 0
    FN = 0
    with tqdm(total=len(origin_texts)) as bar:
        for idx, text in enumerate(origin_texts):
            input_ids = tokenizer(text, return_tensors='pt').input_ids.to(DEVICE)
            outputs = t5_model.generate(input_ids)
            result = tokenizer.decode(outputs[0], skip_special_tokens=True).split(" - ")
            rel_text = rel_texts[idx].split(" - ")
#             print(rel_text)
#             print(result)
            if result[1] == "is prerequisite":
                # print(1)
                if result[1] == rel_text[1]:
                    TP += 1
                else:
                    FP += 1
            else:
                if result[1] != rel_text[1]:
                    FN += 1
            bar.update(1)
    print(TP)
    print(FP)
    print(FN)
    precision = TP / (TP + FP)
    recall = TP / (TP + FN)
    F1 = 2 * precision * recall / (precision + recall)

    return precision,recall,F1

In [13]:
filepaths = ["./Mooc/ML/ML_LabeledFile"]
train_dataset = Extract_Dataset(filepaths=filepaths,tokenizer=tokenizer,max_input_len=40,max_output_len=40)
train_loader = DataLoader(train_dataset,shuffle=True,batch_size=32)



1735
4977


In [14]:
NUM_EPOCHES = 20
with open("./log(mooc).txt",'w',encoding='utf-8') as f:
    for epoch in range(NUM_EPOCHES):
        train(t5_model,DEVICE,train_loader,optimizer)
        precision,recall,F1 = val(t5_model,tokenizer,DEVICE,"./Mooc/ML/W-ML_LabeledFile")
        f.write("EPOCH:" + str(epoch) + "\n")
        f.write("precision：" + str(precision) + "\n")
        f.write("recall：" + str(recall) + "\n")
        f.write("F1：" + str(F1) + "\n")

  0%|          | 0/210 [00:00<?, ?it/s]Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.
100%|██████████| 210/210 [00:32<00:00,  6.48it/s, loss=0.0203]
100%|██████████| 935/935 [02:42<00:00,  5.77it/s]
  0%|          | 1/210 [00:00<00:33,  6.32it/s, loss=0.0163]

60
29
216


100%|██████████| 210/210 [00:32<00:00,  6.56it/s, loss=0.0242] 
100%|██████████| 935/935 [02:40<00:00,  5.82it/s]
  0%|          | 1/210 [00:00<00:33,  6.29it/s, loss=0.0171]

79
36
197


100%|██████████| 210/210 [00:32<00:00,  6.54it/s, loss=0.0168] 
100%|██████████| 935/935 [02:40<00:00,  5.82it/s]
  0%|          | 1/210 [00:00<00:32,  6.36it/s, loss=0.0145]

168
65
108


100%|██████████| 210/210 [00:31<00:00,  6.57it/s, loss=0.0145] 
100%|██████████| 935/935 [02:41<00:00,  5.80it/s]
  0%|          | 1/210 [00:00<00:32,  6.34it/s, loss=0.0119]

176
49
100


100%|██████████| 210/210 [00:31<00:00,  6.57it/s, loss=0.0105] 
100%|██████████| 935/935 [02:40<00:00,  5.82it/s]
  0%|          | 1/210 [00:00<00:33,  6.25it/s, loss=0.01]

163
35
113


100%|██████████| 210/210 [00:32<00:00,  6.51it/s, loss=0.00967]
100%|██████████| 935/935 [02:39<00:00,  5.85it/s]
  0%|          | 1/210 [00:00<00:32,  6.36it/s, loss=0.00821]

204
28
72


100%|██████████| 210/210 [00:31<00:00,  6.57it/s, loss=0.0124] 
100%|██████████| 935/935 [02:39<00:00,  5.87it/s]
  0%|          | 1/210 [00:00<00:32,  6.37it/s, loss=0.0089]

205
18
71


100%|██████████| 210/210 [00:32<00:00,  6.56it/s, loss=0.00415]
100%|██████████| 935/935 [02:39<00:00,  5.85it/s]
  0%|          | 1/210 [00:00<00:33,  6.32it/s, loss=0.00505]

217
27
59


100%|██████████| 210/210 [00:32<00:00,  6.55it/s, loss=0.0211] 
100%|██████████| 935/935 [02:38<00:00,  5.89it/s]
  0%|          | 1/210 [00:00<00:33,  6.31it/s, loss=0.0112]

261
56
15


100%|██████████| 210/210 [00:31<00:00,  6.57it/s, loss=0.00227]
100%|██████████| 935/935 [02:40<00:00,  5.84it/s]
  0%|          | 1/210 [00:00<00:31,  6.56it/s, loss=0.00172]

264
36
12


100%|██████████| 210/210 [00:32<00:00,  6.54it/s, loss=0.0021] 
100%|██████████| 935/935 [02:39<00:00,  5.88it/s]
  0%|          | 1/210 [00:00<00:33,  6.31it/s, loss=0.00269]

256
18
20


100%|██████████| 210/210 [00:32<00:00,  6.53it/s, loss=0.00475] 
 57%|█████▋    | 535/935 [01:30<00:59,  6.73it/s]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



In [15]:
# t =  " ".join(["Flight", "3400", "was", "bound", "for", "Moline", ",", "Il", ".", ",", "when", "it", "was", "diverted", "about", "9", "p.m.", "to", "Majors", "Airport", "in", "Greenville", "."])
# t = "extract relation: " + t.lower() + " </s>"
# print(t)

In [16]:
# input_ids = tokenizer(t, return_tensors='pt').input_ids.to(DEVICE)
# outputs = t5_model.generate(input_ids)
# print(tokenizer.decode(outputs[0], skip_special_tokens=True))

In [17]:
# result = tokenizer.decode(outputs[0], skip_special_tokens=True)
# result.split(" - ")

In [18]:
# print(val(t5_model,tokenizer,DEVICE,"./semeval/semeval_val.txt"))

In [19]:
# acc_rel,acc_entity,acc_all = val(t5_model,tokenizer,DEVICE,"./LectureBank/val.0.csv")

In [20]:
# acc_rel

In [21]:
# acc_all