# Asyncio与Aiohttp

In [27]:
# -*- coding:utf-8 -*-
import re
import os
import time
import copy
import logging
import logging.handlers
import asyncio
import aiohttp
import cchardet
import async_timeout
from html import unescape
from fnmatch import fnmatch
from datetime import datetime
from bs4 import BeautifulSoup
from urllib.parse import urlparse, urljoin
from pprint import pprint
from pymongo import MongoClient

HEADER = {
    'User-Agent': 'Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebKit/535.11 (KHTML, like Gecko) Chrome/17.0.963.84 Safari/535.11 LBBROWSER',
    'Accept': 'text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8',
    'Accept-Language': 'en-US,en;q=0.5',
    'Connection': 'keep-alive',
    'Accept-Encoding': 'gzip, deflate'
}

# MONGODB参数
MONGODB = {
    "user": "xxx",
    "passwd": "xxx",
    "host": "127.0.0.1:27017",
    "dbname": "xxx"
}

def init_root_logger_settings(log_name='spiders', logConsole=True):
    LOG_FORMAT = "%(asctime)s [%(levelname)s] [%(filename)s] [%(lineno)d]: %(message)s"
    log_dir = os.path.join(os.path.dirname(
        os.path.dirname(os.path.abspath(__file__))), "logs")
    if not os.path.exists(log_dir):
        os.mkdir(log_dir)

    root_logger = logging.getLogger()
    root_logger.setLevel(logging.INFO)
    formatter = logging.Formatter(fmt=LOG_FORMAT, datefmt="%m/%d/%Y %H:%M:%S")

    fh = logging.handlers.TimedRotatingFileHandler(filename=os.path.join(log_dir, log_name),
                                                   when='midnight', interval=1, encoding='utf-8')
    fh.setLevel(logging.INFO)
    fh.suffix = "%Y-%m-%d.log"
    fh.setFormatter(formatter)
    root_logger.addHandler(fh)

    if logConsole:
        ch = logging.StreamHandler()
        ch.setLevel(logging.DEBUG)
        ch.setFormatter(formatter)
        root_logger.addHandler(ch)

def connect_mongo(MONGODB):
    client = MongoClient(
        'mongodb://{}:{}@{}/{}'.format(MONGODB['user'],
                                       MONGODB['passwd'],
                                       MONGODB['host'],
                                       MONGODB['dbname']))
    return client[MONGODB['dbname']]

class Spider(object):
    def __init__(self, params=None):
        if params is None:
            self.params = {}
        else:
            self.params = params
        self.parser = Parser(params=params)

    def run(self):
        logging.info('Spider started!')
        start_time = datetime.now()
        loop = asyncio.get_event_loop()
        try:
            semaphore = asyncio.Semaphore(self.params.get('concurrency', 5))
            # tasks = asyncio.wait(self.parser.task(semaphore))
            loop.run_until_complete(self.parser.init_parse(semaphore))
            loop.run_until_complete(self.parser.task(loop, semaphore))
        except KeyboardInterrupt:
            for task in asyncio.Task.all_tasks():
                task.cancel()
            loop.run_forever()
        finally:
            end_time = datetime.now()
            # logging.info('Requests count: {}'.format(self.urls_count))
            # logging.info('Error count: {}'.format(len(self.error_urls)))
            logging.info('Time usage: {}'.format(end_time - start_time))
            logging.info('Spider finished!')
            loop.close()


class Parser(object):
    """解析url"""

    def __init__(self, params=None):
        if params is None:
            self.params = {}
        else:
            self.params = params
        self.item = Item(params=params)
        # 存放url的队列
        self.urlsQ = asyncio.Queue(maxsize=self.params.get('queue_size', 500))
        # 存放response的队列
        self.respQ = asyncio.Queue(maxsize=self.params.get('queue_size', 500))
        self.filter_urls = set()
        self.error_urls = set()
        self.urls_count = 0

    def is_running(self):
        is_running = False
        if not self.urlsQ.empty() or not self.respQ.empty():
            is_running = True
        return is_running

    async def init_parse(self, semaphore):
        async with aiohttp.ClientSession(cookies=self.params.get('cookies')) as session:
            logging.info('init first parse...')
            resp = await self.fetch(self.URL(self.params.get('start_url')), session, semaphore)
            await self.parse_urls(resp)

    def URL(self, url, params=None):
        if params is None:
            params = {}
        return {'url': url,
                'base_url': params.get('url', ''),
                'data': params.get('data'),
                'method': params.get('method', 'GET'),
                'depth': params.get('depth', 0) + 1,
                'end_type': self.params.get('end_type')}

    def normal_url(self, url, base_url):
        new_url = unescape(url.strip())
        if not re.match('(http|https)://', new_url):
            new_url = urljoin(base_url, new_url)
        return new_url[:-1] if new_url.endswith('/') else new_url

    async def do_with_302(self, resp):
        # 处理302重定向漏掉的页面
        for response in resp.history:
            await self.respQ.put_nowait(response)
            await self.item.save(response)

    async def parse_urls(self, resp):
        """提取urls"""
        logging.info('start parse urls...')
        if isinstance(resp, dict):
            base_url = resp['url']
            soup = BeautifulSoup(resp['text'], 'lxml')
            tags = soup.find_all(True)
            for tag in tags:
                params = {'method': 'GET', 'data': None, 'base_url': base_url}
                if tag.name == 'form':
                    url = tag.get('action', '')
                    params['method'] = tag.get('method')
                    params['data'] = self.parse_form_data(tag)
                elif tag.name == 'script':
                    url = tag.get('src', '')
                else:
                    url = tag.get('href', '')
                if url and url != base_url:
                    self.add(url, params)
        else:
            logging.warning('response is none.')
            return await resp

    def parse_form_data(self, tag):
        data = {}
        for input in tag.find_all('input'):
            name = input.get('name')
            if name and input.get('type') in ['text', 'password']:
                data[name] = input.get('value', '')
            elif input.get('type') == 'submit':
                name = 'submit'
                data[name] = input.get('value', '')
            else:
                if data.get(name) is None:
                    data[name] = list(input.get('value', ''))
                else:
                    data[name].append(input.get('value', ''))
        return data

    def add(self, url, params=None):
        # 格式化url
        new_url = self.normal_url(url, params.get('base_url'))
        # 根据域名列表过滤url和过滤重复url
        if not self.url_need_filter(new_url):
            url_obj = self.URL(new_url, params)
            self.urlsQ.put_nowait(url_obj)

    def url_need_filter(self, url):
        # 根据域名列表过滤url,如果域名列表为空则不过滤,否则根据列表来过滤
        if len(self.params.get('allowed_domains', [])) == 0:
            logging.info('no need to filter')
        else:
            domain = urlparse(url).netloc
            if domain not in self.params['allowed_domains']:
                return True
        # 过滤重复url
        if url not in self.filter_urls:
            self.filter_urls.add(url)
            if len(self.filter_urls) > self.params.get('amount', 500):
                return True
            # logging.info('filter urls count: {}'.format(len(self.filter_urls)))
            return False
        else:
            return True

    async def get_resp(self, resp):
        response = {}
        try:
            response['request'] = str(resp.request_info)
            # 统一url格式
            response['url'] = str(resp.url.with_port(resp.url.port))
            response['status'] = resp.status
            response['content'] = await resp.read()
            response['text'] = await resp.text()
            if len(resp.history) > 0:
                for history_item in  resp.history:
                    await self.get_resp(history_item)
        except Exception as e:
            logging.error('get response error: {}'.format(e))
        return response

    async def fetch(self, url_obj, session, semaphore, retry=1):
        # logging.info('start to fetch url_obj = {}'.format(url_obj))
        with (await semaphore):
            try:
                headers = HEADER
                proxy = self.params.get('proxy')
                if url_obj['method'] == 'GET':
                    async with session.get(url_obj['url'], headers=headers, proxy=proxy,
                                           timeout=self.params.get('timeout', 10)) as resp:
                        return await self.get_resp(resp)
                else:
                    async with session.post(url_obj['url'], headers=headers, proxy=proxy,
                                            data=url_obj['data'], timeout=self.params.get('timeout', 10)) as resp:
                        return await self.get_resp(resp)
            except Exception as e:
                logging.error('fetch error: {}'.format(e))
                if retry > 0:
                    logging.warning('fetch url failed, try twice.')
                    await self.fetch(url_obj, session, semaphore, retry - 1)

    async def execute_url(self, url_obj, session, semaphore):
        logging.info('start execute url...')
        resp = await self.fetch(url_obj, session, semaphore)
        try:
            if isinstance(resp, dict):
                # logging.info('respQ put in a resp')
                self.respQ.put_nowait(resp)
                await self.item.save(resp)
            else:
                self.error_urls.add(url_obj['url'])
                logging.info('Error url count: {}'.format(len(self.error_urls)))
            # logging.info('respQ has {} resp'.format(self.respQ.qsize()))
        except Exception as e:
            logging.error('execute url error: {}'.format(e))

        # logging.info('Parsed({}/{}): {}'.format(len(self.done_urls), len(self.filter_urls), url))
        # else:
        #     spider.parse(html)
        #     logging.info('Followed({}/{}): {}'.format(len(self.done_urls), len(self.filter_urls), url))

    async def produce_task(self, index, semaphore):
        # 拿取resp,并解析提取url,并保存数据
        # while self.is_running():
        logging.info('produce task {} start...'.format(index))
        while True:
        # while self.is_running():
            try:
                resp = await asyncio.wait_for(self.respQ.get(), 5)
                if resp is not None:
                    # logging.info(resp.get('url'))
                    asyncio.ensure_future(self.parse_urls(resp))
            except asyncio.TimeoutError:
                logging.error('produce task time out...')
                pass
                if not self.is_running():
                    break
                # await self.consume_task(semaphore)

    async def consume_task(self, index, semaphore):
        # 消费url并将response放到respQ队列中去
        logging.info('consume task {} start...'.format(index))
        async with aiohttp.ClientSession(cookies=self.params.get('cookies')) as session:
            # while self.is_running():
            while True:
                try:
                    url_obj = await asyncio.wait_for(self.urlsQ.get(), 5)
                    if url_obj is not None:
                        asyncio.ensure_future(self.execute_url(url_obj, session, semaphore))
                except asyncio.TimeoutError:
                    logging.error('consume task time out...')
                    if not self.is_running():
                        break
                    # await self.produce_task(semaphore)

    async def task(self, loop, semaphore):
        logging.info('start to create consume and produce tasks...')
        # while self.is_running():
        logging.info("=======================")
        consumers = [loop.create_task(self.consume_task(i, semaphore)) for i in range(5)]
        producers = [loop.create_task(self.produce_task(i, semaphore)) for i in range(2)]
        await asyncio.wait(consumers + producers)


class Item(object):
    """get data and save"""

    def __init__(self, params=None):
        if params is None:
            self.params = {}
        else:
            self.params = params
        db = connect_mongo(MONGODB)
        self.movie = db['my_crawler_urls']
        self.item_count = 0

    def allow_save(self, domain):
        # 允许域名列表为空,则全保存,不为空则过滤保存(不应在此处过滤)
        if not self.params.get('allowed_domains'):
            return True
        else:
            if domain in self.params.get('allowed_domains'):
                return True
            else:
                return False

    async def save(self, resp):
        if resp:
            self.movie.insert(resp)


if __name__ == '__main__':
    init_root_logger_settings()
    # start_url = 'https://www.baidu.com'
    # start_url = 'http://www.freebuf.com'
    # start_url = 'https://passport.csdn.net/account/verify;jsessionid=38DDDAFDD2567C69A26D12D482DC090B.tomcat2'
    params = {
        'start_url': 'https://www.douban.com',
        'end_type': 'PC',
        'allowed_domains': ['www.douban.com']
    }
    allsiteSpider = Spider(params)
    allsiteSpider.run()


  codeob = compile(source, filename, symbol, self.flags, 1)


RuntimeError: Event loop is closed

In [30]:
import aiohttp
import asyncio

async def fetch(session, url):
    async with session.get(url) as response:
        return await response.text()

async def main():
    async with aiohttp.ClientSession() as session:
        html = await fetch(session, 'http://python.org')
        print(html)


loop = asyncio.get_event_loop()
loop.run_until_complete(main())

RuntimeError: Event loop is closed