# Imports

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
import seaborn as sns
from matplotlib.colors import LogNorm, Normalize
import torch
from torch import nn
from torchvision import models
from torch.utils.data import DataLoader, Subset
from sklearn.metrics import confusion_matrix, brier_score_loss
import glob
import os
from tqdm import tqdm
import pandas as pd
from skimage.transform import resize
import umap.umap_ as umap # Recommended way to import UMAP
from matplotlib.colors import LinearSegmentedColormap
from torchvision.models.feature_extraction import create_feature_extractor, get_graph_node_names
from SetRandomSeed import set_random_seeds, GeneratorSeed
from zoobot.pytorch.training.finetune import FinetuneableZoobotClassifier

In [None]:
pd.set_option('display.max_columns', 40)
set_random_seeds(626)
g = GeneratorSeed(626)

In [None]:
# Set the figure size to match one column in AASTeX (3.25 inches width)
width = 3.25
height = 2.5 # height can vary, 2.5 inches is just a suggestion
matplotlib.rcParams["font.size"] = "10"


# Training Curves

# Read In Network

In [None]:
CNNName = 'Adam_Cyclic' 
# Load the saved model state
checkpoint = torch.load('ResNet_' + CNNName +'.pth', map_location=torch.device('cpu'))
model = models.resnet18(weights=True)
print("Best model was from epoch:", checkpoint['epoch'])
#Set up the model for inference
dropout_rate = 0.2
learning_rate = 1e-5

model = FinetuneableZoobotClassifier(name='hf_hub:mwalmsley/zoobot-encoder-resnet18', learning_rate=learning_rate,  # use a low learning rate
    layer_decay=0.5,  # reduce the learning rate from lr to lr^0.5 for each block deeper in the network
    # arguments specific to FinetuneableZoobotClassifier
    num_classes=2
)


# Load the model weights from the checkpoint
model.load_state_dict(checkpoint['model_state_dict'])
model = model.to('cpu')  # Ensure the model is on the correct device (GPU/CPU)
model.eval()  # Switch to evaluation mode

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
from BinaryMergerDataset import BinaryMergerDataset, get_transforms
path = ### input path here! removed to be anonymous
BATCH_SIZE = 64

Now use custom data loader to load images, ensuring training has data augmentation and validation and test do not

In [None]:
test_mergers_dataset_orig = BinaryMergerDataset(path, 'test', mergers = True, transform = get_transforms(aug=False), codetest=False)
test_nonmergers_dataset_orig = BinaryMergerDataset(path, 'test', mergers = False, transform = get_transforms(aug=False), codetest=False)


test_dataset_full = torch.utils.data.ConcatDataset([test_mergers_dataset_orig, test_nonmergers_dataset_orig])
# Create a fixed permutation
indices = np.random.permutation(len(test_dataset_full))
shuffled_test_dataset = Subset(test_dataset_full, indices)

test_dataloader = DataLoader(shuffled_test_dataset, shuffle = False, num_workers = 0, batch_size=BATCH_SIZE, generator=g)

train_mergers_dataset_augment = BinaryMergerDataset(path, 'train', mergers = True, transform = get_transforms(aug=True), codetest=False)
train_nonmergers_dataset_augment = BinaryMergerDataset(path, 'train', mergers = False, transform = get_transforms(aug=True), codetest=False)

train_mergers_dataset_orig = BinaryMergerDataset(path, 'train', mergers = True, transform = get_transforms(aug=False), codetest=False)
train_nonmergers_dataset_orig = BinaryMergerDataset(path, 'train', mergers = False, transform = get_transforms(aug=False), codetest=False)

train_dataset_full = torch.utils.data.ConcatDataset([train_mergers_dataset_augment, train_nonmergers_dataset_augment, train_mergers_dataset_orig, train_nonmergers_dataset_orig])
train_dataloader = DataLoader(train_dataset_full, shuffle = True, num_workers = 0, batch_size=BATCH_SIZE, generator=g)

validation_mergers_dataset_orig = BinaryMergerDataset(path, 'validation', mergers = True, transform = get_transforms(aug=False), codetest=False)
validation_nonmergers_dataset_orig = BinaryMergerDataset(path, 'validation', mergers = False, transform = get_transforms(aug=False), codetest=False)

validation_dataset_full = torch.utils.data.ConcatDataset([validation_mergers_dataset_orig, validation_nonmergers_dataset_orig])
validation_dataloader = DataLoader(validation_dataset_full, shuffle = False, num_workers = 0, batch_size=BATCH_SIZE, generator=g)



In [None]:
print(len(train_mergers_dataset_orig) + len(train_mergers_dataset_augment))
print(len(train_nonmergers_dataset_orig) + len(train_nonmergers_dataset_augment))

print(len(validation_mergers_dataset_orig))
print(len(validation_nonmergers_dataset_orig))

print(len(test_mergers_dataset_orig))
print(len(test_nonmergers_dataset_orig))

In [None]:
#define accuracy and a confusion matrix 
def get_accuracy(pred,original):
    print(pred, original)
    return np.mean(pred == original) * 100

def plot_confusion_matrix(cm, classes, epoch): #help from chat GPT
    plt.figure(figsize=(width, width))
    sns.heatmap(cm, annot=True, fmt=".2f", cmap='Purples', xticklabels=classes, yticklabels=classes, vmin = 0, vmax = 100,
                square = True, cbar_kws={'label': 'Percentage', "shrink": 0.65})
    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    #plt.title('Confusion Matrix ' + str(epoch) + ' Set')
    plt.tight_layout()
    plt.savefig('ConfusionMatrix_TestSet_' + CNNName +'.png', dpi = 300, bbox_inches="tight")
    plt.show()

In [None]:
#print names of nodes I could pull to plot -- we are looking for the name of the fully connected layer for UMAP
print(get_graph_node_names(model))
feature_extractor_model = create_feature_extractor(model, return_nodes={'encoder.fc': 'extracted_features',})

# Set the feature extractor model to evaluation mode
feature_extractor_model.eval()
feature_extractor_model.to(device)

In [None]:
all_labels = []
all_preds = []
all_names = []
all_probabilities = []
all_logits = []
all_extracted_features = []
all_labels_for_isomap_plot = []

model.eval()
with torch.no_grad():  # No need to track gradients during inference
    for images, labels, names in tqdm(test_dataloader):
        images = images.to(dtype=torch.float32).to(device) 
        labels = labels.to(dtype=torch.long).to(device)
        # Forward pass
        outputs = model(images)
        all_logits.extend(outputs.cpu().numpy())
        features_dict = feature_extractor_model(images)
        features = features_dict['extracted_features']

        # The output of avgpool is usually [N, 512, 1, 1]. Flatten it.
        features = torch.flatten(features, 1) # Flattens to [N, 512]
        
        all_extracted_features.extend(features.cpu().numpy())
        all_labels_for_isomap_plot.extend(labels.cpu().numpy())
        probabilities = torch.softmax(outputs, dim=1)
        pred = torch.argmax(outputs, dim=1)   # Convert to binary (0 or 1)
        pred = pred.to(device=device) 
        maxvals, pred_index = torch.max(outputs, 1)
        # Collect labels and predictions
        all_labels.extend(labels.cpu().numpy())
        all_preds.extend(pred.cpu().numpy())
        all_names.extend(names)
        all_probabilities.extend(probabilities.cpu().numpy())
# 4. Compute accuracy or other evaluation metrics (e.g., confusion matrix)
# Convert lists to numpy arrays
all_labels = np.array(all_labels)
#all_preds = np.squeeze(np.array(all_preds))
all_preds = np.array(all_preds)
all_names = np.array(all_names)
all_probabilities = np.array(all_probabilities)
all_logits = np.array(all_logits)
test_accuracy = get_accuracy(all_preds, all_labels)
all_extracted_features = np.array(all_extracted_features)
all_labels_for_isomap_plot = np.array(all_labels_for_isomap_plot)
#test_accuracy = np.mean(np.array(test_acc))
print(f"Test Accuracy: {test_accuracy:.2f}%")


In [None]:
print(f"Test Accuracy: {test_accuracy:.2f}%")
print(test_accuracy)

print("Best model was from epoch:", checkpoint['epoch'])

In [None]:
# Convert lists to numpy arrays
all_labels = np.array(all_labels)
all_preds = np.array(all_preds)
all_names = np.array(all_names)

In [None]:
print(np.shape(all_labels))

In [None]:
print(np.shape(all_preds))

In [None]:
print(all_preds)

In [None]:
cm = confusion_matrix(all_labels, all_preds)
cmn = (cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]) *100 
TP, FN, FP, TN = cm.ravel()
plot_confusion_matrix(cmn, classes=['Merger', 'Non-merger'], epoch='Test')


In [None]:
purity = TP / (TP + FP)
completeness = TP / (TP + FN)
print(purity, completeness)

In [None]:
print(np.shape(all_labels))
print(np.shape(all_probabilities[:,0]))
print(np.max(all_probabilities[:,0]))

In [None]:
correct = []
for i in range(len(all_preds)):
    if all_preds[i] == all_labels[i]:
        correct.append('yes')
    else:
        correct.append('no')

In [None]:
string = '102575_1'
string[:-2]

Making a big data frame for statistics of which galaxies are identified correctly from which angles

In [None]:
ClassificationInfo = pd.DataFrame(all_names, columns=['Image Name'])
ClassificationInfo['True Label'] = all_labels
ClassificationInfo['Predicted Label'] = all_preds
ClassificationInfo['Correct?'] = correct


ClassificationInfo

In [None]:
shids = []
for n in ClassificationInfo['Image Name']:
    #print(n[:-2])
    shids.append(n[:-2])
ClassificationInfo['SubhaloID'] = np.array(shids).astype(int)
shids = list(set(shids))
shids = np.array(shids).astype(int)
print(shids)

In [None]:
shids.sort()
print(shids)
print(type(shids[0]))

In [None]:
ClassificationInfo = ClassificationInfo.sort_values(by='SubhaloID')
ClassificationInfo

In [None]:
table40 = pd.read_csv('/code/SubhaloListForMakeMocks40.csv', usecols = ['Subfind_ID', 'Type'])
table50 = pd.read_csv('/code/SubhaloListForMakeMocks50.csv', usecols = ['Subfind_ID', 'Type'])
bigtable = pd.concat([table40, table50], ignore_index=False)
bigtable

In [None]:
print(type(bigtable['Subfind_ID'][0]))
print(type(shids[0]))

In [None]:
types =[]
for s in shids:
    #print(type(s))
    limit = 6
    result = bigtable.loc[bigtable['Subfind_ID'] == s, 'Type'].values[0]
    print(result)
    while limit != 0:
        types.append(result)
        limit-=1
    

In [None]:
print(len(types))
ClassificationInfo['Type'] = types
ClassificationInfo

### Add Merger Mass Ratio

In [None]:
s_main = 40
s = 40
Subfind_ID_mergers40 = np.loadtxt('/TNGProjects/merger_tables/' + str(s_main) + '/all_mergers_at_' + str(s) + '_enviro.txt', 
                                    skiprows = 1, usecols = 1, dtype = int)
Subfind_ID_nonmergers40 = np.loadtxt('/TNGProjects/merger_tables/' + str(s_main) + '/nonmergers_matched_at_' + str(s) + '_no_enviro.txt', 
                                    skiprows = 1, usecols = 1, dtype = int)
q_mergers40 = np.loadtxt('/TNGProjects/merger_tables/' + str(s_main) + '/all_mergers_at_' + str(s) + '_enviro.txt', 
                                    skiprows = 1, usecols = 3, dtype = float)
sm_mergers40 = np.loadtxt('/TNGProjects/merger_tables/' + str(s_main) + '/all_mergers_at_' + str(s) + '_enviro.txt', 
                                    skiprows = 1, usecols = 5, dtype = float)
sm_nonmergers40 = np.loadtxt('/TNGProjects/merger_tables/' + str(s_main) + '/nonmergers_matched_at_' + str(s) + '_no_enviro.txt', 
                                    skiprows = 1, usecols = 2, dtype = float)
sfr_mergers40 = np.loadtxt('/TNGProjects/merger_tables/' + str(s_main) + '/all_mergers_at_' + str(s) + '_enviro.txt', 
                                    skiprows = 1, usecols = 8, dtype = float)
sfr_nonmergers40 = np.loadtxt('/TNGProjects/merger_tables/' + str(s_main) + '/nonmergers_matched_at_' + str(s) + '_no_enviro.txt', 
                                    skiprows = 1, usecols = 3, dtype = float)

s_main = 50
s = 50
Subfind_ID_mergers50 = np.loadtxt('/TNGProjects/merger_tables/' + str(s_main) + '/all_mergers_at_' + str(s) + '_enviro.txt', 
                                    skiprows = 1, usecols = 1, dtype = int)
Subfind_ID_nonmergers50 = np.loadtxt('/TNGProjects/merger_tables/' + str(s_main) + '/nonmergers_matched_at_' + str(s) + '_no_enviro.txt', 
                                    skiprows = 1, usecols = 1, dtype = int)
q_mergers50 = np.loadtxt('/TNGProjects/merger_tables/' + str(s_main) + '/all_mergers_at_' + str(s) + '_enviro.txt', 
                                    skiprows = 1, usecols = 3, dtype = float)
sm_mergers50 = np.loadtxt('/TNGProjects/merger_tables/' + str(s_main) + '/all_mergers_at_' + str(s) + '_enviro.txt', 
                                    skiprows = 1, usecols = 5, dtype = float)
sm_nonmergers50 = np.loadtxt('/TNGProjects/merger_tables/' + str(s_main) + '/nonmergers_matched_at_' + str(s) + '_no_enviro.txt', 
                                    skiprows = 1, usecols = 2, dtype = float)
sfr_mergers50 = np.loadtxt('/TNGProjects/merger_tables/' + str(s_main) + '/all_mergers_at_' + str(s) + '_enviro.txt', 
                                    skiprows = 1, usecols = 8, dtype = float)
sfr_nonmergers50 = np.loadtxt('/TNGProjects/merger_tables/' + str(s_main) + '/nonmergers_matched_at_' + str(s) + '_no_enviro.txt', 
                                    skiprows = 1, usecols = 3, dtype = float)

In [None]:
#print(sfr_nonmergers40)

In [None]:
print()

In [None]:
q= []
sm = []
sfr = []
Subfind_ID = np.concatenate((Subfind_ID_mergers40, Subfind_ID_mergers50))
Subfind_ID_nonmergers = np.concatenate((Subfind_ID_nonmergers40, Subfind_ID_nonmergers50))
q_mergers = np.concatenate((q_mergers40, q_mergers50))
sm_mergers = np.concatenate((sm_mergers40, sm_mergers50))
sm_nonmergers = np.concatenate((sm_nonmergers40, sm_nonmergers50))
sfr_mergers = np.concatenate((sfr_mergers40, sfr_mergers50))
sfr_nonmergers = np.concatenate((sfr_nonmergers40, sfr_nonmergers50))
for s in shids:
    #print(type(s))
    limit = 6
    if s in Subfind_ID:
        index = np.where(Subfind_ID == s)[0][0]
        #print(index[0][0])
        ratio = q_mergers[index]
        stellarmass = np.log10(sm_mergers[index])
        starformation = sfr_mergers[index]
        if ratio < 1:
            while limit != 0:
                q.append(ratio)
                sm.append(stellarmass) 
                sfr.append(starformation)
                limit-=1     
        else:
            while limit != 0:
                q.append(1/ratio)
                sm.append(stellarmass) 
                sfr.append(starformation)
                limit-=1
    else:
        index = np.where(Subfind_ID_nonmergers == s)[0][0]
        while limit != 0:
            q.append(0.0)
            stellarmass = np.log10(sm_nonmergers[index])
            sm.append(stellarmass) 
            starformation = sfr_nonmergers[index]
            sfr.append(starformation)
            limit-=1

q_name = []
for r in q:
    if r== 0.0:
        q_name.append('non')
    elif r >= 0.25:
        q_name.append('major')
    else: 
        q_name.append('minor')
#print(len(q))
ClassificationInfo['Mass Ratio'] = q
ClassificationInfo['Ratio Name'] = q_name
ClassificationInfo['Stellar Mass'] = sm
ClassificationInfo['SFR'] = sfr

In [None]:
ClassificationInfo

In [None]:
p_class = []
t_class = []

for l in ClassificationInfo['True Label']:
    if l == 0.0:
        t_class.append('merger')
    else:
        t_class.append('nonmerger')

for l in ClassificationInfo['Predicted Label']:
    if l == 0.0:
        p_class.append('merger')
    else:
        p_class.append('nonmerger')

ClassificationInfo['True Class'] = t_class
ClassificationInfo['Predicted Class'] = p_class

In [None]:
ClassificationInfo = ClassificationInfo.loc[:, ['SubhaloID', 'Image Name', 'True Class', 'Predicted Class', 'True Label', 'Predicted Label', 'Correct?', 'Type', 'Mass Ratio', 'Ratio Name', 'Stellar Mass', 'SFR']]

In [None]:
ClassificationInfo

In [None]:
ClassificationInfo.to_csv('CNN_ResultsTable_' + CNNName + '.csv', index=False)

#### How do we do on major vs minor mergers?

In [None]:
allmergers = ClassificationInfo[ClassificationInfo['Mass Ratio'] != 0.0]
allmergers_correct = allmergers[allmergers['Correct?'] == 'yes']
allmergers_accuracy = np.round(len(allmergers_correct)/len(allmergers) *100, 2)
print(allmergers_accuracy)

In [None]:
major = ClassificationInfo[ClassificationInfo['Ratio Name'] == 'major']
minor = ClassificationInfo[ClassificationInfo['Ratio Name'] == 'minor']
non = ClassificationInfo[ClassificationInfo['Ratio Name'] == 'non']

major_correct = major[major['Correct?'] == 'yes']
minor_correct = minor[minor['Correct?'] == 'yes']
non_correct = non[non['Correct?'] == 'yes']

major_accuracy = np.round(len(major_correct)/len(major) *100, 2)
minor_accuracy = np.round(len(minor_correct)/len(minor)*100, 2)
non_accuracy = np.round(len(non_correct)/len(non)*100, 2)

print(major_accuracy)
print(minor_accuracy)
print(non_accuracy)

In [None]:
corrects = ClassificationInfo[ClassificationInfo['Correct?'] == 'yes']
print(len(corrects)/len(ClassificationInfo))

#### How do we do on different merger stages?

In [None]:
merger= ClassificationInfo[ClassificationInfo['Type'] == 'Merger']
early = ClassificationInfo[(ClassificationInfo['Type'] == 'first_progenitor') | (ClassificationInfo['Type'] == 'next_progenitor')] #| is or but better for this case
late = ClassificationInfo[ClassificationInfo['Type'] == 'Descendant']

merger_correct = merger[merger['Correct?'] == 'yes']
early_correct = early[early['Correct?'] == 'yes']
late_correct = late[late['Correct?'] == 'yes']

merger_accuracy = np.round(len(merger_correct)/len(merger) *100,2)
early_accuracy = np.round(len(early_correct)/len(early)*100,2)
late_accuracy = np.round(len(late_correct)/len(late)*100,2)



print(merger_accuracy)
print(early_accuracy)
print(late_accuracy)

In [None]:
post = ClassificationInfo[(ClassificationInfo['Type'] == 'Merger') | (ClassificationInfo['Type'] == 'Descendant')] #| is or but better for this case
post_correct = post[post['Correct?'] == 'yes']
post_accuracy = len(post_correct)/len(post)
print(post_accuracy)

# Which galaxies are classified correctly from every angle?

In [None]:
every_angle_correct = []
for s in shids:
    s = s.astype(str)
    angle1 = s+'_1'
    angle2 = s+'_2'
    angle3 = s+'_3'
    angle4 = s+'_4'
    angle5 = s+'_5'
    angle6 = s+'_6'
    a1 = ClassificationInfo[ClassificationInfo['Image Name'] == angle1]
    a2 = ClassificationInfo[ClassificationInfo['Image Name'] == angle2]
    a3 = ClassificationInfo[ClassificationInfo['Image Name'] == angle3]
    a4 = ClassificationInfo[ClassificationInfo['Image Name'] == angle4]
    a5 = ClassificationInfo[ClassificationInfo['Image Name'] == angle5]
    a6 = ClassificationInfo[ClassificationInfo['Image Name'] == angle6]
    
    if a1['Correct?'].values[0] == 'yes' and a2['Correct?'].values[0] == 'yes' and a2['Correct?'].values[0] == 'yes'\
        and a4['Correct?'].values[0] == 'yes' and a5['Correct?'].values[0] == 'yes' and a6['Correct?'].values[0] == 'yes':
        every_angle_correct.append(int(s))
    #print(a1['Correct?'].values[0])
#ClassificationInfo

In [None]:
print(every_angle_correct)

In [None]:
EveryAngleCorrect = ClassificationInfo[ClassificationInfo['SubhaloID'].isin(every_angle_correct)]
EveryAngleCorrect

In [None]:
EveryAngleCorrect.to_csv('EveryAngleCorrect_' + CNNName + '.csv', index=False)

# Which galaxies are classified wrong from every angle?

In [None]:
every_angle_wrong = []
for s in shids:
    s = s.astype(str)
    angle1 = s+'_1'
    angle2 = s+'_2'
    angle3 = s+'_3'
    angle4 = s+'_4'
    angle5 = s+'_5'
    angle6 = s+'_6'
    a1 = ClassificationInfo[ClassificationInfo['Image Name'] == angle1]
    a2 = ClassificationInfo[ClassificationInfo['Image Name'] == angle2]
    a3 = ClassificationInfo[ClassificationInfo['Image Name'] == angle3]
    a4 = ClassificationInfo[ClassificationInfo['Image Name'] == angle4]
    a5 = ClassificationInfo[ClassificationInfo['Image Name'] == angle5]
    a6 = ClassificationInfo[ClassificationInfo['Image Name'] == angle6]
    
    if a1['Correct?'].values[0] == 'no' and a2['Correct?'].values[0] == 'no' and a2['Correct?'].values[0] == 'no'\
        and a4['Correct?'].values[0] == 'no' and a5['Correct?'].values[0] == 'no' and a6['Correct?'].values[0] == 'no':
        every_angle_wrong.append(int(s))
    #print(a1['Correct?'].values[0])
#ClassificationInfo
print(every_angle_wrong)

In [None]:
EveryAngleWrong = ClassificationInfo[ClassificationInfo['SubhaloID'].isin(every_angle_wrong)]
EveryAngleWrong

In [None]:
EveryAngleWrong.to_csv('EveryAngleWrong_' + CNNName + '.csv', index=False)

# What is the average number of correct angles per galaxy? Does this change with mass ratio?

In [None]:
num_angle_correct = []
num_angle_correct_nm = []
mass_ratio = []
for s in shids:
    s = s.astype(str)
    angle1 = s+'_1'
    angle2 = s+'_2'
    angle3 = s+'_3'
    angle4 = s+'_4'
    angle5 = s+'_5'
    angle6 = s+'_6'
    a1 = ClassificationInfo[ClassificationInfo['Image Name'] == angle1]
    a2 = ClassificationInfo[ClassificationInfo['Image Name'] == angle2]
    a3 = ClassificationInfo[ClassificationInfo['Image Name'] == angle3]
    a4 = ClassificationInfo[ClassificationInfo['Image Name'] == angle4]
    a5 = ClassificationInfo[ClassificationInfo['Image Name'] == angle5]
    a6 = ClassificationInfo[ClassificationInfo['Image Name'] == angle6]
    count = 0
    for a in [a1, a2, a3, a4, a5, a6]:
        if a['Correct?'].values[0] == 'yes':
            count+=1
    num_angle_correct.append(count)
    mass_ratio.append(a['Mass Ratio'].values[0])
num_angle_correct = np.array(num_angle_correct) 
mass_ratio = np.array(mass_ratio)     
print(np.mean(np.array(num_angle_correct)))    
print(np.median(np.array(num_angle_correct)))    
print(np.std(np.array(num_angle_correct)))    

In [None]:
bins = np.arange(-0.5,7,1)

In [None]:
inds = np.where(mass_ratio != 0.0)[0]

In [None]:
print(num_angle_correct[~inds])
print(num_angle_correct[inds])
print(np.mean(num_angle_correct[inds]))
print(np.median(num_angle_correct[inds]))

print(np.mean(num_angle_correct[~inds]))
print(np.median(num_angle_correct[~inds]))

In [None]:
angles_df_q = pd.DataFrame([])
angles_df_q['Number of Angles Correctly Classified'] = num_angle_correct[inds]
angles_df_q['Mass Ratio'] = mass_ratio[inds]

In [None]:
angles_q_ax2_ylabels = []
for c in range(7):
    count = len(angles_df_q[angles_df_q['Number of Angles Correctly Classified'] == c])
    angles_q_ax2_ylabels.append(str(count))

print(angles_q_ax2_ylabels)

In [None]:
fig, ax = plt.subplots(figsize = (8,6), constrained_layout= True)
sns.boxplot(data = angles_df_q, x="Mass Ratio", y = "Number of Angles Correctly Classified", 
            dodge = False, orient = 'h', palette='Purples').set(xlabel = 'Merger Mass Ratio')
ax.axvline(x = 0.25, ymin = 0, ymax = 1, color = 'black', linestyle = ':')
ax.set_xlabel('Merger Mass Ratio', fontsize = 'x-large')
ax.set_ylabel("Number of Angles Correctly Classified", fontsize = 'x-large')
ax.tick_params(labelsize = 'x-large')
plt.ylim(-.9, 6.9)
ax2 = ax.twinx()
ax2.set_ylim(ax.get_ylim())
ax2.set_yticks(ax.get_yticks())
ax2.tick_params(axis=u'both', which=u'both',length=0, labelsize = 'x-large')
ax2.set_yticklabels(angles_q_ax2_ylabels)
ax2.set_ylabel('Number of Galaxies in Bin', fontsize = 'x-large')
plt.savefig('AnglesCorrect_MassRatio' + CNNName + '.png', dpi = 300, bbox_inches="tight")
plt.show()


# What is the average number of angles correct based on mass?

In [None]:
num_angle_correct = []
num_angle_correct_nm = []
mass = []
for s in shids:
    s = s.astype(str)
    angle1 = s+'_1'
    angle2 = s+'_2'
    angle3 = s+'_3'
    angle4 = s+'_4'
    angle5 = s+'_5'
    angle6 = s+'_6'
    a1 = ClassificationInfo[ClassificationInfo['Image Name'] == angle1]
    a2 = ClassificationInfo[ClassificationInfo['Image Name'] == angle2]
    a3 = ClassificationInfo[ClassificationInfo['Image Name'] == angle3]
    a4 = ClassificationInfo[ClassificationInfo['Image Name'] == angle4]
    a5 = ClassificationInfo[ClassificationInfo['Image Name'] == angle5]
    a6 = ClassificationInfo[ClassificationInfo['Image Name'] == angle6]
    count = 0
    for a in [a1, a2, a3, a4, a5, a6]:
        if a['Correct?'].values[0] == 'yes':
            count+=1
    num_angle_correct.append(count)
    mass.append(a['Stellar Mass'].values[0])
num_angle_correct = np.array(num_angle_correct) 
mass = np.array(mass)     
print(np.mean(np.array(num_angle_correct)))    
print(np.median(np.array(num_angle_correct)))    
print(np.std(np.array(num_angle_correct)))    

In [None]:
inds = np.where(mass_ratio != 0.0)[0]

In [None]:
angles_df_sm_mergers = pd.DataFrame([])
angles_df_sm_mergers['Number of Angles Correctly Classified'] = num_angle_correct[inds]
angles_df_sm_mergers['Stellar Mass'] = mass[inds]

angles_df_sm_nonmergers = pd.DataFrame([])
angles_df_sm_nonmergers['Number of Angles Correctly Classified'] = num_angle_correct[~inds]
angles_df_sm_nonmergers['Stellar Mass'] = mass[~inds]

In [None]:
angles_sm_mergers_ax2_ylabels = []
for c in range(7):
    count = len(angles_df_sm_mergers[angles_df_sm_mergers['Number of Angles Correctly Classified'] == c])
    angles_sm_mergers_ax2_ylabels.append(str(count))

print(angles_sm_mergers_ax2_ylabels)

angles_sm_nonmergers_ax2_ylabels = []
for c in range(7):
    count = len(angles_df_sm_nonmergers[angles_df_sm_nonmergers['Number of Angles Correctly Classified'] == c])
    angles_sm_nonmergers_ax2_ylabels.append(str(count))

print(angles_sm_nonmergers_ax2_ylabels)

In [None]:
fig, ax = plt.subplots(figsize = (8,6), constrained_layout=True)
sns.boxplot(data = angles_df_sm_mergers, x="Stellar Mass", y = "Number of Angles Correctly Classified", 
            dodge = False, orient = 'h', palette='Purples').set(xlabel = r"M$_\star$ of Merger [LogM$_\odot$]")
plt.ylim(-.9, 6.9)
ax.set_xlabel(r"M$_\star$ of Merger [LogM$_\odot$]", fontsize = 'x-large')
ax.set_ylabel("Number of Angles Correctly Classified",fontsize = 'x-large')
ax.tick_params(labelsize = 'x-large')
ax2 = ax.twinx()
ax2.set_ylim(ax.get_ylim())
ax2.set_yticks(ax.get_yticks())
ax2.tick_params(axis=u'both', which=u'both',length=0, labelsize = 'x-large')
ax2.set_yticklabels(angles_sm_mergers_ax2_ylabels, fontsize = 'x-large')
ax2.set_ylabel('Number of Galaxies in Bin', fontsize = 'x-large')
#plt.tight_layout()
plt.savefig('AnglesCorrect_StellarMass_Merger' + CNNName + '.png', dpi = 300, bbox_inches="tight")
plt.show()


fig, ax = plt.subplots(figsize = (8,6), constrained_layout=True)
sns.boxplot(data = angles_df_sm_nonmergers, x="Stellar Mass", y = "Number of Angles Correctly Classified", 
            dodge = False, orient = 'h', palette='Purples').set(xlabel = r"M$_\star$ of Nonmerger [LogM$_\odot$]",)
plt.ylim(-.9, 6.9)
ax.set_xlabel(r"M$_\star$ of Nonmerger [LogM$_\odot$]", fontsize = 'x-large')
ax.set_ylabel("Number of Angles Correctly Classified",fontsize = 'x-large')
ax.tick_params(labelsize = 'x-large')
ax2 = ax.twinx()
ax2.set_ylim(ax.get_ylim())
ax2.set_yticks(ax.get_yticks())
ax2.tick_params(axis=u'both', which=u'both',length=0, labelsize = 'x-large')
ax2.set_yticklabels(angles_sm_nonmergers_ax2_ylabels, fontsize = 'x-large')
ax2.set_ylabel('Number of Galaxies in Bin', fontsize = 'x-large')
#plt.tight_layout()
plt.savefig('AnglesCorrect_StellarMass_Nonmerger' + CNNName + '.png', dpi = 300, bbox_inches="tight")
plt.show()


# Calibration Error

## Brier Score

In [None]:
brier = brier_score_loss(y_true = all_labels, y_proba = all_probabilities[:,0], pos_label = 0)
print(brier)

## Expected Calibration Error

In [None]:
#confidences = probabiliies (I think we use all_probabilities here)
print(all_probabilities)
print(np.min(all_probabilities))
print(type(all_labels[0]))

In [None]:
#https://towardsdatascience.com/expected-calibration-error-ece-a-step-by-step-visual-explanation-with-python-code-c3e9aa12937d/
def ECE(samples, labels, numbins = 10):
    # uniform binning approach with M number of bins
    bin_boundaries = np.linspace(0, 1, numbins + 1)
    bin_lowers = bin_boundaries[:-1]
    bin_uppers = bin_boundaries[1:]
    
    # get max probability per sample i
    confidences = np.max(samples, axis=1)
    # get predictions from confidences (positional in this case)
    predicted_label = np.argmax(samples, axis=1)
    
    # get a boolean list of correct/false predictions
    accuracies = predicted_label==labels

    ece = np.zeros(1)
    bin_accs = []
    bin_confs = []
    bin_nums = [] 
    for bin_lower, bin_upper in zip(bin_lowers, bin_uppers):
        # determine if sample is in bin m (between bin lower &amp; upper)
        in_bin = np.logical_and(confidences > bin_lower.item(), confidences <= bin_upper.item())
        nums= in_bin.sum()
        # can calculate the empirical probability of a sample falling into bin m: (|Bm|/n)
        prob_in_bin = in_bin.mean()

        if prob_in_bin.item() > 0:
            bin_nums.append(nums)
            # get the accuracy of bin m: acc(Bm)
            accuracy_in_bin = accuracies[in_bin].mean()
            bin_accs.append(accuracy_in_bin)
            # get the average confidence of bin m: conf(Bm)
            avg_confidence_in_bin = confidences[in_bin].mean()
            bin_confs.append(avg_confidence_in_bin)
            # calculate |acc(Bm) - conf(Bm)| * (|Bm|/n) for bin m and add to the total ECE
            ece += np.abs(avg_confidence_in_bin - accuracy_in_bin) * prob_in_bin
        else:
            bin_accs.append(np.nan)
            bin_confs.append(np.nan)
            bin_nums.append(np.nan)
    return ece, np.array(bin_accs), np.array(bin_confs), bin_nums

In [None]:
ece, accs, confs, nums = ECE(all_probabilities, all_labels)
gap = confs - accs

In [None]:
print(nums)

In [None]:
print(accs)

bin_boundaries = np.linspace(0, 1, 10 + 1)
bin_lowers = bin_boundaries[:-1]
bin_uppers = bin_boundaries[1:]

bin_centers = ((bin_uppers - bin_lowers) /2) + bin_boundaries[:-1]
print(bin_centers)

In [None]:
print(gap)

# Isomap

In [None]:
cmap = matplotlib.colormaps['magma']

mergers_color = cmap(0.2)
nonmergers_color = cmap(0.7)
colors = [mergers_color, nonmergers_color]  # List of colors
cmap_name = "my_cmap"  # Name for your colormap
cmap_binary = LinearSegmentedColormap.from_list(cmap_name, colors)
      


In [None]:
print(np.shape(all_extracted_features))
print(np.shape(all_logits))

In [None]:
print(all_names[0])
print(all_labels_for_isomap_plot)

In [None]:
# custom labels
wordlabels = []

for l in range(len(all_labels_for_isomap_plot)):
    if [l] == 0:
        wordlabels.append('merger')
    else:
        wordlabels.append('nonmerger')
        
print(wordlabels)

In [None]:
np.array(all_labels) == np.array(all_labels_for_isomap_plot)

In [None]:
# 3. Visualize the Isomap embedding (optional, but helpful)
merger_mask = (all_labels_for_isomap_plot == 0)
nonmerger_mask = (all_labels_for_isomap_plot == 1)


In [None]:
# custom pred labels
pred_wordlabels = []

for l in range(len(all_preds)):
    if all_preds[l] == 0:
        pred_wordlabels.append('merger')
    else:
        pred_wordlabels.append('nonmerger')
        
print(pred_wordlabels)

In [None]:
masses_in_order = []

for name in all_names:
    mass = ClassificationInfo.loc[ClassificationInfo['Image Name'] == str(name), 'Stellar Mass'].iloc[0]
    masses_in_order.append(mass)
masses_in_order = np.array(masses_in_order)
norm_mass = Normalize(vmin=np.min(masses_in_order), vmax=np.max(masses_in_order)+0.5)

In [None]:
ratios_in_order = []

for name in all_names:
    ratio = ClassificationInfo.loc[ClassificationInfo['Image Name'] == str(name), 'Mass Ratio'].iloc[0]
    ratios_in_order.append(ratio)
    
print(ratios_in_order)
ratios_in_order = np.array(ratios_in_order)
zero_ratio_indices = np.where(ratios_in_order == 0.0)[0]
non_zero_ratio_indices = np.where(ratios_in_order != 0.0)[0]



In [None]:
types_in_order = []

for name in all_names:
    typeofmerg = ClassificationInfo.loc[ClassificationInfo['Image Name'] == str(name), 'Type'].iloc[0]
    types_in_order.append(typeofmerg)
    
print(types_in_order)
types_in_order = np.array(types_in_order)
nostage = np.where(types_in_order == '0.0')[0]
earlystage = np.where((types_in_order == 'first_progenitor') | (types_in_order == 'next_progenitor'))[0]
latestage = np.where((types_in_order == 'Merger') | (types_in_order == 'Descendant'))[0]

In [None]:
sfrs_in_order = []

for name in all_names:
    sfr = ClassificationInfo.loc[ClassificationInfo['Image Name'] == str(name), 'SFR'].iloc[0]
    sfrs_in_order.append(sfr)
    
log_sfr_array = np.log10(np.array(sfrs_in_order))
norm_sfr = Normalize(vmin=np.min(log_sfr_array), vmax=np.max(log_sfr_array))



In [None]:
sfrs_in_order = []

for name in all_names:
    sfr = ClassificationInfo.loc[ClassificationInfo['Image Name'] == str(name), 'SFR'].iloc[0]
    sfrs_in_order.append(sfr)

In [None]:
sfrs_in_order = []

for name in all_names:
    sfr = ClassificationInfo.loc[ClassificationInfo['Image Name'] == str(name), 'SFR'].iloc[0]
    sfrs_in_order.append(sfr)

In [None]:
masses_in_order = []

for name in all_names:
    mass = ClassificationInfo.loc[ClassificationInfo['Image Name'] == str(name), 'Stellar Mass'].iloc[0]
    masses_in_order.append(mass)
masses_in_order = np.array(masses_in_order)

# Umap

In [None]:
reducer = umap.UMAP(n_components=2, n_neighbors=15, random_state=626)

umap_embedding = reducer.fit_transform(all_extracted_features)

In [None]:
print("UMAP embedding shape:", umap_embedding.shape) # Should be (num_samples, 2)

In [None]:
plt.figure(figsize=(8, 6))

# Define custom labels for the legend
custom_labels = {0: 'Merger', 1: 'Non-merger'}
legend_labels = [custom_labels[label] for label in sorted(np.unique(all_labels))]

# You can color the points by their true labels (all_labels)
# or by predicted labels (all_preds) to see how well they separate.
scatter = plt.scatter(
    umap_embedding[:, 0],
    umap_embedding[:, 1],
    c=all_labels, # Color by true labels
    cmap=cmap_binary, # Choose a colormap, e.g., 'viridis', 'plasma', 'coolwarm'
    s=25,           # Marker size
    alpha=0.7       # Transparency
)

# Add legend for true labels
handles, _ = scatter.legend_elements() # Get default handles for the scatter points
plt.legend(handles=handles, labels=legend_labels, title="True Labels")

plt.title('UMAP Projection of Model')
plt.xlabel('UMAP Component 1')
plt.ylabel('UMAP Component 2')
plt.grid(True, linestyle='--', alpha=0.6)
plt.show()

In [None]:
plt.figure(figsize=(8,6))

# Identify indices for mass ratio = 0 and mass ratio != 0
# Filter out NaNs if any were introduced
valid_indices = ~np.isnan(ratios_in_order)
zero_ratio_indices = np.where((ratios_in_order == 0) & valid_indices)[0]
non_zero_ratio_indices = np.where((ratios_in_order != 0) & valid_indices)[0]

# 1. Plot galaxies with non-zero mass ratio (continuous colormap)
scatter_non_zero = plt.scatter(
    umap_embedding[non_zero_ratio_indices, 0],
    umap_embedding[non_zero_ratio_indices, 1],
    c=ratios_in_order[non_zero_ratio_indices], # Color by mass ratio
    cmap='magma',       # Choose a colormap for continuous data (e.g., 'viridis', 'plasma', 'magma')
    alpha=0.7,
    s = np.array(ratios_in_order[non_zero_ratio_indices])*100,
    label='Mergers'
)

# Add a colorbar for the continuous mass ratios
cbar = plt.colorbar(scatter_non_zero)
cbar.set_label('Mass Ratio')

# 2. Plot galaxies with mass ratio = 0 (distinct fixed color)
scatter_zero = plt.scatter(
    umap_embedding[zero_ratio_indices, 0],
    umap_embedding[zero_ratio_indices, 1],
    color='cornflowerblue',       # Choose a distinct color (e.g., 'red', 'blue', 'lime', 'cyan')
    alpha=0.8,
    s=20,               # Make them slightly larger or use a different marker for emphasis
    marker='^',
    label='Nonmergers'
)

plt.title('UMAP Projection (colored by Mass Ratio)')
plt.xlabel('UMAP Component 1')
plt.ylabel('UMAP Component 2')
plt.grid(True, linestyle='--', alpha=0.6)

# Add a legend for the fixed-color 'Mass Ratio = 0' points
plt.legend(loc='best') # 'best' attempts to place it where it won't overlap much

plt.show()


In [None]:
vmin = masses_in_order.min()
vmax = masses_in_order.max()
plt.figure(figsize=(8,6)) # Adjusted for colorbar

# 1. Plot galaxies with non-zero mass ratio (continuous colormap)
scatter_non_zero = plt.scatter(
    umap_embedding[non_zero_ratio_indices, 0],
    umap_embedding[non_zero_ratio_indices, 1],
    c=masses_in_order[non_zero_ratio_indices], # Color by mass ratio
    cmap='magma',#'magma',  
    vmin=vmin, vmax=vmax,  # Sync color scale,
    # Choose a colormap for continuous data (e.g., 'viridis', 'plasma', 'magma')
    alpha=0.8,
    s = 30,
    label='Mergers'
)

# Add a colorbar for the continuous mass ratios
cbar = plt.colorbar(scatter_non_zero)
cbar.set_label(r'Log[Stellar Mass $M_\odot$]', fontsize = 'x-large')
cbar.ax.tick_params(labelsize='large')

# 2. Plot galaxies with mass ratio = 0 (distinct fixed color)
scatter_zero = plt.scatter(
    umap_embedding[zero_ratio_indices, 0],
    umap_embedding[zero_ratio_indices, 1],
    c=masses_in_order[zero_ratio_indices],       # Choose a distinct color (e.g., 'red', 'blue', 'lime', 'cyan')
    alpha=0.8,
    cmap = 'magma',
    vmin=vmin, vmax=vmax,  # Sync color scale
    s=40,               # Make them slightly larger or use a different marker for emphasis
    marker='^',
    edgecolors='black',
    linewidths=0.5,
    label='Nonmergers'
)

#plt.title('UMAP projection (colored by Galaxy Mass)')
# plt.xlabel('UMAP Component 1')
# plt.ylabel('UMAP Component 2')
plt.tick_params(axis='both',          # Apply to both x and y axes
                which='both',         # Apply to both major and minor ticks
                bottom=False,         # Turn off ticks along the bottom edge
                top=False,            # Turn off ticks along the top edge
                left=False,           # Turn off ticks along the left edge
                right=False,          # Turn off ticks along the right edge
                labelbottom=False,    # Turn off labels along the bottom edge
                labelleft=False) 
# # Add a colorbar to show the mass scale
# cbar = plt.colorbar(scatter)
# cbar.set_label('Stellar Mass') # Update label with units
plt.legend()
plt.tight_layout()
plt.savefig('umap_stellarmass.png', dpi = 300)
plt.show()

In [None]:
ssfrs_in_order = np.array(np.array(sfrs_in_order)/((np.array(masses_in_order)**(10))))
ssfrs_in_order = np.log10(ssfrs_in_order)
vmin = ssfrs_in_order.min()
vmax = ssfrs_in_order.max()
#print(ssfrs_in_order)
plt.figure(figsize=(8,6)) # Adjusted for colorbar
# 1. Plot galaxies with non-zero mass ratio (continuous colormap)
scatter_non_zero = plt.scatter(
    umap_embedding[non_zero_ratio_indices, 0],
    umap_embedding[non_zero_ratio_indices, 1],
    c=ssfrs_in_order[non_zero_ratio_indices], # Color by mass ratio
    cmap='magma',#'magma',  
    vmin=vmin, vmax=vmax,  # Sync color scale,
    # Choose a colormap for continuous data (e.g., 'viridis', 'plasma', 'magma')
    alpha=0.8,
    s = 30,
    label='Mergers'
)

# Add a colorbar for the continuous mass ratios
cbar = plt.colorbar(scatter_non_zero)
cbar.set_label(r'Log[sSFR $M_\odot/yr/M_\star$]', fontsize = 'x-large')
cbar.ax.tick_params(labelsize='large')
# 2. Plot galaxies with mass ratio = 0 (distinct fixed color)
scatter_zero = plt.scatter(
    umap_embedding[zero_ratio_indices, 0],
    umap_embedding[zero_ratio_indices, 1],
    c=ssfrs_in_order[zero_ratio_indices],       # Choose a distinct color (e.g., 'red', 'blue', 'lime', 'cyan')
    alpha=0.8,
    cmap = 'magma',
    vmin=vmin, vmax=vmax,  # Sync color scale
    s=40,               # Make them slightly larger or use a different marker for emphasis
    marker='^',
    edgecolors='black',
    linewidths=0.5,
    label='Nonmergers'
)

# plt.title('UMAP projection (colored by sSFR)')
# plt.xlabel('UMAP Component 1')
# plt.ylabel('UMAP Component 2')
plt.tick_params(axis='both',          # Apply to both x and y axes
                which='both',         # Apply to both major and minor ticks
                bottom=False,         # Turn off ticks along the bottom edge
                top=False,            # Turn off ticks along the top edge
                left=False,           # Turn off ticks along the left edge
                right=False,          # Turn off ticks along the right edge
                labelbottom=False,    # Turn off labels along the bottom edge
                labelleft=False) 
# # Add a colorbar to show the mass scale
# cbar = plt.colorbar(scatter)
# cbar.set_label('Stellar Mass') # Update label with units
plt.legend()
plt.tight_layout()
plt.savefig('umap_ssfr.png', dpi = 300)

plt.show()

In [None]:
vmin_mass = masses_in_order.min()
vmax_mass = masses_in_order.max()
vmin_ssfr = ssfrs_in_order.min()
vmax_ssfr = ssfrs_in_order.max()
fig, (ax1, ax2) = plt.subplots(1,2, figsize = (12, 6))

# 1. Plot galaxies with non-zero mass ratio (continuous colormap)
scatter_non_zero = ax1.scatter(
    umap_embedding[non_zero_ratio_indices, 0],
    umap_embedding[non_zero_ratio_indices, 1],
    c=masses_in_order[non_zero_ratio_indices], # Color by mass ratio
    cmap='magma',#'magma',  
    vmin=vmin_mass, vmax=vmax_mass,  # Sync color scale,
    # Choose a colormap for continuous data (e.g., 'viridis', 'plasma', 'magma')
    alpha=0.8,
    s = 30,
    label='Mergers'
)

# Add a colorbar for the continuous mass ratios
cbar = fig.colorbar(scatter_non_zero, ax=ax1)
cbar.set_label(r'M$_\star$ [Log$M_\odot$]', fontsize='x-large')
cbar.ax.tick_params(labelsize='large')

# 2. Plot galaxies with mass ratio = 0 (distinct fixed color)
scatter_zero = ax1.scatter(
    umap_embedding[zero_ratio_indices, 0],
    umap_embedding[zero_ratio_indices, 1],
    c=masses_in_order[zero_ratio_indices],       # Choose a distinct color (e.g., 'red', 'blue', 'lime', 'cyan')
    alpha=0.8,
    cmap = 'magma',
    vmin=vmin_mass, vmax=vmax_mass,  # Sync color scale
    s=40,               # Make them slightly larger or use a different marker for emphasis
    marker='^',
    edgecolors='black',
    linewidths=0.5,
    label='Nonmergers'
)


ax1.tick_params(axis='both',          # Apply to both x and y axes
                which='both',         # Apply to both major and minor ticks
                bottom=False,         # Turn off ticks along the bottom edge
                top=False,            # Turn off ticks along the top edge
                left=False,           # Turn off ticks along the left edge
                right=False,          # Turn off ticks along the right edge
                labelbottom=False,    # Turn off labels along the bottom edge
                labelleft=False) 

ax1.legend()

scatter_non_zero = ax2.scatter(
    umap_embedding[non_zero_ratio_indices, 0],
    umap_embedding[non_zero_ratio_indices, 1],
    c=ssfrs_in_order[non_zero_ratio_indices], # Color by mass ratio
    cmap='magma',#'magma',  
    vmin=vmin_ssfr, vmax=vmax_ssfr,  # Sync color scale,
    # Choose a colormap for continuous data (e.g., 'viridis', 'plasma', 'magma')
    alpha=0.8,
    s = 30,
    label='Mergers'
)

# Add a colorbar for the continuous mass ratios
cbar2 = fig.colorbar(scatter_non_zero, ax=ax2)
cbar2.set_label(r'sSFR [Logyr$^{-1}$]', fontsize='x-large')
cbar2.ax.tick_params(labelsize='large')
# 2. Plot galaxies with mass ratio = 0 (distinct fixed color)
scatter_zero = ax2.scatter(
    umap_embedding[zero_ratio_indices, 0],
    umap_embedding[zero_ratio_indices, 1],
    c=ssfrs_in_order[zero_ratio_indices],       # Choose a distinct color (e.g., 'red', 'blue', 'lime', 'cyan')
    alpha=0.8,
    cmap = 'magma',
    vmin=vmin_ssfr, vmax=vmax_ssfr,  # Sync color scale
    s=40,               # Make them slightly larger or use a different marker for emphasis
    marker='^',
    edgecolors='black',
    linewidths=0.5,
    label='Nonmergers'
)

# ax2.title('UMAP projection (colored by sSFR)')
# ax2.xlabel('UMAP Component 1')
# ax2.ylabel('UMAP Component 2')
ax2.tick_params(axis='both',          # Apply to both x and y axes
                which='both',         # Apply to both major and minor ticks
                bottom=False,         # Turn off ticks along the bottom edge
                top=False,            # Turn off ticks along the top edge
                left=False,           # Turn off ticks along the left edge
                right=False,          # Turn off ticks along the right edge
                labelbottom=False,    # Turn off labels along the bottom edge
                labelleft=False) 
# # Add a colorbar to show the mass scale
# cbar = ax2.colorbar(scatter)
# cbar.set_label('Stellar Mass') # Update label with units
ax2.legend()

plt.tight_layout()
plt.savefig('UMAP_mass_ssfr_panel.png', dpi=300)
plt.show()

## UMAP clump

In [None]:
#figure out which galaxies are in that clump so we can look at why they are so different
print(all_names[umap_embedding[:,0] > 6])