# You need to download the dataset first 
! git clone https://github.com/Yale-LILY/QMSum.git

In [1]:
import json, pdb
import numpy as np 
import torch
from transformers import BertModel, BertTokenizer
from nltk import word_tokenize

# define some global variables
# not a good style, but very convenient in notebook experiment
BERT_MODEL = None
NN_DEVICE = 'cuda:0'
data_root = 'QMSum/data/'


In [2]:
def tokenize(sent):
    tokens = ' '.join(word_tokenize(sent.lower()))
    return tokens

def clean_data(text):
    text = text.replace('{ vocalsound } ', '')
    text = text.replace('{ disfmarker } ', '')
    text = text.replace('a_m_i_', 'ami')
    text = text.replace('l_c_d_', 'lcd')
    text = text.replace('p_m_s', 'pms')
    text = text.replace('t_v_', 'tv')
    text = text.replace('{ pause } ', '')
    text = text.replace('{ nonvocalsound } ', '')
    text = text.replace('{ gap } ', '')
    return text

def encode(*args):
    return '<s> {} </s>'.format( ' </s> '.join(args))

def get_bert_embedding(text):
    global BERT_MODEL
    
    # initialize the model if first time run
    if BERT_MODEL is None:
        model_config = 'bert-base-uncased'
        tokenizer = BertTokenizer.from_pretrained(model_config)
        model = BertModel.from_pretrained(model_config, output_hidden_states=True)
        model.to(NN_DEVICE)
        model.eval()
        BERT_MODEL = (tokenizer, model)
        
    tokenizer, model = BERT_MODEL
    ids = tokenizer.encode(text)
    ids = torch.IntTensor(ids).unsqueeze(0).to(NN_DEVICE)
    with torch.no_grad():
        output = model(input_ids = ids)
    
    # extract the last 4 layer hidden state
    # and use that to form a sentence embedding
    hidden_states = output[2]
    features = torch.cat([hidden_states[-i] for i in range(4)] , dim=-1)
    features = features.squeeze().cpu().numpy()
    text_embedding = np.mean(features, axis = 0)
    return text_embedding


In [3]:
from collections import *

# similarity function for bert sentence embedding
def cosine_sim(x, y):
    nx = np.linalg.norm(x)
    ny = np.linalg.norm(y)
    return np.sum(x*y) / (nx * ny)

# similarity function for text span 
def iou_sim(x,y):
    a,b = x
    c,d = y
    union = max(b,d) - min(a,c)
    intersection = min(b,d) - max(a,c)
    intersection = max(0, intersection)
    return intersection / union
  
def get_related_query(query_feature, sim_func, train = False):
    '''
    args
    ---------
    query_feature - a list of bert_embedding/relavant_text_span of each query
    sim_func - a function used to compute the similarity between query features
    train - indicated whether is called for train data
    
    returns 
    ---------
    1. the order to evaluate the queries
    2. the related query of current query
    
    side note 
    ------------
    In practice, because we don't know all the question in advance, the order is 
    fixed, and what we can do is to find most relavant history queries.
    '''
    n = len(query_feature)
    similarity = np.zeros(shape = (n,n))
    for i in range(n):
        for j in range(i+1, n):
            similarity[i,j] = sim_func(query_feature[i], query_feature[j])
            similarity[j,i] = similarity[i,j]
    
    # use a heuristic algorithm
    weights = -np.sum(similarity, axis = 0)
    order =  np.argsort(weights)
    related = [0] * n
    
    for i, k in enumerate(order):
        if not train and i == 0:
            related[k] = None
            continue 
        
        if train:
            selected = order
        else:
            selected = order[0:i]
        
        j = np.argmax(similarity[k][selected])
        
        if k == selected[j]:
            related[k] = None
        else:
            related[k] = selected[j]
    
    
    return order, related
        

In [4]:
def load_data(root_dir, split, metric = None):
    '''
    metric - can only be bert/none
    '''
    
    json_data_path = f'{root_dir}/ALL/jsonl/{split}.jsonl'
    with open(json_data_path) as f:
        meetings = [json.loads(line) for line in f]
        
    print('Loaded {} meetings in {} set'.format(len(meetings), split))
    
    data = [] 
    eval_orders = []
    
    for cur_meet in meetings:
        turns = []
        for item in cur_meet['meeting_transcripts']:
            turns.append('{}: {}'.format(item['speaker'].lower(), tokenize(item['content']) ))
        entire_src = ' '.join(turns)
        
        for key_name in ['general_query_list', 'specific_query_list']:
            queries = [tokenize(item['query']) for item in cur_meet[key_name]]
            answers = [tokenize(item['answer']) for item in cur_meet[key_name]]
            
            offset = len(data)
            if metric is not None:
                query_feature = [get_bert_embedding(q) for q in queries] 
                order, related = get_related_query(query_feature, cosine_sim, split == 'train')
                order = np.array(order) + offset
            else:
                order, related = None, None
            eval_orders.append(order)

            for i, item in enumerate(cur_meet[key_name]):
                cur = dict()
                cur['tgt'] = answers[i]
                cur['query'] = queries[i]
                if 'general' in key_name:
                    text = entire_src
                else:
                    selected_turns = []
                    for st, ed in item['relevant_text_span']:
                        st, ed = int(st), int(ed)
                        for k in range(st, ed+1):
                            selected_turns.append( turns[k] )
                    text = ' '.join(selected_turns)
                cur['text'] = text
                
                #  ----------------------  #
                query = cur['query']
                if metric is not None:
                    j = related[i]

                    if j is None:
                        rq, ra = ' ', ' '
                        cur['rq_index'] = None
                    else:
                        rq, ra = queries[j], answers[j]
                        cur['rq_index'] = j + offset
                    
                    if split == 'train':
                        cur['src'] = clean_data(encode(rq, ra, query, text))
                        
                else:
                    cur['src'] = clean_data(encode(query, text))
                data.append(cur)
    
    if metric is not None:
        eval_orders = list(np.concatenate(eval_orders) )
    
    return data, eval_orders


In [5]:
# load all the required data
# we can load the original format in QMSum paper if we set load_metric to None
load_metric = None
train_data, _ = load_data(data_root, 'train', load_metric)
val_data, val_order = load_data(data_root, 'val', load_metric)
test_data, test_order = load_data(data_root, 'test', load_metric)

Loaded 162 meetings in train set
Loaded 35 meetings in val set
Loaded 35 meetings in test set


In [6]:
# construct the training dataset
import pandas as pd
from datasets import Dataset, DatasetDict

dataset = DatasetDict()
for name, data in [('train', train_data), ('val',val_data), ('test', test_data)]:
    d = {'src': [item['src'] for item in data], 'tgt':[item['tgt'] for item in data]}
    df = pd.DataFrame(data = d)    
    dataset[name] = Dataset.from_pandas(df, split = name)
dataset

DatasetDict({
    train: Dataset({
        features: ['src', 'tgt'],
        num_rows: 1257
    })
    val: Dataset({
        features: ['src', 'tgt'],
        num_rows: 272
    })
    test: Dataset({
        features: ['src', 'tgt'],
        num_rows: 281
    })
})

In [7]:
# define tokenizer and process the data
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from transformers import DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer

MAX_TOKENS = 2048      # same as the QMSum paper

def process(instance, **kwargs):
    return tokenizer(instance['src'], text_target=instance['tgt'], max_length = MAX_TOKENS, truncation=True, **kwargs)

model_name = 'facebook/bart-large-cnn'
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenized_data = dataset.map(process, batched = True, remove_columns=['src', 'tgt'])
tokenized_data

Map:   0%|          | 0/1257 [00:00<?, ? examples/s]

Map:   0%|          | 0/272 [00:00<?, ? examples/s]

Map:   0%|          | 0/281 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 1257
    })
    val: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 272
    })
    test: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 281
    })
})

In [8]:
# load and modify the model
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

# change model architecture a little bit to support 2048 tokens 
if MAX_TOKENS == 2048:
    sd = model.state_dict()
    ori_pe = sd['model.encoder.embed_positions.weight']
    new_pe = torch.cat([ori_pe[:-1], ori_pe[1:]], axis = 0)
    new_pe.requires_grad = True
    sd['model.decoder.embed_positions.weight'] = new_pe
    sd['model.encoder.embed_positions.weight'] = new_pe

    new_config = model.config
    new_config.max_position_embeddings = MAX_TOKENS
    new_model = AutoModelForSeq2SeqLM.from_config(new_config)
    new_model.load_state_dict(sd, strict=True)
    model = new_model
    
model.to(NN_DEVICE)
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model, padding='longest')

In [9]:
# evaluation related block of codes
import nltk
import evaluate
ROUGE = evaluate.load("rouge")

MAX_GENERATION = 160

# this function is from internet for rouge evaluation 
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    
    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    
    # ROUGE expects a newline after each sentence
    decoded_preds = ["\n".join(nltk.sent_tokenize(pred.strip())) for pred in decoded_preds]
    decoded_labels = ["\n".join(nltk.sent_tokenize(label.strip())) for label in decoded_labels]
    
    result = ROUGE.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
    # Extract a few results
    result = {key: value * 100 for key, value in result.items()}
    
    # Add mean generated length
    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions]
    result["gen_len"] = np.mean(prediction_lens)
    
    return {k: round(v, 4) for k, v in result.items()}


def evaluate(model, data, order):
    n = len(data)
    summaries = [None] * n
    preds = []
    label = []
    for i in order:
        item = data[i]
        j = item['rq_index']
        if j is None:
            rq, ra = ' ', ' '
        else:
            rq = data[j]['query']
            ra = summaries[j]
            assert(ra is not None), "Error with evaluation order"
        
        item['src'] = clean_data(encode(rq, ra, item['query'], item['text']))
        inputs = process(item, return_tensors = 'pt')
        tensor_input = {k: inputs[k].to(NN_DEVICE) for k in inputs}
        output = model.generate(**tensor_input, num_beams = 4, min_length = 30, max_length = MAX_GENERATION, do_sample = True)
        summary = tokenizer.batch_decode(output, skip_special_tokens=True)[0]
        summaries[i] = summary
        preds.append(output[0].cpu().numpy())
        label.append(inputs['labels'][0].cpu().numpy())
        
    max_pred_len = max(len(item) for item in preds)
    max_label_len =max(len(item) for item in label)
    pred_array =  np.zeros(shape = (len(preds), max_pred_len), dtype = 'i4' ) + 1
    label_array = np.zeros(shape = (len(label), max_label_len), dtype = 'i4' ) - 100
    for i, item in enumerate(preds):
        n = item.size
        pred_array[i][:n] = item
    
    for i, item in enumerate(label):
        n = item.size
        label_array[i][:n] = item
    
    print( compute_metrics( (pred_array, label_array) ) )
    return 
        

In [10]:
# due to special evaluation order of our dataset 
# we need to use callback to evaluate the performance on validation dataset
from transformers import TrainerCallback

class EvalCallback(TrainerCallback):
    def on_log(self, args, state, control, **kwargs):
        model = kwargs['model']
        evaluate(model, val_data, val_order)
    

In [11]:
# start fine tuning 
training_args = Seq2SeqTrainingArguments(
    output_dir = "./results",
    evaluation_strategy = 'steps',
    learning_rate = 1e-5,
    per_device_train_batch_size = 2,
    per_device_eval_batch_size = 2,
    logging_steps = 4000,
    weight_decay = 1e-4,
    save_total_limit = 3,
    num_train_epochs = 80,
    predict_with_generate = True
)

trainer = Seq2SeqTrainer(
    model = model,
    args = training_args,
    train_dataset = tokenized_data['train'],
    eval_dataset = tokenized_data['val'],
    tokenizer = tokenizer,
    data_collator = data_collator,
    compute_metrics = compute_metrics
)

In [12]:
trainer.train(resume_from_checkpoint=True)
# trainer.train()



  0%|          | 0/309 [00:00<?, ?it/s]

You're using a BartTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Step,Training Loss,Validation Loss


TrainOutput(global_step=50320, training_loss=6.178790771525312e-06, metrics={'train_runtime': 303.0843, 'train_samples_per_second': 331.789, 'train_steps_per_second': 166.026, 'total_flos': 3.0681428033367245e+17, 'train_loss': 6.178790771525312e-06, 'epoch': 80.0})

In [13]:
# make a final evaluation on validation set
trainer.evaluate(tokenized_data['val'])

{'eval_loss': 6.210109710693359,
 'eval_rouge1': 39.5366,
 'eval_rouge2': 13.2804,
 'eval_rougeL': 25.1704,
 'eval_rougeLsum': 34.6919,
 'eval_gen_len': 85.7279,
 'eval_runtime': 293.4395,
 'eval_samples_per_second': 0.927,
 'eval_steps_per_second': 0.463,
 'epoch': 80.0}

In [14]:
trainer.evaluate(tokenized_data['test'])

{'eval_loss': 5.999124526977539,
 'eval_rouge1': 40.1961,
 'eval_rouge2': 14.6027,
 'eval_rougeL': 26.1468,
 'eval_rougeLsum': 35.7441,
 'eval_gen_len': 86.8327,
 'eval_runtime': 307.1277,
 'eval_samples_per_second': 0.915,
 'eval_steps_per_second': 0.459,
 'epoch': 80.0}