# 1. Unchanged code

In [1]:
# Unchanged code
!pip install optuna==2.4.0

import tqdm
import numpy as np

import numpy as np
import optuna
from google.colab import drive
import torch
import torch.nn as nn
import sys,os
import random
from torch.utils.data import DataLoader

import shutil
import time

import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable

import matplotlib as mpl
mpl.rcParams['figure.dpi'] = 200

if torch.cuda.is_available():
    print("CUDA Available")
    device = torch.device('cuda')
else:
    print('CUDA Not Available')
    device = torch.device('cpu')


def unbiased_HSIC(K, L):
  '''Computes an unbiased estimator of HISC. This is equation (2) from the paper'''

  #create the unit **vector** filled with ones
  n = K.shape[0]
  ones = np.ones(shape=(n))

  #fill the diagonal entries with zeros
  np.fill_diagonal(K, val=0) #this is now K_tilde
  np.fill_diagonal(L, val=0) #this is now L_tilde

  #first part in the square brackets
  trace = np.trace(np.dot(K, L))

  #middle part in the square brackets
  nominator1 = np.dot(np.dot(ones.T, K), ones)
  nominator2 = np.dot(np.dot(ones.T, L), ones)
  denominator = (n-1)*(n-2)
  middle = np.dot(nominator1, nominator2) / denominator


  #third part in the square brackets
  multiplier1 = 2/(n-2)
  multiplier2 = np.dot(np.dot(ones.T, K), np.dot(L, ones))
  last = multiplier1 * multiplier2

  #complete equation
  unbiased_hsic = 1/(n*(n-3)) * (trace + middle - last)

  return unbiased_hsic

def CKA(X, Y):
  '''Computes the CKA of two matrices. This is equation (1) from the paper'''

  nominator = unbiased_HSIC(np.dot(X, X.T), np.dot(Y, Y.T))
  denominator1 = unbiased_HSIC(np.dot(X, X.T), np.dot(X, X.T))
  denominator2 = unbiased_HSIC(np.dot(Y, Y.T), np.dot(Y, Y.T))

  cka = nominator/np.sqrt(denominator1*denominator2)

  return cka

def calculate_CKA_for_two_matrices(activationA, activationB):
  '''Takes two activations A and B and computes the linear CKA to measure their similarity'''

  #unfold the activations, that is make a (n, h*w*c) representation
  shape = activationA.shape
  activationA = np.reshape(activationA, newshape=(shape[0], np.prod(shape[1:])))

  shape = activationB.shape
  activationB = np.reshape(activationB, newshape=(shape[0], np.prod(shape[1:])))

  #calculate the CKA score
  cka_score = CKA(activationA, activationB)

  del activationA
  del activationB

  return cka_score


# def get_all_layer_outputs_fn(model):
#   '''Builds and returns function that returns the output of every (intermediate) layer'''

#   return tf.keras.backend.function([model.layers[0].input],
#                                   [l.output for l in model.layers[1:]])

def compare_activations(modelA, modelB, data_batch):
  '''
  Calculate a pairwise comparison of hidden representations and return a matrix
  '''

  _, intermediate_outputs_A = modelA(data_batch)
  _, intermediate_outputs_B = modelB(data_batch)

  #create a placeholder array
  result_array = np.zeros(shape=(len(intermediate_outputs_A), len(intermediate_outputs_B)))

  i = 0
  for outputA in tqdm.tqdm_notebook(intermediate_outputs_A):
    j = 0
    for outputB in tqdm.tqdm_notebook(intermediate_outputs_B):
      print(outputA.shape, outputB.shape)
      cka_score = calculate_CKA_for_two_matrices(outputA, outputB)
      result_array[i, j] = cka_score
      j+=1
    i+= 1

  return result_array

# This routine returns the data loader need to train the network
def create_dataset_multifield(mode, seed, fmaps, fparams, batch_size, splits, fmaps_norm,
                              rot_flip_in_mem=True, shuffle=True, verbose=False):

    # whether rotations and flippings are kept in memory
    if rot_flip_in_mem:
        data_set = make_dataset_multifield(mode, seed, fmaps, fparams, splits, fmaps_norm, verbose)
    else:
        data_set = make_dataset_multifield2(mode, seed, fmaps, fparams, splits, fmaps_norm, verbose)

    def seed_worker(worker_id):
        worker_seed = torch.initial_seed() % 2**32
        np.random.seed(worker_seed)
        random.seed(worker_seed)

    g = torch.Generator()
    g.manual_seed(0)

    data_loader = DataLoader(dataset=data_set, batch_size=batch_size, shuffle=shuffle, worker_init_fn=seed_worker, generator=g)
    return data_loader


# This class creates the dataset. Rotations and flippings are precompued and stored in memory
class make_dataset_multifield():

    def __init__(self, mode, seed, fmaps, fparams, splits, fmaps_norm, verbose):

        # get the total number of sims and maps
        params_sims = np.loadtxt(fparams) #simulations parameters, NOT maps parameters
        total_sims, total_maps, num_params = \
                params_sims.shape[0], params_sims.shape[0]*splits, params_sims.shape[1]
        params_maps = np.zeros((total_maps, num_params), dtype=np.float32)
        print(total_sims, splits, params_sims.shape[0])
        for i in range(total_sims):
            for j in range(splits):
                params_maps[i*splits + j] = params_sims[i]

        # normalize the value of the cosmological & astrophysical parameters
        minimum     = np.array([0.1, 0.6, 0.25, 0.25, 0.5, 0.5])
        maximum     = np.array([0.5, 1.0, 4.00, 4.00, 2.0, 2.0])
        params_maps = (params_maps - minimum)/(maximum - minimum)

        # get the size and offset depending on the type of dataset
        if   mode=='train':  offset, size_sims = int(0.00*total_sims), int(0.90*total_sims)
        elif mode=='valid':  offset, size_sims = int(0.90*total_sims), int(0.05*total_sims)
        elif mode=='test':   offset, size_sims = int(0.95*total_sims), int(0.05*total_sims)
        elif mode=='all':    offset, size_sims = int(0.00*total_sims), int(1.00*total_sims)
        else:                raise Exception('Wrong name!')
        size_maps = size_sims*splits

        # randomly shuffle the simulations (not maps). Instead of 0 1 2 3...999 have a
        # random permutation. E.g. 5 9 0 29...342
        np.random.seed(seed)
        sim_numbers = np.arange(total_sims) #shuffle sims not maps
        np.random.shuffle(sim_numbers)
        sim_numbers = sim_numbers[offset:offset+size_sims] #select indexes of mode

        # get the corresponding indexes of the maps associated to the sims
        indexes = np.zeros(size_maps, dtype=np.int32)
        count = 0
        for i in sim_numbers:
            for j in range(splits):
                indexes[count] = i*splits + j
                count += 1

        print(len(indexes))
        print(indexes.shape)

        # keep only the value of the parameters of the considered maps
        params_maps = params_maps[indexes]

        # define the matrix containing the maps with rotations and flipings
        channels = len(fmaps)
        dumb     = np.load(fmaps[0])    #[number of maps, height, width]
        height, width = dumb.shape[1], dumb.shape[2];  del dumb
        data     = np.zeros((size_maps*8, channels, height, width), dtype=np.float32)
        params   = np.zeros((size_maps*8, num_params),              dtype=np.float32)

        # read the data
        print('Found %d channels\nReading data...'%channels)
        for channel, (fim, fnorm) in enumerate(zip(fmaps, fmaps_norm)):

            # read maps in the considered channel
            data_c = np.load(fim)
            if data_c.shape[0]!=total_maps:  raise Exception('sizes do not match')
            if verbose:  print('%.3e < F(all|orig) < %.3e'%(np.min(data_c), np.max(data_c)))

            # rescale maps
            if fim.find('Mstar')!=-1:  data_c = np.log10(data_c + 1.0)
            else:                      data_c = np.log10(data_c)
            if verbose:  print('%.3f < F(all|resc)  < %.3f'%(np.min(data_c), np.max(data_c)))

            # normalize maps
            if fnorm is None:
                mean,    std     = np.mean(data_c), np.std(data_c)
            else:
                # read data
                data_norm = np.load(fnorm)

                # rescale
                if fnorm.find('Mstar')!=-1:  data_norm = np.log10(data_norm + 1.0)
                else:                        data_norm = np.log10(data_norm)

                # compute mean and std
                mean,    std     = np.mean(data_norm), np.std(data_norm)
                del data_norm

            data_c = (data_c - mean)/std
            if verbose:  print('%.3f < F(all|norm) < %.3f'%(np.min(data_c), np.max(data_c)))

            # keep only the data of the chosen set
            data_c = data_c[indexes]

            # do a loop over all rotations (each is 90 deg)
            counted_maps = 0
            for rot in [0,1,2,3]:
                data_rot = np.rot90(data_c, k=rot, axes=(1,2))

                data[counted_maps:counted_maps+size_maps,channel,:,:] = data_rot
                params[counted_maps:counted_maps+size_maps]           = params_maps
                counted_maps += size_maps

                data[counted_maps:counted_maps+size_maps,channel,:,:] = \
                                                    np.flip(data_rot, axis=1)
                params[counted_maps:counted_maps+size_maps]           = params_maps
                counted_maps += size_maps

            if verbose:
                print('Channel %d contains %d maps'%(channel,counted_maps))
                print('%.3f < F < %.3f\n'%(np.min(data_c), np.max(data_c)))

        self.size = data.shape[0]
        self.x    = torch.tensor(data,   dtype=torch.float32)
        self.y    = torch.tensor(params, dtype=torch.float32)
        del data, data_c

    def __len__(self):
        return self.size

    def __getitem__(self, idx):
        return self.x[idx], self.y[idx]

# This class creates the dataset. It will read the maps and store them in memory.
# The rotations and flipings are done when calling the data
class make_dataset_multifield2():

    def __init__(self, mode, seed, fmaps, fparams, splits, fmaps_norm, verbose):

        # get the total number of simulations and maps
        params_sims = np.loadtxt(fparams) #simulations parameters, NOT maps parameters
        total_sims, total_maps, num_params = \
                params_sims.shape[0], params_sims.shape[0]*splits, params_sims.shape[1]
        params = np.zeros((total_maps, num_params), dtype=np.float32)
        for i in range(total_sims):
            for j in range(splits):
                params[i*splits + j] = params_sims[i]

        # normalize params
        minimum = np.array([0.1, 0.6, 0.25, 0.25, 0.5, 0.5])
        maximum = np.array([0.5, 1.0, 4.00, 4.00, 2.0, 2.0])
        params  = (params - minimum)/(maximum - minimum)

        # get the size and offset depending on the type of dataset
        if   mode=='train':  offset, size_sims = int(0.00*total_sims), int(0.90*total_sims)
        elif mode=='valid':  offset, size_sims = int(0.90*total_sims), int(0.05*total_sims)
        elif mode=='test':   offset, size_sims = int(0.95*total_sims), int(0.05*total_sims)
        elif mode=='all':    offset, size_sims = int(0.00*total_sims), int(1.00*total_sims)
        else:                raise Exception('Wrong name!')
        size_maps = size_sims*splits

        # randomly shuffle the simulations (not maps). Instead of 0 1 2 3...999 have a
        # random permutation. E.g. 5 9 0 29...342
        np.random.seed(seed)
        sim_numbers = np.arange(total_sims) #shuffle maps not rotations
        np.random.shuffle(sim_numbers)
        sim_numbers = sim_numbers[offset:offset+size_sims] #select indexes of mode

        # get the corresponding indexes of the maps associated to the sims
        indexes = np.zeros(size_maps, dtype=np.int32)
        count = 0
        for i in sim_numbers:
            for j in range(splits):
                indexes[count] = i*splits + j
                count += 1

        # keep only the value of the parameters of the considered maps
        params = params[indexes]

        # define the matrix containing the maps without rotations or flippings
        channels = len(fmaps)
        dumb     = np.load(fmaps[0])    #[number of maps, height, width]
        height, width = dumb.shape[1], dumb.shape[2];  del dumb
        data     = np.zeros((size_maps, channels, height, width), dtype=np.float32)

        # read the data
        print('Found %d channels\nReading data...'%channels)
        for channel, (fim, fnorm) in enumerate(zip(fmaps, fmaps_norm)):

            # read maps in the considered channel
            data_c = np.load(fim)
            print(data_c.shape, total_maps)
            if data_c.shape[0]!=total_maps:  raise Exception('sizes do not match')
            if verbose:
                print('%.3e < F(all|orig) < %.3e'%(np.min(data_c), np.max(data_c)))

            # rescale maps
            if fim.find('Mstar')!=-1:  data_c = np.log10(data_c + 1.0)
            else:                      data_c = np.log10(data_c)
            if verbose:
                print('%.3f < F(all|resc)  < %.3f'%(np.min(data_c), np.max(data_c)))

            # normalize maps
            if fnorm is None:
                mean,    std     = np.mean(data_c), np.std(data_c)
                minimum, maximum = np.min(data_c),  np.max(data_c)
            else:
                # read data
                data_norm     = np.load(fnorm)

                # rescale data
                if fnorm.find('Mstar')!=-1:  data_norm = np.log10(data_norm + 1.0)
                else:                        data_norm = np.log10(data_norm)

                # compute mean and std
                mean,    std     = np.mean(data_norm), np.std(data_norm)
                minimum, maximum = np.min(data_norm),  np.max(data_norm)
                del data_norm

            data_c = (data_c - mean)/std
            if verbose:  print('%.3f < F(all|norm) < %.3f'%(np.min(data_c), np.max(data_c)))

            # keep only the data of the chosen set
            data[:,channel,:,:] = data_c[indexes]

        self.size = data.shape[0]
        self.x    = torch.tensor(data,   dtype=torch.float32)
        self.y    = torch.tensor(params, dtype=torch.float32)
        del data, data_c

    def __len__(self):
        return self.size

    def __getitem__(self, idx):

        # choose a rotation angle (0-0, 1-90, 2-180, 3-270)
        # and whether do flipping or not
        rot  = np.random.randint(0,4)
        flip = np.random.randint(0,1)

        # rotate and flip the maps
        maps = torch.rot90(self.x[idx], k=rot, dims=[1,2])
        if flip==1:  maps = torch.flip(maps, dims=[1])

        return maps, self.y[idx]

class model_o3_err(nn.Module):
    def __init__(self, hidden, dr, channels):
        super(model_o3_err, self).__init__()

        # input: 1x256x256 ---------------> output: 2*hiddenx128x128
        self.C01 = nn.Conv2d(channels,  2*hidden, kernel_size=3, stride=1, padding=1,
                            padding_mode='circular', bias=True)
        self.C02 = nn.Conv2d(2*hidden,  2*hidden, kernel_size=3, stride=1, padding=1,
                            padding_mode='circular', bias=True)
        self.C03 = nn.Conv2d(2*hidden,  2*hidden, kernel_size=2, stride=2, padding=0,
                            padding_mode='circular', bias=True)
        self.B01 = nn.BatchNorm2d(2*hidden)
        self.B02 = nn.BatchNorm2d(2*hidden)
        self.B03 = nn.BatchNorm2d(2*hidden)

        # input: 2*hiddenx128x128 ----------> output: 4*hiddenx64x64
        self.C11 = nn.Conv2d(2*hidden, 4*hidden, kernel_size=3, stride=1, padding=1,
                            padding_mode='circular', bias=True)
        self.C12 = nn.Conv2d(4*hidden, 4*hidden, kernel_size=3, stride=1, padding=1,
                            padding_mode='circular', bias=True)
        self.C13 = nn.Conv2d(4*hidden, 4*hidden, kernel_size=2, stride=2, padding=0,
                            padding_mode='circular', bias=True)
        self.B11 = nn.BatchNorm2d(4*hidden)
        self.B12 = nn.BatchNorm2d(4*hidden)
        self.B13 = nn.BatchNorm2d(4*hidden)

        # input: 4*hiddenx64x64 --------> output: 8*hiddenx32x32
        self.C21 = nn.Conv2d(4*hidden, 8*hidden, kernel_size=3, stride=1, padding=1,
                            padding_mode='circular', bias=True)
        self.C22 = nn.Conv2d(8*hidden, 8*hidden, kernel_size=3, stride=1, padding=1,
                            padding_mode='circular', bias=True)
        self.C23 = nn.Conv2d(8*hidden, 8*hidden, kernel_size=2, stride=2, padding=0,
                            padding_mode='circular', bias=True)
        self.B21 = nn.BatchNorm2d(8*hidden)
        self.B22 = nn.BatchNorm2d(8*hidden)
        self.B23 = nn.BatchNorm2d(8*hidden)

        # input: 8*hiddenx32x32 ----------> output: 16*hiddenx16x16
        self.C31 = nn.Conv2d(8*hidden,  16*hidden, kernel_size=3, stride=1, padding=1,
                            padding_mode='circular', bias=True)
        self.C32 = nn.Conv2d(16*hidden, 16*hidden, kernel_size=3, stride=1, padding=1,
                            padding_mode='circular', bias=True)
        self.C33 = nn.Conv2d(16*hidden, 16*hidden, kernel_size=2, stride=2, padding=0,
                            padding_mode='circular', bias=True)
        self.B31 = nn.BatchNorm2d(16*hidden)
        self.B32 = nn.BatchNorm2d(16*hidden)
        self.B33 = nn.BatchNorm2d(16*hidden)

        # input: 16*hiddenx16x16 ----------> output: 32*hiddenx8x8
        self.C41 = nn.Conv2d(16*hidden, 32*hidden, kernel_size=3, stride=1, padding=1,
                            padding_mode='circular', bias=True)
        self.C42 = nn.Conv2d(32*hidden, 32*hidden, kernel_size=3, stride=1, padding=1,
                            padding_mode='circular', bias=True)
        self.C43 = nn.Conv2d(32*hidden, 32*hidden, kernel_size=2, stride=2, padding=0,
                            padding_mode='circular', bias=True)
        self.B41 = nn.BatchNorm2d(32*hidden)
        self.B42 = nn.BatchNorm2d(32*hidden)
        self.B43 = nn.BatchNorm2d(32*hidden)

        # input: 32*hiddenx8x8 ----------> output:64*hiddenx4x4
        self.C51 = nn.Conv2d(32*hidden, 64*hidden, kernel_size=3, stride=1, padding=1,
                            padding_mode='circular', bias=True)
        self.C52 = nn.Conv2d(64*hidden, 64*hidden, kernel_size=3, stride=1, padding=1,
                            padding_mode='circular', bias=True)
        self.C53 = nn.Conv2d(64*hidden, 64*hidden, kernel_size=2, stride=2, padding=0,
                            padding_mode='circular', bias=True)
        self.B51 = nn.BatchNorm2d(64*hidden)
        self.B52 = nn.BatchNorm2d(64*hidden)
        self.B53 = nn.BatchNorm2d(64*hidden)

        # input: 64*hiddenx4x4 ----------> output: 128*hiddenx1x1
        self.C61 = nn.Conv2d(64*hidden, 128*hidden, kernel_size=4, stride=1, padding=0,
                            padding_mode='circular', bias=True)
        self.B61 = nn.BatchNorm2d(128*hidden)

        # self.P0  = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)

        self.FC1  = nn.Linear(128*hidden, 64*hidden)
        self.FC2  = nn.Linear(64*hidden,  12)

        self.dropout   = nn.Dropout(p=dr)
        # self.ReLU      = nn.ReLU()

        self.LeakyReLUs = []

        # We do this so that we can turn off some leaky relus according to our wish.
        self.LeakyReLU1 = nn.LeakyReLU(0.2)
        self.LeakyReLU2 = nn.LeakyReLU(0.2)
        self.LeakyReLU3 = nn.LeakyReLU(0.2)
        self.LeakyReLU4 = nn.LeakyReLU(0.2)
        self.LeakyReLU5 = nn.LeakyReLU(0.2)
        self.LeakyReLU6 = nn.LeakyReLU(0.2)
        self.LeakyReLU7 = nn.LeakyReLU(0.2)
        self.LeakyReLU8 = nn.LeakyReLU(0.2)
        self.LeakyReLU9 = nn.LeakyReLU(0.2)
        self.LeakyReLU10 = nn.LeakyReLU(0.2)
        self.LeakyReLU11 = nn.LeakyReLU(0.2)
        self.LeakyReLU12 = nn.LeakyReLU(0.2)
        self.LeakyReLU13 = nn.LeakyReLU(0.2)
        self.LeakyReLU14 = nn.LeakyReLU(0.2)
        self.LeakyReLU15 = nn.LeakyReLU(0.2)
        self.LeakyReLU16 = nn.LeakyReLU(0.2)
        self.LeakyReLU17 = nn.LeakyReLU(0.2)
        self.LeakyReLU18 = nn.LeakyReLU(0.2)
        self.LeakyReLU19 = nn.LeakyReLU(0.2)
        self.LeakyReLU20 = nn.LeakyReLU(0.2)

        # self.tanh      = nn.Tanh()

        for m in self.modules():
            if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight)

    def forward(self, image):
        x01 = self.C01(image)
        x01_ = self.LeakyReLU1(x01)
        x02 = self.C02(x01_)
        x02_ = self.LeakyReLU2(self.B02(x02))
        x03 = self.C03(x02_)
        x03_ = self.LeakyReLU3(self.B03(x03))

        x11 = self.C11(x03_)
        x11_ = self.LeakyReLU4(self.B11(x11))
        x12 = self.C12(x11_)
        x12_ = self.LeakyReLU5(self.B12(x12))
        x13 = self.C13(x12_)
        x13_ = self.LeakyReLU6(self.B13(x13))

        x21 = self.C21(x13_)
        x21_ = self.LeakyReLU7(self.B21(x21))
        x22 = self.C22(x21_)
        x22_ = self.LeakyReLU8(self.B22(x22))
        x23 = self.C23(x22_)
        x23_ = self.LeakyReLU9(self.B23(x23))

        x31 = self.C31(x23_)
        x31_ = self.LeakyReLU10(self.B31(x31))
        x32 = self.C32(x31_)
        x32_ = self.LeakyReLU11(self.B32(x32))
        x33 = self.C33(x32_)
        x33_ = self.LeakyReLU12(self.B33(x33))

        x41 = self.C41(x33_)
        x41_ = self.LeakyReLU13(self.B41(x41))
        x42 = self.C42(x41_)
        x42_ = self.LeakyReLU14(self.B42(x42))
        x43 = self.C43(x42_)
        x43_ = self.LeakyReLU15(self.B43(x43))

        x51 = self.C51(x43_)
        x51_ = self.LeakyReLU16(self.B51(x51))
        x52 = self.C52(x51_)
        x52_ = self.LeakyReLU17(self.B52(x52))
        x53 = self.C53(x52_)
        x53_ = self.LeakyReLU18(self.B53(x53))

        x61 = self.C61(x53_)
        x61_ = self.LeakyReLU19(self.B61(x61))

        x61__ = x61_.view(image.shape[0],-1)
        x61__ = self.dropout(x61__)

        x71 = self.FC1(x61__)
        x71_ = self.LeakyReLU20(x71)
        x71__ = self.dropout(x71_)
        x81 = self.FC2(x71__)

        # enforce the errors to be positive
        y = torch.clone(x81)
        y[:,6:12] = torch.square(x81[:,6:12])

        return y, [x01_, x02_, x03_, x11_, x12_, x13_, x21_, x22_, x23_, x31_, x32_, x33_, x41_, x42_, x43_, x51_, x52_, x53_, x61_, x71_, x81]


class model_o3_err_layer_output(nn.Module):
    def __init__(self, hidden, dr, channels, layer_index=0):
        super(model_o3_err_layer_output, self).__init__()

        self.layer_index = layer_index

        # input: 1x256x256 ---------------> output: 2*hiddenx128x128
        self.C01 = nn.Conv2d(channels,  2*hidden, kernel_size=3, stride=1, padding=1,
                            padding_mode='circular', bias=True)
        self.C02 = nn.Conv2d(2*hidden,  2*hidden, kernel_size=3, stride=1, padding=1,
                            padding_mode='circular', bias=True)
        self.C03 = nn.Conv2d(2*hidden,  2*hidden, kernel_size=2, stride=2, padding=0,
                            padding_mode='circular', bias=True)
        self.B01 = nn.BatchNorm2d(2*hidden)
        self.B02 = nn.BatchNorm2d(2*hidden)
        self.B03 = nn.BatchNorm2d(2*hidden)

        # input: 2*hiddenx128x128 ----------> output: 4*hiddenx64x64
        self.C11 = nn.Conv2d(2*hidden, 4*hidden, kernel_size=3, stride=1, padding=1,
                            padding_mode='circular', bias=True)
        self.C12 = nn.Conv2d(4*hidden, 4*hidden, kernel_size=3, stride=1, padding=1,
                            padding_mode='circular', bias=True)
        self.C13 = nn.Conv2d(4*hidden, 4*hidden, kernel_size=2, stride=2, padding=0,
                            padding_mode='circular', bias=True)
        self.B11 = nn.BatchNorm2d(4*hidden)
        self.B12 = nn.BatchNorm2d(4*hidden)
        self.B13 = nn.BatchNorm2d(4*hidden)

        # input: 4*hiddenx64x64 --------> output: 8*hiddenx32x32
        self.C21 = nn.Conv2d(4*hidden, 8*hidden, kernel_size=3, stride=1, padding=1,
                            padding_mode='circular', bias=True)
        self.C22 = nn.Conv2d(8*hidden, 8*hidden, kernel_size=3, stride=1, padding=1,
                            padding_mode='circular', bias=True)
        self.C23 = nn.Conv2d(8*hidden, 8*hidden, kernel_size=2, stride=2, padding=0,
                            padding_mode='circular', bias=True)
        self.B21 = nn.BatchNorm2d(8*hidden)
        self.B22 = nn.BatchNorm2d(8*hidden)
        self.B23 = nn.BatchNorm2d(8*hidden)

        # input: 8*hiddenx32x32 ----------> output: 16*hiddenx16x16
        self.C31 = nn.Conv2d(8*hidden,  16*hidden, kernel_size=3, stride=1, padding=1,
                            padding_mode='circular', bias=True)
        self.C32 = nn.Conv2d(16*hidden, 16*hidden, kernel_size=3, stride=1, padding=1,
                            padding_mode='circular', bias=True)
        self.C33 = nn.Conv2d(16*hidden, 16*hidden, kernel_size=2, stride=2, padding=0,
                            padding_mode='circular', bias=True)
        self.B31 = nn.BatchNorm2d(16*hidden)
        self.B32 = nn.BatchNorm2d(16*hidden)
        self.B33 = nn.BatchNorm2d(16*hidden)

        # input: 16*hiddenx16x16 ----------> output: 32*hiddenx8x8
        self.C41 = nn.Conv2d(16*hidden, 32*hidden, kernel_size=3, stride=1, padding=1,
                            padding_mode='circular', bias=True)
        self.C42 = nn.Conv2d(32*hidden, 32*hidden, kernel_size=3, stride=1, padding=1,
                            padding_mode='circular', bias=True)
        self.C43 = nn.Conv2d(32*hidden, 32*hidden, kernel_size=2, stride=2, padding=0,
                            padding_mode='circular', bias=True)
        self.B41 = nn.BatchNorm2d(32*hidden)
        self.B42 = nn.BatchNorm2d(32*hidden)
        self.B43 = nn.BatchNorm2d(32*hidden)

        # input: 32*hiddenx8x8 ----------> output:64*hiddenx4x4
        self.C51 = nn.Conv2d(32*hidden, 64*hidden, kernel_size=3, stride=1, padding=1,
                            padding_mode='circular', bias=True)
        self.C52 = nn.Conv2d(64*hidden, 64*hidden, kernel_size=3, stride=1, padding=1,
                            padding_mode='circular', bias=True)
        self.C53 = nn.Conv2d(64*hidden, 64*hidden, kernel_size=2, stride=2, padding=0,
                            padding_mode='circular', bias=True)
        self.B51 = nn.BatchNorm2d(64*hidden)
        self.B52 = nn.BatchNorm2d(64*hidden)
        self.B53 = nn.BatchNorm2d(64*hidden)

        # input: 64*hiddenx4x4 ----------> output: 128*hiddenx1x1
        self.C61 = nn.Conv2d(64*hidden, 128*hidden, kernel_size=4, stride=1, padding=0,
                            padding_mode='circular', bias=True)
        self.B61 = nn.BatchNorm2d(128*hidden)

        self.P0  = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)

        self.FC1  = nn.Linear(128*hidden, 64*hidden)
        self.FC2  = nn.Linear(64*hidden,  12)

        self.dropout   = nn.Dropout(p=dr)
        self.ReLU      = nn.ReLU()
        self.LeakyReLU = nn.LeakyReLU(0.2)
        self.tanh      = nn.Tanh()

        for m in self.modules():
            if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight)


    def forward(self, image):

        x01 = self.C01(image)
        x01_ = self.LeakyReLU(x01)
        if self.layer_index == 0:
          return x01_
        x02 = self.C02(x01_)
        x02_ = self.LeakyReLU(self.B02(x02))
        if self.layer_index == 1:
          return x02_
        x03 = self.C03(x02_)
        x03_ = self.LeakyReLU(self.B03(x03))
        if self.layer_index == 2:
          return x03_

        x11 = self.C11(x03_)
        x11_ = self.LeakyReLU(self.B11(x11))
        if self.layer_index == 3:
          return x11_
        x12 = self.C12(x11_)
        x12_ = self.LeakyReLU(self.B12(x12))
        if self.layer_index == 4:
          return x12_
        x13 = self.C13(x12_)
        x13_ = self.LeakyReLU(self.B13(x13))
        if self.layer_index == 5:
          return x13_

        x21 = self.C21(x13_)
        x21_ = self.LeakyReLU(self.B21(x21))
        if self.layer_index == 6:
          return x21_
        x22 = self.C22(x21_)
        x22_ = self.LeakyReLU(self.B22(x22))
        if self.layer_index == 7:
          return x22_
        x23 = self.C23(x22_)
        x23_ = self.LeakyReLU(self.B23(x23))
        if self.layer_index == 8:
          return x23_

        x31 = self.C31(x23_)
        x31_ = self.LeakyReLU(self.B31(x31))
        if self.layer_index == 9:
          return x31_
        x32 = self.C32(x31_)
        x32_ = self.LeakyReLU(self.B32(x32))
        if self.layer_index == 10:
          return x32_
        x33 = self.C33(x32_)
        x33_ = self.LeakyReLU(self.B33(x33))
        if self.layer_index == 11:
          return x33_

        x41 = self.C41(x33_)
        x41_ = self.LeakyReLU(self.B41(x41))
        if self.layer_index == 12:
          return x41_
        x42 = self.C42(x41_)
        x42_ = self.LeakyReLU(self.B42(x42))
        if self.layer_index == 13:
          return x42_
        x43 = self.C43(x42_)
        x43_ = self.LeakyReLU(self.B43(x43))
        if self.layer_index == 14:
          return x43_

        x51 = self.C51(x43_)
        x51_ = self.LeakyReLU(self.B51(x51))
        if self.layer_index == 15:
          return x51_
        x52 = self.C52(x51_)
        x52_ = self.LeakyReLU(self.B52(x52))
        if self.layer_index == 16:
          return x52_
        x53 = self.C53(x52_)
        x53_ = self.LeakyReLU(self.B53(x53))
        if self.layer_index == 17:
          return x53_

        x61 = self.C61(x53_)
        x61_ = self.LeakyReLU(self.B61(x61))
        if self.layer_index == 18:
          return x61_

        x61__ = x61_.view(image.shape[0],-1)
        x61__ = self.dropout(x61__)

        x71 = self.FC1(x61__)
        x71_ = self.LeakyReLU(x71)
        if self.layer_index == 19:
          return x71_
        x71__ = self.dropout(x71_)
        x81 = self.FC2(x71__)

        # enforce the errors to be positive
        y = torch.clone(x81)
        y[:,6:12] = torch.square(x81[:,6:12])

        if self.layer_index == 20:
          return y


def load_two_models_for_similarity_measure(fweights_1, fweights_2, fdatabase1, study_name, trial_number_1, trial_number_2, fdatabase2=None, layer_output=False, layer_index=0):
  if fdatabase2 is None:
    print('Assuming same database for the two models...')

  if torch.cuda.is_available():
    print("GPUs available")
    device = torch.device('cuda')
  else:
    print('GPUs not available')
    device = torch.device('cpu')

  fweights = [fweights_1, fweights_2]
  fdatabases = [fdatabase1, fdatabase2]
  trial_numbers = [trial_number_1, trial_number_2]
  models = []

  for i in range(2):
    if fdatabase2 is None:
      study = optuna.load_study(study_name=study_name, storage=fdatabase1)
    else:
      study = optuna.load_study(study_name=study_name, storage=fdatabases[i])

    print(f"\nTrial number: {trial_numbers[i]}")
    trial_number = trial_numbers[i]
    trial = study.trials[trial_number]
    print("Trial number:  number {}".format(trial.number))
    print("Loss:          %.5e"%trial.value)
    print("Params: ")
    for key, value in trial.params.items():
      print("    {}: {}".format(key, value))

    if layer_output:
      model = model_o3_err_layer_output(trial.params['hidden'], trial.params['dr'], 1, layer_index=layer_index)
    else:
      model = model_o3_err(trial.params['hidden'], trial.params['dr'], 1)
    model = nn.DataParallel(model)
    model.to(device=device)
    network_total_params = sum(p.numel() for p in model.parameters())
    print('total number of parameters in the model = %d'%network_total_params)

    if os.path.exists(fweights[i]):
      model.load_state_dict(torch.load(fweights[i], map_location=torch.device(device)))
      print('Weights loaded')
      models.append(model)
    else:
      raise Exception('file with weights not found!!!')

  if len(models) != 2:
    print("Warning! Two models were not loaded...")
  return models

# def load_two_models_for_similarity_measure_modified(fweights_1, fweights_2, fdatabase1, study_name, trial_number_1, trial_number_2, fdatabase2=None):  # modified to also return linear layer.
#   if fdatabase2 is None:
#     print('Assuming same database for the two models...')

#   if torch.cuda.is_available():
#     print("GPUs available")
#     device = torch.device('cuda')
#   else:
#     print('GPUs not available')
#     device = torch.device('cpu')

#   fweights = [fweights_1, fweights_2]
#   fdatabases = [fdatabase1, fdatabase2]
#   trial_numbers = [trial_number_1, trial_number_2]
#   models = []
#   lr_models = []

#   for i in range(2):
#     if fdatabase2 is None:
#       study = optuna.load_study(study_name=study_name, storage=fdatabase1)
#     else:
#       study = optuna.load_study(study_name=study_name, storage=fdatabases[i])

#     print(f"\nTrial number: {trial_numbers[i]}")
#     trial_number = trial_numbers[i]
#     trial = study.trials[trial_number]
#     print("Trial number:  number {}".format(trial.number))
#     print("Loss:          %.5e"%trial.value)
#     print("Params: ")
#     for key, value in trial.params.items():
#       print("    {}: {}".format(key, value))

#     model = model_o3_err(trial.params['hidden'], trial.params['dr'], 1)
#     model = nn.DataParallel(model)
#     model.to(device=device)
#     network_total_params = sum(p.numel() for p in model.parameters())
#     print('total number of parameters in the model = %d'%network_total_params)

#     if os.path.exists(fweights[i]):
#       model.load_state_dict(torch.load(fweights[i], map_location=torch.device(device)))
#       print('Weights loaded')
#       models.append(model)
#     else:
#       raise Exception('file with weights not found!!!')

#     lr_models.append(
#         linear_reg()
#     )

#   if len(models) != 2:
#     print("Warning! Two models were not loaded...")
#   if len(lr_models) != 2:
#     print("Warning! Two linear layer models were not loaded...")
#   return models, lr_models

import seaborn as sns
sns.set_context("paper", font_scale = 2)
sns.set_style('whitegrid')
sns.set(style='ticks')
def plot_cka(sim):
  fig, ax = plt.subplots(1, 1, figsize=(5, 5))

  divider = make_axes_locatable(ax)
  cax = divider.append_axes('right', size='5%', pad=0.05)

  im = ax.imshow(sim, cmap='magma', vmin=0.0,vmax=1.0)
  layers = np.arange(1, 21, 2)
  ax.set_xticks(layers); ax.set_yticks(layers)
  ax.set_xticklabels(layers); ax.set_yticklabels(layers)
  ax.axes.invert_yaxis()
  fig.colorbar(im, cax=cax, orientation='vertical')
  return ax


def get_test_acc(model, test_loader):
  # get the number of maps in the test set
  num_maps = 0
  for x,y in test_loader:
        num_maps += x.shape[0]
  print('\nNumber of maps in the test set: %d'%num_maps)

  # define the arrays containing the value of the parameters
  params_true = np.zeros((num_maps,6), dtype=np.float32)
  params_NN   = np.zeros((num_maps,6), dtype=np.float32)
  errors_NN   = np.zeros((num_maps,6), dtype=np.float32)

  # get test loss
  test_loss1, test_loss2 = torch.zeros(len(g)).to(device), torch.zeros(len(g)).to(device)
  test_loss, points = 0.0, 0
  model.eval()
  for x, y in test_loader:
      with torch.no_grad():
          bs    = x.shape[0]    #batch size
          x     = x.to(device)  #send data to device
          y     = y.to(device)  #send data to device
          p     = model(x)[0]      #prediction for mean and variance  # EDIT: Take 0th element because we modify the architecture to return the features as well.
          y_NN  = p[:,:6]       #prediction for mean
          e_NN  = p[:,6:]       #prediction for error
          loss1 = torch.mean((y_NN[:,g] - y[:,g])**2,                     axis=0)
          loss2 = torch.mean(((y_NN[:,g] - y[:,g])**2 - e_NN[:,g]**2)**2, axis=0)
          test_loss1 += loss1*bs
          test_loss2 += loss2*bs

          # save results to their corresponding arrays
          params_true[points:points+x.shape[0]] = y.cpu().numpy()
          params_NN[points:points+x.shape[0]]   = y_NN.cpu().numpy()
          errors_NN[points:points+x.shape[0]]   = e_NN.cpu().numpy()
          points    += x.shape[0]
  test_loss = torch.log(test_loss1/points) + torch.log(test_loss2/points)
  test_loss = torch.mean(test_loss).item()
  print('Test loss = %.3e\n'%test_loss)
  return params_true, params_NN, errors_NN

def get_test_acc_modified(model_list, test_loader):  # for internal layer accuracy experiment
  lr_model = model_list[1]
  model = model_list[0]
  # get the number of maps in the test set
  num_maps = 0
  for x,y in test_loader:
        num_maps += x.shape[0]
  print('\nNumber of maps in the test set: %d'%num_maps)

  # define the arrays containing the value of the parameters
  params_true = np.zeros((num_maps,6), dtype=np.float32)
  params_NN   = np.zeros((num_maps,6), dtype=np.float32)
  errors_NN   = np.zeros((num_maps,6), dtype=np.float32)

  # get test loss
  test_loss1, test_loss2 = torch.zeros(len(g)).to(device), torch.zeros(len(g)).to(device)
  test_loss, points = 0.0, 0
  model.eval(); lr_model.eval()
  for x, y in test_loader:
      with torch.no_grad():
          bs    = x.shape[0]    #batch size
          x     = x.to(device)  #send data to device
          y     = y.to(device)  #send data to device
          pp   = model(x)[0]           #NN output
          p    = lr_model(pp)  # Logistic regression.
          y_NN  = p[:,:6]       #prediction for mean
          e_NN  = p[:,6:]       #prediction for error
          loss1 = torch.mean((y_NN[:,g] - y[:,g])**2,                     axis=0)
          loss2 = torch.mean(((y_NN[:,g] - y[:,g])**2 - e_NN[:,g]**2)**2, axis=0)
          test_loss1 += loss1*bs
          test_loss2 += loss2*bs

          # save results to their corresponding arrays
          params_true[points:points+x.shape[0]] = y.cpu().numpy()
          params_NN[points:points+x.shape[0]]   = y_NN.cpu().numpy()
          errors_NN[points:points+x.shape[0]]   = e_NN.cpu().numpy()
          points    += x.shape[0]
  test_loss = torch.log(test_loss1/points) + torch.log(test_loss2/points)
  test_loss = torch.mean(test_loss).item()
  print('Test loss = %.3e\n'%test_loss)
  return params_true, params_NN, errors_NN

from sklearn.metrics import r2_score, mean_squared_error

def get_r2_score(params_true, params_NN):
  r2_scores_params = []
  for i in range(params_true.shape[1]):
    r2_scores_params.append(
      r2_score(params_true[:, i], params_NN[:, i])
    )
  return r2_scores_params

def get_rmse(params_true, params_NN):
  rmse_scores_params = []
  for i in range(params_true.shape[1]):
    rmse_scores_params.append(
      mean_squared_error(params_true[:, i], params_NN[:, i])
    )
  return rmse_scores_params

class linear_reg(nn.Module):
  def __init__(self):
    super(linear_reg, self).__init__()
    self.linear_reg = nn.LazyLinear(12)

    # for m in self.modules():
    #   if isinstance(m, nn.LazyLinear):
    #     nn.init.kaiming_normal_(m.weight)

  def forward(self, x):
      x = self.linear_reg(x)
      # enforce the errors to be positive
      y = torch.clone(x)
      y[:,6:12] = torch.square(x[:,6:12])
      return y

# class model_o3_err_C01(nn.Module):
#   def __init__(self, hidden, dr, channels, state_dict):
#     super(model_o3_err_C01, self).__init__()
#       # input: 1x256x256 ---------------> output: 2*hiddenx128x128
#     self.C01 = nn.Conv2d(channels, 2*hidden, kernel_size=3, stride=1, padding=1,
#                             padding_mode='circular', bias=True)

#     with torch.no_grad():
#       self.C01.weight.copy_(state_dict['module.C01.weight'])
#       self.C01.bias.copy_(state_dict['module.C01.bias'])

#     self.C01.weight.requires_grad = False
#     self.C01.bias.requires_grad = False

#     self.LeakyReLU = nn.LeakyReLU(0.2)

#   def forward(self, image):
#     x01 = self.C01(image)
#     x01_ = self.LeakyReLU(x01)
#     return x01_

def perform_training(model_list, train_loader, valid_loader, epochs):
  model, lr_model = model_list
  optimizer = torch.optim.AdamW(lr_model.parameters(), lr=lr, weight_decay=wd, betas=(beta1, beta2))
  scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.3, patience=10)

  print('Computing initial validation loss')
  lr_model.eval(); model.eval()
  valid_loss1, valid_loss2 = torch.zeros(len(g)).to(device), torch.zeros(len(g)).to(device)
  min_valid_loss, points = 0.0, 0
  for x, y in valid_loader:
        with torch.no_grad():
            bs   = x.shape[0]                #batch size
            x    = x.to(device=device)       #maps
            y    = y.to(device=device)[:,g]  #parameters

            #with torch.no_grad():  # No weights must be updated. Only the linear regressor in the next line is trained.
            pp   = model(x)[0]           #NN output
            p    = lr_model(pp)  # Logistic regression.

            y_NN = p[:,g]                    #posterior mean
            e_NN = p[:,h]                    #posterior std
            loss1 = torch.mean((y_NN - y)**2,                axis=0)
            loss2 = torch.mean(((y_NN - y)**2 - e_NN**2)**2, axis=0)
            loss  = torch.mean(torch.log(loss1) + torch.log(loss2))
            valid_loss1 += loss1*bs
            valid_loss2 += loss2*bs
            points += bs
  min_valid_loss = torch.log(valid_loss1/points) + torch.log(valid_loss2/points)
  min_valid_loss = torch.mean(min_valid_loss).item()
  print('Initial valid loss = %.3e'%min_valid_loss)

  lr_model.train(); model.eval()
  # do a loop over all epochs
  start = time.time()
  for epoch in range(epochs):
      # do training
      train_loss1, train_loss2 = torch.zeros(len(g)).to(device), torch.zeros(len(g)).to(device)
      train_loss, points = 0.0, 0
      lr_model.train(); model.eval()
      for x, y in train_loader:
          bs   = x.shape[0]         #batch size
          x    = x.to(device)       #maps
          y    = y.to(device)[:,g]  #parameters

          #with torch.no_grad():  # No weights must be updated. Only the logistic regressor in the next line is trained.
          pp   = model(x)[0]           #NN output
          p    = lr_model(pp)  # Logistic regression.

          y_NN = p[:,g]             #posterior mean
          e_NN = p[:,h]             #posterior std
          loss1 = torch.mean((y_NN - y)**2,                axis=0)
          loss2 = torch.mean(((y_NN - y)**2 - e_NN**2)**2, axis=0)
          loss  = torch.mean(torch.log(loss1) + torch.log(loss2))
          train_loss1 += loss1*bs
          train_loss2 += loss2*bs
          points      += bs
          optimizer.zero_grad()
          loss.backward()
          optimizer.step()

          #if points>18000:  break
      train_loss = torch.log(train_loss1/points) + torch.log(train_loss2/points)
      train_loss = torch.mean(train_loss).item()

      # do validation: cosmo alone & all params
      valid_loss1, valid_loss2 = torch.zeros(len(g)).to(device), torch.zeros(len(g)).to(device)
      valid_loss, points = 0.0, 0
      lr_model.eval(); model.eval()
      for x, y in valid_loader:
          with torch.no_grad():
              bs    = x.shape[0]         #batch size
              x     = x.to(device)       #maps
              y     = y.to(device)[:,g]  #parameters

              pp    = model(x)[0]           #NN output
              p     = lr_model(pp)  # Logistic regression.

              y_NN  = p[:,g]             #posterior mean
              e_NN  = p[:,h]             #posterior std
              loss1 = torch.mean((y_NN - y)**2,                axis=0)
              loss2 = torch.mean(((y_NN - y)**2 - e_NN**2)**2, axis=0)
              loss  = torch.mean(torch.log(loss1) + torch.log(loss2))
              valid_loss1 += loss1*bs
              valid_loss2 += loss2*bs
              points     += bs
      valid_loss = torch.log(valid_loss1/points) + torch.log(valid_loss2/points)
      valid_loss = torch.mean(valid_loss).item()

      scheduler.step(valid_loss)

      # verbose
      print('%03d %.3e %.3e '%(epoch, train_loss, valid_loss), end='')

      # save model if it is better
      if valid_loss<min_valid_loss:
          torch.save(lr_model.state_dict(), fmodel)
          min_valid_loss = valid_loss
          print('(C) ', end='')
      print('')

      # save losses to file
      f = open(floss, 'a')
      f.write('%d %.5e %.5e\n'%(epoch, train_loss, valid_loss))
      f.close()

  stop = time.time()
  print('Time take (h):', "{:.4f}".format((stop-start)/3600.0))

  return lr_model

Collecting optuna==2.4.0
  Downloading optuna-2.4.0-py3-none-any.whl (282 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m282.7/282.7 kB[0m [31m3.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting alembic (from optuna==2.4.0)
  Downloading alembic-1.11.2-py3-none-any.whl (225 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m225.3/225.3 kB[0m [31m6.7 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting cliff (from optuna==2.4.0)
  Downloading cliff-4.3.0-py3-none-any.whl (80 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m80.6/80.6 kB[0m [31m5.1 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting cmaes>=0.6.0 (from optuna==2.4.0)
  Downloading cmaes-0.10.0-py3-none-any.whl (29 kB)
Collecting colorlog (from optuna==2.4.0)
  Downloading colorlog-6.7.0-py2.py3-none-any.whl (11 kB)
Collecting Mako (from alembic->optuna==2.4.0)
  Downloading Mako-1.2.4-py3-none-any.whl (78 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m78

# 2. Analysis start

In [2]:
base_dir = '/content/drive/MyDrive/CAMELS/After_ICML_results'
field = 'T'
data_sim = 'IllustrisTNG'
model_sim = 'IllustrisTNG'
study_name = 'wd_dr_hidden_lr_o3'

# data parameters
fmaps      = ['/content/maps_T.npy'] #tuple containing the maps with the different fields to consider
fmaps_norm = [None if data_sim == model_sim else os.path.join(base_dir, f'Maps_{field}_Nbody_{model_sim}_LH_z=0.00.npy' if field == 'Mtot' else f'Maps_{field}_{model_sim}_LH_z=0.00.npy')] #if you want to normalize the maps according to the properties of some data set, put that data set here (This is mostly used when training on IllustrisTNG and testing on SIMBA, or vicerversa)
fparams    = os.path.join(base_dir, f'params_LH_Nbody_{data_sim}.txt' if field == 'Mtot' else f'params_LH_{data_sim}.txt')  # Note: f'params_LH_Nbody_{data_sim}.txt' ONLY for Nbody, else f'params_LH_{data_sim}.txt'
seed       = 1   #random seed to split maps among training, validation and testing
splits     = 1   #number of maps per simulation

channels        = 1                #we only consider here 1 field
params          = [0,1,2,3,4,5]    #0(Omega_m) 1(sigma_8) 2(A_SN1) 3 (A_AGN1) 4(A_SN2) 5(A_AGN2). The code will be trained to predict all these parameters.
g               = params           #g will contain the mean of the posterior
h               = [6+i for i in g] #h will contain the variance of the posterior
rot_flip_in_mem = False            #whether rotations and flipings are kept in memory. True will make the code faster but consumes more RAM memory.

fmaps2 = os.path.join(base_dir, f'Maps_{field}_Nbody_{data_sim}_LH_z=0.00.npy' if field == 'Mtot' else f'Maps_{field}_{data_sim}_LH_z=0.00.npy')
maps  = np.load(fmaps2)
print('Shape of the maps:',maps.shape)

# define the array that will contain the indexes of the maps
indexes = np.zeros(1000*splits, dtype=np.int32)

# do a loop over all maps and choose the ones we want
count = 0
for i in range(15000):
    if i%15 in np.arange(splits):
      indexes[count] = i
      count += 1
print('Selected %d maps out of 15000'%count)

# save these maps to a new file
maps = maps[indexes]
np.save('maps_T.npy', maps)
print(f'Shape of selected maps array: {maps.shape}')
del maps

# IMPORTANT: shuffle must be False to ensure iteration over train loader is stable.
# We need the iterations to be stable since after saving, the outputs will be used for training.
batch_size = 1
train_loader = create_dataset_multifield('train', seed, fmaps, fparams, batch_size, splits, fmaps_norm,
                                         rot_flip_in_mem=rot_flip_in_mem, verbose=True, shuffle=False)
test_loader  = create_dataset_multifield('test', seed, fmaps, fparams, batch_size, splits, fmaps_norm,
                                        rot_flip_in_mem=False, verbose=True, shuffle=False)

true_params = []
for i, (x, y) in enumerate(train_loader):
    with torch.no_grad():
      bs   = x.shape[0]         #batch size
      x    = x.to(device)       #maps
      y    = y.to(device)[:, g]  #parameters
      true_params.append(y)

Shape of the maps: (15000, 256, 256)
Selected 1000 maps out of 15000
Shape of selected maps array: (1000, 256, 256)
Found 1 channels
Reading data...
(1000, 256, 256) 1000
1.547e+03 < F(all|orig) < 8.150e+07
3.189 < F(all|resc)  < 7.911
-1.265 < F(all|norm) < 4.516
Found 1 channels
Reading data...
(1000, 256, 256) 1000
1.547e+03 < F(all|orig) < 8.150e+07
3.189 < F(all|resc)  < 7.911
-1.265 < F(all|norm) < 4.516


todo: we need to use chunking fit only for layer_index=0,1,2....so modify code.
todo: test R2 on test set not train set.

In [3]:
def chunker(seq, size):
  return (seq[pos:pos + size] for pos in range(0, len(seq), size))

def std_dev(array, mean, n):  # Use this instead of np.std since we cannot load the entire 900 features from earlier layers into memory at once.
  return np.sqrt((1/n) * np.sum((array-mean)**2))

param_index = 0  # First parameter only
layer_index = 18
BATCH_SIZE = 10  # For training regressor below
trial_num = 10  # or 28 for ICML quick experiment.

import glob
if layer_index <= 2:  # memory issues if entire array is loaded in memory. So read in 900 files, one for each example in the training set.
  files = f'/content/drive/MyDrive/CAMELS/After_ICML_results/layer_outputs_for_trials_used_for_ICML/{field}_{model_sim}_{data_sim}_trial{trial_num}_all900outputs_layerIndex{layer_index}batchnum*.pt'  # note: this path should ideally also include trial number info...
  assert len(files) == 900
  sorted_files = sorted(glob.glob(files), key = lambda x: int(x.split('/')[8].split('_')[6][8:].split('.')[0]))
  def yield_data_splits():
    for f, par in zip(
        chunker(sorted_files, BATCH_SIZE),
        chunker(true_params, BATCH_SIZE)
    ):
      t = torch.vstack([torch.load(ff) for ff in f])
      params = [p[0][param_index] for p in par]
      yield t, params, t.mean()
elif layer_index >= 3:
  files = f'/content/drive/MyDrive/CAMELS/After_ICML_results/layer_outputs_for_trials_used_for_ICML/{field}_{model_sim}_{data_sim}_trial{trial_num}_all900outputs_layerIndex{layer_index}.pt'  # note: this path should ideally also include trial number info...
  assert len(glob.glob(files)) == 1
  outputs = torch.load(files)
  assert outputs.shape[0] == 900
  def yield_data_splits():
    for t, par in zip(
        chunker(outputs, BATCH_SIZE),
        chunker(true_params, BATCH_SIZE)
    ):
      params = [p[0][param_index] for p in par]
      yield t, params, t.mean()

if layer_index <= 2:
  _t = next(yield_data_splits())
  t_shape = torch.numel(_t[0]) * (len(sorted_files) / BATCH_SIZE)
  mean_all = np.sum([ds[2]*torch.numel(ds[0]) for ds in yield_data_splits()]) / t_shape
elif layer_index >= 3:
  mean_all = outputs.mean()

if layer_index <= 2:
  def yield_meand_sums():
    for f, par in zip(
        chunker(sorted_files, BATCH_SIZE),
        chunker(true_params, BATCH_SIZE)
    ):
      t = torch.vstack([torch.load(ff) for ff in f])
      params = [p[0][param_index] for p in par]

      yield torch.sum((t - mean_all)**2)

  len_meand_sums = sum(1 for _ in yield_meand_sums())
  meand_sums = np.fromiter((ms for ms in yield_meand_sums()), dtype=float)

  assert len_meand_sums == (len(sorted_files) / BATCH_SIZE)
elif layer_index >= 3:
  def yield_meand_sums():
    for t, par in zip(
        chunker(outputs, BATCH_SIZE),
        chunker(true_params, BATCH_SIZE)
    ):
      params = [p[0][param_index] for p in par]

      yield torch.sum((t - mean_all)**2)

if layer_index <= 2:
  stddev_all = np.sqrt((1/t_shape) * np.sum(meand_sums))  # For formula, see https://en.wikipedia.org/wiki/Standard_deviation#Relationship_between_standard_deviation_and_mean
else:
  stddev_all = outputs.std()

if layer_index >= 3:
  _t = next(yield_data_splits())
  t_shape = torch.numel(_t[0]) * (900 / BATCH_SIZE)
  mean_all_check = np.sum([ds[2]*torch.numel(ds[0]) for ds in yield_data_splits()]) / t_shape
  assert np.allclose(
      mean_all_check,
      outputs.mean()
  )
  len_meand_sums = sum(1 for _ in yield_meand_sums())
  meand_sums = np.fromiter((ms for ms in yield_meand_sums()), dtype=float)
  std_all_check = np.sqrt((1/t_shape) * np.sum(meand_sums))
  assert np.allclose(
      std_all_check,
      outputs.std()
  )
print(outputs.min(), outputs.max(), mean_all, stddev_all)

tensor(-0.0859) tensor(0.2103) tensor(0.0024) tensor(0.0265)


todo: check this code works for small test cases...

In [4]:
stddev_all, mean_all

(tensor(0.0265), tensor(0.0024))

In [5]:
from sklearn.linear_model import Ridge, LinearRegression, SGDRegressor
from sklearn.preprocessing import StandardScaler

reg = SGDRegressor()
for _ in range(10): # 10 passes through the data
  for X, y, _ in yield_data_splits():
    X = (X - mean_all) / stddev_all
    X = X.reshape(X.shape[0], -1)
    reg.partial_fit(X, y)

In [6]:
y_preds = []
y_true = []
for X, y, _ in yield_data_splits():
  X = (X - mean_all) / stddev_all
  X = X.reshape(X.shape[0], -1)
  y_pred = reg.predict(X)
  y_preds.append(y_pred)
  y_true.append(y)

r2s = r2_score(np.hstack(y_true), np.hstack(y_preds))
print(f'R2 score with layer_index {layer_index}, trial {trial_num} = {r2s}')

R2 score with layer_index 18, trial 10 = -5.3750752205003535e+20


In [11]:
import gc
def get_outputs(model, train_loader, prefix=None, layer_index=0):  # index specifies output of which layer we want to extract.
  model.eval()
  for i, (x, y) in enumerate(train_loader):
    with torch.no_grad():
      bs   = x.shape[0]         #batch size
      x    = x.to(device)       #maps
      y    = y.to(device)[:, g]  #parameters

      out  = model(x)           #NN output
      if layer_index == 0 or layer_index == 1 or layer_index == 2:  # memory issues, so we save each output.
        # torch.save(out, f'{field}_{model_sim}_{data_sim}_trial{trial_number_1}_batchsize{batch_size}_layerIndex{layer_index}_batchnum{i}.pt')
        del out
        gc.collect()
      else:
        yield out

def yield_data_splits_for_test():
  for t, par in zip(
      chunker(outputs, BATCH_SIZE),
      chunker(true_params, BATCH_SIZE)
  ):
    params = [p[0][param_index] for p in par]
    yield t, params

true_params = []
for i, (_, y) in enumerate(test_loader):
    with torch.no_grad():
      bs   = x.shape[0]         #batch size
      # x    = x.to(device)       #maps
      y    = y.to(device)[:,g]  #parameters
      true_params.append(y)

# trials_1 = np.load(os.path.join(base_dir, 'trials_1.npy'))
# trials_2 = np.load(os.path.join(base_dir, 'trials_2.npy'))

# assert len(np.intersect1d(trials_1, trials_2)) == 0

fweights_1 = os.path.join(base_dir, f'weights_{model_sim}_{field}_{trial_num}_all_steps_500_500_o3.pt')
fdatabase1 = 'sqlite:////' + os.path.join(base_dir, f'{model_sim}_o3_{field}_all_steps_500_500_o3.db')

study = optuna.load_study(study_name=study_name, storage=fdatabase1)

trial_number = int(fweights_1.split('/')[6].split('_')[3])

model = load_model_for_similarity_measure(fweights_1, fdatabase1, study_name, trial_number, layer_output=True, layer_index=layer_index)

outputs = get_outputs(model, test_loader, layer_index=layer_index)
outputs = torch.vstack([arr for arr in outputs])

y_preds = []
y_true = []
for X, y in yield_data_splits_for_test():
  X = X.reshape(X.shape[0], -1)
  X = (X - mean_all) / stddev_all
  y_pred = reg.predict(X)
  y_preds.append(y_pred)
  y_true.append(y)

r2s = r2_score(np.hstack(y_true), np.hstack(y_preds))
print(f'R2 score with layer_index {layer_index}, trial {trial_num} = {r2s}')

GPUs not available

Trial number: 10
Trial number:  number 10
Loss:          -1.22274e+01
Params: 
    dr: 0.6695702848487869
    hidden: 6
    lr: 0.0011543514200712233
    wd: 3.859932400788152e-07
total number of parameters in the model = 8466876
Weights loaded
R2 score with layer_index 10, trial 10 = -1.5864497833485992e+28


todo: training happens for one trial but I am testing on two trials -- fix.

In [None]:
r2_score(np.hstack(y_true), np.hstack(y_preds))  # # with layer_index=0

0.21904098479058076

In [None]:
r2_score(np.hstack(y_true), np.hstack(y_preds))  # # with layer_index=20

0.8809224141966819

In [None]:
r2_score(np.hstack(y_true), np.hstack(y_preds))  # # with layer_index=10

0.38652268214278784

In [None]:
r2_score(np.hstack(y_true), np.hstack(y_preds))  # # with layer_index=16

0.9613126182628591

In [None]:
r2_score(np.hstack(y_true), np.hstack(y_preds))  # # with layer_index=18

0.9753024655199403

In [None]:
y_preds  # with layer_index=0

[array([-1.74921162e+14, -1.22941702e+15,  1.15113647e+15, -3.54342462e+14,
         1.26445390e+14,  3.02976815e+14, -3.28154966e+14,  8.32266892e+14,
         1.36467119e+15,  6.87493922e+14]),
 array([ 4.52157570e+14, -9.15411201e+14, -1.55280093e+14, -7.36774970e+14,
        -1.87113722e+14, -2.56441961e+14,  1.26406706e+14, -2.12743179e+14,
         1.13585875e+15,  1.30814196e+15]),
 array([-1.08765192e+15,  6.28530446e+14, -4.79874470e+14, -1.24491919e+15,
         1.82060987e+14, -3.51950213e+14,  3.46235231e+14,  2.63217652e+14,
        -2.65794876e+14,  9.40113987e+14]),
 array([-1.18786307e+15, -4.65894348e+14, -1.68375060e+14,  1.01303510e+15,
         4.28621519e+13,  3.89659114e+14,  6.75878222e+14,  2.92869915e+14,
         4.52482279e+14, -4.71503418e+14]),
 array([-8.09659679e+14, -2.73468372e+14,  5.92967614e+14, -5.51126195e+14,
         3.96729200e+14,  1.20832523e+15,  5.85339633e+14, -1.09643273e+15,
         2.54221537e+14,  8.68517490e+14]),
 array([ 3.45783870e

In [None]:
y_preds  # with layer_index=10

[array([ 1.01933446e+14,  8.05261820e+13,  1.84602267e+14,  1.05339799e+14,
         4.43008251e+13,  7.01698047e+13,  1.09704668e+14, -3.20906777e+13,
         2.05806608e+13, -1.88376388e+13]),
 array([8.56284749e+13, 1.03866768e+14, 5.86992902e+13, 1.19602031e+14,
        4.33631128e+13, 7.04106850e+13, 9.68914223e+13, 6.07550786e+13,
        4.64704957e+13, 2.06202365e+13]),
 array([ 8.10123404e+13, -4.36167139e+13,  7.28009635e+13,  1.37008496e+14,
        -3.97418926e+13, -3.76108309e+13,  6.91322120e+13,  1.86807056e+13,
         9.84340932e+13,  1.13996722e+13]),
 array([ 7.10702042e+13,  2.73719581e+13,  3.28367932e+13,  9.32909089e+13,
         4.87876590e+12,  8.08892370e+13, -4.87724506e+12,  3.46425689e+13,
         5.50322772e+13,  4.51992930e+13]),
 array([ 1.71248540e+13,  7.24101279e+13,  5.77935562e+13,  4.40714346e+13,
        -2.38598329e+13, -6.08996426e+12,  2.56322698e+13,  1.47027046e+14,
         8.83174080e+13, -4.05753798e+12]),
 array([ 4.42034166e+13,  1.28

In [None]:
y_preds  # with layer_index=16

[array([-4.45118357e+11,  9.53812006e+11,  1.14939835e+13, -2.93997318e+12,
        -1.47145224e+13, -2.38411666e+12, -1.14178486e+13,  5.05004536e+12,
         2.26684350e+12,  1.19420816e+13]),
 array([ 1.62113400e+13, -5.25175842e+12, -1.31925641e+13, -9.29885629e+12,
        -1.70820983e+13, -1.71609928e+13,  5.67826318e+12, -8.98114894e+11,
         2.39841036e+13, -8.73528483e+11]),
 array([-1.55433208e+13,  1.50842727e+13,  1.59793258e+13,  1.96340060e+13,
         3.51092186e+12, -2.57553615e+13, -6.75088408e+12, -2.89751580e+12,
        -1.77541530e+11, -2.78513684e+13]),
 array([-1.33739206e+13, -6.58147150e+12, -1.75108409e+13,  8.88851986e+12,
        -4.30321150e+12,  1.26082120e+13, -1.26120838e+13,  2.99891841e+13,
        -1.54913559e+13, -7.10903588e+12]),
 array([ 2.31293479e+13, -2.05318409e+13, -3.33764721e+12, -2.56637258e+13,
        -1.29731125e+13,  2.03272038e+12, -3.92356638e+12, -2.28697099e+13,
        -3.87492551e+13, -8.11679912e+12]),
 array([-8.41310168e

In [None]:
y_preds  # with layer_index=18

[array([0.96408858, 0.88510712, 0.74801715, 0.41726325, 0.14450293,
        0.94204061, 0.68642963, 0.23873456, 0.18076492, 0.89925655]),
 array([0.61956511, 0.82914544, 0.39426974, 0.74801086, 0.1350881 ,
        0.6596291 , 0.53251797, 0.48623539, 0.9467238 , 0.32439335]),
 array([0.85825694, 0.62147529, 0.76711667, 0.95234358, 0.45985393,
        0.63305692, 0.88612779, 0.49248477, 0.44888753, 0.21766493]),
 array([0.22425999, 0.30398587, 0.23830638, 0.382821  , 0.94705546,
        0.56443306, 0.62120502, 0.54131019, 0.07111339, 0.45079669]),
 array([0.73627888, 0.11213173, 0.32320059, 0.09850259, 0.17090231,
        0.05302753, 0.66926856, 0.92054819, 0.04092115, 0.0649834 ]),
 array([0.13063946, 0.95871142, 0.73914542, 0.29915338, 0.08589331,
        0.72396739, 0.67769806, 0.10211055, 0.11268701, 0.76155581]),
 array([0.87900742, 0.32301346, 0.88566844, 0.46056626, 0.95554547,
        0.18102891, 0.64117329, 0.43251788, 0.56286015, 0.09226081]),
 array([0.94012606, 0.67724452, 0.