# Initialization

In [None]:
from dataclasses import dataclass

import pandas as pd
import numpy as np
import random

# import plotly.io as pio
# pio.renderers.default = "vscode"
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.express as px
import matplotlib.pyplot as plt

import matplotlib
matplotlib.rcParams['figure.dpi']=120
matplotlib.rcParams['font.size'] = 14

from collections import defaultdict
from contextlib import contextmanager

import pickle
import json
import os
from pathlib import Path
from os.path import join as pjoin
from tqdm.auto import tqdm
from copy import deepcopy
import shutil
from time import sleep

from sklearn.model_selection import ParameterGrid
from sklearn.metrics import classification_report
from sklearn.metrics import accuracy_score
from sklearn.preprocessing import OneHotEncoder

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.utils.parametrize as nn_parametrize

from skorch import NeuralNetClassifier
from skorch.callbacks import EpochScoring, Callback
from skorch.dataset import ValidSplit

In [None]:
from typing import NamedTuple

import numpy as np

from pathlib import Path
from ATTOP.data.dataset import sample_negative as ATTOP_sample_negative

import torch
from torch.utils import data

from COSMO_utils import temporary_random_numpy_seed


def get_and_update_num_calls(func_ptr):
    try:
        get_and_update_num_calls.num_calls_cnt[func_ptr] += 1
    except AttributeError as e:
        if 'num_calls_cnt' in repr(e):
            get_and_update_num_calls.num_calls_cnt = defaultdict(int)
        else:
            raise

    return get_and_update_num_calls.num_calls_cnt[func_ptr]

def categorical_histogram(data, labels_list, plot=True, frac=True, plt_show=False):
    import matplotlib.pyplot as plt
    s_counts = pd.Series(data).value_counts()
    s_frac = s_counts/s_counts.sum()
    hist_dict = s_counts.to_dict()
    if frac:
        hist_dict = s_frac.to_dict()
    hist = []
    for ix, _ in enumerate(labels_list):
        hist.append(hist_dict.get(ix, 0))

    if plot:
        pd.Series(hist, index=labels_list).plot(kind='bar')
        if frac:
            plt.ylim((0,1))
        if plt_show:
            plt.show()
    else:
        return np.array(hist, dtype='float32')

class DataItem(NamedTuple):
    """ A NamedTuple for returning a Dataset item """
    feat: torch.Tensor
    pos_attr_id: int
    pos_obj_id: int
    neg_attr_id: int
    neg_obj_id: int
    image_fname: str


class CompDataFromDict():
    # noinspection PyMissingConstructor
    def __init__(self, dict_data: dict, data_subset: str, data_dir: str):

        # define instance variables to be retrieved from struct_data_dict
        self.split: str = 'TBD'
        self.phase: str = 'TBD'
        self.feat_dim: int = -1
        self.objs: list = []
        self.attrs: list = []
        self.attr2idx: dict = {}
        self.obj2idx: dict = {}
        self.pair2idx: dict = {}
        self.seen_pairs: list = []
        self.all_open_pairs: list = []
        self.closed_unseen_pairs: list = []
        self.unseen_closed_val_pairs: list = []
        self.unseen_closed_test_pairs: list = []
        self.train_data: tuple = tuple()
        self.val_data: tuple = tuple()
        self.test_data: tuple = tuple()

        self.data_dir: str = data_dir

        # retrieve instance variables from struct_data_dict
        vars(self).update(dict_data)
        self.data = dict_data[data_subset]

        self.activations = {}
        features_dict = torch.load(Path(data_dir) / 'features.t7')
        for i, img_filename in enumerate(features_dict['files']):
            self.activations[img_filename] = features_dict['features'][i]

        self.input_shape = (self.feat_dim,)
        self.num_objs = len(self.objs)
        self.num_attrs = len(self.attrs)
        self.num_seen_pairs = len(self.seen_pairs)
        self.shape_obj_attr = (self.num_objs, self.num_attrs)

        self.flattened_seen_pairs_mask = self.get_flattened_pairs_mask(self.seen_pairs)
        self.flattened_closed_unseen_pairs_mask = self.get_flattened_pairs_mask(self.closed_unseen_pairs)
        self.flattened_all_open_pairs_mask = self.get_flattened_pairs_mask(self.all_open_pairs)
        self.seen_pairs_joint_class_ids = np.where(self.flattened_seen_pairs_mask)

        self.y1_freqs, self.y2_freqs, self.pairs_freqs = self._calc_freqs()
        self._just_load_labels = False

        self.train_pairs = self.seen_pairs

    def sample_negative(self, attr, obj):
        return ATTOP_sample_negative(self, attr, obj)

    def get_flattened_pairs_mask(self, pairs):
        pairs_ids = np.array([(self.obj2idx[obj], self.attr2idx[attr]) for attr, obj in pairs])
        flattened_pairs = np.zeros(self.shape_obj_attr, dtype=bool)  # init an array of False
        flattened_pairs[tuple(zip(*pairs_ids))] = True
        flattened_pairs = flattened_pairs.flatten()
        return flattened_pairs

    def just_load_labels(self, just_load_labels=True):
        self._just_load_labels = just_load_labels

    def get_all_labels(self):
        attrs = []
        objs = []
        joints = []
        self.just_load_labels(True)
        for attrs_batch, objs_batch in self:
            if isinstance(attrs_batch, torch.Tensor):
                attrs_batch = attrs_batch.cpu().numpy()
            if isinstance(objs_batch, torch.Tensor):
                objs_batch = objs_batch.cpu().numpy()
            joint = self.to_joint_label(objs_batch, attrs_batch)

            attrs.append(attrs_batch)
            objs.append(objs_batch)
            joints.append(joint)

        self.just_load_labels(False)
        attrs = np.array(attrs)
        objs = np.array(objs)
        return attrs, objs, joints

    def _calc_freqs(self):
        y2_train, y1_train, ys_joint_train = self.get_all_labels()
        y1_freqs = categorical_histogram(y1_train, range(self.num_objs), plot=False, frac=True)
        y1_freqs[y1_freqs == 0] = np.nan
        y2_freqs = categorical_histogram(y2_train, range(self.num_attrs), plot=False, frac=True)
        y2_freqs[y2_freqs == 0] = np.nan

        pairs_freqs = categorical_histogram(ys_joint_train,
                                            range(self.num_objs * self.num_attrs),
                                            plot=False, frac=True)
        pairs_freqs[pairs_freqs == 0] = np.nan
        return y1_freqs, y2_freqs, pairs_freqs

    def get(self, name):
        return vars(self).get(name)

    def __getitem__(self, idx):
        image_fname, attr, obj = self.data[idx]
        pos_attr_id, pos_obj_id = self.attr2idx[attr], self.obj2idx[obj]
        if self._just_load_labels:
            return pos_attr_id, pos_obj_id

        num_calls_cnt = get_and_update_num_calls(self.__getitem__)

        negative_attr_id, negative_obj_id = -1, -1  # default values
        if self.phase == 'train':
            # we set a temp np seed to override a weird issue with
            # sample_negative() at __getitem__, where the sampled pairs
            # could not be deterministically reproduced:
            # Now at each call to _getitem_ we set the seed to a 834276 (chosen randomly) + the number of calls to _getitem_
            with temporary_random_numpy_seed(834276 + num_calls_cnt):
                # draw a negative pair
                negative_attr_id, negative_obj_id = self.sample_negative(attr, obj)

        item = DataItem(
            feat=self.activations[image_fname],
            pos_attr_id=pos_attr_id,
            pos_obj_id=pos_obj_id,
            neg_attr_id=negative_attr_id,
            neg_obj_id=negative_obj_id,
            image_fname=image_fname,
        )
        return item

    def __len__(self):
        return len(self.data)

    def to_joint_label(self, y1_batch, y2_batch):
        return (y1_batch * self.num_attrs + y2_batch)


def get_data_loaders(train_dataset, valid_dataset, test_dataset, batch_size,
                     num_workers=10, test_batchsize=None, shuffle_eval_set=True):
    if test_batchsize is None:
        test_batchsize = batch_size

    pin_memory = True
    if num_workers == 0:
        pin_memory = False
    print('num_workers = ', num_workers)
    print('pin_memory = ', pin_memory)
    train_loader = data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers,
                                   pin_memory=pin_memory)
    valid_loader = None
    if valid_dataset is not None and len(valid_dataset) > 0:
        valid_loader = data.DataLoader(valid_dataset, batch_size=test_batchsize, shuffle=shuffle_eval_set,
                                       num_workers=num_workers, pin_memory=pin_memory)
    test_loader = data.DataLoader(test_dataset, batch_size=test_batchsize, shuffle=shuffle_eval_set,
                                  num_workers=num_workers, pin_memory=pin_memory)
    return test_loader, train_loader, valid_loader

# Loading

## Getting global combinations

In [None]:
dataset_name = 'ao_clevr'
data_dir = f'data/{dataset_name}'
label_cols = ['shape', 'color']
# --------------------------------

meta_df = pd.read_csv(pjoin(data_dir, 'objects_metadata.csv'))
global_label_combs = meta_df[label_cols].groupby(label_cols).size()
global_label_combs.name = 'samples'
global_label_combs = global_label_combs.reset_index()
global_label_combs.index.name = 'comb idx'
global_label_combs

In [None]:
dataset_variant = 'VT'
seen_seed = 0
num_split = 2000
# --------------------------------
meta_path = Path(f"{data_dir}/metadata_pickles")
random_state_path = Path(f"{data_dir}/np_random_state_pickles")
# meta_path = meta_path.expanduser()

dict_data = dict()

for subset in ['train', 'valid', 'test']:
    metadata_full_filename = meta_path / f"metadata_{dataset_name}__{dataset_variant}_random__comp_seed_{num_split}__seen_seed_{seen_seed}__{subset}.pkl"
    dict_data[f'{subset}'] = deepcopy(pickle.load(open(metadata_full_filename, 'rb')))

np_rnd_state_fname = random_state_path / f"np_random_state_{dataset_name}__{dataset_variant}_random__comp_seed_{num_split}__seen_seed_{seen_seed}.pkl"
np_seed_state = pickle.load(open(np_rnd_state_fname, 'rb'))
np.random.set_state(np_seed_state)

datasets = {}
for phase in ['train', 'val', 'test']:
    datasets[phase] = CompDataFromDict(dict_data[phase if phase!='val' else 'valid'],
        data_subset=f'{phase}_data', data_dir=data_dir)

## Preparing data

In [None]:
X, Y_comb = {}, {}

for phase in ['train', 'val', 'test']:
    dataset = datasets[phase]
    data_df = pd.DataFrame(dataset.data, columns=['filename', 'color', 'shape'])
    data_df = data_df.merge(global_label_combs.reset_index(), on=label_cols)
    Y_comb[phase] = data_df['comb idx'].values.astype('int64')
    print(f"{phase} contains {len(data_df)} samples")

    features = dataset.activations
    X[phase] = torch.cat([features[filename].unsqueeze(0)
        for filename in data_df['filename']], dim=0)