## Image Sequence to Story generation Demo

### 1. Import Dependencies

In [1]:
import os
import glob
import numpy as np
import pandas as pd
import random
import time
import json
import nltk
import ipyplot
from pprint import pprint
import torch
import pytorch_lightning as pl
from pytorch_lightning import seed_everything
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer
from datasets import load_dataset, Dataset, load_metric
import warnings
warnings.filterwarnings('ignore')

### 2. Read Input Data

In [2]:
# Set data_example as either 'VIST' or 'COCO'. 

# For VIST, we sample a fixed image sequence from the VIST test dataset. 
# For COCO, we sample a random set of similar images from the MS-COCO test split (More diverse generations).

data_example = 'VIST'

In [3]:
if data_example == 'VIST':
    
    # Path to generated Flickr Captions
    flickr_gens = np.load('../pretrained/sub_gc_flickr/captions_16000.npy', allow_pickle=True)
    
    # Path to VIST Images
    VIST_demo_path = 'demo_utils/VIST_Flickr'
    VIST_demo_imgs = glob.glob(os.path.join(VIST_demo_path, '*.jpg'))
    VIST_demo_ids = [os.path.basename(p).split('_')[0] for p in VIST_demo_imgs]

    # Load Presaved Generations (with scores) from Ids
    VIST_demo_generations = [[gen for gen in flickr_gens if str(gen['image_id']) == Id][0] for Id in VIST_demo_ids]

    # Display Images
    ipyplot.plot_images(VIST_demo_imgs, max_images=20, img_width=180)
    
    # Ground Truth Story (If Present)
    with open('demo_utils/VIST_Flickr/ground_truth_story.txt') as f:
        gtstory = f.read()
    print("Ground Truth Story: {}".format(gtstory))
    
    # Prepare Data
    captions = VIST_demo_generations
    
elif data_example == 'COCO':
    
    # Path to generated MS-COCO Captions
    coco_gens = np.load('../pretrained/sub_gc_MRNN/captions_60000.npy', allow_pickle=True)

    # Path to MS-COCO Images
    COCO_demo_dir = 'demo_utils/MSCOCO'
    COCO_demo_path = random.sample(glob.glob(os.path.join(COCO_demo_dir, '*')), 1)[0]
    COCO_demo_imgs = glob.glob(os.path.join(COCO_demo_path, '*.jpg'))
    COCO_demo_ids = [os.path.basename(p) for p in COCO_demo_imgs]

    # Load Presaved Generations (with scores) from Ids
    COCO_demo_generations = [[gen for gen in coco_gens if os.path.basename(gen['image_path']) == Id][0] for Id in COCO_demo_ids]

    # Display Images
    ipyplot.plot_images(COCO_demo_imgs, max_images=20, img_width=180)
    
    # Prepare Data
    captions = COCO_demo_generations    
    
else:
    raise Exception("Data Split not recognised ...")

Ground Truth Story: It was my grandmother's birthday last month, so we threw her a party. My uncle showed up and made sure everyone had fun. My niece was getting so big. It was amazing to see how grown up she'd become. My brother and his wife were also there. They are always great to see. The kids had so much fun, but we were worn out trying to keep them out of trouble. It was a great day.


### 3. Load Pretrained Model

In [4]:
# Here we load the C2S-LM model. 
# Sub-GC predictions are presaved locally and read along with data to prevent additional preprocessing.

In [5]:
# Load pretrained C2S-LM
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
model_checkpoint = "pretrained/t5-large-finetuned-caption-to-story-gen/checkpoint-5000/"
model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint).to(DEVICE)

In [6]:
# Load tokenizer
model_checkpoint = "t5-large"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

### 4. Preprocess Inputs

In [7]:
# Sample a set of generated captions
capset = [cap['caption'][0].capitalize()+'.' for cap in captions]
data = [capset]

In [8]:
capset

['A group of people are sitting in a tent.',
 'A man in a red shirt is holding a microphone.',
 'A woman in a white shirt is holding a bottle of wine.',
 'A group of women in white dresses are dancing.',
 'A group of young girls are standing in a circle.']

In [9]:
def data_gen():
    for i in range(len(data)):
        yield {"captions": " ".join(data[i]), "story": "N.A."}
        
# Huggingface Dataset Object for one data-point
dataset = Dataset.from_generator(data_gen)

prefix = "generate a short story using the following descriptions of events: "
max_input_length = 256
max_target_length = 256

def preprocess(datapoint):
    inputs = [prefix + caption for caption in datapoint["captions"]]
    model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True)

    # Setup the tokenizer for targets
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(datapoint["story"], max_length=max_target_length, truncation=True)

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

# inference_dataset
tokenized_dataset = dataset.map(preprocess, batched=True)

Using custom data configuration default-92d1744028e4b235
Found cached dataset generator (/home/bsantra/.cache/huggingface/datasets/generator/default-92d1744028e4b235/0.0.0)
Loading cached processed dataset at /home/bsantra/.cache/huggingface/datasets/generator/default-92d1744028e4b235/0.0.0/cache-21dbee97b397507f.arrow


### 5. Run C2S-LM on generated captions

In [10]:
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

In [11]:
def compute_metrics(eval_pred):
    metric = load_metric("rouge")
    predictions, labels = eval_pred
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    
    # Rouge expects a newline after each sentence
    decoded_preds = ["\n".join(nltk.sent_tokenize(pred.strip())) for pred in decoded_preds]
    decoded_labels = ["\n".join(nltk.sent_tokenize(label.strip())) for label in decoded_labels]
    
    result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
    # Extract a few results
    result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
    
    # Add mean generated length
    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions]
    result["gen_len"] = np.mean(prediction_lens)
    
    return {k: round(v, 4) for k, v in result.items()}

In [12]:
args = Seq2SeqTrainingArguments(
    f"pretrained/{model_checkpoint}-finetuned-caption-to-story-gen",
    evaluation_strategy = "steps",
    eval_steps=1000,
    save_steps=1000,
    learning_rate=2e-5,
    auto_find_batch_size=True,
    weight_decay=0.01,
    num_train_epochs=1,
    predict_with_generate=True,
    fp16=False,
    report_to="none",
)

trainer = Seq2SeqTrainer(
    model,
    args,
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

predictions, label_ids, metrics = trainer.predict(test_dataset=tokenized_dataset, max_length = 256)

The following columns in the test set don't have a corresponding argument in `T5ForConditionalGeneration.forward` and have been ignored: captions, story. If captions, story are not expected by `T5ForConditionalGeneration.forward`,  you can safely ignore this message.
***** Running Prediction *****
  Num examples = 1
  Batch size = 8
You're using a T5TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


In [13]:
for pred in predictions:
    print("##### Predicted Story #####")
    print(tokenizer.decode(pred, skip_special_tokens=True))

##### Predicted Story #####
The group of friends gathered for the dance. The speaker was very funny. The group danced for a while. Then they all gathered for a group photo.
