# Imports

In [None]:
import sys
import os

os.environ['CUDA_VISIBLE_DEVICES'] = '1'
sys.path.append(os.path.abspath("../.."))       # for 'protonet_mnist_add_utils' folder
sys.path.append(os.path.abspath("../../.."))    # for 'data' folder
sys.path.append(os.path.abspath("../../../..")) # for 'models' and 'datasets' folders


print(sys.path)

In [None]:
import torch
import argparse
import itertools
import matplotlib.pyplot as plt

from argparse import Namespace
from datasets import get_dataset
from datasets.utils.base_dataset import BaseDataset

# Data Loading

In [None]:
args_protonet = Namespace(
    dataset='kandinsky',     
    batch_size=32,
    preprocess=0,
    c_sup=1, # ^ supervision loaded to simulate direct annotation for prototypes
    which_c=[-1],
    model='kandsl',        
    task='patterns',    
)

kand_dataset = get_dataset(args_protonet)
kand_train_loader, kand_val_loader , kand_test_loader = kand_dataset.get_data_loaders()
print(kand_dataset)

# Select Triplets Containing Required Concepts

In [None]:


# Define target combinations as frozensets of (shape, color) pairs.
target_combos = {
    frozenset({(0, 0), (1, 2), (2, 1)}): "Red Square, Blue Circle, Yellow Triangle",
    frozenset({(0, 1), (1, 0), (2, 2)}): "Yellow Square, Red Circle, Blue Triangle",
    frozenset({(0, 2), (1, 1), (2, 0)}): "Blue Square, Yellow Circle, Red Triangle"
}

def get_specific_figures_with_labels(dataset):
    train_loader, _, _ = dataset.get_data_loaders()
    collected = {combo: None for combo in target_combos.keys()}
    collected_labels = {combo: None for combo in target_combos.keys()}
    collected_indices = {combo: None for combo in target_combos.keys()}
    
    for batch_idx, data in enumerate(train_loader):
        images, _, concepts = data
        batch_size = images.shape[0]
        
        for b in range(batch_size):
            current_image = images[b]
            figs = [
                current_image[:, :, :64],
                current_image[:, :, 64:128],
                current_image[:, :, 128:]
            ]
            
            for k in range(concepts.shape[1]):
                shape_triplet = concepts[b, k, :3].tolist()
                color_triplet = concepts[b, k, 3:].tolist()
                
                pairs = frozenset((shape_triplet[i], color_triplet[i]) for i in range(3))
                
                if pairs in collected and collected[pairs] is None:
                    collected[pairs] = figs[k]
                    collected_labels[pairs] = (shape_triplet, color_triplet)
                    collected_indices[pairs] = (batch_idx, b, k)
            
            if all(value is not None for value in collected.values()):
                print(batch_idx)
                break
        if all(value is not None for value in collected.values()):
            break

    return collected, collected_labels, collected_indices

# Get figures and their labels
collected_specific, collected_labels, collected_indices = get_specific_figures_with_labels(kand_dataset)
print(collected_indices)

collected_imgs = []
collected_titles = []
collected_info = []

for combo, title in target_combos.items():
    img = collected_specific.get(combo)
    labels = collected_labels.get(combo)
    if img is not None:
        collected_imgs.append(img.unsqueeze(0))
        collected_titles.append(title)
        collected_info.append(labels[0] + labels[1])

collected_info = torch.tensor(collected_info)
if collected_imgs:
    collected_imgs = torch.cat(collected_imgs, dim=0)
    
    fig, axs = plt.subplots(1, 3, figsize=(12, 6))
    for idx in range(3):
        ax = axs[idx]
        img_np = collected_imgs[idx].permute(1, 2, 0).cpu().numpy()
        ax.imshow(img_np)
        ax.axis('off')
        ax.set_title(f"{collected_titles[idx]}\n{collected_info[idx]}", fontsize=10)
    plt.tight_layout()
    plt.show()
else:
    print("No matching figures found in the dataset.")    


# Saving

In [None]:
print(collected_imgs.shape)
print(collected_info.shape)

os.makedirs('data/kand_prototypes/pnet', exist_ok=True)

torch.save(collected_imgs, 'data/kand_prototypes/concepts_init_aggregated.pt')
torch.save(collected_info, 'data/kand_prototypes/labels_init_aggregated.pt')