In [1]:
import csv
from collections import defaultdict
import itertools
import time

class OptimizedApriori:
    def __init__(self, min_sup=0.01, min_conf=0.5):
        self.min_sup = min_sup
        self.min_conf = min_conf
        self.transactions = []
        self.goods_map = {}
        
    def load_data(self, goods_path, trans_path):
        """高效加载大规模数据"""
        # 加载商品信息
        with open(goods_path, 'r', encoding='utf-8') as f:
            reader = csv.DictReader(f)
            for row in reader:
                self.goods_map[row['ID']] = row['Name']
        
        # 流式读取交易数据
        with open(trans_path, 'r') as f:
            reader = csv.reader(f)
            next(reader)  # Skip header
            for row in reader:
                if not row: continue
                items = [item.strip() for item in row[1:] if item.strip()]
                self.transactions.append(items)
                
        print(f"已加载商品数: {len(self.goods_map)}, 交易数: {len(self.transactions)}")

    def run(self):
        """优化后的Apriori算法"""
        start_time = time.time()
        
        # 生成频繁1-项集
        item_counts = defaultdict(int)
        for trans in self.transactions:
            for item in trans:
                item_counts[item] += 1
                
        freq_items = {
            1: [tuple([item]) for item, cnt in item_counts.items() 
                if cnt / len(self.transactions) >= self.min_sup]
        }
        print(f"1-项集数量: {len(freq_items[1])}")
        
        # 逐层生成k-项集
        k = 2
        while True:
            candidates = self._generate_candidates(freq_items[k-1])
            if not candidates:
                break
                
            # 使用事务压缩技术加速计数
            candidate_counts = defaultdict(int)
            for trans in self.transactions:
                trans_set = set(trans)
                for cand in candidates:
                    if set(cand).issubset(trans_set):
                        candidate_counts[cand] += 1
                        
            # 筛选频繁项集
            freq_k = [
                itemset for itemset, cnt in candidate_counts.items()
                if cnt / len(self.transactions) >= self.min_sup
            ]
            
            if not freq_k:
                break
                
            freq_items[k] = freq_k
            print(f"{k}-项集数量: {len(freq_k)}")
            k += 1
            
        # 生成关联规则
        rules = self._generate_rules(freq_items)
        
        print(f"总耗时: {time.time()-start_time:.2f}s")
        return freq_items, rules
    
    def _generate_candidates(self, prev_items):
        """带剪枝的候选生成"""
        candidates = set()
        sorted_items = sorted(prev_items)
        
        # 连接步（前缀匹配）
        for i in range(len(sorted_items)):
            for j in range(i+1, len(sorted_items)):
                itemset1 = sorted_items[i]
                itemset2 = sorted_items[j]
                
                if itemset1[:-1] == itemset2[:-1]:
                    new_cand = itemset1 + (itemset2[-1],)
                    candidates.add(new_cand)
                    
        # 剪枝步（子集检查）
        valid_candidates = []
        for cand in candidates:
            subsets = itertools.combinations(cand, len(cand)-1)
            if all(subset in prev_items for subset in subsets):
                valid_candidates.append(cand)
                
        return valid_candidates
    
    def _generate_rules(self, freq_items):
        """高效规则生成"""
        rules = []
        for k in range(2, len(freq_items)+1):
            for itemset in freq_items[k]:
                total_count = sum(1 for t in self.transactions 
                                if set(itemset).issubset(t))
                
                for i in range(1, k):
                    for antecedent in itertools.combinations(itemset, i):
                        antecedent = tuple(sorted(antecedent))
                        consequent = tuple(sorted(set(itemset)-set(antecedent)))
                        
                        ante_count = sum(1 for t in self.transactions 
                                      if set(antecedent).issubset(t))
                        
                        if ante_count == 0:
                            continue
                            
                        conf = total_count / ante_count
                        if conf >= self.min_conf:
                            rules.append((
                                antecedent,
                                consequent,
                                total_count/len(self.transactions),
                                conf
                            ))
        return rules
    
    def print_results(self, freq_items, rules, top_n=10):
        """可视化输出"""
        print("\n=== 高频项集 ===")
        for k in sorted(freq_items.keys()):
            print(f"{k}-项集 (前{top_n}个):")
            for itemset in freq_items[k][:top_n]:
                names = [f"{self.goods_map.get(i, i)}({i})" for i in itemset]
                print("  ", ", ".join(names))
                
        print("\n=== 强关联规则 ===")
        for i, (antec, conseq, sup, conf) in enumerate(rules[:top_n], 1):
            ante_names = [f"{self.goods_map.get(i, i)}({i})" for i in antec]
            conseq_names = [f"{self.goods_map.get(i, i)}({i})" for i in conseq]
            print(f"规则#{i}: {', '.join(ante_names)} => {', '.join(conseq_names)}")
            print(f"  支持度: {sup:.3f}, 置信度: {conf:.3f}\n")

# ================== 使用示例 ==================
if __name__ == "__main__":
    apriori = OptimizedApriori(
        min_sup=0.03,   # 1000条数据时建议2%-5%
        min_conf=0.7
    )
    
    # 加载数据（替换为实际路径）
    apriori.load_data(
        goods_path="goods.csv",
        trans_path="out1.csv"
    )
    
    # 运行算法
    freq_items, rules = apriori.run()
    
    # 打印结果
    apriori.print_results(freq_items, rules, top_n=6)

已加载商品数: 50, 交易数: 1000
1-项集数量: 48
2-项集数量: 6
3-项集数量: 1
总耗时: 0.55s

=== 高频项集 ===
1-项集 (前6个):
   'Casino'(2)
   'Truffle'(5)
   'Marzipan'(27)
   'Walnut'(29)
   'Apricot'(32)
   'Cherry'(48)
2-项集 (前6个):
   'Blackberry'(15), 'Apple'(36)
   'Gongolais'(22), 'Napoleon'(9)
   'Apple'(12), 'Blueberry'(16)
   'Apple'(12), 'Berry'(14)
   'Berry'(14), 'Blueberry'(16)
   'Lemon'(1), 'Single'(49)
3-项集 (前6个):
   'Apple'(12), 'Berry'(14), 'Blueberry'(16)

=== 强关联规则 ===
规则#1: 'Blackberry'(15) => 'Apple'(36)
  支持度: 0.139, 置信度: 0.751

规则#2: 'Apple'(36) => 'Blackberry'(15)
  支持度: 0.139, 置信度: 0.799

规则#3: 'Gongolais'(22) => 'Napoleon'(9)
  支持度: 0.181, 置信度: 0.842

规则#4: 'Napoleon'(9) => 'Gongolais'(22)
  支持度: 0.181, 置信度: 0.804

规则#5: 'Apple'(12) => 'Blueberry'(16)
  支持度: 0.259, 置信度: 0.863

规则#6: 'Blueberry'(16) => 'Apple'(12)
  支持度: 0.259, 置信度: 0.915

