# PaleoAI Dataset Arithmetic

`paleoai_dataset_arithmetic.ipynb`

This notebook is for using pandas dataframes for elegantly adding or subtracting sets of data samples while maintaining unique-specimen constraints based on enforcing uniqueness for a user-specified id column. Using simple operator overloading in Python class definitions, we can perform complex queries with minimal boilerplate.

Author: Jacob A Rose  
Created on: Monday July 19th, 2021

## Top

In [2]:
# !conda list
# !pip list
test = "won"
print(test)


won


In [1]:
# def left_union(data_df: pd.DataFrame, other_df: pd.DataFrame, id_col: str="catalog_number", suffixes=("_x", "_y")) -> pd.DataFrame:
#     """
#     Return a new dataframe containing all rows from `data_df`, concatenated with any rows that only exist in `other_df`. Any rows that are shared between the 2 default to only including the values from `data_df`.
    
#     """
#     return data_df.merge(other_df, how='outer', on=id_col, suffixes=suffixes)

In [2]:
import pandas as pd


def intersection(data_df: pd.DataFrame, other_df: pd.DataFrame, id_col: str="catalog_number", suffixes=("_x", "_y")) -> pd.DataFrame:
    """
    Return a new dataframe containing only rows that share the same values for `id_col` between `data_df` and `other_df`
    
    Equivalent to an AND join between sets
    """
    return data_df.merge(other_df, how='inner', on=id_col, suffixes=suffixes)


def left_exclusive(data_df: pd.DataFrame, other_df: pd.DataFrame, id_col: str="catalog_number") -> pd.DataFrame:
    """
    Return a new dataframe containing only rows from `data_df` that do not share an `id_col` value with any row from `other_df`.
    
    Equivalent to subtracting the set of `id_col` values in `other_df` from `data_df`
    """
    omit = list(other_df[id_col].values)
    
    return data_df[data_df[id_col].apply(lambda x: x not in omit)]

In [3]:
# extant_df = pd.read_csv("/media/data_cifs/projects/prj_fossils/data/processed_data/leavesdb-v0_3/Extant-dataset_leavesdb-v0_3.csv", index_col=0)
# pnas_train = pd.read_csv("/media/data_cifs/projects/prj_fossils/data/processed_data/data_splits/PNAS_family_100/train.csv")
# pnas_test = pd.read_csv("/media/data_cifs/projects/prj_fossils/data/processed_data/data_splits/PNAS_family_100/test.csv")
# pnas_df = pd.concat([pnas_train, pnas_test])

# extant_in_pnas = intersection(data_df=extant_df,
#                               other_df=pnas_df,
#                               id_col="catalog_number",
#                               suffixes=("_extant", "_pnas"))

# # In order to only keep original columns
# suffixes=("_extant", "_pnas")
# extant_in_pnas = extant_in_pnas.drop(columns = [c for c in extant_in_pnas.columns if c.endswith(suffixes[1])])
# extant_in_pnas = extant_in_pnas.rename(columns = {c:c.split(suffixes[0])[0] for c in extant_in_pnas.columns})

In [4]:
# # extant_minus_pnas -> rows exclusive to extant dataset
# extant_minus_pnas = left_exclusive(data_df=extant_df,
#                                    other_df=pnas_df,
#                                    id_col="catalog_number",
#                                    suffixes=("_extant", "_pnas"))

# # pnas_minus_extant -> rows exclusive to pnas dataset
# pnas_minus_extant = left_exclusive(data_df=pnas_df,
#                                    other_df=extant_df,
#                                    id_col="catalog_number",
#                                    suffixes=("_pnas", "_extant"))

# extant_minus_pnas = left_exclusive(data_df=extant_df, other_df=pnas_df, id_col="catalog_number", suffixes=("_extant", "_pnas"))

# pnas_minus_extant = left_exclusive(data_df=pnas_df, other_df=extant_df, id_col="catalog_number", suffixes=("_pnas", "_extant"))

# extant_and_pnas = left_union(data_df=extant_df, other_df=pnas_df, id_col="catalog_number", suffixes=("_extant", "_pnas"))

## Code

In [5]:
import logging
import os.path
 
def initialize_logger():
    logger = logging.getLogger()
    logger.setLevel(logging.DEBUG)
     
    # create console handler and set level to info
    handler = logging.StreamHandler()
    handler.setLevel(logging.INFO)
    formatter = logging.Formatter("%(levelname)s - %(message)s")
    handler.setFormatter(formatter)
    logger.addHandler(handler)

initialize_logger()
    
import torchdata
from typing import Union, List, Any, Tuple
# from collections import Counter
from lightning_hydra_classifiers.utils import template_utils
from lightning_hydra_classifiers.utils.common_utils import trainvaltest_split
import collections
from omegaconf import OmegaConf, DictConfig
from lightning_hydra_classifiers.data.common import CommonDataSelect, CommonDataset, LeavesLightningDataModule
from lightning_hydra_classifiers.data import fossil, extant, pnas
from rich import print as pp
import os

from typing import *
from pathlib import Path
import matplotlib.pyplot as plt

from IPython.display import display

log = template_utils.get_logger(__name__, level=logging.DEBUG)
import pandas as pd

pd.set_option('display.max_rows', 500)
pd.set_option('display.max_columns', 500)
pd.set_option('display.max_colwidth', 200)

  rank_zero_deprecation(


In [6]:
config_dir = "/media/data/jacob/GitHub/lightning-hydra-classifiers/configs"

from hydra.experimental import compose, initialize, initialize_config_dir
from omegaconf import OmegaConf, DictConfig
os.chdir(config_dir)
print(f"cwd = {os.getcwd()}")

def initialize_config(config_dir: str,
                      overrides=None):
    with initialize_config_dir(config_dir=config_dir, job_name="multi-gpu_experiment"):

        cfg = compose(config_name="multi-gpu", overrides=overrides)
        OmegaConf.set_struct(cfg, False)
        return cfg

cwd = /media/data/jacob/GitHub/lightning-hydra-classifiers/configs


In [7]:
from lightning_hydra_classifiers.utils.common_utils import LabelEncoder, trainval_split
from lightning_hydra_classifiers.data.common import CommonDataset, LeavesLightningDataModule, plot_split_distributions
# default_config = initialize_config(config_dir=config_dir,
#                                    overrides=["datamodule=default_datamodule"])
# config = DictConfig({"datamodule.dataset.name":"Extant_family_10_minus_PNAS_family_100_512"})
# default_config = DictConfig({'datamodule':LeavesLightningDataModule.default_config()})

# pp(OmegaConf.to_container(datamodule.datamodule_config, resolve=True))
# default_config = DictConfig({'datamodule':LeavesLightningDataModule.default_config()})
# user_config = DictConfig({"datamodule":
#                               {"dataset":
#                                    {"name":"Extant_family_10_minus_PNAS_family_100_512"}
#                               }
#                          })
# pp(OmegaConf.to_container(OmegaConf.merge(default_config, user_config), resolve=True))

In [8]:
output_dir = "/media/data/jacob/GitHub/prj_fossils_contrastive/notebooks/Extant_family_10_1024_minus_PNAS_family_100_1024"

#         config = DictConfig({"dataset":
#                                        {"name":"Extant_family_10_minus_PNAS_family_100_512"}
#                             })
# config.dataset.config.name = "Extant_family_10_1024_in_PNAS_family_100_1024"
datamodule = LeavesLightningDataModule(config=None, #config, #default_config,
                                       data_dir=output_dir)
config.hparams.classes = datamodule.classes
config.hparams.num_classes = len(config.hparams.classes)
config.dataset.config.classes = datamodule.classes
config.dataset.config.num_classes = len(config.hparams.classes)


data_loader = datamodule.data_loader

2021-07-22 17:50:20,667 lightning_hydra_classifiers.utils.common_utils INFO     LabelEncoder replacing 1 class encodings with that other an another class
INFO - LabelEncoder replacing 1 class encodings with that other an another class
2021-07-22 17:50:20,670 lightning_hydra_classifiers.utils.common_utils INFO     Replacing: {'Nothofagaceae': 'Fagaceae'}
INFO - Replacing: {'Nothofagaceae': 'Fagaceae'}


KeyError: ''

In [8]:
y_col = 'family'
seed = 5687
val_train_split = 0.2

# pnas_name = "PNAS_family_100_512"
# extant_name = "Extant_family_10_512"

pnas_name = "PNAS_family_100_1024"#512"
extant_name = "Extant_family_10_1024" #512"


## Load primary Extant and PNAS datamodules
pnas_cfg = initialize_config(config_dir=config_dir,
                        overrides=["dataset=pnas_dataset",
                                  "datamodule=standalone_datamodule"])
# pnas_cfg[f"dataset.config.name"] = pnas_name

pnas_cfg.datamodule.config.dataset.name = pnas_name
pp(OmegaConf.to_container(pnas_cfg, resolve=False))

pnas_datamodule = LeavesLightningDataModule(pnas_cfg)#.datamodule.config)

extant_cfg = initialize_config(config_dir=config_dir,
                        overrides=["dataset=extant_dataset",
                                   "datamodule=standalone_datamodule"])
# pnas_cfg[f"dataset.config.name"] = pnas_name

extant_cfg.datamodule.config.dataset.name = extant_name
extant_datamodule = LeavesLightningDataModule(extant_cfg) #.datamodule.config)

2021-07-22 09:54:41,710 lightning_hydra_classifiers.data.common INFO     [SELECT DATASET] (name=PNAS_family_100_1024, num_files=5311), 
dataset_dirs=
    /media/data_cifs/projects/prj_fossils/data/processed_data/data_splits/PNAS_family_100_1024
INFO - [SELECT DATASET] (name=PNAS_family_100_1024, num_files=5311), 
dataset_dirs=
    /media/data_cifs/projects/prj_fossils/data/processed_data/data_splits/PNAS_family_100_1024
2021-07-22 09:54:43,058 lightning_hydra_classifiers.data.common INFO     [RUNNING] datamodule.setup(None)
INFO - [RUNNING] datamodule.setup(None)


train 2124
val 531
train 2124
val 531
train 2124
val 531
train 2124
val 531


2021-07-22 09:54:46,061 lightning_hydra_classifiers.data.common INFO     [SELECT DATASET] (name=Extant_family_10_1024, num_files=25496), 
dataset_dirs=
    /media/data_cifs/projects/prj_fossils/data/processed_data/leavesdb-v0_3/catalog_files/extant_family_10/1024
INFO - [SELECT DATASET] (name=Extant_family_10_1024, num_files=25496), 
dataset_dirs=
    /media/data_cifs/projects/prj_fossils/data/processed_data/leavesdb-v0_3/catalog_files/extant_family_10/1024
2021-07-22 09:54:50,559 lightning_hydra_classifiers.data.common INFO     [RUNNING] datamodule.setup(None)
INFO - [RUNNING] datamodule.setup(None)


train 11815
val 2954
train 11815
val 2954
train 11815
val 2954
train 11815
val 2954


In [9]:
#########################################

dataset_name = f"{extant_name}_minus_{pnas_name}" # Extant_family_10_512_minus_PNAS_family_100_512
test_dataset_name = f"{extant_name}_in_{pnas_name}"
output_dir = f"/media/data/jacob/GitHub/prj_fossils_contrastive/notebooks/{dataset_name}"

#########################################

extant_dataset = extant_datamodule.dataset
pnas_dataset = pnas_datamodule.dataset
extant_df = extant_dataset.samples_df
pnas_df = pnas_dataset.samples_df

#########################################
#########################################

# extant_minus_pnas -> rows exclusive to extant dataset
extant_minus_pnas = left_exclusive(data_df=extant_df,
                                   other_df=pnas_df,
                                   id_col="catalog_number")
# extant_in_pnas -> rows from extant dataset that share a catalog_number with PNAS
extant_in_pnas = intersection(data_df=extant_df,
                              other_df=pnas_df,
                              id_col="catalog_number",
                              suffixes=("_extant", "_pnas"))

suffixes=("_extant", "_pnas")
extant_in_pnas = extant_in_pnas.drop(columns = [c for c in extant_in_pnas.columns if c.endswith(suffixes[1])])
extant_in_pnas = extant_in_pnas.rename(columns = {c:c.split(suffixes[0])[0] for c in extant_in_pnas.columns})

#########################################
### GENERATE TRAIN VAL SPLIT INDICES
#########################################

y = extant_minus_pnas[y_col]
data_splits = trainval_split(x=None,
                             y=y,
                             val_train_split=val_train_split,
                             random_state=seed,
                             stratify=True
                             )
train_idx, val_idx = data_splits['train'][0], data_splits['val'][0]

#########################################
### CREATE COMMONDATASETS FROM DATAFRAMES, USING THE SPLIT INDICES
#########################################


train_df = extant_minus_pnas.iloc[train_idx,:]
val_df = extant_minus_pnas.iloc[val_idx,:]

train_dataset_extant_minus_pnas = CommonDataset.from_dataframe(
                                                               sample_df=train_df,
                                                               config=None,
                                                               return_signature = ["image","target","path"],
                                                               subset_key="train")

val_dataset_extant_minus_pnas = CommonDataset.from_dataframe(
                                                             sample_df=val_df,
                                                             config=None,
                                                             return_signature = ["image","target","path"],
                                                             subset_key="val")

test_dataset_extant_in_pnas = CommonDataset.from_dataframe(sample_df=extant_in_pnas,
                                                            config=None,
                                                            return_signature = ["image","target","path"],
                                                            subset_key="test")

data_splits= {'train':train_dataset_extant_minus_pnas,
              'val': val_dataset_extant_minus_pnas,
              'test':test_dataset_extant_in_pnas}

In [None]:
data_splits['train'].config.name = dataset_name
data_splits['train'].config.subset_key = "train"

data_splits['val'].config.name = dataset_name
data_splits['val'].config.subset_key = "val"

data_splits['test'].config.name = test_dataset_name
data_splits['test'].config.subset_key = "test"

#########################################
### LABEL ENCODER
#########################################

replace = {"Nothofagaceae": "Fagaceae"}
label_encoder = LabelEncoder(replace=replace) # class2idx)
label_encoder.fit(data_splits["test"].targets)
label_encoder.fit(data_splits["train"].targets)
for d in list(data_splits.values()):
    d.label_encoder = label_encoder

    
#########################################
### EXPORT DATASET CONFIGURATION TO A COMBO OF CSV, JSON, YAML, AND JPG FILES.
#########################################
from lightning_hydra_classifiers.data.common import export_dataset_to_csv, import_dataset_from_csv
# output_dir = "/media/data/jacob/GitHub/prj_fossils_contrastive/notebooks/Extant_family_10_512_minus_PNAS_family_100_512"
export_dataset_to_csv(data_splits=data_splits,
                          label_encoder=label_encoder,
                          output_dir=output_dir)

In [7]:
# loaded_data_splits, conf = import_dataset_from_csv(data_catalog_dir = output_dir)
# print(conf)
# for k,v in loaded_data_splits.items():
#     print(k, repr(v))


output_dir = "/media/data/jacob/GitHub/prj_fossils_contrastive/notebooks/Extant_family_10_1024_minus_PNAS_family_100_1024"

#         config = DictConfig({"dataset":
#                                        {"name":"Extant_family_10_minus_PNAS_family_100_512"}
#                             })
# config.dataset.config.name = "Extant_family_10_1024_in_PNAS_family_100_1024"
datamodule = LeavesLightningDataModule(config=None, #config, #default_config,
                                       data_dir=output_dir)
config.hparams.classes = datamodule.classes
config.hparams.num_classes = len(config.hparams.classes)
config.dataset.config.classes = datamodule.classes
config.dataset.config.num_classes = len(config.hparams.classes)


data_loader = datamodule.data_loader

2021-07-22 17:43:22,848 lightning_hydra_classifiers.utils.common_utils INFO     LabelEncoder replacing 1 class encodings with that other an another class
INFO - LabelEncoder replacing 1 class encodings with that other an another class
2021-07-22 17:43:22,851 lightning_hydra_classifiers.utils.common_utils INFO     Replacing: {'Nothofagaceae': 'Fagaceae'}
INFO - Replacing: {'Nothofagaceae': 'Fagaceae'}


ConfigAttributeError: Missing key dataset
    full_key: dataset
    object_type=dict

In [None]:
# 

In [None]:
output_dir = "/media/data/jacob/GitHub/prj_fossils_contrastive/notebooks/Extant_family_10_1024_minus_PNAS_family_100_1024"

config = DictConfig({"dataset":
                               {"name":"Extant_family_10_minus_PNAS_family_100_512"}
                    })
datamodule = LeavesLightningDataModule(config=config, #default_config,
                                       data_dir=output_dir)

In [None]:
# def left_exclusive(data_df: pd.DataFrame, other_df: pd.DataFrame, id_col: str="catalog_number") -> pd.DataFrame:
#     """
#     Return a new dataframe containing only rows from `data_df` that do not share an `id_col` value with any row from `other_df`.
    
#     Equivalent to subtracting the set of `id_col` values in `other_df` from `data_df`
#     """
#     omit = list(other_df[id_col].values)
    
#     return data_df[data_df[id_col].apply(lambda x: x not in omit)]

In [None]:
pnas_cfg = initialize_config(config_dir=config_dir,
                        overrides=["dataset=pnas_dataset"])
pnas_data = LeavesLightningDataModule(pnas_cfg.datamodule.config)

extant_cfg = initialize_config(config_dir=config_dir,
                        overrides=["dataset=extant_dataset"])
extant_data = LeavesLightningDataModule(extant_cfg.datamodule.config)

extant_dataset = extant_data.dataset
pnas_dataset = pnas_data.dataset

print(len(extant_dataset), len(pnas_dataset))

extant_df = extant_dataset.samples_df
pnas_df = pnas_dataset.samples_df

In [None]:
# extant_minus_pnas -> rows exclusive to extant dataset
extant_minus_pnas = left_exclusive(data_df=extant_df,
                                   other_df=pnas_df,
                                   id_col="catalog_number")

# pnas_minus_extant -> rows exclusive to pnas dataset
# pnas_minus_extant = left_exclusive(data_df=pnas_df,
#                                    other_df=extant_df,
#                                    id_col="catalog_number")


# print(extant_minus_pnas.shape, pnas_minus_extant.shape)

# pnas_in_extant = intersection(data_df=pnas_df,
#                               other_df=extant_df,
#                               id_col="catalog_number",
#                               suffixes=("_pnas", "_extant"))

extant_in_pnas = intersection(data_df=extant_df,
                              other_df=pnas_df,
                              id_col="catalog_number",
                              suffixes=("_extant", "_pnas"))


# print(extant_in_pnas.shape, pnas_in_extant.shape)

suffixes=("_extant", "_pnas")
extant_in_pnas = extant_in_pnas.drop(columns = [c for c in extant_in_pnas.columns if c.endswith(suffixes[1])])
extant_in_pnas = extant_in_pnas.rename(columns = {c:c.split(suffixes[0])[0] for c in extant_in_pnas.columns})

extant_in_pnas

In [None]:
y_col = 'family'
seed = 5687
val_train_split = 0.2

y = extant_minus_pnas[y_col]

data_splits = trainval_split(x=None,
                             y=y,
                               val_train_split=val_train_split,
                               random_state=seed,
                               stratify=True
                               )

data_splits['train'][0].shape

In [None]:
train_idx, val_idx = data_splits['train'][0], data_splits['val'][0]

train_df = extant_minus_pnas.iloc[train_idx,:]
val_df = extant_minus_pnas.iloc[val_idx,:]

train_dataset_extant_minus_pnas = CommonDataset.from_dataframe(
                                                               sample_df=train_df,
                                                               config=None,
                                                               return_signature = ["image","target","path"],
                                                               subset_key="train")
#                                                                subset_key=None) #"train")


val_dataset_extant_minus_pnas = CommonDataset.from_dataframe(
                                                             sample_df=val_df,
                                                             config=None,
                                                             return_signature = ["image","target","path"],
                                                             subset_key="val")


test_dataset_extant_in_pnas = CommonDataset.from_dataframe(sample_df=extant_in_pnas,
                                                            config=None,
                                                            return_signature = ["image","target","path"],
                                                            subset_key="test")

data_splits= {'train':train_dataset_extant_minus_pnas,
              'val': val_dataset_extant_minus_pnas,
              'test':test_dataset_extant_in_pnas}

In [None]:
# for k,v in data_splits.items():    
#     print(k, v.__repr__())

# import_dataset_from_csv(self, data_dir: str)

In [None]:
# OmegaConf.to_container(test_dataset_extant_in_pnas.config, resolve=True)
# label_encoder = LabelEncoder() # class2idx)
# label_encoder.fit(data_splits["train"].targets)

# for d in list(data_splits.values()):
#     d.label_encoder = label_encoder
# test_dataset_extant_in_pnas.label_encoder
# label_encoder = LabelEncoder() # class2idx)
# label_encoder.fit(data_splits["test"].targets)



# test_dataset_extant_in_pnas.label_encoder
# label_encoder

# test_df = data_splits["test"].samples_df

# replace = {"Nothofagaceae": "Fagaceae"}
# label_encoder = LabelEncoder(replace=replace) # class2idx)
# label_encoder.fit(data_splits["test"].targets)

# label_encoder.fit(data_splits["train"].targets)



# for d in list(data_splits.values()):
#     d.label_encoder = label_encoder
    
# label_encoder
# test_df[test_df.family=="Nothofagaceae"].replace(replace)

In [None]:
# from typing import *
# from omegaconf import DictConfig
# from pathlib import Path
# import matplotlib.pyplot as plt
# pnas_df[pnas_df.catalog_number=="Wolfe_8535"]

# set(data_splits["test"].samples_df.family.astype(pd.CategoricalDtype()).cat.categories) - set(pnas_df.family.astype(pd.CategoricalDtype()).cat.categories)

# test_dataset_extant_in_pnas.label_encoder

In [None]:
def save_config(config: DictConfig, config_path: str):
    with open(config_path, "w") as f:
        f.write(OmegaConf.to_yaml(config, resolve=True))

def load_config(config_path: str) -> DictConfig:    
    with open(config_path, "r") as f:
        loaded = OmegaConf.load(f)
    return loaded


def export_image_data_diagnostics(data_splits: Dict[str,CommonDataset],
                                  output_dir: str='.',
                                  max_samples: int = 64,
                                  export_sample_images: bool=True,
                                  export_class_distribution_plots: bool=True) -> Dict[str,str]:
    image_paths = {"images": {},
                   "class_distribution_plots":{}}
    
    image_dir = os.path.join(output_dir, "images")
    plot_dir = os.path.join(output_dir, "plots")
    os.makedirs(image_dir, exist_ok = True)
    os.makedirs(plot_dir, exist_ok = True)

    if export_sample_images:
#         subsets = ['train', 'val', 'test']
        for subset in data_splits.keys():
            fig, ax = data_splits[subset].show_batch(indices=max_samples, include_colorbar=False,
                                                     suptitle = f"subset: {subset}, {max_samples} random images")
            img_path = os.path.join(image_dir, f"subset: {subset}, {max_samples} random images.jpg")
            image_paths["images"][subset] = img_path
            plt.savefig(img_path)

    if export_class_distribution_plots:
        fig, ax = plot_split_distributions(data_splits=data_splits)
        class_distribution_plot_path = os.path.join(plot_dir, f"class_distribution_plots_{[subset for subset in data_splits.keys()]}")
        image_paths["class_distribution_plots"]["all"] = class_distribution_plot_path
        plt.savefig(class_distribution_plot_path)

    return image_paths



def export_dataset_to_csv(data_splits: Dict[str,CommonDataset],
                          label_encoder: Optional[LabelEncoder]=None,
                          output_dir: str='.',
                          export_sample_images: bool=True,
                          export_class_distribution_plots: bool=True) -> Dict[str,str]:
    output_paths = {"tables":{},
                    "class_labels":{},
                    "configs":{}}
    os.makedirs(output_dir, exist_ok=True)
    for k, data in data_splits.items():
        subset_data_path = os.path.join(output_dir, f"{k}_data_table.csv")
        data.samples_df.to_csv(subset_data_path)
        output_paths["tables"][k] = subset_data_path
        
        if hasattr(data, "config"):
            subset_config_path = os.path.join(output_dir, f"{k}_config.yaml")
            save_config(config=data.config, config_path=subset_config_path)
            output_paths["configs"][k] = subset_config_path
        
        if hasattr(data, 'label_encoder') and (label_encoder is None):
            subset_label_path = os.path.join(output_dir, k + "_label_encoder.json")
            data.label_encoder.save(subset_label_path)
            output_paths["class_labels"][k] = subset_data_path
            
    if label_encoder is not None:
        full_label_encoder_path = os.path.join(output_dir, "label_encoder.json")
        label_encoder.save(full_label_encoder_path)
        output_paths["class_labels"]["full"] = full_label_encoder_path

    
    export_image_data_diagnostics(data_splits=data_splits,
                                  output_dir=output_dir,
                                  max_samples = 64,
                                  export_sample_images=export_sample_images,
                                  export_class_distribution_plots=export_class_distribution_plots)
        
    return output_paths
    

def import_dataset_from_csv(data_catalog_dir: str) -> Dict[str, CommonDataset]:
    
    data_paths = list(Path(data_catalog_dir).glob("*.csv"))
    config_paths = list(Path(data_catalog_dir).glob("*.yaml"))
    label_encoder_paths = list(Path(data_catalog_dir).glob("*.json"))
    
    assert len(data_paths) == len(config_paths)
    input_paths = {"tables":{},
                   "class_labels":{},
                   "configs":{}}
    subsets = ["train", "val", "test"]
    for subset in subsets:
        input_paths["tables"][subset] = [p for p in data_paths if p.stem.startswith(subset)][0]
        input_paths["configs"][subset] = [p for p in config_paths if p.stem.startswith(subset)][0]
    
    if len(label_encoder_paths) == 1:
        label_encoder = LabelEncoder.load(label_encoder_paths[0])
    else:
        raise(f'Currently cannot distinguish between multiple label_encoders, please delete all but 1 in experiment directory. Contents: {label_encoder_paths}')
    
    data_splits = {}
    for subset in subsets:
        sample_df = pd.read_csv(input_paths["tables"][subset])
        config = load_config(input_paths["configs"][subset])
        data_splits[subset] = CommonDataset.from_dataframe(sample_df,
                                                           config=config)
        data_splits[subset].label_encoder = label_encoder
        
    return data_splits

In [None]:
# train_d = data_splits['train']


# save_config(config=config, config_path=subset_config_path)
# loaded = load_config(config_path=subset_config_path)

# pp(OmegaConf.to_container(config, resolve=True))

# pp(OmegaConf.to_container(loaded))

# pp(data_paths, config_paths, label_encoder_paths)

In [None]:
replace = {"Nothofagaceae": "Fagaceae"}
label_encoder = LabelEncoder(replace=replace) # class2idx)
label_encoder.fit(data_splits["test"].targets)
label_encoder.fit(data_splits["train"].targets)
for d in list(data_splits.values()):
    d.label_encoder = label_encoder
    
label_encoder

In [None]:
output_dir = "/media/data/jacob/GitHub/prj_fossils_contrastive/notebooks/Extant_family_10_512_minus_PNAS_family_100_512"

export_dataset_to_csv(data_splits=data_splits,
                          label_encoder=label_encoder,
                          output_dir=output_dir)

loaded_data_splits = import_dataset_from_csv(data_catalog_dir = output_dir)

for k,v in loaded_data_splits.items():
    print(k, repr(v))

In [None]:
import matplotlib.pyplot as plt
plt.style.available

# plt.style.use("seaborn-notebook")
plt.style.use("seaborn-white")

In [None]:
# self = data_splits['train']
# import torch
# import numpy as np
# # indices = [0,1, 2,3,4,5]
# indices = np.array(indices)
# batch = [self[idx] for idx in indices]

# # y = [batch[idx][1] for idx in indices]
# y = torch.Tensor(np.array([(item[1]) for item in batch])).to(int)
# y

# batch = (torch.stack([item[0] for item in batch]),
#          torch.stack([torch.Tensor(item[1]) for item in batch]))
# return batch

In [None]:


image_dir = os.path.join(output_dir, "images")
plot_dir = os.path.join(output_dir, "plots")
os.makedirs(image_dir, exist_ok = True)
os.makedirs(plot_dir, exist_ok = True)


image_paths = {"images": {},
               "class_distribution_plots":{}}

max_samples = 64
subsets = ['train', 'val', 'test']
for subset in subsets:    
    fig, ax = data_splits[subset].show_batch(indices=max_samples, include_colorbar=False,
                                             suptitle = f"subset: {subset}, {max_samples} random images")
    img_path = os.path.join(image_dir, f"subset: {subset}, {max_samples} random images.jpg")
    image_paths["images"][subset] = img_path
    plt.savefig(img_path)

fig, ax = plot_split_distributions(data_splits=data_splits)
class_distribution_plot_path = os.path.join(plot_dir, f"class_distribution_plots_{[subset for subset in data_splits.keys()]}")
image_paths["class_distribution_plots"]["all"] = class_distribution_plot_path

plt.savefig(class_distribution_plot_path)


# dir(ax)
# cb=plt.colorbar()
# cb.remove()
# plt.draw()

# plt.subplots_adjust(left=None, bottom=0.0, right=None, top=0.95, wspace=None, hspace=None)

In [None]:
v.display_grid(indices=64,
               label_font_size="medium")
plt.suptitle(k, fontsize="large")
rows = 5
plt.subplots_adjust(left=None, bottom=0.0, right=None, top=0.9, wspace=None, hspace=0.05*rows) #wspace=0.05, hspace=0.1)

import torch

torch.stack

set(pnas_df.family) #- set(label_encoder.classes[:20])

import seaborn as sns
import matplotlib.pyplot as plt

plt.style.use('fivethirtyeight')
sns_context = "talk"
sns_style = "seaborn-bright"
sns.set_context(context=sns_context, font_scale=0.8)

sns.set_palette("Accent")
# valid contexts = paper, notebook, talk, poster - 
# with notebook being 1:1 and paper being smaller and poster being largest
# sns.set_style('darkgrid')
# sns.set_palette('Set2')

# plt.style.use(sns_style)
fig, ax = plot_split_distributions(data_splits= {'train':train_dataset_extant_minus_pnas,
                                                 'val': val_dataset_extant_minus_pnas,
                                                 'test':test_dataset_extant_in_pnas})

# for label in ax[1].xaxis.get_ticklabels()[::2]:
#     label.set_visible(False)

In [None]:
import numpy as np


classes = list(set(list(train_df['family'])))#[:100]
num_classes = len(classes)
df = pd.DataFrame(np.random.random((num_classes,num_classes)), columns=classes)

In [None]:
!where latex

In [None]:
# plt.style.use('dark_background')
plt.style.use('fivethirtyeight')

plt.figure(figsize = (15,15))
plt.imshow(df.values, cmap="BrBG")


label_format = '{:,.0f}'

# nothing done to ax1 as it is a "control chart."
ax = plt.gca()


import matplotlib.ticker as mticker

# fixing yticks with "set_yticks"
# ticks_loc = ax.get_yticks().tolist()
# ax.set_yticklabels([label_format.format(x) for x in ticks_loc])

# # fixing yticks with matplotlib.ticker "FixedLocator"
# ticks_loc = ax3.get_yticks().tolist()
# ax3.yaxis.set_major_locator(mticker.FixedLocator(ticks_loc))
# ax3.set_yticklabels([label_format.format(x) for x in ticks_loc])

# # fixing xticks with FixedLocator but also using MaxNLocator to avoid cramped x-labels
# ax.xaxis.set_major_locator(mticker.MaxNLocator(75))
# ticks_loc = ax.get_xticks().tolist()
# ax.xaxis.set_major_locator(mticker.FixedLocator(ticks_loc))
# ax.set_xticklabels([label_format.format(x) for x in ticks_loc])





# ax = plt.gca()

# ax.set_xticklabels(classes)
# plt.xticks(
# rotation=90, #45, 
# horizontalalignment='right',
# fontweight='light',
# fontsize='small'
# )


# plot a heatmap with annotation
# sns.heatmap(df, annot=True, annot_kws={"size": 7})

In [None]:
plt.style.available

In [None]:
assert pnas_df.shape[0] == 5311, "Expected full PNAS_family_100 dataset to have 5311 samples"

assert pnas_minus_extant.shape[0] == 2518
assert pnas_in_extant.shape[0] == 2793

assert pnas_in_extant.shape[0] == extant_in_pnas.shape[0]

assert pnas_in_extant.merge(pnas_minus_extant, on="catalog_number", how="inner").shape[0] == 0

In [None]:
pnas_df.shape[0]

In [None]:
import numpy as np

In [None]:
import matplotlib.pyplot as plt
from PIL import Image
from pathlib import Path

fig, ax = plt.subplots(1,2, figsize=(24,12))
extant_in_pnas.iloc[:1,:].apply(lambda x: [ax[0].imshow(Image.open(x.path_extant).convert("L"), cmap="gray"), ax[1].imshow(Image.open(x.path_pnas).convert("L"), cmap="gray")], axis=1)
ax[0].set_title("Extant (No-Crop)")
ax[1].set_title("PNAS (Cropped)")
plt.suptitle(Path(extant_in_pnas.iloc[0,:].path_extant).stem)
# extant_in_pnas.iloc[:1,:].apply(lambda x: print(type(x)), axis=1)

In [None]:
2793+22704

In [None]:
# In order to only keep original columns
suffixes=("_extant", "_pnas")
extant_in_pnas = extant_in_pnas.drop(columns = [c for c in extant_in_pnas.columns if c.endswith(suffixes[1])])
extant_in_pnas = extant_in_pnas.rename(columns = {c:c.split(suffixes[0])[0] for c in extant_in_pnas.columns})

In [None]:
# data_df = extant_df
# other_df = pnas_df
# id_col = "catalog_number"


data_df.sort_values("catalog_number")
intersected = data_df.merge(other_df, on=id_col, how='inner').sort_values(id_col)

In [None]:
intersected

In [None]:
other_df.sort_values("catalog_number")

In [None]:
# extant_minus_pnas

pnas_minus_extant
pnas_df

In [None]:
data_df = pd.DataFrame(self.samples)
data_df = data_df.convert_dtypes()

other_df = pd.DataFrame(other.samples)
other_df = other_df.convert_dtypes()

In [None]:
data_df.describe(include='all')

In [None]:
other_df.describe(include='all')

In [None]:
# class CommonDataArithmetic(CommonDataset):
    
    
#     @property
#     def samples_df(self):        
#         data_df = pd.DataFrame(self.samples)
#         data_df = data_df.convert_dtypes()
#         return data_df

    
# other_df = pd.DataFrame(other.samples)
# other_df = other_df.convert_dtypes()
    
    
#     def intersection(self, other):
#         samples_df = self.samples_df
#         other_df = other.samples_df
        
#         intersection = data_df.merge(other_df, how='inner', on=self.id_col)
#         return intersection

#     def __sub__(self, other)
    
#         intersection = self.intersection(other)
#         samples_df = self.samples_df
        
#         remainder = samples_df[samples_df[self.id_col].apply(lambda x: x not in intersection[self.id_col])]
        
#         return remainder
    
    
    
    
    
data_df = pd.DataFrame(self.samples)
data_df = data_df.convert_dtypes()

other_df = pd.DataFrame(other.samples)
other_df = other_df.convert_dtypes()
        
        
#         init_params = self.init_params
#         init_params["files"] = data_df.iloc[:,0]

In [None]:
int.__sub__

In [None]:
pp(OmegaConf.to_container(self.config, resolve=True))

In [None]:
union_data = data_df.merge(other_df, how='inner', on=self.id_col)

d

In [None]:
concat_data = extant_train + pnas_train
concat_data

In [None]:
dir(concat_data)

In [None]:
concat_data.datasets

In [None]:
from typing import *
import collections
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_context(context='talk', font_scale=0.8)
# sns.set_style("whitegrid")

In [None]:
def compute_class_counts(targets: Sequence,
                         sort_by: Optional[Union[str, bool, Sequence]]="count"
                        ) -> Dict[str, int]:
    
    counts = collections.Counter(targets)
    if isinstance(sort_by, list):
        counts = {k: counts[k] for k in sort_by}
    elif (sort_by == "count"):
        counts = dict(sorted(counts.items(), key = lambda x:x[1], reverse=True))
    elif (sort_by is True):
        counts = dict(sorted(counts.items(), key = lambda x:x[0], reverse=True))
        
    return counts

def plot_class_distributions(targets: List[Any], 
                             sort_by: Optional[Union[str, bool, Sequence]]="count",
                             ax=None,
                             xticklabels: bool=True):
    """
    Example:
        counts = plot_class_distributions(targets=data.targets, sort=True)
    """
    
    counts = compute_class_counts(targets,
                                  sort_by=sort_by)
                        
    keys = list(counts.keys())
    values = list(counts.values())

    if ax is None:
        plt.figure(figsize=(16,12))
    ax = sns.histplot(x=keys, weights=values, discrete=True, ax=ax)
    plt.sca(ax)
    if xticklabels:
        plt.xticks(
            rotation=45, 
            horizontalalignment='right',
            fontweight='light',
            fontsize='medium'
        )
    else:
        ax.set_xticklabels([])
    
    return counts


def plot_split_distributions(data_splits: Dict[str, CommonDataset]):
    """
    Create 3 vertically-stacked count plots of train, val, and test dataset class label distributions
    """
    assert isinstance(data_splits, dict)
    num_splits = len(data_splits)
    
    if num_splits < 4:
        rows = num_splits
        cols = 1
    else:
        rows = int(num_splits // 2)
        cols = int(num_splits % 2)
    fig, ax = plt.subplots(rows, cols, figsize=(16*cols,8*rows))
    ax = ax.flatten()
    
    
    train_key = [k for k,v in data_splits.items() if "train" in k]
    if len(train_key)==1:
        train_counts = compute_class_counts(data_splits[train_key[0]].targets,
                                            sort_by="count")
    xticklabels=False
    num_samples = 0
    counts = {}
    for i, (k, v) in enumerate(data_splits.items()):
        if i == len(data_splits)-1:
            xticklabels=True
        counts[k] = plot_class_distributions(targets=v.targets, 
                                             sort_by=train_counts,
                                             ax = ax[i],
                                             xticklabels=xticklabels)
        plt.gca().set_title(f"{k} (n={len(v)})", fontsize='large')
        
        num_samples += len(v)
    
    plt.suptitle('-'.join(list(data_splits.keys())) + f"_splits (total={num_samples})", fontsize='x-large')
    plt.subplots_adjust(bottom=0.1, top=0.95, wspace=None, hspace=0.07)
    
    return fig, ax

In [None]:
from lightning_hydra_classifiers.data.common import plot_split_distributions, compute_class_counts


data_splits = {"train": data.train_dataset,
               "val": data.val_dataset,
               "test": data.test_dataset}

# plot_split_distributions(data_splits=data_splits)

In [None]:
# for k in list(data_splits.keys()):
#     data_splits[k] = pd.DataFrame([data_splits[k]]).assign(split = k)
    
# target_splits = pd.concat(list(data_splits.values()))
# target_splits.reset_index().describe(include='all')
import numpy as np    


# y_col = "target"
y_col = "family"
target_splits = pd.concat([pd.DataFrame(v.targets).assign(split = k) for k, v in data_splits.items()]).rename(columns={0:y_col})
target_splits.reset_index().describe(include='all')

# pd.DataFrame(target_splits.groupby("family"))

# pd.DataFrame(target_splits.groupby("split"))

pd.DataFrame(target_splits.groupby("split")["family"])#.agg([len]))

import seaborn as sns

sns.countplot(data=target_splits,
              x="family",
              hue="split")

In [None]:
from dataclasses import dataclass, field
import dataclasses
from omegaconf import DictConfig, OmegaConf
from rich import print as pp

In [None]:
ax = None
xticklabels = True

sns.set_style('darkgrid')
sns.set_palette('Set2')

In [None]:
# PLOT_TYPES = DictConfig({
#     "unstacked_grouped_countplot": {"multiple":"dodge",
#                                     "stat":"count", "kde":True, "shrink":0.95, "binwidth":1.5},
#     "stacked_filled_grouped_histplot": {"multiple":"fill",
#                                         "stat":"probability", "shrink":0.95, "binwidth":0.6},
#     "stacked_grouped_countplot": {"multiple":"stack",
#                                   "stat":"count", "shrink":0.9, "binwidth":1.5}
#     })
        
# kwargs = PLOT_TYPES
# pp(dict(kwargs))
# pp(OmegaConf.to_container(cfg, resolve=True))

In [None]:
from pathlib import Path



def plot_grouped_class_distributions(data: pd.DataFrame,
                                     x_col: str="family",
                                     group_col: Optional[str]=None,
                                     suptitle: Optional[str]=None,
                                     savefig: Optional[str]=None,
                                     single_fig_plot: Optional[bool]=True,
                                     log_dir: Union[Path, str]=".",
                                     height = 13,
                                     width = 25,
                                     kwargs: Optional[Dict[str,str]]=None):

    if isinstance(kwargs, dict):
        kwargs = [kwargs]
    elif kwargs is None:
        kwargs = [{"kwargs":{}}]
        
    default_kwargs = {"shrink":0.9, "binwidth":3.0}
    axes = []
    
    
    
    counts = compute_class_counts(targets=data[x_col],
                         sort_by="count"
                        )
    class_order = list(counts.keys())
    
    if single_fig_plot:
        rows = len(kwargs); cols = 1
        figsize=(width*cols,height*rows)
        fig, axes = plt.subplots(rows, cols, figsize=figsize)
        axes = axes.flatten()
        

    for i in range(len(kwargs)):
        kwargs_i = default_kwargs
        kwargs_i.update(kwargs[i]["kwargs"])
        
        if single_fig_plot:
            ax = axes[i]
            if "title" in kwargs[i]:
                ax.set_title(kwargs[i]["title"], fontsize="large")
            plt.subplots_adjust(bottom=0.05, top=0.96, wspace=None, hspace=0.25)
            
        else:
            fig, ax = plt.subplots(1, 1, figsize=(20,12))
            axes.append(ax)
            if "title" in kwargs[i]:
#                 ax.set_title(kwargs[i]["title"], fontsize="large")
                plt.suptitle(kwargs[i]["title"], fontsize="large")
            plt.subplots_adjust(bottom=0.15, top=0.95, wspace=None, hspace=0.3)            
        
        ax = sns.histplot(data=data,
                          x=x_col,
                          hue=group_col,
                          ax=ax,
                          pthresh=0.1,
                          **kwargs_i)

        plt.sca(ax)
        sns.despine()
        xticklabels = bool(data[x_col].nunique() < 100)
        if xticklabels:
            plt.xticks(
                rotation=45, 
                horizontalalignment='right',
                fontweight='light',
                fontsize='small'
            )
        else:
            ax.set_xticklabels([])
            
            
        if not single_fig_plot:
            if "savefig" in kwargs[i]:
                plt.savefig(kwargs[i]["savefig"])

    if single_fig_plot:
        plt.suptitle(suptitle, fontsize="x-large")
        if isinstance(savefig, (Path, str)):
            print(f'Saving: savefig={savefig}')
            plt.savefig(savefig)
        elif isinstance(suptitle, str):
            print(f'Saving: suptitle={suptitle}')
            plt.savefig(os.path.join(log_dir, f"{suptitle}.png"))
    return fig, axes

#         if "title" in kwargs[i]:
#             ax.set_title(kwargs[i]["title"], fontsize="large")        
#     plt.suptitle(suptitle)    
#     plt.subplots_adjust(bottom=0.1, top=0.95, wspace=None, hspace=0.3)

## Latest data distribution plots -- July 18th, 2021

In [None]:
configs = [initialize_config(config_dir=config_dir,
                        overrides=["dataset=pnas_dataset"]),
            initialize_config(config_dir=config_dir,
                        overrides=["dataset=extant_dataset"]),
            initialize_config(config_dir=config_dir,
                        overrides=["dataset=fossil_dataset", "hparams.image_size=1024"])
           ]

logdir = f"/media/data/jacob/GitHub/lightning-hydra-classifiers/outputs/data_distribution_logs"
os.makedirs(logdir, exist_ok=True)



for i in range(len(configs)):

    cfg = configs[i]
    data = LeavesLightningDataModule(cfg.datamodule.config)
    data_splits = {"train": data.train_dataset,
                   "val": data.val_dataset,
                   "test": data.test_dataset}


    dataset_name = cfg.dataset.config.name
    label_col = cfg.dataset.config.class_type
    group_col = "subset"

    kwargs_options = [{"kwargs":{"multiple":"dodge", "stat":"count", "kde":True, "shrink":0.9, "binwidth":2*1.5},
                       "title":f"{dataset_name}, Per-class countplot, grouped by subset",
                       "savefig":os.path.join(logdir, f"{dataset_name}_{label_col}_unstacked_subset_count_distributions.png")},
                      {"kwargs":{"multiple":"fill", "stat":"probability", "shrink":0.95, "binwidth":2*0.6},
                       "title":f"{dataset_name}, Per-class filled histograms, grouped by subset",
                       "savefig":os.path.join(logdir, f"{dataset_name}_{label_col}_stacked_subset_prior_probabilities.png")},
                      {"kwargs":{"multiple":"stack", "stat":"count", "shrink":0.9, "binwidth":2*1.5},
                       "title":f"{dataset_name}, Per-class countplot, grouped by subset",
                       "savefig":os.path.join(logdir, f"{dataset_name}_{label_col}_stacked_subset_count_distributions.png")}
                     ]

    target_splits = pd.concat([pd.DataFrame(v.targets).assign(**{group_col:k}) for k, v in data_splits.items()
                              ]).rename(columns={0:label_col})


    ### Sort classes by count
    data_df = target_splits
    counts = compute_class_counts(targets=data_df[label_col],
                                  sort_by="count"
                        )
    class_order = {label:i for i, label in enumerate(counts.keys())}

    data_df = data_df.assign(family_order = data_df.family.apply(lambda x: class_order[x]))
    target_splits = data_df.sort_values(by=["family_order"], ascending=True).drop(columns=["family_order"])



    plot_grouped_class_distributions(data=target_splits,
                                     x_col=label_col,
                                     group_col=group_col,
                                     single_fig_plot=False,
    #                                  suptitle=f"Dataset: {dataset_name} {label_col} class distributions",
                                     log_dir = logdir,
                                     kwargs = kwargs_options[:])

    plot_grouped_class_distributions(data=target_splits,
                                     x_col=label_col,
                                     group_col=group_col,
                                     single_fig_plot=True,
                                     suptitle=f"Dataset={dataset_name} {label_col} class distributions",
                                     log_dir = logdir,
                                     kwargs = kwargs_options[:])

## End

In [None]:
plt.style.use('fivethirtyeight')
assert isinstance(data_splits, dict)
num_splits = len(data_splits)

if num_splits < 4:
    rows = num_splits
    cols = 1
else:
    rows = int(num_splits // 2)
    cols = int(num_splits % 2)
fig, ax = plt.subplots(rows, cols, figsize=(16*cols,8*rows))
ax = ax.flatten()


train_key = [k for k,v in data_splits.items() if "train" in k]
if len(train_key)==1:
    train_counts = compute_class_counts(data_splits[train_key[0]].targets,
                                        sort_by="count")
xticklabels=False
num_samples = 0
counts = {}
for i, (k, v) in enumerate(data_splits.items()):
    if i == len(data_splits)-1:
        xticklabels=True
    counts[k] = plot_class_distributions(targets=v.targets, 
                                         sort_by=train_counts,
                                         ax = ax[i],
                                         xticklabels=xticklabels)
    plt.gca().set_title(f"{k} (n={len(v)})", fontsize='large')

    num_samples += len(v)

plt.suptitle('-'.join(list(data_splits.keys())) + f"_splits (total={num_samples})", fontsize='x-large')
plt.subplots_adjust(bottom=0.1, top=0.95, wspace=None, hspace=0.07)


##########################



import random
plt.style.use('fivethirtyeight')

width = 0.5

temp_summer=[ random.uniform(20,40) for i in range(5)]
temp_winter=[ random.uniform(0,10) for i in range(5)]

fig=plt.figure(figsize=(10,6))

city=['City A','City B','City C','City D','City E']
x_pos_summer=list(range(1,6))
x_pos_winter=[ i+width for i in x_pos_summer]

graph_summer=plt.bar(x_pos_summer, temp_summer, color='tomato', label='Summer', width=width)
graph_winter=plt.bar(x_pos_winter, temp_winter, color='dodgerblue', label='Winter', width=width)

plt.xticks([i+width/2 for i in x_pos_summer],city)
plt.title('City Temperature')
plt.ylabel('Temperature ($^\circ$C)')

#Annotating graphs
for summer_bar,winter_bar,ts,tw in zip(graph_summer,graph_winter,temp_summer,temp_winter):
    plt.text(summer_bar.get_x() + summer_bar.get_width()/2.0,summer_bar.get_height(),'%.2f$^\circ$C'%ts,ha='center',va='bottom')
    plt.text(winter_bar.get_x() + winter_bar.get_width()/2.0,winter_bar.get_height(),'%.2f$^\circ$C'%tw,ha='center',va='bottom')

plt.legend()  
plt.show()

In [None]:
display(data.train_dataset)
display(data.val_dataset)
display(data.test_dataset)

In [None]:
train_dataloader = data.train_dataloader()

train_dataloader

In [None]:
batch = next(iter(train_dataloader))

In [None]:
import pandas as pd


dir(pd.DataFrame)

In [None]:
pd.DataFrame({"0":[0,1,2,3,4], "1":[0,1,2,3,4]}).T.to_records()[0]

In [None]:
len(batch)

In [None]:
data.dataset.plot_trainvaltest_splits(data.train_dataset,
                                     data.val_dataset,
                                     data.test_dataset)

In [None]:
display(data.dataset)

In [None]:
import torchvision
from torchvision import transforms

torchvision.transforms.ToPILImage()(data.train_dataset[387][0])

In [None]:
# import re
# from pathlib import Path
# path_schema: str = "{family}_{genus}_{species}_{collection}_{catalog_number}"
# # path_schema: str = "{family}_{genus}_{species}_{catalog_number}"

# # # path_schema = Path("/media/data_cifs/projects/prj_fossils/data/processed_data/data_splits/PNAS_family_100_1536/train/Fabaceae")
# # filepath = 'Fabaceae_Derris_alborubra_Wolfe_9829.jpg'
# sep = "_"

# # from dataclasses import dataclass
# # from typing import *


# # @dataclass 
# # class PathSchema:
# #     path_schema: str = Path("{family}_{genus}_{species}_{collection}_{catalog_number}")
# #     schema_parts: List[str] = path_schema.split(sep)
# #     maxsplit = len(schema_parts) - 2
    
# #     def parse(self, path: Union[Path, str], sep: str="_"):
    
# #         parts = Path(path).stem.split(sep, maxsplit=maxsplit)
# #         if len(parts) == 5:
# #             family, genus, species, collection, catalog_number = parts
# #         if len(parts) == 4:
# #             family, genus, species, catalog_number = parts
# #             collection = catalog_number.split("_")

# #         return family, genus, species, collection, catalog_number

In [None]:
filepath = Path("/media/data_cifs/projects/prj_fossils/data/processed_data/leavesdb-v0_3/Extant_Leaves/Aizoaceae/Aizoaceae_Galenia_pubescens_Hickey_Hickey_8097.jpg").stem

path_schema: str = "{family}_{genus}_{species}_{collection}_{catalog_number}"
schema_parts = path_schema.split(sep)
maxsplit = len(schema_parts) - 2

print(f"schema_parts={schema_parts}")
print(f"maxsplit={maxsplit}")
family, genus, species, collection, catalog_number = Path(filepath).stem.split("_", maxsplit=maxsplit)
print(family, genus, species, collection, catalog_number)

In [None]:
filepath = Path("/media/data_cifs/projects/prj_fossils/data/processed_data/data_splits/PNAS_family_100_1536/train/Fabaceae/Fabaceae_Derris_alborubra_Wolfe_9829.jpg").stem

path_schema: str = "{family}_{genus}_{species}_{catalog_number}"
schema_parts = path_schema.split(sep)
maxsplit = len(schema_parts) - 2

print(f"schema_parts={schema_parts}")
print(f"maxsplit={maxsplit}")

family, genus, species, catalog_number = Path(filepath).stem.split("_", maxsplit=maxsplit)
print(family, genus, species, catalog_number)

In [None]:
toPIL = torchvision.transforms.ToPILImage("RGB")

In [None]:
data.show_batch()

In [None]:
data.show_batch(stage='val')

data.show_batch(stage='test')

In [None]:
train_dataset = data.train_dataset

train_dataloader = data.train_dataloader()

train_dataset = data.get_dataset("train")

train_dataset.show_batch()

# train_dataset = 
data.get_dataset("val")

# train_dataset = 
data.get_dataset("test")

In [None]:
data.dataset

train_dataloader

In [None]:
pp(OmegaConf.to_container(cfg.dataset, resolve=True))

pp(OmegaConf.to_container(cfg.datamodule.config.dataset, resolve=True))

pp(OmegaConf.to_container(cfg.datamodule, resolve=True))

In [None]:
dataset_name = "Fossil"
pp([k for k in CommonDataset.available_datasets.keys() if dataset_name in k])

dataset_name = "PNAS"
pp([k for k in CommonDataset.available_datasets.keys() if dataset_name in k])

dataset_name = "Extant"
pp([k for k in CommonDataset.available_datasets.keys() if dataset_name in k])

In [None]:
Extant_config = OmegaConf.create({"name": "Extant_family_10_512",
                            "val_split": 0.2,
                            "test_split": 0.3,
                            "threshold": 3,
                            "seed": 987485,
                            "class_type": "family",
                            "x_col":"path",
                            "y_col":"${.class_type}",
                            "id_col":"catalog_number"
})


config = OmegaConf.create({"name": "Fossil_512",
                            "val_split": 0.2,
                            "test_split": 0.3,
                            "threshold": 3,
                            "seed": 987485,
                            "class_type": "family",
                            "x_col":"path",
                            "y_col":"${.class_type}",
                            "id_col":"catalog_number"
})




pp(OmegaConf.to_container(config, resolve=True))
data = CommonDataset(config=config,
                     files=None,
                     class2idx=None)

data[1].image
print(data.__repr__())

data.label_encoder

In [None]:
# config_path = "/media/data/jacob/GitHub/lightning-hydra-classifiers/configs/datamodule/fossil_datamodule.yaml"
config_path = "/media/data/jacob/GitHub/lightning-hydra-classifiers/configs/multi-gpu.yaml"
config = OmegaConf.load(config_path)

In [None]:
pp(OmegaConf.to_container(config, resolve=True))

# Scratch

In [None]:
name = "Fossil_512"
dataset_dirs = CommonDataSelect.available_datasets[name]

dataset_dirs

In [None]:
# CommonDataset.available_datasets.keys()

# CommonDataset.available_datasets["Fossil_512"]#.keys()

# fossil.available_datasets

# d0 = CommonDataSelect.select_dataset_by_name("Fossil_512")

# dir(OmegaConf)

# config_path = "/media/data/jacob/GitHub/lightning-hydra-classifiers/configs/datamodule/fossil_datamodule.yaml"
# config = OmegaConf.load(config_path)

In [None]:
from pathlib import Path

# d1 = torchdata.datasets.Files.from_folder(Path(dataset_dirs[0]), regex="*/*.jpg")
d2 = torchdata.datasets.Files.from_folder(Path(dataset_dirs[1]), regex="*/*.jpg")
# d2

from itertools import repeat, chain
from more_itertools import collapse, flatten


cls = torchdata.datasets.Files

log.info(f"Concatenating dataset_dirs located at: {dataset_dirs}")
file_list = list(flatten(
                    [cls.from_folder(Path(root),
                                     regex="*/*.jpg").files
                     for root in dataset_dirs]
                                            ))
data = cls(files=file_list,
           name=name)


In [None]:
file_list

In [None]:
x_train, x_val, x_test = (split[0] for split in data_splits.values())
y_train, y_val, y_test = (split[1] for split in data_splits.values())



from rich import print as pp


pp(data_splits)

In [None]:
val_split = 0.2
test_split = 0.3

train_split = 1 - (val_split + test_split)

val_relative_split = val_split/(train_split + val_split)
train_relative_split = train_split/(train_split + val_split)
random_state = 0


x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=test_split, random_state=random_state, stratify=y)
print(f"x_train.shape={x_train.shape}, x_test.shape={x_test.shape}, y_train.shape={y_train.shape}, y_test.shape={y_test.shape}")


x_train, x_val, y_train, y_val = train_test_split(x_train, y_train, test_size=val_relative_split, random_state=random_state, stratify=y_train)

print(f"x_train.shape={x_train.shape}, x_val.shape={x_val.shape}, y_train.shape={y_train.shape}, y_val.shape={y_val.shape}")


print(f'Absolute splits: {[train_split, val_split, test_split]}')
print(f'Relative splits: [{train_relative_split:.2f}, {val_relative_split:.2f}, {test_split}]')

print(f'train+val={train_split+val_split} | test={test_split}')
print(f'train={train_relative_split:.2f} | val={val_relative_split:.2f}')

In [None]:
skf = StratifiedKFold(n_splits=n_splits)
skf.get_n_splits(x, y)
print(skf)

In [None]:
for train_index, test_index in skf.split(x, y):
#     print("TRAIN:", train_index, "TEST:", test_index)
    print("TRAIN:", train_index.shape, "TEST:", test_index.shape)
    X_train, X_test = x[train_index], x[test_index]
    y_train, y_test = y[train_index], y[test_index]
    
    print(f'y_test: {y_test}')

In [None]:
class Base:
    """Common base class for all `torchtraining` objects.
    Defines default `__str__` and `__repr__`.
    Most objects should customize `__str__` according to specific
    needs.
    Custom objects usually use `yaml.dump` to easily see parameters
    and whole pipeline.
    """

    def __str__(self) -> str:
        return f"{type(self).__module__}.{type(self).__name__}"

    def __repr__(self) -> str:
        parameters = ", ".join(
            "{}={}".format(key, value)
            for key, value in self.__dict__.items()
            if not key.startswith("_")
        )
        return "{}({})".format(self, parameters)

In [None]:
!pip list | grep fast

In [None]:
print(b)

In [None]:
import torchdata
torchdata.datasets.Files

In [None]:
import torchdata
dir(torchdata.datasets)

In [None]:
threshold=3
test_split=0.3
val_train_split=0.2
batch_size=32
num_workers=0
seed=8567
debug=False
normalize=True
image_size = 'auto'
channels=3
dataset_dir=None
predict_on_split="val"

print(dataset_name)

In [None]:
datamodule = FossilLightningDataModule(name=dataset_name,
                                       threshold=threshold,
                                       test_split=test_split,
                                       val_train_split=val_train_split,
                                       batch_size=batch_size,
                                       num_workers=num_workers,
                                       seed=seed,
                                       debug=debug,
                                       normalize=normalize,
                                       image_size=image_size,
                                       channels=channels,
                                       predict_on_split=predict_on_split)

datamodule

datamodule.setup("fit")
datamodule.setup("test")

# datamodule.show_batch("train")
# datamodule.show_batch("val")
# datamodule.show_batch("test")

In [None]:
ckpt_dir = "/media/data_cifs/projects/prj_fossils/users/jacob/experiments/July2021-Nov2021/Fossil_512_train-test/2021-07-12/06-03/model/checkpoints/"
ckpt_path = os.path.join(ckpt_dir, "best_model-epoch-epoch=05--val_loss-val_loss=96.52.ckpt")

print(os.path.isfile(ckpt_path))

In [None]:
ckpt_state = torch.load(ckpt_path)

print(type(ckpt_state))

In [None]:
print(ckpt_state.keys())

In [None]:
pl.__version__

In [None]:
print(ckpt_state['state_dict'].keys())

In [None]:
model = TransferLearningModel.load_from_checkpoint(ckpt_path)

In [None]:
train_data, val_data, test_data = data.create_trainvaltest_splits(dataset=data,
                                                                  test_split=0.3,
                                                                  val_train_split=0.2,
                                                                  shuffle=True,
                                                                  seed=3654,
                                                                  plot_distributions=True)

In [None]:

config_dir = "/media/data/jacob/GitHub/lightning-hydra-classifiers/configs"

from hydra.experimental import compose, initialize, initialize_config_dir
from omegaconf import OmegaConf
import os
from rich import print as pp
os.chdir(config_dir)

# context initialization
# with initialize(config_path="../configs", job_name="test_app"):

with initialize_config_dir(config_dir=config_dir, job_name="multi-gpu_experiment"):
    
    cfg = compose(config_name="multi-gpu")
#     print(OmegaConf.to_yaml(cfg))
    
    pp(OmegaConf.to_container(cfg, resolve=True))
    
    pp(os.environ)

In [None]:
plt.style.available

In [None]:

sns_context = "talk"
sns_style = "seaborn-bright"
sns.set_context(context=sns_context, font_scale=0.8)
# valid contexts = paper, notebook, talk, poster - 
# with notebook being 1:1 and paper being smaller and poster being largest
plt.style.use(sns_style)



In [None]:
# def plot_class_distributions(targets: List[Any], 
#                              sort: Union[bool,Sequence]=True,
#                              ax=None,
#                              xticklabels: bool=True):
#     """
#     Example:
#         counts = plot_class_distributions(targets=data.targets, sort=True)
#     """
#     counts = collections.Counter(targets)
#     if hasattr(sort, "__len__"):
#         counts = {k: counts[k] for k in sort}
#     elif sort is True:
#         counts = dict(sorted(counts.items(), key = lambda x:x[1], reverse=True))

#     keys = list(counts.keys())
#     values = list(counts.values())

#     if ax is None:
#         plt.figure(figsize=(16,12))
#     ax = sns.histplot(x=keys, weights=values, discrete=True, ax=ax)
#     plt.sca(ax)
#     if xticklabels:
#         plt.xticks(
#             rotation=45, 
#             horizontalalignment='right',
#             fontweight='light',
#             fontsize='medium'
#         )
#     else:
#         ax.set_xticklabels([])
    
#     return counts


# def plot_trainvaltest_splits(train_data,
#                              val_data,
#                              test_data):
#     """
#     Create 3 vertically-stacked count plots of train, val, and test dataset class label distributions
#     """
#     fig, ax = plt.subplots(3, 1, figsize=(16,8*3))

#     train_counts = plot_class_distributions(targets=train_data.targets, sort=True, ax = ax[0], xticklabels=False)
#     plt.gca().set_title(f"train (n={len(train_data)})", fontsize='large')
#     sort_classes = train_counts.keys()

#     val_counts = plot_class_distributions(targets=val_data.targets, ax = ax[1], sort=sort_classes, xticklabels=False)
#     plt.gca().set_title(f"val (n={len(val_data)})", fontsize='large')
#     test_counts = plot_class_distributions(targets=test_data.targets, ax = ax[2], sort=sort_classes)
#     plt.gca().set_title(f"test (n={len(test_data)})", fontsize='large')

#     plt.suptitle(f"Train-Val-Test_splits (total={len(data)})", fontsize='x-large')

#     plt.subplots_adjust(bottom=0.1, top=0.95, wspace=None, hspace=0.07)
    
#     return fig, ax



In [None]:
plot_trainvaltest_splits(train_data,
                         val_data,
                         test_data)

# End

In [None]:
train_data = select_from_dataset(data,
                                 indices=train_idx,
                                 update_class2idx=False,
                                 x_col = 'path',
                                 y_col = "family")

val_data = select_from_dataset(data,
                               indices=val_idx,
                               update_class2idx=False,
                               x_col = 'path',
                               y_col = "family")
val_data

test_data = select_from_dataset(data,
                                indices=test_idx,
                                update_class2idx=False,
                                x_col = 'path',
                                y_col = "family")



train_counts = plot_class_distributions(targets=train_data.targets, sort=True)
sort_classes = train_counts.keys()
val_counts = plot_class_distributions(targets=val_data.targets, sort=sort_classes)
test_counts = plot_class_distributions(targets=test_data.targets, sort=sort_classes)


train_data = (train_val_samples)
train_samples = np.array(train_val_samples)[train_idx]
val_samples = np.array(train_val_samples)[val_idx]


counts = plot_class_distributions(targets=data.targets, sort=True)

train_val_idx.shape
test_idx.shape

In [None]:
class DataSplit:

    def __init__(self,
                 dataset,
                 test_split=0.3,
                 val_train_split=0.2,
                 shuffle: bool=False,
                 seed: int=None):
        
        self.dataset = dataset

        dataset_size = len(dataset)
        self.indices = np.arange(range(dataset_size))
#         test_split = int(np.floor(test_train_split * dataset_size))

        if shuffle:
            np.random.shuffle(self.indices)

        targets = dataset.targets

        train_val_idx, test_idx = train_test_split(
                                               indices,
                                               test_size=test_split,
                                               random_state=seed,
                                               shuffle=shuffle,
                                               stratify=targets)
        
            
            
        train_indices, self.test_indices = self.indices[], self.indices[test_split:]
        train_size = len(train_indices)
        validation_split = int(np.floor((1 - val_train_split) * train_size))

        self.train_indices, self.val_indices = train_indices[ : validation_split], train_indices[validation_split:]

        self.train_sampler = SubsetRandomSampler(self.train_indices)
        self.val_sampler = SubsetRandomSampler(self.val_indices)
        self.test_sampler = SubsetRandomSampler(self.test_indices)

    def get_train_split_point(self):
        return len(self.train_sampler) + len(self.val_indices)

    def get_validation_split_point(self):
        return len(self.train_sampler)

    @lru_cache(maxsize=4)
    def get_split(self, batch_size=50, num_workers=4):
        logging.debug('Initializing train-validation-test dataloaders')
        self.train_loader = self.get_train_loader(batch_size=batch_size, num_workers=num_workers)
        self.val_loader = self.get_validation_loader(batch_size=batch_size, num_workers=num_workers)
        self.test_loader = self.get_test_loader(batch_size=batch_size, num_workers=num_workers)
        return self.train_loader, self.val_loader, self.test_loader

    @lru_cache(maxsize=4)
    def get_train_loader(self, batch_size=50, num_workers=4):
        logging.debug('Initializing train dataloader')
        self.train_loader = torch.utils.data.DataLoader(self.dataset, batch_size=batch_size, sampler=self.train_sampler, shuffle=False, num_workers=num_workers)
        return self.train_loader

    @lru_cache(maxsize=4)
    def get_validation_loader(self, batch_size=50, num_workers=4):
        logging.debug('Initializing validation dataloader')
        self.val_loader = torch.utils.data.DataLoader(self.dataset, batch_size=batch_size, sampler=self.val_sampler, shuffle=False, num_workers=num_workers)
        return self.val_loader

    @lru_cache(maxsize=4)
    def get_test_loader(self, batch_size=50, num_workers=4):
        logging.debug('Initializing test dataloader')
        self.test_loader = torch.utils.data.DataLoader(self.dataset, batch_size=batch_size, sampler=self.test_sampler, shuffle=False, num_workers=num_workers)
        return self.test_loader

In [None]:
df = pd.DataFrame(data.samples)#.iloc[:,0]
df = df.assign(sub_dataset = df.apply(lambda x: x[0].parts[-3], axis=1)) #.value_counts()

df = df.rename(columns={0:"path",
                        1:"family",
                        2:"genus",
                        3:"species",
                        4:"collection",
                        5:"catalog_number"})#.value_counts()



In [None]:
from sklearn.model_selection import train_test_split

targets = dataset.targets

train_idx, valid_idx = train_test_split(
                                        indices,
                                        test_size=test_split,
                                        random_state=seed,
                                        shuffle=True,
                                        stratify=targets)

print(np.unique(np.array(targets)[train_idx], return_counts=True))
print(np.unique(np.array(targets)[valid_idx], return_counts=True))


# val_split = 0.2
# test_split = 0.3
# total = 1.0
# trainval_split = total-test_split
# print(trainval_split)
# print(trainval_split - val_split)
# print((val_split/(trainval_split)))# - val_split)

(val_split*0.7)# + 0.7

In [None]:
# class FossilDatasetSubset(FossilDataset):
    
#     def __init__(self,
#                  split
#                  files: List[Path]=None,
#                  name: Optional[str]=None,
#                  return_items: List[str] = ["image","target","path"],
#                  image_return_type: str = "tensor",
#                  *args, **kwargs):
#                 ):

In [None]:
print('starting')

model = models.resnet18()
# inputs = torch.randn(5, 3, 224, 224)


In [None]:
from torch.utils.data import DataLoader
from torch import nn
batch_size = 64

dataloader = DataLoader(data,
                        batch_size=batch_size,
                        shuffle=False)
#                         sampler=None,
#                         batch_sampler=None,
#                         num_workers=0,
#                         collate_fn=None,
#                         pin_memory=False,
#                         drop_last=False,
#                         timeout=0,
#                         worker_init_fn=None)

# idx = [0,10,20,50,100]
# idx = 10
idx = list(range(0,1000,100))
print(len(idx))
data.display_grid(idx, repeat_n=1)


In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

# data.samples[0][0].parts[-3]

df = pd.DataFrame(data.samples)#.iloc[:,0]
df = df.assign(sub_dataset = df.apply(lambda x: x[0].parts[-3], axis=1)) #.value_counts()

df = df.rename(columns={0:"path",
                        1:"family",
                        2:"genus",
                        3:"species",
                        4:"collection",
                        5:"catalog_number"})#.value_counts()

# df.value_counts().plot(kind='bar')

chart = sns.catplot(
    data=df, #[data['Year'].isin([1980, 2008])],
    x='family',
    kind='count',
    palette='Set1',
    row='sub_dataset',
    aspect=3,
    height=3
)
chart.set_xticklabels(rotation=65, horizontalalignment='right')

In [None]:
by_sport = (df
            .groupby('family')
            .filter(lambda x : len(x) > 10)
            .groupby(['family', 'genus'])
#             .groupby(['genus', 'species'])
            .size()
            .unstack()
           )
by_sport

In [None]:
# def select_from_indices(data,
#                         indices: Sequence,
#                         update_class2idx: bool=False,
#                         x_col = 'path',
#                         y_col = "family") -> "FossilDataset":
#     """
#     Helper method to create a new FossilDataset containing only samples contained in `indices`
#     Useful for producing train/val/test splits
    
#     """
#     if update_class2idx:
#         class2idx=None
#     else:
#         class2idx=data.class2idx

    
#     df = pd.DataFrame(data.samples)
#     df = df.rename(columns={0:"path",
#                             1:"family",
#                             2:"genus",
#                             3:"species",
#                             4:"collection",
#                             5:"catalog_number"})#.value_counts()
    
#     df = df.iloc[indices,:]
    
#     files = df[x_col].to_list()

#     return FossilDataset(files=files,
#                          name=data.name,
#                          return_items=data.return_items,
#                          image_return_type=data.image_return_type,
#                          class2idx=class2idx)



# def filter_df_by_threshold(df: pd.DataFrame,
#                            threshold: int,
#                            y_col: str='family'):
#     """
#     Filter rare classes from dataset in a pd.DataFrame
    
#     Input:
#         df (pd.DataFrame):
#             Must contain at least 1 column with name given by `y_col`
#         threshold (int):
#             Exclude any rows from df that contain a `y_col` value with fewer than `threshold` members in all of df.
#         y_col (str): default="family"
#             The column in df to look for rare classes to exclude.
#     Output:
#         (pd.DataFrame):
#             Returns a dataframe with the same number of columns as df, and an equal or lower number of rows.
#     """
#     return df.groupby(y_col).filter(lambda x: len(x) >= threshold)



# def filter_samples_by_threshold(data: FossilDataset,
#                                 threshold: int=1,
#                                 update_class2idx: bool=True,
#                                 x_col = 'path',
#                                 y_col = "family") -> FossilDataset:
#     if update_class2idx:
#         class2idx=None
#     else:
#         class2idx=data.class2idx

        
#     df = pd.DataFrame(data.samples)
#     df = df.rename(columns={0:"path",
#                             1:"family",
#                             2:"genus",
#                             3:"species",
#                             4:"collection",
#                             5:"catalog_number"})#.value_counts()
    
#     df = filter_df_by_threshold(df=df,
#                                 threshold=threshold,
#                                 y_col=y_col)
        
#     files = df[x_col].to_list()

#     return FossilDataset(files=files,
#                          name=data.name,
#                          return_items=data.return_items,
#                          image_return_type=data.image_return_type,
#                          class2idx=class2idx)


In [None]:
# # dataset.targets

# @classmethod
# def create_trainvaltest_splits(cls,
#                                dataset,
#                                test_split: float=0.3,
#                                val_train_split: float=0.2,
#                                shuffle: bool=True,
#                                seed: int=3654):

#     dataset_size = len(dataset)
#     indices = np.arange(dataset_size)

#     samples = np.array(dataset.samples)
#     targets = np.array(dataset.targets)

#     train_val_idx, test_idx = train_test_split(
#                                                indices,
#                                                test_size=test_split,
#                                                random_state=seed,
#                                                shuffle=shuffle,
#                                                stratify=targets)

#     train_val_targets = targets[train_val_idx]

#     trainval_indices = np.arange(len(train_val_targets))
#     train_idx, val_idx = train_test_split(
#                                           trainval_indices,
#                                           test_size=val_train_split,
#                                           random_state=seed,
#                                           shuffle=shuffle,
#                                           stratify=train_val_targets)

#     train_data = data.select_from_indices(indices=train_idx,
#                                           update_class2idx=False,
#                                           x_col = 'path',
#                                           y_col = "family")


#     val_data = data.select_from_indices(indices=val_idx,
#                                         update_class2idx=False,
#                                         x_col = 'path',
#                                         y_col = "family")


#     test_data = data.select_from_indices(indices=test_idx,
#                                          update_class2idx=False,
#                                          x_col = 'path',
#                                          y_col = "family")


#     return train_data, val_data, test_data

In [None]:
plt.figure(figsize=(10,10))
g = sns.heatmap(
    by_sport, 
    square=True, # make cells square
    cbar_kws={'fraction' : 0.01}, # shrink colour bar
    cmap='OrRd', # use orange/red colour map
    linewidth=1 # space between cells
)

In [None]:
by_sport.shape

In [None]:
plt.figure(figsize=(10,10))
g = sns.heatmap(
    by_sport, 
    square=True,
    cbar_kws={'fraction' : 0.01},
    cmap='OrRd',
    linewidth=1
)

g.set_xticklabels(g.get_xticklabels(), rotation=45, horizontalalignment='right')
g.set_yticklabels(g.get_yticklabels(), rotation=45, horizontalalignment='right')

In [None]:
chart = sns.catplot(
    data=data[data['Year'].isin([1980, 2008])],
    x='Sport',
    kind='count',
    palette='Set1',
    row='Year',
    aspect=3,
    height=3
)

In [None]:
import collections
count_dist = collections.Counter(data.targets)
# count_dist.update(data.targets)

import matplotlib.pyplot as plt
import seaborn as sns

def plot_class_distributions(targets: List[Any])
test = count_dist #{1:1,2:1,3:1,4:2,5:3,6:5,7:4,8:2,9:1,10:1}
# with matplotlib
plt.hist(list(test.keys()), weights=list(test.values()))

test = sorted(test.items(), key = lambda x:x[1], reverse=True)

test

sns.set_style("whitegrid")

# keys = list(test.keys())
# values = list(test.values())

keys = [i[0] for i in test]
values = [i[1] for i in test]

plt.figure(figsize=(16,12))
chart = sns.histplot(x=keys, weights=values, discrete=True)
plt.xticks(
    rotation=45, 
    horizontalalignment='right',
    fontweight='light',
    fontsize='x-large'  
)








# chart.set_xticklabels(
#     chart.get_xticklabels(), 
#     rotation=45, 
#     horizontalalignment='right',
#     fontweight='light',
#     fontsize='x-large'
    
# )

# None 

In [None]:
# with seaborn (use hist_kws to send arugments to plt.hist, used underneath)
sns.distplot(range(len(list(test.keys()))), hist_kws={"weights":list(test.values())})

In [None]:
from pytorch_memlab import MemReporter


max_batches = 10

reporter = MemReporter()

for i, batch in enumerate(dataloader):
    print(i, len(batch), batch[0].shape)

    print('========= before backward =========')
    reporter.report()
    out = model(batch[0])
    
    loss = nn.functional.cross_entropy(out, batch[1])
    loss.backward()
    print('========= after backward =========')
    reporter.report()
    
    if i>=max_batches:
        break
    



In [None]:
reporter = MemReporter(model)


print('========= before loop =========')
reporter.report()
for batch in data[]
out.backward()
print('========= after backward =========')
reporter.report()
###################################################
import torch
from pytorch_memlab import MemReporter

lstm = torch.nn.LSTM(1024, 1024).cuda()
reporter = MemReporter(lstm)
reporter.report(verbose=True)
inp = torch.Tensor(10, 10, 1024).cuda()
out, _ = lstm(inp)
out.mean().backward()
reporter.report(verbose=True)





with torch.autograd.profiler.profile(use_cuda=True) as prof:
# with torch.autograd.profiler.profile() as prof:
    print('starting')
    inputs = torch.randn(5,3,224,224, device='cuda')
    print('half way')
    outputs = inputs + torch.randn(5,3,224,224, device='cuda')
    print('ending')
    
# print(prof)
print(prof.key_averages().table(sort_by="self_cpu_time_total"))

In [None]:
# from pytorch_memlab import MemReporter
# linear = torch.nn.Linear(1024, 1024).cuda()
# linear2 = torch.nn.Linear(1024, 1024).cuda()


# def inner():
#     torch.nn.Linear(100, 100).cuda()

# def outer():
#     linear = torch.nn.Linear(100, 100).cuda()
#     linear2 = torch.nn.Linear(100, 100).cuda()
#     inner()
# reporter = MemReporter()
# reporter.report()

linear = torch.nn.Linear(1024, 1024).cuda()
inp = torch.Tensor(512, 1024).cuda()
# pass in a model to automatically infer the tensor names
reporter = MemReporter(linear)
out = linear(inp).mean()
print('========= before backward =========')
reporter.report()
out.backward()
print('========= after backward =========')
reporter.report()
###################################################
import torch
from pytorch_memlab import MemReporter

lstm = torch.nn.LSTM(1024, 1024).cuda()
reporter = MemReporter(lstm)
reporter.report(verbose=True)
inp = torch.Tensor(10, 10, 1024).cuda()
out, _ = lstm(inp)
out.mean().backward()
reporter.report(verbose=True)

# import torch
# from pytorch_memlab import LineProfiler

# def inner():
#     torch.nn.Linear(100, 100).cuda()

# def outer():
#     linear = torch.nn.Linear(100, 100).cuda()
#     linear2 = torch.nn.Linear(100, 100).cuda()
#     inner()

# with LineProfiler(outer, inner) as prof:
#     outer()

In [None]:
%load_ext pytorch_memlab

In [None]:
prof.print_stats()

dir(prof)
# type(prof)

prof.display()

In [None]:
with torch.autograd.profiler.profile(use_cuda=True) as prof:
# with torch.autograd.profiler.profile() as prof:
    print('starting')
    inputs = torch.randn(5,3,224,224, device='cuda')
    print('half way')
    outputs = inputs + torch.randn(5,3,224,224, device='cuda')
    print('ending')
    
# print(prof)
print(prof.key_averages().table(sort_by="self_cpu_time_total"))

# if i % 1000 == 0:
#     print("Iteration: {}, memory: {}".format(i, psutil.virtual_memory()))

import torch
import torchvision.models as models
from torch.profiler import profile, record_function, ProfilerActivity

model = models.resnet18()
inputs = torch.randn(5, 3, 224, 224)

with profile(activities=[ProfilerActivity.CPU], record_shapes=True) as prof:
    with record_function("model_inference"):
        model(inputs)

print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))

In [None]:
# pl.profiler.PyTorchProfiler(dirpath=None,
#                             filename=None,
#                             group_by_input_shapes=False,
#                             emit_nvtx=False,
#                             export_to_chrome=True,
#                             row_limit=20,
#                             sort_by_key=None,
#                             record_functions=None,
#                             record_module_names=True,
#                             profiled_functions=None,
#                             output_filename=None, 
#                             **profiler_kwargs)

In [None]:
# idx = [0,1,2,3,4]
# idx = 10
# data.display_grid(idx, repeat_n=5)
data.display_grid(idx, repeat_n=2)

In [None]:
%%time
self=data
indices=10
repeat_n=5
from itertools import repeat, chain
from more_itertools import collapse
import random
indices = random.sample(range(self.num_samples), indices)
idx = collapse((repeat(i,repeat_n) for i in indices))

# print([i for i in idx])

In [None]:
# self = data
# indices = idx

# if isinstance(indices, int):
#     indices = random.sample(range(self.num_samples), indices)
# indices = list(indices)
# images = [self[idx][0] for idx in indices]
# labels = [self.classes[self[idx][1]] for idx in indices]
# labels

In [None]:
data.display_grid(idx)
plt.suptitle(f"{idx} random images")
plt.tight_layout()

In [None]:
# import matplotlib.pyplot as plt
# from PIL.Image import Image as PilImage
# import textwrap, os

# def display_images(
#     images: [PilImage], 
#     columns=5, max_images=15,
#     width=20, height=8,    
#     label_wrap_length=50, 
#     label_font_size=8):

#     if not images:
#         print("No images to display.")
#         return 

#     if len(images) > max_images:
#         print(f"Showing {max_images} images of {len(images)}:")
#         images=images[0:max_images]

#     rows = int(len(images)/columns)
        
#     height = max(height, rows * height)
#     plt.figure(figsize=(width, height))
#     for i, image in enumerate(images):

#         plt.subplot(rows + 1, columns, i + 1)
#         plt.imshow(image)

#         if hasattr(image, 'filename'):
#             title=image.filename
#             if title.endswith("/"): title = title[0:-1]
#             title=os.path.basename(title)
#             title=textwrap.wrap(title, label_wrap_length)
#             title="\n".join(title)
#             plt.title(title, fontsize=label_font_size); 

In [None]:
data[5][0]

import random
num_display = 12

indices = random.sample(range(data.num_samples), num_display)

indices

# indices = range(0,4)

display_images(
    images = [data[idx][0] for idx in indices],
    columns=5, max_images=15,
    width=20, height=8,    
    label_wrap_length=50, 
    label_font_size=8)

In [None]:
#     def parse_item(self, index: int):
#         path = self.files[index]
#         family, genus, species, collection, catalog_number = path.stem.split("_", maxsplit=4)
#         item = {"path":path,
#                 "target":None,
#                 "family":family,
#                 "genus":genus,
#                 "species":species,
#                 "collection":collection,
#                 "catalog_number":catalog_number}
#         item["target"] = item[self.class_type]

## Previous Fossil class def code, now relocated to fossil.py

In [None]:
from more_itertools import flatten
from dataclasses import dataclass

In [None]:
available_datasets = {
    "Wilf_Fossil_512": "/media/data_cifs/projects/prj_fossils/data/processed_data/leavesdb-v0_3/Fossil/ccrop_512/Wilf_Fossil",
    "Wilf_Fossil_1024": "/media/data_cifs/projects/prj_fossils/data/processed_data/leavesdb-v0_3/Fossil/ccrop_1024/Wilf_Fossil",
    "Wilf_Fossil_1536": "/media/data_cifs/projects/prj_fossils/data/processed_data/leavesdb-v0_3/Fossil/ccrop_1536/Wilf_Fossil",
    "Wilf_Fossil_2048": "/media/data_cifs/projects/prj_fossils/data/processed_data/leavesdb-v0_3/Fossil/ccrop_2048/Wilf_Fossil",
    
    "Florissant_Fossil_512": "/media/data_cifs/projects/prj_fossils/data/processed_data/leavesdb-v0_3/Fossil/ccrop_512/Florissant_Fossil",
    "Florissant_Fossil_1024": "/media/data_cifs/projects/prj_fossils/data/processed_data/leavesdb-v0_3/Fossil/ccrop_1024/Florissant_Fossil",
    "Florissant_Fossil_1536": "/media/data_cifs/projects/prj_fossils/data/processed_data/leavesdb-v0_3/Fossil/ccrop_1536/Florissant_Fossil",
    "Florissant_Fossil_2048": "/media/data_cifs/projects/prj_fossils/data/processed_data/leavesdb-v0_3/Fossil/ccrop_2048/Florissant_Fossil"
}

available_datasets["Fossil_512"] = ["/media/data_cifs/projects/prj_fossils/data/processed_data/leavesdb-v0_3/Fossil/ccrop_512/Wilf_Fossil",
                                    "/media/data_cifs/projects/prj_fossils/data/processed_data/leavesdb-v0_3/Fossil/ccrop_512/Florissant_Fossil"]
available_datasets["Fossil_1024"] = ["/media/data_cifs/projects/prj_fossils/data/processed_data/leavesdb-v0_3/Fossil/ccrop_1024/Wilf_Fossil",
                                     "/media/data_cifs/projects/prj_fossils/data/processed_data/leavesdb-v0_3/Fossil/ccrop_1024/Florissant_Fossil"]
available_datasets["Fossil_1536"] = ["/media/data_cifs/projects/prj_fossils/data/processed_data/leavesdb-v0_3/Fossil/ccrop_1536/Wilf_Fossil",
                                     "/media/data_cifs/projects/prj_fossils/data/processed_data/leavesdb-v0_3/Fossil/ccrop_1536/Florissant_Fossil"]
available_datasets["Fossil_2048"] = ["/media/data_cifs/projects/prj_fossils/data/processed_data/leavesdb-v0_3/Fossil/ccrop_2048/Wilf_Fossil",
                                     "/media/data_cifs/projects/prj_fossils/data/processed_data/leavesdb-v0_3/Fossil/ccrop_2048/Florissant_Fossil"]

fossil_collections = {"Florissant":"florissant_fossil",
                      "Wilf":"wilf_fossil"}

@dataclass
class DatasetConfig:
    name: str
    dataset: str=None
    collection: str=None
    resolution: int=None
        
    num_files: Optional[int]=None
    num_classes: Optional[int]=None
    class_type: str="family"
    path_schema: str = "{family}_{genus}_{species}_{collection}_{catalog_number}"
                
        
    def __init__(self, name: str, **kwargs):
        self.name = name
        parts = self.name.split("_")
        self.resolution = int(parts[-1])
        if len(parts)==3:
            self.dataset = parts[1]
            self.collection = "_".join(parts[:2])
        elif len(parts)==2:    
            self.dataset = parts[0]
            self.collection = ["_".join([c, self.dataset]) for c in fossil_collections.keys()]
            
        self.__dict__.update(kwargs)

    def __repr__(self):
        disp = f"""<{str(type(self)).strip("'>").split('.')[1]}>:"""
        
        disp += "\nFields:\n"
        for k in self.__dataclass_fields__.keys():
            disp += f"    {k}: {getattr(self,k)}\n"
        return disp
    

DatasetConfig("Fossil_512")

class FossilDataset(torchdata.datasets.Files): #ImageDataset):
    
#     loader: Callable = Image.open
    transform = None
    target_transform = None
    
    class_type: str="family"
    totensor: Callable = torchvision.transforms.ToTensor()
    def __init__(self,
                 files: List[Path],
                 return_items: List[str] = ["image","target","path"],
                 image_return_type: str = "tensor",
                 *args, **kwargs):
        super().__init__(files=files, *args, **kwargs)
        
        self.name = kwargs.get("name","")
        self.return_items = return_items
        self.image_return_type = image_return_type
        
        self.samples = [self.parse_item(idx) for idx in range((len(self)))]
        self.targets = [sample[1] for sample in self.samples]
        self._imgs = None
        self.classes = sorted(set(self.targets))
        self.class2idx = {name:idx for idx, name in enumerate(self.classes)}
        
        self.config = DatasetConfig(self.name,
                                    class_type=self.class_type,
                                    num_files=len(self.files),
                                    num_classes=len(self.classes)
                                   )

    def getitem(self, index: int):
        path, family, genus, species, collection, catalog_number = self.samples[index]
        img = Image.open(path)
        return img, family, path

#     @property
#     def transform(self) -> Callable:#, img: PIL.Image):
#         _transforms = []
#         if self.image_return_type == "tensor":
#             _transforms.append(self.totensor)
#         return lambda x: x
        
    def __getitem__(self, index: int):
        
        img, family, path = self.getitem(index)
        target = self.class2idx[family]
        
        if self.image_return_type == "tensor":
            img = self.totensor(img)
        if self.transform is not None:
            img = self.transform(img)
        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target, path
    
    
    def parse_item(self, index: int):
        path = self.files[index]
        family, genus, species, collection, catalog_number = path.stem.split("_", maxsplit=4)
        return path, family, genus, species, collection, catalog_number
    
    def __repr__(self):
        return self.config.__repr__()
        

    @classmethod
    def create_dataset(cls, name: str) -> "ImageDataset":
        dataset_dirs = available_datasets[name]
        if isinstance(available_datasets[name], str):
            dataset_dirs = [available_datasets[name]]
        assert isinstance(dataset_dirs, list)
        file_list = list(flatten(
                            [torchdata.datasets.Files.from_folder(Path(root),
                                                                  regex="*/*.jpg").files
                             for root in dataset_dirs]
                                                    ))
        data = FossilDataset(file_list,
                             name=name)

        return data #.map(lambda x: (torchvision.transforms.ToTensor()(x[0]), x[1]))
    
    
#     @classmethod
#     def create_dataset(cls, name: str) -> "ImageDataset":
#         dataset_dirs = available_datasets[name]
#         if isinstance(available_datasets[name], list):
#             file_list = list(flatten(
#                                 [torchdata.datasets.Files.from_folder(Path(root),
#                                                         regex="*/*.jpg").files
#                                for root in available_datasets[name]]
#                                                             ))
            
#         elif isinstance(available_datasets[name], str):
#             file_list = torchdata.datasets.Files.from_folder(Path(available_datasets[name]),
#                                                              regex="*/*.jpg").files

#         data = FossilDataset(file_list,
#                                  name=name)

#         return data.map(torchvision.transforms.ToTensor())

fossil_data = FossilDataset.create_dataset(name="Fossil_1024")
fossil_data

# fossil_data = FossilDataset.create_dataset(name="Florissant_Fossil_1024")
# fossil_data

# fossil_data = FossilDataset.create_dataset(name="Wilf_Fossil_1024")
# fossil_data

### Future todo: Separate subclass of simpler Leaves/Fossil Dataset class to allow for more customization of return signatures (allowing dict records instead of tuple, multiple labels per sample)

In [None]:
class MultiLabelFossilDataset(FossilDataset): #ImageDataset):
    
    loader: Callable = Image.open
    transform = torchvision.transforms.ToTensor()
    target_transform = None
    
    class_type: str="family"
    
    def __init__(self,
                 files: List[Path],
                 return_items: List[str,str] = ["image","target","path"],
                 *args, **kwargs):
        super().__init__(files=files, *args, **kwargs)
        
        self.name = kwargs.get("name","")
        self.return_items = return_items
        
        self.samples = [self.parse_item(idx) for idx in range((len(self)))]            
        self.targets = [sample[1] for sample in self.samples]
        self.classes = sorted(set(self.targets))
        self.class2idx = {name:idx for idx, name in enumerate(self.classes)}
        
        self.config = DatasetConfig(self.name,
                                    class_type=self.class_type,
                                    num_files=len(self.files),
                                    num_classes=len(self.classes)
                                    )
        
        
    
        
    def get_item(self, index: int):
        item = self.samples[index]
        img = self.loader(item["path"])
        if self.transform is not None:
            img = self.transform(img)
        if self.target_transform is not None:
            target = self.target_transform(target)

        return self.get_item((img, target, path))
    
        return Image.open(item[0]), self.class2idx[item[1]]

    
        
    def __getitem__(self, index: int):
        item = self.samples[index]
        
        img = self.loader(path)
        if self.transform is not None:
            img = self.transform(img)
        if self.target_transform is not None:
            target = self.target_transform(target)

        return self.get_item((img, target, path))
#         return Image.open(item[0]), self.class2idx[item[1]]
    
    def parse_item(self, index: int):
        path = self.files[index]
        family, genus, species, collection, catalog_number = path.stem.split("_", maxsplit=4)
        item = {"path":path,
                "target":None,
                "family":family,
                "genus":genus,
                "species":species,
                "collection":collection,
                "catalog_number":catalog_number}
        item["target"] = item[self.class_type]
#         return path, family, genus, species, collection, catalog_number
    
    def __repr__(self):
        return self.config.__repr__()
        
#     @property
#     def config(self):
#         return DatasetConfig(self.name,
#                              num_files=len(self.files),
#                              num_classes=len(self.classes)
#                             )
    @classmethod
    def create_dataset(cls, name: str) -> "ImageDataset":
        dataset_dirs = available_datasets[name]
        if isinstance(available_datasets[name], str):
            dataset_dirs = [available_datasets[name]]
        assert isinstance(dataset_dirs, list)
        file_list = list(flatten(
                            [torchdata.datasets.Files.from_folder(Path(root),
                                                                  regex="*/*.jpg").files
                             for root in dataset_dirs]
                                                    ))
        data = FossilDataset(file_list,
                             name=name)

        return data.map(lambda x: (torchvision.transforms.ToTensor()(x[0]), x[1]))
    
    
#     @classmethod
#     def create_dataset(cls, name: str) -> "ImageDataset":
#         dataset_dirs = available_datasets[name]
#         if isinstance(available_datasets[name], list):
#             file_list = list(flatten(
#                                 [torchdata.datasets.Files.from_folder(Path(root),
#                                                         regex="*/*.jpg").files
#                                for root in available_datasets[name]]
#                                                             ))
            
#         elif isinstance(available_datasets[name], str):
#             file_list = torchdata.datasets.Files.from_folder(Path(available_datasets[name]),
#                                                              regex="*/*.jpg").files

#         data = FossilDataset(file_list,
#                                  name=name)

#         return data.map(torchvision.transforms.ToTensor())

In [None]:
Image.open(fossil_data.samples[0][0])

In [None]:
fossil_data.class2idx

In [None]:
# class ImageDataset(torchdata.datasets.Files):
    
#     def __getitem__(self, index):
#         return Image.open(self.files[index])
    
# #     def __repr__(self):
# #         return f"""{self.kwargs['name']}"""
# #         return f"""ImageDataset: {self.kwargs['name']}"""

#     def __init__(self, files: List[Path], *args, **kwargs):
#         super().__init__(files=files, *args, **kwargs)
        
# #         self.name = kwargs.get("name","")
# #         self.cfg = DatasetConfig(self.name)


#     @classmethod
#     def create_dataset(cls, name: str) -> "ImageDataset":
#         dataset_dirs = available_datasets[name]

#         if isinstance(available_datasets[name], str):
#             data = ImageDataset.from_folder(Path(available_datasets[name]),
#                                             regex="*/*.jpg",
#                                             name=name)
#         return data.map(torchvision.transforms.ToTensor())    

In [None]:
from IPython.core.debugger import set_trace

import os
import types
os.environ['CUDA_VISIBLE_DEVICES'] = "0"
import numpy as np
import pytorch_lightning as pl
from torchvision import models
import torchvision
import torch
import timm
from rich import print
import matplotlib.pyplot as plt
from contrastive_learning.data.pytorch.pnas import PNASLightningDataModule
from contrastive_learning.data.pytorch.extant import ExtantLightningDataModule
from contrastive_learning.data.pytorch.common import DataStageError, LeavesLightningDataModule

from lightning_hydra_classifiers.callbacks.wandb_callbacks import WatchModelWithWandb, LogPerClassMetricsToWandb, WandbClassificationCallback # LogConfusionMatrixToWandb
from lightning_hydra_classifiers.models.resnet import ResNet, get_scalar_metrics
import lightning_hydra_classifiers
from torch import nn
import inspect

import wandb
pl.trainer.seed_everything(seed=9)

    
# class Config:
#     pass


# config = Config()

# # config.model_name = 'resnet50'
# # config.dataset_name = 'PNAS_family_100_512'
# config.dataset_name = '(Extant-PNAS)_family_10_512'
# config.normalize = True
# config.num_workers = 4
# config.batch_size = 16

# config = Box({
#     "dataset":{
#         namef"PNAS_{label_type}_{pnas_threshold}"
#     }
    
# })

########################################
# if 'Extant' in config.dataset_name:
#     datamodule = ExtantLightningDataModule(name=config.dataset_name, batch_size=config.batch_size, debug=False, normalize=config.normalize, num_workers=config.num_workers)
# elif 'PNAS' in config.dataset_name:
#     datamodule = PNASLightningDataModule(name=config.dataset_name, batch_size=config.batch_size, debug=False, normalize=config.normalize, num_workers=config.num_workers)#, normalize=False)#True)
# datamodule.setup('fit')
# ########################################
# num_classes = len(datamodule.classes)
# config.num_classes = num_classes

In [None]:
from box import Box
import os

os.environ['WANDB_CACHE_DIR'] = "/media/data/jacob/wandb_cache"
class_type = "family"
extant_threshold = 10
pnas_threshold = 100
image_size = 512
seed = 257

config = Box({})

config.datasets = [{
                  "name": f"PNAS_{class_type}_{pnas_threshold}_{image_size}",
                  "batch_size":32,
                  "val_split":None, # TODO specify split explicitly in wandb report
                  "num_workers":4,
                  "image_size":image_size,
                  "channels":3,
                  "class_type":class_type,
                  "debug":False,
                  "normalize":True,
                  "seed":seed,
                  "dataset_dir":None,
                  "predict_on_split":"val",
                  },
    {
                  "name":f"Extant_{class_type}_{extant_threshold}_{image_size}",  # f"PNAS_{label_type}_{pnas_threshold}_{image_size}"
                  "batch_size":32,
                  "val_split":None, # TODO specify split explicitly in wandb report
                  "num_workers":4,
                  "image_size":image_size,
                  "channels":3,
                  "class_type":class_type,
                  "debug":False,
                  "normalize":True,
                  "seed":seed,
                  "dataset_dir":None,
                  "predict_on_split":"val",
                  }]



config.wandb = {
                "init":
                       {
                        "entity":"jrose",
                        "project":"image_classification_datasets",
                        "job_type":'create-dataset',
                        "group":None,
                        "run_dir":os.environ['WANDB_CACHE_DIR'],
                        "tags":[d.name for d in config.datasets]
                       },
                "artifacts":
                        {
                        "root_dir":None
                        },
                "input_artifacts":
                       [
                           {
                            "entity":"jrose",
                            "project":"image_classification_datasets",
                            "name": config.datasets[0].name,
                            "version": "v6",
                            "type": "raw_data",
                            "root_dir":None,
                            "uri":None
                           }
                       ]
}

i = 0

config.wandb.artifacts.root_dir = os.path.join(config.wandb.init.run_dir,
                                               "artifacts")

config.wandb.input_artifacts[i].uri = "/".join([config.wandb.input_artifacts[i].entity,
                                                config.wandb.input_artifacts[i].project,
                                                config.wandb.input_artifacts[i].name]) \
                                           + f':{config.wandb.input_artifacts[i].version}'


config.wandb.input_artifacts[i].root_dir = os.path.join(config.wandb.artifacts.root_dir,
                                                        "datasets",
                                                         config.wandb.input_artifacts[i].name \
                                                         + f':{config.wandb.input_artifacts[i].version}'
                                                        )


# def fetch_datamodule_from_dataset_artifact(config: Box, run_or_api=None) -> LeavesLightningDataModule:
#     run = run_or_api or wandb.Api()
#     artifact = run.use_artifact(config.wandb.input_artifact.uri,
#                                 type=config.wandb.input_artifact.type)
#     dataset_artifact_dir = artifact.download(root=config.wandb.input_artifact.root_dir)


#     datamodule = get_datamodule(config.dataset)
#     datamodule.setup('fit')
#     datamodule.setup('test')
#     ########################
#     config.model.num_classes = config.dataset.num_classes

def fetch_datamodule_from_dataset_artifact(config: Box, run_or_api=None) -> LeavesLightningDataModule:
    run = run_or_api or wandb.Api()
    artifact = run.use_artifact(config.wandb.input_artifact.uri,
                                type=config.wandb.input_artifact.type)
    dataset_artifact_dir = artifact.download(root=config.wandb.input_artifact.root_dir)


    datamodule = get_datamodule(config.dataset)
    datamodule.setup('fit')
    datamodule.setup('test')
    ########################
    config.model.num_classes = config.dataset.num_classes


In [None]:
import pandas as pd
import numpy as np
from PIL import Image, ImageStat
import seaborn_image as isns
import scipy

def image_stat(img: np.ndarray):
    if img.ndim==3:
        h, w, c = img.shape
    else:
        h, w, c = (*img.shape, 1)
    return {
        "min":np.min(img),
        "max":np.max(img),
        "var":np.var(img),
        "mean":np.mean(img),
        "mode":scipy.stats.mode(img,axis=None),
        "height":h,
        "width":w,
        "channels":c,
        "num_pixels":h*w*c
    }


def load_and_analyze_image(image_path: str):
    img = np.array(Image.open(image_path))
    return img, image_stat(img)

# def load_analyze_and_save_annotated_image(image_path: str):
#     img, stats = load_and_analyze_image(image_path)
#     return img, image_stat(img)




def fig2data ( fig ):
    """
    @brief Convert a Matplotlib figure to a 4D numpy array with RGBA channels and return it
    @param fig a matplotlib figure
    @return a numpy 3D array of RGBA values
    """
    # draw the renderer
    fig.canvas.draw()
 
    # Get the RGBA buffer from the figure
    w,h = fig.canvas.get_width_height()
    buf = np.fromstring(fig.canvas.tostring_argb(), dtype=np.uint8)
    buf.shape = ( w, h, 4 )
 
    # canvas.tostring_argb give pixmap in ARGB mode. Roll the ALPHA channel to have it in RGBA mode
    buf = np.roll( buf, 3, axis = 2 )
    return buf

 
def fig2img ( fig ):
    """
    @brief Convert a Matplotlib figure to a PIL Image in RGBA format and return it
    @param fig a matplotlib figure
    @return a Python Imaging Library ( PIL ) image
    """
    # put the figure pixmap into a numpy array
    buf = fig2data(fig)
    w, h, d = buf.shape
    return Image.frombytes("RGBA", (w, h), buf.tostring())

# f = isns.imghist(img_array,
#                  describe=True)
# results = fig2img(f)

In [None]:
config.wandb.input_artifacts[0].uri

In [None]:
# def trainvaltest_split(x: Union[List[Any],np.ndarray]=None,
#                        y: Union[List[Any],np.ndarray]=None,
#                        splits: List[float]=(0.5, 0.2, 0.3),
#                        random_state: int=None,
#                        stratify: bool=True
#                        ) -> Dict[str,Tuple[np.ndarray]]:
#     """
#     Wrapper function to split data into 3 stratified subsets specified by `splits`.
    
#     User specifies absolute fraction of total requested for each subset (e.g. splits=[0.5, 0.2, 0.3])
    
#     Function calculates adjusted fractions necessary in order to use sklearn's builtin train_test_split function over a sequence of 2 steps.
    
#     Step 1: Separate test set from the rest of the data (constituting the union of train + val)
    
#     Step 2: Separate the train and val sets from the remainder produced by step 1.
    
    
    
#     Output:
#         Dict: {'train':(x_train, y_train),
#                 'val':(x_val_y_val),
#                 'test':(x_test, y_test)}
    
#     """
    
    
#     assert len(splits) == 3, "Must provide eactly 3 float values for `splits`"
#     assert np.isclose(np.sum(splits), 1.0), f"Sum of all splits values {splits} = {np.sum(splits)} must be 1.0"
    
#     train_split, val_split, test_split = splits
#     val_relative_split = val_split/(train_split + val_split)
#     train_relative_split = train_split/(train_split + val_split)
    
    
#     if stratify and (y is None):
#         raise ValueError("If y is not provided, stratify must be set to False.")
    
#     y = np.array(y)
#     if x is None:
#         x = np.arange(len(y))
#     else:
#         x = np.array(x)
    
#     stratify_y = y if stratify else None    
#     x_train_val, x_test, y_train_val, y_test = train_test_split(x, y,
#                                                         test_size=test_split, 
#                                                         random_state=random_state,
#                                                         stratify=y)
#     log.info(f"(x_train+x_val).shape={x_train_val.shape}, (y_train+y_val).shape={y_train_val.shape}")
#     log.info(f"x_test.shape={x_test.shape}, y_test.shape={y_test.shape}")
    
# #     print(f"(x_train+x_val).shape={x_train_val.shape}, (y_train+y_val).shape={y_train_val.shape}")
# #     print(f"x_test.shape={x_test.shape}, y_test.shape={y_test.shape}")

#     stratify_y_train = y_train_val if stratify else None
#     x_train, x_val, y_train, y_val = train_test_split(x_train_val, y_train_val,
#                                                       test_size=val_relative_split,
#                                                       random_state=random_state, 
#                                                       stratify=y_train_val)
    
#     x = np.concatenate((x_train, x_val, x_test)).tolist()
#     assert len(set(x)) == len(x), f"[Warning] Check for possible data leakage. len(set(x))={len(set(x))} != len(x)={len(x)}"
    
#     log.info(f"x_train.shape={x_train.shape}, y_train.shape={y_train.shape}")
#     log.info(f"x_val.shape={x_val.shape}, y_val.shape={y_val.shape}")
    
#     log.info(f'Absolute splits: {[train_split, val_split, test_split]}')
#     log.info(f'Relative splits: [{train_relative_split:.2f}, {val_relative_split:.2f}, {test_split}]')
    
#     return {"train":(x_train, y_train),
#             "val":(x_val, y_val),
#             "test":(x_test, y_test)}

#####################

# y = data.targets

# data_splits = trainvaltest_split(x=None,
#                                  y=y,
#                                  splits=(0.5, 0.2, 0.3),
#                                  random_state=0,
#                                  stratify=True)

In [None]:
import wandb
api = wandb.Api()
artifact = api.artifact(config.wandb.input_artifacts[0].uri)

In [None]:
artifact.data

In [None]:
# dir(artifact)
# dir(artifact.manifest)
# artifact.manifest.entries

In [None]:
data = artifact.get('dataset/test.table.json')

In [None]:
type(data)

In [None]:
df=data.data

print(data.columns)


data_df = pd.DataFrame(data=df, columns=data.columns)

In [None]:
wide_samples = data_df
samples = list(data_df[['image', 'label']].itertuples())

In [None]:

num_classes = len(set(wide_samples.label.values))
plt.bar(range(num_classes), wide_samples.groupby("label")["catalog_number"].count())

In [None]:
# in_mem = data_df.image.apply(lambda x: np.array(x._image))
in_mem = data_df.image.apply(lambda x: x._image)

In [None]:
in_mem[199]

In [None]:
np.array(data_df.image[0]._image).shape


in_mem[0]

In [None]:


from IPython.display import display

df = data.get_column('image')

data.columns

In [None]:
dir(artifact)

downloaded_artifact = artifact.checkout(root=config.wandb.input_artifacts[0].root_dir)

In [None]:
# os.path.abspath
(downloaded_artifact)

In [None]:
from contrastive_learning.data.pytorch.pnas import PNASLeavesDataset
from contrastive_learning.data.pytorch.extant import ExtantLeavesDataset
# from contrastive_learning.data.pytorch.common import DataStageError
from paleoai_data.dataset_drivers import base_dataset

# Step 1. Instantiate PyTorch Datasets for each of Extant Leaves & PNAS, separately
pnas_torch = PNASLeavesDataset(name = f"PNAS_{label_type}_{pnas_threshold}",
                 split: str="train",
                 dataset_dir: Optional[str]=None,
                 return_paths: bool=False,)
extant_torch = ExtantLeavesDataset




# def create_dataset_by_name(name: str,
#                            version: str='v0.2',
#                            exclude_classes = ['notcataloged','notcatalogued', 'II. IDs, families uncertain', 'Unidentified']):
#     data_df = query_db(version=version, **{'dataset':name})
#     dataset = base_dataset.BaseDataset.from_dataframe(df=data_df, name=name, exclude_classes=exclude_classes)
#     return dataset

In [None]:
datamodule.train_dataset[0][0].shape

In [None]:
with wandb.init(project=WANDB_PROJECT, job_type="model_result_analysis") as run:
    
    # Retrieve the original raw dataset
    dataset_artifact = run.use_artifact("raw_data:latest")
    data_table = dataset_artifact.get("raw_examples")
    
    # Retrieve the train and test score tables
    train_artifact = run.use_artifact("train_results:latest")
    train_table = train_artifact.get("train_iou_score_table")

In [None]:
# data = PNASLightningDataModule(batch_size=16)
# data = ExtantLightningDataModule(batch_size=16, num_workers=12)
# data.setup(stage='fit')

# data.setup(stage='test')

# data.setup(stage=None)

# try:
#     data.setup(stage='other')
#     print('success')
# except DataStageError as e:
#     print(e.with_traceback(None))

In [None]:
data.setup(stage='fit')
train_dataloader = data.get_dataloader(stage='train')
val_dataloader = data.get_dataloader(stage='val')
data.setup(stage='test')
test_dataloader = data.get_dataloader(stage='test')

# train_dataloader
#         if stage=='train': return self.train_dataloader()
#         if stage=='val': return self.val_dataloader()
#         if stage=='test': return self.test_dataloader()


In [None]:


# data.train_dataset.transform = None #data.default_train_transforms() #None
x, y = data.train_dataset[0]
# print(x.shape)

In [None]:
# from PIL import ImageOps

# print(x.max(), x.min())
# plt.imshow(ImageOps.invert(x))#.permute(1,2,0))

In [None]:
%%time
batch_idx = 0

data.show_batch('train', batch_idx=batch_idx)
# data.show_batch('train', cmap='plasma')
plt.savefig(f'ExtantLeaves v0_3 train batch {batch_idx}.png')

data.show_batch('val', batch_idx=batch_idx)
plt.savefig(f'ExtantLeaves v0_3 val batch {batch_idx}.png')

data.show_batch('test', batch_idx=batch_idx)
plt.savefig(f'ExtantLeaves v0_3 test batch {batch_idx}.png')
# data.show_batch('train', cmap='magma')
# data.show_batch('train', cmap='cividis')

In [None]:
self = data
stage = 'test'
batch_idx = 0

x, y = self.get_batch(stage=stage, batch_idx=batch_idx)

x = x[:12,...]

batch_size = x.shape[0]

fig, ax = plt.subplots(1,1, figsize=(24,24))
grid_img = torchvision.utils.make_grid(x, nrow=int(np.ceil(np.sqrt(batch_size))))

img_min, img_max = grid_img.min(), grid_img.max()
print(img_min, img_max)

grid_img = (grid_img - img_min)/(img_max - img_min)
img_min, img_max = grid_img.min(), grid_img.max()
print(img_min, img_max)



print('before:', grid_img.shape)

if torch.argmin(torch.Tensor(grid_img.shape)) == 0:
    grid_img = grid_img.permute(1,2,0)
print('after:', grid_img.shape)

img_ax = ax.imshow(grid_img[:,:,0], cmap='viridis')#, vmin = img_min, vmax = img_max)
fig.colorbar(img_ax, ax=ax)#)#cax=ax)
plt.axis('off')
plt.suptitle(f'{stage} batch')
#         return fig, ax

help(plt.imshow)

%debug

x, y = next(iter(train_dataloader))

x.min()

plt.imshow(x[1,...].permute(1,2,0))