-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
12 changed files
with
1,074 additions
and
25 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
MIT License | ||
|
||
Copyright (c) 2019 Donghyun Kim and Bryan Plummer | ||
|
||
Permission is hereby granted, free of charge, to any person obtaining a copy | ||
of this software and associated documentation files (the "Software"), to deal | ||
in the Software without restriction, including without limitation the rights | ||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
copies of the Software, and to permit persons to whom the Software is | ||
furnished to do so, subject to the following conditions: | ||
|
||
The above copyright notice and this permission notice shall be included in all | ||
copies or substantial portions of the Software. | ||
|
||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||
SOFTWARE. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
import os | ||
import json | ||
import numpy as np | ||
from data_loader import DatasetLoader | ||
|
||
class COCOLoader(DatasetLoader): | ||
""" Dataset loader class that loads feature matrices from given paths and | ||
create shuffled batch for training, unshuffled batch for evaluation. | ||
""" | ||
def tokenize(self, language, token_filename, image_filename, tokens, vocab): | ||
with open(token_filename, 'r') as f: | ||
all_sentences = json.load(f) | ||
|
||
image_list = [im.strip() for im in open(image_filename, 'r').readlines()] | ||
assert len(all_sentences) == len(image_list) | ||
max_length = 0 | ||
for im, sentences in zip(image_list, all_sentences): | ||
i = self.image2index[im] | ||
for sentence in sentences: | ||
if language == 'en': | ||
sentence = sentence.lower().split() | ||
else: | ||
sentence = sentence.encode('utf8').split() | ||
|
||
vocab.update(sentence) | ||
max_length = max(len(sentence), max_length) | ||
tokens[i].append(sentence) | ||
|
||
return max_length | ||
|
||
def get_tokens(self, args, language): | ||
token_filename = os.path.join('data', args.dataset, 'tokenized', '%s_%s_caption_list.json' % (self.split, language)) | ||
if self.split != 'train': | ||
image_filename = os.path.join('data', args.dataset, self.split + '.txt') | ||
elif language == 'en': | ||
# contains images which there are no human-generated sentences for other languages | ||
image_filename = os.path.join('data', args.dataset, 'tokenized', '%s_%s_coco.txt' % (self.split, language)) | ||
else: | ||
image_filename = os.path.join('data', args.dataset, 'tokenized', '%s_en_%s_coco.txt' % (self.split, language)) | ||
|
||
tokens = [[] for _ in range(len(self.image2index))] | ||
vocab = set() | ||
max_length = self.tokenize(language, token_filename, image_filename, tokens, vocab) | ||
if self.split == 'train': | ||
if language == 'en': | ||
# add images that have human-generated japanese captions | ||
token_filename = os.path.join('data', args.dataset, 'tokenized', 'train_en_jp_caption_list.json') | ||
image_filename = os.path.join('data', args.dataset, 'tokenized', 'train_en_jp_coco.txt') | ||
max_length = max(max_length, self.tokenize(language, token_filename, image_filename, tokens, vocab)) | ||
|
||
# add images that have human-generated chinese captions | ||
token_filename = os.path.join('data', args.dataset, 'tokenized', 'train_en_cn_caption_list.json') | ||
image_filename = os.path.join('data', args.dataset, 'tokenized', 'train_en_cn_coco.txt') | ||
max_length = max(max_length, self.tokenize(language, token_filename, image_filename, tokens, vocab)) | ||
|
||
# add translations to english from other languages | ||
for this_lang in ['cn', 'jp']: | ||
token_filename = os.path.join('data', args.dataset, 'tokenized', 'train_%s_to_en_caption_list.json' % this_lang) | ||
image_filename = os.path.join('data', args.dataset, 'tokenized', 'train_en_%s_coco.txt' % this_lang) | ||
max_length = max(max_length, self.tokenize(language, token_filename, image_filename, tokens, vocab)) | ||
|
||
else: | ||
# add translations from english | ||
token_filename = os.path.join('data', args.dataset, 'tokenized', 'train_%s_augment_caption_list.json' % language) | ||
image_filename = os.path.join('data', args.dataset, 'tokenized', 'train_en_%s_augment_coco.txt' % language) | ||
max_length = max(max_length, self.tokenize(language, token_filename, image_filename, tokens, vocab)) | ||
|
||
im2sent = {} | ||
sent2im = [] | ||
num_sentences = 0 | ||
for i, sentences in enumerate(tokens): | ||
im2sent[i] = np.arange(num_sentences, num_sentences + len(sentences)) | ||
sent2im.append(np.ones(len(sentences), np.int32) * i) | ||
num_sentences += len(sentences) | ||
|
||
sent2im = np.hstack(sent2im) | ||
max_length = min(max_length, args.max_sentence_length) | ||
return tokens, sent2im, im2sent, vocab, max_length |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,176 @@ | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
import numpy as np | ||
import os | ||
import pickle | ||
from abc import abstractmethod | ||
|
||
def get_sentence(vocab, tokens, token_length): | ||
sent_feats = [] | ||
for sentences in tokens: | ||
feats = np.zeros((len(sentences), token_length), np.int32) | ||
for i, words in enumerate(sentences): | ||
words = [word for word in words if word in vocab] | ||
for j, word in enumerate(words[:token_length]): | ||
feats[i, j] = vocab[word] | ||
|
||
sent_feats.append(feats) | ||
|
||
sent_feats = np.concatenate(sent_feats, axis=0) | ||
return sent_feats | ||
|
||
def get_embeddings(args, language, vocab): | ||
cachefn = os.path.join('data', args.dataset, language + '_vecs.pkl') | ||
if os.path.exists(cachefn): | ||
embedding_data = pickle.load(open(cachefn, 'rb')) | ||
word2index = embedding_data['word2index'] | ||
vecs = embedding_data['vecs'] | ||
else: | ||
embedding_dims = 300 | ||
if language == 'cn': | ||
language = 'zh' | ||
elif language == 'jp': | ||
language = 'ja' | ||
|
||
wordvec_file = os.path.join('fasttext', 'cc.%s.300.vec' % language) | ||
with open(wordvec_file, 'r') as f: | ||
w2v_dict = {} | ||
for i, line in enumerate(f): | ||
if i % 100000 == 0: | ||
print('reading %s vector %i' % (language, i)) | ||
|
||
line = line.strip() | ||
if not line: | ||
continue | ||
|
||
vec = line.split() | ||
if len(vec) != embedding_dims + 1: | ||
continue | ||
|
||
label = vec[0].lower() | ||
if label not in vocab: | ||
continue | ||
|
||
vec = np.array([float(x) for x in vec[1:]], np.float32) | ||
assert(len(vec) == embedding_dims) | ||
w2v_dict[label] = vec | ||
|
||
vocab = vocab.intersection(set(w2v_dict.keys())) | ||
|
||
vocab = list(vocab) | ||
vecs = np.concatenate((np.zeros((1, embedding_dims), np.float32), np.random.standard_normal((len(vocab), embedding_dims)))) | ||
word2index = {} | ||
for i, tok in enumerate(vocab): | ||
vecs[i + 1] = w2v_dict[tok] | ||
word2index[tok] = i + 1 | ||
|
||
pickle.dump({'word2index' : word2index, 'vecs' : vecs}, open(cachefn, 'wb')) | ||
|
||
return word2index, vecs | ||
|
||
class DatasetLoader: | ||
""" Dataset loader class that loads feature matrices from given paths and | ||
create shuffled batch for training, unshuffled batch for evaluation. | ||
""" | ||
def __init__(self, args, split): | ||
im_feats = np.load(os.path.join('data', args.dataset, split + '.npy')) | ||
with open(os.path.join('data', args.dataset, split + '.txt'), 'r') as f: | ||
image_ids = [line.strip() for line in f.readlines()] | ||
|
||
assert len(im_feats) == len(image_ids) | ||
self.image2index = dict(zip(image_ids, range(len(image_ids)))) | ||
self.split = split | ||
self.im_feat_shape = (im_feats.shape[0], im_feats.shape[-1]) | ||
self.im_feats = im_feats | ||
self.languages = args.languages | ||
self.sent_feats = {} | ||
self.num_sentences = {} | ||
self.sent2im = {} | ||
self.im2sent = {} | ||
self.vecs = {} | ||
self.vocab = {} | ||
self.max_length = {} | ||
max_sentences = 0 | ||
for language in args.languages: | ||
tokens, sent2im, im2sent, vocab, max_length = self.get_tokens(args, language) | ||
self.vocab[language], self.vecs[language] = get_embeddings(args, language, vocab) | ||
language_features = get_sentence(self.vocab[language], tokens, max_length) | ||
self.max_length[language] = max_length | ||
self.sent_feats[language] = language_features | ||
num_sentences = len(sent2im) | ||
self.num_sentences[language] = num_sentences | ||
if num_sentences > max_sentences: | ||
self.sent2im = sent2im | ||
self.max_language = language | ||
max_sentences = num_sentences | ||
|
||
self.im2sent[language] = im2sent | ||
|
||
self.sent_inds = range(max_sentences) | ||
if split != 'train': | ||
self.test_labels = {} | ||
for language, im2sent in self.im2sent.iteritems(): | ||
labels = np.zeros((self.num_sentences[language], len(self.image2index)), np.bool) | ||
for image_index, sentences in im2sent.iteritems(): | ||
labels[sentences, image_index] = True | ||
|
||
self.test_labels[language] = labels | ||
|
||
@abstractmethod | ||
def get_tokens(self, args, language): | ||
pass | ||
|
||
def shuffle_inds(self): | ||
''' | ||
shuffle the indices in training (run this once per epoch) | ||
nop for testing and validation | ||
''' | ||
np.random.shuffle(self.sent_inds) | ||
|
||
def sample_items(self, sample_inds, sample_size): | ||
''' | ||
for each index, return the relevant image and sentence features | ||
sample_inds: a list of sent indices | ||
sample_size: number of neighbor sentences to sample per index. | ||
''' | ||
im_ids = [self.sent2im[i] for i in sample_inds] | ||
im_feats = self.im_feats[im_ids] | ||
sent_feats = {} | ||
for language in self.languages: | ||
feats = [] | ||
for im, sent in zip(im_ids, sample_inds): | ||
im_sent = list(self.im2sent[language][im]) | ||
if language == self.max_language: | ||
if len(im_sent) < sample_size: | ||
for _ in range(sample_size - len(im_sent)): | ||
im_sent.append(np.random.choice(im_sent)) | ||
|
||
sample_index = im_sent | ||
else: | ||
im_sent.remove(sent) | ||
sample_index = np.random.choice(im_sent, sample_size - 1, replace=False) | ||
sample_index = np.append(sample_index, sent) | ||
else: | ||
sample_index = np.random.choice(im_sent, sample_size, replace=len(im_sent) < sample_size) | ||
|
||
feats.append(self.sent_feats[language][sample_index]) | ||
|
||
sent_feats[language] = np.concatenate(feats, axis=0) | ||
|
||
return (im_feats, sent_feats) | ||
|
||
def get_batch(self, batch_index, batch_size, sample_size): | ||
start_ind = batch_index * batch_size | ||
end_ind = start_ind + batch_size | ||
sample_inds = self.sent_inds[start_ind : end_ind] | ||
(im_feats, sent_feats) = self.sample_items(sample_inds, sample_size) | ||
|
||
# Each row of the labels is the label for one sentence, | ||
# with corresponding image index sent to True. | ||
labels = np.repeat(np.eye(batch_size, dtype=bool), sample_size, axis=0) | ||
return (im_feats, sent_feats, labels) | ||
|
||
|
||
|
||
|
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
#!/bin/bash | ||
|
||
set -x | ||
set -e | ||
|
||
mkdir fasttext | ||
cd fasttext | ||
|
||
wget https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.en.300.vec.gz | ||
gunzip cc.en.300.vec.gz | ||
|
||
wget https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.de.300.vec.gz | ||
gunzip cc.de.300.vec.gz | ||
|
||
wget https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.fr.300.vec.gz | ||
gunzip cc.fr.300.vec.gz | ||
|
||
wget https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.cs.300.vec.gz | ||
gunzip cc.cs.300.vec.gz | ||
|
||
wget https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.zh.300.vec.gz | ||
gunzip cc.zh.300.vec.gz | ||
|
||
wget https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.ja.300.vec.gz | ||
gunzip cc.ja.300.vec.gz | ||
|
||
cd .. | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
# taken from https://github.com/pumpikano/tf-dann | ||
|
||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
import tensorflow as tf | ||
from tensorflow.python.framework import ops | ||
|
||
|
||
class FlipGradientBuilder(object): | ||
def __init__(self): | ||
self.num_calls = 0 | ||
|
||
def __call__(self, x, l=1.0): | ||
grad_name = "FlipGradient%d" % self.num_calls | ||
@ops.RegisterGradient(grad_name) | ||
def _flip_gradients(op, grad): | ||
return [tf.negative(grad) * l] | ||
|
||
g = tf.get_default_graph() | ||
with g.gradient_override_map({"Identity": grad_name}): | ||
y = tf.identity(x) | ||
|
||
self.num_calls += 1 | ||
return y | ||
|
||
flip_gradient = FlipGradientBuilder() |
Oops, something went wrong.