In [253]:
import pandas as pd
import numpy as np
import os
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from matplotlib.ticker import MaxNLocator
from matplotlib.colors import ListedColormap
from itertools import combinations,product
plt.rcParams['text.usetex'] = True
import sys

In [254]:
EMBED_DIR="../../Experiments/dimensionality_reduction_metrics/embeddings/"
DATA_DIR="../../Experiments/dimensionality_reduction_metrics/normalized_data/"
DATASET='Reuters30k'
SAVE_DIR='./results/plots'
LABEL_DIR='../../Experiments/dimensionality_reduction_metrics/datasets/'
SAMPLE_PERCENTAGE=1 if DATASET in ['geneRNASeq'] else (0.8 if DATASET in ['hatespeech'] else 0.02)
RANDOM_SEED=42
COLORMAP= 'tab10' if DATASET in ['Cifar10', 'FMnist', 'geneRNASeq','hatespeech'] else 'magma'

In [255]:
def load_data(dataset:str, EMBED_DIR:str, k1:int, k2:int, max_iter:int=100):
    return np.load(os.path.join(EMBED_DIR, dataset, f'{k1}_{k2}_{max_iter}.npy'))

In [256]:
def convert_to_numerical(labels:np.ndarray)->np.ndarray:
    return np.argmax(labels,axis=1)+1

In [257]:
def load_labels(dataset:str, LABEL_DIR:str):
    if not dataset in ['Reuters30k']:
        return np.load(os.path.join(LABEL_DIR, dataset, 'y.npy'))
    else:
        return convert_to_numerical(np.load(os.path.join(LABEL_DIR, dataset, 'y.npy')))

In [258]:
data=load_data(DATASET, EMBED_DIR, 2, 8)

In [259]:
labels=load_labels(DATASET,LABEL_DIR)

In [260]:
print(list(combinations(range(3),2)))

[(0, 1), (0, 2), (1, 2)]


In [261]:
if not os.path.exists('./results/plots/'):
    os.mkdir('./results/plots')

In [262]:
def plot_data(x:np.ndarray, y:np.ndarray, labels:np.ndarray, title:str, xlabel:str, ylabel:str, save_path:str, alpha:float ,colormap:str, markersize:int):


    plt.figure(figsize=(10, 8))
    unique_labels=np.unique(labels)
    num_classes=len(unique_labels)
    cmap=plt.get_cmap(colormap, num_classes)
    norm=plt.Normalize(0,num_classes)
    label_cmap=ListedColormap(cmap(norm(unique_labels)))

    plt.scatter(x,y,c=labels, cmap=label_cmap,alpha=alpha)
    cbar = plt.colorbar()
    cbar.set_ticks(np.arange(num_classes))
    cbar.set_ticklabels(unique_labels)
    handles = [plt.Line2D([0], [0], marker='o', color='w', markerfacecolor=label_cmap(i), markersize=markersize) for i in range(num_classes)]
    labels_legend = [f'Label {label}' for label in unique_labels]
    plt.legend(handles, labels_legend, title="Labels")
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.title(title)
    plt.savefig(save_path)
    plt.close()


In [263]:
def make_plots(data:np.ndarray, labels:np.ndarray, k1:int, k2:int, dataset:str, alpha:float=None, colormap:str='viridis', markersize:int=6):
    
    parent_dir=os.path.join(SAVE_DIR, dataset)
    if not os.path.exists(parent_dir):
        os.mkdir(parent_dir)
    
    axes_list=list(combinations(range(k1),2))
    pca_dir=os.path.join(parent_dir, 'k1')
    if not os.path.exists(pca_dir):
        os.mkdir(pca_dir)
    for axes in axes_list:
        fig_path=os.path.join(pca_dir,f'Axes{axes[0]}{axes[1]}.png')
        plot_data(data[:,axes[0]], data[:,axes[1]], labels, f'{dataset} Axes: {axes[0]}, {axes[1]}', f'Axes {axes[0]}', f'Axes {axes[1]}',fig_path,alpha, colormap,markersize)
    
    axes_list=list(combinations(range(k1,k2),2))
    rmap_dir=os.path.join(parent_dir, 'k2')
    if not os.path.exists(rmap_dir):
        os.mkdir(rmap_dir)
    for axes in axes_list:
        fig_path=os.path.join(rmap_dir,f'Axes{axes[0]}{axes[1]}.png')
        plot_data(data[:,axes[0]], data[:,axes[1]], labels, f'{dataset} Axes: {axes[0]}, {axes[1]}', f'Axes {axes[0]}', f'Axes {axes[1]}',fig_path,alpha, colormap,markersize)
    
    cross_combinations=list(product(range(k1), range(k1,k2)))
    cross_dir=os.path.join(parent_dir, 'cross_combination_plots')
    if not os.path.exists(cross_dir):
        os.mkdir(cross_dir)
    for axes in cross_combinations:
        fig_path=os.path.join(cross_dir,f'Axes{axes[0]}{axes[1]}.png')
        plot_data(data[:,axes[0]], data[:,axes[1]], labels, f'{dataset} Axes: {axes[0]}, {axes[1]}', f'Axes {axes[0]}', f'Axes {axes[1]}',fig_path,alpha, colormap,markersize)


In [264]:
sample_percentage=SAMPLE_PERCENTAGE
np.random.seed(RANDOM_SEED)
random_indices = np.random.choice(data.shape[0], int(data.shape[0]*sample_percentage), replace=False)
data_sample = data[random_indices]
labels_sample = labels[random_indices]

In [265]:
make_plots(data_sample,labels_sample, 2,8, DATASET, None,COLORMAP,3)