# Item 39 Use @classmethod Polymorphism to Construct Objects

- 在有些场合中, 可以在类中实现生成方法

In [1]:
import os

## 首先定义输入的数据, 主要存放文件中的数据

In [2]:
class InputData:
    def read(self):
        raise NotImplementedError

class PathInputData(InputData):
    def __init__(self, path):
        super().__init__()
        self.path = path
    def read(self):
        with open(self.path) as f:
            return f.read()

## Worker的作用是将输入的文件进行处理, 此处使用一些小技巧, 先写结构再使用继承

- 补充: reduce实际上起到一个合并worker计数的作用

In [3]:
class Worker:
    def __init__(self, input_data):
        self.input_data = input_data
        self.result = None
    def map(self):
        raise NotImplementedError
    def reduce(self, other):
        raise NotImplementedError

class LineCountWorker(Worker):
    def map(self):
        data = self.input_data.read()
        self.result = data.count('\n')

    def reduce(self, other):
        self.result += other.result

## 可以将其看成一个pipeline

- 从以下`generate_inputs`与`create_workers`两个函数可以看出, 是首先找到文件读取内容, 再使用worker进行工作, 最后将其整合起来.

In [4]:
def generate_inputs(data_dir):
    for name in os.listdir(data_dir):
        yield PathInputData(os.path.join(data_dir, name))

def create_workers(input_list):
    workers = []
    for input_data in input_list:
        workers.append(LineCountWorker(input_data))
    return workers

In [5]:
from threading import Thread

def execute(workers):
    threads = [Thread(target=w.map) for w in workers]
    
    for thread in threads: thread.start()
    for thread in threads: thread.join()
    
    first, *rest = workers
    
    for worker in rest:
        first.reduce(worker)
    return first.result

In [6]:
def mapreduce(data_dir):
    inputs = generate_inputs(data_dir)
    workers = create_workers(inputs)
    return execute(workers)

In [7]:
import os
import random

# @helper function
def write_test_files(tmpdir):
    os.makedirs(tmpdir)
    for i in range(100):
        with open(os.path.join(tmpdir, str(i)), 'w') as f:
            f.write('\n' * random.randint(0, 100))
    
tmpdir = 'test_inputs'

if not os.path.exists(tmpdir):
    write_test_files(tmpdir)

result = mapreduce(tmpdir)
print(f'There are {result} lines')

There are 5227 lines


## 改进方法

若要直接在类内生成inputs, 减少代码工程量

In [8]:
class GenericInputData:
    def read(self):
        raise NotImplementedError
    @classmethod
    def generate_inputs(cls, config):
        raise NotImplementedError


class PathInputData(GenericInputData):
    def __init__(self, path):
        super().__init__()
        self.path = path

    def read(self):
        with open(self.path) as f:
            return f.read()
            
    @classmethod
    def generate_inputs(cls, config):
        data_dir = config['data_dir']
        for name in os.listdir(data_dir):
            yield cls(os.path.join(data_dir, name))

In [9]:
class GenericWorker:
    def __init__(self, input_data):
        self.input_data = input_data
        self.result = None
        
    def map(self):
        raise NotImplementedError
    def reduce(self, other):
        raise NotImplementedError

    @classmethod
    def create_workers(cls, input_class, config):
        workers = []
        for input_data in input_class.generate_inputs(config):
            workers.append(cls(input_data))
        return workers


class LineCountWorker(GenericWorker):
    def map(self):
        data = self.input_data.read()
        self.result = data.count('\n')

    def reduce(self, other):
        self.result += other.result

In [10]:
def mapreduce(worker_class, input_class, config):
    workers = worker_class.create_workers(input_class, config)
    return execute(workers)

In [11]:
config = {'data_dir': tmpdir}
result = mapreduce(LineCountWorker, PathInputData, config)
print(f'There are {result} lines')

There are 5227 lines
