# `supervenn` playground notebook

Created on: Sunday June 12th, 2022  
Created by: Jacob A Rose

In [None]:
#!pip3 install supervenn

In [None]:
%load_ext autoreload
%autoreload 2


from omegaconf import DictConfig, OmegaConf
import os
from rich import print as pp

import numpy as np
from typing import *
import inspect
from tqdm.auto import tqdm
import pandas as pd
from pathlib import Path
import logging
from imutils.catalog_registry import available_datasets

In [None]:
from imutils.big.common_catalog_utils import DataETL

In [None]:
DataETL.import_dataset_state(

In [None]:
dataset_catalog_dir = "/media/data_cifs/projects/prj_fossils/users/jacob/data/leavesdb-v1_1"
dataset_names = sorted(os.listdir(dataset_catalog_dir))
main_datasets = [d for d in dataset_names if (not "_minus_" in d) and (not "_w_" in d) and (not "original" in d) and ("512" in d) and ("family" in d)]

In [None]:
%%time
# data_dirs = [Path(dataset_catalog_dir, d) for d in main_datasets]
data_assets = [
    {"config_path": Path(dataset_catalog_dir, d, "CSVDataset-config.yaml"),
     "dataset_name": d}
    for d in main_datasets
]

datasets = {}
for asset in tqdm(data_assets):
    datasets[asset["dataset_name"]] = DataETL.import_dataset_state(**asset)
    pp(asset["dataset_name"])

In [None]:
print(len(datasets))

# print(datasets)

# datasets = {k:v for k, v in datasets.items() if "512" in k}
pp(datasets.keys())

In [None]:
from supervenn import supervenn
import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
style_list = ['default', 'classic'] + sorted(
    style for style in plt.style.available if style != 'classic'
)

pp(style_list)
style_label = "seaborn-notebook"
plt.style.context(style_label)
# Plot a demonstration figure for every available style sheet.
# for style_label in style_list:
    # with plt.rc_context({"figure.max_open_warning": len(style_list)}):
        # with plt.style.context(style_label)

## Extant Leaves -- various settings gallery

In [None]:
# plt.figure(figsize=(16, 8), dpi=150)

sets_orderings_list = ["size",  'chunk count', 'random', 'minimize gaps', None]
chunks_orderings_list = ["size",  'chunk count', 'random', 'minimize gaps', None]

fig, ax = plt.subplots(len(sets_orderings_list), len(chunks_orderings_list), figsize=(16*len(chunks_orderings_list), 8*len(sets_orderings_list)), dpi=150)

for j, chunks_ordering in enumerate(chunks_orderings_list):
    for i, sets_ordering in enumerate(sets_orderings_list):

        select_column = "family"
        ax_title = f"Extant Leaves - range of thresholds - ({sets_ordering=}) - ({chunks_ordering=})"
        select_stem = "Extant"


        selected_sets = {k:v for k,v in datasets.items() if select_stem in k}
        sets = [set(v[0].samples_df[select_column].values) for k, v in selected_sets.items()]
        labels = [k for k, v in selected_sets.items()]

        sv = supervenn(sets, labels, sets_ordering=sets_ordering, ax=ax[j, i])
        ax[j, i].set_title(ax_title)
    # plt.suptitle(suptitle, fontsize="xx-large")
    
plt.savefig(f"Extant Leaves {select_column} categories.svg")

In [None]:
sets_orderings_list = ["size",  'chunk count', 'random', 'minimize gaps', None]
chunks_orderings_list = ["size",  'chunk count', 'random', 'minimize gaps', None]

fig, ax = plt.subplots(len(sets_orderings_list), len(chunks_orderings_list), figsize=(16*len(chunks_orderings_list), 8*len(sets_orderings_list)), dpi=150)

for j, chunks_ordering in enumerate(chunks_orderings_list):
    for i, sets_ordering in enumerate(sets_orderings_list):

        select_column = "genus"
        ax_title = f"Extant Leaves - range of thresholds - ({sets_ordering=}) - ({chunks_ordering=})"
        select_stem = "Extant"


        selected_sets = {k:v for k,v in datasets.items() if select_stem in k}
        sets = [set(v[0].samples_df[select_column].values) for k, v in selected_sets.items()]
        labels = [k for k, v in selected_sets.items()]

        sv = supervenn(sets, labels, sets_ordering=sets_ordering, ax=ax[j, i])
        ax[j, i].set_title(ax_title)
    # plt.suptitle(suptitle, fontsize="xx-large")
    
plt.savefig(f"Extant Leaves {select_column} categories.svg")

sets_orderings_list = ["size",  'chunk count', 'random', 'minimize gaps', None]
chunks_orderings_list = ["size",  'chunk count', 'random', 'minimize gaps', None]

fig, ax = plt.subplots(len(sets_orderings_list), len(chunks_orderings_list), figsize=(16*len(chunks_orderings_list), 8*len(sets_orderings_list)), dpi=150)

for j, chunks_ordering in enumerate(chunks_orderings_list):
    for i, sets_ordering in enumerate(sets_orderings_list):

        select_column = "species"
        ax_title = f"Extant Leaves - range of thresholds - ({sets_ordering=}) - ({chunks_ordering=})"
        select_stem = "Extant"


        selected_sets = {k:v for k,v in datasets.items() if select_stem in k}
        sets = [set(v[0].samples_df[select_column].values) for k, v in selected_sets.items()]
        labels = [k for k, v in selected_sets.items()]

        sv = supervenn(sets, labels, sets_ordering=sets_ordering, ax=ax[j, i])
        ax[j, i].set_title(ax_title)
    # plt.suptitle(suptitle, fontsize="xx-large")
    
plt.savefig(f"Extant Leaves {select_column} categories.svg")

## Fossil Leaves -- various settings gallery

In [None]:
# plt.figure(figsize=(16, 8), dpi=150)

sets_orderings_list = ["size",  'chunk count', 'minimize gaps']
chunks_orderings_list = ["size",  'chunk count', 'minimize gaps']

fig, ax = plt.subplots(len(sets_orderings_list), len(chunks_orderings_list), figsize=(16*len(chunks_orderings_list), 8*len(sets_orderings_list)), dpi=150)

for j, chunks_ordering in tqdm(enumerate(chunks_orderings_list)):
    for i, sets_ordering in tqdm(enumerate(sets_orderings_list)):

        select_column = "family"
        ax_title = f"Fossil Leaves - range of thresholds - ({sets_ordering=}) - ({chunks_ordering=})"
        select_stem = "Fossil"

        selected_sets = {k:v for k,v in datasets.items() if (select_stem in k) and (f"{select_column}_3" in k)}
        sets = [set(v[0].samples_df[select_column].values) for k, v in selected_sets.items()]
        labels = [k for k, v in selected_sets.items()]

        sv = supervenn(sets, labels, sets_ordering=sets_ordering, ax=ax[j, i])
        ax[j, i].set_title(ax_title)
    # plt.suptitle(suptitle, fontsize="xx-large")
    
plt.savefig(f"Fossil Leaves {select_column} categories.svg")

## All datasets comparisons

In [None]:
# plt.figure(figsize=(16, 8), dpi=150)

sets_orderings_list = ["size",  'chunk count', 'minimize gaps']
chunks_orderings_list = ["size",  'chunk count', 'minimize gaps']

fig, ax = plt.subplots(len(sets_orderings_list), len(chunks_orderings_list), figsize=(16*len(chunks_orderings_list), 8*len(sets_orderings_list)), dpi=150)

for j, chunks_ordering in tqdm(enumerate(chunks_orderings_list)):
    for i, sets_ordering in tqdm(enumerate(sets_orderings_list)):

        select_column = "family"
        ax_title = f"Fossil Leaves - range of thresholds - ({sets_ordering=}) - ({chunks_ordering=})"
        # select_stem = "Fossil"

        selected_sets = {k:v for k,v in datasets.items() if (f"{select_column}_3" in k)}
        sets = [set(v[0].samples_df[select_column].values) for k, v in selected_sets.items()]
        labels = [k for k, v in selected_sets.items()]

        sv = supervenn(sets, labels, sets_ordering=sets_ordering, ax=ax[j, i])
        ax[j, i].set_title(ax_title)
    # plt.suptitle(suptitle, fontsize="xx-large")
    
plt.savefig(f"Fossil Leaves {select_column} categories.svg")

In [None]:
selected_sets['Fossil_family_3_512'][0].samples_df

In [None]:
# plt.figure(figsize=(16, 8), dpi=150)

# sets_orderings_list = ["size",  'chunk count', 'minimize gaps']
# chunks_orderings_list = ["size",  'chunk count', 'minimize gaps']

# fig, ax = plt.subplots(3, 1, figsize=(16, 8*3), dpi=200)

chunks_ordering = 'minimize gaps'
sets_ordering = 'size'
thresh = 3
columns = ["family", "genus", "species", "collection"]

for i, select_column in enumerate(tqdm(columns)):

    fig, ax = plt.subplots(1, 1, figsize=(16, 8), dpi=200)
# select_column = "family"
    ax_title = f"Leavesdb v1.1 - {select_column}"
    selected_sets = {k:v for k,v in datasets.items() if (f"family_{thresh}" in k)}
    sets = [set(v[0].samples_df[select_column].values) for k, v in selected_sets.items()]
    labels = [k for k, v in selected_sets.items()]

    sv = supervenn(sets, labels, sets_ordering=sets_ordering, ax=ax, widths_minmax_ratio=0.01)
    ax.set_title(ax_title)

    plt.savefig(f"{select_column} label distribution across Leavesdbv1_1 datasets (thresholded at family={thresh} - supervenn diagram.svg")
# plt.suptitle(suptitle, fontsize="xx-large")

# plt.savefig(f"multi-level taxonommy labels across Leavesdbv1_1 datasets - supervenn diagram.svg")

In [None]:
chunks_ordering = 'minimize gaps'
sets_ordering = 'size'
thresh = [50,100]
columns = ["family", "genus", "species", "collection"]

for i, select_column in enumerate(tqdm(columns)):

    fig, ax = plt.subplots(1, 1, figsize=(16, 8), dpi=200)
# select_column = "family"
    ax_title = f"Leavesdb v1.1 - {select_column}"
    selected_sets = {k:v for k,v in datasets.items() if (f"family_{thresh[0]}" in k) or (f"family_{thresh[1]}" in k)}
    sets = [set(v[0].samples_df[select_column].values) for k, v in selected_sets.items()]
    labels = [k for k, v in selected_sets.items()]

    sv = supervenn(sets, labels, sets_ordering=sets_ordering, ax=ax)
    ax.set_title(ax_title)

    plt.savefig(f"{select_column} label distribution across Leavesdbv1_1 datasets (thresholded at family={thresh} - supervenn diagram.svg")
# plt.suptitle(suptitle, fontsize="xx-large")

# plt.savefig(f"multi-level taxonommy labels across Leavesdbv1_1 datasets - supervenn diagram.svg")

In [None]:
sets_orderings_list = ["size",  'chunk count', 'random', 'minimize gaps', None]
chunks_orderings_list = ["size",  'chunk count', 'random', 'minimize gaps', None]

fig, ax = plt.subplots(len(sets_orderings_list), len(chunks_orderings_list), figsize=(16*len(chunks_orderings_list), 8*len(sets_orderings_list)), dpi=150)

for j, chunks_ordering in enumerate(chunks_orderings_list):
    for i, sets_ordering in enumerate(sets_orderings_list):

        select_column = "genus"
        ax_title = f"Extant Leaves - range of thresholds - ({sets_ordering=}) - ({chunks_ordering=})"
        select_stem = "Extant"


        selected_sets = {k:v for k,v in datasets.items() if select_stem in k}
        sets = [set(v[0].samples_df[select_column].values) for k, v in selected_sets.items()]
        labels = [k for k, v in selected_sets.items()]

        sv = supervenn(sets, labels, sets_ordering=sets_ordering, ax=ax[j, i])
        ax[j, i].set_title(ax_title)
    # plt.suptitle(suptitle, fontsize="xx-large")
    
plt.savefig(f"Extant Leaves {select_column} categories.svg")

sets_orderings_list = ["size",  'chunk count', 'random', 'minimize gaps', None]
chunks_orderings_list = ["size",  'chunk count', 'random', 'minimize gaps', None]

fig, ax = plt.subplots(len(sets_orderings_list), len(chunks_orderings_list), figsize=(16*len(chunks_orderings_list), 8*len(sets_orderings_list)), dpi=150)

for j, chunks_ordering in enumerate(chunks_orderings_list):
    for i, sets_ordering in enumerate(sets_orderings_list):

        select_column = "species"
        ax_title = f"Extant Leaves - range of thresholds - ({sets_ordering=}) - ({chunks_ordering=})"
        select_stem = "Extant"


        selected_sets = {k:v for k,v in datasets.items() if select_stem in k}
        sets = [set(v[0].samples_df[select_column].values) for k, v in selected_sets.items()]
        labels = [k for k, v in selected_sets.items()]

        sv = supervenn(sets, labels, sets_ordering=sets_ordering, ax=ax[j, i])
        ax[j, i].set_title(ax_title)
    # plt.suptitle(suptitle, fontsize="xx-large")
    
plt.savefig(f"Extant Leaves {select_column} categories.svg")

In [None]:
select_column = "family"
suptitle = f"Fossil Leaves shared {select_column} categories across a range of thresholds"
select_stem = "Fossil"


selected_sets = {k:v for k,v in datasets.items() if select_stem in k}
sets = [set(v[0].samples_df[select_column].values) for k, v in selected_sets.items()]
labels = [k for k, v in selected_sets.items()]
plt.figure(figsize=(16, 8), dpi=150)
sv = supervenn(sets, labels) #, side_plots=False)
plt.suptitle(suptitle, fontsize="xx-large")

### Meerkat dataset definition

In [None]:
import meerkat as mk
from meerkat.contrib.imagenette import download_imagenette

# download_imagenette(".")
# dp = mk.DataPanel.from_csv("imagenette2-160/imagenette.csv")
# dp["img"] = mk.ImageColumn.from_filepaths(dp["img_path"])

# dp[["label", "split", "img"]].lz[:3]

In [None]:
from torch.utils.data import Dataset


class MeerkatDataset(Dataset):
    """Torch dataset wrapper around meerkat dp"""

    def __init__(self, datapanel, xs, ys):
        self.dataset = datapanel
        self.x_names = xs
        self.y_names = ys

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        # if self.x_names is single element, return single element
        if len(self.x_names) > 1:
            x = [self.dataset[idx][input_feat] for input_feat in self.x_names]
        else:
            x = self.dataset[idx][self.x_names[0]]
        if len(self.y_names) > 1:
            y = [self.dataset[idx][output_feat] for output_feat in self.y_names]
        else:
            y = self.dataset[idx][self.y_names[0]]
        return (x, y)

In [None]:
# dir(available_datasets)

# available_datasets.tags

# available_datasets.get("Fossil_1024")

catalog_path = "/media/data_cifs/projects/prj_fossils/users/jacob/data/leavesdb-v1_1/Fossil_family_3_1024/Fossil_family_3_1024-full_dataset.csv"
# df = pd.read_csv(catalog_path)
dp = mk.DataPanel.from_csv(catalog_path)
print(dp.columns)
dp["img"] = mk.ImageColumn.from_filepaths(dp["path"])
dp[["family", "genus", "img"]].lz[:3]

In [None]:
df = dp.to_pandas()

In [None]:
df.groupby("family").apply(len)

In [None]:
dp2 = dp.lz[:4]

In [None]:
dp2.write("test_dp")
dp2

## supervenn + fossil leaves

In [None]:
available_datasets.tags

In [None]:
import torchdatasets

In [None]:
available_datasets.search("Fossil", "v1_1")

In [None]:
available_datasets.search("PNAS_family_100_original")



In [None]:
os.listdir(pnas_path)

In [None]:
os.listdir(general_fossil_path)
# os.listdir(florissant_fossil_path)

In [None]:
pnas_path = Path(available_datasets.get("PNAS_family_100_original"), "train.csv")

# general_fossil_path, florissant_fossil_path = available_datasets.get("Fossil_512")
fossil_path = "/media/data_cifs/projects/prj_fossils/users/jacob/data/leavesdb-v1_1/Fossil_512/Fossil_512-full_dataset.csv"
extant_100_path = "/media/data_cifs/projects/prj_fossils/users/jacob/data/leavesdb-v1_1/Extant_Leaves_family_100_512/Extant_Leaves_family_100_512-full_dataset.csv"

In [None]:
pnas_df = pd.read_csv(pnas_path)
fossil_df = pd.read_csv(fossil_path)
extant_100_df = pd.read_csv(extant_100_path)
# general_fossil_df = pd.read_csv(general_fossil_path)
# florissant_fossil_df = pd.read_csv(florissant_fossil_path)

In [None]:
pnas_df.describe(include='all')

In [None]:
fossil_df.describe(include='all')

In [None]:
extant_100_df.describe(include='all')

In [None]:
from supervenn import supervenn
import matplotlib.pyplot as plt

select_column = "family"
sets = [set(pnas_df[select_column].values),
        set(fossil_df[select_column].values),
        set(extant_100_df[select_column].values)]

labels = ["PNAS_100",
          "Fossils",
          "Extant_100"]

plt.figure(figsize=(16, 8))
sv = supervenn(sets, labels) #, side_plots=False)

In [None]:
sv.chunks

In [None]:
dir(sv)

In [None]:
from supervenn import supervenn


sets = [{0, 1, 2, 3, 4}, {3, 4, 5}, {1, 6, 7, 8}]


supervenn(sets)#, side_plots=False)

In [None]:
from supervenn import supervenn
import matplotlib.pyplot as plt
plt.figure(figsize=(16, 8))


sets = [{0, 1, 2, 3, 4}, {3, 4, 5}, {1, 6, 7, 8}]

sv = supervenn(sets, side_plots="right")

dir(sv)

In [None]:
from omegaconf import DictConfig, OmegaConf
import os
from rich import print as pp
import hydra


import numpy as np
from typing import *
import inspect
from tqdm.auto import tqdm
import pandas as pd
from pathlib import Path
import logging