In [1]:
import torch
import numpy as np
from PIL import Image
from torchvision import transforms
from itertools import cycle
import matplotlib.patches as mpatches 
import random
import os

%matplotlib inline
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE

from utils.Utilities import Utilities
from utils.EEGDataset import EEGDataset
from utils.Caltech101Dataset import Caltech101Dataset

Utilities_handler = Utilities()

In [2]:
from torchvision.models import resnet50, ResNet50_Weights

weights = ResNet50_Weights.DEFAULT
resnet50_model = resnet50(weights=weights)

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(device)

# Remove the final classification (softmax) layer
model = torch.nn.Sequential(*(list(resnet50_model.children())[:-1])) 
model.eval()
model = model.to(device)

cuda


In [3]:
SUBJECT =1
BATCH_SIZE = 1
learning_rate = 0.01
EPOCHS = 100
SaveModelOnEveryEPOCH = 100
EEG_DATASET_PATH = "./data/eeg/eeg_signals_raw_with_mean_std.pth"
# EEG_DATASET_PATH = "./data/eeg/eeg_55_95_std.pthh"

LSTM_INPUT_FEATURES = 2048 # should be image features output.
LSTM_HIDDEN_SIZE = 460  # should be same as sequence length

# custom_model_weights =  "./models/raw/FC__subject1_epoch_10.pth"
# custom_model_weights =  "./models/raw/FC__subject1_epoch_49.pth"
# custom_model_weights =  "./models/raw/50/subject1/FC__subject1_epoch_10.pth"
# custom_model_weights =  "./models/raw/mse50/FC__subject1_epoch_20.pth"
# custom_model_weights =  "./models/raw/FC__subject1_epoch_199.pth"

custom_model_weights =  "./logs/1/VIT_Head_finetuned_eeg_subject_1_epoch40.pth"

In [4]:
from utils.CustomModel import CustomModel
CustModel = CustomModel(input_size=(LSTM_INPUT_FEATURES),output_size=(LSTM_HIDDEN_SIZE*128))

if os.path.exists(custom_model_weights):
    CustModel = torch.load(custom_model_weights)
    print(f"loaded custom weights: {custom_model_weights}")

CustModel = CustModel.to(device)
CustModel.eval()

loaded custom weights: ./logs/1/VIT_Head_finetuned_eeg_subject_1_epoch40.pth


CustomModel(
  (fc): Sequential(
    (0): Linear(in_features=768, out_features=2000, bias=True)
    (1): ReLU()
    (2): Linear(in_features=2000, out_features=2000, bias=True)
    (3): ReLU()
    (4): Linear(in_features=2000, out_features=58880, bias=True)
  )
)

In [5]:
dataset = EEGDataset(subset="train",eeg_signals_path=EEG_DATASET_PATH, eeg_splits_path="./data/eeg/block_splits_by_image_all.pth", subject=1,preprocessin_fn=weights.transforms(), time_low=20, time_high=480)
test_dataset = Caltech101Dataset(images_path="./data/images/caltech/101_ObjectCategories",preprocessin_fn=weights.transforms())

100%|██████████| 9144/9144 [00:00<00:00, 380944.36it/s]


In [None]:
caltech_label_wise_data = {}

if os.path.exists("./preProcessedData/caltech101_preProcessed_eegPredicted_labelWise_49_epochs.pth"):
    caltech_label_wise_data = torch.load("./preProcessedData/caltech101_preProcessed_eegPredicted_labelWise_49_epochs.pth")
    print("loaded:./preProcessedData/caltech101_preProcessed_eegPredicted_labelWise_49_epochs.pth")
else:
    caltech_label_wise_data = Utilities_handler.load_data_label_wise(test_dataset,model=model,CustModel=CustModel,device=device,process_data_with_model=True)
# for data in test_dataset:
#     _, label, image,i = data
#     with torch.no_grad():
#         features = model(image.unsqueeze(0).to(device))
#         features = features.view(-1, features.size(1))
#         outputs = CustModel(features)
#     test_eeg = outputs.cpu().numpy() 
#     if not label["ClassId"] in caltech_label_wise_data:
#         caltech_label_wise_data[label["ClassId"]] = {"images":[], "eeg":[], "pred_eeg":[]}
#     caltech_label_wise_data[label["ClassId"]]["images"].append(image)
#     caltech_label_wise_data[label["ClassId"]]["pred_eeg"].append(test_eeg)

In [None]:
caltech_label_wise_data.keys()

In [None]:
# torch.save(caltech_label_wise_data, "./preProcessedData/caltech101_preProcessed_eegPredicted_labelWise_49_epochs.pth")

In [None]:
def prepareEEGData(labelWiseData, convert_to_numpy=True, flatten_eeg=True):
    eeg_features_ = []
    eeg_labels_ = []
    for label, labeData in labelWiseData.items():
        pred_eeg_fet = labeData["pred_eeg"]
        for idx,eeg in enumerate(pred_eeg_fet):
            eeg_features_.append(pred_eeg_fet[idx][0])
            eeg_labels_.append(label)
    if convert_to_numpy:
        eeg_features_  =np.array(eeg_features_, dtype=float)
    if flatten_eeg:
        eeg_features_ = eeg_features_.reshape(eeg_features_.shape[0], -1) 
    return eeg_features_, eeg_labels_


In [None]:
caltech_eeg_pred_features = []
caltech_eeg_labels = []
# for label, labeData in caltech_label_wise_data.items():
#     pred_eeg_fet = labeData["pred_eeg"]
#     for idx,eeg in enumerate(pred_eeg_fet):
#         caltech_eeg_pred_features.append(pred_eeg_fet[idx][0])
#         caltech_eeg_labels.append(label)
# caltech_eeg_pred_features  =np.array(caltech_eeg_pred_features, dtype=float)
# caltech_eeg_pred_features = caltech_eeg_pred_features.reshape(caltech_eeg_pred_features.shape[0], -1)
caltech_eeg_pred_features,caltech_eeg_labels = prepareEEGData(caltech_label_wise_data, convert_to_numpy=True, flatten_eeg=True)

In [None]:
# torch.save((caltech_eeg_pred_features,caltech_eeg_labels), "./preProcessedData/caltech101_preProcessed_eegfeatures_labels.pth")

In [None]:
caltech_eeg_pred_features.shape

In [None]:
X_tsne_flattned_preds_caltech = TSNE(n_components=3,perplexity=40, init="pca", learning_rate=0.1, n_iter=1000).fit_transform(caltech_eeg_pred_features)
# torch.save(X_tsne_flattned,"tsne_fltattned_raw_data_perplexity_20_init_random_lr_auto.pth")

In [None]:
plt.figure().set_size_inches(20,10)
plt.clf()

cmap = plt.cm.get_cmap("tab20c", int(1*len(caltech_label_wise_data.keys())))
cmap_pred = plt.cm.get_cmap("tab20c", int(1*len(caltech_label_wise_data.keys())))

gen_colors = []
handles = []
cmaps = []

gen_colors_pred = []
handles_pred = []
cmaps_pred = []

for eeg_label in list(caltech_label_wise_data.keys()):
    
    cmaps.append(cmap(eeg_label))
    cmaps_pred.append(cmap_pred(eeg_label))
    # if eeg_label==0 or eeg_label==15 or eeg_label==30:
    _patch = mpatches.Patch(color=cmap(eeg_label), label=f'Class {eeg_label}') 
    handles.append(_patch)
    _patch = mpatches.Patch(color=cmap_pred(eeg_label), label=f'Class {eeg_label} Pred') 
    handles_pred.append(_patch)


for i in range(caltech_eeg_pred_features.shape[0]):
    colorMap = cmaps[caltech_eeg_labels[i]]
    gen_colors.append(colorMap)

    colorMap = cmaps_pred[caltech_eeg_labels[i]]
    gen_colors_pred.append(colorMap)

In [None]:
plt.clf()
fig = plt.figure(figsize=(25, 25))
fig.set_size_inches(25,25)
ax = fig.add_subplot(111, projection="3d")
fig.add_axes(ax)

ax.set_title("EEG data")
# ax.view_init(azim=90, elev=1)
ax.view_init(azim=60, elev=5)
_ = ax.text2D(0.0, 1.0, s=f"n_samples={X_tsne_flattned_preds_caltech.shape[0]}", transform=ax.transAxes)

sel_channel = 97
ax.scatter(X_tsne_flattned_preds_caltech[:,0],X_tsne_flattned_preds_caltech[:,1],X_tsne_flattned_preds_caltech[:,2], c=gen_colors_pred, s=30, alpha=0.8)
ax.legend(handles=handles_pred, loc="best", fontsize=13,fancybox=True,ncol=10)
fig.savefig("Caltech_Predicted_EEG_Map.png",bbox_inches='tight')
plt.show()

In [None]:
label_wise_data = {}
model.eval()
CustModel.eval()
# label_wise_data = Utilities_handler.load_data_label_wise(dataset=dataset,model=model,CustModel=CustModel,device=device,process_data_with_model=True)

for data in dataset:
    eeg, label, image,i = data
    with torch.no_grad():
        features = model(image.unsqueeze(0).to(device))
        features = features.view(-1, features.size(1))
        outputs = CustModel(features)
    test_eeg = outputs.cpu().numpy() 

    if not label["ClassId"] in label_wise_data:
        label_wise_data[label["ClassId"]] = {"images":[], "eeg":[], "pred_eeg":[]}

    label_wise_data[label["ClassId"]]["images"].append(image)
    label_wise_data[label["ClassId"]]["pred_eeg"].append(test_eeg)
    label_wise_data[label["ClassId"]]["eeg"].append(eeg.numpy())

In [None]:
label_wise_data.keys()

In [None]:
eeg_features = []
eeg_pred_features = []
eeg_labels = []
for label, labeData in label_wise_data.items():
    eeg_fet = labeData["eeg"]
    pred_eeg_fet = labeData["pred_eeg"]
    for idx,eeg in enumerate(eeg_fet):
        # if not label==32 and not label==6:
            eeg_features.append(eeg)
            eeg_pred_features.append(pred_eeg_fet[idx][0])
            eeg_labels.append(label)
# for labeData in label_wise_data[0]["eeg"]:
#     eeg_features.append(labeData)
eeg_features  =np.array(eeg_features, dtype=float)
eeg_pred_features  =np.array(eeg_pred_features, dtype=float)

In [None]:
len(eeg_labels)

In [None]:
channelMap = Utilities_handler.read_channel_map(input_file="./channelmap.txt")
channel_names = list(channelMap.values())  # Add more channel names as needed
channel_types = ['eeg'] * len(channel_names)

In [None]:
eeg_features = eeg_features.reshape(eeg_features.shape[0], -1)
# noiseless_eeg = noiseless_eeg.reshape(noiseless_eeg.shape[0], -1)
# samples, time_length , channels  = denoised_ICA_EEG.shape
# denoised_ICA_EEG = denoised_ICA_EEG.reshape(samples, channels, time_length)
# denoised_ICA_EEG = denoised_ICA_EEG.reshape(samples, -1)

In [None]:
eeg_features.shape, eeg_pred_features.shape

In [None]:
selected_channels = []
for channe in range(128):
    if "O" in channelMap[channe+1]:
        # print(channe,channel_map[channe+1])
        selected_channels.append(channe)

# selected_channels.remove(27)  # 27 PO9
# selected_channels.remove(28)  # 28 O1
# selected_channels.remove(29)  # 29 Oz
# selected_channels.remove(30)  # 30 O2
# selected_channels.remove(31)  # 31 PO10
# selected_channels.remove(59)  # 59 PO7
# selected_channels.remove(60)  # 60 PO3
# selected_channels.remove(61)  # 61 POz
# selected_channels.remove(62)  # 62 PO4
# selected_channels.remove(63)  # 63 PO8
# selected_channels.remove(91)  # 91 POO9h
# selected_channels.remove(92)  # 92 POO1
# selected_channels.remove(93)  # 93 POO2
# selected_channels.remove(94)  # 94 POO10h
# selected_channels.remove(117) # 117 PPO9h
# selected_channels.remove(118) # 118 PPO5h
# selected_channels.remove(119) # 119 PPO1h
# selected_channels.remove(120) # 120 PPO2h
# selected_channels.remove(121) # 121 PPO6h
# selected_channels.remove(122) # 122 PPO10h
# selected_channels.remove(125) # 125 OI1h
# selected_channels.remove(126) # 126 OI2h

for chn in selected_channels:
    print(f"Channel : {chn}-{channelMap[chn+1]} will be displayed")

In [None]:
eeg_features.shape

In [None]:
X_tsne_flattned_imagenet = TSNE(n_components=3,perplexity=40, init="pca", learning_rate=0.1, n_iter=1000).fit_transform(eeg_features)
# torch.save(X_tsne_flattned,"tsne_fltattned_raw_data_perplexity_20_init_random_lr_auto.pth")

X_tsne_flattned_preds_imagenet = TSNE(n_components=3,perplexity=40, init="pca", learning_rate=0.1, n_iter=1000).fit_transform(eeg_pred_features)
# torch.save(X_tsne_flattned,"tsne_fltattned_raw_data_perplexity_20_init_random_lr_auto.pth")

# X_tsne_flattned_preds = TSNE(n_components=3,perplexity=40, init="pca", learning_rate=0.1, n_iter=500).fit_transform(caltech_eeg_pred_features)
# # torch.save(X_tsne_flattned,"tsne_fltattned_raw_data_perplexity_20_init_random_lr_auto.pth")

In [None]:
plt.figure().set_size_inches(20,10)
plt.clf()

cmap = plt.cm.get_cmap("Set1", int(1*len(label_wise_data.keys())))
cmap_pred = plt.cm.get_cmap("tab20c", int(1*len(label_wise_data.keys())))

gen_colors_imagenet = []
handles_imagenet = []
cmaps_imagenet = []

gen_colors_pred_imagenet = []
handles_pred_imagenet = []
cmaps_pred_imagenet = []

for eeg_label in list(label_wise_data.keys()):
    
    cmaps_imagenet.append(cmap(eeg_label))
    cmaps_pred_imagenet.append(cmap_pred(eeg_label))
    # if eeg_label==0 or eeg_label==15 or eeg_label==30:
    _patch = mpatches.Patch(color=cmap(eeg_label), label=f'Class {eeg_label}') 
    handles_imagenet.append(_patch)
    _patch = mpatches.Patch(color=cmap_pred(eeg_label), label=f'Class {eeg_label} Pred') 
    handles_pred_imagenet.append(_patch)


for i in range(eeg_features.shape[0]):
    colorMap = cmaps_imagenet[eeg_labels[i]]
    gen_colors_imagenet.append(colorMap)

    colorMap = cmaps_pred_imagenet[eeg_labels[i]]
    gen_colors_pred_imagenet.append(colorMap)

In [None]:
channel_cmaps = []
channel_cmap = plt.cm.get_cmap("hsv", 128)
for chn_range in range(128):
    channel_cmaps.append(channel_cmap(chn_range))

In [None]:
# X_tsne_flattned = TSNE(n_components=3,perplexity=40, init="pca", learning_rate=0.1, n_iter=1000).fit_transform(eeg_features)
# torch.save(X_tsne_flattned,"tsne_fltattned_raw_data_perplexity_20_init_random_lr_auto.pth")

# X_tsne_flattned_preds = TSNE(n_components=3,perplexity=40, init="pca",  learning_rate=0.1, n_iter=1000).fit_transform(eeg_pred_features)
# torch.save(X_tsne_flattned,"tsne_fltattned_raw_data_perplexity_20_init_random_lr_auto.pth")

plt.clf()
plt.figure(figsize=(10, 10))
plt.scatter(X_tsne_flattned_preds_imagenet[:,0], X_tsne_flattned_preds_imagenet[:,1], c=gen_colors_pred_imagenet, alpha=0.5)
# plt.scatter(X_tsne_flattned_imagenet[:,0], X_tsne_flattned_imagenet[:,1], c=gen_colors_imagenet, alpha=0.5)
# plt.plot([X_tsne_flattned_preds[:,0],X_tsne_flattned[:,0]],[X_tsne_flattned_preds[:,1],X_tsne_flattned[:,1]], c="blue", lw=0.5)
plt.legend(handles=handles_pred_imagenet, loc="lower right", fontsize=13, bbox_to_anchor=(2.2, 0.1),fancybox=True,ncol=5)
# plt.savefig("test_train_preds_fc.png")

In [None]:
# X_tsne_flattned = TSNE(n_components=3,perplexity=30, init="pca", learning_rate=0.1, n_iter=500, angle=0.5).fit_transform(eeg_features)
# torch.save(X_tsne_flattned,"tsne_fltattned_raw_data_perplexity_20_init_random_lr_auto.pth")

# X_tsne_flattned_preds = TSNE(n_components=3,perplexity=20, init="pca",  learning_rate=0.1, n_iter=1000).fit_transform(eeg_pred_features)
# torch.save(X_tsne_flattned,"tsne_fltattned_raw_data_perplexity_20_init_random_lr_auto.pth")

plt.clf()
fig = plt.figure(figsize=(20, 20))
fig.set_size_inches(20,20)
ax = fig.add_subplot(111, projection="3d")
fig.add_axes(ax)

ax.set_title("EEG data")
# ax.view_init(azim=90, elev=1)
ax.view_init(azim=60, elev=90)
_ = ax.text2D(0.0, 1.0, s="n_samples=1500", transform=ax.transAxes)


ax.scatter(X_tsne_flattned_preds_imagenet[:,0], X_tsne_flattned_preds_imagenet[:,1], X_tsne_flattned_preds_imagenet[:,2], c=gen_colors_pred_imagenet, s=30, alpha=0.8)
# ax.legend(handles=handles_pred_imagenet, loc="best", fontsize=13,fancybox=True,ncol=7)

# ax.scatter(X_tsne_flattned_imagenet[:,0], X_tsne_flattned_imagenet[:,1], X_tsne_flattned_imagenet[:,2], c=gen_colors_imagenet, s=30, alpha=0.8)
# ax.legend(handles=hanttned_imagenet[:,0], X_tsne_flattned_imagenet[:,1], X_tsne_flattned_imagenet[:,2], c=gen_colors_imagenet, s=30, alpha=0.8)
# ax.legend(handles=handles_imagenet, loc="best", fontsize=13,fancybox=True,ncol=7)

# ax.legend(handles=handles_imagenet+handles_pred_imagenet, loc="best", fontsize=13,fancybox=True,ncol=7)

# ax.scatter(X_tsne_flattned_preds_caltech[:,0],X_tsne_flattned_preds_caltech[:,1],X_tsne_flattned_preds_caltech[:,2], c=gen_colors_pred, s=30, alpha=0.8)
# ax.legend(handles=handles_pred, loc="best", fontsize=13,fancybox=True,ncol=10)

# ax.scatter(X_tsne_flattned_preds[:,0],X_tsne_flattned_preds[:,1],X_tsne_flattned_preds[:,2], c=gen_colors, s=30, alpha=0.8)
# ax.legend(handles=handles_imagenet+handles_pred_imagenet, loc="best", fontsize=13,fancybox=True,ncol=10)
ax.legend(handles=handles_pred_imagenet, loc="best", fontsize=13,fancybox=True,ncol=7)
# 
# fig.savefig("Imagenet_Map_Pred_EEGs_50epochs.png",bbox_inches='tight')
# fig.savefig("Imagenet_Map_Raw_EEGs_50epochs.png",bbox_inches='tight')
# fig.savefig("Imagenet_Map_Raw_and_Pred_EEGs_50epochs.png",bbox_inches='tight')
# fig.savefig("Imagenet_Map_Raw_EEGs_Caltech_Data_50epochs.png",bbox_inches='tight')


# fig.savefig("Imagenet_Map_Pred_EEGs_Caltech_Data_50epochs.png",bbox_inches='tight')

# fig.savefig("Imagenet_Map_Pred_RAW_EEGs_20epochs.png",bbox_inches='tight')
# fig.savefig("Imagenet_Map_Pred_EEGs_20epochs.png",bbox_inches='tight')

plt.show()

In [None]:
eeg_labels_np = np.array(eeg_labels)
print(eeg_labels_np.shape)

In [None]:
filtered = np.where(eeg_labels_np == 0)[0]

In [None]:
filtered.shape

In [None]:
np.take(tsne_channel_wise[0], filtered, 0).shape

In [None]:
# for label_idx, c_label in enumerate(list(set(eeg_labels))):
c_label = 0
for chn in range(128):
    fig = plt.figure(figsize=(20, 20))
    fig.set_size_inches(20,20)
    ax = fig.add_subplot(111, projection="3d")
    fig.add_axes(ax)

    ax.set_title("EEG data")
    ax.view_init(azim=-50, elev=50)
    _ = ax.text2D(0.8, 0.05, s="n_samples=1500", transform=ax.transAxes)

    handles.append(mpatches.Patch(color=channel_cmaps[int(random.random()*128)], label=f'Channel {chn}') )
    plt.legend(handles=handles, loc="lower right", fontsize=10,ncol=4, bbox_to_anchor=(1.0, 0.0),fancybox=True) 

    # filtered = np.where(eeg_labels_np == c_label)[0]
    # filtered_tsne = np.take(tsne_channel_wise[chn], filtered, 0)
    # ax.scatter(filtered_tsne[:,0], filtered_tsne[:,1], c=[cmaps[c_label] for i in range(filtered_tsne.shape[0])], s=50, alpha=0.8)
    ax.scatter(tsne_channel_wise[chn][:,0], tsne_channel_wise[chn][:,1],tsne_channel_wise[chn][:,2], c=gen_colors, s=50, alpha=0.8)
    
    fig.savefig(f"output/{c_label}_raw_eeg_channel_{chn}.png")
    fig.clf()

    del handles[-1]

In [None]:
# # EEG flattened features fit.
# X_tsne = TSNE(n_components=3,perplexity=50, init="pca", learning_rate='auto').fit_transform(eeg_features)

In [None]:
len(handles)

In [None]:
X_tsne.shape

In [None]:
# plt.scatter(X_tsne[:, 0], X_tsne[:, 1], c=gen_colors,cmap=plt.get_cmap("Spectral"),alpha=.4,edgecolor='k',projection="3d")

fig = plt.figure(figsize=(8, 6))
ax = fig.add_subplot(111, projection="3d")
fig.add_axes(ax)
ax.scatter(
    X_tsne[:, 0], X_tsne[:, 1], c=gen_colors, s=50, alpha=0.8
)
ax.set_title("EEG data")
ax.view_init(azim=-66, elev=12)
_ = ax.text2D(0.8, 0.05, s="n_samples=1500", transform=ax.transAxes)


In [None]:
X_tsne.shape

In [None]:
X_tsne.shape

In [None]:
SELECTED_SAMPLE = 0
SELECTED_CLASS = 13
SELECTED_CHANNEL = 1
TIME_SERIES_START = 0
TIEM_SERIES_END = 500

SELECTED_CLASSES = [2,13,19,25,30]


colorMapString = {}

for SELECTED_CLASS in SELECTED_CLASSES:

    label0data = np.array(label_wise_data[SELECTED_CLASS]["eeg"])
    cmap = plt.cm.get_cmap("tab20", label0data.shape[0])

    for sample_num in range(10):

        plt.figure().set_size_inches(20,10)
        plt.clf()
        f, axarr = plt.subplots(1,2)
        f.set_size_inches(20,10)

        handles = []
        
        SELECTED_SAMPLE = sample_num

        image = label_wise_data[SELECTED_CLASS]["images"][SELECTED_SAMPLE]
        for channe_num in selected_channels:
            rand_int = random.randint(0, label0data.shape[0])
            colorMap = cmap(rand_int)
            if not channel_map[channe_num+1] in colorMapString:
                colorMapString[channel_map[channe_num+1]] = colorMap

            colorMap = colorMapString[channel_map[channe_num+1]]

            axarr[1].plot(label0data[SELECTED_SAMPLE][TIME_SERIES_START:TIEM_SERIES_END,channe_num], c=colorMap)
            max_of_channel  =np.max(label0data[SELECTED_SAMPLE][TIME_SERIES_START:TIEM_SERIES_END,channe_num])
            max_of_channel = str(np.round(max_of_channel,2))
            min_of_channel  = np.min(label0data[SELECTED_SAMPLE][TIME_SERIES_START:TIEM_SERIES_END,channe_num])
            min_of_channel = str(np.round(min_of_channel,2))
            _channel_patch = mpatches.Patch(color=colorMap, label=f'Channel: {channe_num}:{channel_map[channe_num+1]} [{min_of_channel}:{max_of_channel}]')
            handles.append(_channel_patch)

        rand_int = random.randint(1, label0data.shape[0])
        _sample_patch = mpatches.Patch(color=cmap(rand_int), label=f'Sample: {SELECTED_SAMPLE}')
        handles.append(_sample_patch)

        if not f"class_{SELECTED_CLASS}" in colorMapString:
            rand_int = random.randint(1, label0data.shape[0])
            colorMap = cmap(rand_int)
            colorMapString[f"class_{SELECTED_CLASS}"] = colorMap

        _class_patch = mpatches.Patch(color=colorMapString[f"class_{SELECTED_CLASS}"], label=f"CLASS: {SELECTED_CLASS}")
        handles.append(_class_patch)

        if not "TS" in colorMapString:
            rand_int = random.randint(1, label0data.shape[0])
            colorMap = cmap(rand_int)
            colorMapString["TS"] = colorMap

        TS_patch = mpatches.Patch(color=colorMapString["TS"],label=f"T: {TIME_SERIES_START}:{TIEM_SERIES_END}") 
        handles.append(TS_patch)

        axarr[1].legend(handles=handles, loc="lower right", fontsize=13, bbox_to_anchor=(1.0, 0.0),fancybox=True,ncol=1)
        axarr[0].imshow(image)
        f.savefig(f"./output/Class_{SELECTED_CLASS}_Sample_{SELECTED_SAMPLE}_Channels_{selected_channels}.png")