In [1]:
!pip install zoobot
!pip install pyro-ppl
!pip install pytorch_lightning
!pip install timm

Defaulting to user installation because normal site-packages is not writeable
Defaulting to user installation because normal site-packages is not writeable
Defaulting to user installation because normal site-packages is not writeable
Defaulting to user installation because normal site-packages is not writeable


In [2]:
import os
import time
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Subset
from astropy.io import fits
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, accuracy_score, precision_score, recall_score, roc_curve, auc
from sklearn.preprocessing import label_binarize
from zoobot.pytorch.training import finetune
from zoobot.pytorch.training.finetune import FinetuneableZoobotClassifier, LinearHead
from torchvision.models import efficientnet_b0

In [3]:
# Aggressive arcsinh scaling
def aggressive_arcsinh_scaling(image):
    """
    Apply aggressive arcsinh scaling to enhance low surface brightness features.
    Steps:
      - Compute the arcsinh of the image.
      - Replace pixel values below the median with the median.
      - Clip high pixel values using the 90th percentile of the central region.
      - Normalize the image to the [0,1] range.
    """
    # If image has more than 2 dimensions; first dimension is the channel
    #print(f"{image.ndim}\n")
    if image.ndim > 2:
        image = image[0]
    
    image_scaled = np.arcsinh(image)
    median_val = np.median(image_scaled)
    image_scaled[image_scaled < median_val] = median_val

    # Define a central region (e.g. 80x80 pixels)
    h, w = image_scaled.shape
    cx, cy = h // 2, w // 2
    central_region = image_scaled[max(cx - 40, 0):min(cx + 40, h), max(cy - 40, 0):min(cy + 40, w)]
    threshold = np.percentile(central_region, 90)

    image_scaled[image_scaled > threshold] = threshold

    # Normalize to [0, 1]
    image_norm = (image_scaled - np.min(image_scaled)) / (np.max(image_scaled) - np.min(image_scaled) + 1e-8)
    return image_norm

# Dataset for Classification
# class ClassificationDataset(Dataset):
#     def __init__(self, datadir, labels, transform):
#         """
#         Args:
#           datadir (str): Base directory containing the FITS images.
#           labels: FITS HDU with a 'data' attribute.
#           transform (callable): Function to apply to the raw image.
#         """
#         self.datadir = datadir
#         self.labels = labels  # FITS HDU with structured data (labels.data)
#         self.transform = transform
#         self.length = len(self.labels.data)
    
#     def __len__(self):
#         return self.length
    
#     def __getitem__(self, idx):
#         label_entry = self.labels.data[idx]
#         ID = label_entry["ID"]
#         snap = label_entry["snapnum"]
#         # Here, we use the 'time_before_merger' column.
#         # In the classification, we define three classes:
#         # 0: non-merger, 1: pre-merger, 2: post-merger.
#         tim = label_entry['time_before_merger']

#         # file path
#         file_path = os.path.join(
#             self.datadir,
#             f"mock_v4/F150W/L75n1820TNG/snapnum_0{snap}/xy/JWST_50kpc_F150W_TNG100_sn0{snap}_xy_broadband_{ID}.fits"
#         )

#         with fits.open(file_path) as hdul:
#             img_data = hdul[0].data  # assuming primary HDU contains the image
        
#         # Convert to float32 and ensure correct byte order
#         img_data = img_data.astype(np.float32, copy=False)
#         img_data = img_data.newbyteorder("=")
        
#         # Apply transformation
#         img_transformed = self.transform(img_data)
        
#         # Convert to tensor and add a channel dimension (assumes grayscale image)
#         img_tensor = torch.tensor(img_transformed).unsqueeze(0)  # shape: [1, H, W]

#         # Replicate the single channel to create a 3-channel image: shape becomes [3, H, W]
#         img_tensor = img_tensor.repeat(3, 1, 1)

#         # Define the class label based on merger time.
#         # For example, we adopt the paper’s default:
#         # pre-merger: t_merger between -0.8 and -0.1 Gyr -> label 1
#         # post-merger: t_merger between 0.1 and 0.3 Gyr -> label 2
#         # Otherwise, non-merger -> label 0
#         if (-0.8 <= tim <= -0.1):
#             class_label = 1  # pre-merger
#         elif (0.1 <= tim <= 0.3):
#             class_label = 2  # post-merger
#         else:
#             class_label = 0  # non-merger
        
#         label_tensor = torch.tensor(class_label, dtype=torch.long)
#         return img_tensor, label_tensor

class ClassificationDataset(Dataset):
    def __init__(self, datadir, labels, transform):
        """
        Args:
          datadir (str): Base directory containing the FITS images.
          labels: FITS HDU (labels.data) with columns 'is_pre_merger', 'is_ongoing_merger', 'is_post_merger'.
          transform (callable): Function to apply to the raw image.
        """
        self.datadir = datadir
        self.labels = labels
        self.transform = transform
        self.length = len(self.labels.data)
    
    def __len__(self):
        return self.length
    
    def __getitem__(self, idx):
        row = self.labels.data[idx]
        ID = row["ID"]
        snap = row["snapnum"]
        
        # Build file path for your FITS image
        file_path = os.path.join(
            self.datadir,
            f"mock_v4/F150W/L75n1820TNG/snapnum_0{snap}/xy/"
            f"JWST_50kpc_F150W_TNG100_sn0{snap}_xy_broadband_{ID}.fits"
        )
        
        # Load FITS image
        with fits.open(file_path) as hdul:
            img_data = hdul[0].data.astype(np.float32).newbyteorder("=")
        
        # Apply your transformation (e.g., arcsinh scaling)
        img_transformed = self.transform(img_data)
        
        # Convert to tensor and replicate channel to get shape [3, H, W]
        img_tensor = torch.tensor(img_transformed).unsqueeze(0).repeat(3, 1, 1)
        
        # Read flags
        is_pre = (row['is_pre_merger'] == 1)
        is_ongoing = (row['is_ongoing_merger'] == 1)
        is_post = (row['is_post_merger'] == 1)
        
        # 3-class labeling logic:
        # 0 => non-merger
        # 1 => pre-merger
        # 2 => ongoing OR post
        if is_pre:
            class_label = 1
        elif is_ongoing or is_post:
            class_label = 2
        else:
            class_label = 0
        
        label_tensor = torch.tensor(class_label, dtype=torch.long)
        return img_tensor, label_tensor

# Training+evaluation functions

def train_epoch(model, loader, criterion, optimizer, device):
    model.train()
    running_loss, correct, total = 0.0, 0, 0
    for inputs, labels in loader:
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item() * inputs.size(0)
        _, preds = torch.max(outputs, 1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)
    return running_loss / total, correct / total

def evaluate(model, loader, criterion, device):
    model.eval()
    running_loss, correct, total = 0.0, 0, 0
    all_preds, all_labels = [], []
    with torch.no_grad():
        for inputs, labels in loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            running_loss += loss.item() * inputs.size(0)
            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    return running_loss / total, correct / total, all_preds, all_labels

def compute_metrics(y_true, y_pred):
    cm = confusion_matrix(y_true, y_pred)
    acc = accuracy_score(y_true, y_pred)
    # Compute per-class precision and recall (do not average so you can see the performance for non-mergers, pre- and post-mergers)
    prec = precision_score(y_true, y_pred, average=None, zero_division=0)
    rec = recall_score(y_true, y_pred, average=None, zero_division=0)
    return cm, acc, prec, rec

def evaluate_with_probabilities(model, loader, device):
    model.eval()
    all_probs, all_labels = [], []
    with torch.no_grad():
        for inputs, labels in loader:
            inputs = inputs.to(device)
            out = model(inputs)
            probs = torch.softmax(out, dim=1)
            all_probs.append(probs.cpu().numpy())
            all_labels.append(labels.numpy())
    return np.concatenate(all_probs, axis=0), np.concatenate(all_labels, axis=0)

def plot_confusion_matrix(y_true, y_pred, classes=None, normalize=False, title=None, cmap=plt.cm.Blues):
    """
    From scikit-learn: plots a confusion matrix.
    Normalization can be applied by setting `normalize=True`.
    """
    if not title:
        if normalize:
            title = "Normalized confusion matrix"
        else:
            title = "Confusion matrix"

    # Compute confusion matrix
    if len(y_true.shape) > 1 and len(y_pred.shape) > 1:
        cm = confusion_matrix(np.argmax(y_true, axis=1), np.argmax(y_pred, axis=1))
    else:
        cm = confusion_matrix(y_true, y_pred)

    if normalize:
        cm = cm.astype("float") / cm.sum(axis=1)[:, np.newaxis]

    fig, ax = plt.subplots()
    im = ax.imshow(cm, interpolation="nearest", cmap=cmap, origin="lower")
    cbar = ax.figure.colorbar(im, ax=ax)
    cbar.set_label(title)

    ax.set(
        xticks=np.arange(cm.shape[1]),
        yticks=np.arange(cm.shape[0]),
        ylabel="True label",
        xlabel="Predicted label",
    )

    if classes is not None:
        tick_marks = np.arange(len(classes))
        plt.xticks(tick_marks, classes, rotation=45)
        plt.yticks(tick_marks, classes)

    # Loop over data dimensions and create text annotations.
    fmt = ".2f" if normalize else "d"
    thresh = cm.max() / 2.0
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            ax.text(j, i, format(cm[i, j], fmt), ha="center", va="center", color="white" if cm[i, j] > thresh else "black")
    fig.tight_layout()

In [4]:
# Load labels and data
label_file = "/net/virgo01/data/users/spirov/Deep/catalog_tng100_jwst_all_50sns.fits"
labels = fits.open(label_file)[1]  # second HDU

datadir = "/net/virgo01/data/users/mahesh/DeepLearning/data/"

# Create the dataset using the scaling
dataset = ClassificationDataset(datadir, labels, aggressive_arcsinh_scaling)

# dataset splits
total_samples = len(dataset)
indices = list(range(total_samples))
train_split = int(0.81 * total_samples)
val_split = int(0.9 * total_samples)
train_indices = indices[:train_split]
val_indices = indices[train_split:val_split]
test_indices = indices[val_split:]

train_dataset = Subset(dataset, train_indices)
val_dataset = Subset(dataset, val_indices)
test_dataset = Subset(dataset, test_indices)

batch_size = 128
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

train_counts = [0, 0, 0]
for i in train_indices:
    row = labels.data[i]
    if row['is_pre_merger'] == 1: 
        train_counts[1] += 1
    elif row['is_ongoing_merger'] == 1 or row['is_post_merger'] == 1: 
        train_counts[2] += 1
    else: 
        train_counts[0] += 1
weights = 1.0 / np.array(train_counts)
weights /= weights.sum()
weights_t = torch.tensor(weights, dtype=torch.float)

# Model Setup with Zoobot (One-Stage Classification)
# load the Zoobot pre-trained EfficientNet-B0 model
model = FinetuneableZoobotClassifier(
    name='hf_hub:mwalmsley/zoobot-encoder-efficientnet_b0',
    n_blocks=0,              # Fine-tune only the head
    learning_rate=1e-5,      # Use their learning rate
    lr_decay=0.5,
    num_classes=3            # Three classes: non-merger, pre-merger, post-merger
)
# uncomment below to use torchvision's EfficientNet-B0 (note: this model is pre-trained on ImageNet, not on Galaxy Zoo labels)
#model = efficientnet_b0(pretrained=True)

# Modify the classifier head to output three classes
if hasattr(model, 'head'):
    in_features = model.head.linear.in_features
    # Update the dropout probability (optional
    model.head.dropout.p = 0.2
    
    model.head.linear = nn.Linear(in_features, 3)
else:
    raise ValueError("Model structure not recognized. Please adjust or face my wrath.")

# Use GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

In [5]:
#model

In [6]:
with fits.open(label_file) as hdul:
    data = hdul[1].data  # The second HDU

n_total = len(data)

# Count how many are flagged as each type
n_pre = np.sum(data['is_pre_merger'] == 1)
n_ongoing = np.sum(data['is_ongoing_merger'] == 1)
n_post = np.sum(data['is_post_merger'] == 1)

# A simple definition of 'non-merger' is galaxies that have none of the three flags
n_non = n_total - (n_pre + n_ongoing + n_post)

print(f"Total galaxies: {n_total}")
print(f"Pre-merger: {n_pre} ({n_pre / n_total * 100:.2f}%)")
print(f"Ongoing merger: {n_ongoing} ({n_ongoing / n_total * 100:.2f}%)")
print(f"Post-merger: {n_post} ({n_post / n_total * 100:.2f}%)")
print(f"Non-merger: {n_non} ({n_non / n_total * 100:.2f}%)")

Total galaxies: 58436
Pre-merger: 1267 (2.17%)
Ongoing merger: 511 (0.87%)
Post-merger: 636 (1.09%)
Non-merger: 56022 (95.87%)


In [7]:
# Hyperparameters (following the paper's one-stage configuration)
optimizer = optim.Adam(model.parameters(), lr=1e-5, weight_decay=0.005)
criterion = nn.CrossEntropyLoss(weight=weights_t.to(device))

# training and validation Loop

num_epochs = 5
best_val_loss = np.inf
train_losses, train_accuracies = [], []
val_losses, val_accuracies = [], []

for epoch in range(num_epochs):
    last_print_time = time.time()
    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
    val_loss, val_acc, val_preds, val_labels = evaluate(model, val_loader, criterion, device)
    current_time = time.time()
    elapsed = current_time - last_print_time
    print(f"\n[{elapsed:.1f}s elapsed] Epoch {epoch+1}/{num_epochs}: "
          f"Train Loss = {train_loss:.3f}, Train Acc = {train_acc:.3f}; "
          f"Val Loss = {val_loss:.3f}, Val Acc = {val_acc:.3f}")

    train_losses.append(train_loss)
    train_accuracies.append(train_acc)
    val_losses.append(val_loss)
    val_accuracies.append(val_acc)
    
    # Save the best model based on validation loss
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), "best_model.pth")


[418.8s elapsed] Epoch 1/5: Train Loss = 1.093, Train Acc = 0.903; Val Loss = 1.055, Val Acc = 0.901

[218.1s elapsed] Epoch 2/5: Train Loss = 1.088, Train Acc = 0.828; Val Loss = 1.052, Val Acc = 0.808

[216.5s elapsed] Epoch 3/5: Train Loss = 1.082, Train Acc = 0.768; Val Loss = 1.051, Val Acc = 0.728

[237.6s elapsed] Epoch 4/5: Train Loss = 1.077, Train Acc = 0.742; Val Loss = 1.048, Val Acc = 0.687

[218.4s elapsed] Epoch 5/5: Train Loss = 1.071, Train Acc = 0.711; Val Loss = 1.045, Val Acc = 0.655


In [None]:
# Evaluation on Test Set
test_loss, test_acc, test_preds, test_labels = evaluate(model, test_loader, criterion, device)
print(f"\nTest Loss: {test_loss:.3f}")
print(f"Test Accuracy: {test_acc:.3f}")

plt.figure(figsize=(8, 6))
plt.plot(range(1, num_epochs+1), train_losses, label="Train Loss")
plt.plot(range(1, num_epochs+1), val_losses, label="Validation Loss")
plt.xticks(range(num_epochs+1))
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Loss vs. Epoch")
plt.legend()
plt.show()

plt.figure(figsize=(8, 6))
plt.plot(range(1, num_epochs+1), train_accuracies, label="Train Accuracy")
plt.plot(range(1, num_epochs+1), val_accuracies, label="Validation Accuracy")
plt.xticks(range(num_epochs+1))
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.title("Accuracy vs. Epoch")
plt.legend()
plt.show()

#confusion matrix
cm = confusion_matrix(test_labels, test_preds)
print("Confusion Matrix:\n", cm)

plt.figure(figsize=(6,5))
plot_confusion_matrix(
    y_true=test_labels,
    y_pred=test_preds,
    classes=['Non-merger', 'Pre-merger', 'Post-merger'],
    normalize=True
)

# Compute probabilities on the test set
model.load_state_dict(torch.load("best_model.pth"))
test_probs, test_labels = evaluate_with_probabilities(model, test_loader, device)
test_preds = np.argmax(test_probs, axis=1)

# plt.figure(figsize=(6,5))
# plt.imshow(cm, cmap=plt.cm.Blues)
# plt.title("Confusion Matrix")
# plt.colorbar()
# plt.xticks([0,1,2], ['Non-merger','Pre-merger','Post-merger'], rotation=45)
# plt.yticks([0,1,2], ['Non-merger','Pre-merger','Post-merger'])
# th = cm.max() / 2.
# for i in range(cm.shape[0]):
#     for j in range(cm.shape[1]):
#         plt.text(j, i, cm[i,j],
#                  ha="center", va="center",
#                  color="white" if cm[i,j] > th else "black")
# plt.tight_layout()
# plt.show()

# ROC Curves for Multi-class (One-vs-Rest)
n_classes = 3
# Binarize the true labels
test_labels_bin = label_binarize(test_labels, classes=[0, 1, 2])
fpr, tpr, roc_auc = {}, {}, {}

for i in range(n_classes):
    fpr[i], tpr[i], _ = roc_curve(test_labels_bin[:,i], test_probs[:,i])
    roc_auc[i] = auc(fpr[i], tpr[i])

# Plot ROC curves for each class
plt.figure(figsize=(8,6))
colors = ['blue','red','green']
for i,c in zip(range(n_classes),colors):
    plt.plot(fpr[i], tpr[i], color=c, label=f'Class {i} (AUC={roc_auc[i]:.2f})')
plt.plot([0,1],[0,1],'k--')
plt.xlim([0,1]); plt.ylim([0,1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Multi-class ROC Curves')
plt.legend(loc="lower right")
plt.show()