In [1]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Install required packages
!pip install galsim==2.5.3   # Or 2.3.5
!pip install astropy==5.0.4
!pip install astroquery==0.4.6
!pip install torch torchvision
!pip install opencv-python
!pip install scikit-learn
!pip install tqdm
!pip install seaborn

# Create necessary directories
!mkdir -p data/slacs data/simulated data/processed
!mkdir -p checkpoints logs

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import torchvision.transforms as transforms

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, roc_curve, auc, precision_recall_curve, roc_auc_score
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from sklearn.model_selection import KFold

import astropy.io.fits as fits
from astropy.visualization import ZScaleInterval
from astroquery.sdss import SDSS
from astropy.coordinates import SkyCoord
import astropy.units as u

import cv2
from PIL import Image
import galsim

import os
import random
from tqdm.notebook import tqdm
import logging
import json
from datetime import datetime

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Set random seeds
def set_seed(seed=42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True

set_seed()


class Config:
    def __init__(self):
        self.data_params = {
            'image_size': 128,
            'batch_size': 16,  # Reduced for Colab
            'num_workers': 2,  # Reduced for Colab
            'train_split': 0.6,  # Adjusted for smaller datasets
            'val_split': 0.2,
            'test_split': 0.2
        }

        self.model_params = {
            'in_channels': 1,
            'base_filters': 32,
            'n_blocks': 4,
            'dropout_rate': 0.5,
            'use_residual': True,
            'use_se_block': True,  # Added SE blocks
            'se_reduction': 16
        }

        self.training_params = {
            'epochs': 100,
            'learning_rate': 3e-4,
            'weight_decay': 1e-5,
            'patience': 10,
            'min_delta': 1e-4,
            'warmup_epochs': 5,
            'use_mixed_precision': True,
            'gradient_clip_val': 1.0,
            'label_smoothing': 0.1
        }

        self.augmentation_params = {
            'rotation_range': 15,
            'zoom_range': 0.1,
            'shift_range': 0.1,
            'do_flip': True,
            'use_color_jitter': True,
            'use_gaussian_blur': True
        }

        self.paths = {
            'data_dir': './data',
            'slacs_dir': './data/slacs',
            'sim_dir': './data/simulated',
            'processed_dir': './data/processed',
            'checkpoints_dir': './checkpoints',
            'logs_dir': './logs'
        }

In [3]:
def download_slacs_data():
    """
    Downloads SLACS lens data from SDSS
    """
    config = Config()
    os.makedirs(config.paths['slacs_dir'], exist_ok=True)

    # Actual SLACS lens coordinates from the survey
    slacs_catalog = pd.DataFrame({
        'RA': [159.0668, 158.6518, 161.5276, 185.8116, 189.0877],
        'DEC': [39.2778, 44.3341, 42.4391, 8.8673, 11.8752],
        'name': ['SDSSJ1036+3927', 'SDSSJ1034+4432', 'SDSSJ1046+4224',
                'SDSSJ1213+0848', 'SDSSJ1216+1144']
    })

    logger.info("Downloading SLACS data...")
    for idx, row in tqdm(slacs_catalog.iterrows(), total=len(slacs_catalog)):
        try:
            coord = SkyCoord(ra=row['RA']*u.degree, dec=row['DEC']*u.degree)
            imgs = SDSS.get_images(coordinates=coord, radius=0.5*u.arcmin)

            if imgs and len(imgs) > 0:
                img = imgs[0]
                img.writeto(f"{config.paths['slacs_dir']}/slacs_lens_{row['name']}.fits",
                           overwrite=True)
        except Exception as e:
            logger.error(f"Error downloading image {row['name']}: {str(e)}")

def generate_simulated_lenses(n_samples=1000):
    """
    Generates simulated gravitational lenses using GalSim
    """
    config = Config()
    os.makedirs(config.paths['sim_dir'], exist_ok=True)

    logger.info(f"Generating {n_samples} simulated lenses...")
    for i in tqdm(range(n_samples)):
        try:
            # Create source galaxy with Sersic profile
            source = galsim.Sersic(n=4, half_light_radius=0.5, flux=100)

            # Create lens galaxy with Sersic profile
            lens = galsim.Sersic(n=4, half_light_radius=1.0, flux=200)

            # Apply shear to simulate lensing effect
            g1, g2 = 0.05, 0.02  # Shear parameters
            source = source.shear(g1=g1, g2=g2)

            # Draw both objects separately
            image_size = 128  # pixels
            pixel_scale = 0.2  # arcsec/pixel

            # Draw images
            lens_image = lens.drawImage(nx=image_size, ny=image_size, scale=pixel_scale)
            source_image = source.drawImage(nx=image_size, ny=image_size, scale=pixel_scale)

            # Add images
            final_image = lens_image + source_image

            # Add noise to the final image
            noise = galsim.GaussianNoise(sigma=0.1)
            final_image.addNoise(noise)

            # Save image
            final_image.write(f"{config.paths['sim_dir']}/sim_lens_{i}.fits")

        except Exception as e:
            logger.error(f"Error generating lens {i}: {str(e)}")

if __name__ == "__main__":
    download_slacs_data()
    generate_simulated_lenses(n_samples=10)  #  small number for testing

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

In [4]:
def apply_preprocessing(data):
    """Apply preprocessing steps to a single image"""
    try:
        # Z-score normalization
        zscale = ZScaleInterval()
        data = zscale(data)

        # Standardize size
        data = cv2.resize(data, (128, 128))

        # Apply noise reduction
        data = cv2.GaussianBlur(data, (3,3), 0)

        # Ensure correct data type and range
        data = (data * 255).astype(np.uint8)

        return data
    except Exception as e:
        logger.error(f"Error in preprocessing: {str(e)}")
        return None

def preprocess_images():
    """Preprocess all images in the data directories"""
    config = Config()
    os.makedirs(config.paths['processed_dir'], exist_ok=True)

    # Process SLACS data
    logger.info("Processing SLACS data...")
    slacs_files = os.listdir(config.paths['slacs_dir'])
    for fits_file in tqdm(slacs_files):
        if fits_file.endswith('.fits'):
            try:
                with fits.open(os.path.join(config.paths['slacs_dir'], fits_file)) as hdul:
                    data = hdul[0].data
                processed_data = apply_preprocessing(data)
                if processed_data is not None:
                    np.save(
                        os.path.join(config.paths['processed_dir'], f"slacs_{fits_file[:-5]}.npy"),
                        processed_data
                    )
            except Exception as e:
                logger.error(f"Error processing {fits_file}: {str(e)}")

    # Process simulated data
    logger.info("Processing simulated data...")
    sim_files = os.listdir(config.paths['sim_dir'])
    for fits_file in tqdm(sim_files):
        if fits_file.endswith('.fits'):
            try:
                with fits.open(os.path.join(config.paths['sim_dir'], fits_file)) as hdul:
                    data = hdul[0].data
                processed_data = apply_preprocessing(data)
                if processed_data is not None:
                    np.save(
                        os.path.join(config.paths['processed_dir'], f"sim_{fits_file[:-5]}.npy"),
                        processed_data
                    )
            except Exception as e:
                logger.error(f"Error processing {fits_file}: {str(e)}")

In [5]:
class GravitationalLensDataset(Dataset):
    def __init__(self, mode='train', transform=None):
        self.config = Config()
        self.transform = transform
        self.mode = mode

        # Get all processed files
        all_files = os.listdir(self.config.paths['processed_dir'])
        np_files = [f for f in all_files if f.endswith('.npy')]

        if len(np_files) == 0:
            raise RuntimeError("No data files found. Please run the data generation and preprocessing steps first.")

        # Split data
        np.random.seed(42)
        np.random.shuffle(np_files)

        n_files = len(np_files)
        train_idx = int(n_files * self.config.data_params['train_split'])
        val_idx = train_idx + int(n_files * self.config.data_params['val_split'])

        if mode == 'train':
            self.files = np_files[:train_idx]
        elif mode == 'val':
            self.files = np_files[train_idx:val_idx]
        else:  # test
            self.files = np_files[val_idx:]

        if len(self.files) == 0:
            raise RuntimeError(f"No files found for {mode} split. Check your data directory and split ratios.")

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        try:
            file_path = os.path.join(self.config.paths['processed_dir'], self.files[idx])
            image = np.load(file_path)

            # Ensure image is in correct format (single channel)
            if len(image.shape) == 2:
                image = np.expand_dims(image, axis=2)

            # Convert to PIL Image
            image = Image.fromarray(image.squeeze().astype(np.uint8), mode='L')

            if self.transform:
                image = self.transform(image)

            # Get label (1 for real SLACS, 0 for simulated)
            is_real = 1 if self.files[idx].startswith('slacs') else 0
            label = torch.tensor(is_real, dtype=torch.long)

            # Create metadata tensor
            metadata = torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32)

            return image, label, metadata
        except Exception as e:
            logger.error(f"Error loading file {self.files[idx]}: {str(e)}")
            # Return a dummy sample in case of error
            return torch.zeros((1, self.config.data_params['image_size'],
                              self.config.data_params['image_size'])), torch.tensor(0), torch.zeros(3)

class DataAugmentation:
    def __init__(self, config):
        self.config = config

    def get_transforms(self, mode='train'):
        if mode == 'train':
            return transforms.Compose([
                transforms.Resize((self.config.data_params['image_size'],
                                 self.config.data_params['image_size'])),
                transforms.RandomRotation(self.config.augmentation_params['rotation_range']),
                transforms.RandomAffine(
                    degrees=0,
                    translate=(self.config.augmentation_params['shift_range'],
                             self.config.augmentation_params['shift_range']),
                    scale=(1-self.config.augmentation_params['zoom_range'],
                          1+self.config.augmentation_params['zoom_range'])
                ),
                transforms.RandomHorizontalFlip(),
                transforms.RandomVerticalFlip(),
                transforms.ToTensor(),  # Convert PIL Image to tensor
                transforms.Normalize(mean=[0.485], std=[0.229])
            ])
        else:
            return transforms.Compose([
                transforms.Resize((self.config.data_params['image_size'],
                                 self.config.data_params['image_size'])),
                transforms.ToTensor(),  # Convert PIL Image to tensor
                transforms.Normalize(mean=[0.485], std=[0.229])
            ])

In [6]:
class SEBlock(nn.Module):
    def __init__(self, channel, reduction=16):
        super(SEBlock, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y.expand_as(x)

class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, use_se=True):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3,
                              stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3,
                              stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)

        # SE block
        self.se = SEBlock(out_channels) if use_se else nn.Identity()

        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1,
                         stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        residual = x
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out = self.se(out)  # Apply SE block
        out += self.shortcut(residual)
        out = F.relu(out)
        return out

class LensDetectionCNN(nn.Module):
    def __init__(self, config):
        super(LensDetectionCNN, self).__init__()

        # Initial convolution block
        self.initial = nn.Sequential(
            nn.Conv2d(config.model_params['in_channels'],
                     config.model_params['base_filters'],
                     kernel_size=7, stride=2, padding=3),
            nn.BatchNorm2d(config.model_params['base_filters']),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )

        # Residual blocks
        use_se = config.model_params.get('use_se_block', True)
        self.layer1 = self._make_layer(config.model_params['base_filters'],
                                     config.model_params['base_filters']*2,
                                     stride=2, use_se=use_se)
        self.layer2 = self._make_layer(config.model_params['base_filters']*2,
                                     config.model_params['base_filters']*4,
                                     stride=2, use_se=use_se)
        self.layer3 = self._make_layer(config.model_params['base_filters']*4,
                                     config.model_params['base_filters']*8,
                                     stride=2, use_se=use_se)

        # Attention mechanism
        self.attention = nn.Sequential(
            nn.Conv2d(config.model_params['base_filters']*8,
                     1, kernel_size=1),
            nn.Sigmoid()
        )

        # Global average pooling and dense layers
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc1 = nn.Linear(config.model_params['base_filters']*8, 512)
        self.dropout1 = nn.Dropout(config.model_params['dropout_rate'])
        self.fc2 = nn.Linear(512, 128)
        self.dropout2 = nn.Dropout(config.model_params['dropout_rate'])
        self.fc3 = nn.Linear(128, 2)

    def _make_layer(self, in_channels, out_channels, stride, use_se):
        layers = []
        layers.append(ResidualBlock(in_channels, out_channels, stride, use_se))
        layers.append(ResidualBlock(out_channels, out_channels, 1, use_se))
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.initial(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)

        attention = self.attention(x)
        x = x * attention

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = self.dropout1(x)
        x = F.relu(self.fc2(x))
        x = self.dropout2(x)
        x = self.fc3(x)

        return F.log_softmax(x, dim=1)

In [7]:
class Trainer:
    def __init__(self, model, train_loader, val_loader, config):
        self.config = config
        self.model = model
        self.train_loader = train_loader
        self.val_loader = val_loader

        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.model = self.model.to(self.device)

        self.criterion = nn.NLLLoss()
        self.optimizer = optim.AdamW(
            self.model.parameters(),
            lr=config.training_params['learning_rate'],
            weight_decay=config.training_params['weight_decay']
        )
        self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizer, mode='min', factor=0.5,
            patience=5, verbose=True
        )

        self.best_val_loss = float('inf')
        self.patience_counter = 0

    def train_epoch(self):
        self.model.train()
        total_loss = 0
        correct = 0
        total = 0

        pbar = tqdm(self.train_loader, desc='Training')
        for batch_idx, (data, target, metadata) in enumerate(pbar):
            data, target = data.to(self.device), target.to(self.device)
            metadata = metadata.to(self.device)

            self.optimizer.zero_grad()
            output = self.model(data)
            loss = self.criterion(output, target)

            loss.backward()
            self.optimizer.step()

            total_loss += loss.item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
            total += target.size(0)

            pbar.set_postfix({
                'loss': total_loss/(batch_idx+1),
                'acc': 100.*correct/total
            })

        return total_loss/len(self.train_loader), correct/total

    def validate(self):
        self.model.eval()
        val_loss = 0
        correct = 0
        total = 0

        with torch.no_grad():
            for data, target, metadata in self.val_loader:
                data, target = data.to(self.device), target.to(self.device)
                metadata = metadata.to(self.device)

                output = self.model(data)
                val_loss += self.criterion(output, target).item()

                pred = output.argmax(dim=1, keepdim=True)
                correct += pred.eq(target.view_as(pred)).sum().item()
                total += target.size(0)

        val_loss /= len(self.val_loader)
        accuracy = correct/total

        return val_loss, accuracy

    def train(self):
        logger.info("Starting training...")
        history = {
            'train_loss': [], 'train_acc': [],
            'val_loss': [], 'val_acc': []
        }

        for epoch in range(self.config.training_params['epochs']):
            logger.info(f"Epoch {epoch+1}/{self.config.training_params['epochs']}")

            train_loss, train_acc = self.train_epoch()
            val_loss, val_acc = self.validate()

            history['train_loss'].append(train_loss)
            history['train_acc'].append(train_acc)
            history['val_loss'].append(val_loss)
            history['val_acc'].append(val_acc)

            logger.info(
                f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, "
                f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}"
            )

            self.scheduler.step(val_loss)

            if val_loss < self.best_val_loss - self.config.training_params['min_delta']:
                self.best_val_loss = val_loss
                self.save_checkpoint(epoch, val_loss)
                self.patience_counter = 0
            else:
                self.patience_counter += 1

            if self.patience_counter >= self.config.training_params['patience']:
                logger.info("Early stopping triggered")
                break

        return history

    def save_checkpoint(self, epoch, val_loss):
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'val_loss': val_loss,
        }
        torch.save(checkpoint,
                  f"{self.config.paths['checkpoints_dir']}/model_epoch_{epoch}.pt")


class ModelEvaluator:
    def __init__(self, model, data_loader, device):
        self.model = model
        self.data_loader = data_loader
        self.device = device

    def evaluate(self):
        self.model.eval()
        predictions = []
        targets = []
        confidences = []

        with torch.no_grad():
            for data, target, _ in tqdm(self.data_loader, desc='Evaluating'):
                data = data.to(self.device)
                output = self.model(data)

                # Get predictions and confidences
                probs = torch.exp(output)
                pred = output.argmax(dim=1).cpu().numpy()
                conf = probs.max(dim=1)[0].cpu().numpy()

                predictions.extend(pred)
                targets.extend(target.numpy())
                confidences.extend(conf)

        # Convert to numpy arrays
        predictions = np.array(predictions)
        targets = np.array(targets)
        confidences = np.array(confidences)

        # Calculate metrics
        metrics = {
            'accuracy': accuracy_score(targets, predictions),
            'precision': precision_score(targets, predictions),
            'recall': recall_score(targets, predictions),
            'f1': f1_score(targets, predictions),
            'confusion_matrix': confusion_matrix(targets, predictions),
            'roc_auc': roc_auc_score(targets, confidences),
            'average_confidence': np.mean(confidences)
        }

        return metrics, predictions, confidences, targets

    def plot_results(self, metrics, predictions, confidences, targets):
        """Plot comprehensive evaluation results"""
        fig = plt.figure(figsize=(20, 12))

        # Confusion Matrix
        plt.subplot(2, 3, 1)
        sns.heatmap(metrics['confusion_matrix'], annot=True, fmt='d', cmap='Blues')
        plt.title('Confusion Matrix')
        plt.xlabel('Predicted')
        plt.ylabel('True')

        # ROC Curve
        plt.subplot(2, 3, 2)
        fpr, tpr, _ = roc_curve(targets, confidences)
        plt.plot(fpr, tpr)
        plt.plot([0, 1], [0, 1], 'k--')
        plt.xlabel('False Positive Rate')
        plt.ylabel('True Positive Rate')
        plt.title(f'ROC Curve (AUC = {metrics["roc_auc"]:.3f})')

        # Precision-Recall Curve
        plt.subplot(2, 3, 3)
        precision, recall, _ = precision_recall_curve(targets, confidences)
        plt.plot(recall, precision)
        plt.xlabel('Recall')
        plt.ylabel('Precision')
        plt.title(f'Precision-Recall Curve\n(F1 = {metrics["f1"]:.3f})')

        # Confidence Distribution
        plt.subplot(2, 3, 4)
        plt.hist(confidences[predictions == targets], alpha=0.5, label='Correct', bins=30)
        plt.hist(confidences[predictions != targets], alpha=0.5, label='Incorrect', bins=30)
        plt.title('Prediction Confidence Distribution')
        plt.xlabel('Confidence')
        plt.ylabel('Count')
        plt.legend()

        # Metrics Summary
        plt.subplot(2, 3, 5)
        metrics_display = {
            'Accuracy': metrics['accuracy'],
            'Precision': metrics['precision'],
            'Recall': metrics['recall'],
            'F1 Score': metrics['f1'],
            'ROC AUC': metrics['roc_auc'],
            'Avg Conf.': metrics['average_confidence']
        }
        plt.axis('off')
        plt.text(0.1, 0.9, 'Metrics Summary:', fontsize=12, fontweight='bold')
        for i, (metric, value) in enumerate(metrics_display.items()):
            plt.text(0.1, 0.8 - i*0.1, f'{metric}: {value:.3f}', fontsize=10)

        plt.tight_layout()
        return fig

In [8]:
def setup_data(config):
    """Setup datasets and dataloaders"""
    data_aug = DataAugmentation(config)

    # Create datasets
    train_dataset = GravitationalLensDataset(mode='train', transform=data_aug.get_transforms('train'))
    val_dataset = GravitationalLensDataset(mode='val', transform=data_aug.get_transforms('val'))
    test_dataset = GravitationalLensDataset(mode='test', transform=data_aug.get_transforms('test'))

    # Create dataloaders with custom collate function
    train_loader = DataLoader(train_dataset,
                            batch_size=config.data_params['batch_size'],
                            shuffle=True,
                            num_workers=config.data_params['num_workers'],
                            collate_fn=custom_collate)

    val_loader = DataLoader(val_dataset,
                          batch_size=config.data_params['batch_size'],
                          shuffle=False,
                          num_workers=config.data_params['num_workers'],
                          collate_fn=custom_collate)

    test_loader = DataLoader(test_dataset,
                           batch_size=config.data_params['batch_size'],
                           shuffle=False,
                           num_workers=config.data_params['num_workers'],
                           collate_fn=custom_collate)

    return train_loader, val_loader, test_loader


def custom_collate(batch):
    """Custom collate function to handle PIL images"""
    images = []
    labels = []
    metadata = []

    for item in batch:
        image, label, meta = item
        if isinstance(image, Image.Image):
            # Convert PIL Image to tensor if not already converted
            transform = transforms.ToTensor()
            image = transform(image)
        images.append(image)
        labels.append(label)
        metadata.append(meta)

    # Stack all items into tensors
    images = torch.stack(images)
    labels = torch.stack(labels)
    metadata = torch.stack(metadata)

    return images, labels, metadata


def plot_training_history(history):
    """Plot training history"""
    plt.figure(figsize=(12, 4))

    plt.subplot(1, 2, 1)
    plt.plot(history['train_loss'], label='Train Loss')
    plt.plot(history['val_loss'], label='Val Loss')
    plt.title('Loss History')
    plt.legend()

    plt.subplot(1, 2, 2)
    plt.plot(history['train_acc'], label='Train Acc')
    plt.plot(history['val_acc'], label='Val Acc')
    plt.title('Accuracy History')
    plt.legend()

    plt.tight_layout()
    plt.show()

def check_data_status(config):
    """Check status of data directories and files"""

    # Check SLACS directory
    if os.path.exists(config.paths['slacs_dir']):
        slacs_files = [f for f in os.listdir(config.paths['slacs_dir']) if f.endswith('.fits')]
        logger.info(f"Found {len(slacs_files)} SLACS files")
    else:
        logger.info("SLACS directory not found")

    # Check simulated directory
    if os.path.exists(config.paths['sim_dir']):
        sim_files = [f for f in os.listdir(config.paths['sim_dir']) if f.endswith('.fits')]
        logger.info(f"Found {len(sim_files)} simulated files")
    else:
        logger.info("Simulated directory not found")

    # Check processed directory
    if os.path.exists(config.paths['processed_dir']):
        proc_files = [f for f in os.listdir(config.paths['processed_dir']) if f.endswith('.npy')]
        logger.info(f"Found {len(proc_files)} processed files")
    else:
        logger.info("Processed directory not found")

def perform_cross_validation(model_class, dataset, config, device, n_folds=5):
    """Perform k-fold cross validation"""
    kf = KFold(n_splits=n_folds, shuffle=True, random_state=42)
    fold_results = []

    for fold, (train_idx, val_idx) in enumerate(kf.split(range(len(dataset)))):
        logger.info(f"Training fold {fold + 1}/{n_folds}")

        # Create data loaders for this fold
        train_subsampler = torch.utils.data.SubsetRandomSampler(train_idx)
        val_subsampler = torch.utils.data.SubsetRandomSampler(val_idx)

        train_loader = DataLoader(
            dataset,
            batch_size=config.data_params['batch_size'],
            sampler=train_subsampler,
            num_workers=config.data_params['num_workers'],
            collate_fn=custom_collate
        )
        val_loader = DataLoader(
            dataset,
            batch_size=config.data_params['batch_size'],
            sampler=val_subsampler,
            num_workers=config.data_params['num_workers'],
            collate_fn=custom_collate
        )

        # Initialize model and trainer
        model = model_class(config).to(device)
        trainer = Trainer(model, train_loader, val_loader, config)

        # Train model
        history = trainer.train()

        # Evaluate
        evaluator = ModelEvaluator(model, val_loader, device)
        metrics, predictions, confidences, targets = evaluator.evaluate()

        fold_results.append({
            'metrics': metrics,
            'history': history,
            'predictions': predictions,
            'confidences': confidences,
            'targets': targets
        })

        # Save fold results
        torch.save(model.state_dict(), f"{config.paths['checkpoints_dir']}/model_fold_{fold}.pt")

    return fold_results

def main():
    try:
        # Initialize configuration
        config = Config()

        # Set up device
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        logger.info(f"Using device: {device}")

        # Create directories
        for path in config.paths.values():
            os.makedirs(path, exist_ok=True)

        # Generate data if needed
        if not os.path.exists(config.paths['processed_dir']) or \
           len(os.listdir(config.paths['processed_dir'])) < 30:  # Minimum required files
            logger.info("Generating new data...")
            download_slacs_data()
            generate_simulated_lenses(n_samples=200)
            preprocess_images()

        # Create dataset
        full_dataset = GravitationalLensDataset(mode='train')  # Use all data for CV

        # Perform cross-validation
        cv_results = perform_cross_validation(LensDetectionCNN, full_dataset, config, device)

        # Analyze cross-validation results
        cv_metrics = {
            'accuracy': [],
            'f1': [],
            'roc_auc': []
        }

        for fold_result in cv_results:
            metrics = fold_result['metrics']
            cv_metrics['accuracy'].append(metrics['accuracy'])
            cv_metrics['f1'].append(metrics['f1'])
            cv_metrics['roc_auc'].append(metrics['roc_auc'])

        # Print cross-validation results
        logger.info("\nCross-validation Results:")
        for metric, values in cv_metrics.items():
            mean_val = np.mean(values)
            std_val = np.std(values)
            logger.info(f"{metric}: {mean_val:.3f} ± {std_val:.3f}")

        # Save results
        timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
        results_path = os.path.join(config.paths['logs_dir'], f'cv_results_{timestamp}.json')

        with open(results_path, 'w') as f:
            json.dump({
                'cv_metrics': {k: list(map(float, v)) for k, v in cv_metrics.items()},
                'config': config.__dict__
            }, f, indent=4)

        logger.info(f"\nResults saved to {results_path}")

    except Exception as e:
        logger.error(f"Error in main execution: {str(e)}")
        raise

if __name__ == "__main__":
    main()



Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/2 [00:00<?, ?it/s]



Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/2 [00:00<?, ?it/s]

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/2 [00:00<?, ?it/s]



Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/2 [00:00<?, ?it/s]



Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Training:   0%|          | 0/7 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/2 [00:00<?, ?it/s]

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
