In [25]:
%matplotlib inline
from collections import Counter

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import numpy as np
from  scipy.special import softmax as sm
from tqdm import tqdm

In [26]:
# Global Vars
CLASS_NO = 2
BIMODAL_CONCEPT_NO = 40
MODE_A_CONCEPT_NO = 300
MODE_A_CONCEPT_DIM_OG = 30
MODE_B_CONCEPT_NO = 300
MODE_B_CONCEPT_DIM_OG = 30
DATASET_SIZE = 100000

In [27]:
# Conditional Distributions [Can be scaled arbitrarily]
class_dist = sm(np.random.normal(size=(CLASS_NO)))
class_to_bimodal = sm(np.random.normal(size=(CLASS_NO, BIMODAL_CONCEPT_NO)),axis=1)
bimodal_to_mode_a = sm(np.random.normal(size=(BIMODAL_CONCEPT_NO, MODE_A_CONCEPT_NO)),axis=1)
bimodal_to_mode_b = sm(np.random.normal(size=(BIMODAL_CONCEPT_NO, MODE_B_CONCEPT_NO)),axis=1)
class_dist

array([0.49347702, 0.50652298])

In [28]:
# Generate labels
TRUE_LABELS = np.random.choice(CLASS_NO, DATASET_SIZE, p=class_dist)
# Generate bimodal concepts from labels
multimodal_concepts = np.zeros(TRUE_LABELS.shape)
modality_a_concept = np.zeros(TRUE_LABELS.shape)
modality_b_concept = np.zeros(TRUE_LABELS.shape)
for idx,label in tqdm(enumerate(TRUE_LABELS)):
  multimodal_concepts[idx] = np.random.choice(BIMODAL_CONCEPT_NO, 1, p=class_to_bimodal[label,:])
  mc = multimodal_concepts[idx]
  modality_a_concept[idx] = np.random.choice(MODE_A_CONCEPT_NO, 1, p=bimodal_to_mode_a[int(mc), :])
  modality_b_concept[idx] = np.random.choice(MODE_B_CONCEPT_NO, 1, p=bimodal_to_mode_b[int(mc), :])

# Here are the "true" concepts per each modality:
CONCEPT_DATA = np.stack([modality_a_concept, modality_b_concept], axis=1)

100000it [00:11, 9013.72it/s]


In [29]:
print(CONCEPT_DATA.shape)

(100000, 2)


In [30]:
# Generate random values for each concept's mean and variance:
MODALITY_A_CONCEPT_MEANS = np.random.uniform(-1.0,1.0,(MODE_A_CONCEPT_NO, MODE_A_CONCEPT_DIM_OG))
MODALITY_A_CONCEPT_COV = np.random.uniform(0.1,0.5,(MODE_A_CONCEPT_NO, MODE_A_CONCEPT_DIM_OG, MODE_A_CONCEPT_DIM_OG))
MODALITY_B_CONCEPT_MEANS = np.random.uniform(-1.0,1.0,(MODE_B_CONCEPT_NO, MODE_B_CONCEPT_DIM_OG))
MODALITY_B_CONCEPT_COV = np.random.uniform(0.1,0.5,(MODE_B_CONCEPT_NO, MODE_B_CONCEPT_DIM_OG, MODE_B_CONCEPT_DIM_OG))

modality_a_data = np.zeros((DATASET_SIZE, MODE_A_CONCEPT_DIM_OG))
for idx, concept in tqdm(enumerate(modality_a_concept)):
  concept_mean = MODALITY_A_CONCEPT_MEANS[int(concept), :]
  concept_std = MODALITY_A_CONCEPT_COV[int(concept), :, :]
  modality_a_data[idx,:]= np.random.multivariate_normal(concept_mean, concept_std)

modality_b_data = np.zeros((DATASET_SIZE, MODE_B_CONCEPT_DIM_OG))
for idx, concept in tqdm(enumerate(modality_a_concept)):
  concept_mean = MODALITY_B_CONCEPT_MEANS[int(concept), :]
  concept_std = MODALITY_B_CONCEPT_COV[int(concept), :, :]
  modality_b_data[idx,:]= np.random.multivariate_normal(concept_mean, concept_std)

X_a = modality_a_data
X_b = modality_b_data
X_comb = np.concatenate([modality_a_data,modality_b_data], axis=1)
X_a.shape, X_b.shape, X_comb.shape

  modality_a_data[idx,:]= np.random.multivariate_normal(concept_mean, concept_std)
100000it [00:53, 1883.07it/s]
  modality_b_data[idx,:]= np.random.multivariate_normal(concept_mean, concept_std)
100000it [00:53, 1872.43it/s]


((100000, 30), (100000, 30), (100000, 60))

In [31]:
import pickle

data = dict()
data['a'] = X_a
data['b'] = X_b
data['label'] = TRUE_LABELS

with open("data.pickle", "wb") as f:
    try:
        pickle.dump(data, f)
    except Exception as ex:
        print("Error during pickling object", ex)
f.close()

In [32]:
try:
    with open("data.pickle", "rb") as f:
        data = pickle.load(f)
        print(data.keys())
    f.close()
except Exception as ex:
    print("Error during unpickling object", ex)

dict_keys(['a', 'b', 'label'])


In [33]:
import pickle

data = dict()
data['concept'] = CONCEPT_DATA
data['label'] = TRUE_LABELS

with open("concept.pickle", "wb") as f:
    try:
        pickle.dump(data, f)
    except Exception as ex:
        print("Error during pickling object", ex)
f.close()

In [34]:
try:
    with open("concept.pickle", "rb") as f:
        concept = pickle.load(f)
        print(concept.keys())
    f.close()
except Exception as ex:
    print("Error during unpickling object", ex)

dict_keys(['concept', 'label'])


In [35]:
assert((concept['label'] == data['label']).all())