In [None]:
from torch.utils.data import Dataset, DataLoader
from glob import glob
import numpy as np
import astropy.io.fits as fits
import os
from tqdm.auto import tqdm
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from datetime import datetime, timedelta
import os

In [None]:
def compute_mean_std_from_histogram(bin_edges, bin_counts):
    # Calculate midpoints of each bin
    bin_midpoints = (bin_edges[:-1] + bin_edges[1:]) / 2
    
    # Calculate weighted sum and total count
    weighted_sum = np.sum(bin_midpoints * bin_counts)
    total_count = np.sum(bin_counts)
    
    # Calculate mean
    mean = weighted_sum / total_count
    
    # Calculate weighted squared deviation from mean
    weighted_squared_deviation = np.sum(bin_counts * (bin_midpoints - mean)**2)
    
    # Calculate weighted standard deviation
    std = np.sqrt(weighted_squared_deviation / total_count)
    
    return mean, std

def get_dataset_counts(dataset):
    
    bit_depth = ds[0].flatten()[0].nbytes
    n_vals = 2**(8*bit_depth)
    counts = np.zeros(n_vals, dtype=np.uint64)

    for d in tqdm(dataset):
        counts += np.bincount(d.flatten(), minlength=n_vals).astype(np.uint64)

    return counts

"""
counts = get_dataset_mean_std(ds)
hist_edges = np.arange(len(counts) + 1)
hist_counts = counts

# Compute mean and standard deviation from the histogram
mean, std = compute_mean_std_from_histogram(hist_edges, hist_counts)
"""

In [None]:
def extract_images_within_time_range(events, image_paths):
    selected_images = []

    # Organize image paths by date for efficient checking
    image_paths_by_date = {}
    for image_path in image_paths:
        image_date = datetime.strptime(os.path.basename(image_path).split('_')[0], '%Y%m%d').date()
        image_paths_by_date.setdefault(image_date, []).append(image_path)

    for event in events:

        if not event['visible']:
            continue
        
        event_start_time = datetime.strptime(event['datetime'], '%Y/%m/%d %H:%M')
        event_stop_time = datetime.strptime(event['event_stop_time'], '%Y%m%d_%H%M%S')
        event_date = event_start_time.date()
        
        # Check images within the event's date
        if event_date in image_paths_by_date:

            paths = image_paths_by_date[event_date]

            next_day = event_date + timedelta(days=1)
            if next_day in image_paths_by_date:
                paths.extend(image_paths_by_date[next_day])
            
            for image_path in paths:
                image_timestamp = datetime.strptime('_'.join(os.path.basename(image_path).split('_')[:2]), '%Y%m%d_%H%M%S')
                if event_start_time <= image_timestamp <= event_stop_time:
                    selected_images.append(image_path)

    return selected_images

In [None]:
class CMEDataset(Dataset):

    def __init__(self, root, events = [], pol='all'):
        self.events = events
        self.stereo_a = glob(os.path.join(root, "*", "cor1", "stereo_a", "*", "*.fts"))
        self.stereo_b = glob(os.path.join(root, "*", "cor1", "stereo_b", "*", "*.fts"))

        self.images = []

        if pol == 'all':
            self.mean = 2691.3037070368546 
            self.std = 2579.566574917962
            self.images.extend(self.stereo_a)
            self.images.extend(self.stereo_b)
        elif pol == 'sum':
            self.images.extend([im for im in self.stereo_a if 'n4' in im])
            self.images.extend([im for im in self.stereo_b if 'n4' in im])
            self.mean = 3658.224788149089
            self.std = 3399.0258091444553
        else:
            self.mean = 2691.3037070368546 
            self.std = 2579.566574917962

            self.images.extend([im for im in self.stereo_a if os.path.basename(os.path.dirname(im)) == pol])
            self.images.extend([im for im in self.stereo_b if os.path.basename(os.path.dirname(im)) == pol])

        self.transform = transforms.Compose([
            transforms.ToTensor(),                   # Convert image to PyTorch tensor
            transforms.Normalize(mean=self.mean, std=self.std) # Normalize using mean and std
        ])

        self._get_labels()

    def _get_labels(self):
        self.positive_labels = set(extract_images_within_time_range(self.events['stereo_a'], self.stereo_a))
        self.positive_labels |= set(extract_images_within_time_range(self.events['stereo_b'], self.stereo_b))
        self.cme_images = set([im for im in self.images if im in self.positive_labels])
                                     
    def __getitem__(self, i):
        raw_data = fits.getdata(self.images[i]).astype(np.float32)
        
        return self.transform(raw_data), (self.images[i] in self.positive_labels)

    def __len__(self):
        return len(self.images)

In [None]:
with open('events_201402.json', 'r') as fp:
    events = json.load(fp)

#events['stereo_a'] = [events['stereo_a'][10]]
#events['stereo_b'] = [events['stereo_b'][10]]

ds = CMEDataset(root="/media/josh/josh_tuf_a/data/fdl/2023/onboard", pol='all', events=events)

In [None]:
len(ds.cme_images)

In [None]:
len(ds.images)