In [None]:
import sketch
import pandas as pd
import sqlite3
import datasketches
import datasketch
import base64
import uuid
import datetime
import logging
import heapq

In [None]:
class SketchBase:
    def __init__(self, data):
        self.name = self.__class__.__name__
        self.data = data

    @classmethod
    def from_series(cls, series):
        raise NotImplementedError(f"Need from_series method for {self.__class__}")
    
    def pack(self):
        return self.data
    
    @staticmethod
    def unpack(data):
        return data
        
    def to_dict(self):
        return {'name': self.__class__.__name__, 'data': self.pack()}
    
    @classmethod
    def from_dict(cls, data):
        tcls = cls
        if data['name'] != cls.__name__:
            for subclass in cls.__subclasses__():
                if subclass.__name__ == data['name']:
                    tcls = subclass
        return tcls(data=tcls.unpack(data['data']))
        

class Rows(SketchBase):
    @classmethod
    def from_series(cls, series):
        return cls(data=int(series.size))

class Count(SketchBase):
    @classmethod
    def from_series(cls, series):
        return cls(data=int(series.count()))
    
    
class MinHash(SketchBase):
    @classmethod
    def from_series(cls, series):
        minhash = datasketch.MinHash()
        minhash.update_batch([str(x).encode('utf-8') for x in series])
        lmh = datasketch.LeanMinHash(minhash)
        return cls(data=lmh)

    def pack(self):
        buf = bytearray(self.data.bytesize())
        self.data.serialize(buf)
        return base64.b64encode(buf).decode('utf-8')
         
    @staticmethod
    def unpack(data):
        return datasketch.LeanMinHash.deserialize(base64.b64decode(data))

In [None]:
class SketchPad:
    verison = '0.0.1'
    sketches = [Rows, Count, MinHash]
    
    def __init__(self, context=None):
        self.version = '0.0.1'
        self.id = str(uuid.uuid4())
        self.metadata = {
            'id': self.id,
            'creation_start': datetime.datetime.utcnow().isoformat()
        }
        self.context = context or {}
        # TODO: consider alternate naming convention
        # so can do dictionary lookups
        self.sketches = []
    
    @classmethod
    def from_series(cls, series, context=None):
        sp = cls(context=context)
        for skcls in cls.sketches:
            sp.sketches.append(skcls.from_series(series))
        sp.metadata['creation_end'] = datetime.datetime.utcnow().isoformat()
        sp.context['column_name'] = series.name
        return sp
    
    def get_sketch_by_name(self, name):
        sketches = [sk for sk in self.sketches if sk.name == name]
        if len(sketches) == 1:
            return sketches[0]
        return None
    
    def get_sketchdata_by_name(self, name):
        sketch = self.get_sketch_by_name(name)
        return sketch.data if sketch else None
    
    def minhash_jaccard(self, other):
        self_minhash = self.get_sketchdata_by_name('MinHash')
        other_minhash = other.get_sketchdata_by_name('MinHash')
        if self_minhash is None or other_minhash is None:
            return None
        return self_minhash.jaccard(other_minhash)
    
    def to_dict(self):
        return {
            'version': self.version,
            'metadata': self.metadata,
            'sketches': [s.to_dict() for s in self.sketches],
            'context': self.context
        }

    @classmethod
    def from_dict(cls, data):
        assert data['version'] == cls.version
        sp = cls()
        sp.id = data['metadata']['id']
        sp.metadata = data['metadata']
        sp.context = data['context']
        sp.sketches = [SketchBase(s) for s in data['sketches']]
        return sp

In [None]:
class Portfolio:
    def __init__(self, sketchpads=None):
        self.sketchpads = {sp.id: sp for sp in (sketchpads or [])}
    
    def add_dataframe(self, df):
        for col in df.columns:
            sp = SketchPad.from_series(df[col], context=df.attrs)
            self.add_sketchpad(sp)
    
    def add_dataframes(self, dfs):
        for df in dfs:
            self.add_dataframe(df)
            
    def add_sketchpad(self, sketchpad):
        self.sketchpads[sketchpad.id] = sketchpad
        
    def add_sqlite(self, sqlite_db_path):
        conn = sqlite3.connect(sqlite_db_path)
        tables = pd.read_sql("SELECT name FROM sqlite_schema WHERE type='table' ORDER BY name;", conn)
        logging.info(f'Found {len(tables)} tables in file {sqlite_db_path}')
        all_tables = {}
        for i, table in enumerate(tables.name):
            df = pd.read_sql(f"SELECT * from '{table}'", conn)
            df.attrs |= {'table_name': table, 'source': sqlite_db_path}
            self.add_dataframe(df)
        return list(tables.name)
        
    def closest_overlap(self, sketchpad, n=5):
        scores = []
        for sp in self.sketchpads.values():
            score = sketchpad.minhash_jaccard(sp)
            heapq.heappush(scores, (score, sp.id))
        top_n = heapq.nlargest(n, scores, key=lambda x: x[0])
        return [(s, self.sketchpads[i]) for s, i in top_n]

In [None]:
pf = Portfolio()
_ = pf.add_sqlite('datasets/fivethirtyeight.db')

In [None]:
import random
random_sketchpad = random.choice(list(pf.sketchpads.values()))

In [None]:
random_sketchpad.context

In [None]:
result = pf.closest_overlap(random_sketchpad)
[(s, x.context) for s, x in result]

In [None]:
print(get_uniques(random_sketchpad))
print('---')
for x in result:
    print(x, get_uniques(x[1]))

In [None]:
# def cardinality_spectogram(self, 

In [None]:
def run_sql(sql, path='datasets/fivethirtyeight.db'):
    conn = sqlite3.connect(path)
    table = pd.read_sql(sql, conn)
    return table

def get_uniques(sketchpad):
    conn = sqlite3.connect(sketchpad.context['source'])
    table = pd.read_sql(f"""
        select 
            "{sketchpad.context['column_name']}"
        from
            "{sketchpad.context['table_name']}"
        group by "{sketchpad.context['column_name']}"
    """, conn)
    return table

In [None]:
get_uniques(

In [None]:
run_sql("Select 'F*G' from 'classic-rock/classic-rock-song-list' limit 10")

In [None]:
run_sql("Select 'rees-davies' from 'next-bechdel/nextBechdel_allTests' limit 10")