# Setup

## Packages and Downloading ATOMIC-2020

In [None]:
!git clone https://github.com/allenai/comet-atomic-2020.git

Cloning into 'comet-atomic-2020'...
remote: Enumerating objects: 190, done.[K
remote: Counting objects: 100% (77/77), done.[K
remote: Compressing objects: 100% (38/38), done.[K
remote: Total 190 (delta 56), reused 42 (delta 39), pack-reused 113[K
Receiving objects: 100% (190/190), 7.15 MiB | 22.00 MiB/s, done.
Resolving deltas: 100% (74/74), done.


In [None]:
### download pretrained model : https://github.com/allenai/comet-atomic-2020
!wget --header="Host: storage.googleapis.com" --header="User-Agent: Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/97.0.4692.71 Safari/537.36" --header="Accept: text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.9" --header="Accept-Language: en-GB,en-US;q=0.9,en;q=0.8" --header="Referer: https://github.com/allenai/comet-atomic-2020/issues/12" "https://storage.googleapis.com/ai2-mosaic-public/projects/mosaic-kgs/comet-atomic_2020_BART.zip" -c -O 'comet-atomic_2020_BART.zip'
!unzip comet-atomic_2020_BART.zip

!pip install -r comet-atomic-2020/requirements.txt

### copy utils script to current directory
shutil.copy("comet-atomic-2020/models/comet_atomic2020_bart/utils.py", "/content/utils.py")

!pip install transformers==3.0.2

## Imports

In [None]:

import shutil
import torch
import argparse
import time 
import datetime
from tqdm import tqdm
from pathlib import Path
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from utils import calculate_rouge, use_task_specific_params, calculate_bleu_score, trim_batch


# Atomic-2020

In [None]:
def chunks(lst, n):
    """Yield successive n-sized chunks from lst."""
    for i in range(0, len(lst), n):
        yield lst[i : i + n]


class Comet:
    def __init__(self, model_path):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.model = AutoModelForSeq2SeqLM.from_pretrained(model_path).to(self.device)
        self.tokenizer = AutoTokenizer.from_pretrained(model_path)
        task = "summarization"
        use_task_specific_params(self.model, task)
        self.batch_size = 100
        self.decoder_start_token_id = None

    def generate(
            self, 
            queries,
            decode_method="beam", 
            num_generate=5, 
            ):

        with torch.no_grad():
            examples = queries

            decs = []
            batch_idx = 0
            for batch in list(chunks(examples, self.batch_size)):
                
                time1 = datetime.datetime.now()

                batch = self.tokenizer(batch, return_tensors="pt", truncation=True, padding="max_length").to(self.device)
                input_ids, attention_mask = trim_batch(**batch, pad_token_id=self.tokenizer.pad_token_id)

                summaries = self.model.generate(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    decoder_start_token_id=self.decoder_start_token_id,
                    num_beams=num_generate,
                    num_return_sequences=num_generate
                    )
                dec = self.tokenizer.batch_decode(summaries, skip_special_tokens=True, clean_up_tokenization_spaces=False)
                decs.append(dec)

                time2 = datetime.datetime.now()

                batch_idx += 1
                if (batch_idx % 50) == 0: 
                    print("Processed batch: {}, time takens: {}".format(batch_idx, (time2-time1).total_seconds()))


            return decs


all_relations = [
    "AtLocation",
    "CapableOf",
    "Causes",
    "CausesDesire",
    "CreatedBy",
    "DefinedAs",
    "DesireOf",
    "Desires",
    "HasA",
    "HasFirstSubevent",
    "HasLastSubevent",
    "HasPainCharacter",
    "HasPainIntensity",
    "HasPrerequisite",
    "HasProperty",
    "HasSubEvent",
    "HasSubevent",
    "HinderedBy",
    "InheritsFrom",
    "InstanceOf",
    "IsA",
    "LocatedNear",
    "LocationOfAction",
    "MadeOf",
    "MadeUpOf",
    "MotivatedByGoal",
    "NotCapableOf",
    "NotDesires",
    "NotHasA",
    "NotHasProperty",
    "NotIsA",
    "NotMadeOf",
    "ObjectUse",
    "PartOf",
    "ReceivesAction",
    "RelatedTo",
    "SymbolOf",
    "UsedFor",
    "isAfter",
    "isBefore",
    "isFilledBy",
    "oEffect",
    "oReact",
    "oWant",
    "xAttr",
    "xEffect",
    "xIntent",
    "xNeed",
    "xReact",
    "xReason",
    "xWant",
    ]


In [None]:
# sample usage (reproducing demo)
print("model loading ...")
comet = Comet("./comet-atomic_2020_BART")
comet.model.zero_grad()
print("model loaded")

model loading ...
model loaded


In [None]:
queries = []
head = "a bouquet of white peonies in a vase"
rel = "xNeed"
query = "{} {} [GEN]".format(head, rel)
queries.append(query)
print(queries)
results = comet.generate(queries, decode_method="beam", num_generate=5)
print(results)

['a bouquet of white peonies in a vase xNeed [GEN]']
[[' put in vase', ' bring to the store', ' bring to the party', ' to buy', ' buy']]


  beam_id = beam_token_id // vocab_size


# Extract Commonsense


In [None]:
caption_type = "CLIP" #@param ["CLIP", "OFA", "BLIP"]
split = "train" #@param ["train", "valid", "test"]
BEAM_SIZE = 5 #@param

captions = json.load(open("/content/{}_captions.json".format(caption_type)))
stories = json.load(open("/content/{}_stories.json".format(split)))

In [None]:
%%time
keys = list(stories.keys())

heads = []

for key in keys: 
    story = stories[key]['story']
    images = stories[key]['images']

    image_formats = ['.jpg', '.gif', '.png', '.bmp']

    for img in images: 
        for f in image_formats: 
            img_name = img + f 
            try:
                caption = captions[img_name]
                if caption[-1] == ".":
                    caption = caption[:-1] 
                if caption == '': continue
                heads.append(caption)
            except: 
                continue 

CPU times: user 393 ms, sys: 1.01 ms, total: 394 ms
Wall time: 393 ms


In [None]:
print(len(heads))
heads = list(set(heads))
print(len(heads))

200746
34204


In [None]:
queries = []
rels = ["AtLocation", "CapableOf", "xNeed", "xIntent", "xWant", "xEffect", "xReact", "xAttr"]

for head in heads: 
    for rel in rels:
        query = "{} {} [GEN]".format(head, rel)
        queries.append(query)

print("Number of queries: {}".format(len(queries)))

Number of queries: 273632


In [None]:
%%time 
results = comet.generate(queries, decode_method="beam", num_generate=BEAM_SIZE) # takes 3 hours for training

Processed batch: 50, time takens: 2.154699
Processed batch: 100, time takens: 2.743159
Processed batch: 150, time takens: 2.00381
Processed batch: 200, time takens: 2.104369
Processed batch: 250, time takens: 2.260292
Processed batch: 300, time takens: 2.217628
Processed batch: 350, time takens: 2.484714
Processed batch: 400, time takens: 2.436287
Processed batch: 450, time takens: 2.25503
Processed batch: 500, time takens: 2.143347
Processed batch: 550, time takens: 2.023096
Processed batch: 600, time takens: 2.247681
Processed batch: 650, time takens: 2.144781
Processed batch: 700, time takens: 2.653079
Processed batch: 750, time takens: 2.046963
Processed batch: 800, time takens: 2.300652
Processed batch: 850, time takens: 2.140453
Processed batch: 900, time takens: 1.99728
Processed batch: 950, time takens: 2.123757
Processed batch: 1000, time takens: 2.287488
Processed batch: 1050, time takens: 2.440983
Processed batch: 1100, time takens: 2.099444
Processed batch: 1150, time taken

In [None]:
all_results = [item for sublist in results for item in sublist]
assert len(queries) * BEAM_SIZE == len(all_results)
len(queries)


273632

In [None]:
### save results 
data_dict = {} 

r_idx = 0
for i in range(len(queries)): 
    data_dict[queries[i]] = all_results[r_idx:r_idx+BEAM_SIZE]
    r_idx += BEAM_SIZE
    assert len(data_dict[queries[i]]) == BEAM_SIZE

print(len(data_dict))

with open("/content/Atomic2020/{}_{}_comet_ck.json".format(caption_type, split), "w") as outfile:
    json.dump(data_dict, outfile)

273632


# Reformat Commonsense to Dictionary Format

In [None]:
def get_cap_and_imgs(stories, story_id, img2cap, cap_type = "VIST", display = False): 
    """
    Retrieve the image names and corresponding captions 
    for given story id. 
    """

    if str(story_id) not in stories:
        print("This story id does not exist.")
        return

    story = stories[str(story_id)]['story']
    images = stories[str(story_id)]['images']

    image_formats = ['.jpg', '.gif', '.png', '.bmp']
    image_list = [] 
    cap_list = []

    for img in images: 

        for f in image_formats: 
            if cap_type == "VIST":
                img_name = img 
            else:
                img_name = img + f 
            try:
                caption = img2cap[img_name]
                if display == True: 
                    print(img_name, ":" , caption)
                image_list.append(img_name)
                if cap_type == "CLIP":  # get rid of last full stop for CLIP captions
                    if caption[-1] == ".": 
                        cap_list.append(caption[:-1])
                    else:
                        cap_list.append(caption)
                else: 
                    cap_list.append(caption)
            except: 
                continue 

    return image_list, cap_list

In [None]:
split = "test"
caption_type = "CLIP"

cks = json.load(open("/content/{}_{}_comet_ck.json".format(caption_type, split)))
stories = json.load(open("/content/{}_stories.json".format(split)))
img2cap = json.load(open("/content/{}_captions.json".format(caption_type)))

In [None]:
rels = ["AtLocation", "CapableOf", "xNeed", "xIntent", "xWant", "xEffect", "xReact", "xAttr"]

cks_dict = {} 
problem_stories = []

for story_id in stories: 
    _, cap_list = get_cap_and_imgs(stories, story_id, img2cap, cap_type = "CLIP")
    # cap_list = list(set(cap_list))
    if len(cap_list) != 5: 
        problem_stories.append(story_id)
        continue
    temp = {} # [img_num : rels]
    for i in range(0, len(cap_list)): 
        img_num = i
        image_ck = {} # rel: [ck1, ck2...]
        for rel in rels: 
            key = cap_list[i] + " " + rel + " " + "[GEN]" 
            common_sense = cks[key]
            common_sense = [x.strip() for x in common_sense if x != " none"]
            image_ck[rel] = common_sense 
        image_ck["caption"] = cap_list[i]
        temp[img_num] = image_ck 
    cks_dict[story_id] = temp 
print(len(problem_stories))
print(len(cks_dict))

In [None]:
with open('cks_{}_{}.json'.format(caption_type, split), 'w') as f:
    json.dump(cks_dict, f)

shutil.move("/content/cks_{}_{}.json".format(caption_type, split),
            "/content/Common Sense Dicts/cks_{}_{}.json".format(caption_type, split))