# 以@classmethod形式的多态去通用地构建对象

**示例1：**实现一套MapReduce流程，定义公共基类来表示输入的数据。

In [1]:
class InputData(object):
    def read(self):
        raise NotImplementedError

In [2]:
class PathInputData(InputData):
    def __init__(self, path):
        super().__init__()
        self.path = path
        
    def read(self):
        return open(self.path).read()

为MapReduce工作线程定义一套类似的抽象接口，以便用标准的方式来处理输入的数据。

In [3]:
class Worker(object):
    def __init__(self, input_data):
        self.input_data = input_data
        self.result = None

    def map(self):
        raise NotImplementedError

    def reduce(self, other):
        raise NotImplementedError

定义实现类，以实现我们想要的MapReduce功能。本例所实现的功能，是一个简单的换行符计数器。

In [4]:
class LineCountWorker(Worker):
    def map(self):
        data = self.input_data.read()
        self.result = data.count('\n')

    def reduce(self, other):
        self.result += other.result

通过辅助函数将对象联系起来。**示例：**列出某个目录的内容，并为该目录下的每个文件创建一个PathInputData实例。

In [5]:
import os

def generate_inputs(data_dir):
    for name in os.listdir(data_dir):
        yield PathInputData(os.path.join(data_dir, name))

In [6]:
def create_workers(input_list):
    workers = []
    for input_data in input_list:
        workers.append(LineCountWorker(input_data))
    return workers

**执行流程：**将MapReduce流程中的map步骤发到多个线程之中，然后反复调用reduce方法，将map步骤的结果合并成一个最终值。

In [7]:
from threading import Thread

def execute(workers):
    threads = [Thread(target=w.map) for w in workers]
    for thread in threads: thread.start()
    for thread in threads: thread.join()

    first, rest = workers[0], workers[1:]
    for worker in rest:
        first.reduce(worker)
    return first.result

In [8]:
def mapreduce(data_dir):
    inputs = generate_inputs(data_dir)
    workers = create_workers(inputs)
    return execute(workers)

## 创建100个临时文件，每个文件写入一个0-100的随机数，统计最后总数

In [9]:
from tempfile import TemporaryDirectory
import random

def write_test_files(tmpdir):
    for i in range(100):
        with open(os.path.join(tmpdir, str(i)), 'w') as f:
            f.write('\n' * random.randint(0, 100))

with TemporaryDirectory() as tmpdir:
    write_test_files(tmpdir)
    result = mapreduce(tmpdir)

print('There are', result, 'lines')

There are 5013 lines


**问题：**MapReduce函数不够通用，如果要编写其他对的InputData或者Work子类，就需要重写generate_inputs，create_workers和mapreduce函数，以便与之匹配。

## 多态形式的改造方案

In [10]:
class GenericInputData(object):
    def read(self):
        raise NotImplementedError
    
    # 根据通用的接口来创建新的InputData实例
    # config：含有配置参数的字典
    @classmethod
    def generate_inputs(cls, config):
        raise NotImplementedError

In [11]:
class PathInputData(GenericInputData):
    def __init__(self, path):
        super().__init__()
        self.path = path

    def read(self):
        return open(self.path).read()

    @classmethod
    def generate_inputs(cls, config):
        # 通过config字典来查询输入文件所在的目录
        data_dir = config['data_dir']
        for name in os.listdir(data_dir):
            yield cls(os.path.join(data_dir, name))

In [12]:
class GenericWorker(object):
    def __init__(self, input_data):
        self.input_data = input_data
        self.result = None

    def map(self):
        raise NotImplementedError

    def reduce(self, other):
        raise NotImplementedError

    @classmethod
    def create_workers(cls, input_class, config):
        workers = []
        # 类级别的多态方法
        for input_data in input_class.generate_inputs(config):
            workers.append(cls(input_data))
        return workers

In [13]:
class LineCountWorker(GenericWorker):
    def map(self):
        data = self.input_data.read()
        self.result = data.count('\n')

    def reduce(self, other):
        self.result += other.result

In [14]:
def mapreduce(worker_class, input_class, config):
    workers = worker_class.create_workers(input_class, config)
    return execute(workers)

In [15]:
with TemporaryDirectory() as tmpdir:
    write_test_files(tmpdir)
    config = {'data_dir': tmpdir}
    # 传递多个class进行调用
    result = mapreduce(LineCountWorker, PathInputData, config)
print('There are', result, 'lines')

There are 5562 lines
