In [1]:
RUN_EXXA2 = True  # Unsupervised clustering of protoplanetary disks
RUN_EXXA4 = True  # Autoencoder for latent space analysis
RUN_EXXA3 = True  # Classifier for exoplanet transit light curves


In [2]:
from google.colab import drive
drive.mount('/content/drive/')

Mounted at /content/drive/


In [3]:
!pip install numpy pandas torch torchvision astropy pytorch-msssim scikit-learn matplotlib seaborn tqdm


Collecting pytorch-msssim
  Downloading pytorch_msssim-1.0.0-py3-none-any.whl.metadata (8.0 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB

In [4]:
from google.colab import drive
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from astropy.io import fits
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.cluster import KMeans, DBSCAN
from sklearn.preprocessing import StandardScaler
from pytorch_msssim import ssim, ms_ssim
from tqdm import tqdm
import os
from torchvision.utils import make_grid
from sklearn.model_selection import train_test_split
from datetime import datetime
from sklearn.metrics import roc_curve, auc, precision_recall_curve, average_precision_score
from torch.nn import init
import torch.nn.functional as F
from sklearn.decomposition import PCA
from sklearn.metrics import silhouette_samples, silhouette_score
SAVE_DIR = '/content/drive/MyDrive/EXXA_Results'
os.makedirs(SAVE_DIR, exist_ok=True)
os.makedirs(os.path.join(SAVE_DIR, 'training_progress'), exist_ok=True)
os.makedirs(os.path.join(SAVE_DIR, 'checkpoints'), exist_ok=True)
os.makedirs(os.path.join(SAVE_DIR, 'metrics'), exist_ok=True)

In [5]:
torch.manual_seed(42)
np.random.seed(42)


In [6]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

Using device: cuda


#Data Loading and Preprocessing


In [7]:
class DiskDataset(Dataset):
    def __init__(self, data_dir, transform=None):
        self.data_dir = data_dir
        self.transform = transform
        self.files = [f for f in os.listdir(data_dir) if f.endswith('.fits')]

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        file_path = os.path.join(self.data_dir, self.files[idx])
        with fits.open(file_path) as hdul:
            image = hdul[0].data[0]

        image = (image - np.min(image)) / (np.max(image) - np.min(image))

        if len(image.shape) > 2:
            image = image.squeeze()

        image = image[np.newaxis, :, :]

        if self.transform:
            image = self.transform(image)

        return torch.FloatTensor(image)

#Unsupervised Clustering Model


In [8]:
class DiskClustering:
    def __init__(self, n_clusters=5):
        self.n_clusters = n_clusters
        self.kmeans = KMeans(
            n_clusters=n_clusters,
            random_state=42,
            n_init=50,  # Increased for better initialization
            max_iter=1000,  # More iterations
            tol=1e-6  # Tighter convergence
        )
        self.scaler = StandardScaler()

    def preprocess_images(self, images):
        """Improved preprocessing pipeline"""
        # Flatten images
        n_samples = images.shape[0]
        flattened = images.reshape(n_samples, -1)

        # Remove low variance features
        std = np.std(flattened, axis=0)
        high_var_mask = std > np.percentile(std, 10)
        flattened = flattened[:, high_var_mask]

        # Remove outliers
        mean = np.mean(flattened, axis=0)
        std = np.std(flattened, axis=0)
        z_scores = np.abs((flattened - mean) / std)
        outlier_mask = np.all(z_scores < 3, axis=1)  # Remove points with z-score > 3
        flattened = flattened[outlier_mask]

        # Scale features
        scaled = self.scaler.fit_transform(flattened)

        pca = PCA(n_components=0.95)
        reduced = pca.fit_transform(scaled)

        self.high_var_mask = high_var_mask
        self.outlier_mask = outlier_mask
        self.pca = pca

        print(f"Preprocessed data shape: {reduced.shape}")
        print(f"Removed {np.sum(~outlier_mask)} outliers")
        print(f"Kept {reduced.shape[1]} components explaining 95% variance")

        return reduced

    def fit(self, images):
        print("Preprocessing data...")
        self.processed_data = self.preprocess_images(images)

        print(f"Fitting KMeans with {self.n_clusters} clusters...")
        self.kmeans.fit(self.processed_data)

        self.silhouette = silhouette_score(self.processed_data, self.kmeans.labels_)
        print(f"Silhouette Score: {self.silhouette:.3f}")

        scores = []
        for k in range(2, 8):
            kmeans_temp = KMeans(n_clusters=k, random_state=42, n_init='auto')
            labels_temp = kmeans_temp.fit_predict(self.processed_data)
            score = silhouette_score(self.processed_data, labels_temp)
            scores.append(score)
            print(f"Silhouette Score with {k} clusters: {score:.3f}")

        best_k = np.argmax(scores) + 2
        if best_k != self.n_clusters:
            print(f"\nNote: {best_k} clusters might give better results (score: {max(scores):.3f})")

        return self

    def predict(self, images):
        processed = self.preprocess_images(images)
        return self.kmeans.predict(processed)

    def visualize_clusters(self, images, labels):
        fig, axes = plt.subplots(2, 5, figsize=(20, 8))
        axes = axes.ravel()

        for i in range(self.n_clusters):
            cluster_images = images[labels == i]
            if len(cluster_images) > 0:
                axes[i].imshow(cluster_images[0], cmap='viridis')
                axes[i].set_title(f'Cluster {i}')
                axes[i].axis('off')

        plt.tight_layout()
        plt.show()

#Autoencoder Model

In [9]:
class DiskAutoencoder(nn.Module):
    def __init__(self, latent_dim=128):
        super(DiskAutoencoder, self).__init__()

        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 32, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(128 * 75 * 75, latent_dim)
        )

        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 128 * 75 * 75),
            nn.Unflatten(1, (128, 75, 75)),
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 1, 4, stride=2, padding=1),
            nn.Sigmoid()
        )

    def encode(self, x):
        return self.encoder(x)

    def decode(self, z):
        return self.decoder(z)

    def forward(self, x):
        z = self.encode(x)
        return self.decode(z)

#Transit Curve Classifier

In [10]:
class TransitClassifier(nn.Module):
    def __init__(self, input_size):
        super(TransitClassifier, self).__init__()

        # Feature extraction with attention
        self.features = nn.Sequential(
            nn.Linear(input_size, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Dropout(0.2)
        )

        # Attention mechanism
        self.attention = nn.Sequential(
            nn.Linear(128, 64),
            nn.Tanh(),
            nn.Linear(64, 128),
            nn.Sigmoid()
        )

        # Classification head
        self.classifier = nn.Sequential(
            nn.Linear(128, 64),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(64, 1)
        )

        # Initialize weights
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        # Feature extraction
        features = self.features(x)

        # Apply attention
        attention_weights = self.attention(features)
        attended_features = features * attention_weights

        # Classification
        return self.classifier(attended_features)

def generate_transit_curves(n_samples=5000, n_points=100):
    curves = np.zeros((n_samples, n_points))
    labels = np.zeros(n_samples)

    for i in range(n_samples):
        t = np.linspace(0, 1, n_points)
        if np.random.random() > 0.5:  # Transit present
            depth = np.random.uniform(0.01, 0.05)  # Shallower transits
            duration = np.random.uniform(0.05, 0.15)  # More varied duration
            center = np.random.uniform(0.2, 0.8)  # More varied position

            transit = np.ones_like(t)
            mask = np.abs(t - center) < duration/2
            transit[mask] = 1 - depth

            # Add ingress/egress
            ingress_duration = duration * 0.1
            ingress_mask = (t > center - duration/2) & (t < center - duration/2 + ingress_duration)
            egress_mask = (t > center + duration/2 - ingress_duration) & (t < center + duration/2)

            # Smooth ingress/egress
            for mask in [ingress_mask, egress_mask]:
                if np.any(mask):
                    transit[mask] = 1 - depth * (1 - np.abs(t[mask] - center) / (duration/2))

            # More realistic noise and systematics
            noise = np.random.normal(0, 0.005, len(t))  # Increased noise
            trend = np.random.uniform(-0.01, 0.01) * t  # Larger trends
            oscillation = 0.002 * np.sin(2 * np.pi * t * np.random.uniform(1, 3))  # Add oscillations

            curve = transit + noise + trend + oscillation

            curves[i] = curve
            labels[i] = 1
        else:  # No transit

            noise = np.random.normal(0, 0.005, len(t))
            trend = np.random.uniform(-0.01, 0.01) * t
            oscillation = 0.002 * np.sin(2 * np.pi * t * np.random.uniform(1, 3))

            curve = 1 + noise + trend + oscillation

            curves[i] = curve
            labels[i] = 0

        # Normalize
        curves[i] = (curves[i] - np.mean(curves[i])) / np.std(curves[i])

    return curves, labels

#Training Functions

In [11]:
def visualize_training_progress(model, dataloader, epoch, save_dir=None):
    if save_dir is None:
        save_dir = os.path.join(SAVE_DIR, 'training_progress')

    model.eval()
    with torch.no_grad():
        batch = next(iter(dataloader))
        batch = batch.to(device)
        reconstructions = model(batch)
        n = min(8, batch.size(0))
        comparison = torch.cat([batch[:n], reconstructions[:n]])
        grid = make_grid(comparison.cpu(), nrow=n, normalize=True)

        plt.figure(figsize=(12, 4))
        plt.imshow(grid.permute(1, 2, 0))
        plt.axis('off')
        plt.savefig(os.path.join(save_dir, f'epoch_{epoch}.png'))
        plt.close()

In [12]:
def save_checkpoint(model, optimizer, epoch, loss, path):
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss,
    }, path)


In [13]:
def train_autoencoder(model, train_loader, num_epochs=50):
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    checkpoint_dir = os.path.join(SAVE_DIR, 'checkpoints')

    history = {'train_loss': [], 'val_loss': []}

    train_size = int(0.8 * len(train_loader.dataset))
    val_size = len(train_loader.dataset) - train_size
    train_dataset, val_dataset = torch.utils.data.random_split(train_loader.dataset, [train_size, val_size])

    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

    best_val_loss = float('inf')

    for epoch in range(num_epochs):
        # Training phase
        model.train()
        total_train_loss = 0

        for batch in tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}'):
            batch = batch.to(device)

            optimizer.zero_grad()
            output = model(batch)
            loss = criterion(output, batch)

            loss.backward()
            optimizer.step()

            total_train_loss += loss.item()

        avg_train_loss = total_train_loss / len(train_loader)

        # Validation phase
        model.eval()
        total_val_loss = 0
        with torch.no_grad():
            for batch in val_loader:
                batch = batch.to(device)
                output = model(batch)
                loss = criterion(output, batch)
                total_val_loss += loss.item()

        avg_val_loss = total_val_loss / len(val_loader)

        # Update history
        history['train_loss'].append(avg_train_loss)
        history['val_loss'].append(avg_val_loss)

        print(f'Epoch {epoch+1}, Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}')

        # Save checkpoint if validation loss improved
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            save_checkpoint(model, optimizer, epoch, avg_val_loss,
                          os.path.join(checkpoint_dir, 'best_autoencoder.pth'))

        if (epoch + 1) % 10 == 0:
            save_checkpoint(model, optimizer, epoch, avg_train_loss,
                          os.path.join(checkpoint_dir, f'autoencoder_epoch_{epoch+1}.pth'))

        visualize_training_progress(model, train_loader, epoch)

    plt.figure(figsize=(10, 6))
    plt.plot(history['train_loss'], label='Training Loss')
    plt.plot(history['val_loss'], label='Validation Loss')
    plt.title('Training History')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.savefig(os.path.join(SAVE_DIR, 'autoencoder_training_history.png'))
    plt.close()

    return history

In [14]:
def evaluate_model(model, test_loader, device):
    model.eval()
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs = inputs.to(device)
            outputs = model(inputs)
            preds = torch.sigmoid(outputs)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.numpy())

    # Calculate ROC curve and AUC
    fpr, tpr, _ = roc_curve(all_labels, all_preds)
    auc_score = auc(fpr, tpr)

    # Calculate precision-recall curve
    precision, recall, _ = precision_recall_curve(all_labels, all_preds)
    avg_precision = average_precision_score(all_labels, all_preds)

    return {
        'fpr': fpr,
        'tpr': tpr,
        'auc': auc_score,
        'precision': precision,
        'recall': recall,
        'avg_precision': avg_precision
    }



In [15]:
def plot_metrics(metrics, save_dir=None):
    """Plot and save evaluation metrics."""
    if save_dir is None:
        save_dir = os.path.join(SAVE_DIR, 'metrics')
    os.makedirs(save_dir, exist_ok=True)

    # Plot ROC curve
    plt.figure(figsize=(10, 6))
    plt.plot(metrics['fpr'], metrics['tpr'], color='darkorange', lw=2,
             label=f'ROC curve (AUC = {metrics["auc"]:.3f})')
    plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Receiver Operating Characteristic (ROC) Curve')
    plt.legend(loc="lower right")
    plt.savefig(os.path.join(save_dir, 'roc_curve.png'))
    plt.close()

    # Plot precision-recall curve
    plt.figure(figsize=(10, 6))
    plt.plot(metrics['recall'], metrics['precision'], color='blue', lw=2,
             label=f'Precision-Recall curve (AP = {metrics["avg_precision"]:.3f})')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('Recall')
    plt.ylabel('Precision')
    plt.title('Precision-Recall Curve')
    plt.legend(loc="lower left")
    plt.savefig(os.path.join(save_dir, 'precision_recall_curve.png'))
    plt.close()



In [16]:

def train_classifier(model, train_loader, val_loader, num_epochs=100, device='cuda'):
    model = model.to(device)
    criterion = nn.BCEWithLogitsLoss()
    optimizer = optim.AdamW(model.parameters(), lr=0.0003, weight_decay=0.001)

    # Learning rate scheduler with warmup
    scheduler = optim.lr_scheduler.OneCycleLR(
        optimizer,
        max_lr=0.001,
        epochs=num_epochs,
        steps_per_epoch=len(train_loader),
        pct_start=0.1
    )

    best_val_acc = 0
    patience = 20
    patience_counter = 0
    history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}

    for epoch in range(num_epochs):
        # Training phase
        model.train()
        train_loss = 0
        train_correct = 0
        train_total = 0

        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels.unsqueeze(1).float())

            # Add L1 regularization
            l1_lambda = 0.0001
            l1_norm = sum(p.abs().sum() for p in model.parameters())
            loss = loss + l1_lambda * l1_norm

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            scheduler.step()

            train_loss += loss.item()
            preds = torch.sigmoid(outputs) > 0.5
            train_correct += (preds.squeeze() == labels).sum().item()
            train_total += labels.size(0)

        train_loss = train_loss / len(train_loader)
        train_acc = 100 * train_correct / train_total

        # Validation phase
        model.eval()
        val_loss = 0
        val_correct = 0
        val_total = 0

        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels.unsqueeze(1).float())

                val_loss += loss.item()
                preds = torch.sigmoid(outputs) > 0.5
                val_correct += (preds.squeeze() == labels).sum().item()
                val_total += labels.size(0)

        val_loss = val_loss / len(val_loader)
        val_acc = 100 * val_correct / val_total

        # Update history
        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)

        print(f'Epoch [{epoch+1}/{num_epochs}]')
        print(f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%')
        print(f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%')

        # Early stopping
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            patience_counter = 0
            # Save best model
            model_path = os.path.join(SAVE_DIR, 'checkpoints', 'best_classifier.pth')
            torch.save(model.state_dict(), model_path)
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(f'Early stopping triggered after {epoch+1} epochs')
                break

    return history


In [21]:
def visualize_clustering_results(model, data, labels, save_dir=None):
    """Generate and save clustering visualizations."""
    if save_dir is None:
        save_dir = os.path.join(SAVE_DIR, 'clustering')
    os.makedirs(save_dir, exist_ok=True)

    data_2d = model.processed_data
    labels = model.kmeans.labels_

    # 1. PCA visualization
    print("Generating PCA visualization...")
    plt.figure(figsize=(10, 8))
    scatter = plt.scatter(data_2d[:, 0], data_2d[:, 1], c=labels, cmap='viridis')
    plt.colorbar(scatter)
    plt.title('Cluster Visualization (PCA)')
    plt.xlabel('First Principal Component')
    plt.ylabel('Second Principal Component')
    plt.savefig(os.path.join(save_dir, 'cluster_visualization.png'))
    plt.close()

    # 2. Silhouette plot
    print("Generating Silhouette plot...")
    silhouette_avg = silhouette_score(data_2d, labels)
    sample_silhouette_values = silhouette_samples(data_2d, labels)

    plt.figure(figsize=(10, 8))
    y_lower = 10
    for i in range(len(set(labels))):
        ith_cluster_values = sample_silhouette_values[labels == i]
        ith_cluster_values.sort()
        size_cluster_i = len(ith_cluster_values)
        y_upper = y_lower + size_cluster_i

        color = plt.cm.nipy_spectral(float(i) / len(set(labels)))
        plt.fill_betweenx(np.arange(y_lower, y_upper),
                         0, ith_cluster_values,
                         facecolor=color, edgecolor=color, alpha=0.7)
        y_lower = y_upper + 10

    plt.title(f'Silhouette Analysis (avg score: {silhouette_avg:.3f})')
    plt.xlabel('Silhouette Coefficient')
    plt.ylabel('Cluster')
    plt.axvline(x=silhouette_avg, color="red", linestyle="--")
    plt.savefig(os.path.join(save_dir, 'silhouette_plot.png'))
    plt.close()

    print(f"Clustering visualizations saved in {save_dir}")
    return silhouette_avg

In [24]:
def visualize_autoencoder_results(model, dataloader, save_dir='/content/drive/MyDrive/EXXA_Results/autoencoder'):
    """Visualize autoencoder results"""
    model.eval()
    os.makedirs(save_dir, exist_ok=True)

    data_iter = iter(dataloader)
    samples = next(data_iter)
    samples = samples.to(next(model.parameters()).device)

    with torch.no_grad():
        reconstructions = model(samples)

    samples = samples.cpu().numpy()
    reconstructions = reconstructions.cpu().numpy()

    # Plot original vs reconstructed images
    n_samples = min(5, len(samples))
    fig, axes = plt.subplots(2, n_samples, figsize=(15, 6))

    for i in range(n_samples):
        # Original
        axes[0, i].imshow(samples[i, 0], cmap='viridis')
        axes[0, i].set_title('Original')
        axes[0, i].axis('off')

        # Reconstructed
        axes[1, i].imshow(reconstructions[i, 0], cmap='viridis')
        axes[1, i].set_title('Reconstructed')
        axes[1, i].axis('off')

    plt.tight_layout()
    plt.savefig(os.path.join(save_dir, 'reconstruction_comparison.png'))
    plt.close()

    # Plot loss distribution
    if hasattr(model, 'train_losses') and hasattr(model, 'val_losses'):
        plt.figure(figsize=(10, 6))
        plt.plot(model.train_losses, label='Training Loss')
        plt.plot(model.val_losses, label='Validation Loss')
        plt.title('Autoencoder Training Progress')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.legend()
        plt.savefig(os.path.join(save_dir, 'training_progress.png'))
        plt.close()

In [25]:
if __name__ == "__main__":
    DATA_DIR = '/content/drive/MyDrive/continuum_data_subset'

    # Create datasets and dataloaders
    disk_dataset = DiskDataset(DATA_DIR)
    disk_loader = DataLoader(disk_dataset, batch_size=32, shuffle=True)

    # Initialize models
    if RUN_EXXA2:
        print('Running EXXA2: Unsupervised Clustering')

        # Initialize clustering model with 2 clusters instead of 5
        clustering_model = DiskClustering(n_clusters=2)

        print("Loading disk data...")
        all_data = []
        for batch in disk_loader:
            all_data.append(batch.numpy())
        disk_data = np.concatenate(all_data, axis=0)
        print(f"Loaded disk data shape: {disk_data.shape}")

        # Fit the clustering model and get labels
        print("Fitting clustering model...")
        clustering_model.fit(disk_data)
        cluster_labels = clustering_model.kmeans.labels_

        # Generate visualizations
        print("Generating clustering visualizations...")
        visualize_clustering_results(clustering_model, disk_data, cluster_labels)
    if RUN_EXXA4:
        print('Running EXXA4: Autoencoder Training...')
        autoencoder = DiskAutoencoder().to(device)
        train_autoencoder(autoencoder, disk_loader)

        # Generate autoencoder visualizations
        print("Generating autoencoder visualizations...")
        visualize_autoencoder_results(autoencoder, disk_loader)

    if RUN_EXXA3:
        print('Running EXXA3: Transit Classifier Training...')
        # Generate more data for better training
        transit_curves, labels = generate_transit_curves(n_samples=10000)  # Increased samples

        # Create datasets
        dataset = torch.utils.data.TensorDataset(
            torch.FloatTensor(transit_curves),
            torch.FloatTensor(labels)
        )

        # Split into train, validation, and test sets
        train_size = int(0.7 * len(dataset))
        val_size = int(0.15 * len(dataset))
        test_size = len(dataset) - train_size - val_size

        train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(
            dataset, [train_size, val_size, test_size]
        )

        train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)  # Increased batch size
        val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)
        test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

        # Initialize and train model
        classifier = TransitClassifier(100).to(device)
        history = train_classifier(classifier, train_loader, val_loader, num_epochs=150)  # More epochs

        # Evaluate on test set
        model_path = os.path.join(SAVE_DIR, 'checkpoints', 'best_classifier.pth')
        torch.save(classifier.state_dict(), model_path)
        classifier.load_state_dict(torch.load(model_path))
        test_metrics = evaluate_model(classifier, test_loader, device)
        plot_metrics(test_metrics)
        print(f'Test AUC: {test_metrics["auc"]:.3f}')
        print(f'Test Average Precision: {test_metrics["avg_precision"]:.3f}')

        # Plot training history
        plt.figure(figsize=(12, 4))

        plt.subplot(1, 2, 1)
        plt.plot(history['train_loss'], label='Train Loss')
        plt.plot(history['val_loss'], label='Val Loss')
        plt.title('Loss History')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.legend()

        plt.subplot(1, 2, 2)
        plt.plot(history['train_acc'], label='Train Accuracy')
        plt.plot(history['val_acc'], label='Val Accuracy')
        plt.title('Accuracy History')
        plt.xlabel('Epoch')
        plt.ylabel('Accuracy (%)')
        plt.legend()

        plt.tight_layout()
        plt.savefig(os.path.join(SAVE_DIR, 'classifier_training_history.png'))
        plt.close()


Running EXXA2: Unsupervised Clustering
Loading disk data...
Loaded disk data shape: (150, 1, 600, 600)
Fitting clustering model...
Preprocessing data...
Preprocessed data shape: (59, 7)
Removed 91 outliers
Kept 7 components explaining 95% variance
Fitting KMeans with 2 clusters...
Silhouette Score: 0.433
Silhouette Score with 2 clusters: 0.433
Silhouette Score with 3 clusters: 0.403
Silhouette Score with 4 clusters: 0.307
Silhouette Score with 5 clusters: 0.329
Silhouette Score with 6 clusters: 0.370
Silhouette Score with 7 clusters: 0.358
Generating clustering visualizations...
Generating PCA visualization...
Generating Silhouette plot...
Clustering visualizations saved in /content/drive/MyDrive/EXXA_Results/clustering
Running EXXA4: Autoencoder Training...


Epoch 1/50: 100%|██████████| 4/4 [00:06<00:00,  1.61s/it]


Epoch 1, Train Loss: 0.2224, Val Loss: 0.1260


Epoch 2/50: 100%|██████████| 4/4 [00:05<00:00,  1.34s/it]


Epoch 2, Train Loss: 0.0588, Val Loss: 0.0178


Epoch 3/50: 100%|██████████| 4/4 [00:05<00:00,  1.33s/it]


Epoch 3, Train Loss: 0.0108, Val Loss: 0.0174


Epoch 4/50: 100%|██████████| 4/4 [00:05<00:00,  1.34s/it]


Epoch 4, Train Loss: 0.0113, Val Loss: 0.0173


Epoch 5/50: 100%|██████████| 4/4 [00:06<00:00,  1.51s/it]


Epoch 5, Train Loss: 0.0108, Val Loss: 0.0173


Epoch 6/50: 100%|██████████| 4/4 [00:06<00:00,  1.58s/it]


Epoch 6, Train Loss: 0.0110, Val Loss: 0.0173


Epoch 7/50: 100%|██████████| 4/4 [00:05<00:00,  1.28s/it]


Epoch 7, Train Loss: 0.0115, Val Loss: 0.0173


Epoch 8/50: 100%|██████████| 4/4 [00:04<00:00,  1.19s/it]


Epoch 8, Train Loss: 0.0113, Val Loss: 0.0173


Epoch 9/50: 100%|██████████| 4/4 [00:06<00:00,  1.62s/it]


Epoch 9, Train Loss: 0.0108, Val Loss: 0.0173


Epoch 10/50: 100%|██████████| 4/4 [00:05<00:00,  1.37s/it]


Epoch 10, Train Loss: 0.0107, Val Loss: 0.0173


Epoch 11/50: 100%|██████████| 4/4 [00:07<00:00,  1.76s/it]


Epoch 11, Train Loss: 0.0108, Val Loss: 0.0173


Epoch 12/50: 100%|██████████| 4/4 [00:06<00:00,  1.57s/it]


Epoch 12, Train Loss: 0.0107, Val Loss: 0.0173


Epoch 13/50: 100%|██████████| 4/4 [00:06<00:00,  1.63s/it]


Epoch 13, Train Loss: 0.0111, Val Loss: 0.0173


Epoch 14/50: 100%|██████████| 4/4 [00:05<00:00,  1.47s/it]


Epoch 14, Train Loss: 0.0111, Val Loss: 0.0173


Epoch 15/50: 100%|██████████| 4/4 [00:05<00:00,  1.43s/it]


Epoch 15, Train Loss: 0.0112, Val Loss: 0.0173


Epoch 16/50: 100%|██████████| 4/4 [00:05<00:00,  1.36s/it]


Epoch 16, Train Loss: 0.0110, Val Loss: 0.0173


Epoch 17/50: 100%|██████████| 4/4 [00:05<00:00,  1.25s/it]


Epoch 17, Train Loss: 0.0111, Val Loss: 0.0173


Epoch 18/50: 100%|██████████| 4/4 [00:05<00:00,  1.32s/it]


Epoch 18, Train Loss: 0.0107, Val Loss: 0.0173


Epoch 19/50: 100%|██████████| 4/4 [00:05<00:00,  1.32s/it]


Epoch 19, Train Loss: 0.0109, Val Loss: 0.0173


Epoch 20/50: 100%|██████████| 4/4 [00:05<00:00,  1.36s/it]


Epoch 20, Train Loss: 0.0110, Val Loss: 0.0173


Epoch 21/50: 100%|██████████| 4/4 [00:05<00:00,  1.48s/it]


Epoch 21, Train Loss: 0.0108, Val Loss: 0.0173


Epoch 22/50: 100%|██████████| 4/4 [00:07<00:00,  1.81s/it]


Epoch 22, Train Loss: 0.0110, Val Loss: 0.0173


Epoch 23/50: 100%|██████████| 4/4 [00:06<00:00,  1.53s/it]


Epoch 23, Train Loss: 0.0109, Val Loss: 0.0173


Epoch 24/50: 100%|██████████| 4/4 [00:05<00:00,  1.45s/it]


Epoch 24, Train Loss: 0.0111, Val Loss: 0.0173


Epoch 25/50: 100%|██████████| 4/4 [00:05<00:00,  1.28s/it]


Epoch 25, Train Loss: 0.0110, Val Loss: 0.0173


Epoch 26/50: 100%|██████████| 4/4 [00:05<00:00,  1.48s/it]


Epoch 26, Train Loss: 0.0107, Val Loss: 0.0173


Epoch 27/50: 100%|██████████| 4/4 [00:05<00:00,  1.34s/it]


Epoch 27, Train Loss: 0.0108, Val Loss: 0.0173


Epoch 28/50: 100%|██████████| 4/4 [00:05<00:00,  1.35s/it]


Epoch 28, Train Loss: 0.0110, Val Loss: 0.0173


Epoch 29/50: 100%|██████████| 4/4 [00:05<00:00,  1.32s/it]


Epoch 29, Train Loss: 0.0110, Val Loss: 0.0173


Epoch 30/50: 100%|██████████| 4/4 [00:05<00:00,  1.26s/it]


Epoch 30, Train Loss: 0.0108, Val Loss: 0.0173


Epoch 31/50: 100%|██████████| 4/4 [00:05<00:00,  1.38s/it]


Epoch 31, Train Loss: 0.0108, Val Loss: 0.0173


Epoch 32/50: 100%|██████████| 4/4 [00:07<00:00,  1.82s/it]


Epoch 32, Train Loss: 0.0106, Val Loss: 0.0173


Epoch 33/50: 100%|██████████| 4/4 [00:06<00:00,  1.55s/it]


Epoch 33, Train Loss: 0.0109, Val Loss: 0.0173


Epoch 34/50: 100%|██████████| 4/4 [00:05<00:00,  1.32s/it]


Epoch 34, Train Loss: 0.0106, Val Loss: 0.0173


Epoch 35/50: 100%|██████████| 4/4 [00:05<00:00,  1.47s/it]


Epoch 35, Train Loss: 0.0111, Val Loss: 0.0173


Epoch 36/50: 100%|██████████| 4/4 [00:05<00:00,  1.50s/it]


Epoch 36, Train Loss: 0.0112, Val Loss: 0.0173


Epoch 37/50: 100%|██████████| 4/4 [00:07<00:00,  1.86s/it]


Epoch 37, Train Loss: 0.0107, Val Loss: 0.0173


Epoch 38/50: 100%|██████████| 4/4 [00:04<00:00,  1.23s/it]


Epoch 38, Train Loss: 0.0109, Val Loss: 0.0173


Epoch 39/50: 100%|██████████| 4/4 [00:05<00:00,  1.39s/it]


Epoch 39, Train Loss: 0.0109, Val Loss: 0.0173


Epoch 40/50: 100%|██████████| 4/4 [00:05<00:00,  1.37s/it]


Epoch 40, Train Loss: 0.0109, Val Loss: 0.0173


Epoch 41/50: 100%|██████████| 4/4 [00:05<00:00,  1.39s/it]


Epoch 41, Train Loss: 0.0107, Val Loss: 0.0173


Epoch 42/50: 100%|██████████| 4/4 [00:07<00:00,  1.81s/it]


Epoch 42, Train Loss: 0.0113, Val Loss: 0.0173


Epoch 43/50: 100%|██████████| 4/4 [00:06<00:00,  1.69s/it]


Epoch 43, Train Loss: 0.0110, Val Loss: 0.0173


Epoch 44/50: 100%|██████████| 4/4 [00:05<00:00,  1.42s/it]


Epoch 44, Train Loss: 0.0112, Val Loss: 0.0173


Epoch 45/50: 100%|██████████| 4/4 [00:05<00:00,  1.31s/it]


Epoch 45, Train Loss: 0.0111, Val Loss: 0.0173


Epoch 46/50: 100%|██████████| 4/4 [00:05<00:00,  1.30s/it]


Epoch 46, Train Loss: 0.0111, Val Loss: 0.0173


Epoch 47/50: 100%|██████████| 4/4 [00:05<00:00,  1.29s/it]


Epoch 47, Train Loss: 0.0108, Val Loss: 0.0173


Epoch 48/50: 100%|██████████| 4/4 [00:05<00:00,  1.46s/it]


Epoch 48, Train Loss: 0.0109, Val Loss: 0.0173


Epoch 49/50: 100%|██████████| 4/4 [00:05<00:00,  1.28s/it]


Epoch 49, Train Loss: 0.0109, Val Loss: 0.0173


Epoch 50/50: 100%|██████████| 4/4 [00:05<00:00,  1.34s/it]


Epoch 50, Train Loss: 0.0111, Val Loss: 0.0173
Generating autoencoder visualizations...
Running EXXA3: Transit Classifier Training...
Epoch [1/150]
Train Loss: 2.6150, Train Acc: 52.13%
Val Loss: 0.6604, Val Acc: 60.13%
Epoch [2/150]
Train Loss: 2.5577, Train Acc: 55.83%
Val Loss: 0.5881, Val Acc: 72.67%
Epoch [3/150]
Train Loss: 2.4651, Train Acc: 64.19%
Val Loss: 0.4965, Val Acc: 78.60%
Epoch [4/150]
Train Loss: 2.3477, Train Acc: 73.11%
Val Loss: 0.4043, Val Acc: 80.80%
Epoch [5/150]
Train Loss: 2.2242, Train Acc: 80.46%
Val Loss: 0.3358, Val Acc: 84.07%
Epoch [6/150]
Train Loss: 2.1292, Train Acc: 84.69%
Val Loss: 0.2565, Val Acc: 88.47%
Epoch [7/150]
Train Loss: 2.0177, Train Acc: 87.79%
Val Loss: 0.2081, Val Acc: 91.60%
Epoch [8/150]
Train Loss: 1.9129, Train Acc: 90.87%
Val Loss: 0.1556, Val Acc: 94.73%
Epoch [9/150]
Train Loss: 1.8111, Train Acc: 92.27%
Val Loss: 0.1346, Val Acc: 95.07%
Epoch [10/150]
Train Loss: 1.7081, Train Acc: 93.74%
Val Loss: 0.1270, Val Acc: 96.00%
Epoch