# Instrument-to-Instrument translation

ITI provides translations between image domains of solar instruments (image enhancement, super-resolution, cross-calibration, estimation of observables). This notebook provides two training examples of ITI.

Colab offers free online computation power. The training requires an active GPU. This can be changed in the menu (Runtime -> Change runtime type -> Hardware accelerator -> GPU).

## Installation and imports

In [None]:
!pip install sunpy==3.0 zarr gcsfs
!pip install git+https://github.com/RobertJaro/InstrumentToInstrument.git@v0.1.0
!pip install git+https://github.com/vale-salvatelli/sdo-autocal_pub.git

In [None]:
import gcsfs
import zarr
import dask.array as da

import glob
import os
import logging

import numpy as np

from torch.utils.data import Dataset, DataLoader
from multiprocessing import get_context

from sdo.datasets.sdo_dataset import SDO_Dataset
from sdo.pytorch_utilities import create_dataloader

from iti.data.editor import *
from iti.data.dataset import BaseDataset, StackDataset, sdo_norms
from iti.train.model import DiscriminatorMode
from iti.trainer import Trainer, loop
from urllib.request import urlretrieve

from matplotlib import pyplot as plt
from matplotlib.colors import Normalize
from matplotlib.colors import LogNorm
from astropy.visualization import ImageNormalize, LinearStretch, AsinhStretch

from sunpy.visualization.colormaps import cm

from datetime import datetime, timedelta

import pandas as pd

from astropy import units as u
from astropy.coordinates import SkyCoord
import warnings
warnings.filterwarnings('ignore')
os.makedirs('data', exist_ok=True)

from tqdm import tqdm

import torch



## Download SDOML dataset

The SDOML dataset is publicly available for download. The data is stored as compressed Numpy arrays.

For this demo we use observations from 2011, where we select one observation per day. For practical applications all avaialable years should be considered (2010 - 2021).

We use the HMI magnetograms as reference and select the corresponding EUV observations.

In [None]:
gcs = gcsfs.GCSFileSystem(access="read_only")

In [None]:
# load HMI data
loc_hmi = "fdl-sdoml-v2/sdomlv2_hmi.zarr/2011"
store = gcsfs.GCSMap(loc_hmi, gcs=gcs, check=False)
root = zarr.group(store)

hmi_times = root["Bx"].attrs["T_OBS"]
sampling_step = len(hmi_times) // 365
hmi_times = hmi_times[::sampling_step] # subsample
hmi_times = pd.to_datetime(hmi_times, format='%Y.%m.%d_%H:%M:%S_TAI').to_pydatetime()

hmi_Bx = da.from_array(root["Bx"])[::sampling_step] # subsample
hmi_By = da.from_array(root["By"])[::sampling_step] # subsample
hmi_Bz = da.from_array(root["Bz"])[::sampling_step] # subsample

In [None]:
# load AIA data
loc = "fdl-sdoml-v2/sdomlv2.zarr/2011"
store = gcsfs.GCSMap(loc, gcs=gcs, check=False)
aia_root = zarr.group(store, synchronizer=zarr.ThreadSynchronizer())

In [None]:
# align AIA data
aia_keys = ['171A', '193A', '211A', '304A']# all keys: aia_root.array_keys()
time_data_mapping = {t: [] for t in hmi_times}
for key in aia_keys: 
  df = aia_root[key]
  obs_times = pd.to_datetime(df.attrs['T_OBS'], format='%Y-%m-%dT%H:%M:%S.%fZ').to_pydatetime()
  da_array = da.from_array(df)
  for t in hmi_times:
    if np.min(np.abs(obs_times - t)) > timedelta(minutes=15):
      continue
    idx = np.argmin(np.abs(obs_times - t))
    time_data_mapping[t] += [da_array[idx]]


From our selection we donwload the data and save it as numpy arrays.

In [None]:
for i, (d, aia_cube) in tqdm(enumerate(time_data_mapping.items()), desc='loading data cubes', total=len(hmi_times)):
  if len(aia_cube) != len(aia_keys):
    continue
  save_path = 'data/%s.npy' % d.isoformat('T')
  if os.path.exists(save_path):
    continue
  # subsample to 256x256
  cube = np.stack([hmi_Bx[i, ::2, ::2], 
                   hmi_By[i, ::2, ::2], 
                   hmi_Bz[i, ::2, ::2], 
                   *[d[::2, ::2] for d in aia_cube]])
  np.save(save_path, cube)

## Training - EUV-to-304

In the first example we use the EUV channels 171, 193 and 211 to generate synthetic 304 observations. 

In [None]:
base_dir = "sdo_to_sdo"
os.makedirs(base_dir, exist_ok=True)

We first create a data set that reads the numpy arrays and normalizes them to [-1, 1].

In [None]:
norm_min = np.array([-1500, -1500, -1500, 0, 0, 0, 0])
norm_max = np.array([1500, 1500, 1500, 2000, 2500, 2000, 1000])

class SDODataset(Dataset):

    def __init__(self, data, channel_slice=(0, 7), **kwargs):
        self.data = data
        self.channel_slice = channel_slice

        super().__init__()

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        data = np.load(self.data[idx])
        data = (data - norm_min[:, None, None]) / (norm_max - norm_min)[:, None, None]
        data = data * 2 - 1
        data[data > 1] = 1
        data[data < -1] = -1
        data = data[self.channel_slice[0]:self.channel_slice[1]]
        return data

We create two data sets:

A(low-quality) --> channels 3 to 6 (171, 193, 211)

B(high-quality) --> channels 3 to 7 (171, 193, 211, 304)


Therefore, we synthesize the 304 channel in addition to the other channels. The advantage of this approach is that the 304 channel is generated to be consistent with the other channels.

In [None]:
files = glob.glob('data/*.npy')
sdo_train_A = SDODataset(files, (3, 6))
sdo_train_B = SDODataset(files, (3, 7))

We create a basic logging to monitor the progress of our training.

In [None]:
# Init model
logging.basicConfig(
    level=logging.INFO,
    handlers=[
        logging.FileHandler("{0}/{1}.log".format(base_dir, "info_log")),
        logging.StreamHandler()
    ])

The trainer is the central component of ITI. Here, we translate from dat with 3 channels to 4 channels. The discriminator mode uses a single discriminator for each channel and a separate discriminator for the combined set of channesls. Since we do not expect large instrumental noise we set the diversity training to 0. The layer normalization is important when dealing with image patches. For our application we use full-disk observations, which does not require a tracking of norm statistics. For the training with image patches the 'in_aff_rs' norm is suggested. 

In [None]:
trainer = Trainer(input_dim_a=3, input_dim_b=4, discriminator_mode=DiscriminatorMode.CHANNELS, lambda_diversity=0, norm='in_aff')

For monitoring our progress we initialize callbacks that plot intermediate resutls and save the model state.

In [None]:
ds_A=sdo_train_A
ds_B=sdo_train_B
num_workers=0

trainer.cuda()
trainer.train()
start_it = trainer.resume(base_dir)

# Init Callbacks
from iti.callback import HistoryCallback, ProgressCallback, SaveCallback, PlotBAB, PlotABA, ValidationHistoryCallback
history_callback = HistoryCallback(trainer, base_dir)
progress_callback = ProgressCallback(trainer)
save_callback = SaveCallback(trainer, base_dir)

plot_settings = [
    {"cmap": cm.sdoaia171, "title": "AIA 171", 'norm': ImageNormalize(vmin=-1, vmax=1, stretch=AsinhStretch(0.01))},
    {"cmap": cm.sdoaia193, "title": "AIA 193", 'norm': ImageNormalize(vmin=-1, vmax=1, stretch=AsinhStretch(0.01))},
    {"cmap": cm.sdoaia211, "title": "AIA 211", 'norm': ImageNormalize(vmin=-1, vmax=1, stretch=AsinhStretch(0.01))},
    {"cmap": cm.sdoaia304, "title": "AIA 304", 'norm': ImageNormalize(vmin=-1, vmax=1, stretch=AsinhStretch(0.01))},
]
random_sample = [ds_A[i] for i in random.sample(range(len(ds_A)), 4)]
plot_ABA_callback = PlotABA(random_sample, trainer, base_dir, log_iteration=100, plot_settings_A=plot_settings[:-1], plot_settings_B=plot_settings)
plot_ABA_callback.call(0)

random_sample = [ds_B[i] for i in random.sample(range(len(ds_B)), 4)]
plot_BAB_callback = PlotBAB(random_sample, trainer, base_dir, log_iteration=100, plot_settings_A=plot_settings[:-1], plot_settings_B=plot_settings)
plot_BAB_callback.call(0)

callbacks = [plot_ABA_callback, plot_BAB_callback, history_callback, progress_callback, save_callback]
# init data loaders
B_iterator = loop(DataLoader(ds_B, batch_size=1, shuffle=True, num_workers=4, ))
A_iterator = loop(DataLoader(ds_A, batch_size=1, shuffle=True, num_workers=4, ))

With this we can start the main training loop, where randomly sample from our data sets. We iteratively use the trainer to update the generator and discrimintor networks. The results are automatically logged to the filesystem (`base_dir`).

In [None]:
# start update cycle
for it in range(start_it, 10000):
    trainer.train()
    #
    x_a, x_b = next(A_iterator), next(B_iterator)
    x_a, x_b = x_a.float().cuda().detach(), x_b.float().cuda().detach()
    trainer.discriminator_update(x_a, x_b)
    #
    x_a, x_b = next(A_iterator), next(B_iterator)
    x_a, x_b = x_a.float().cuda().detach(), x_b.float().cuda().detach()
    trainer.generator_update(x_a, x_b)
    torch.cuda.synchronize()
    #
    trainer.eval()
    with torch.no_grad():
      for callback in callbacks:
          callback(it)

## Training - EUV-to-Magnetogram

In [None]:
norm_min = np.array([-1500, -1500, -1500, 0, 0, 0, 0])
norm_max = np.array([1500, 1500, 1500, 2000, 2500, 2000, 1000])

class EUVDataset(Dataset):

    def __init__(self, data, **kwargs):
        self.data = data

        super().__init__()

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        data = np.load(self.data[idx])[3:]
        data = (data - norm_min[3:, None, None]) / (norm_max - norm_min)[3:, None, None]
        data = data * 2 - 1
        data[data > 1] = 1
        data[data < -1] = -1
        return data

class MagDataset(Dataset):

    def __init__(self, data, **kwargs):
        self.data = data

        super().__init__()

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        data = np.load(self.data[idx])
        #data[:3] = np.flip(data[:3], axis=(2))
        data = (data - norm_min[:, None, None]) / (norm_max - norm_min)[:, None, None]
        data = data * 2 - 1
        data[data > 1] = 1
        data[data < -1] = -1
        return data

In [None]:
files = glob.glob('data/*.npy')
sdo_train_A = EUVDataset(files)
sdo_train_B = MagDataset(files)

In [None]:
base_dir = "sdo_to_sdo"
os.makedirs(base_dir, exist_ok=True)

# Init model
logging.basicConfig(
    level=logging.INFO,
    handlers=[
        logging.FileHandler("{0}/{1}.log".format(base_dir, "info_log")),
        logging.StreamHandler()
    ])

In [None]:
trainer = Trainer(input_dim_a=4, input_dim_b=7, discriminator_mode=DiscriminatorMode.CHANNELS, lambda_diversity=0, norm='in_aff')

In [None]:
base_dir = base_dir
ds_A=sdo_train_A
ds_B=sdo_train_B
num_workers=0

trainer.cuda()
trainer.train()
start_it = trainer.resume(base_dir)

# Init Callbacks
from iti.callback import HistoryCallback, ProgressCallback, SaveCallback, PlotBAB, PlotABA, ValidationHistoryCallback
history_callback = HistoryCallback(trainer, base_dir)
progress_callback = ProgressCallback(trainer)
save_callback = SaveCallback(trainer, base_dir)

plot_settings = [
    {"cmap": 'gray', "title": "Bx", 'vmin': -1, 'vmax': 1},
    {"cmap": 'gray', "title": "By", 'vmin': -1, 'vmax': 1},
    {"cmap": 'gray', "title": "Bz", 'vmin': -1, 'vmax': 1},
    {"cmap": cm.sdoaia171, "title": "AIA 171", 'vmin': -1, 'vmax': 1},
    {"cmap": cm.sdoaia193, "title": "AIA 193", 'vmin': -1, 'vmax': 1},
    {"cmap": cm.sdoaia211, "title": "AIA 211", 'vmin': -1, 'vmax': 1},
    {"cmap": cm.sdoaia304, "title": "AIA 304", 'vmin': -1, 'vmax': 1},
]
random_sample = [ds_A[i] for i in random.sample(range(len(ds_A)), 4)]
plot_ABA_callback = PlotABA(random_sample, trainer, base_dir, log_iteration=100, plot_settings_A=plot_settings[3:], plot_settings_B=plot_settings)
plot_ABA_callback.call(0)

random_sample = [ds_B[i] for i in random.sample(range(len(ds_B)), 4)]
plot_BAB_callback = PlotBAB(random_sample, trainer, base_dir, log_iteration=100, plot_settings_A=plot_settings[3:], plot_settings_B=plot_settings)
plot_BAB_callback.call(0)

callbacks = [plot_ABA_callback, plot_BAB_callback, history_callback, progress_callback, save_callback]
# init data loaders
B_iterator = loop(DataLoader(ds_B, batch_size=1, shuffle=True, num_workers=4, ))
A_iterator = loop(DataLoader(ds_A, batch_size=1, shuffle=True, num_workers=4, ))

In [None]:
# start update cycle
for it in range(start_it, 10000):
    trainer.train()
    #
    x_a, x_b = next(A_iterator), next(B_iterator)
    x_a, x_b = x_a.float().cuda().detach(), x_b.float().cuda().detach()
    trainer.discriminator_update(x_a, x_b)
    #
    x_a, x_b = next(A_iterator), next(B_iterator)
    x_a, x_b = x_a.float().cuda().detach(), x_b.float().cuda().detach()
    trainer.generator_update(x_a, x_b)
    torch.cuda.synchronize()
    #
    trainer.eval()
    with torch.no_grad():
      for callback in callbacks:
          callback(it)