In [None]:
import torch

import os
from astropy.io import fits
import matplotlib.pyplot as plt
import numpy as np

from torch.utils.data import Dataset, DataLoader

from datasets import SolarDataset, normalize_standard, preprocess_clip_wrapper, preprocess_dino,load_file_names_and_classes_for_test, prepare_dataloaders, find_all_fits_files, load_filenames

from utils import test_visualize_images_all_cluster_zero, visualize_batch_images_from_cluster, visualize_histograms_from_cluster, visualize_np_images_from_cluster

im_size = 32
mode="clip"  
from general import DATA_PATH, TEST_PATH

import os
import glob
# Set the environment variable for the current session
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"

from utils import *

from sklearn.svm import OneClassSVM
from sklearn.ensemble import IsolationForest


In [None]:
im_shape = (3, 224, 224)

if mode == "standard":
    im_shape = (1, im_size, im_size)

In [None]:
train_dataloader, test_dataloader = prepare_dataloaders(mode=mode, im_size=im_size, batch_size=16)

property_names = ['thermal_component', 'class', "data", "filename"]
train_features, train_properties = extract_features(train_dataloader, property_names)
test_features, test_properties = extract_features(test_dataloader, property_names)

train_properties['thermal_component_vectorized'] = replace_string_values(train_properties['thermal_component'])
test_properties['thermal_component_vectorized'] = replace_string_values(test_properties['thermal_component'])

if len(train_properties["data"].shape) != 4:
    train_properties["data"] = train_properties["data"].reshape(train_features.shape[0], 3, *im_shape[1:])
    test_properties["data"] = test_properties["data"].reshape(test_features.shape[0], 3, *im_shape[1:])

# We first fit isolation forest to the test set, and then apply it to same test set to check what is considered as anomaly and then to big unlabaled set. This solution is higly dependent on prechosen test set.

In [None]:
current_features = test_features
current_properties = test_properties

isolation_forest = IsolationForest(n_estimators=900, contamination='auto', random_state=42)

isolation_forest.fit(current_features)

distances = isolation_forest.decision_function(current_features)
cluster_labels_test = (distances <0.0)*1

print(combine_predictions(cluster_labels_test, test_properties["class"],))

for i in np.unique(cluster_labels_test):
    visualize_np_images_from_cluster(i, cluster_labels_test, current_properties["data"], max_images=300, norm_func=lambda x : x)
    plt.show()

cluster_labels = (isolation_forest.decision_function(train_features) <0.0)*1

for i in np.unique(cluster_labels):
    visualize_np_images_from_cluster(i, cluster_labels, train_properties["data"], max_images=900, norm_func=lambda x : x)
    plt.show()

In [None]:
def visualize_np_images_from_cluster(cluster_num, cluster_assignments, images, max_images=32, norm_func=np.log1p):
    cluster_indices = [i for i, cluster_id in enumerate(cluster_assignments) if cluster_id == cluster_num]

    num_images = min(len(cluster_indices), max_images)

    num_cols = int(math.ceil(math.sqrt(num_images)))
    while num_images > num_cols * (num_cols // 4): 
        num_cols += 1

    num_rows = max(1, num_cols // 4)

    fig, axs = plt.subplots(num_rows, num_cols, figsize=(2 * num_cols, 2 * num_rows), squeeze=False)
    axs = axs.flatten() 

    for ax in axs[num_images:]:  
        ax.axis('off')

    count = 0
    for i in cluster_indices[:num_images]: 
        img = norm_func(images[i,0]) 

        axs[count].imshow(img, cmap='inferno') 
        axs[count].axis('off')
        count += 1



In [None]:

isolation_forest = IsolationForest(n_estimators=900, contamination='auto', random_state=42)
isolation_forest.fit(current_features)

distances = isolation_forest.decision_function(current_features)
cluster_labels_test = (distances < 0.0) * 1

sorted_indices = np.argsort(distances)

sorted_cluster_labels_test = cluster_labels_test[sorted_indices]
sorted_test_data = current_properties["data"][sorted_indices]

for i in np.unique(sorted_cluster_labels_test):
    cluster_indices = np.where(sorted_cluster_labels_test == i)[0]
    
    visualize_np_images_from_cluster(i, sorted_cluster_labels_test, sorted_test_data, max_images=300, norm_func=lambda x: x, )
    plt.savefig(f"outputs/isolation_forest_test_set_{i}.png", dpi=300)

train_distances = isolation_forest.decision_function(train_features)
train_sorted_indices = np.argsort(train_distances)
sorted_cluster_labels = (train_distances[train_sorted_indices] < 0.0) * 1
sorted_train_data = train_properties["data"][train_sorted_indices]

for i in np.unique(sorted_cluster_labels):
    cluster_indices = np.where(sorted_cluster_labels == i)[0]
    
    visualize_np_images_from_cluster(i, sorted_cluster_labels, sorted_train_data, max_images=900, norm_func=lambda x: x)
    plt.show()


# Here we train a small classifier on labeled test set and then apply to big unlabeled set.

In [None]:
test_labels_for_training = test_properties["class"]>1

import torch
import torch.nn as nn
import torch.nn.functional as F

class CLIPClassifier(nn.Module):
    def __init__(self, input_dim, num_classes):
        super(CLIPClassifier, self).__init__()
        self.fc1 = nn.Linear(input_dim, 512)  
        self.bn1 = nn.BatchNorm1d(512)
        self.fc2 = nn.Linear(512, 256) 
        self.bn2 = nn.BatchNorm1d(256)
        self.fc3 = nn.Linear(256, 256)  
        self.bn3 = nn.BatchNorm1d(256)
        self.fc4 = nn.Linear(256, 256)  
        self.bn4 = nn.BatchNorm1d(256)
        self.fc5 = nn.Linear(256, num_classes)  
        self.dropout = nn.Dropout(0.6)

    def forward(self, x):
        x = F.relu(self.bn1(self.fc1(x)))
        x = self.dropout(x)
        x = F.relu(self.bn2(self.fc2(x)))
        x = F.relu(self.bn3(self.fc3(x)))+x
        x = self.dropout(x)
        x = F.relu(self.bn4(self.fc4(x)))
        x = self.dropout(x)
        x = self.fc5(x)
        return x

input_dim = 512  
num_classes = 2  

model = CLIPClassifier(input_dim=input_dim, num_classes=num_classes)

from torchsummary import summary
summary(model, input_size=(input_dim,))
max_ones = 0

In [None]:
import torch.optim as optim

from torch.utils.data import Dataset, DataLoader
import torch

class CLIPDataset(Dataset):
    def __init__(self, features, labels):
        self.features = features
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return self.features[idx], self.labels[idx]

test_features_tensor = torch.tensor(test_features, dtype=torch.float)
test_labels_tensor = torch.tensor(test_labels_for_training, dtype=torch.long)  # Use torch.long for labels

clip_dataset = CLIPDataset(test_features_tensor, test_labels_tensor)
train_loader = DataLoader(clip_dataset, batch_size=32, shuffle=True)
valid_loader = train_loader

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

num_epochs = 300 
for epoch in range(num_epochs):
    model.train()  
    running_loss = 0.0
    
    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        
        optimizer.zero_grad() 
        
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
    if epoch % 100 == 1:
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss/len(train_loader)}")
    
    model.eval() 
    correct = 0
    total = 0
    
    with torch.no_grad():
        for inputs, labels in valid_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    if epoch % 100 == 1:
        print(f"Validation Accuracy: {100 * correct / total}%")
        print(max_ones)

    clip_test_dataset = CLIPDataset(train_features, train_properties["class"])
    model.eval()

    train_test_loader = DataLoader(clip_test_dataset, batch_size=64, shuffle=False)
    predicted_labels_classifier = []
    for i, (features, labels) in enumerate(train_test_loader):
        features = features 
        logits = model(features) 
        probabilities = F.softmax(logits, dim=1)  
        predicted_labels = torch.argmax(probabilities, dim=1) 
        predicted_labels_classifier.extend([i for i in predicted_labels])
    predicted_labels_classifier = np.array(predicted_labels_classifier).reshape(-1,)
    model.train()
    for i in np.unique(predicted_labels_classifier):
        nb_ones = np.sum(predicted_labels_classifier==i)
        if np.random.rand()>0.995:
            max_ones = 0
        if i==1:
            print(nb_ones)
        if i==1 and nb_ones > max_ones and correct / total > 0.99 and nb_ones > 100:
            
            max_ones = nb_ones
            torch.save(model.state_dict(), f'weights/detector_at_samples_{max_ones}.pth')
            if i==1:
                visualize_np_images_from_cluster(i, predicted_labels_classifier, train_properties["data"], max_images=300, norm_func=lambda x : x)
                plt.show()

print("Training Complete")


# As our test set is very small, we do have a validation set. For a fare study it could have been done manually taking a small subset of big unlabeled set. For the proof-of-concept, we suggest visually choose a checkpoint (defined by nb of anomalies in the big unlabeled set) whose anomalies look consistent. (hence so many visualization by the previous script)

In [None]:
max_ones = 534
model.load_state_dict(torch.load( f'weights/detector_at_samples_{max_ones}.pth'))

model.eval()

In [None]:
clip_test_dataset = CLIPDataset(train_features, train_properties["class"])
model.eval()

train_test_loader = DataLoader(clip_test_dataset, batch_size=64, shuffle=False)
predicted_labels_classifier = []
for i, (features, labels) in enumerate(train_test_loader):
    features = features 
    logits = model(features) 
    probabilities = F.softmax(logits, dim=1)
    predicted_labels = torch.argmax(probabilities, dim=1)  
    predicted_labels_classifier.extend([i for i in predicted_labels])
 predicted_labels_classifier = np.array(predicted_labels_classifier).reshape(-1,)

# Writing clusters in the text file. Needed for further analysis.

In [None]:
for i in np.unique(predicted_labels_classifier):
    print(i)
    print(np.sum(predicted_labels_classifier==i))
    visualize_np_images_from_cluster(i, predicted_labels_classifier, train_properties["data"], max_images=300, norm_func=lambda x : x)
    plt.savefig(f"outputs/isolation_forest/binary_classifier_class_{i}.png")

list_multiple_sources = train_properties["filename"][predicted_labels_classifier==i]

with open(f'outputs/multiple_sources_{max_ones}.txt', 'w') as file:
    for item in list_multiple_sources:
        file.write(item + '\n')


In [None]:
def parse_clusters_file(clusters_file_path):
    clusters = {}
    current_cluster = None
    with open(clusters_file_path, 'r') as f:
        for line in f:
            line = line.strip()
            if line.isdigit():  # This is a cluster number
                current_cluster = int(line)
                clusters[current_cluster] = []
            elif line:  # This is a file path
                if current_cluster is not None:
                    clusters[current_cluster].append(line)
    return clusters

def parse_anomalies_file(anomalies_file_path):
    anomalies = set()
    with open(anomalies_file_path, 'r') as f:
        for line in f:
            anomalies.add(line.strip())
    return anomalies

def compute_anomalies_percentage(clusters, anomalies):
    anomalies_percentage = {}
    for cluster, files in clusters.items():
        count = sum(1 for file in files if file in anomalies)
        percentage = (count / len(files)) * 100 if files else 0
        anomalies_percentage[cluster] = percentage
    return anomalies_percentage


In [None]:
n_clusters = 30
clusters_file_path = f"outputs/clusters_seed_45_{n_clusters}.txt"
anomalies_file_path = f"outputs/multiple_sources_{max_ones}.txt"

clusters = parse_clusters_file(clusters_file_path)
anomalies = parse_anomalies_file(anomalies_file_path)
    
anomalies_percentage = compute_anomalies_percentage(clusters, anomalies)
    
plt.hist(anomalies_percentage.values(), bins=15)
plt.show()

for cluster, percentage in anomalies_percentage.items():
    print(f'Cluster {cluster}: {percentage:.2f}% anomalies')


In [None]:
import math 
for i in anomalies_percentage.keys():
    print(f'Cluster {i}: {anomalies_percentage[i]:.2f}% anomalies')
    if anomalies_percentage[i] < 25:
        continue
    file_names = clusters[i]
    nb_files = len(file_names)
    grid_size = math.ceil(np.sqrt(nb_files))
    images = []
    for file_name in file_names:
        with fits.open(file_name) as hdul:
            image_data = hdul[0].data
            image_data = (image_data - np.min(image_data))/(np.max(image_data) - np.min(image_data))
            images.append(image_data)
    images = np.array(images)
    
    random_seed = 45
    def visualize_grid(images, grid_size=(90, 90), figsize=(18, 18)):
        grid = np.zeros((grid_size[0] * images.shape[1], grid_size[1] * images.shape[2]))
    
        for idx, img in enumerate(images):
            if idx >= grid_size[0] * grid_size[1]:
                break  
            row = idx // grid_size[1]
            col = idx % grid_size[1]
            grid[row * images.shape[1]:(row + 1) * images.shape[1],
                 col * images.shape[2]:(col + 1) * images.shape[2]] = img
    
        plt.figure(figsize=figsize)
        plt.imshow(grid, cmap='inferno', aspect='auto')
        plt.axis('off')  
        plt.savefig(f"weights/anomalies_{random_seed}_{anomalies_percentage[i]}.png", dpi=350, )
        plt.show()
    
    visualize_grid(images, figsize=(40,40), grid_size=(grid_size, grid_size))

In [None]:
anomalies_percentage = {i: anomalies_percentage[i] for i in sorted(anomalies_percentage, key=anomalies_percentage.get, reverse=True)}
for i in anomalies_percentage.keys():
    print(f'Cluster {i}: {anomalies_percentage[i]:.2f}% anomalies')


In [None]:
import matplotlib.pyplot as plt
from matplotlib import cm

cluster_ids = list(anomalies_percentage.keys())
colormap = cm.get_cmap('tab20b', max(cluster_ids) + 1)

n = 30 
top_n_clusters = list(anomalies_percentage.items())[:n]
clusters, percentages = zip(*top_n_clusters) 
clusters = [f"Cluster {i}" for i, _ in top_n_clusters]

plt.figure(figsize=(12, 8))

colors = [colormap(cluster_id) for cluster_id, _ in top_n_clusters]

bars = plt.barh(clusters, percentages, color=colors)
plt.gca().invert_yaxis() 

for spine in plt.gca().spines.values():
    spine.set_visible(False)

plt.tick_params(axis='x', which='both', bottom=False, top=False, labelbottom=False)
plt.tick_params(axis='y', which='both', left=False, right=False, labelleft=True)  # Keep y-axis labels

for bar in bars:
    width = bar.get_width()
    #plt.text(width,  # Position at the end of the bar
    #         bar.get_y() + bar.get_height() / 2,
    #         f'{width:.2f}%',  # The text to display
    #         va='center', ha='right', color='black')  # Align text

plt.tight_layout()
plt.savefig("outputs/anomaly_percentages.png", dpi=350)
