In [1]:
from datasets import load_dataset

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
DATASET_NAME="common_gen"
data=load_dataset("gem", DATASET_NAME)
data

Downloading builder script: 65.1kB [00:00, 34.1MB/s]                   
Downloading metadata: 166kB [00:00, 82.3MB/s]                    


Downloading and preparing dataset gem/common_gen (download: 1.84 MiB, generated: 10.87 MiB, post-processed: Unknown size, total: 12.71 MiB) to /home/hrenduchinta/.cache/huggingface/datasets/gem/common_gen/1.1.0/982a54473b12c6a6e40d4356e025fb7172a5bb2065e655e2c1af51f2b3cf4ca1...


Downloading data: 100%|██████████| 1.85M/1.85M [00:00<00:00, 8.23MB/s]
Downloading data: 100%|██████████| 87.8k/87.8k [00:00<00:00, 1.85MB/s]
Downloading data files: 100%|██████████| 2/2 [00:00<00:00,  2.51it/s]
Extracting data files: 100%|██████████| 2/2 [00:00<00:00, 43.95it/s]
                                                                                            

Dataset gem downloaded and prepared to /home/hrenduchinta/.cache/huggingface/datasets/gem/common_gen/1.1.0/982a54473b12c6a6e40d4356e025fb7172a5bb2065e655e2c1af51f2b3cf4ca1. Subsequent calls will reuse this data.


100%|██████████| 6/6 [00:00<00:00, 615.66it/s]


DatasetDict({
    train: Dataset({
        features: ['gem_id', 'gem_parent_id', 'concept_set_id', 'concepts', 'target', 'references'],
        num_rows: 67389
    })
    validation: Dataset({
        features: ['gem_id', 'gem_parent_id', 'concept_set_id', 'concepts', 'target', 'references'],
        num_rows: 993
    })
    test: Dataset({
        features: ['gem_id', 'gem_parent_id', 'concept_set_id', 'concepts', 'target', 'references'],
        num_rows: 1497
    })
    challenge_train_sample: Dataset({
        features: ['gem_id', 'gem_parent_id', 'concept_set_id', 'concepts', 'target', 'references'],
        num_rows: 500
    })
    challenge_validation_sample: Dataset({
        features: ['gem_id', 'gem_parent_id', 'concept_set_id', 'concepts', 'target', 'references'],
        num_rows: 500
    })
    challenge_test_scramble: Dataset({
        features: ['gem_id', 'gem_parent_id', 'concept_set_id', 'concepts', 'target', 'references'],
        num_rows: 500
    })
})

In [3]:
data['train'][0]

{'gem_id': 'common_gen-train-0',
 'gem_parent_id': 'common_gen-train-0',
 'concept_set_id': 0,
 'concepts': ['mountain', 'ski', 'skier'],
 'target': 'Skier skis down the mountain',
 'references': []}

In [4]:
def construct_input_for_batch(batch):
    """Construct input strings from a batch."""
    source = [' '.join(concepts) for concepts in batch["concepts"]]
    target = batch["target"]
    return source, target

In [5]:
def batch_tokenize(batch, tokenizer, max_length=32):
    """Construct the batch (source, target) and run them through a tokenizer."""
    source, target = construct_input_for_batch(batch)
    res = {
        "input_ids": tokenizer(source)["input_ids"],
        "labels": tokenizer(
            target,
            padding="max_length",
            truncation=True,
            max_length=max_length
        )["input_ids"],
    }
    return res

In [7]:
from transformers import AutoTokenizer

In [8]:
MODEL_NAME = "gpt2"
MAX_LENGTH = 32

In [12]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer.pad_token=tokenizer.eos_token

In [13]:
train_data_tokenized = data['train'].map(
    lambda batch: batch_tokenize(batch, tokenizer, max_length=MAX_LENGTH),
    batched=True
)
valid_data_tokenized = data['validation'].map(
    lambda batch: batch_tokenize(batch, tokenizer, max_length=MAX_LENGTH),
    batched=True
)

100%|██████████| 68/68 [00:03<00:00, 19.76ba/s]
100%|██████████| 1/1 [00:00<00:00, 20.94ba/s]


In [14]:
from transformers import GPT2LMHeadModel

In [15]:
from datasets import load_metric

In [17]:
rouge_scorer = load_metric("rouge")

Downloading builder script: 5.60kB [00:00, 4.59MB/s]                   


In [18]:
def rouge_metric_builder(tokenizer):
    def compute_rouge_metrics(pred):
        """Utility to compute ROUGE during training."""
        labels_ids = pred.label_ids
        pred_ids = pred.predictions
        # All special tokens are removed.
        pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
        labels_ids[labels_ids == -100] = tokenizer.pad_token_id
        label_str = tokenizer.batch_decode(labels_ids, skip_special_tokens=True)
        # Compute the metric.
        rouge_results = rouge_scorer.compute(
            predictions=pred_str,
            references=label_str,
            rouge_types=["rouge2", "rougeL"],
            use_agregator=True,
            use_stemmer=False,
        )
        return {
            "rouge2": round(rouge_results['rouge2'].mid.fmeasure, 4),
            "rougeL": round(rouge_results['rougeL'].mid.fmeasure, 4),
        }
    return compute_rouge_metrics

In [19]:
rouge_metric_fn = rouge_metric_builder(tokenizer)

In [20]:
import torch

In [None]:
DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
RANDOM_SEED = 42
BEAM_SIZE = 4

In [None]:
model = GPT2LMHeadModel.from_pretrained(MODEL_NAME)
model = model.to(DEVICE)

In [None]:
def beam_generate_sentences(
    batch,
    model,
    tokenizer,
    num_beams=4,
    max_length=32,
    device='cuda:0'
):
    """Generate outputs from a model with beam search decoding."""
    # Create batch inputs.
    source, _ = construct_input_for_batch(batch)
    # Use the model's tokenizer to create the batch input_ids.
    batch_features = tokenizer(source, padding=True, return_tensors='pt')
    # Move all inputs to the device.
    batch_features = dict([(k, v.to(device)) for k, v in batch_features.items()])

    # Generate with beam search.
    generated_ids = model.generate(
        **batch_features,
        num_beams=num_beams,
        max_length=max_length,
    )

    # Use model tokenizer to decode to text.
    generated_sentences = [
        tokenizer.decode(gen_ids.tolist(), skip_special_tokens=True)
        for gen_ids in generated_ids
    ]
    return generated_sentences

In [None]:
valid_output = data['validation'].map(
    lambda batch: {'generated': beam_generate_sentences(
        batch,
        model,
        tokenizer,
        num_beams=BEAM_SIZE,
        max_length=MAX_LENGTH,
        device=DEVICE)
    },
    batched=True,
    batch_size=128,
)