In [None]:
#MSG
import requests
from requests.auth import HTTPBasicAuth
import time

BASE_URL = 'https://e621.net/'
RATE_TIME = 1.5

class E621WrapperError(Exception):
    pass

class E621Wrapper:
    def __init__(self, username: str, api_key: str, user_agent: str = None, wait_for_rate_limit: bool = True) -> None:
        if not user_agent:
            self.user_agent = f'e621_scraper (by {username} on e621)'
        self.user_agent = str(user_agent)

        self._wait_for_rate_limit = wait_for_rate_limit
        self._last_requested = 0

        self._wrapper_session = requests.Session()
        self._wrapper_session.auth = HTTPBasicAuth(username, api_key)
        self._wrapper_session.headers.update({'User-Agent' : user_agent})

    def _wait(self) -> None:
        if self._wait_for_rate_limit:
            while True:
                if time.time() - self._last_requested > RATE_TIME:
                    return
    

    def get_posts(self, limit: int = 10, tags: list = [], page: str = None) -> dict:
        route = BASE_URL + 'posts.json'
        param_builder = {}
        
        if limit > 320:
            limit = 320
        param_builder['limit'] = limit

        tag_str = ' '.join(tags)
        param_builder['tags'] = tag_str

        if page:
            param_builder['page'] = page

        self._wait()
        self._last_requested = time.time()
        r = self._wrapper_session.get(route, params=param_builder)

        return r.json()['posts']
    
    def get_tags(self, limit: int = 75, category: int = None, page: str = None) -> dict:
        route = BASE_URL + 'tags.json'
        param_builder = {}
        
        if limit > 320:
            limit = 320
        param_builder['limit'] = limit

        if type(category) != type(int()):
            raise ValueError('Category must be type int()')

        if 0 <= category <= 8:
            param_builder['search[category]'] = category
        if page:
            param_builder['page'] = page

        self._wait()
        self._last_requested = time.time()
        r = self._wrapper_session.get(route, params=param_builder)

        return r.json()

In [None]:
#Collect Posts
from monosodium_glutamate import E621Wrapper
import json

USERNAME = 'ENTERUSER'
API_KEY = 'ENTERAPIKEY'

DATA_FILE = 'data/posts.json'

def convert_rating(rating: str) -> list:
    if rating == 's':
        return [1, 0, 0]
    if rating == 'e':
        return [0, 1, 0]
    return [0, 0, 1]

def load_posts() -> dict:
    try:
        with open(DATA_FILE, encoding='utf-8') as json_file:
            return json.load(json_file)
    except FileNotFoundError:
        return {}

def save_posts(posts: dict) -> None:
    print(f'saving {len(posts)} posts')
    with open(DATA_FILE, 'w', encoding='utf-8') as json_file:
        json.dump(posts, json_file, indent=2,ensure_ascii=False)

def get_last_post() -> str:
    try:
        with open(DATA_FILE, encoding='utf-8') as json_file:
            posts = json.load(json_file)
            keys = posts.keys()
            return 'b' + str(list(keys)[-1])
    except FileNotFoundError:
        return None

with open('key.json') as json_file:
    key = json.load(json_file)["key"]

e = E621Wrapper(USERNAME, key, API_KEY)

tags = ["-animated", "-webm", "-flash", "-3d_(artwork)", "-sketch", "-pixel_(artwork)", "~digital_media_(artwork)", "~traditional_media_(artwork)"]
print(' '.join(tags))
print(get_last_post())
last_post = get_last_post()
all_posts = load_posts()
save_count = 0
while len(all_posts) < 3200:
    posts = e.get_posts(limit=320, tags=tags, page=last_post)
    if len(posts) == 0:
        break
    
    last_post = 'b' + str(posts[-1]["id"])
    print(last_post)

    for post in posts:
        post_id = post['id']
        url = post['file']['url']
        ext = post['file']['ext']
        rating = post['rating']
        species = post['tags']['species']
        general = post['tags']['general']
        score = post['score']['total']
        if ext in ['png', 'jpg']:
            all_posts[str(post_id)] = {
                "url"     : url,
                "rating"  : convert_rating(rating),
                "score"   : score,
                "species" : species,
                "general" : general
                }
        
    save_count += 1
    if save_count % 10 == 0:
        save_posts(all_posts)

save_posts(all_posts)


In [None]:
#get images
import json
import requests
from io import BytesIO
from PIL import Image

ML_SIZE   = (512,512)
DATA_FILE = 'data/posts.json'

def load_posts() -> dict:
    try:
        with open(DATA_FILE, encoding='utf-8') as json_file:
            return json.load(json_file)
    except FileNotFoundError:
        return {}

def create_image(url):
    response = requests.get(url)
    if response.status_code != 200:
        print(f'Failed to collect image. ({response.status_code})')
        return None
    img = Image.open(BytesIO(response.content))
    img = img.convert('RGB')
    img1 = img.resize(ML_SIZE)
    return img1

posts = load_posts()

for item in list(posts.items()):
    url = item[1]['url']
    ext = url[url.rfind("."):]
    print(f'creating image {item[0]}.png')
    print(url)
    image = create_image(item[1]['url'])
    if image:
        image.save(f'images/{item[0]}.png')

In [None]:
#Create Datasets
import json
import os 

DATA_FILE = 'data/posts.json'
TAG_FILE  = 'data/tags_list.json'

def load_posts_saved() -> list:
    posts = os.listdir('images/')
    posts = [p[:p.rfind('.')] for p in posts]
    return posts

def load_posts_dict() -> dict:
    try:
        with open(DATA_FILE, encoding='utf-8') as json_file:
            return json.load(json_file)
    except FileNotFoundError:
        return {}

def load_tags() -> dict:
    with open(TAG_FILE, encoding='utf-8') as json_file:
        tags = json.load(json_file)
        tags['species'] = list(tags['species'])
        tags['general'] = list(tags['general'])
        return tags

def create_onehot(length, positions):
    l = [0] * length
    for pos in positions:
        l[pos] = 1
    return l

tags = load_tags()
species_len = len(tags['species'])
general_len = len(tags['general'])

posts = load_posts_saved()
post_dict = load_posts_dict()

s_header = f"id, {', '.join(tags['species'])}\n"
g_header = f"id, {', '.join(tags['general'])}\n"

s_lines = [s_header]
g_lines = [g_header]

for post in posts:
    species_intersect = set(post_dict[post]['species']).intersection(tags['species'])
    s_positions = [tags['species'].index(s) for s in species_intersect]
    s_onehot = create_onehot(species_len, s_positions)

    general_intersect = set(post_dict[post]['general']).intersection(tags['general'])
    g_positions = [tags['general'].index(g) for g in general_intersect]
    g_onehot = create_onehot(general_len, g_positions)
    
    s_line = f"{post}, {', '.join(map(str, s_onehot))}\n"
    s_lines.append(s_line)

    g_line = f"{post}, {', '.join(map(str, g_onehot))}\n"
    g_lines.append(g_line)


with open('data/species_dataset.csv', 'w', encoding='utf-8') as csv_file:
    csv_file.writelines(s_lines)


with open('data/general_dataset.csv', 'w', encoding='utf-8') as csv_file:
    csv_file.writelines(g_lines)


In [1]:
#Sort them tags
import json

with open('data\general_tags.json') as json_file:
    general = list(json.load(json_file)["tags"])

with open('data/species_tags.json') as json_file:
    species = list(json.load(json_file)["tags"])

CUTOFF = 7000

species.sort(key= lambda x: x['post_count'], reverse=True)
species = [s for s in species if s['post_count'] > CUTOFF]
general.sort(key= lambda x: x['post_count'], reverse=True)
general = [g for g in general if g['post_count'] > CUTOFF]

tags = {"species": [s['name'] for s in species], "general": [g['name'] for g in general]}

with open('data/tags_list.json', 'w', encoding='utf-8') as json_file:
    json.dump(tags, json_file, indent=2, ensure_ascii=False)