In [None]:
!pip install wikipedia

In [None]:
import wikipediaapi

wiki_wiki = wikipediaapi.Wikipedia(language='en')

def print_categorymembers(categorymembers, level=0, max_level=3):
    depth_result = []
    for c in categorymembers.values():
        print("%s: %s (ns: %d)" % ("*" * (level + 1), c.title, c.ns))
        if c.ns == wikipediaapi.Namespace.CATEGORY and level < max_level:
            depth_result.append(print_categorymembers(c.categorymembers, level=level + 1, max_level=max_level))
        else:
            depth_result.append(c)
    return depth_result

cat = wiki_wiki.page("Category:Salads")
result = print_categorymembers(cat.categorymembers)

In [None]:
import wikipedia

article_data = []

def get_data_from_article(article_obj):
    result = {}

    result['pageid'] = article_obj.pageid
    result['title'] = article_obj.title
    result['text'] = article_obj.text
    result['summary'] = article_obj.summary

    try:
        page = wikipedia.page(pageid=article_obj.pageid, auto_suggest=True)
        result['images'] = page.images
    except Exception as e:
        print(e)

    return result

def dfs(article_graph):
    if isinstance(article_graph, list):
        for article in article_graph:
            dfs(article)
    else:
        title = article_graph.title
        if 'List of' not in title:
            print(article_graph.title)
            article_data.append(get_data_from_article(article_graph))

dfs(result)

In [1]:
import json

with open('article_text.json') as f_in:
    article_data = json.load(f_in)

In [2]:
article_data[0].keys()

dict_keys(['pageid', 'title', 'text', 'summary', 'images'])

In [3]:
def is_valid_image(link):
    if '/commons/' not in link:
        return False
    
    if 'Flag of' in link:
        return False
    
    if 'Wiki' in link:
        return False

    if link.endswith('.svg'):
        return False

    return True

def prune_image_links(ex):
    if 'images' not in ex:
        ex['images'] = []
        return ex
    
    image_links = ex['images']
    image_links = list(filter(is_valid_image, image_links))
    ex['images'] = image_links
    return ex

article_data = list(map(prune_image_links, article_data))

In [5]:
with open('article_text_cleanimages.json', 'w') as f_out:
    json.dump(article_data, f_out)

In [4]:
from itertools import chain

just_urls = list(chain(*[ex['images'] for ex in article_data]))
ids = list(range(len(just_urls)))

url_to_id = {url: idx for url, idx in zip(just_urls, ids)}

with open('url_to_id.json', 'w') as f_out:
    json.dump(url_to_id, f_out)

In [11]:
import requests
import os

headers = {
    'User-Agent': open('USER_AGENT.txt').read(),
    'Accept':'text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.7'
}

base_dir = 'salad_images/'

for url, idx in zip(just_urls, ids):
    extension = url.split('.')[-1]
    out_fp = os.path.join(base_dir, f'{idx}.{extension}')
    print(f'{url} -> {out_fp}')

    img_data = requests.get(url, headers=headers).content

    try:
        if '<!DOCTYPE html>' in img_data.decode():
            raise Exception(f'Not an image: {img_data}')
    except (UnicodeDecodeError, AttributeError):
        pass

    with open(out_fp, 'wb') as img_out:
        img_out.write(img_data)

https://upload.wikimedia.org/wikipedia/commons/4/4f/A_large_mixed_salad.jpg -> salad_images/0.jpg
https://upload.wikimedia.org/wikipedia/commons/3/3c/Ambrosia_salad.jpg -> salad_images/1.jpg
https://upload.wikimedia.org/wikipedia/commons/4/45/Fruit_salad.JPG -> salad_images/2.JPG
https://upload.wikimedia.org/wikipedia/commons/5/58/GreenSalad.jpg -> salad_images/3.jpg
https://upload.wikimedia.org/wikipedia/commons/0/04/Potato_salad_with_egg_and_mayonnaise.jpg -> salad_images/4.jpg
https://upload.wikimedia.org/wikipedia/commons/a/a4/Rocket_lettuce%2C_Butternut_squash%2C_Beetroot%2C_Green_beans%2C_whipped_cream_salad.jpg -> salad_images/5.jpg
https://upload.wikimedia.org/wikipedia/commons/c/cb/Salad_Dressing-2_%2823044326320%29.jpg -> salad_images/6.jpg
https://upload.wikimedia.org/wikipedia/commons/b/be/Treska_s_majonezou.jpg -> salad_images/7.jpg
https://upload.wikimedia.org/wikipedia/commons/9/94/Salad_platter.jpg -> salad_images/8.jpg
https://upload.wikimedia.org/wikipedia/commons/0/0

In [6]:
def print_categories(page):
    categories = page.categories
    for title in sorted(categories.keys()):
        print("%s: %s" % (title, categories[title]))

page_py = wiki_wiki.page('Category:Food_and_drink')
print("Categories")
print_categories(page_py)

NameError: name 'wiki_wiki' is not defined

In [21]:
from PIL import Image, UnidentifiedImageError
from glob import glob
from datasets import Dataset

def dataset_generator():
    for ex in article_data:
        for url in ex['images']:
            url_idx = url_to_id[url]

            found_fp = glob(f'salad_images/{url_idx}.*')
            assert len(found_fp) == 1

            img_fp = found_fp[0]
            
            try:
                image = Image.open(img_fp)
            except UnidentifiedImageError:
                continue

            ex_copy = dict(ex)
            del ex_copy['images']

            ex_copy['image'] = image
            yield ex_copy

img_dataset = Dataset.from_generator(dataset_generator)

Using custom data configuration default-1ec43adc7fb56049


Downloading and preparing dataset generator/default to /home/alexwan/.cache/huggingface/datasets/generator/default-1ec43adc7fb56049/0.0.0...


                                                                

Dataset generator downloaded and prepared to /home/alexwan/.cache/huggingface/datasets/generator/default-1ec43adc7fb56049/0.0.0. Subsequent calls will reuse this data.




In [7]:
label_to_idx = {label: idx for idx, label in enumerate(set(img_dataset['title']))}

def add_label(ex):
    label_idx = label_to_idx[ex['title']]
    ex['label'] = label_idx

    return ex

img_dataset = img_dataset.map(add_label)

NameError: name 'img_dataset' is not defined

In [29]:
img_dataset['image'][0]

KeyboardInterrupt: 

In [15]:
from datasets import load_from_disk

dset = load_from_disk('salad_dataset/')

dset.train_test_split(test_size=0.1, train_size=0.9, shuffle=True).save_to_disk('salad_dataset_split/')

Flattening the indices: 100%|██████████| 1/1 [00:00<00:00,  1.22ba/s]
Flattening the indices: 100%|██████████| 1/1 [00:00<00:00, 13.96ba/s]                     
                                                                                        

In [10]:
dset = load_from_disk('salad_dataset/')

Dataset({
    features: ['pageid', 'title', 'text', 'summary', 'image', 'label'],
    num_rows: 650
})