# Setup

In [None]:
!pip install transformers

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [None]:
from google.colab import drive
drive.mount('/content/drive')

import json
import torch
import os
import datetime
import pickle
import pandas as pd
import numpy as np
import random

import torch
import torchvision
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from sklearn.model_selection import train_test_split

from transformers.models.bart.modeling_bart import shift_tokens_right
from transformers import BartTokenizer, BartModel, BartForConditionalGeneration

import nltk
nltk.download('averaged_perceptron_tagger')
nltk.download('punkt')
nltk.download('brown')
nltk.download('wordnet')
nltk.download('omw-1.4')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     /root/nltk_data...
[nltk_data]   Package averaged_perceptron_tagger is already up-to-
[nltk_data]       date!
[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package brown to /root/nltk_data...
[nltk_data]   Package brown is already up-to-date!
[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package omw-1.4 to /root/nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!


True

In [None]:
use_cuda = True if torch.cuda.is_available() else False
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print('We are using GPU.' if use_cuda else 'We are using CPU.')

We are using GPU.


# Data Loader

In [None]:
def collate_fn(data):
    """
    Process each batch
    """
    # sort to find encoder input max length. data is nested list of storylines
    data.sort(key = lambda x: len(x[0]), reverse=True)
    input_ids, target_story_ids, story_ids, storyline = zip(*data)
    input_max_len = len(input_ids[0])

    # sort by decoder len
    data.sort(key=lambda x: len(x[1]), reverse=True)
    input_ids, target_story_ids, story_ids, storyline = zip(*data)
    story_max_len = len(target_story_ids[0])

    inputs_padded = []
    input_attention_ids = []
    gold_stories_padded = []

    for i in range(len(data)):

        input_line = input_ids[i] # ['input_ids']
        seq_len = len(input_line)
        pad_len = input_max_len - seq_len
        input_line.extend([tokenizer.pad_token_id] * pad_len) # add padding to right. Token IDs
        inputs_padded.append(torch.tensor(input_line))

        # encoder attention ids
        attentions = ([1] * seq_len)+ ([0] * pad_len)
        input_attention_ids.append(torch.tensor(attentions))

        gold_story = target_story_ids[i] # ['input_ids']
        seq_len = len(gold_story)
        pad_len = story_max_len - seq_len
        gold_story.extend([tokenizer.pad_token_id] * pad_len)
        gold_stories_padded.append(torch.tensor(gold_story))

    inputs_padded = torch.stack(inputs_padded) # (batch_size, max_len)
    input_attention_ids = torch.stack(input_attention_ids) # (batch_size, max_len)

    gold_stories_padded = torch.stack(gold_stories_padded)

    batch_data = {"inputs": inputs_padded, # (batch_size, max_len)
                  "gold_stories": gold_stories_padded,
                  "input_attention_ids": input_attention_ids,
                  "batch_max_lens": (input_max_len, story_max_len),
                  "story_ids": story_ids,
                  "text_inputs": storyline} # <-- return sentence lengths

    return batch_data

In [None]:
class VSTDataLoader(Dataset):

    def __init__(self, storyinfo_path, storyline_path, tokenizer,
                 split = "train", cap_type = "CLIP", weights = "tgcn_cosine"):

        self.tokenizer = tokenizer

        # Ground truth stories: for finding the images of the story
        self.gt_stories = json.load(open(os.path.join(storyinfo_path, "{}_stories.json".format(split))))
        self.story_ids = list(self.gt_stories.keys())


        self.srls = json.load(open(os.path.join(storyline_path, "{}_{}_srl.json".format(cap_type, weights))))
        srl_keys = set(list(self.srls.keys()))
        self.story_ids = [x for x in self.story_ids if x in srl_keys]

    def __len__(self):

        return len(self.story_ids)

    def __getitem__(self, index):

        story_id = self.story_ids[index]

        srl = self.srls[story_id]

        storyline = ' </s> '.join(srl)
        input_ids = self.tokenizer(storyline)['input_ids']

        joined_story = ' '.join(self.gt_stories[story_id]["story"])
        target_story_ids = self.tokenizer(joined_story)["input_ids"]


        return input_ids, target_story_ids, story_id, storyline

In [None]:
tokenizer = BartTokenizer.from_pretrained("facebook/bart-large")
print("Tokenizer size before adding tokens: {}".format(len(tokenizer)))

Tokenizer size before adding tokens: 50265


In [None]:
%%time
storyinfo_path = "/content/drive"
srl_path = "/content/drive"

valid_dl = VSTDataLoader(storyinfo_path, srl_path, tokenizer, split = "valid")
train_dl = VSTDataLoader(storyinfo_path, srl_path, tokenizer, split = "train")
test_dl = VSTDataLoader(storyinfo_path, srl_path, tokenizer, split = "test")

CPU times: user 208 ms, sys: 79.6 ms, total: 287 ms
Wall time: 397 ms


In [None]:
### test data loader

index = 2414

print("input ids: {}".format(len(test_dl[index][0])))
print("Length of story_ids: {}".format(len(test_dl[index][1])))
print("Story id: {}".format(test_dl[index][2]))

input ids: 108
Length of story_ids: 43
Story id: 47944


In [None]:
print(valid_dl.__len__())
print(train_dl.__len__())
print(test_dl.__len__())

4988
40137
5055


In [None]:
valid_dataloader = torch.utils.data.DataLoader(valid_dl, shuffle = False, batch_size = 8, collate_fn = collate_fn)
test_dataloader = torch.utils.data.DataLoader(test_dl, shuffle = False, batch_size = 16, collate_fn = collate_fn)
train_dataloader = torch.utils.data.DataLoader(train_dl, shuffle = True, batch_size = 8, collate_fn = collate_fn)

In [None]:
%%time
### test collate function
for batch_idx, data in enumerate(valid_dataloader):
    # if batch_idx % 100 == 0: print(batch_idx)
    if batch_idx > 0: break
    batch_data = data

CPU times: user 28.2 ms, sys: 0 ns, total: 28.2 ms
Wall time: 27.7 ms


In [None]:
print("input_ids: {}".format(batch_data["inputs"][0].shape))
print("Story ids: ", len(batch_data["story_ids"][0]))
print("input attention ids: ", batch_data["input_attention_ids"].shape)
print("Gold Stories: ", batch_data["gold_stories"].shape)

input_ids: torch.Size([122])
Story ids:  5
input attention ids:  torch.Size([8, 122])
Gold Stories:  torch.Size([8, 74])


# Bart Encoder Decoder

In [None]:
class StoryDecoder(nn.Module):
    def __init__(self, bart_model, tokenizer):
        """
        Event to Stories
        """
        super(StoryDecoder, self).__init__()

        self.tokenizer = tokenizer
        self.bart_model = bart_model

        new_embeddings = self.bart_model.resize_token_embeddings(len(tokenizer))

    def forward(self, batch_data, device):

        decoder_input_ids = shift_tokens_right(batch_data["gold_stories"],
                                               tokenizer.pad_token_id,
                                               tokenizer.eos_token_id)

        input_ids = batch_data["inputs"].to(device)

        output = self.bart_model(input_ids = batch_data["inputs"].to(device),
                                attention_mask = batch_data["input_attention_ids"].to(device),
                                decoder_input_ids = decoder_input_ids.to(device))

        lm_logits = output[0] # (batch_size, max_story_len, vocab_size)

        return lm_logits

In [None]:
story_decoder = StoryDecoder(bart_model, tokenizer)
lm_logits = story_decoder(batch_data, device)

# Plot2Story

In [None]:
class Plot2Story(nn.Module):
    def __init__(self, tokenizer, model_nm = "facebook/bart-large"):
        """
        Event to Stories
        """
        super(Plot2Story, self).__init__()

        self.bart_model = BartForConditionalGeneration.from_pretrained(model_nm)
        new_embeddings = self.bart_model.resize_token_embeddings(len(tokenizer))

        self.tokenizer = tokenizer

    def forward(self, batch_data, device):

        decoder_input_ids = shift_tokens_right(batch_data["gold stories"],
                                               tokenizer.pad_token_id,
                                               tokenizer.bos_token_id)

        output = self.bart_model(input_ids = batch_data["story lines"].to(device),
                                attention_mask = batch_data["input attention ids"].to(device),
                                decoder_input_ids = decoder_input_ids.to(device),
                                decoder_attention_mask = batch_data["decoder attention ids"].to(device))

        lm_logits = output[0] # (batch_size, max_story_len, vocab_size)

        return lm_logits

# Trainer

In [None]:
def train_epoch(model, dataloader, criterion, optimizer, mode = 'train'):
    """
    Function to train each epoch.

    Args:
        dataloader: either train or valid dataloader.
        criterion: loss function to use.
        optimizer: optimizer for training.
        mode: either "train" or "validate".
    """
    time1 = datetime.datetime.now()

    if mode == 'train':
        model.train()
    else: #put model in validation mode
        model.eval()

    #keep track of training and validation loss and accuracy
    running_loss, running_acc = 0, 0

    # mini-batch training with the dataloader
    for batch_idx, data in enumerate(dataloader):

        # move this batch of data to specified device
        batch_data = data
        # gradient calculation when training
        with torch.set_grad_enabled(mode =="train"):

            # forward data through model to get logits
            lm_logits = model(batch_data, device) # (batch_size, max_sent_len, vocab_size)


            loss = criterion(lm_logits.view(-1, lm_logits.shape[-1]).to('cpu'),
                             batch_data['gold_stories'].view(-1))

            if mode == 'train':
                loss.backward()       # backward the loss and calculate gradients for parameters.
                optimizer.step()      # update the parameters.
                optimizer.zero_grad() # zero the gradient to stop from accumulating

        if (batch_idx + 1) % 10 == 0:
            print("Processed batch: {}. Loss: {}".format(batch_idx+1, loss.item()))

        running_loss += loss.item()

    # note len(dataloader) is number of batches
    epoch_loss = running_loss/len(dataloader) # len(dataloader) = no. of examples / batch size
    time2 = datetime.datetime.now()

    return epoch_loss, (time2-time1).total_seconds()

def train_model(model, training_info, start_epoch = 0):
    """
    Function for model training.

    Args:
        model: initialised pytorch model (class)
        training_info: dict of loader, criterion and optimizer information.
        opt: the parsed arguments.
        start_epoch: starting epoch. Will be > 0 if loaded model checkpoint.
    """

    MIN_LOSS = float('inf')
    EARLY_STOPPING_COUNT = 0
    EVAL_EVERY_EPOCH = 1

    scheduler = training_info["scheduler"]

    for epoch in range(start_epoch, training_info["num_epochs"]):

        # forward training data through model
        train_loss, runtime = train_epoch(model, training_info["train_loader"],
                                                    training_info["criterion"],
                                                    training_info["optimizer"],
                                                    mode = 'train')

        print("Epoch:%d, train loss: %.4f, time: %.2fs" %(epoch+1, train_loss, runtime))

        if (epoch + 1) % EVAL_EVERY_EPOCH == 0:
            valid_loss, runtime = train_epoch(model, training_info["valid_loader"],
                                                         training_info["criterion"],
                                                         training_info["optimizer"],
                                                         mode = 'validate')

            print('-'*60)
            print("Epoch:%d, valid loss: %.4f, time: %.2fs" %(epoch+1, valid_loss, runtime))
            print('-'*60)

            """
            CHECK EARLY STOPPING CONDITIONS
            """
            if valid_loss < MIN_LOSS:
                MIN_LOSS = valid_loss
                EARLY_STOPPING_COUNT = 0

                # save the best model so far
                state = {"epoch": epoch + 1, "model": model.state_dict(), "valid_loss": valid_loss,
                         "train_loss": train_loss}
                model_name = "srl2story_CLIP_tgcn_cosine_epoch{}.pth.tar".format(epoch + 1)
                torch.save(state, os.path.join(training_info["save_path"], model_name))
            else:
                EARLY_STOPPING_COUNT += 1

            if EARLY_STOPPING_COUNT == training_info["num_es_epochs"]:
                break

            # apply learning rate decay
            scheduler.step()

In [None]:
model_nm = "facebook/bart-large"
bart_model = BartForConditionalGeneration.from_pretrained(model_nm)

Downloading:   0%|          | 0.00/1.02G [00:00<?, ?B/s]

In [None]:
model = StoryDecoder(bart_model, tokenizer)
model = model.to(device) # move to GPU

In [None]:
learning_rate = 0.00002
weight_decay = 0.00001
lr_decay = 0.95
num_epochs = 20
early_stop = 3
save_path = "/content/drive/Model"

#resume = "/content/drive/Model/srl2story_CLIP_cap_epoch2.pth.tar"
resume = False

criterion = nn.CrossEntropyLoss(ignore_index = tokenizer.pad_token_id)
optimizer = optim.Adam(model.parameters(), lr = learning_rate,
                        weight_decay = weight_decay)

scheduler = optim.lr_scheduler.StepLR(optimizer, step_size = 1.0, gamma = lr_decay)

training_info  = {"num_epochs": num_epochs,
                "criterion": criterion,
                "optimizer": optimizer,
                "scheduler": scheduler,
                "num_es_epochs": early_stop,
                "train_loader": train_dataloader,
                "valid_loader": valid_dataloader,
                "save_path": save_path}

if resume:
    ### start training model from loaded check point
    if os.path.isfile(resume):
        checkpoint = torch.load(resume)
        model.load_state_dict(checkpoint['model'])
        start_epoch = checkpoint["epoch"]

        print("Loaded model checkpoint! Starting at epoch: {}".format(start_epoch+1))

        train_model(model, training_info, start_epoch = start_epoch)
else:
    ### train model from scratch
    train_model(model, training_info)

Processed batch: 10. Loss: 3.499938488006592
Processed batch: 20. Loss: 3.3829314708709717
Processed batch: 30. Loss: 3.360431671142578
Processed batch: 40. Loss: 3.1429808139801025
Processed batch: 50. Loss: 3.55824613571167
Processed batch: 60. Loss: 3.2122206687927246
Processed batch: 70. Loss: 3.5974764823913574
Processed batch: 80. Loss: 3.423002243041992
Processed batch: 90. Loss: 3.4307570457458496
Processed batch: 100. Loss: 3.218897819519043
Processed batch: 110. Loss: 3.5450572967529297
Processed batch: 120. Loss: 2.9842424392700195
Processed batch: 130. Loss: 3.2806568145751953
Processed batch: 140. Loss: 3.0795178413391113
Processed batch: 150. Loss: 3.1647517681121826
Processed batch: 160. Loss: 3.138158082962036
Processed batch: 170. Loss: 3.3837642669677734
Processed batch: 180. Loss: 2.8943586349487305
Processed batch: 190. Loss: 3.209237575531006
Processed batch: 200. Loss: 3.146991729736328
Processed batch: 210. Loss: 3.5114927291870117
Processed batch: 220. Loss: 3.1

# Save Generated Stories

In [None]:
model_nm = "facebook/bart-large"
bart_model = BartForConditionalGeneration.from_pretrained(model_nm)

model = StoryDecoder(bart_model, tokenizer)
model = model.to(device) # move to GPU

resume = "/content/VIST Model/Stage 2 Model/srl2story_CLIP_pmi_epoch4.pth.tar"

if resume:
    ### load model
    if os.path.isfile(resume):
        checkpoint = torch.load(resume) # map_location=torch.device('cpu'))
        model.load_state_dict(checkpoint['model'])
        start_epoch = checkpoint["epoch"]

        print("Loaded model checkpoint! Model obtained from epoch: {}".format(start_epoch))
else:
    print("Could not find model checkpoint! :(")

In [None]:
dataloader = torch.utils.data.DataLoader(test_dl, shuffle = False, batch_size = 16, collate_fn = collate_fn)
print("Number of batches: {}".format(len(dataloader)))

Number of batches: 316


In [None]:
### test collate function
model.eval()
decoding_method = "greedy"
gen_stories = {}

for batch_idx, data in enumerate(dataloader):

    if (batch_idx + 1) % 50 == 0:
        print("Processing batch {}".format(batch_idx + 1))

    batch_data = data

    story_ids = set(batch_data["story_ids"])
    batch_data = data

    with torch.no_grad():
        batch_data = data
        input_ids = batch_data["inputs"].to(device)
        attention_mask = batch_data["input_attention_ids"].to(device)

    if decoding_method == "beam":
        outputs = model.bart_model.generate(input_ids = input_ids,
                                            attention_mask = attention_mask,
                                            num_beams = 3, max_length = 200)
    elif decoding_method == "nucleus":
        outputs = model.bart_model.generate(input_ids = input_ids,
                                            attention_mask = attention_mask,
                                            top_p = 0.9, temperature = 0.9, top_k = 0,
                                            max_length = 200, do_sample = True)
    elif decoding_method == "top_k":
        outputs = model.bart_model.generate(input_ids = input_ids,
                                            attention_mask = attention_mask,
                                            top_k = 50, temperature = 1.5, max_length = 200)
    elif decoding_method == "greedy":
        outputs = model.bart_model.generate(input_ids = input_ids,
                                            attention_mask = attention_mask,
                                            max_length = 200)

    generated_outputs = tokenizer.batch_decode(outputs,skip_special_tokens=True)

    for i in range(len(outputs)):
        story_id = batch_data["story_ids"][i]
        gen_stories[story_id] = generated_outputs[i]

Processing batch 50
Processing batch 100
Processing batch 150
Processing batch 200
Processing batch 250
Processing batch 300


In [None]:
# save to json file
with open('srl_pmi_greedy_stories.json', 'w') as fp:
    json.dump(gen_stories, fp)