In [2]:
from __future__ import print_function, division

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy
import shutil
from bs4 import BeautifulSoup
import base64
from io import BytesIO
from PIL import Image
import requests
import asyncio
import aiohttp
from aiohttp import ClientSession, ClientConnectorError

plt.ion()   # interactive mode

In [19]:
def isBase64Image(text):
    return isinstance(text, str) and text.startswith('data:image/jpeg;base64')

def isUrlImage(text):
    return isinstance(text, str) and text.startswith('http')

def get_base64_content(initial):
    r1 = initial.split('data:image/jpeg;base64,')[1]
    return r1

def get_image_type(src):
    if src == None:
        return {'ext': None, 'label': None}
    if 'jpeg' in src:
        return {'ext': 'jpg', 'label': 'JPEG'}
    elif src.startswith('https://'):
        return {'ext': 'img', 'label': None}
#     else if 'png' in src:
#         return {'ext': 'png', 'label': 'PNG'}
    else:
        return {'ext': None, 'label': None}

def get_class_name_from_file(file_name: str):
    (class_name, ext) = file_name.split('.html')
    return class_name

def get_data_dir():
    return os.path.join(os.getcwd(), 'data', 'vespa_mandarinia_data')

def get_sample_file_name(file_name):
    data_dir = get_data_dir
    return os.path.join(data_dir, file_name)
    
async def seed_images():
    print('Seeding...')
    data_dir = get_data_dir()
    if os.path.exists(data_dir):
        shutil.rmtree(data_dir)
    os.mkdir(data_dir)
    await process_html_files()
    print('Seeding complete')

    
async def stream_to_data(stream):
    empty_bytes = b''
    result = empty_bytes
    while True:
        chunk = await stream.read(8)
        if chunk == empty_bytes:
            break
        result += chunk
    return result
    
async def make_file_from_base64(src_string: str, file_name: str):
    image_type = get_image_type(src_string)
    content = get_base64_content(src_string)
    decoded = base64.b64decode(content)
    bytes_data = BytesIO(decoded)
    image = Image.open(bytes_data)
    ext = image_type.get('ext')
    file_path = f'{file_name}.{ext}'
    image.save(file_path, image_type['label'])
    
async def make_file_from_url(url: str, file_name: str, session: ClientSession):
    try:
        resp = await session.request(method="GET", url=url)
        image_contents = await stream_to_data(resp.content)
        image = Image.open(BytesIO(image_contents))
        # TODO: handle other image types
        if image.format == 'JPEG':
            image.save(file_name + '.jpg', image.format)
    except ClientConnectorError:
        print('Error getting image: ', url)
        
async def make_requests(items, **kwargs) -> None:
    base64_tasks = []
    url_tasks = []

    for item in items:
        if item['request_type'] == 'base64':
            handler = make_file_from_base64(item['src'], item['file_name'])
            base64_tasks.append(handler)
        elif item['request_type'] == 'url':
            url_tasks.append(item)
      
    base64_results = await asyncio.gather(*base64_tasks)

    async with ClientSession() as session:
        tasks = []
        for item in url_tasks:
            tasks.append(
                make_file_from_url(item['src'], item['file_name'], session=session)
            )
        results = await asyncio.gather(*tasks)
        return results

async def process_html_files():
    all_items = []
    files = ['vespa_mandarinia', 'sphex_ichneumoneus', 'sphecius_speciosus']
    for file in files:
        # use file name same as class_name
        class_name = file
        path = os.path.join('data', 'html', f'{file}.txt')
        items = process_html_file(path, class_name)
        all_items.append(items)
    # Flatten list
    all_items = [ item for item_group in all_items for item in item_group ]
    print('Attempting to create {} images'.format(len(all_items)))
    next_items = []
    count = 0
    for item in all_items:
        class_name = item['class_name']
        file_name = os.path.join(get_data_dir(), f"{class_name}-{count}")
        next_items.append({**item, 'file_name': file_name})
            
        count += 1
            
    results = await make_requests(next_items)
                        
    

def process_html_file(file_name: str, class_name: str):
    with open(file_name, 'r') as file:
        soup = BeautifulSoup(file.read(), 'html.parser')

    el = soup.find(id="islrg")
    images = el.find_all('img')
    items = []

    for image in images:
        request_type = None
        if (isBase64Image(image.get('src'))):
            request_type = 'base64'
        elif isUrlImage(image.get('src')):
            request_type = 'url'
        item = {
            "src": image.get('src'),
            "alt": image.get('alt'),
            'class_name': class_name,
            "request_type": request_type
        }
        if item['request_type'] != None:
            items.append(item)
    return items


In [20]:
# Seed data
await seed_images()

Seeding...
Attempting to create 1384 images
Seeding complete


In [220]:
# Load pretrained model
resnet18 = models.resnet18(pretrained=True)

In [None]:
# Data augmentation and normalization for training
# Just normalization for validation
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

data_dir = 'data/vespa_mandarinia'

image_datasets = []
dataloaders = []
dataset_sizes = []
for x in ['train', 'val']:
    image_datasets.push({x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x])})
    dataloaders.append({x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4, shuffle=True, num_workers=4)})
    dataset_sizes.append({x: len(image_datasets[x])})

class_names = image_datasets['train'].classes

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")