In [None]:
from torch.utils.data import Dataset, DataLoader
from glob import glob
import numpy as np
import astropy.io.fits as fits
import os
from tqdm.auto import tqdm
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from datetime import datetime, timedelta
from collections import defaultdict
import os
import copy
import json

## Dataloader

Note: 
* If doing a major update to the backgrounds first delete /mnt/onboard_data/classifier/.cache/background

In [None]:
def extract_images_within_time_range(events, image_paths):
    selected_images = []

    # Organize image paths by date for efficient checking
    image_paths_by_date = {}
    for image_path in image_paths:
        image_date = datetime.strptime(os.path.basename(image_path).split('_')[0], '%Y%m%d').date()
        image_paths_by_date.setdefault(image_date, []).append(image_path)

    for event in events:

        if not event['visible']:
            continue

        if event['faint']:
            continue
        
        event_start_time = datetime.strptime(event['datetime'], '%Y/%m/%d %H:%M')
        event_stop_time = datetime.strptime(event['event_stop_time'], '%Y%m%d_%H%M%S')
        event_date = event_start_time.date()
        
        # Check images within the event's date
        if event_date in image_paths_by_date:

            paths = image_paths_by_date[event_date]

            next_day = event_date + timedelta(days=1)
            if next_day in image_paths_by_date:
                paths.extend(image_paths_by_date[next_day])
            
            for image_path in paths:
                image_timestamp = datetime.strptime('_'.join(os.path.basename(image_path).split('_')[:2]), '%Y%m%d_%H%M%S')
                if event_start_time <= image_timestamp <= event_stop_time:
                    selected_images.append(image_path)

    return selected_images

In [None]:
class CMEDataset(Dataset):

    def __init__(self, root, events = [], pol='all', size=512):
        self.cache_dir = os.path.join(root, '.cache')
        os.makedirs(self.cache_dir, exist_ok=True)
        
        self.events = events
        self.stereo_a = sorted(glob(os.path.join(root, "201402*", "cor1", "stereo_a", "*", "*.fts")))
        self.stereo_b = sorted(glob(os.path.join(root, "201402*", "cor1", "stereo_b", "*", "*.fts")))

        self.pol = pol

        self.images = []

        if pol == 'all':
            self.mean = 2691.3037070368546 
            self.std = 2579.566574917962
            #self.images.extend([im for im in self.stereo_a if os.path.basename(os.path.dirname(im)) != '1001.0'])
            #self.images.extend([im for im in self.stereo_b if os.path.basename(os.path.dirname(im)) != '1001.0'])
            self.images.extend([im for im in self.stereo_a if os.path.basename(os.path.dirname(im)) in ["0.0", "120.0", "240.0"]])
            self.images.extend([im for im in self.stereo_b if os.path.basename(os.path.dirname(im)) in ["0.0", "120.0", "240.0"]])
        elif pol == 'sum':
            self.images.extend([im for im in self.stereo_a if 'n4' in im])
            self.images.extend([im for im in self.stereo_b if 'n4' in im])
            self.mean = 3658.224788149089
            self.std = 3399.0258091444553
        else:
            self.mean = 2691.3037070368546 
            self.std = 2579.566574917962

            self.images.extend([im for im in self.stereo_a if os.path.basename(os.path.dirname(im)) == pol])
            self.images.extend([im for im in self.stereo_b if os.path.basename(os.path.dirname(im)) == pol])

        # filter size
        for image in self.images:
            
            if fits.getdata(image).shape != (size,size):
                print(f"Removing {image}")
                self.images.remove(image)

            # TODO bug: at this point no 60.0s?
            if "/60.0/" in str(image):            
                print(f"Removing {image}")
                self.images.remove(image)

        self.images = sorted(self.images)
        self.images_for_date = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
        self.dates = set()
        self.background = defaultdict(lambda: defaultdict(lambda: defaultdict()))
        
        # TODO bug: but by this point there are 60.0s...??
        for image_path in self.images:
            image_date = self._image_date(image_path)
            sat, angle = self._image_info(image_path)

            self.images_for_date[image_date][sat][angle].append(image_path)
            self.dates.add(image_date)
            
        self.transform = transforms.Compose([
            transforms.ToTensor(),                   # Convert image to PyTorch tensor
            transforms.Normalize(mean=self.mean, std=self.std) # Normalize using mean and std
        ])

        self._get_labels()
        self._gen_background()

    def _image_date(self, image):
        return datetime.strptime(os.path.basename(image).split('_')[0], '%Y%m%d').date()

    def _image_info(self, image_path):
        if 'stereo_a' in image_path: sat = 'stereo_a'
        elif 'stereo_b' in image_path: sat = 'stereo_b'
        if "/0.0/" in image_path: angle = "0.0"
        elif "/120.0/" in image_path: angle = "120.0"
        elif "/240.0/" in image_path: angle = "240.0"
        elif "/1001.0/" in image_path: angle = "1001.0"

        # TODO: bug: why do 60.0s show up now?
        elif "/60.0/" in image_path: angle = "60.0"
        elif "/90.0/" in image_path: angle = "90.0"
        elif "/180.0/" in image_path: angle = "180.0"
        elif "/300.0/" in image_path: angle = "300.0"
        else: print(image_path)

        return sat, angle

    def _get_labels(self):
        self.positive_labels = set(extract_images_within_time_range(self.events['stereo_a'], self.stereo_a))
        self.positive_labels |= set(extract_images_within_time_range(self.events['stereo_b'], self.stereo_b))
        self.cme_images = set([im for im in self.images if im in self.positive_labels])

    def _gen_background(self, filter_cme=False):
        """
        Generate backgrounds for each satellite, each polarisation angle
        Need to delete cache if want to overwrite
        """
        cache_root = os.path.join(self.cache_dir, "background")
        os.makedirs(cache_root, exist_ok=True)
        print("Cache:", cache_root)
        
        for date in tqdm(sorted(list(self.dates))):
            image_date = date - timedelta(days=1)
            date_string = date.strftime("%Y%m%d")
            
            if image_date not in self.images_for_date:
                image_date = date

            for sat in ['stereo_a', 'stereo_b']:
                for angle in self.images_for_date[image_date][sat].keys():
                    cache_path = os.path.join(cache_root, f"{date_string}_{sat}_{angle}.npy")
        
                    if os.path.exists(cache_path):
                        self.background[date][sat][angle] = np.load(cache_path)
                    else:
                        imgs = self.images_for_date[image_date][sat][angle]

                        if filter_cme:
                            imgs = np.array([fits.getdata(im) for im in imgs if im not in self.cme_images])
                        else:
                            imgs = np.array([fits.getdata(im) for im in imgs])
                        
                        self.background[date][sat][angle] = np.median(imgs, axis=0)
                        np.save(cache_path, self.background[date][sat][angle])

    def _get_background(self, image_path):
        """
        given image path, retrieve correct background from dict
        """
        image_date = self._image_date(image_path)
        sat, angle = self._image_info(image_path)

        bg = self.background[image_date][sat][angle]

        # try:
        #     # retrieve background from all non-cme images from previous day
        #     # TODO already using day before when generated??
        #     bg = self.background[image_date - timedelta(days=1)][sat][angle]
        # except KeyError:
        #     # handle cases where no background available
        #     bg = self.background[image_date][sat][angle]
        #     print("Warning: using same day background for", image_path)

        return bg

    def _get_difference_image(self, i):
        """
        will not give viable results for the first image
        (and first day will use same-day background)

        TODO: ensure differencing over same sat, same polar with i-1
        """
        # read raw images for current and previous
        if i == 0: j = 0
        else: j = i-1
        raw_img_i = fits.getdata(self.images[i]).astype(np.float32)
        raw_img_j = fits.getdata(self.images[j]).astype(np.float32)

        # get background from day before
        bg = self._get_background(self.images[i])

        # get difference image with subtracted background
        diff_img = (raw_img_i - bg) - (raw_img_j - bg)

        return diff_img
                                     
    def __getitem__(self, i):
        data  = self._get_difference_image(i)
        label = int(self.images[i] in self.cme_images)
        
        return data, label #self.transform(data), label

    def __len__(self):
        return len(self.images)

In [None]:
# from event scraper of: 
# https://cor1.gsfc.nasa.gov/catalog/cme/2014/Daniel_Hong_COR1_preliminary_event_list_2014-02.html
with open('events_201402.json', 'r') as fp:
    events = json.load(fp)

In [None]:
# manually make each cme arbitrarily 3 hours long

date_format = "%Y%m%d_%H%M%S"

for sat in events.keys():
    for i in range(len(events[sat])):
        #print(events[sat][i]["event_stop_time"])
        start = datetime.strptime(events[sat][i]["event_start_time"], date_format)
        end   = (start + timedelta(hours=3)).strftime(date_format)
        events[sat][i]["event_stop_time"] = end
        #print(events[sat][i]["event_stop_time"])
        #break
    #break

In [None]:
events

In [None]:
%%time
#events['stereo_a'] = [events['stereo_a'][10]]
#events['stereo_b'] = [events['stereo_b'][10]]

#ds = CMEDataset(root="/media/josh/josh_tuf_a/data/fdl/2023/onboard", pol='0.0', events=events)
#ds = CMEDataset(root="/mnt/onboard_data/classifier", pol='0.0', events=events)

# 7 mins without cache, 2-3 with cache
ds = CMEDataset(root="/mnt/onboard_data/classifier", pol='all', events=events)

In [None]:
print(ds.images_for_date.keys())
for k in ds.images_for_date.keys():
    print(ds.images_for_date[k].keys())
    print(ds.images_for_date[k]["stereo_a"].keys())
    print(ds.images_for_date[k]["stereo_b"].keys())
    break

In [None]:
print(ds.background.keys())
for k in ds.background.keys():
    print(ds.background[k].keys())
    print(ds.background[k]["stereo_a"].keys())
    print(ds.background[k]["stereo_a"]["0.0"])
    break

In [None]:
len(ds.cme_images)

In [None]:
len(ds.images)

In [None]:
ds.images[:5]

In [None]:
ds.images[-5:]

In [None]:
ds.dates

In [None]:
raw_img_i = fits.getdata(ds.images[1]).astype(np.float32)
raw_img_j = fits.getdata(ds.images[0]).astype(np.float32)

In [None]:
fits.getdata(ds.images[-1]).astype(np.float32).shape

In [None]:
raw_img_i.shape, raw_img_j.shape

In [None]:
bg_i = ds._get_background(ds.images[1])
bg_i.shape

In [None]:
image_date = ds._image_date(ds.images[1])
image_date

In [None]:
image_date - timedelta(days=1)

## Check mean std for normalisation

TODO:
* double check and update for when background is removed

In [None]:
def compute_mean_std_from_histogram(bin_edges, bin_counts):
    # Calculate midpoints of each bin
    bin_midpoints = (bin_edges[:-1] + bin_edges[1:]) / 2
    
    # Calculate weighted sum and total count
    weighted_sum = np.sum(bin_midpoints * bin_counts)
    total_count = np.sum(bin_counts)
    
    # Calculate mean
    mean = weighted_sum / total_count
    
    # Calculate weighted squared deviation from mean
    weighted_squared_deviation = np.sum(bin_counts * (bin_midpoints - mean)**2)
    
    # Calculate weighted standard deviation
    std = np.sqrt(weighted_squared_deviation / total_count)
    
    return mean, std

def get_dataset_counts(dataset):
    
    #bit_depth = ds[0].flatten()[0].nbytes
    bit_depth = 1 #ds[0][0].flatten()[0].numpy().nbytes
    n_vals = 2**(8*bit_depth) # bit_depth of 8 causing this to explode
    counts = np.zeros(n_vals, dtype=np.uint64)

    for d in tqdm(dataset):
        # must be non-negative...
        counts += np.bincount(d[0].flatten(), minlength=n_vals).astype(np.uint64)

    return counts

def get_dataloader_meanstd(ds):
    counts = get_dataset_counts(ds)
    hist_edges = np.arange(len(counts) + 1)
    hist_counts = counts

    # Compute mean and standard deviation from the histogram
    mean, std = compute_mean_std_from_histogram(hist_edges, hist_counts)

    return mean, std

In [None]:
%%time
# ds = CMEDataset(root="/mnt/onboard_data/classifier", pol='0.0', events=events)

In [None]:
get_dataloader_meanstd(ds)

In [None]:
ds.images[1]

In [None]:
ds[1][0]

In [None]:
plt.hist(ds[1][0].flatten())
plt.show()

In [None]:
len(ds.images)

In [None]:
ds.images[26360]

In [None]:
for d in tqdm(ds):
    print(d)
    break

In [None]:
x = []
for d in tqdm(ds):
    x.append(d[0].flatten())

In [None]:
npx = np.array(x)

In [None]:
npx.shape

In [None]:
np.mean(npx), np.std(npx) # too slow # killed kernel after 20 mins

## Check CME images

In [None]:
date = datetime(2014, 2, 25).date()
bg = ds.background[date]['stereo_b']["0.0"]

plt.figure(figsize=(10,10))
plt.subplot(121)
plt.imshow(bg)
plt.colorbar(location='bottom')

plt.subplot(122)
_ = plt.hist(bg.flatten(), bins=256)

plt.tight_layout()

In [None]:
idx = 3
pol = "0.0"
img_curr = fits.getdata(ds.images_for_date[date]['stereo_b'][pol][idx])
img_prev = fits.getdata(ds.images_for_date[date]['stereo_b'][pol][idx-1])

img = (img_curr - bg) - (img_prev - bg)

plt.figure(figsize=(10,10))
plt.subplot(121)
plt.imshow(img)
plt.colorbar(location='bottom')

plt.subplot(122)
_ = plt.hist(img.flatten(), bins=256)

plt.tight_layout()

In [None]:
# if this is 0 I don't know if it is using the background
idx = 25
img = ds[idx][0]
label = ds[idx][1]

plt.figure(figsize=(10,10))
plt.subplot(121)
plt.imshow(img)
plt.colorbar(location='bottom')

plt.subplot(122)
_ = plt.hist(img.flatten(), bins=256)

plt.tight_layout()

## Further pre-process?

In [None]:
from sunpy.map.maputils import all_coordinates_from_map
from sunpy.map import Map
from astropy.coordinates import SkyCoord
from matplotlib.colors import LogNorm, Normalize, PowerNorm
from rich.progress import Progress
from sunpy.map import Map
import astropy.units as u

def generate_image(img_fname, ref_map, vmin=0, vmax=20):
    """
    for single image
    """
    m = Map(img_fname)

    pixel_coords = all_coordinates_from_map(m)
    solar_center = SkyCoord(0 * u.deg, 0 * u.deg, frame=m.coordinate_frame)
    pixel_radii = np.sqrt(
        (pixel_coords.Tx - solar_center.Tx) ** 2
        + (pixel_coords.Ty - solar_center.Ty) ** 2
    )
    # r2 masking
    mask = 1 - ((pixel_radii / pixel_radii.max()) ** 2) * 0.5
    mask = mask.value
    mask[pixel_radii.value >= 0.9 * pixel_coords.Tx.max().value] = np.nan

    data = (m.data - ref_map) / mask
    
    # imshow is mirror to m.plot
    plt.imshow(data, origin="lower", cmap="stereocor2")

In [None]:
#generate_image(ds.images_for_date[date][20], bg)