`#!/bin/python3`

In [1]:
import pandas as pd
import numpy as np
import string
import json

from os import listdir
from os.path import isfile, isdir, join
from keras.preprocessing import image
from keras.applications.resnet import ResNet152, preprocess_input

Using TensorFlow backend.


In [2]:
def _globalMaxPool1D(tensor):
    _,_,_,size = tensor.shape
    return [tensor[:,:,:,i].max() for i in range(size)]

def _getImageFeatures(model, img_path):
    img = image.load_img(img_path, target_size=None)

    img_data = image.img_to_array(img)
    img_data = np.expand_dims(img_data, axis=0)
    img_data = preprocess_input(img_data)

    feature_tensor = model.predict(img_data)
    return _globalMaxPool1D(feature_tensor)

def _getTextFeatures(text_path):
    with open(text_path) as json_file:
        data = json.loads(json.load(json_file))
        text = data['text'].replace("\n", " ")
        return {
            'id': data['id'],
            'text': text.translate(str.maketrans('', '', string.punctuation)),
        }
    
def _getValidImagePaths(article_path):
    img_path = join(article_path, 'img/')
    return [join(img_path, f) for f in listdir(img_path) if isfile(join(img_path, f)) and f[-4:].lower() == ".jpg"]

def GetArticleData(model, article_path):
    article_data = _getTextFeatures(join(article_path, 'text.json'))
    article_data["img"] = []
    for img_path in _getValidImagePaths(article_path):
        img_features = _getImageFeatures(model, img_path)
        article_data["img"].append(img_features)
        
    return article_data

def PreprocessArticles(data_path, limit=None):
    article_paths = [join(data_path, f) for f in listdir(data_path) if isdir(join(data_path, f))]
    limit = limit if limit else len(article_paths) + 1
    model = ResNet152(weights='imagenet', include_top=False) 
    
    articles = []
    for path in article_paths:
        article_data = GetArticleData(model, path)
        articles.append(article_data)
        if len(articles) >= limit: break
            
    return articles

In [3]:
articles = PreprocessArticles('./data/', limit=2)