In [3]:
import os
import torch
from torch import nn
import argparse
from fastprogress.fastprogress import IN_NOTEBOOK
from datasets import load_from_disk, load_metric, load_dataset
from transformers import AutoTokenizer, BartConfig
from architectures import Autoencoder, CBartForConditionalGeneration

In [4]:
if IN_NOTEBOOK:
    class CustomArgs():
        batch_size = 8
        exp_name = "384-1024"
        checkpoint_dir = "./cbart-checkpoints/384"
        first = 576
        second = 480
        third = 384
        test = True
    
    args = CustomArgs()

else:    
    parser = argparse.ArgumentParser()
    
    parser.add_argument('--checkpoint_dir', type=str, help="The cBART model directory.", required=True)
    parser.add_argument('--exp_name', type=str, help="The experiment name.", required=True)
    parser.add_argument('--batch_size', type=int, help="The batch_size.", required=True)
    parser.add_argument('--first', type=int, help="The AE's first projection.", required=True)
    parser.add_argument('--second', type=int, help="The AE's second projection.", required=True)
    parser.add_argument('--third', type=int, help="The AE's third projection.", required=True)
    parser.add_argument('--test', action='store_true', help="A test run with just one sample.")

    args = parser.parse_args()

In [5]:
bart_checkpoint    = 'facebook/bart-base'
encoder_max_length = 1024
log_directory      = os.path.join("./results/", args.exp_name)

os.makedirs(log_directory, exist_ok=True)
rouge = load_metric("rouge")

# The AutoEncoder

In [6]:
ae = Autoencoder(bart_encoder_emb_size=768,
                 first_proj=args.first,
                 second_proj=args.second,
                 third_proj=args.third,
                 max_len=encoder_max_length)

# Initialize compressedBART

In [7]:
cbart_config = BartConfig.from_pretrained(bart_checkpoint)
cbart_config.enc_d_model = cbart_config.d_model
cbart_config.d_model = args.third

In [8]:
CBart_model = CBartForConditionalGeneration.from_pretrained(args.checkpoint_dir,
                                                            config=cbart_config,
                                                            ae=ae,
                                                            ignore_mismatched_sizes=True)

Some weights of CBartForConditionalGeneration were not initialized from the model checkpoint at ./cbart-checkpoints/384 and are newly initialized: ['model.encoder.ae.encoder.1.bias', 'model.encoder.ae.encoder.7.bias', 'model.encoder.ae.decoder.4.num_batches_tracked', 'model.encoder.ae.encoder.1.weight', 'model.encoder.ae.decoder.7.running_mean', 'model.encoder.ae.decoder.4.running_mean', 'model.encoder.ae.encoder.0.weight', 'model.encoder.ae.encoder.7.weight', 'model.encoder.ae.encoder.7.running_var', 'model.encoder.ae.decoder.6.weight', 'model.encoder.ae.decoder.1.running_var', 'model.encoder.ae.decoder.7.running_var', 'model.encoder.ae.encoder.7.num_batches_tracked', 'model.encoder.ae.encoder.7.running_mean', 'model.encoder.ae.encoder.4.running_var', 'model.encoder.ae.encoder.4.bias', 'model.encoder.ae.decoder.7.bias', 'model.encoder.ae.encoder.1.running_var', 'model.encoder.ae.decoder.7.weight', 'model.encoder.ae.encoder.4.running_mean', 'model.encoder.ae.decoder.1.weight', 'model.e

In [11]:
# Load the weights and put it on GPU
ae_checkpoint = os.path.join( args.checkpoint_dir, "ae-checkpoint.pth" )
if torch.cuda.is_available():
    ae_checkpoint = torch.load(ae_checkpoint)

else:
    ae_checkpoint = torch.load(ae_checkpoint, map_location=torch.device('cpu'))

CBart_model.model.encoder.ae.load_state_dict(ae_checkpoint['model'])
del ae_checkpoint

In [12]:
CBart_model.eval()

if torch.cuda.is_available():
    CBart_model.to("cuda")

# Tokenizer

In [13]:
tokenizer = AutoTokenizer.from_pretrained(bart_checkpoint, cache_dir="./hf-cache/bart-base")

# Load the Data

In [14]:
test_dataset = load_from_disk( '../../hf-cache/cnn_dailymail/{}'.format('test') )

In [15]:
# test_dataset = load_dataset("ccdv/cnn_dailymail", "3.0.0", cache_dir="./hf-cache/cnn_dailymail",
#                              split="test", ignore_verifications=True)

In [16]:
if args.test:
    test_dataset = test_dataset.select(range(1))

In [17]:
def generate_summary(batch):
    inputs_dict = tokenizer(batch["article"], padding="max_length", max_length=encoder_max_length, return_tensors="pt", truncation=True)
    input_ids = inputs_dict.input_ids
    attention_mask = inputs_dict.attention_mask
    
    if torch.cuda.is_available():
        input_ids = input_ids.to("cuda")
        attention_mask = attention_mask.to("cuda")

    predicted_abstract_ids = CBart_model.generate(input_ids, attention_mask=attention_mask, use_cache=True)
    batch["predicted_highlights"] = tokenizer.batch_decode(predicted_abstract_ids, skip_special_tokens=True)
    
    return batch

# Decoding Beam Search

In [18]:
# set decoding params
CBart_model.config.max_length = 144
CBart_model.config.min_length = 55
CBart_model.config.no_repeat_ngram_size = 3
CBart_model.config.early_stopping = True
CBart_model.config.length_penalty = 2.0
CBart_model.config.num_beams = 4
    
result = test_dataset.map(generate_summary, batched=True, batch_size=args.batch_size)



  0%|          | 0/1 [00:00<?, ?ba/s]

In [19]:
result[0]['predicted_highlights']

'James Best played Rosco P. Coltrane on "The Dukes of Hazzard"\nHe was 88.\nHe died in hospice in North Carolina, of complications from pneumonia.\nBest was best known for his role, which still lives on in reruns.'

In [20]:
if not IN_NOTEBOOK:
    result.to_csv( os.path.join(log_directory, 'beam.csv') )

beam_scores = rouge.compute(predictions=result["predicted_highlights"], references=result["highlights"],
                            rouge_types=["rouge1", "rouge2", "rouge3", "rougeLsum"])
    
for item, score in beam_scores.items():
    print(item)
    print("{:.3f}".format(score.mid.fmeasure))

rouge1
0.382
rouge2
0.152
rouge3
0.094
rougeLsum
0.353
