### Imports

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
%%capture
!unzip /content/drive/MyDrive/InverseProblem/InverseProblem1.zip -d /content/

In [None]:
%%capture
!unzip /content/drive/MyDrive/InverseProblem/DATA/parameters_base.zip -d /content/InverseProblem/data/

In [None]:
from astropy.io import fits
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import sys

In [None]:
sys.path.append('/content/InverseProblem/inverse_problem/normalising_flows/model')
sys.path.append('/content/InverseProblem')
from NFFitter import NFFitter
from helpfuncs import calculate_metrics, compare_metrics
from inverse_problem.milne_edington.me import read_full_spectra, HinodeME, BatchHinodeME
from inverse_problem.nn_inversion.posthoc import compute_metrics, open_param_file, plot_params

In [None]:
from sklearn.preprocessing import StandardScaler
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
from torch.autograd import Variable
from NormalizingFlow import NormalizingFlow
from RealNVP import RealNVP



# DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DEVICE = torch.device('cpu')

class NFFitterDistilled(object):
    
    def __init__(self, teacher_model, var_size=2, cond_size=2, normalize_y=True, n_layers=8,  batch_size=32, n_epochs=10, lr=0.0001):
        
        self.normalize_y = normalize_y
        self.batch_size = batch_size
        self.n_epochs = n_epochs
        self.lr = lr
        self.loss_history = []

        self.teacher_model = teacher_model
        
        prior = torch.distributions.MultivariateNormal(torch.zeros(var_size), torch.eye(var_size))

        layers = []
        for i in range(n_layers):
            mask = ((torch.arange(var_size) + i) % 2)
            # mask = torch.ones(var_size, dtype=torch.int64) * torch.from_numpy(np.random.rand(var_size) < 0.5)
            layers.append(RealNVP(var_size=var_size, cond_size=cond_size, mask=mask, hidden=40))

        self.nf = NormalizingFlow(layers=layers, prior=prior)
        self.opt = torch.optim.Adam(self.nf.parameters(), lr=self.lr)
        
        
    def reshape(self, y):
        try:
            y.shape[1]
            return y
        except:
            return y.reshape(-1, 1)
    
    
    def fit(self, X, y):
        
        # reshape
        y = self.reshape(y)
        
        # normalize
        if self.normalize_y:
            self.ss_y = StandardScaler()
            y = self.ss_y.fit_transform(y)
            
        #noise = np.random.normal(0, 1, (y.shape[0], 1))
        #y = np.concatenate((y, noise), axis=1)
        
        # numpy to tensor
        y_real = torch.tensor(y, dtype=torch.float32, device=DEVICE)
        X_cond = torch.tensor(X, dtype=torch.float32, device=DEVICE)

        
        
        # tensor to dataset
        dataset_real = TensorDataset(y_real, X_cond)
        
        criterion = nn.MSELoss()
        

        # Fit GAN
        for epoch in range(self.n_epochs):
            for i, (y_batch, x_batch) in enumerate(DataLoader(dataset_real, batch_size=self.batch_size, shuffle=True)):
                
                noise = np.random.normal(0, 1, (len(y_batch), 1))
                noise = torch.tensor(noise, dtype=torch.float32, device=DEVICE)
                y_batch = torch.cat((y_batch, noise), dim=1)
                
                y_pred = self.nf.sample(x_batch)
                y_teacher = self.teacher_model.nf.sample(x_batch)
                
                # caiculate loss
                #loss = -self.nf.log_prob(y_batch, x_batch)
                # loss = criterion(y_batch, y_pred)

                loss = criterion(y_teacher, y_pred)
                
                # optimization step
                self.opt.zero_grad()
                loss.backward()
                self.opt.step()
                    
                # caiculate and store loss
                self.loss_history.append(loss.detach().cpu())
                    
        
    def predict(self, X):
        #noise = np.random.normal(0, 1, (X.shape[0], 1))
        #X = np.concatenate((X, noise), axis=1)
        X = torch.tensor(X, dtype=torch.float32, device=DEVICE)
        y_pred = self.nf.sample(X).cpu().detach().numpy()#[:, 0]
        # normalize
        if self.normalize_y:
            y_pred = self.ss_y.inverse_transform(y_pred)
        return y_pred
    
    def predict_n_times(self, X, n_times=100):
        predictions = []
        for i in range(n_times):
            y_pred = self.predict(X)
            predictions.append(y_pred)
        predictions = np.array(predictions)
        mean = predictions.mean(axis=0)
        std = predictions.std(axis=0)
        return mean, std

    def predict_image_n_times(self, X, n_times=100):
        predictedImage = []
        for row in X:
            predictedRow, _ = self.predict_n_times(row, n_times=n_times)
            predictedImage.append(predictedRow)
        return np.array(predictedImage)

In [None]:
DATA_PATH = '/content/InverseProblem/data/parameters_base.fits'

In [None]:
def prepare_data(data_path, size_limit=None, batch_size=10000):

    params = fits.open(DATA_PATH)[0].data
    lines = None

    if size_limit is None:
      size_limit = params.shape[0]

    for i in range((size_limit - 1)//batch_size + 1):
      # print(f'Computing: {batch_size*i} - {min(batch_size*(i+1), params.shape[0])}')
      modelBatchME = BatchHinodeME(params[batch_size*i:min(batch_size*(i+1), size_limit)])
      if lines is None:
        lines = modelBatchME.compute_spectrum()
      else:
        lines = np.concatenate((lines, modelBatchME.compute_spectrum()), axis=0)

    lines = np.reshape(lines, (lines.shape[0], lines.shape[1]*lines.shape[2])).astype(np.float32)
    params = params[:size_limit].astype(np.float32)

    return lines, params


In [None]:
lines, params = prepare_data(DATA_PATH, size_limit=10000, batch_size=1000)

In [None]:
lines_train, lines_test, params_train, params_test = train_test_split(lines, params, test_size=0.2)

In [None]:
params_scaler = StandardScaler()
sc_params = params_scaler.fit_transform(params)

In [None]:
sc_params_train = params_scaler.transform(params_train)
sc_params_test = params_scaler.transform(params_test)

In [None]:
model = torch.load(f'/content/drive/MyDrive/InverseProblem/saved_models/saved_mid_model_190.pth')

In [None]:
predicted, _ = model.predict_n_times(lines_test)
unsc_predicted = params_scaler.inverse_transform(predicted[:, :-1])
sc_predicted = predicted[:, :-1]

calculate_metrics(sc_params_test, sc_predicted)

Unnamed: 0,r2,mse,mae
Field_Strength,0.953,0.006182,0.0547
Field_Inclination,0.904,0.005375,0.054473
Field_Azimuth,0.71,0.020018,0.078421
Doppler_Width,0.931,0.001623,0.030462
Damping,0.92,0.002302,0.030834
Line_Strength,0.703,0.005249,0.031045
Original_Continuum_Intensity,0.636,0.006388,0.061855
Source_Function_Gradient,0.835,0.002624,0.040018
Doppler_Shift2,0.867,0.001998,0.030088
Stray_Light_Fill_Factor,0.762,0.013424,0.088449


In [None]:
dist_model = NFFitterDistilled(teacher_model=model, var_size=12, cond_size=lines.shape[1], normalize_y=False, n_layers=8, batch_size=250, n_epochs=100, lr=0.003)

In [None]:
dist_model.fit(lines_train, sc_params_train)

In [None]:
dist_predicted, _ = dist_model.predict_n_times(lines_test)
dist_unsc_predicted = params_scaler.inverse_transform(dist_predicted[:, :-1])
dist_sc_predicted = dist_predicted[:, :-1]

In [None]:
calculate_metrics(sc_params_test, dist_predicted)

Unnamed: 0,r2,mse,mae
Field_Strength,0.909,0.009662,0.069515
Field_Inclination,0.867,0.007167,0.065016
Field_Azimuth,0.576,0.029648,0.113828
Doppler_Width,0.734,0.005934,0.059421
Damping,0.704,0.006666,0.053801
Line_Strength,0.349,0.009523,0.044319
Original_Continuum_Intensity,0.548,0.007754,0.068204
Source_Function_Gradient,0.725,0.004886,0.057232
Doppler_Shift2,0.718,0.004886,0.049503
Stray_Light_Fill_Factor,0.501,0.027775,0.129498


In [None]:
compare_metrics(calculate_metrics(sc_params_test, sc_predicted), calculate_metrics(sc_params_test, dist_predicted))

0.8998259305337928