In [1]:
import json
import os
from torch.utils.data import Dataset, DataLoader
import torch
from tqdm import tqdm
import random
import re

In [2]:
from transformers import (
    AdamW,
    T5ForConditionalGeneration,
    T5Tokenizer,
    get_linear_schedule_with_warmup
)

In [3]:
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

# 数据预处理

In [4]:
tokenizer = T5Tokenizer.from_pretrained('./t5-base/')

In [5]:
# # 获取字典
# CV_topic_dic = {}
# with open("./LectureBank/CV/CV.topics.tsv",'r',encoding='utf-8') as f:
#     lines = f.readlines()
#     for line in lines:
#         line = line.strip().split("\t")
#         CV_topic_dic[line[0]] = line[1]
# CV_topic_dic

In [6]:
# # 获取字典
# BIO_topic_dic = {}
# with open("./LectureBank/BIO/BIO.topics.tsv",'r',encoding='utf-8') as f:
#     lines = f.readlines()
#     for line in lines:
#         line = line.strip().split("\t")
#         BIO_topic_dic[line[0]] = line[1]
# BIO_topic_dic

In [7]:
# with open("./wiki80/wiki80_train.txt",'r',encoding='utf-8') as f:
#         lines = f.readlines()
#         for line in lines:
#             data = json.loads(line)
#             h_entity_replace = randomLetter()
#             t_entity_replace = randomLetter()
#             plus_text = data['token'][0:data['h']['pos'][0]] + [h_entity_replace] +  data['token'][data['h']['pos'][1]:data['t']['pos'][0]] + [t_entity_replace] + data['token'][data['h']['pos'][1]:] 
#             plus_text = " ".join(plus_text)
#             plus_text = "extract relation: " + plus_text.lower() + " </s>"
#             rel_text = h_entity_replace.lower() + " - " +  data['relation'].lower() + ' - ' + t_entity_replace.lower() + " </s>"
#             print(plus_text)
#             print(rel_text)
#             break

In [8]:
def extract_data(filepath,tokenizer):
    origin_texts = []
    rel_texts = []
    max_input_len = 0
    max_output_len = 0
    with open(filepath,'r',encoding='utf-8') as f:
        lines = f.readlines()
        for line in lines:
            line = line.strip().split(",")
            lecture1 = line[0].replace("_"," ")
            lecture2 = line[1].replace("_"," ")
            text = "judge prerequisite: " + lecture1 + " " + lecture2 + " </s>"
            if line[2] == "0":
                rel = "is not prerequisite"
            elif line[2] == "1":
                rel = "is prerequisite"
            rel_text = lecture1 + " - " +  rel + ' - ' + lecture2 + " </s>"
            origin_texts.append(text)
            rel_texts.append(rel_text)
            tokenized_inp = tokenizer.encode_plus(text, return_tensors="pt")
            tokenized_output = tokenizer.encode_plus(rel_text,return_tensors="pt")
            input_ids  = tokenized_inp["input_ids"]
            output_ids  = tokenized_output["input_ids"]
            max_input_len = max(max_input_len, input_ids.shape[1])
            max_output_len = max(max_output_len, output_ids.shape[1])
            
            # 数据增强
            # 用两个随机字母替换两个实体
#             h_entity_replace = randomLetter()
#             t_entity_replace = randomLetter()
#             first_entity_pos = []
#             second_entity_pos = []
#             if data['h']['pos'][0] < data['t']['pos'][0]:
#                 first_entity_pos = data['h']['pos']
#                 second_entity_pos = data['t']['pos']
#             else:
#                 first_entity_pos = data['t']['pos']
#                 second_entity_pos = data['h']['pos']
#             plus_text = data['token'][0:first_entity_pos[0]] + [h_entity_replace] +  data['token'][first_entity_pos[1]:second_entity_pos[0]] + [t_entity_replace] + data['token'][second_entity_pos[1]:] 
#             plus_text = " ".join(plus_text)
#             plus_text = "extract relation: " + plus_text.lower() + " </s>"
#             rel_text = h_entity_replace.lower() + " - " +  data['relation'].lower() + ' - ' + t_entity_replace.lower() + " </s>"
#             print(plus_text)
#             print(rel_text)
#             origin_texts.append(plus_text)
#             rel_texts.append(rel_text)
    print(max_input_len)
    print(max_output_len)
    return origin_texts,rel_texts

In [9]:
def extract_val_data(filepath,tokenizer):
    origin_texts = []
    rel_texts = []
    max_input_len = 0
    max_output_len = 0
    with open(filepath,'r',encoding='utf-8') as f:
        lines = f.readlines()
        for line in lines:
            line = line.strip().split(",")
            lecture1 = line[0].replace("_"," ")
            lecture2 = line[1].replace("_"," ")
            text = "judge prerequisite: " + lecture1 + " " + lecture2 + " </s>"
            if line[2] == "0":
                rel = "is not prerequisite"
            elif line[2] == "1":
                rel = "is prerequisite"
            rel_text = lecture1 + " - " +  rel + ' - ' + lecture2
            origin_texts.append(text)
            rel_texts.append(rel_text)
    return origin_texts,rel_texts

In [10]:
class Extract_Dataset(Dataset):
    def __init__(self, filepaths, tokenizer,max_input_len,max_output_len):
        self.origin_texts, self.rel_texts = [],[]
        for filepath in filepaths:
            o,r = extract_data(filepath,tokenizer)
            self.origin_texts.extend(o)
            self.rel_texts.extend(r)
        self.max_input_len = max_input_len
        self.max_output_len = max_output_len
        
    def __len__(self):
        return len(self.origin_texts)
    
    def __getitem__(self, index):
        tokenized_input = tokenizer.encode_plus(self.origin_texts[index], max_length=self.max_input_len, pad_to_max_length=True, return_tensors="pt")
        tokenized_output = tokenizer.encode_plus(self.rel_texts[index], max_length=self.max_output_len, pad_to_max_length=True, return_tensors="pt")
        
        input_ids  = tokenized_input["input_ids"].squeeze()
        attention_mask = tokenized_input["attention_mask"].squeeze()

        output_ids = tokenized_output["input_ids"].squeeze()
        decoder_attention_mask=  tokenized_output["attention_mask"].squeeze()
        
        data = {
            'input_ids':input_ids,
            'attention_mask':attention_mask,
            'output_ids':output_ids,
            'decoder_attention_mask':decoder_attention_mask
        }
        return data

# 训练

In [11]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
t5_model = T5ForConditionalGeneration.from_pretrained('./t5-base/').to(DEVICE)

In [12]:
# optimizer
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
    {
        "params": [p for n, p in t5_model.named_parameters() if not any(nd in n for nd in no_decay)],
        "weight_decay": 0.0,
    },
    {
        "params": [p for n, p in t5_model.named_parameters() if any(nd in n for nd in no_decay)],
        "weight_decay": 0.0,
    },
]
optimizer = AdamW(optimizer_grouped_parameters, lr=3e-4)

In [13]:
def train(model, device, train_loader, optimizer):   # 训练模型
    model.train()
    with tqdm(total=len(train_loader)) as bar:
        for idx,data in enumerate(train_loader):
            input_ids, attention_mask, output_ids, decoder_attention_mask = data['input_ids'].to(device), data['attention_mask'].to(device), data['output_ids'].to(device), data['decoder_attention_mask'].to(device)
            output = model(input_ids=input_ids, labels=output_ids,decoder_attention_mask=decoder_attention_mask,attention_mask=attention_mask)
            loss = output[0]
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            bar.set_postfix(loss=loss.item())
            bar.update(1)

In [14]:
def val(model, tokenizer,device, filepaths):
    origin_texts, rel_texts = [], []
    for filepath in filepaths:
        o,r= extract_val_data(filepath,tokenizer)
        origin_texts.extend(o)
        rel_texts.extend(r)
    model.eval()
    TP = 0
    FP = 0
    FN = 0
    with tqdm(total=len(origin_texts)) as bar:
        for idx, text in enumerate(origin_texts):
            input_ids = tokenizer(text, return_tensors='pt').input_ids.to(DEVICE)
            outputs = t5_model.generate(input_ids)
            result = tokenizer.decode(outputs[0], skip_special_tokens=True).split(" - ")
            rel_text = rel_texts[idx].split(" - ")
#             print(rel_text)
#             print(result)
            if len(result) > 1 and result[1] == "is prerequisite":
#                 print(1)
                if result[1] == rel_text[1]:
                    TP += 1
                else:
                    FP += 1
            else:
                if result[1] != rel_text[1]:
                    FN += 1
            bar.update(1)
    print(TP)
    print(FP)
    print(FN)
    precision = TP / (TP + FP)
    recall = TP / (TP + FN)
    F1 = 2 * precision * recall / (precision + recall)

    return precision,recall,F1

In [15]:
# def val(model, tokenizer,device, filepath):
#     origin_texts, rel_texts = extract_val_data(filepath,tokenizer)
#     model.eval()
#     rel_acc_num = 0
#     total = 0
#     all_acc_num = 0
#     entity_acc_num = 0
#     with tqdm(total=len(origin_texts)) as bar:
#         for idx, text in enumerate(origin_texts):
#             total += 1
#             input_ids = tokenizer(text, return_tensors='pt').input_ids.to(DEVICE)
#             outputs = t5_model.generate(input_ids)
#             result = tokenizer.decode(outputs[0], skip_special_tokens=True).split(" - ")
#             rel_text = rel_texts[idx].split(" - ")
# #             print(rel_text)
# #             print(result)
#             if len(result) > 1 and rel_text[1] == result[1]:
#                 rel_acc_num += 1
#             if len(result) > 2 and rel_text[0] == result[0] and rel_text[2] == result[2]:
#                 entity_acc_num += 1
#             if len(result) > 2 and rel_text[1] == result[1] and rel_text[0] == result[0] and rel_text[2] == result[2]:
#                 all_acc_num += 1
#             bar.update(1)
#     return str(rel_acc_num / total),str(entity_acc_num / total),str(all_acc_num / total)

In [16]:
#filepaths = ["./UniversityCourse/train_0.txt","./UniversityCourse/train_1.txt","./UniversityCourse/train_2.txt","./UniversityCourse/train_3.txt","./UniversityCourse/train_4.txt"]
filepaths = ["./UniversityCourse/train_0.txt","./UniversityCourse/train_1.txt","./UniversityCourse/train_2.txt"]
train_dataset = Extract_Dataset(filepaths=filepaths,tokenizer=tokenizer,max_input_len=40,max_output_len=40)
train_loader = DataLoader(train_dataset,shuffle=True,batch_size=32)



24
28
24
28
24
28


In [17]:
NUM_EPOCHES = 20
with open("./log(universityCourse60%).txt",'w',encoding='utf-8') as f:
    for epoch in range(NUM_EPOCHES):
        train(t5_model,DEVICE,train_loader,optimizer)
        precision,recall,F1 = val(t5_model,tokenizer,DEVICE,["./UniversityCourse/test_0.txt","./UniversityCourse/test_1.txt","./UniversityCourse/test_2.txt","./UniversityCourse/test_3.txt","./UniversityCourse/test_4.txt"])
        f.write("EPOCH:" + str(epoch) + "\n")
        f.write("precision：" + str(precision) + "\n")
        f.write("recall：" + str(recall) + "\n")
        f.write("F1：" + str(F1) + "\n")

  0%|          | 0/171 [00:00<?, ?it/s]Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.
100%|██████████| 171/171 [00:26<00:00,  6.46it/s, loss=0.0241]
100%|██████████| 2255/2255 [06:55<00:00,  5.42it/s]
  1%|          | 1/171 [00:00<00:26,  6.32it/s, loss=0.0169]

788
454
217


100%|██████████| 171/171 [00:26<00:00,  6.57it/s, loss=0.00362]
100%|██████████| 2255/2255 [06:53<00:00,  5.45it/s]
  1%|          | 1/171 [00:00<00:26,  6.37it/s, loss=0.0149]

943
569
62


100%|██████████| 171/171 [00:25<00:00,  6.59it/s, loss=0.0179] 
100%|██████████| 2255/2255 [06:48<00:00,  5.51it/s]
  1%|          | 1/171 [00:00<00:27,  6.28it/s, loss=0.0123]

976
543
29


100%|██████████| 171/171 [00:26<00:00,  6.56it/s, loss=0.00932]
100%|██████████| 2255/2255 [06:54<00:00,  5.43it/s]
  1%|          | 1/171 [00:00<00:27,  6.28it/s, loss=0.00577]

986
470
19


100%|██████████| 171/171 [00:25<00:00,  6.60it/s, loss=0.0089] 
100%|██████████| 2255/2255 [07:00<00:00,  5.36it/s]
  1%|          | 1/171 [00:00<00:26,  6.37it/s, loss=0.00524]

923
206
82


100%|██████████| 171/171 [00:26<00:00,  6.54it/s, loss=0.00129]
100%|██████████| 2255/2255 [06:53<00:00,  5.46it/s]
  1%|          | 1/171 [00:00<00:26,  6.38it/s, loss=0.00403]

983
269
22


100%|██████████| 171/171 [00:26<00:00,  6.56it/s, loss=0.00024]
100%|██████████| 2255/2255 [06:53<00:00,  5.45it/s]
  1%|          | 1/171 [00:00<00:26,  6.34it/s, loss=0.00591]

1000
335
5


100%|██████████| 171/171 [00:25<00:00,  6.58it/s, loss=0.00051]
100%|██████████| 2255/2255 [06:57<00:00,  5.40it/s]
  1%|          | 1/171 [00:00<00:26,  6.35it/s, loss=0.00111]

1002
304
3


100%|██████████| 171/171 [00:25<00:00,  6.59it/s, loss=0.000212]
100%|██████████| 2255/2255 [06:53<00:00,  5.45it/s]
  1%|          | 1/171 [00:00<00:26,  6.31it/s, loss=0.00165]

1002
227
3


100%|██████████| 171/171 [00:26<00:00,  6.57it/s, loss=2.98e-5] 
100%|██████████| 2255/2255 [06:54<00:00,  5.44it/s]
  1%|          | 1/171 [00:00<00:26,  6.36it/s, loss=0.00609]

1003
258
2


100%|██████████| 171/171 [00:25<00:00,  6.60it/s, loss=7.62e-5] 
100%|██████████| 2255/2255 [06:53<00:00,  5.46it/s]
  1%|          | 1/171 [00:00<00:26,  6.44it/s, loss=0.000674]

1005
237
0


100%|██████████| 171/171 [00:25<00:00,  6.60it/s, loss=0.000156]
100%|██████████| 2255/2255 [06:57<00:00,  5.41it/s]
  1%|          | 1/171 [00:00<00:26,  6.34it/s, loss=0.000961]

1001
211
4


100%|██████████| 171/171 [00:25<00:00,  6.60it/s, loss=5.05e-5] 
100%|██████████| 2255/2255 [06:54<00:00,  5.45it/s]
  1%|          | 1/171 [00:00<00:26,  6.38it/s, loss=0.000551]

1002
189
3


100%|██████████| 171/171 [00:26<00:00,  6.57it/s, loss=0.000456]
100%|██████████| 2255/2255 [06:55<00:00,  5.43it/s]
  1%|          | 1/171 [00:00<00:27,  6.27it/s, loss=0.000207]

1004
230
1


100%|██████████| 171/171 [00:25<00:00,  6.58it/s, loss=0.0823]  
100%|██████████| 2255/2255 [06:52<00:00,  5.47it/s]
  1%|          | 1/171 [00:00<00:26,  6.38it/s, loss=0.00169]

1000
164
5


100%|██████████| 171/171 [00:25<00:00,  6.59it/s, loss=0.000198]
100%|██████████| 2255/2255 [06:53<00:00,  5.46it/s]
  1%|          | 1/171 [00:00<00:26,  6.33it/s, loss=0.000133]

1004
202
1


100%|██████████| 171/171 [00:25<00:00,  6.60it/s, loss=0.000139]
100%|██████████| 2255/2255 [06:52<00:00,  5.46it/s]
  1%|          | 1/171 [00:00<00:26,  6.34it/s, loss=0.000118]

1002
166
3


100%|██████████| 171/171 [00:26<00:00,  6.57it/s, loss=0.00243] 
100%|██████████| 2255/2255 [06:55<00:00,  5.43it/s]
  1%|          | 1/171 [00:00<00:26,  6.39it/s, loss=0.00125]

1000
155
5


100%|██████████| 171/171 [00:25<00:00,  6.70it/s, loss=0.00641] 
100%|██████████| 2255/2255 [06:54<00:00,  5.44it/s]
  1%|          | 1/171 [00:00<00:26,  6.37it/s, loss=0.00039]

1003
192
2


100%|██████████| 171/171 [00:25<00:00,  6.58it/s, loss=0.000125]
100%|██████████| 2255/2255 [06:54<00:00,  5.44it/s]

1004
277
1





In [18]:
# t =  " ".join(["Flight", "3400", "was", "bound", "for", "Moline", ",", "Il", ".", ",", "when", "it", "was", "diverted", "about", "9", "p.m.", "to", "Majors", "Airport", "in", "Greenville", "."])
# t = "extract relation: " + t.lower() + " </s>"
# print(t)

In [19]:
# input_ids = tokenizer(t, return_tensors='pt').input_ids.to(DEVICE)
# outputs = t5_model.generate(input_ids)
# print(tokenizer.decode(outputs[0], skip_special_tokens=True))

In [20]:
# result = tokenizer.decode(outputs[0], skip_special_tokens=True)
# result.split(" - ")

In [21]:
# print(val(t5_model,tokenizer,DEVICE,"./semeval/semeval_val.txt"))

In [22]:
# acc_rel,acc_entity,acc_all = val(t5_model,tokenizer,DEVICE,"./LectureBank/val.0.csv")

In [23]:
# acc_rel

In [24]:
# acc_all