In [None]:
!pip install scenedetect
!pip install datasets
!pip install POT

Collecting scenedetect
  Downloading scenedetect-0.6.2-py3-none-any.whl (117 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/117.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m117.1/117.1 kB[0m [31m3.8 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: scenedetect
Successfully installed scenedetect-0.6.2
Collecting datasets
  Downloading datasets-2.18.0-py3-none-any.whl (510 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m510.5/510.5 kB[0m [31m5.5 MB/s[0m eta [36m0:00:00[0m
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m6.4 MB/s[0m eta [36m0:00:00[0m
Collecting multiprocess (from datasets)
  Downloading multiprocess-0.70.16-py310-none-any.whl (134 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.8/134.8 kB

In [None]:
""" utility functions"""
import re
import os
from os.path import basename

import gensim
import torch
from torch import nn
import json
from statistics import median

PAD = 0
UNK = 1
START = 2
END = 3
PAD_TOKEN = '<pad>'
UNK_TOKEN = '<unk>'
START_TOKEN = '<start>'
END_TOKEN = '<end>'

import torch
import numpy as np
from torch.autograd import Variable
from collections import defaultdict, Counter, OrderedDict
from os.path import join
from itertools import chain


class OrderedCounter(Counter, OrderedDict):
    """Counter that remembers the order elements are first encountered"""
    def __repr__(self):
        return '%s(%r)' % (self.__class__.__name__, OrderedDict(self))

    def __reduce__(self):
        return self.__class__, (OrderedDict(self),)


def to_var(x):
    if torch.cuda.is_available():
        x = x.cuda()
    return x


def idx2word(idx, i2w, pad_idx):
    sent_str = [str()]*len(idx)
    for i, sent in enumerate(idx):
        for word_id in sent:
            if word_id == pad_idx:
                break
            sent_str[i] += i2w[str(word_id.item())] + " "
        sent_str[i] = sent_str[i].strip()
    return sent_str


def interpolate(start, end, steps):

    interpolation = np.zeros((start.shape[0], steps + 2))

    for dim, (s, e) in enumerate(zip(start, end)):
        interpolation[dim] = np.linspace(s, e, steps+2)

    return interpolation.T


def expierment_name(args, ts):
    exp_name = str()
    exp_name += "BS=%i_" % args.batch_size
    exp_name += "LR={}_".format(args.learning_rate)
    exp_name += "EB=%i_" % args.embedding_size
    exp_name += "%s_" % args.rnn_type.upper()
    exp_name += "HS=%i_" % args.hidden_size
    exp_name += "L=%i_" % args.num_layers
    exp_name += "BI=%i_" % args.bidirectional
    exp_name += "LS=%i_" % args.latent_size
    exp_name += "WD={}_".format(args.word_dropout)
    exp_name += "ANN=%s_" % args.anneal_function.upper()
    exp_name += "K={}_".format(args.k)
    exp_name += "X0=%i_" % args.x0
    exp_name += "TS=%s" % ts

    return exp_name

def count_data(path):
    """ count number of data in the given path"""
    matcher = re.compile(r'[0-9]+\.json')
    match = lambda name: bool(matcher.match(name))
    names = os.listdir(path)
    n_data = len(list(filter(match, names)))
    return n_data

def make_vocab(wc, vocab_size):
    word2id, id2word = {}, {}
    word2id[PAD_TOKEN] = PAD
    word2id[UNK_TOKEN] = UNK
    word2id[START_TOKEN] = START
    word2id[END_TOKEN] = END
    for i, (w, _) in enumerate(wc.most_common(vocab_size), 4):
        word2id[w] = i
    return word2id

def convert_word2id(w, word2id):
    try:
        wid = word2id[w]
        if wid < 30000:
            return wid
        return UNK
    except:
        return UNK

def make_embedding(id2word, w2v_file, initializer=None):
    attrs = basename(w2v_file).split('.')  #word2vec.{dim}d.{vsize}k.bin
    w2v = gensim.models.Word2Vec.load(w2v_file).wv
    vocab_size = len(id2word)
    emb_dim = int(attrs[-3][:-1])
    embedding = nn.Embedding(vocab_size, emb_dim).weight
    if initializer is not None:
        initializer(embedding)

    oovs = []
    with torch.no_grad():
        for i in range(len(id2word)):
            # NOTE: id2word can be list or dict
            if i == START:
                embedding[i, :] = torch.Tensor(w2v['<s>'])
            elif i == END:
                embedding[i, :] = torch.Tensor(w2v[r'<\s>'])
            elif id2word[i] in w2v:
                embedding[i, :] = torch.Tensor(w2v[id2word[i]])
            else:
                oovs.append(i)
    return embedding, oovs

def count_data_stat(path):
    """ count statistics of the data"""
    max_article_split = []
    max_sentence_split = []
    median_article_split = []
    median_sentence_split = []
    for split in ['train', 'val', 'test']:
        art_sents = []
        data_path = join(path, split)
        for i in range(count_data(data_path)):
            with open(join(data_path, '{}.json'.format(i))) as f:
                js = json.loads(f.read())
                art_sents.append(js['article'])
        article_size = [len(story) for story in art_sents] #number of sentences in an article
        sentence_size = [len(row) for row in chain.from_iterable([story for story in art_sents])] #max number of words in a sentence
        max_article_size = max(article_size)
        max_sentence_size = max(sentence_size)
        median_article_size = median(article_size)
        median_sentence_size = median(sentence_size)
        max_article_split.append(max_article_size)
        max_sentence_split.append(max_sentence_size)
        median_article_split.append(median_article_size)
        median_sentence_split.append(median_sentence_size)
        print('######## Statistics for', split,'split: ######')
        print('Number of data:', count_data(data_path))
        print('Max number of sentences in an article:', max_article_size)
        print('Median number of sentences in an article:', median_article_size)
        print('Max number of words in a sentence:', max_sentence_size)
        print('Median number of words in a sentence:', median_sentence_size)

    return max(max_article_split), max(max_sentence_split), max(median_article_split), max(median_sentence_split)

#MAX_ARTICLE_SIZE, MAX_SENTENCE_SIZE, MEAN_ARTICLE_SIZE, MEAN_SENTENCE_SIZE = count_data_stat(DATA_DIR)

In [None]:
from transformers import GPT2Tokenizer as GPT2Tok
from transformers import BertTokenizer as BertTok
#import sentencepiece as spm
import nltk

class Capita:
    def forward(self, text):
        # words = nltk.tokenize.word_tokenize(text)
        words = text.split(" ")
        final_words = []
        for word in words:
            if not word.isalpha():
                final_words.append(word.lower())
            else:
                if word.islower():
                    pass
                elif word.isupper():
                    final_words.append("⇧")
                elif word[0].isupper() and word[1:].islower():
                    final_words.append("↑")
                else:
                    final_words.append("↑")
                final_words.append(word.lower())
        return " ".join(final_words)

    def backward(self, text):
        words = text.split(" ")
        final_words = []
        all_caps = False; capitalized = False
        for w in words:
            if w == "⇧": all_caps = True
            elif w == "↑": capitalized = True
            else:
                final_word = w
                if all_caps: final_word = final_word.upper()
                elif capitalized:
                    if len(final_word) <= 1: final_word = final_word.upper()
                    else: final_word = final_word[0].upper()+final_word[1:]
                final_words.append(final_word)
                all_caps = False; capitalized = False
        return " ".join(final_words)
"""
class BPETokenizer:
    def __init__(self, bpe_model, use_capita=True):
        self.sp = spm.SentencePieceProcessor()
        self.sp.Load(bpe_model)
        self.use_capita = use_capita

        self.pad_tok, self.start_tok, self.end_tok = "<pad>", "<start>", "<end>"
        self.pad_id, self.start_id, self.end_id = tuple(self.sp.piece_to_id(p) for p in [self.pad_tok, self.start_tok, self.end_tok])

        self.vocab_size = self.sp.get_piece_size()

        if self.use_capita:
            self.cpt = Capita()

    def tokenize(self, text):
        if len(text) == 0:
            return []
        if text[:len(self.start_tok)] == self.start_tok and text[len(self.start_tok)] != " ":
            text = text.replace(self.start_tok, self.start_tok+" ")

        if self.use_capita:
            text = self.cpt.forward(text)
        tokens = self.sp.encode_as_pieces(text)
        tokens = [w for i, w in enumerate(tokens) if (i < (len(tokens)-1) and tokens[i+1] not in ["⇧", "↑"]) or i==(len(tokens)-1)]
        if tokens[0] == "▁":
            tokens = tokens[1:]
        return tokens

    def encode(self, text):
        tokens = self.tokenize(text)
        token_ids = [self.sp.piece_to_id(w) for w in tokens]
        return token_ids

    def decode(self, token_ids):
        text = self.sp.decode_ids(token_ids).replace("⇧", " ⇧").replace("↑", " ↑")
        if self.use_capita:
            text = self.cpt.backward(text)
        text = text.replace(self.start_tok+" ", self.start_tok)
        return text
"""
class BERTCacheTokenizer:
    def __init__(self):
        self.cache = {}
        self.cache_keys = []
        self.tokenizer = BertTok.from_pretrained("bert-base-uncased")
        # self.tokenizer.max_len = 10000 # This was removed in later transformer tokenizers

    def encode(self, text):
        if text in self.cache:
            return self.cache[text]

        output = self.tokenizer.encode(text)

        if len(self.cache) > 1000:
            del self.cache[self.cache_keys.pop(0)]
        self.cache[text] = output
        self.cache_keys.append(text)
        return output

class GPT2Tokenizer:
    def __init__(self):
        self.tokenizer = GPT2Tok.from_pretrained("gpt2")
        # self.tokenizer.max_len = 10000

        self.pad_tok, self.start_tok, self.end_tok = "<PAD>", " ST", " END"

        self.pad_id = 0
        self.start_id = self.tokenizer.encode(self.start_tok)[0]
        self.end_id =   self.tokenizer.encode(self.end_tok)[0]
        self.vocab_size =  self.tokenizer.vocab_size

    def tokenize(self, text):
        return self.tokenizer.tokenize(text)

    def encode(self, text):
        return self.tokenizer.encode(text)

    def decode(self, token_ids):
        return self.tokenizer.decode(token_ids)


In [None]:
import torch.utils.data
from os.path import join, exists
import re, json, cv2, os, sys, glob
import xml.etree.ElementTree as ET
#import torchvision
import random, itertools
#from torchvision import transforms as t
#from torchvision import transforms
from PIL import Image
#import torchvision.models as models
import numpy as np

  # Standard PySceneDetect imports:
from scenedetect.video_manager import VideoManager
from scenedetect.scene_manager import SceneManager

# For content-aware scene detection:
from scenedetect.detectors.content_detector import ContentDetector



class MultimodalDataset(torch.utils.data.dataset.Dataset):
    def __init__(self, split: str, path: str):
        # print(path)
        self._data_path = join(path, split)
        self._n_data = _count_data(self._data_path)

    def __len__(self) -> int:
        return self._n_data

    def __getitem__(self, i: str):
        #print("js path", join(self._data_path, '{}.json'.format(i)))
        #print("i", i)

        with open(join(self._data_path, '{}.json'.format(i))) as f:
            js = json.loads(f.read())

        original_frames = []
        vidcap = cv2.VideoCapture(join(self._data_path, '{}.mp4'.format(i)))
        success,image = vidcap.read()
        count = 0
        while success:
            #if count % 120:
            if count % 480:
                #image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
                #original_frames.append(cv2.resize(image, (640, 360)))
                original_frames.append(torch.tensor(cv2.resize(image, (640, 360))))
            success, image = vidcap.read()
            count += 1
            if count > 100: #reach cuda limit
                break
        # Stack it into a tensor
        if original_frames:
            video = torch.stack(original_frames, 0)
        else:
            video = torch.tensor([])


        thumbnail = cv2.imread(join(self._data_path, '{}.png'.format(i)))

        try:
            transcript = ET.parse(join(self._data_path, '{} (a.en).xml'.format(i))).getroot()
        except:
            transcript = ''


        return js, thumbnail, transcript, video

class EXMSMODataset(MultimodalDataset):
    """ single article sentence -> single abstract sentence
    (dataset created by greedily matching ROUGE)
    """

    def __init__(self, split, DATA_DIR):
        super().__init__(split, DATA_DIR)
        files = glob.glob(join(self._data_path, "*.json"))
        self.file_id = [os.path.split(x)[1].replace('.json', '') for x in files]

    def __getitem__(self, i):
        js, thumbnail, transcript_xml, video = super().__getitem__(self.file_id[i])

        transcript = []
        if not transcript_xml == '':
            for w in transcript_xml:
                transcript.append(w.text)
            #print("transcript_xml", transcript_xml)
            transcripts = '; '.join(transcript).replace('&#39;', '\'')
        else:
            transcripts = ''

        title = js['title']
        description= js['description']

        return  self.file_id[i], description, video, title, thumbnail, transcripts



class MultimodalNoTruncateDataset(torch.utils.data.dataset.Dataset):
    def __init__(self, split: str, path: str):
        self._data_path = join(path, split)
        self._n_data = _count_data(self._data_path)

    def __len__(self) -> int:
        return self._n_data

    def __getitem__(self, i: str):
        #print("js path", join(self._data_path, '{}.json'.format(i)))
        #print("i", i)

        with open(join(self._data_path, '{}.json'.format(i))) as f:
            js = json.loads(f.read())

        original_frames = []
        vidcap = cv2.VideoCapture(join(self._data_path, '{}.mp4'.format(i)))
        vframe = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
        vfps    = vidcap.get(cv2.CAP_PROP_FPS)
        vs    = vframe/vfps


        thumbnail = cv2.imread(join(self._data_path, '{}.png'.format(i)))

        transcript = ET.parse(join(self._data_path, '{} (a.en).xml'.format(i))).getroot()


        return js, thumbnail, transcript, vframe, vfps, vs


class EXMSMONoTruncateDataset(MultimodalNoTruncateDataset):
    """ single article sentence -> single abstract sentence
    (dataset created by greedily matching ROUGE)
    """

    def __init__(self, split, DATA_DIR):
        super().__init__(split, DATA_DIR)
        files = glob.glob(join(self._data_path, "*.json"))
        self.file_id = [os.path.split(x)[1].replace('.json', '') for x in files]

    def __getitem__(self, i):
        js, thumbnail, transcript_xml, vframe, vfps, vs = super().__getitem__(self.file_id[i])

        transcript = []
        for w in transcript_xml:
            transcript.append(w.text)
        #print("transcript_xml", transcript_xml)
        transcripts = '; '.join(transcript).replace('&#39;', '\'')

        title = js['title']
        description= js['description']

        return  self.file_id[i], description, vframe, title, vfps, vs, thumbnail, transcripts



class MultimodalWithSceneDataset(torch.utils.data.dataset.Dataset):
    def __init__(self, split: str, path: str):
        self._data_path = join(path, split)
        self._n_data = _count_data(self._data_path)

    def __len__(self) -> int:
        return self._n_data

    def __getitem__(self, i: str):
        #print("js path", join(self._data_path, '{}.json'.format(i)))
        #print("i", i)

        with open(join(self._data_path, '{}.json'.format(i))) as f:
            js = json.loads(f.read())

        original_frames = []
        vidcap = cv2.VideoCapture(join(self._data_path, '{}.mp4'.format(i)))
        success,image = vidcap.read()
        count = 0
        while success:
            #if count % 120:
            if count % 360:
                #image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
                #original_frames.append(cv2.resize(image, (640, 360)))
                original_frames.append(torch.tensor(cv2.resize(image, (640, 360))))
            success, image = vidcap.read()
            count += 1
            if count > 100: #reach cuda limit
                break
        # Stack it into a tensor
        video = torch.stack(original_frames, 0)


        scene_list = find_scenes(join(self._data_path, '{}.mp4'.format(i)))
        thumbnail = cv2.imread(join(self._data_path, '{}.png'.format(i)))

        #transcript = ET.parse(join(self._data_path, '{} (a.en).xml'.format(i))).getroot()

        #return js, thumbnail, transcript, video, scene_list
        return js, thumbnail, video, scene_list

class EXMSMOWithSceneDataset(MultimodalWithSceneDataset):
    """ single article sentence -> single abstract sentence
    (dataset created by greedily matching ROUGE)
    """

    def __init__(self, split, DATA_DIR):
        super().__init__(split, DATA_DIR)
        files = glob.glob(join(self._data_path, "*.json"))
        self.file_id = [os.path.split(x)[1].replace('.json', '') for x in files]

    def __getitem__(self, i):
        js, thumbnail, video, scene_list = super().__getitem__(self.file_id[i])

        #transcript = []
        #for w in transcript_xml:
        #    transcript.append(w.text)
        #print("transcript_xml", transcript_xml)
        #transcripts = '; '.join(transcript).replace('&#39;', '\'')

        title = js['title']
        description= js['description']

        #return  self.file_id[i], description, video, title, thumbnail, transcripts, scene_list
        return  self.file_id[i], description, video, title, thumbnail, scene_list

class MSMO(MultimodalDataset):
    def __init__(self, split, DATA_DIR):
        super().__init__(split, DATA_DIR)
        files = glob.glob(join(self._data_path + 'article', "*.txt"))
        self.file_id = [os.path.split(x)[1].replace('.txt', '') for x in files]

    def __getitem__(self, i):

        article_path = join(self._data_path + 'article' , '{}.txt'.format(i))
        document, extreme_summaries = get_art_abs(article_path)

        original_frames = []
        vidcap = cv2.VideoCapture(join(self._data_path, '{}.mp4'.format(i)))
        vframe = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
        vfps    = vidcap.get(cv2.CAP_PROP_FPS)
        vs    = vframe/vfps


        thumbnail = cv2.imread(join(self._data_path, '{}.png'.format(i)))

        transcript = ET.parse(join(self._data_path, '{} (a.en).xml'.format(i))).getroot()



        js, thumbnail, video, scene_list = super().__getitem__(self.file_id[i])

        #transcript = []
        #for w in transcript_xml:
        #    transcript.append(w.text)
        #print("transcript_xml", transcript_xml)
        #transcripts = '; '.join(transcript).replace('&#39;', '\'')

        title = js['title']
        description= js['description']

        #return  self.file_id[i], description, video, title, thumbnail, transcripts, scene_list
        return  self.file_id[i], description, video, title, thumbnail, scene_list


def read_story_file(text_file):
    with open(text_file, "r") as f:
        # sentences are separated by 2 newlines
        # single newlines might be image captions
        # so will be incomplete sentence
        lines = f.read().split('\n\n')
    return lines

def fix_missing_period(line):
    """Adds a period to a line that is missing a period"""
    if "@summary" in line:
        return line
    if line == "":
        return line
    return line + " ."

def get_art_abs(story_file):
    """ return as list of sentences"""
    lines = read_story_file(story_file)

    # Lowercase, truncated trailing spaces, and normalize spaces
    lines = [' '.join(line.lower().strip().split()) for line in lines]

    # Put periods on the ends of lines that are missing them (this is a problem
    # in the dataset because many image captions don't end in periods;
    # consequently they end up in the body of the article as run-on sentences)
    lines = [fix_missing_period(line) for line in lines]

    # Separate out article and abstract sentences
    article_lines = []
    highlights = []
    next_is_highlight = False
    next_is_body = False
    for idx, line in enumerate(lines):
        #print("line", line)
        if line == "":
            #print("empty")
            continue # empty line
        elif line.startswith("@body"):
            #print("line.startswith(body)")
            next_is_body = True
            article_lines.append(line.replace("@body", '') + '.')
        elif line.startswith("@summary"):
            #print("line.startswith(summary)")
            next_is_highlight = True
            next_is_body = False
            highlights.append(line.replace("@summary", '') + '.')
        elif next_is_body:
            #print("next_is_body")
            article_lines.append(line)
        elif next_is_highlight:
            #print("next_is_highlight")
            highlights.append(line)

    return ' '.join(article_lines), ' '.join(highlights)

def find_scenes(video_path, threshold=80.0):
    # Create our video & scene managers, then add the detector.
    video_manager = VideoManager([video_path])
    scene_manager = SceneManager()
    scene_manager.add_detector(ContentDetector(threshold=threshold))
    # Improve processing speed by downscaling before processing.
    video_manager.set_downscale_factor()
    # Start the video manager and perform the scene detection.
    video_manager.start()
    scene_manager.detect_scenes(frame_source=video_manager, frame_skip=360)
    # Each returned scene is a tuple of the (start, end) timecode.
    scene_list = scene_manager.get_scene_list()

    scene_frame_list = []
    for i, scene in enumerate(scene_list):
        #print(
        #    'Scene %2d: Start %s / Frame %d, End %s / Frame %d' % (
        #    i+1,
        #    scene[0].get_timecode(), scene[0].get_frames(),
        #    scene[1].get_timecode(), scene[1].get_frames(),))
        scene_frame_list.append(int(scene[1].get_frames()/360))
        #scene_frame_list.append(scene[1].get_frames())

    return torch.Tensor(scene_frame_list)




def _count_data(path):
    """ count number of data in the given path"""
    files = glob.glob(join(path, "*.json"))
    n_data = len(files)
    return n_data

In [None]:
import numpy as np
import re, math
import pickle as pkl
from torch.nn import functional as F
#from transformers import BertTokenizer, BertModel
import os
import torch.nn.functional as F
import torch.nn as nn
import torch
import torch.nn.utils.rnn as rnn_utils
import time
# from utils import PAD, UNK, END, START, UNK_TOKEN, PAD_TOKEN
from torch.nn.functional import softplus
import torchvision.models as models
import nltk
from torchvision import transforms
from PIL import Image
from nltk import word_tokenize
from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTokenizer, CLIPTextModel
from nltk.tokenize import sent_tokenize
from transformers.pipelines import pipeline
from nltk.tag import pos_tag
import cv2
nltk.download('averaged_perceptron_tagger')
nltk.download('punkt')
class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        """
        Args:
            x: Tensor, shape [seq_len, batch_size, embedding_dim]
        """
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)

class CLIPSum(nn.Module):
    def __init__(self, text_hidden_size, video_hidden_size,
                max_summary_word, max_summary_pic):

        super(CLIPSum, self).__init__()
        self.pos = ['LS', 'TO', 'VBN', "''", 'WP', 'UH', 'VBG', 'JJ', 'VBZ', '--', 'VBP', 'NN', 'DT', 'PRP', ':', 'WP$', 'NNPS', 'PRP$', 'WDT', '(', ')', '.', ',', '``', '$', 'RB', 'RBR', 'RBS', 'VBD', 'IN', 'FW', 'RP', 'JJR', 'JJS', 'PDT', 'MD', 'VB', 'WRB', 'NNP', 'EX', 'NNS', 'SYM', 'CC', 'CD', 'POS']

        self.face_cascade = cv2.CascadeClassifier('haarcascade_frontalface_default.xml')


        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.max_summary_word = max_summary_word
        self.max_summary_pic = max_summary_pic

        self.tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
        self.featureextractor = CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32")
        #self.vision_model = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32")
        #self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
        #self.text_model.to(device)
        #self.vision_model.eval()
        #self.vision_model.to(device)


        self.cliptext = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32")
        #self.cliptext.eval()
        #self.clipvision = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32")
        self.clip = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
        self.clip.eval()
        self.cliptext.eval()
        #self.clipvision.eval()
        self.hidden_size = 512
        self.topic_size = 100
        self.dropout = 0.1
        #self.image_hidden_size=768
        self.num_attention_head = 16
        #self.wordattn = nn.MultiheadAttention(embed_dim=self.hidden_size, num_heads=self.num_attention_head, dropout=self.dropout)
        #self.word_layer_norm = nn.LayerNorm(self.hidden_size)
        #self.wordlstm= nn.LSTM(self.hidden_size*2, self.hidden_size, num_layers=2, bidirectional=False, batch_first = True)

        self.word_pret_pos_encoder = PositionalEncoding(self.hidden_size)
        #word_encoder_layer = nn.TransformerEncoderLayer(d_model=self.hidden_size, nhead=16)
        #self.word_pret_transformer_model = nn.TransformerEncoder(word_encoder_layer, num_layers=12)
        self.word_pret_transformer_model = nn.Transformer(self.hidden_size, nhead=self.num_attention_head, num_encoder_layers=12, batch_first=True)

        self.word_modalspecific_transformer_model = nn.Transformer(self.hidden_size, nhead=self.num_attention_head, num_encoder_layers=12, batch_first=True)

        #self.word_pos_encoder = PositionalEncoding(self.hidden_size + 1)
        self.word_transformer_model = nn.Transformer(self.hidden_size, nhead=self.num_attention_head, num_encoder_layers=12, batch_first=True)

        self.image_pret_pos_encoder = PositionalEncoding(self.hidden_size)
        #image_encoder_layer = nn.TransformerEncoderLayer(d_model=self.hidden_size, nhead=16)
        #self.image_pret_transformer_model = nn.TransformerEncoder(image_encoder_layer, num_layers=12)
        self.image_pret_transformer_model = nn.Transformer(self.hidden_size, nhead=self.num_attention_head, num_encoder_layers=12, batch_first=True)
        self.image_modalspecific_transformer_model = nn.Transformer(self.hidden_size, nhead=self.num_attention_head, num_encoder_layers=12, batch_first=True)

        #self.image_pos_encoder = PositionalEncoding(self.hidden_size + 1)
        self.image_transformer_model = nn.Transformer(self.hidden_size, nhead=self.num_attention_head, num_encoder_layers=12, batch_first=True)


        self.v2tattn = nn.MultiheadAttention(embed_dim=self.hidden_size, num_heads=self.num_attention_head, batch_first=True, dropout=self.dropout)
        self.v2tattn_layer_norm = nn.LayerNorm(self.hidden_size)
        self.v2tattn_linear = nn.Linear(self.hidden_size, self.hidden_size)
        self.v2tattn_linear_layer_norm = nn.LayerNorm(self.hidden_size )

        self.t2vattn = nn.MultiheadAttention(embed_dim=self.hidden_size, num_heads=self.num_attention_head, batch_first=True, dropout=self.dropout)
        self.t2vattn_layer_norm = nn.LayerNorm(self.hidden_size)
        self.t2vattnn_linear = nn.Linear(self.hidden_size , self.hidden_size )
        self.t2vattn_linear_layer_norm = nn.LayerNorm(self.hidden_size )

        self.coattn = nn.Linear((self.hidden_size )*2, self.hidden_size)
        #self.image_layer_norm = nn.LayerNorm(self.hidden_size)
        #self.imagelstm= nn.LSTM(self.hidden_size*2, self.hidden_size, num_layers=2, bidirectional=False, batch_first = True)


        with open("kmeans_model_100.pkl", "rb") as f:
            self.k_means =  pkl.load(f)

        self.outputs2vocab = nn.Linear(self.hidden_size, 1)

       # self.pic_linear_z = nn.Linear(in_features=self.hidden_size*2, out_features=self.hidden_size)

        self.outputs2coverframe = nn.Linear(self.hidden_size, 1)


    def forward(self, input_text, input_video, text_summary_length=11):

        batch_size = len(input_text)

        batch_sent = []
        batch_sent_num = []
        batch_sent_pad = []
        text_token_list = []
        text_id_list = []
        pos_list = []
        for text in input_text:
            text = text.replace("Please subscribe HERE http://bit.ly/1rbfUog", "")
            text = text.replace("#BBCNews", "")
            text = text.replace("\n", " \n ")
            sents = sent_tokenize(text)
            text_id_l = []
            text_token_l = []
            pos_l = []
            for sent in sents:
                text_t = word_tokenize(sent)
                pos_parse = nltk.pos_tag(text_t)
                pos_t = []
                for p in pos_parse:
                    print("p", p)
                    try:
                        pos_t.append(self.pos.index(p[1])+1)
                    except:
                        pos_t.append(0)
                for t in text_t:
                    text_token = self.tokenizer(t , return_tensors = "pt", padding=True, truncation=True, max_length=50)
                    text_id_l.append(int(text_token['input_ids'][0][1]))
                text_token_l.extend(text_t)
                pos_l.extend(pos_t)
            text_token_list.append(text_token_l)
            print("text_id_l", text_id_l)
            pos_list.append(torch.LongTensor(pos_l[:77]))
            text_id_list.append(torch.LongTensor(text_id_l[:77])) #clip max length

            batch_sent.append(sents)
            batch_sent_num.append(len(sents))

        for batch in batch_sent:
            batch += [' '] * (max(batch_sent_num) - len(batch))
            batch_sent_pad.append(batch)

        batch_sent_pad_t = [list(x) for x in zip(*batch_sent_pad)]
        sent_feature = []
        word_feature = []
        pad_sent_len = []
        sent_feature = []
        last_len = 0
        #for sent in batch_sent_pad_t:
        #print("sent", sent)
        pos_id_list = torch.nn.utils.rnn.pad_sequence(pos_list, batch_first=True).cuda()
        text_id_list = torch.nn.utils.rnn.pad_sequence(text_id_list, batch_first=True).cuda()
        print("text_id_list", text_id_list.size())
        print("pos_id_list", pos_id_list.size())
        word_feature = self.cliptext(text_id_list).last_hidden_state
        word_feature = self.word_pret_pos_encoder(word_feature)
        #word_feature = self.word_pret_transformer_model(word_feature.permute(1,0,2)).permute(1,0,2)


        print("word_feature", word_feature.size())
        #text_overall_feature = self.clip.get_text_features(text_token['input_ids'].cuda())
        #text_token_list.append(text_token_l)
        sent_feature = self.clip.get_text_features(text_id_list)
        print("sent_feature", sent_feature.size())

        text_b,text_s,text_d = word_feature.size()

        pos_id_target = torch.zeros(text_b, text_s, self.hidden_size).cuda()
        pos_id_target[:, :, :1] = pos_id_list.unsqueeze(2)

        #word_feature = torch.cat([word_feature, pos_id_list], dim=2)

        word_feature = self.word_modalspecific_transformer_model(word_feature, pos_id_target)




        topic_distance_t = torch.from_numpy(self.k_means.transform(sent_feature.detach().cpu().numpy())).cuda()

        topic_distance_t_target = torch.zeros(text_b, self.hidden_size).cuda()
        topic_distance_t_target[:, :self.topic_size] = topic_distance_t

        word_feature = self.word_pret_transformer_model(word_feature, topic_distance_t_target.unsqueeze(1).expand(-1,text_s , -1))

        last_len = 0

        if batch_size > 1:
            input_video = torch.nn.utils.rnn.pad_sequence(input_video, batch_first=True)

        scene_frame_pad_batch_list = input_video[:, :50,:,:,:] #max for cuda

        v_image_features = []
        v_num_face = []
        for i, images in enumerate(scene_frame_pad_batch_list.permute(1,0,4,2,3)):
            num_faces = []
            image_batch = []
            for batch in images:
                img = transforms.ToPILImage()(batch.squeeze_(0))
                #print("img", img.size())
                image_batch.append(img)
                p#rint("img", img.size())
                faces = self.face_cascade.detectMultiScale(cv2.cvtColor(np.array(img), cv2.COLOR_BGR2GRAY), 1.1, 4)
                num_faces.append(len(faces))
            v_num_face.append(torch.LongTensor(num_faces))
            image_token = self.featureextractor(image_batch, return_tensors = "pt")
            v_image_features.append(self.clip.get_image_features(image_token['pixel_values'].cuda()))

        v_num_face = torch.stack(v_num_face, dim=1)
        print("v_num_face", v_num_face.size())
        print("v_num_face", v_num_face)
        image_feature = torch.stack(v_image_features, dim=1)
        image_feature = self.image_pret_pos_encoder(image_feature)


        video_feature = torch.mean(image_feature, dim=1)

        print("video_feature", video_feature.size())
        #image_feature = torch.cat([image_feature, v_num_face.unsqueeze(2).cuda()], dim=2)
        print("image_feature", image_feature.size())
        video_b,video_s,video_d = image_feature.size()

        face_target = torch.zeros(video_b, video_s, self.hidden_size).cuda()
        face_target[:, :, :1] = v_num_face.unsqueeze(2)

        image_feature = self.image_modalspecific_transformer_model(image_feature, face_target.cuda())




        print("video_d", video_d)
        topic_distance_v = torch.from_numpy(self.k_means.transform(video_feature.detach().cpu().numpy())).cuda()

        topic_distance_v_target = torch.zeros(video_b, self.hidden_size).cuda()
        topic_distance_v_target[:, :self.topic_size] = topic_distance_v


        image_feature = self.image_pret_transformer_model(image_feature, topic_distance_v_target.unsqueeze(1).expand(-1,video_s , -1))


        v2t_attn, _ = self.v2tattn(image_feature, word_feature, word_feature)
        v2t_attn = self.v2tattn_layer_norm(v2t_attn) + image_feature
        v2t_attn_linear = self.v2tattn_linear(v2t_attn.reshape(-1, video_d))
        v2t_attn = self.v2tattn_linear_layer_norm(v2t_attn_linear.view(video_b,video_s,video_d)) + v2t_attn

        t2v_attn, _ = self.t2vattn(word_feature, image_feature, image_feature)
        t2v_attn = self.v2tattn_layer_norm(t2v_attn) + word_feature
        t2v_attn_linear = self.v2tattn_linear(t2v_attn.reshape(-1, text_d))
        t2v_attn = self.v2tattn_linear_layer_norm(t2v_attn_linear.view(text_b,text_s,text_d)) + t2v_attn

        print("v2t_attn", v2t_attn.size())
        print("t2v_attn", t2v_attn.size())
        #print("torch.cat([v2t_attn.squeeze(1), t2v_attn.squeeze(1)],dim=1)", torch.cat([v2t_attn.squeeze(1), t2v_attn.squeeze(1)],dim=1).size())
        overall_feature = self.coattn(torch.cat([torch.mean(v2t_attn, dim=1), torch.mean(t2v_attn, dim=1)],dim=1))


        #overall_feature = torch.mean(torch.stack([video_feature, sent_feature],dim=1), dim=1)


        topic_distance = torch.from_numpy(self.k_means.transform(overall_feature.detach().cpu().numpy())).cuda()

        topic_distance_target = torch.zeros(text_b, self.hidden_size).cuda()
        topic_distance_target[:, :self.topic_size] = topic_distance


        #word_feature_topic = torch.cat((word_feature, topic_distance.unsqueeze(1).expand(-1,text_s , -1)), dim=2)
        #word_attn_output, _ = self.wordattn(word_feature.permute(1, 0, 2), topic_distance_target.unsqueeze(1).expand(-1,text_s , -1).permute(1, 0, 2), topic_distance_target.unsqueeze(1).expand(-1,text_s , -1).permute(1, 0, 2))
        #word_attn_output = self.word_layer_norm(word_attn_output)
        #print("word_attn_output", word_attn_output.size())
        #word_attn_output = word_attn_output.permute(1, 0, 2)
        #word_feature = self.word_pos_encoder(word_feature)
        word_attn_output = self.word_transformer_model(word_feature, topic_distance_target.unsqueeze(1).expand(-1,text_s , -1))
        #word_lstm_out = torch.cat([word_feature, word_attn_output.permute(1, 0, 2)], -1)
        #print("word_lstm_out", word_lstm_out.size())
        #word_lstm_out1, (_, _) = self.wordlstm(word_lstm_out)
        #print("word_lstm_out1", word_lstm_out1.size())
        print("word_attn_output", word_attn_output.size())

        #text_logp = self.outputs2vocab(word_lstm_out1.reshape(-1, word_lstm_out1.size(2)))
        text_logp = self.outputs2vocab(word_attn_output.reshape(-1, word_attn_output.size(2)))

        #print("text_logp", text_logp.shape)
        text_logp = text_logp.view(text_b, text_s, 1)


        #text_b,text_s,_ = sent_feature.size()
        #text_logp = self.outputs2vocab(sent_feature.reshape(-1, sent_feature.size(2)))
        #print("text_logp", text_logp.shape)
        #text_logp = text_logp.view(text_b, text_s, 1)


        #image_feature_topic = torch.cat((image_feature, topic_distance.unsqueeze(1).expand(-1,video_s , -1)), dim=2)
        #image_attn_output, _ = self.imageattn(image_feature.permute(1, 0, 2), topic_distance_target.unsqueeze(1).expand(-1,video_s , -1).permute(1, 0, 2), topic_distance_target.unsqueeze(1).expand(-1,video_s , -1).permute(1, 0, 2))
        #image_attn_output = self.image_layer_norm(image_attn_output)
        #image_attn_output = image_attn_output.permute(1, 0, 2)
        #image_feature = self.image_pos_encoder(image_feature)
        image_attn_output = self.image_transformer_model(image_feature, topic_distance_target.unsqueeze(1).expand(-1,video_s , -1))
        #image_lstm_out = torch.cat([image_feature, image_attn_output.permute(1, 0, 2)], -1)
        #image_lstm_out1, (_, _) = self.imagelstm(image_lstm_out)
        print("image_attn_output", image_attn_output.size())

        #video_logp = self.outputs2coverframe(image_lstm_out1.reshape(-1, image_lstm_out1.size(2)))
        video_logp = self.outputs2coverframe(image_attn_output.reshape(-1, image_attn_output.size(2)))
        #print("video_logp", video_logp.shape)

        video_logp = video_logp.view(video_b, video_s, 1)


        output_video_summaries = []
        output_video_summaries_pos = []

        for image, summary in zip(scene_frame_pad_batch_list, video_logp):
            #print("summary video", summary.size())
            #output_video_summaries_pos.append(summary.argmax(dim=-1).item())
            #rank = torch.topk(summary.squeeze(), self.max_summary_pic).indices
            rank = torch.argsort(summary, dim=0, descending=True)
            #print("rank video", rank[0])
            print("image", len(image))
            output_video_summaries_pos.append(rank[0])
            output_video_summaries.append(image[int(rank[0])])

        #print("output_video_summaries_pos", output_video_summaries_pos)

        output_text_summaries = []
        output_text_summaries_pos = []

        for text, t_id, summary in zip(text_token_list, text_id_list, text_logp):
            word_count = 0
            pos = []
            #print("summary text", summary.size())
            #print( "text", text)
            #rank = torch.topk(summary.squeeze(), self.max_summary_word).indices
            rank = torch.argsort(summary, dim=0, descending=True)
            #print("rank text", rank)
            text_id = []

            filtered_rank = []
            #for i in sorted(rank):
            for i in rank:
                if i < len(text) and t_id[int(i)] < 49406:
                    filtered_rank.append(int(i))
                    if text[int(i)] != PAD_TOKEN and text[int(i)].isalnum():
                        word_count += 1

                if word_count > self.max_summary_word:
                    break

            for i in sorted(filtered_rank):
                text_id.append(text[int(i)])
                pos.append(int(i))

            #print("text_id", text_id)
            output_text_summaries.append(" ".join(text_id))
            output_text_summaries_pos.append(pos)


        #for text, summary in zip(batch_sent_pad, text_logp):
            #pos = []
        #    print("summary text", summary)
        #    print( "text", text)
            #rank = torch.argsort(summary.squeeze(), descending=True)
        #    rank = torch.argsort(summary, dim=0, descending=True)
         #   print("rank text", rank)
            #text_id = []

            #for i in sorted(rank):
            #    text_id.append(text[int(i)])
            #    pos.append(int(i))
            #print("text", text)
            #print("summary", summary)
            #print("rank", rank)
          #  text_summary = ' '
           # text_summary_pos = 0

            #for rank_i in rank:
             #   if text[rank_i] != ' ':
              #      text_summary = text[rank_i]
               #     text_summary_pos = rank_i
                #    break

            #output_text_summaries.append(text_summary)
            #output_text_summaries_pos.append(text_summary_pos)

        print("output_text_summaries", output_text_summaries)
        print("output_text_summaries_pos", output_text_summaries_pos)


        return output_text_summaries, output_text_summaries_pos, text_logp, output_video_summaries, output_video_summaries_pos, video_logp


[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     /root/nltk_data...
[nltk_data]   Unzipping taggers/averaged_perceptron_tagger.zip.
[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.


In [None]:
from transformers import GPT2LMHeadModel, GPT2Config

import torch.utils.data.dataset
# import utils_tokenizer
import torch, tqdm, math

def pad(data, padval=0):
    return torch.nn.utils.rnn.pad_sequence(data, batch_first=True, padding_value=padval)

class GeneTransformer:
    def __init__(self, max_output_length=25, max_input_length=300, device='cpu', tokenizer_type='gpt2', bpe_model="", starter_model=None, word_count=None):
        if tokenizer_type == "gpt2":
            self.tokenizer = GPT2Tokenizer()
            config = GPT2Config.from_pretrained("gpt2")

        elif tokenizer_type == "bpecap":
            self.tokenizer = BPETokenizer(bpe_model)
            config = GPT2Config.from_dict({"finetuning_task": None, "initializer_range": 0.02,
                            "layer_norm_epsilon": 1e-05, "n_ctx": 1024, "n_embd": 768, "n_head": 12, "n_layer": 12, "n_positions": 1024, "num_labels": 1,
                            "resid_pdrop": 0.1, "use_bfloat16": False, "vocab_size": self.tokenizer.vocab_size})
        else:
            print("Tokenizer unrecognized. Should be gpt2 or bpecap.")
            exit()

        self.model = GPT2LMHeadModel(config)

        self.model.to(device)
        #self.model.cuda(3)
        self.device = device
        if starter_model is not None:
            self.reload(starter_model)

        self.max_output_length = max_output_length
        self.max_input_length = max_input_length

        self.model.train()
        self.mode = "train"
        if word_count is not None:
            self.word_count = word_count

    def train_batch(self, bodies, summaries, special_append=None, no_preinput=False):
        # if self.mode != 'train':
        #     print("BEWARE. Model is not in train mode.")

        inputs, summ_inp, summ_out = self.preprocess_batch(bodies, summaries, special_append)
        past = None
        if not no_preinput:
           # _, past = self.model(input_ids=inputs, past_key_values=None)
            _, past = self.model(input_ids=inputs, past=None)
        #logits, _ = self.model(input_ids=summ_inp, past_key_values=past)
        logits, _ = self.model(input_ids=summ_inp, past=past)
        crit = torch.nn.CrossEntropyLoss(ignore_index=-1)
        loss = crit(logits.view(-1, self.tokenizer.vocab_size), summ_out.contiguous().view(-1))
        return loss

    def train(self):
        self.model.train()
        self.mode = 'train'

    def eval(self):
        self.model.eval()
        self.mode = 'eval'

    def reload(self, from_file):
        print(self.model.load_state_dict(torch.load(from_file)))

    def save(self, to_file):
        torch.save(self.model.state_dict(), to_file)

    def preprocess_input(self, bodies, special_append=None):
        if special_append is None:
            special_append = [[] for i in range(len(bodies))]
        inputs = [torch.LongTensor(spe+self.tokenizer.encode(body)) for body, spe in zip(bodies, special_append)]
        inputs = pad(inputs, padval=0)
        inputs = inputs[:, :self.max_input_length].to(self.device)
        return inputs

    def preprocess_batch(self, bodies, summaries, special_append=None):
        inputs = self.preprocess_input(bodies, special_append)

        # Big hack
        if special_append is None:
            special_append = [[] for i in range(len(bodies))]

        summaries = [spe+self.tokenizer.encode(summ) for summ, spe in zip(summaries, special_append)]

        summaries = [summ[:(self.max_output_length-1)] for summ in summaries] # We cut short, but we want the end token at the end

        summ_inp = pad([torch.LongTensor([self.tokenizer.start_id]+summ) for summ in summaries], padval=0).to(self.device)
        summ_out = pad([torch.LongTensor(summ+[self.tokenizer.end_id]) for summ in summaries], padval=-1).to(self.device)
        # summ_inp = summ_inp[:, :self.max_output_length].to(self.device)
        # summ_out = summ_out[:, :self.max_output_length].to(self.device)
        return inputs, summ_inp, summ_out

    def score(self, summaries, video_sum, bodies, videos, idx_batch=None, bodies_tokenized=None, lengths=None, extra=None):
        # Unconditional rating of the summaries
        self.model.eval()
        # if self.mode != 'eval':
        #     print("BEWARE. Model is not in eval mode.")

        inputs, summ_inp, summ_out = self.preprocess_batch(bodies, summaries)
        summ_out = summ_out.contiguous()

        with torch.no_grad():
            #logits, _ = self.model(input_ids=summ_inp[:1024], past=None)
            #print("summ_inp[:1024]", summ_inp[:1024])
            out = self.model(input_ids=summ_inp[:1024])
            #print("out", out)
            logits = out["logits"]
            #print("logits", logits)
            crit = torch.nn.CrossEntropyLoss(ignore_index=-1, reduction='none')
            loss = crit(logits.view(-1, self.tokenizer.vocab_size), summ_out.view(-1)).view(summ_out.shape)
            mask = (summ_inp != torch.LongTensor([0]).to(self.device)).float()
            non_pad_count = torch.sum(mask, dim=1)
            #print("summaries", summaries)

            #p_us = []
            #total_word = sum(self.word_count.values())

            #for idx, summary in enumerate(summaries):
            #    tokens = self.tokenizer.encode(' '.join(summary))

                #print("tokens", tokens)
             #   p_u = 1

                #print("total_word", total_word)
             #   for token in tokens:
                    #print("token", token)
             #       try:
             #           p_u *= self.word_count[token]/total_word
             #       except:
             #           p_u *= 1 /total_word#in case the word is not found in the training dataset

                #print("p_u", p_u)
                #print("p_u log", math.log(p_u+0.001))
             #   p_us.append(math.log(p_u+0.001))


        #print("torch.sum(loss, dim=1)", torch.sum(loss, dim=1))
        #p_us = torch.tensor(p_us).to(self.device)
        #print("p_us", p_us)
        #loss_per = (torch.sum(loss, dim=1) - p_us)/ non_pad_count
        loss_per = torch.sum(loss, dim=1) / non_pad_count

        #print("loss_per", loss_per)

        #score = (10.0 - loss_per) / 10.0
        score = (10.0 - loss_per) / 10.0
        #print("score", score)
        #score = loss_per
        return score, None

    def score_pairs(self, bodies, summaries):
        if self.mode != 'eval':
            print("BEWARE. Model is not in eval mode.")

        inputs, summ_inp, summ_out = self.preprocess_batch(bodies, summaries)

        with torch.no_grad():
            _, past = self.model(input_ids=inputs, past=None)
            logits, _ = self.model(input_ids=summ_inp, past=past)

            crit = torch.nn.CrossEntropyLoss(ignore_index=-1, reduction='none')
            loss = crit(logits.view(-1, self.tokenizer.vocab_size), summ_out.view(-1)).view(summ_out.shape)
            mask = (summ_inp != torch.LongTensor([0]).to(self.device)).float()
            non_pad_count = torch.sum(mask, dim=1)
            loss_per = torch.sum(loss, dim=1) / non_pad_count

        return loss_per.tolist(), None





In [None]:
""" run decoding of rnn-ext + abs + RL (+ rerank)"""
import argparse
import json
import os
from os.path import join, exists
from datetime import timedelta
from time import time
import pickle as pkl
from collections import Counter, defaultdict
from itertools import product
from functools import reduce
import operator as op
from torch.autograd import Variable
import numpy as np
import torch
from torch.utils.data import DataLoader

from torch import multiprocessing as mp, nn
#from utils import make_vocab
#from data.batcher import tokenize
# from clipsumcomp_topic_trans_3stream_modal_specific import CLIPSum
# from clipsum import CLIPSum
# from datasets import load_dataset
# from util_dataset import EXMSMODataset
import cv2

try:
    DATA_DIR = os.environ['DATA']

except KeyError:
    print('please use environment variable to specify data directories')

def collate_func(inps):
    return [a for a in inps]

def decode(params, dataset_folder, save_path, model_dir, model_name, split, batch_size, cuda):
    start = time()

    summarizer = CLIPSum(**params)
    print("-----------summarizer")
    #print("summarizer", summarizer)
    summarizer.cuda()


    summarizer.load_state_dict(torch.load(join(model_dir,model_name)))
    summarizer.eval()

    dataset = EXMSMODataset('test', dataset_folder)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=0)
    n_data = len(dataset)

    # prepare save paths and logs
    if not exists(join(save_path, 'outputText')):
        os.makedirs(join(save_path, 'outputText'))

    if not exists(join(save_path, 'outputTextExtIdx')):
        os.makedirs(join(save_path, 'outputTextExtIdx'))

    if not exists(join(save_path, 'outputPic')):
        os.makedirs(join(save_path, 'outputPic'))

    if not exists(join(save_path, 'outputPicExtIdx')):
        os.makedirs(join(save_path, 'outputPicExtIdx'))

    print("loader length", len(dataloader))
    # Decoding
    i = 0
    start_process=False
    if args.last_process == "":
        start_process=True

    print("start_process", start_process)
    with torch.no_grad():
        for ib, d in enumerate(dataloader):


            file_id = d[0][0]

            if  args.last_process == file_id:
                start_process=True
            if start_process:
                descriptions = d[1]
                videos = d[2]
                #scenes = d[6]
                #titles.append(d[3])
                #transcripts.append(d[4])
                print("file_id", file_id)
                print("descriptions", descriptions)
                print("videos", videos.size())
                # Forward pass
                #bodies = [doc[args.dataset_doc_field] for doc in documents]
                output_text_summaries, output_text_summaries_pos, text_logp, output_video_summaries, output_video_summaries_pos, video_logp = summarizer(descriptions, videos)

                    #sampled_summaries, _, sampled_pointers = summarizer.forward(bodies, args.max_ext_output_length, args.max_comp_output_length)


                #print("decoded", comp_sampled_summaries[0])

                #print("decoded ext_arts_w", ext_arts_w)
                with open(join(save_path, 'outputText/{}.dec'.format(file_id)),'w') as f:
                    f.write(output_text_summaries[0])

                with open(join(save_path, 'outputTextExtIdx/{}.dec'.format(file_id)),'w') as f:
                    f.write(','.join([str(sent) for sent in output_text_summaries_pos[0]] ))

                #print("videos[pics[0]]", videos[pics[0]].numpy().shape)
                #cv2.imwrite(join(save_path, 'outputPic/{}.png'.format(file_id)), cv2.cvtColor(output_video_summaries[0].numpy(), cv2.COLOR_BGR2RGB))
                cv2.imwrite(join(save_path, 'outputPic/{}.png'.format(file_id)), output_video_summaries[0].numpy())

                #with open(join(args.save_path, 'outputPic/{}.dec'.format(ib)),'w') as f:
                #    f.write(videos[pics[0]])
                with open(join(save_path, 'outputPicExtIdx/{}.dec'.format(file_id)),'w') as f:
                    f.write(','.join([str(pic) for pic in output_video_summaries_pos] )) #adjust position

                i += 1
    print()




if __name__ == '__main__':
    print("running")
    parser = argparse.ArgumentParser()


    parser.add_argument('--max_sequence_length', type=int, default=900)
    parser.add_argument('--max_article_length', type=int, default=5)
    parser.add_argument('--max_summary_pic', type=int, default=1)
    parser.add_argument('--max_summary_word', type=int, default=12)

    parser.add_argument('--test', action='store_true')
    parser.add_argument('-bs', '--batch', type=int, default=1)
    parser.add_argument('-lr', '--learning_rate', type=float, default=0.00005)

    parser.add_argument('-ths', '--text_hidden_size', type=int, default=128)
    parser.add_argument('-vhs', '--video_hidden_size', type=int, default=128)
    #parser.add_argument('-nah', '--num_attention_head', type=int, default=2)


    #parser.add_argument('-chs', '--conductor_hidden_size', type=int, default=256)
    #parser.add_argument('-dhs', '--decoders_hidden_size', type=int, default=64)
    #parser.add_argument('-dis', '--decoders_initial_size', type=int, default=32)

    #parser.add_argument('-nl', '--num_layers', type=int, default=2)

    parser.add_argument('--path', required=True, help='path to ext model')
    parser.add_argument('--model_dir', required=True, help='path to ext model')
    parser.add_argument('--model_name', required=True, help='ext model')
    parser.add_argument('--dataset_folder', type=str, help='folder of dataset')
    parser.add_argument('--last_process', type=str, default='')
    args = parser.parse_args()

    #args.cuda = torch.cuda.is_available() and not args.no_cuda
    print("torch.cuda.is_available()", torch.cuda.is_available())
    args.cuda = True

    params = dict(
        max_summary_word=args.max_summary_word,
        max_summary_pic=args.max_summary_pic,
        text_hidden_size=args.text_hidden_size,
        video_hidden_size=args.video_hidden_size,
        #num_attention_head=args.num_attention_head,
        #num_layers=args.num_layers,
    )
    data_split = 'test' if args.test else 'val'
    decode(params, args.dataset_folder, args.path, args.model_dir, args.model_name,
           data_split, args.batch,
           args.cuda)


please use environment variable to specify data directories
running


usage: colab_kernel_launcher.py [-h] [--max_sequence_length MAX_SEQUENCE_LENGTH]
                                [--max_article_length MAX_ARTICLE_LENGTH]
                                [--max_summary_pic MAX_SUMMARY_PIC]
                                [--max_summary_word MAX_SUMMARY_WORD] [--test] [-bs BATCH]
                                [-lr LEARNING_RATE] [-ths TEXT_HIDDEN_SIZE]
                                [-vhs VIDEO_HIDDEN_SIZE] --path PATH --model_dir MODEL_DIR
                                --model_name MODEL_NAME [--dataset_folder DATASET_FOLDER]
                                [--last_process LAST_PROCESS]
colab_kernel_launcher.py: error: the following arguments are required: --path, --model_dir, --model_name


SystemExit: 2

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)



In [None]:
import sys
#site_packages = next(p for p in sys.path if 'site-packages' in p)
#print(site_packages)

import numpy as np
import os, shutil
import re
import codecs
import os, pickle as pkl
# from util_dataset import EXMSMODataset
#from MMVAE import MMVAE
#from clipsumcomp import CLIPSum
#from clipsumcomp_topic import CLIPSum
# from clipsum import CLIPSum
import codecs
import argparse
import torch
import torch.nn as nn
import torch.nn.init as init
from torch import optim
from torch import save
from torch.nn import functional as F
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader, RandomSampler
# from utils import make_vocab, make_embedding, convert_word2id, to_var, idx2word
#from transformers import BertModel
from transformers import BertTokenizer
from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTokenizer
from torchvision import transforms
import os
import torch.nn.functional as F
import torch.nn as nn
import gensim
import torch.nn.utils.rnn as rnn_utils
from multiprocessing import cpu_count
import time
import torch, gc
from collections import OrderedDict, defaultdict
# from utils import PAD, UNK, END, START
import json
from torch.nn.functional import softplus
import torchvision.models as models
from datasets import load_dataset
import cv2 ,ot
import ssl
# from model_generator import GeneTransformer
from torch.utils.tensorboard import SummaryWriter
gc.collect()


ssl._create_default_https_context = ssl._create_unverified_context
BERT_NUM_TOKEN = 30522
torch.manual_seed(12345)



class TextCoverageLoss:
    # Depending on how many words are used a large fraction of the last X summaries
    def __init__(self, device="cuda", costmatrix_filename="COST_MATRIX_bert.pickle"):
    #def __init__(self, device="cpu", costmatrix_filename="COST_MATRIX.pickle"):

        #self.model = BertModel.from_pretrained("bert-base-uncased", output_hidden_states = True)
        #self.model.eval()
        #self.tokenizer = utils_tokenizer.GPT2Tokenizer()
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        max_bytes = 2**31 - 1

        bytes_in = bytearray(0)
        input_size = os.path.getsize(costmatrix_filename)
        with open(costmatrix_filename, 'rb') as f_in:
            for _ in range(0, input_size, max_bytes):
                bytes_in += f_in.read(max_bytes)

            self.COST_MATRIX = pkl.loads(bytes_in)
            #self.COST_MATRIX = pkl.load(f_in, map_location=torch.device('cpu'))
            #self.COST_MATRIX = torch.load(f_in, map_location=torch.device('cpu'))
        #self.COST_MATRIX = np.negative(self.COST_MATRIX)
        #self.COST_MATRIX = np.reciprocal(self.COST_MATRIX)

    def score(self, summaries, bodies):
        scores = []
         # Avoid changing p and q outside of this function
        with torch.no_grad():
            for i in range(len(summaries)):

                #doc = remove_stopwords(bodies[i])
                #summary = remove_stopwords(summaries[i])
                summary = summaries[i]
                doc = bodies[i]
                if len(summary)==0:
                    score = 1
                else:

                    summary_token = self.tokenizer.encode(summary)
                    body_token = self.tokenizer.encode(doc)

                    summary_bow = construct_BOW(summary_token)
                    body_bow = construct_BOW(body_token)

                    score = sparse_ot(summary_bow, body_bow, self.COST_MATRIX)

                scores.append(score)

        print('text coverage score', scores)
        return sum(scores)/len(scores)


class MmCoverageLoss:
    def __init__(self, device="cuda"):
        self.tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
        self.featureextractor = CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32")

        self.clip = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
        self.clip.eval()
        self.clip.cuda()
        COSTMATRIX_DIM = 512
        #self.cosloss = nn.CosineEmbeddingLoss()
        self.COST_MATRIX = torch.ones(COSTMATRIX_DIM, COSTMATRIX_DIM) -  torch.eye(COSTMATRIX_DIM)
        self.COST_MATRIX = self.COST_MATRIX/COSTMATRIX_DIM


    def score(self, text_summaries, video_summaries, texts, videos):
        scores = []
        #for text, image in zip(text_summaries, video_text_summaries):
        #    print("text", text)
        #print("text_summaries", text_summaries)
        #print("video_text_summaries", video_summaries[0].size())
        with torch.no_grad():
            for v, t in zip(video_summaries, text_summaries):
                print("v", v.size())
                #print("t", t.size())
                i = transforms.ToPILImage()(v.permute(2,0,1).squeeze_(0))
                text_t = self.tokenizer(t, return_tensors = "pt", padding=True, truncation=True)
                image_t = self.featureextractor(i, return_tensors = "pt")
                print("image_t['pixel_values']", image_t['pixel_values'].size())
                text_f = self.clip.get_text_features(text_t['input_ids'].cuda())
                image_f = self.clip.get_image_features(image_t['pixel_values'].cuda())

                print("text_f", text_f.size())
                print("image_f", image_f.size())
                score = sparse_ot(text_f.squeeze(0).cpu().detach().numpy(), image_f.squeeze(0).cpu().detach().numpy(), self.COST_MATRIX.numpy())
                scores.append(score)
        #scores.append(self.cosloss(text_f, image_f, Variable(torch.ones(text_f.size()[0]).cuda())))
        #return scores[0]
        #return sum(scores)/len(scores)
        return sum(scores)/len(scores)

class MmAlignmentLoss:
    def __init__(self, device="cuda"):
        self.tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
        self.featureextractor = CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32")
        #self.cliptext = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32")
        #self.clipvision = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32")
       # self.cliptext.eval()
        #self.clipvision.eval()

        self.clip = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
        self.clip.eval()
        self.clip.cuda()

        self.cosloss = nn.CosineEmbeddingLoss()

    def score(self, text_summaries, video_summaries):
        scores = []
        #for text, image in zip(text_summaries, video_text_summaries):
        #    print("text", text)
        #print("text_summaries", text_summaries)
        #print("video_text_summaries", video_summaries[0].size())
        text_t = self.tokenizer(text_summaries, return_tensors = "pt", padding=True, truncation=True)
        #print("text_t['input_ids']", text_t['input_ids'].size())
        image_batch = []
        for batch in video_summaries:
            image_batch.append(transforms.ToPILImage()(batch.permute(2,0,1).squeeze_(0)))

        image_t = self.featureextractor(image_batch, return_tensors = "pt")
        #print("image_t['pixel_values']", image_t['pixel_values'].size())
        text_f = self.clip.get_text_features(text_t['input_ids'].cuda())
        image_f = self.clip.get_image_features(image_t['pixel_values'].cuda())

        #print("text_f", text_f.size())
        #print("image_f", image_f.size())

        scores.append(self.cosloss(text_f, image_f, Variable(torch.ones(text_f.size()[0]).cuda())))



        return scores[0]

def VideoCoverageLoss(summaries, bodies):
    scores = []
    # Avoid changing p and q outside of this function

    for summary, video in zip(summaries, bodies):
        video = np.mean(np.array(video), axis=0).astype(np.float32)
        #summary = np.array(summary).reshape((128, 64)).astype(np.float32)
        summary = summary.detach().numpy().astype(np.float32)

        #video_64 = cv2.fromarray(video)
        #video_32 = cv2.cv.CreateMat(video.rows, video.cols, cv2.CV_32FC1)
        #video_32 = np.zeros((video.shape[0], video.shape[1], 1), dtype = np.float32)

        #cv2.Convert(video, video_32)

        #summary_64 = cv2.fromarray(summary)
        #summary_32 = np.zeros((summary_32.shape[0], summary_32.shape[1], 1), dtype = np.float32)

        #summary_32 = cv2.cv.CreateMat(summary.rows, summary.cols, cv2.CV_32FC1)
        #cv2.Convert(summary, summary_32)

        video_bw = cv2.cvtColor(video, cv2.COLOR_BGR2GRAY)
        summary_bw = cv2.cvtColor(summary, cv2.COLOR_BGR2GRAY)

        #print("video_bw", video_bw.shape)
        #print("summary_bw", summary_bw.shape)
        #print("video_bw", video_bw)
        #print("summary_bw", summary_bw)
        score = 1.0 / 0.001
        try:
        #black_image = cv2.cvtColor((np.ones((256,256,3))*255).astype(np.float32), cv2.COLOR_BGR2GRAY)
        #white_image = cv2.cvtColor((np.ones((256,256,3))*0.001).astype(np.float32), cv2.COLOR_BGR2GRAY)

        #scale = cv2.EMD(black_image,white_image,cv2.DIST_L2)[0]
        #print("scale", scale)
    #score = cv2.EMD(summary_bw,video_bw,cv2.DIST_L2)[0] / cv2.EMD(np.ones((500, 500, 1), dtype = "uint8")*0.001,np.ones((500, 500, 1), dtype = "uint8"),cv2.DIST_L2)[0]
            #score = cv2.EMD(summary_bw,video_bw,cv2.DIST_L2)[0] / scale
            score = cv2.EMD(summary_bw,video_bw,cv2.DIST_L2)[0]
        except:
            print("VideoCoverageLoss cannot compute")
        scores.append(score)

        ## change the latent representation to the actual video/image
    print('VideoCoverageLoss', scores)
    return sum(scores)/len(scores)

class OT_topic():
    def __init__(self):
        with open("kmeans_model.pkl", "rb") as f:
            self.k_means =  pkl.load(f)
        self.tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
        self.featureextractor = CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32")
        #self.cliptext = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32")
        #self.clipvision = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32")
       # self.cliptext.eval()
        #self.clipvision.eval()

        self.clip = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
        self.clip.eval()
        self.clip.cuda()

        self.topic_size = 50

        self.COST_MATRIX = torch.ones(self.topic_size, self.topic_size) - torch.eye(self.topic_size)
        self.COST_MATRIX = self.COST_MATRIX.numpy()

    def score(self, summaries_f, bodies_f):
        scores = []

        topic_distance_summaries = self.k_means.transform(summaries_f.detach().cpu().numpy())
        topic_distance_bodies = self.k_means.transform(bodies_f.detach().cpu().numpy())

        for summ, bod in zip(topic_distance_summaries, topic_distance_bodies):
            score = sparse_ot(summ, bod, self.COST_MATRIX)
            scores.append(score)
        return scores

    def score_text(self, summaries, bodies):

        bodies_t = self.tokenizer(bodies, return_tensors = "pt", padding=True, truncation=True)
        summaries_t = self.tokenizer(summaries, return_tensors = "pt", padding=True, truncation=True)

        bodies_f = self.clip.get_text_features(bodies_t['input_ids'].cuda())
        summaries_f = self.clip.get_text_features(summaries_t['input_ids'].cuda())

        scores = self.score(summaries_f, bodies_f)

        print('OT topic text coverage score', scores)
        return sum(scores)/len(scores)

    def score_image(self, summaries, bodies):

        summaries_t = self.featureextractor(summaries, return_tensors = "pt")
        summaries_f = self.clip.get_image_features(summaries_t['pixel_values'].cuda())

        videos = []
        for b in bodies:
            videos.append(torch.from_numpy(np.mean(np.array(b), axis=0).astype(np.float32)))

        bodies_t = self.featureextractor(torch.cat(videos), return_tensors = "pt")
        bodies_f = self.clip.get_image_features(bodies_t['pixel_values'].cuda())

        scores = self.score(summaries_f, bodies_f)
        print('OT topic visual coverage score', scores)
        return sum(scores)/len(scores)

    def score_image_text(self, v_summaries, t_summaries):

        summaries_vt = self.featureextractor(v_summaries, return_tensors = "pt")
        summaries_vf = self.clip.get_image_features(summaries_vt['pixel_values'].cuda())

        summaries_tt = self.tokenizer(t_summaries, return_tensors = "pt", padding=True, truncation=True)
        summaries_tf = self.clip.get_text_features(summaries_tt['input_ids'].cuda())

        scores = self.score(summaries_vf, summaries_tf)
        print('OT topic textual_visual coverage score', scores)
        return sum(scores)/len(scores)

def save_log(log_input):
    file_name = MODEL_PATH + '/log.txt'
    p = log_input
    c = """text_file = open(file_name, "a+");text_file.write(p);text_file.close()"""
    exec(c)


def main(args):
    ts = time.strftime('%Y-%b-%d-%H-%M-%S', time.gmtime())

    splits = ['train', 'validation'] + (['test'] if args.test else [])
    MODEL_PATH = args.save_model_path
    #wv = api.load('word2vec-google-news-300')
    dataset_folder = args.dataset_folder


    params = dict(
        max_summary_word=args.max_summary_word,
        max_summary_pic=args.max_summary_pic,
        text_hidden_size=args.text_hidden_size,
        video_hidden_size=args.video_hidden_size,
        #num_attention_head=args.num_attention_head,
        #num_layers=args.num_layers,
    )


    #for i in list(range(9)):
    #    train_dataset = EXMSMODataset('train_'+str(i), dataset_folder)
    #    with open('train_'+str(i)+'.pickle', 'wb') as handle:
    #        pkl.dump(train_dataset, handle , protocol=4)

    #    print("train_dataset"+str(i))

    #max_bytes = 4096
    max_bytes = 2**31 - 1
    #train_dataset_list = []

    #for i in list(range(9)):
    #    bytes_in = bytearray(0)
    #    input_size = os.path.getsize('train_'+str(i)+'.pickle')
    #    with open('train_'+str(i)+'.pickle', 'rb') as f_in:
    #        for _ in range(0, input_size, max_bytes):
    #            bytes_in += f_in.read(max_bytes)
    #    train_dataset = pkl.loads(bytes_in)
    #    train_dataset_list.extend(train_dataset)

    #val_dataset = EXMSMODataset('val', dataset_folder)
    #with open('val.pickle', 'wb') as handle:
    #    pkl.dump(val_dataset, handle , protocol=4)

    #bytes_in = bytearray(0)
    #input_size = os.path.getsize('test.pickle')
    #print("input_size", input_size)
    #with open('test.pickle', 'rb') as f_in:
     #   unpickler = pkl.Unpickler(f_in)
        # if file is not empty scores will be equal
        # to the value unpickled
    #    train_dataset = unpickler.load()

        #print("total", input_size/max_bytes)

        #for i in range(0, input_size, max_bytes):
        #    print("load data", i/max_bytes)
        #    bytes_in += f_in.read(max_bytes)
        #train_dataset =    pkl.load(f_in , protocol=4)
    #print("all loaded")
    #train_dataset = pkl.loads(bytes_in)

    def collate_func(inps):
        return [a for a in inps]

    #del bytes_in

    #val_dataset = EXMSMODataset('val_test', dataset_folder)
    train_dataset = EXMSMODataset('train', dataset_folder)
    val_dataset = EXMSMODataset('val', dataset_folder)
    #print("train_dataset", train_dataset)
    #print("val_dataset", val_dataset)
    print("Train Dataset size:", len(train_dataset))
    #print("Val Dataset size:", len(val_dataset))
    train_data_loader = DataLoader(train_dataset, batch_size=args.batch_size, sampler=RandomSampler(train_dataset), drop_last=True, collate_fn=collate_func)
    val_data_loader = DataLoader(val_dataset, batch_size=args.batch_size, drop_last=True, collate_fn=collate_func)


    model = CLIPSum(**params)

    #plain_model = MMHHGATGPO(**params)
    #plain_model.load_state_dict(torch.load("/share/home/ptan6545/multimodal_OTVAE/mmhhgatgpo_model/2022-Apr-03-01:24:05/E5-1.828360.ckpt"))

    if torch.cuda.is_available():
        model = model.cuda()

    #for target_param, param in zip(model.parameters(), plain_model.parameters()):
    #    target_param.data.copy_(param.data)

    if args.resume_training:
        model.load_state_dict(torch.load(os.path.join(MODEL_PATH,args.model_name)))
        model.train()


    print(model)


    save_model_path = os.path.join(args.save_model_path, ts)
    # save_model_path = os.path.join(args.save_model_path, '1')
    os.makedirs(save_model_path)

    with open(os.path.join(save_model_path, 'model_params.json'), 'w') as f:
        json.dump(params, f, indent=4)

    def loss_fn(output_text_summaries, output_video_summaries, texts, videos, OT_topic):


        # cut-off unnecessary padding from target, and flatten
        #logp = logp[:, :torch.max(summary_length).item(), :].view(-1, logp.size(2))
        #logp = logp[:, :summary_length, :].contiguous().view(-1, logp.size(2))
        #text_logp = text_logp[:, :, :].contiguous().view(-1, text_logp.size(2))
        #print("loss target", target.size())
        #print("loss logp", logp.size())
        # Negative Log Likelihood

        #textcoverage_loss = textCoverageLoss.score(output_text_summaries, texts)
        textcoverage_loss = OT_topic.score_text(output_text_summaries, texts)
        #videocoverage_loss = VideoCoverageLoss(output_video_summaries, videos)
        videocoverage_loss = OT_topic.score_image(output_video_summaries, videos)
        #NLL_loss = NLL(logp, target)
        #print("text_z", text_z.size())
        #print("video_z", video_z.size())
        #print('target', Variable(torch.ones(text_z.size()[0])).size())
        #mmcoverage_loss = mmCoverageLoss.score(output_text_summaries, output_video_summaries, texts, videos)
        #mmcoverage_loss = mmCoverageLoss.score(output_text_summaries, output_video_summaries)
        mmcoverage_loss = OT_topic.score_image_text(output_video_summaries, output_text_summaries)
        fluency_loss, _ = fluencyLoss.score(output_text_summaries, output_video_summaries, descriptions, videos)
        # KL Divergence


        return textcoverage_loss, videocoverage_loss , mmcoverage_loss, sum(fluency_loss)/len(fluency_loss)

    #optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
    optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate)
    #optimizer = torch.optim.RMSprop(model.parameters(), lr=args.learning_rate)

    tensor = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.Tensor
    step = 0
    fluency_news_model_file = os.path.join("models", "gpt2_copier23.bin")

    #textCoverageLoss = TextCoverageLoss()
    #mmCoverageLoss = MmCoverageLoss()
    #mmCoverageLoss = MmAlignmentLoss()
    fluencyLoss = GeneTransformer(max_output_length=args.max_summary_word, device="cuda", starter_model=fluency_news_model_file)
    otTopic = OT_topic()
    writer = SummaryWriter()
    for epoch in range(args.epochs):

        for split in splits:
            print("split", split)
            tracker = defaultdict(tensor)

            # Enable/Disable Dropout
            if split == 'train':
                model.train()
                data_loader = train_data_loader
            else:
                model.eval()
                data_loader = val_data_loader

            #print("data_loader", len(data_loader))

            for iteration, data in enumerate(data_loader):
                gc.collect()
                torch.cuda.empty_cache()
                #print("iteration", iteration)
                #print("data", data)
                print("epoch ", epoch)
                print("iteration ", iteration)
                file_id = []
                descriptions = []
                videos = []
                #titles = []
                #scenes = []

                for d in data:
                    file_id.append(d[0])
                    descriptions.append(d[1])
                    videos.append(d[2])
                    #titles.append(d[3])
                    #scenes.append(d[6])

                batch_size = len(data)

                # Forward pass
                #bodies = [doc[args.dataset_doc_field] for doc in documents]
                output_text_summaries, output_text_summaries_pos, text_logp, output_video_summaries, output_video_summaries_pos, video_logp = model(descriptions, videos)

                # loss calculation
                textcoverage_loss, videocoverage_loss, mmalignment_loss, fluency_loss = loss_fn(output_text_summaries, output_video_summaries, descriptions, videos, otTopic)

                #loss = torch.zeros(1, requires_grad=True)

                #loss = (NLL_loss + KL_weight * KL_loss) / batch_size
                print("textcoverage_loss", textcoverage_loss)
                print("videocoverage_loss", videocoverage_loss)
                print("mmalignment_loss", mmalignment_loss)
                print("fluency_loss", fluency_loss)
                #print("batch_size", batch_size)
                #loss = (textcoverage_loss + videocoverage_loss + mmalignment_loss+text_KL_weight * text_KL_loss+video_KL_loss*video_KL_weight) / batch_size
                loss = textcoverage_loss + 0.01 * videocoverage_loss + mmalignment_loss + 2* fluency_loss #previous videocoverage 0.001

                loss.requires_grad = True
                # backward + optimization
                if split == 'train':
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                    step += 1

                # bookkeepeing
                tracker['LOSS'] = torch.cat((tracker['LOSS'], loss.data.view(1, -1)), dim=0)

                if args.tensorboard_logging:
                    writer.add_scalar("%s/Loss" % split.upper(), loss.item(), epoch*len(data_loader) + iteration)
                    writer.add_scalar("%s/Text Coverage Loss" % split.upper(), textcoverage_loss,
                                      epoch*len(data_loader) + iteration)
                    writer.add_scalar("%s/Video Coverage Loss" % split.upper(), videocoverage_loss,
                                      epoch*len(data_loader) + iteration)
                    writer.add_scalar("%s/MM Coverage Loss" % split.upper(), mmalignment_loss,
                                      epoch*len(data_loader) + iteration)

                if iteration % args.print_every == 0 or iteration+1 == len(data_loader):
                    print("%s Batch %04d/%i, Loss %9.4f, Text-Coverage-Loss %9.4f, Video-Coverage-Loss %9.4f, MM-Coverage-Loss %9.4f"
                          % (split.upper(), iteration, len(data_loader)-1, loss.item(), textcoverage_loss, videocoverage_loss, mmalignment_loss))

                if split == 'valid':
                    if 'target_sents' not in tracker:
                        tracker['target_sents'] = list()
                    #tracker['target_sents'] += idx2word(answer, i2w=datasets['train'].get_i2w(),pad_idx=PAD)
                    tracker['target_sents'] += idx2word(answer, i2w=self.id2word,pad_idx=PAD)
                    tracker['z'] = torch.cat((tracker['z'], z.data), dim=0)

            print("%s Epoch %02d/%i, Mean LOSS %9.4f" % (split.upper(), epoch, args.epochs, tracker['LOSS'].mean()))

            if args.tensorboard_logging:
                writer.add_scalar("%s-Epoch/LOSS" % split.upper(), torch.mean(tracker['LOSS']), epoch)

            # save a dump of all sentences and the encoded latent space
            if split == 'valid':
                dump = {'target_sents': tracker['target_sents'], 'z': tracker['z'].tolist()}
                if not os.path.exists(os.path.join('dumps', ts)):
                    os.makedirs('dumps/'+ts)
                with open(os.path.join('dumps/'+ts+'/valid_E%i.json' % epoch), 'w') as dump_file:
                    json.dump(dump,dump_file)

            # save checkpoint
            if split == 'train' and epoch%5 ==0:
                checkpoint_path = os.path.join(save_model_path, "E%i-%9f.ckpt" % (epoch,tracker['LOSS'].mean()))
                torch.save(model.state_dict(), checkpoint_path)
                print("Model saved at %s" % checkpoint_path)



def sparse_ot(weights1, weights2, M):
    """ Compute Wasserstein distances"""

    weights1 = weights1/weights1.sum()
    weights2 = weights2/weights2.sum()

    active1 = np.where(weights1)[0]
    active2 = np.where(weights2)[0]

    weights_1_active = weights1[active1]
    weights_2_active = weights2[active2]
    #print("active1", active1)
    #print("active2", active2)
    #print("M", M)
    #print("M", M)
    try1 = M[active1][:,active2]
    #print("try1", try1)
    M_reduced = np.ascontiguousarray(M[active1][:,active2])

    return ot.emd2(weights_1_active,weights_2_active,M_reduced)

def construct_BOW(tokens):
    bag_vector = np.zeros(BERT_NUM_TOKEN)
    for token in tokens:
        bag_vector[token] += 1
    return bag_vector/len(tokens)

if __name__ == '__main__':
    parser = argparse.ArgumentParser()


    parser.add_argument('--max_sequence_length', type=int, default=900)
    parser.add_argument('--max_article_length', type=int, default=5)
    parser.add_argument('--max_summary_pic', type=int, default=1)
    parser.add_argument('--max_summary_word', type=int, default=12)

    parser.add_argument('--test', action='store_true', default='False')

    parser.add_argument('-ep', '--epochs', type=int, default=100000)
    parser.add_argument('-bs', '--batch_size', type=int, default=2)
    parser.add_argument('-lr', '--learning_rate', type=float, default=0.01)

    parser.add_argument('-ths', '--text_hidden_size', type=int, default=128)
    parser.add_argument('-vhs', '--video_hidden_size', type=int, default=128)
    parser.add_argument('-nah', '--num_attention_head', type=int, default=2)
    parser.add_argument('-nl', '--num_layers', type=int, default=2)

    parser.add_argument('-v', '--print_every', type=int, default=50)
    parser.add_argument('-tb', '--tensorboard_logging', action='store_true')
    parser.add_argument('-log', '--logdir', type=str, default='logs')
    parser.add_argument('-bin', '--save_model_path', type=str, default='multimodal_model')
    parser.add_argument('--dataset_folder', type=str, help='folder of dataset', default='/content/drive/MyDrive/Research/data')
    parser.add_argument("--resume_training", type=bool, default=False)
    parser.add_argument("--model_name", type=str, default='bench')
    args = parser.parse_args()


    main(args)





In [None]:
!