### Tips for training a ViT:

- https://wandb.ai/dtamkus/posts/reports/5-Tips-for-Creating-Lightweight-Vision-Transformers--Vmlldzo0MjQyMzg0
- Improving performance of ViTs on small datasets: https://keras.io/examples/vision/vit_small_ds/
- What if we use a pretrained ViT from Imagenet, for example? See [here](https://github.com/hananshafi/vits-for-small-scale-datasets):
> ...in contrast to convolutional neural networks, Vision Transformer lacks inherent inductive biases. Therefore, successful training of such models is mainly attributed to pre-training on large-scale datasets such as ImageNet with 1.2M or JFT with 300M images.
- The self-attention mechanism allows to learn the relationship between different patches in the input sequence. From [here](https://arxiv.org/pdf/2304.08192.pdf).
- ViTs: permutation-equivariant, not translation invariant, not feature heirarchy. See [the tutorial](https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/tutorial15/Vision_Transformer.html).

In [None]:
!pip install Pylians

In [None]:
!pip install torch_intermediate_layer_getter

In [None]:
USE_COLAB = False
if USE_COLAB:
    base_dir = '/content'
else:
    base_dir = '/kaggle/working'

In [None]:
if USE_COLAB:
    !rm -rf /content/ViT-LSS
else:
    !rm -rf /kaggle/working/ViT-LSS
!git clone https://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@github.com/Yash-10/ViT-LSS.git

In [None]:
if USE_COLAB:
    !cp /content/ViT-LSS/scripts/*.py /content
else:
    !cp /kaggle/working/ViT-LSS/scripts/*.py /kaggle/working

In [None]:
!pip install pytorch-lightning==2.1.3

In [None]:
## Standard libraries
import os
import numpy as np
import random
import math
import json
from functools import partial
from PIL import Image

## Imports for plotting
import matplotlib.pyplot as plt
plt.set_cmap('cividis')
%matplotlib inline
from IPython.display import set_matplotlib_formats
set_matplotlib_formats('svg', 'pdf') # For export
from matplotlib.colors import to_rgb
import matplotlib
matplotlib.rcParams['lines.linewidth'] = 2.0
import seaborn as sns
sns.reset_orig()

## tqdm for loading bars
from tqdm.notebook import tqdm

## PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torch.optim as optim

## Torchvision
import torchvision
from torchvision.datasets import CIFAR10
from torchvision import transforms

# PyTorch Lightning
try:
    import pytorch_lightning as pl
except ModuleNotFoundError: # Google Colab does not have PyTorch Lightning installed by default. Hence, we do it here if necessary
    !pip install --quiet pytorch-lightning>=1.4
    import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint

# Import tensorboard
%load_ext tensorboard

# Path to the folder where the pretrained models are saved
CHECKPOINT_PATH = "./"

# Setting the seed
SEED = 42
pl.seed_everything(SEED)

# Ensure that all operations are deterministic on GPU (if used) for reproducibility
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
print("Device:", device)

In [None]:
import numpy as np
import gzip
import pandas as pd
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
if USE_COLAB:
    from google.colab import drive
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader
import time, sys, os
import matplotlib.pyplot as plt

# optimizer parameters
beta1 = 0.5
beta2 = 0.999

# TODO: Need to do hyperparameter optimization.
batch_size = 32
lr         = 1e-3
wd         = 1e-5  #value of weight decay
dr         = 0.2    #dropout value for fully connected layers
epochs     = 25    #number of epochs to train the network

channels        = 1                #we only consider here 1 field
params          = [0,1,2,3,4]    #Omega_m, Omega_b, h, n_s, sigma_8. The code will be trained to predict all these parameters.
g               = params           #g will contain the mean of the posterior
h               = [5+i for i in g] #h will contain the variance of the posterior

model_kwargs = {
    'embed_dim': 256,
    'hidden_dim': 512,
    'num_heads': 8,
    'num_layers': 6,
    'patch_size': 16,
    'num_channels': 1,
    'num_patches': 16,
    'num_classes': 10,
    'dropout': dr
}

GRID_SIZE = 64

# output files names
floss  = 'loss.txt'   #file with the training and validation losses for each epoch
fmodel = 'weights.pt' #file containing the weights of the best-model

num_maps_per_projection_direction = 10
num_sims = 1000

In [None]:
import seaborn as sns
sns.set_style('whitegrid')
sns.set(style='ticks')
sns.set_context("paper", font_scale = 2)

## Visualization of fields for different parameters

In [None]:
from utils import read_hdf5
import glob

o, s = [], []
sorted_files = sorted(glob.glob('/kaggle/input/density-fields-vit-lss-64/my_outputs/*.h5'))
for f in sorted_files:
    _, params = read_hdf5(f)
    o.append(params[0])
    s.append(params[4])  # 4 or -1 will give the same.

den_max_om, den_max_om_params = read_hdf5(sorted_files[o.index(max(o))])
den_min_om, den_min_om_params = read_hdf5(sorted_files[o.index(min(o))])
den_max_s8, den_max_s8_params = read_hdf5(sorted_files[s.index(max(s))])
den_min_s8, den_min_s8_params = read_hdf5(sorted_files[s.index(min(s))])

# Showing the same (random) slice from all four cases.
vmin = np.log10(min(den_max_om[:, 10, :].min(), den_min_om[:, 10, :].min(), den_max_s8[:, 10, :].min(), den_min_s8[:, 10, :].min()))
vmax = np.log10(max(den_max_om[:, 10, :].max(), den_min_om[:, 10, :].max(), den_max_s8[:, 10, :].max(), den_min_s8[:, 10, :].max()))
plt.imshow(np.log10(den_max_om[:, 10, :]), vmin=vmin, vmax=vmax); plt.title(den_max_om_params); plt.colorbar(); plt.show()
plt.imshow(np.log10(den_min_om[:, 10, :]), vmin=vmin, vmax=vmax); plt.title(den_min_om_params); plt.colorbar(); plt.show()
plt.imshow(np.log10(den_max_s8[:, 10, :]), vmin=vmin, vmax=vmax); plt.title(den_max_s8_params); plt.colorbar(); plt.show()
plt.imshow(np.log10(den_min_s8[:, 10, :]), vmin=vmin, vmax=vmax); plt.title(den_min_s8_params); plt.colorbar(); plt.show()

# Pretraining

In [None]:
if USE_COLAB:
    !wget https://www.dropbox.com/scl/fi/jqyvpxl17hp7pinqtd68c/density_fields_3D_LH_z0_grid64_masCIC.tar.gz?rlkey=cvf3oxbd922xxrzv8tue0zoiv&dl=0

In [None]:
if USE_COLAB:
    !tar -xzf /content/density_fields_3D_LH_z0_grid64_masCIC.tar.gz?rlkey=cvf3oxbd922xxrzv8tue0zoiv
else:
    pass

In [None]:
import subprocess

prefix=''
if USE_COLAB:
    command = [
        'python', 'create_data.py', '--num_sims', f'{num_sims}', '--train_frac', '0.8', '--test_frac', '0.1',
        '--seed', f'{SEED}', '--path', '/content/my_outputs', '--grid_size', f'{GRID_SIZE}',
        '--num_maps_per_projection_direction', f'{num_maps_per_projection_direction}', '--prefix', ''
    ]
else:
    command = [
        'python', 'create_data.py', '--num_sims', f'{num_sims}', '--train_frac', '0.8', '--test_frac', '0.1',
        '--seed', f'{SEED}', '--path', '/kaggle/input/density-fields-vit-lss-64/my_outputs', '--grid_size', f'{GRID_SIZE}',
        '--num_maps_per_projection_direction', f'{num_maps_per_projection_direction}', '--prefix', ''
    ]
result = subprocess.run(command)
result

In [None]:
# Store the mean, std, min_vals and max_vals into variables
MEAN = np.load(f'{prefix}_dataset_mean.npy')
STD = np.load(f'{prefix}_dataset_std.npy')
MIN_VALS = np.load(f'{prefix}_dataset_min_vals.npy')
MAX_VALS = np.load(f'{prefix}_dataset_max_vals.npy')
MEAN_DENSITIES = np.load(f'{prefix}_dataset_mean_densities.npy')

In [None]:
!ls train | wc -l
!ls val | wc -l
!ls test | wc -l

In [None]:
from model_dataset import CustomImageDataset
from torchvision.transforms import v2

import torchvision.transforms.functional as TF

import random
# See https://pytorch.org/vision/0.15/transforms.html
class MyRotationTransform:
    """Rotate by 90/180/270 degrees."""

    def __init__(self, angles):
        self.angles = angles

    def __call__(self, x):
        angle = random.choice(self.angles)
        x = torch.from_numpy(x).unsqueeze(0)
        return TF.rotate(x, angle).squeeze()

transform = v2.Compose([
    MyRotationTransform(angles=[90, 180, 270]),
    v2.ToDtype(torch.float32),#, scale=False),
])

train_dataset = CustomImageDataset(f'{base_dir}/train', normalized_cosmo_params_path='train/train_normalized_params.csv', transform=transform)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)

val_dataset = CustomImageDataset(f'{base_dir}/val', normalized_cosmo_params_path='val/val_normalized_params.csv', transform=None)
val_loader = DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=False)

test_dataset = CustomImageDataset(f'{base_dir}/test', normalized_cosmo_params_path='test/test_normalized_params.csv', transform=None)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

In [None]:
def img_to_patch(x, patch_size, flatten_channels=True):
    """
    Inputs:
        x - torch.Tensor representing the image of shape [B, C, H, W]
        patch_size - Number of pixels per dimension of the patches (integer)
        flatten_channels - If True, the patches will be returned in a flattened format
                           as a feature vector instead of a image grid.
    """
    B, C, H, W = x.shape
    x = x.reshape(B, C, H//patch_size, patch_size, W//patch_size, patch_size)
    x = x.permute(0, 2, 4, 1, 3, 5) # [B, H', W', C, p_H, p_W]
    x = x.flatten(1,2)              # [B, H'*W', C, p_H, p_W]
    if flatten_channels:
        x = x.flatten(2,4)          # [B, H'*W', C*p_H*p_W]
    return x

x = next(iter(train_loader))
print(x[0].shape)

# See original code in the original tutorial: https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/tutorial15/Vision_Transformer.html
img_patches = img_to_patch(x[0], patch_size=16, flatten_channels=False)

fig, ax = plt.subplots(1, 1, figsize=(10,10))
fig.suptitle("Images as input sequences of patches")
for i in range(x[0].shape[0]):
    img_grid = torchvision.utils.make_grid(img_patches[i], nrow=4, normalize=False, pad_value=0.9)
    img_grid = img_grid.permute(1, 2, 0)
    ax.imshow(img_grid)
    ax.axis('off')
    break

plt.tight_layout()
plt.show()
plt.close()

TODO: Remember that the positional embedding (`self.pos_embedding`) below is learnable assuming we use fixed-resolution images. In the future, we would like to transfer learn to different resolution images, in which case we need to use the sine and cosine functions proposed in the original ViT paper.

In [None]:
from utils import get_rmse_score

In [None]:
class AttentionBlock(nn.Module):

    def __init__(self, embed_dim, hidden_dim, num_heads, dropout=0.0):
        """
        Inputs:
            embed_dim - Dimensionality of input and attention feature vectors
            hidden_dim - Dimensionality of hidden layer in feed-forward network
                         (usually 2-4x larger than embed_dim)
            num_heads - Number of heads to use in the Multi-Head Attention block
            dropout - Amount of dropout to apply in the feed-forward network
        """
        super().__init__()

        self.layer_norm_1 = nn.LayerNorm(embed_dim)
        self.attn = nn.MultiheadAttention(embed_dim, num_heads,
                                          dropout=dropout)
        self.layer_norm_2 = nn.LayerNorm(embed_dim)
        self.linear = nn.Sequential(
            nn.Linear(embed_dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, embed_dim),
            nn.Dropout(dropout)
        )


    def forward(self, x):
        inp_x = self.layer_norm_1(x)
        x = x + self.attn(inp_x, inp_x, inp_x)[0]
        x = x + self.linear(self.layer_norm_2(x))
        return x


class VisionTransformer(nn.Module):

    def __init__(self, embed_dim, hidden_dim, num_channels, num_heads, num_layers, num_classes, patch_size, num_patches, dropout=0.0):
        """
        Inputs:
            embed_dim - Dimensionality of the input feature vectors to the Transformer
            hidden_dim - Dimensionality of the hidden layer in the feed-forward networks
                         within the Transformer
            num_channels - Number of channels of the input (3 for RGB)
            num_heads - Number of heads to use in the Multi-Head Attention block
            num_layers - Number of layers to use in the Transformer
            num_classes - Number of classes to predict
            patch_size - Number of pixels that the patches have per dimension
            num_patches - Maximum number of patches an image can have
            dropout - Amount of dropout to apply in the feed-forward network and
                      on the input encoding
        """
        super().__init__()

        self.patch_size = patch_size

        # Layers/Networks
        self.input_layer = nn.Linear(num_channels*(patch_size**2), embed_dim)
        self.transformer = nn.Sequential(*[AttentionBlock(embed_dim, hidden_dim, num_heads, dropout=dropout) for _ in range(num_layers)])
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(embed_dim),
            nn.Linear(embed_dim, num_classes)
        )
        self.dropout = nn.Dropout(dropout)

        # Parameters/Embeddings
        self.cls_token = nn.Parameter(torch.randn(1,1,embed_dim))
        self.pos_embedding = nn.Parameter(torch.randn(1,1+num_patches,embed_dim))


    def forward(self, x):
        # Preprocess input
        x = img_to_patch(x, self.patch_size)
        B, T, _ = x.shape
        x = self.input_layer(x)

        # Add CLS token and positional encoding
        cls_token = self.cls_token.repeat(B, 1, 1)
        x = torch.cat([cls_token, x], dim=1)
        x = x + self.pos_embedding[:,:T+1]

        # Apply Transforrmer
        x = self.dropout(x)
        x = x.transpose(0, 1)
        x = self.transformer(x)

        # Perform regression prediction
        cls = x[0]
        out = self.mlp_head(cls)
        return out


class ViT(pl.LightningModule):

    def __init__(self, model_kwargs, lr, wd, beta1, beta2):
        super().__init__()
        self.save_hyperparameters()
        self.model = VisionTransformer(**model_kwargs)
        self.example_input_array = next(iter(train_loader))[0]

    def forward(self, x):
        x = self.model(x)
        # enforce the errors to be positive
        y = torch.clone(x)
        y[:,5:10] = torch.square(x[:,5:10])
        return y

    def configure_optimizers(self):
        optimizer = optim.AdamW(self.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.wd, betas=(self.hparams.beta1, self.hparams.beta2))
        lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.3, patience=5)
#         return [optimizer], [lr_scheduler]

        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": lr_scheduler,
                "monitor": "val_loss"
            },
        }

    def _calculate_loss(self, batch, mode="train"):
        x, y, _ = batch
        p = self.model(x)
        y_NN = p[:,g]             #posterior mean
        e_NN = p[:,h]             #posterior std
        # TODO: Ensure this below implementation of loss is actually the one optimized in CAMELS.
        loss1 = torch.mean((y_NN - y)**2,                axis=0)
        loss2 = torch.mean(((y_NN - y)**2 - e_NN**2)**2, axis=0)
        loss  = torch.mean(torch.log(loss1) + torch.log(loss2))
        self.log(f'{mode}_loss', loss)

        # TODO: Where are the logs/logged values stored? Need to find and print manually at end of training outside this function.

        # Also log RMSE and sigma_bar for all parameters.
        rmse = get_rmse_score(y.cpu().detach().numpy(), y_NN.cpu().detach().numpy())
        sigma_bar = np.mean(y_NN.cpu().detach().numpy(), axis=0)
        # Only log at the end of epoch instead of each step.
        # Logging is only done for Omega_m and sigma_8 since only these are interesting for DM density/DM halo fields.
        # But more can easily be added here if and when needed.
        self.log(f'{mode}_omegam_rmse', rmse[0], on_step=False, on_epoch=True)
        self.log(f'{mode}_omegam_sigma_bar', sigma_bar[0], on_step=False, on_epoch=True)
        self.log(f'{mode}_sigma8_rmse', rmse[-1], on_step=False, on_epoch=True)
        self.log(f'{mode}_sigma8_sigma_bar', sigma_bar[-1], on_step=False, on_epoch=True)
        return loss

    def training_step(self, batch, batch_idx):
        loss = self._calculate_loss(batch, mode="train")
        return loss

    def validation_step(self, batch, batch_idx):
        self._calculate_loss(batch, mode="val")

    def test_step(self, batch, batch_idx):
        self._calculate_loss(batch, mode="test")

In [None]:
def train_model(**kwargs):
    trainer = pl.Trainer(default_root_dir=os.path.join(CHECKPOINT_PATH, "ViT"),
                         accelerator="gpu" if str(device).startswith("cuda") else "cpu",
                         devices=1,
                         max_epochs=epochs,
                         callbacks=[ModelCheckpoint(save_weights_only=True, mode="min", monitor="val_loss"),
                                    LearningRateMonitor("epoch")])
    trainer.logger._log_graph = True         # If True, we plot the computation graph in tensorboard
    trainer.logger._default_hp_metric = None # Optional logging argument that we don't need

    # Check whether pretrained model exists. If yes, load it and skip training
    pretrained_filename = os.path.join(CHECKPOINT_PATH, "ViT.ckpt")
    if os.path.isfile(pretrained_filename):
        print(f"Found pretrained model at {pretrained_filename}, loading...")
        model = ViT.load_from_checkpoint(pretrained_filename) # Automatically loads the model with the saved hyperparameters
    else:
        pl.seed_everything(42) # To be reproducable
        model = ViT(**kwargs)
        trainer.fit(model, train_loader, val_loader)
        model = ViT.load_from_checkpoint(trainer.checkpoint_callback.best_model_path) # Load best checkpoint after training

    # Test best model on validation and test set
    val_result = trainer.test(model, val_loader, verbose=False)
    test_result = trainer.test(model, test_loader, verbose=False)
    result = {"test": test_result[0]["test_loss"], "val": val_result[0]["test_loss"]}

    return model, result

In [None]:
!pip install torchsummary

In [None]:
from torchsummary import summary

def load_model_for_torch_summary(**kwargs):
    model_torch_summary = ViT(**kwargs)
    return model_torch_summary

model_torch_summary = load_model_for_torch_summary(model_kwargs=model_kwargs, lr=lr, wd=wd, beta1=beta1, beta2=beta2)
summary(model_torch_summary.to(device), (1, 64, 64))

In [None]:
model, results = train_model(
    model_kwargs=model_kwargs,
    lr=lr, wd=wd, beta1=beta1, beta2=beta2
)
print("ViT results", results)
model.to(device)

In [None]:
from train_val_test_boilerplate import test

# Below values calculated during data preparation. See above.
minimum = MIN_VALS
maximum = MAX_VALS

params_true, params_NN, errors_NN, filenames = test(model, test_loader, g=g, h=h, device=device, minimum=minimum, maximum=maximum)

In [None]:
# # Opens tensorboard in notebook. Adjust the path to your CHECKPOINT_PATH!
# %tensorboard --logdir ./tensorboards/

In [None]:
from evaluation_analysis import post_test_analysis

post_test_analysis(
    params_true, params_NN, errors_NN, filenames,
    test_loader, params, num_sims, MEAN, STD, MEAN_DENSITIES, minimum, maximum,
    num_maps_per_projection_direction, None, dr, channels, fmodel,
    device=device, test_results_filename='test_results.csv', cka_filename='cka_matrix_pretrained_CNN_grid64_test.png',
    smallest_sim_number=0, model=model.model,
    return_layers = {
        'transformer.0.linear.1': 'GELU',
        'transformer.1.linear.1': 'GELU',
        'transformer.2.linear.1': 'GELU',
        'transformer.3.linear.1': 'GELU',
        'transformer.4.linear.1': 'GELU',
        'transformer.5.linear.1': 'GELU'
    }
)

## Interpreting the ViT

See [here](https://github.com/jacobgil/pytorch-grad-cam/issues/140) for more information.

TODO: Grad-CAM section below is commented since the code didn't work. Try to make it work.

In [None]:
# model.model.transformer[-1].layer_norm_1

In [None]:
# import torch
# import torch.nn.functional as F
# import numpy as np
# import matplotlib.pyplot as plt

# class GradCAMRegressor:
#     def __init__(self, model, target_layer, ground_truth_param_value, index):
#         self.model = model
#         self.target_layer = target_layer
#         self.feature_maps = None  # Placeholder for feature maps
#         self.gradients = None  # Placeholder for gradients
#         self.ground_truth_param_value = ground_truth_param_value
#         self.index = index

#         # Set the model to evaluation mode
#         self.model.eval()

#         # Register a hook to capture gradients of the target layer
#         self.hook = self.register_hooks()

#     def register_hooks(self):
#         def forward_hook(module, input, output):
#             self.feature_maps = output

#         def backward_hook(module, grad_input, grad_output):
#             self.gradients = grad_output[0]

#         # Register hooks for both forward and backward passes
#         forward_hook_handle = self.target_layer.register_forward_hook(forward_hook)
#         backward_hook_handle = self.target_layer.register_full_backward_hook(backward_hook)

#         return forward_hook_handle, backward_hook_handle

#     def remove_hooks(self):
#         # Remove hooks after usage
#         self.hook[0].remove()
#         self.hook[1].remove()

#     def generate_gradcam(self, input_tensor, interpolate=False, target_size=(64, 64)):  # If interpolate=True, interpolates the Grad-CAM heatmap to target_size. target_size is used only when interpolate=True.
#         # Forward pass
#         input_tensor.requires_grad_()
#         self.model.zero_grad()
#         model_output = self.model(input_tensor)[:, self.index]

#         D = model_output - self.ground_truth_param_value
#         # d = 1 / D

#         # Backward pass to compute gradients
#         model_output.backward(torch.ones_like(model_output), retain_graph=True)

#         # The approach is from https://arxiv.org/pdf/2304.08192.pdf
#         # TODO: CHANGED FOR ViT BELOW LINE
#         self.gradients = self.gradients * (-1 / (D**2))

#         # Retrieve gradients and feature maps
#         gradients = self.gradients  # Gradients of the target layer
#         feature_maps = self.feature_maps  # Output of the target layer

#         print(gradients.shape, feature_maps.shape)
#         return

#         # TODO: Doing the below to prevent error for now. Need to find reliable solution.
#         # TODO: Why is gradients and features a 3d tensor instead of 4d? what each dimension in that represents?
#         # TODO: I get 17 in one dim = 1+4*4. So i need to smth like x[1:] to get 16 elements, and then reshape to 4X4.
#         gradients = gradients.unsqueeze(0)
#         feature_maps = feature_maps.unsqueeze(0)

#         print(f'Gradients shape: {gradients.shape}')
#         print(f'Feature maps shape: {feature_maps.shape}')

#         # pool the gradients across the channels
#         pooled_gradients = torch.mean(gradients, dim=[0, 2, 3])

#         # weight the channels by corresponding gradients
#         for i in range(feature_maps.size()[1]):
#             feature_maps[:, i, :, :] *= pooled_gradients[i]

#         # average the channels of the activations and squeeze across the batch dimension since we assume only one image is passed.
#         # Note: If want to support multiple images at once, need to remove squeeze(dim=0) and instead add mean(dim=0) to average cams across batch images.
#         cam = torch.mean(feature_maps, dim=1).squeeze(dim=0)

#         # relu on top of the heatmap
#         cam = F.relu(cam)

# #         # normalize the heatmap
# #         cam /= torch.max(cam)

# #         # Calculate weighted combination of feature maps
# #         weights = F.adaptive_avg_pool2d(gradients, 1)
# #         cam = torch.sum(weights * feature_maps, dim=1, keepdim=True)
# #         cam = F.relu(cam)
        
#         if interpolate:
#             cam = cam.unsqueeze(0).unsqueeze(0)
#             # Resize Grad-CAM to match the input image size
#             cam = F.interpolate(cam, size=target_size, mode='bilinear', align_corners=False)

#         return cam

#     def visualize_gradcam(self, input_tensor, target_size=(64, 64)):
#         gradcam = self.generate_gradcam(input_tensor)
#         print(f'Shape of the raw gradcam map: {gradcam.shape}')

#         gradcam = gradcam.unsqueeze(0).unsqueeze(0)

#         # Resize Grad-CAM to match the input image size
#         gradcam = F.interpolate(gradcam, size=target_size, mode='bilinear', align_corners=False)

#         print(f'Shape of the interpolated gradcam map: {gradcam.shape}')

#         # Convert to numpy array for visualization
#         gradcam = gradcam.squeeze().cpu().detach().numpy()

#         # Normalize for visualization
#         gradcam = (gradcam - np.min(gradcam)) / (np.max(gradcam) - np.min(gradcam) + 1e-8)

#         # Display the original image
#         original_image = input_tensor[0].permute(1, 2, 0).detach().cpu().numpy()
#         original_image = (original_image - np.min(original_image)) / (np.max(original_image) - np.min(original_image) + 1e-8)
#         plt.imshow(original_image)

#         # Overlay Grad-CAM on the original image
#         plt.imshow(gradcam, cmap='jet', alpha=0.3, interpolation='bilinear')
#         plt.show()

In [None]:
# # from grad_cam_interpret import GradCAMRegressor
# import torch.nn.functional as F

# images, labels, _ = next(iter(test_loader))
# images = images.to(device)
# labels = labels.to(device)

# # Example usage:
# # Assuming 'model' is your regression model and 'target_layer' is the layer you want to visualize
# for img_idx in [5, 7, 10, 13, 15]:
#     # Only select one image
#     images_ = images[img_idx, :].unsqueeze(0)
#     print(images_.shape)
#     for index in [0, 4]:
#         if index == 0:
#             param = r'$\Omega_m$'
#         elif index == 4:
#             param = r'$\sigma_8$'

#         gradcam_regressor = GradCAMRegressor(model, target_layer=model.model.transformer[-1].layer_norm_1, ground_truth_param_value=labels[:, index][img_idx], index=index)

#         # Visualize Grad-CAM for the entire image in a regression context
#         gradcam = gradcam_regressor.generate_gradcam(images_)

#         # Remove hooks after usage
#         gradcam_regressor.remove_hooks()

#         fig, ax = plt.subplots(2, 3, figsize=(15, 7))
#         show_image = images_.squeeze().cpu().detach().numpy()
#         ax[0,0].imshow(show_image)
#         ax[0,1].imshow(gradcam.cpu().detach().numpy())

#         gradcam = gradcam.unsqueeze(0).unsqueeze(0)
#         gradcam = gradcam.squeeze().reshape(16,16).unsqueeze(0).unsqueeze(0)

#         # Resize Grad-CAM to match the input image size
#         gradcamI = F.interpolate(gradcam, size=show_image.shape, mode='bilinear', align_corners=False)
#         # Convert to numpy array for visualization
#         gradcamI = gradcamI.squeeze().cpu().detach().numpy()
#         # Normalize for visualization
#         gradcamI = (gradcamI - np.min(gradcamI)) / (np.max(gradcamI) - np.min(gradcamI) + 1e-8)
#         original_image = images_[0].permute(1, 2, 0).detach().cpu().numpy()
#         original_image = (original_image - np.min(original_image)) / (np.max(original_image) - np.min(original_image) + 1e-8)
#         ax[0,2].imshow(original_image)
#         # Overlay Grad-CAM on the original image
#         ax[0,2].imshow(gradcamI, cmap='jet', alpha=0.3, interpolation='bilinear')
#         ax[0,0].set_title(param)

#         # Plot normalized density and gradcam value.
#         den_gradcam_pairs = []
#         original_image = original_image.squeeze()
#         gradcam = gradcam.squeeze().cpu().detach().numpy()
#         receptive_size = (original_image.shape[0]//gradcam.shape[0], original_image.shape[1]//gradcam.shape[1])
#         for i in range(0, original_image.shape[0], original_image.shape[0]//gradcam.shape[0]):
#             for j in range(0, original_image.shape[1], original_image.shape[1]//gradcam.shape[1]):
#                 mean_den = original_image[i:i+receptive_size[0], j:j+receptive_size[1]].mean()
#                 this_grad_cam = gradcam[i//receptive_size[0], j//receptive_size[1]]
#                 den_gradcam_pairs.append((mean_den, this_grad_cam))

#         ax[1,0].scatter([d[0] for d in den_gradcam_pairs], [d[1] for d in den_gradcam_pairs]);
#         ax[1,0].set_xlabel('Normalized density')
#         ax[1,0].set_ylabel('Normalized Grad-CAM')

#         sns.kdeplot(x=original_image.flatten(), y=gradcamI.flatten(), ax=ax[1,1], fill=True)
#         ax[1,1].set_xlabel('Normalized density')
#         ax[1,1].set_ylabel('Interpolated normalized Grad-CAM')

#         plt.show()

## Transfer learning

In [None]:
import seaborn as sns
sns.set_context("paper", font_scale = 2)
sns.set_style('whitegrid')
sns.set(style='ticks')

In [None]:
from utils import smooth_3D_field

In [None]:
halo_dirname = '/kaggle/input/halo-vit-lss-64-same-sims-no-mass-cuts'  # Sims 0-999. These are same simulations as the density fields (0-999).

In [None]:
SAME_SIMS = True  # Whether the halo and DM density simulations exactly match.

In [None]:
if SAME_SIMS:
    # Analysis of the bias parameter
    import h5py
    import os
    import glob
    import numpy as np
    from utils import read_hdf5
    import matplotlib.pyplot as plt

    from scipy.stats import iqr

    DEN_FIELD_DIRECTORY = 'my_outputs'

    if USE_COLAB:
        dpath = os.path.join('/content/', f'{DEN_FIELD_DIRECTORY}', '*.h5')
        hpath = os.path.join('/content/', f'{DEN_FIELD_DIRECTORY}_halo', '*.h5')
    else:
        dpath = os.path.join('/kaggle/input/density-fields-vit-lss-64/', f'{DEN_FIELD_DIRECTORY}', '*.h5')
        hpath = os.path.join(f'{halo_dirname}', f'{DEN_FIELD_DIRECTORY}_halo', '*.h5')

    all_bias_params = []
    dens = []
    halos = []
    for i, filename in enumerate(
        zip(
          sorted(glob.glob(dpath)),
          sorted(glob.glob(hpath))
        )
    ):
        # print(filename[0], filename[1])
        den, dparams = read_hdf5(filename[0], dataset_name='3D_density_field')
        halo, hparams = read_hdf5(filename[1], dataset_name='3D_halo_distribution')

        den = smooth_3D_field(den)
        halo = smooth_3D_field(halo)

        den_contrast = den/den.mean() - 1
        halo_contrast = halo/halo.mean() - 1

        dens.append(den_contrast)
        halos.append(halo_contrast)
    #     bias_params = (halo_contrast[np.where(den_contrast < 1)]/den_contrast[np.where(den_contrast < 1)])
        bias_params = (halo_contrast/den_contrast)

        # # Remove outliers.
        # bias_params = bias_params[(bias_params > np.quantile(bias_params, 0.01)) & (bias_params < np.quantile(bias_params, 0.99))]
        all_bias_params.append(np.median(bias_params))

    all_bias_params = np.array(all_bias_params)
    bias = np.mean(all_bias_params)

    fig, ax = plt.subplots(1, 1, figsize=(7, 5))
    ax.hist(all_bias_params.ravel(), bins=20)
    ax.set_yscale('log')
    ax.set_title(f'Bias: {bias:.2f} +/- {np.std(all_bias_params):.2f}')
    plt.show()

    # fig, ax = plt.subplots(1, 1, figsize=(7, 5))
    # ax.scatter(dens, halos, alpha=0.6)
    # plt.show()

In [None]:
if SAME_SIMS:
    from utils import power_spectrum

    # Analysis of the bias parameter
    import h5py
    import os
    import glob
    import numpy as np
    from utils import read_hdf5
    import matplotlib.pyplot as plt
    import contextlib

    DEN_FIELD_DIRECTORY = 'my_outputs'

    if USE_COLAB:
        dpath = os.path.join('/content/', f'{DEN_FIELD_DIRECTORY}', '*.h5')
        hpath = os.path.join('/content/', f'{DEN_FIELD_DIRECTORY}_halo', '*.h5')
    else:
        dpath = os.path.join('/kaggle/input/density-fields-vit-lss-64/', f'{DEN_FIELD_DIRECTORY}', '*.h5')
        hpath = os.path.join(f'{halo_dirname}', f'{DEN_FIELD_DIRECTORY}_halo', '*.h5')

    Pk_dens = []
    Pk_halos = []
    for i, filename in enumerate(
        zip(
          sorted(glob.glob(dpath)),
          sorted(glob.glob(hpath))
        )
    ):
        # print(filename[0], filename[1])
        den, dparams = read_hdf5(filename[0], dataset_name='3D_density_field')
        halo, hparams = read_hdf5(filename[1], dataset_name='3D_halo_distribution')

        with open(os.devnull, "w") as f, contextlib.redirect_stdout(f):  # Prevent unnecessary verbose output from printing on screen.
            k_den, Pk_den = power_spectrum(den, dimensional=3)
            k_halo, Pk_halo = power_spectrum(halo, dimensional=3)

        # Discard scales below the Nyquist frequency since they are unreliable.
        L = 1000   # Mpc/h
        Ng = 64
        k_Nq = (2 * np.pi / L) * (Ng / 2)
        condition_k = k_den <= k_Nq
        k_den = k_den[condition_k]
        k_halo = k_halo[condition_k]

        # Remove shot noise from halo power spectra: Pk' = Pk - (1/n_bar)
        # The field `halo` contains the effective no. of particles inside a 2D cell (n_bar).
        # And Pk has units (Mpc/h)^2, so n_bar must be divided by the area of each cell.
        n_bar = halo.mean()
        n_bar = n_bar / ((L / Ng) ** 2)
        Pk_halo = Pk_halo - (1 / n_bar)

        # Select Pk powers corresponding to the refined set of wavenumbers.
        Pk_den = Pk_den[condition_k]
        Pk_halo = Pk_halo[condition_k]

        Pk_dens.append(Pk_den)
        Pk_halos.append(Pk_halo)

    Pk_den = np.vstack(Pk_dens).mean(axis=0)
    Pk_halo = np.vstack(Pk_halos).mean(axis=0)

    assert np.all(k_den == k_halo)

    # Most likely bias from the power spectrum.
    # Following this paper: https://iopscience.iop.org/article/10.1088/0004-637X/724/2/878 >>> "We calculate b2 as the average over the 10 largest wavelength modes in the simulation".
    most_likely_bias_from_ps = np.sqrt(np.mean(Pk_halo[:10]/Pk_den[:10]))

    print(f'Most likely bias from the power spectrum: {most_likely_bias_from_ps}')
    print(f'Bias obtained from the halo_contrast/DM_density_contrast analysis in the above cell: {bias}')

    fig, ax = plt.subplots(2, 1, figsize=(6, 5))
    fig.subplots_adjust(hspace=0)
    ax[0].loglog(k_den, Pk_den, c='black', label='DM', linewidth=4)
    ax[0].loglog(k_den, Pk_halo, c='black', linestyle='--', label='Halo', linewidth=4);
    ax[0].legend();
    ax[0].set_ylabel(r'$P(k)$')
    ax[1].loglog(k_den, np.sqrt(Pk_halo/Pk_den), linewidth=4)
    ax[1].set_ylabel(r'$\sqrt{P_{halo}(k)/P_{dm}(k)}$')
    ax[1].set_xlabel(r'$k$')
    ax[1].axhline(y=bias, linestyle='--', c='gray', linewidth=2)
    ax[1].axhline(y=most_likely_bias_from_ps, linestyle='--', c='red', linewidth=2)
    # ax[1].set_ylim([np.mean(all_bias_params)-3*np.std(all_bias_params), np.mean(all_bias_params)+3*np.std(all_bias_params)])

Using the power spectrum is not the best choice since it leads to a degeneracy between bias and $\sigma_8$. See e.g., https://arxiv.org/pdf/2006.01146.pdf

QUESTION: It looks like the bias found here (~1.5-1.6) is a bit on the higher side than typically found values such as 1.15 or so (for e.g., here). Could this be because of the small grid size? Due to the small grid size, the halos may not be resolved well, and hence these estimates may not be very reliable. This is just my guess. Bias dependence on resolution was mentioned in this paper as well.

Understanding any correlation between and bias. So we plot both for each simulation in the below plot.

In [None]:
if SAME_SIMS:
    biases, sigma_8s, omegams = [], [], []
    for i, filename in enumerate(
        zip(
          sorted(glob.glob(dpath)),
          sorted(glob.glob(hpath))
        )
    ):
        den, dparams = read_hdf5(filename[0], dataset_name='3D_density_field')
        halo, hparams = read_hdf5(filename[1], dataset_name='3D_halo_distribution')

        sigma_8 = dparams[-1]
        omega_m = dparams[0]
        assert sigma_8 == hparams[-1]
        assert omega_m == hparams[0]

        den = smooth_3D_field(den)
        halo = smooth_3D_field(halo)

        den_contrast = den/den.mean() - 1
        halo_contrast = halo/halo.mean() - 1

        bias_param = np.median(halo_contrast/den_contrast)

        biases.append(bias_param)
        sigma_8s.append(sigma_8)
        omegams.append(omega_m)

    # Fit a line to the bias vs sigma_8 data points.
    from scipy import stats
    res = stats.linregress(biases, sigma_8s)  # res contains slope, intercept, r_value, p_value, std_err.

    from mpl_toolkits.axes_grid1 import make_axes_locatable

    fig, ax = plt.subplots(1, 1)
    im = ax.scatter(biases, sigma_8s, c=omegams, alpha=0.75)
    ax.plot(biases, res.intercept + res.slope * np.array(biases), 'r', label='fitted line')
    ax.set_xlabel(r'$b$')
    ax.set_ylabel(r'$\sigma_8$')
    ax.set_title(r'$b$ vs $\sigma_8$ for all 1000 simulations')
    ax.legend()
    print(fr'Equation of fitted line: $\sigma_8 = {res.slope:.2f} * bias + {res.intercept:.2f}$')

    divider = make_axes_locatable(ax)
    cax = divider.append_axes('right', size='5%', pad=0.05)
    cbar = fig.colorbar(im, cax=cax, orientation='vertical')
    cbar.ax.set_ylabel(r'$\Omega_m$', rotation=270)
    cbar.ax.get_yaxis().labelpad = 15

    fig, ax = plt.subplots(1, 1)
    sns.kdeplot(x=biases, y=sigma_8s, ax=ax, fill=True)
    ax.set_xlabel(r'$b$')
    ax.set_ylabel(r'$\sigma_8$')
    ax.set_title(r'$b$ vs $\sigma_8$ for all 1000 simulations')

One speculation: if we use bias (=1.58) to scale the halo distribution for transfer learning, then predictions for $\sigma_8$ corresponding to ranges corresponding to the bias value of 1.58, which from the above figure seems to be ~0.7-0.8, could be the most affected. One can look at the prediction plots below to ascertain this.

Plotting $\delta$ vs. $b$

In [None]:
if SAME_SIMS:
    from mpl_toolkits.axes_grid1 import make_axes_locatable

    biases, dm_density_contrasts = [], []
    counter = 0
    for i, filename in enumerate(
        zip(
          sorted(glob.glob(dpath)),
          sorted(glob.glob(hpath))
        )
    ):
        den, dparams = read_hdf5(filename[0], dataset_name='3D_density_field')
        halo, hparams = read_hdf5(filename[1], dataset_name='3D_halo_distribution')

        den = smooth_3D_field(den)
        halo = smooth_3D_field(halo)

        den_contrast = den/den.mean() - 1
        halo_contrast = halo/halo.mean() - 1

        bias_param = halo_contrast/den_contrast

        # Remove extreme (and probably unphysical) bias parameters.
        lower_limit_condition = bias_param > np.quantile(bias_param, 0.01)
        upper_limit_condition = bias_param < np.quantile(bias_param, 0.99)

        bias_param = bias_param[(lower_limit_condition) & (upper_limit_condition)]
        # Need to also do this for corresponding density contrast.
        den_contrast = den_contrast[(lower_limit_condition) & (upper_limit_condition)]

        bpf = bias_param.flatten()
        dcf = den_contrast.flatten()
        biases.append(bpf)
        dm_density_contrasts.append(dcf)

        fig, ax = plt.subplots(1, 2, figsize=(10, 5))
        im = ax[0].scatter(dcf, bpf)
        ax[0].set_xlabel(r'$\delta_{DM}$')
        ax[0].set_ylabel(r'$b$')
        ax[0].set_title(r'$\delta_{DM}$ vs $b$')
        ax[0].set_ylim([0, 3])
    #     divider = make_axes_locatable(ax)
    #     cax = divider.append_axes('right', size='5%', pad=0.05)
    #     fig.colorbar(im, cax=cax, orientation='vertical')
    #     sns.kdeplot(x=den_contrast.flatten(), y=bias_param.flatten(), ax=ax[1], fill=True)
        ax[1].hexbin(dcf, bpf)
        ax[1].set_xlabel(r'$\delta_{DM}$')
        ax[1].set_ylabel(r'$b$')
        ax[1].set_title(r'$\delta_{DM}$ vs $b$')
        ax[1].set_ylim([0, 3])
        plt.show()

    #     fig, ax = plt.subplots(1, 1)
    #     sns.kdeplot(x=dm_density_contrasts, y=biases, ax=ax, fill=True)
    #     ax.set_xlabel(r'$\delta_{DM}$')
    #     ax.set_ylabel(r'$b$')
    #     ax.set_title(r'$\delta_{DM}$ vs $b$ for all 1000 simulations')
    #     ax.set_ylim([0, 2])
    #     plt.show()

        counter += 1
        if counter == 5:
            break

Plotting the PDF of density and halo distribution for comparison.

In the below plot, we use the mean DM density for calculating the density contrast and mean halo density for the halo contrast.

In [None]:
if SAME_SIMS:
    dens, halos = [], []
    for i, filename in enumerate(
        zip(
          sorted(glob.glob(dpath)),
          sorted(glob.glob(hpath))
        )
    ):
        den, dparams = read_hdf5(filename[0], dataset_name='3D_density_field')
        halo, hparams = read_hdf5(filename[1], dataset_name='3D_halo_distribution')

        den = smooth_3D_field(den)
        halo = smooth_3D_field(halo)

        dens.append(den/den.mean())
        halos.append(halo/halo.mean())

    dens = np.array(dens)
    halos = np.array(halos)
    fig, ax = plt.subplots(1, 1)
    ax.hist(dens.ravel(), histtype='step', linewidth=3, label='DM density')
    ax.hist(halos.ravel(), histtype='step', linewidth=3, label='Halo')
    ax.set_xlabel(r'$1 + \delta$')
    ax.set_ylabel('Counts')
    ax.set_yscale('log')
    ax.legend()

In [None]:
# NOTE: use precomputed_mean, precomputed_stddev, precomputed_min_vals, and precomputed_max_vals only if you are using bias.
# Else it's better to use statistics of this new dataset for preprocessing.

#########################################################################################################

# The logic is that when the exact same sims are used for DM density and halo, we use the bias and also
# the precomputed statistics. Whereas if different simulations are used for DM and halo, we don't use the
# bias and also don't use the precomputed statistics. When same sims, the bias adds a positive value, thus
# log10(halo) does not give divide by zero error. When different sims, the same is handled by log(1+halo).

#########################################################################################################
if USE_COLAB:
    if SAME_SIMS:
        !python create_data.py --num_sims {num_sims} --train_frac 0.8 --test_frac 0.1 --seed 42 --path /content/my_outputs_halo --grid_size 64 \
        --num_maps_per_projection_direction 10 --prefix 'halos' --dataset_name '3D_halo_distribution' --bias {bias} \
        --precomputed_mean {MEAN} --precomputed_stddev {STD} \
        --precomputed_min_vals {MIN_VALS[0]} {MIN_VALS[1]} {MIN_VALS[2]} {MIN_VALS[3]} {MIN_VALS[4]} \
        --precomputed_max_vals {MAX_VALS[0]} {MAX_VALS[1]} {MAX_VALS[2]} {MAX_VALS[3]} {MAX_VALS[4]} \
        --smallest_sim_number 0
    else:  # don't use bias.
        !python create_data.py --num_sims {num_sims} --train_frac 0.8 --test_frac 0.1 --seed 42 --path /content/my_outputs_halo --grid_size 64 \
        --num_maps_per_projection_direction 10 --prefix 'halos' --dataset_name '3D_halo_distribution' \
        --smallest_sim_number 1000 --log_1_plus
#         --precomputed_mean {MEAN} --precomputed_stddev {STD} \
#         --precomputed_min_vals {MIN_VALS[0]} {MIN_VALS[1]} {MIN_VALS[2]} {MIN_VALS[3]} {MIN_VALS[4]} \
#         --precomputed_max_vals {MAX_VALS[0]} {MAX_VALS[1]} {MAX_VALS[2]} {MAX_VALS[3]} {MAX_VALS[4]} \
else:
    if SAME_SIMS:
        !python create_data.py --num_sims {num_sims} --train_frac 0.8 --test_frac 0.1 --seed 42 --path {halo_dirname}/my_outputs_halo --grid_size 64 \
        --num_maps_per_projection_direction 10 --prefix 'halos' --dataset_name '3D_halo_distribution' --bias {bias} \
        --precomputed_mean {MEAN} --precomputed_stddev {STD} \
        --precomputed_min_vals {MIN_VALS[0]} {MIN_VALS[1]} {MIN_VALS[2]} {MIN_VALS[3]} {MIN_VALS[4]} \
        --precomputed_max_vals {MAX_VALS[0]} {MAX_VALS[1]} {MAX_VALS[2]} {MAX_VALS[3]} {MAX_VALS[4]} \
        --smallest_sim_number 0
    else:  # don't use bias.'
        !python create_data.py --num_sims {num_sims} --train_frac 0.8 --test_frac 0.1 --seed 42 --path {halo_dirname}/my_outputs_halo --grid_size 64 \
        --num_maps_per_projection_direction 10 --prefix 'halos' --dataset_name '3D_halo_distribution' \
        --smallest_sim_number 1000 --log_1_plus
#         --precomputed_mean {MEAN} --precomputed_stddev {STD} \
#         --precomputed_min_vals {MIN_VALS[0]} {MIN_VALS[1]} {MIN_VALS[2]} {MIN_VALS[3]} {MIN_VALS[4]} \
#         --precomputed_max_vals {MAX_VALS[0]} {MAX_VALS[1]} {MAX_VALS[2]} {MAX_VALS[3]} {MAX_VALS[4]} \

In [None]:
# Store the mean, std, min_vals and max_vals into variables
prefix = 'halos'
MEAN = np.load(f'{prefix}_dataset_mean.npy')
STD = np.load(f'{prefix}_dataset_std.npy')
MIN_VALS = np.load(f'{prefix}_dataset_min_vals.npy')
MAX_VALS = np.load(f'{prefix}_dataset_max_vals.npy')
MEAN_DENSITIES = np.load(f'{prefix}_dataset_mean_densities.npy')

In [None]:
import gzip
import numpy as np
import glob
filename = sorted(glob.glob(f'{base_dir}/train/processed_sim*_X1_LH_z0_grid64_masCIC.npy.gz'))[0]
f = gzip.GzipFile(filename, 'r'); halo = np.load(f)

In [None]:
import pandas as pd
df = pd.read_csv('train/train_original_params.csv')
df

In [None]:
import matplotlib.pyplot as plt
v = df[df['0'] == '/'.join(filename.split('/')[-2:])]
params = list(v[v.columns[-5:]].iloc[0])

plt.imshow(halo); plt.title(np.round(params, 4)); plt.colorbar()

In [None]:
import torch
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader

from model_dataset import CustomImageDataset
from torchvision.transforms import v2
from torchvision import transforms
transform = v2.Compose([
    MyRotationTransform(angles=[90, 180, 270]),
    v2.ToDtype(torch.float32)
])

train_dataset = CustomImageDataset('/content/train', normalized_cosmo_params_path='train/train_normalized_params.csv', transform=transform)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)

val_dataset = CustomImageDataset('/content/val', normalized_cosmo_params_path='val/val_normalized_params.csv', transform=None)
val_loader = DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=False)

test_dataset = CustomImageDataset('/content/test', normalized_cosmo_params_path='test/test_normalized_params.csv', transform=None)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

In [None]:
len(train_dataset), len(val_dataset), len(test_dataset)

In [None]:
# Updated parameters for the transfer learning come here.
epochs = 10  # Smaller no. of epochs than pretraining.
# dr = dr * 2
# wd = wd * 2
# lr = lr * 0.1

# output files names
floss  = 'loss_transfer_learning_halo_grid64.txt'   #file with the training and validation losses for each epoch
fmodel = 'weights_transfer_learning_halo_grid64.pt' #file containing the weights of the best-model

In [None]:
lr, wd, dr

In [None]:
FREEZE_LAYERS = False  # Whether to freeze all layers for transfer learning. If False, all layers are retrained on the new dataset.

In [None]:
from torch_intermediate_layer_getter import IntermediateLayerGetter as MidGetter

In [None]:
class ViT_FineTune(ViT):
    def __init__(self, pretrained_filename, model_kwargs, lr, wd, beta1, beta2):
        super().__init__()
        self.save_hyperparameters()
        self.model = self.load_from_checkpoint(pretrained_filename)

        # Add a new classification layer to the model
        # TODO: When I do this, need to check if weights of self.model.mlp_head are randomized or kept pretrained.
        self.model.FC_final = nn.Linear(len(params)*2, len(params)*2)
        self.model.mlp_head = nn.Sequential(
            self.model.mlp_head,
            self.model.FC_final
        )
        # TODO: Print the model and ensure FC_final is not coming twice in the model, for example.

        self.example_input_array = next(iter(train_loader))[0]

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(
            [
                {"params": self.parameters(), "lr": 1e-4},
                {"params": self.model.FC_final.parameters(), "lr": 1e-2},
            ],
            weight_decay=self.hparams.wd, betas=(self.hparams.beta1, self.hparams.beta2)
        )
        lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.3, patience=5)

        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": lr_scheduler,
                "monitor": "val_loss"
            },
        }

In [None]:
def finetune_model(pretrained_filename, **kwargs):
    trainer = pl.Trainer(default_root_dir=os.path.join(CHECKPOINT_PATH, "ViT"),
                         accelerator="gpu" if str(device).startswith("cuda") else "cpu",
                         devices=1,
                         max_epochs=epochs,
                         callbacks=[ModelCheckpoint(save_weights_only=True, mode="min", monitor="val_loss"),
                                    LearningRateMonitor("epoch")])
    trainer.logger._log_graph = True         # If True, we plot the computation graph in tensorboard
    trainer.logger._default_hp_metric = None # Optional logging argument that we don't need

    # Check whether pretrained model exists. If yes, load it and skip training
#     pretrained_filename = os.path.join(CHECKPOINT_PATH, "ViT.ckpt")
    if os.path.isfile(pretrained_filename):
        print(f"Found pretrained model at {pretrained_filename}, loading...")
        model = ViT_FineTune.load_from_checkpoint(pretrained_filename) # Automatically loads the model with the saved hyperparameters
        # After loading the pretrained model, finetune it.
        pl.seed_everything(42) # To be reproducable
        trainer.fit(model, train_loader, val_loader)
        model = ViT_FineTune.load_from_checkpoint(trainer.checkpoint_callback.best_model_path) # Load best checkpoint after training

    # Test best model on validation and test set
    val_result = trainer.test(model, val_loader, verbose=False)
    test_result = trainer.test(model, test_loader, verbose=False)
    result = {"test": test_result[0]["test_loss"], "val": val_result[0]["test_loss"]}

    return model, result

In [None]:
pretrained_filename = os.path.join(CHECKPOINT_PATH, "ViT.ckpt")
model, results = finetune_model(
    pretrained_filename, model_kwargs=model_kwargs,
    lr=lr, wd=wd, beta1=beta1, beta2=beta2
)
print("ViT_FineTune results", results)
model.to(device)

In [None]:
from train_val_test_boilerplate import test

# Below values calculated during data preparation. See above.
minimum = MIN_VALS
maximum = MAX_VALS

params_true, params_NN, errors_NN, filenames = test(model, test_loader, g=g, h=h, device=device, minimum=minimum, maximum=maximum)

In [None]:
from evaluation_analysis import post_test_analysis

post_test_analysis(
    params_true, params_NN, errors_NN, filenames,
    test_loader, params, num_sims, MEAN, STD, MEAN_DENSITIES, minimum, maximum,
    num_maps_per_projection_direction, None, dr, channels, fmodel,
    device=device, test_results_filename='test_results_transfer_learning_ViT.csv', cka_filename='cka_matrix_transfer_learning_halo_ViT_grid64_test.png',
    smallest_sim_number=0, model=model.model,
    return_layers = {
        'transformer.0.linear.1': 'GELU',
        'transformer.1.linear.1': 'GELU',
        'transformer.2.linear.1': 'GELU',
        'transformer.3.linear.1': 'GELU',
        'transformer.4.linear.1': 'GELU',
        'transformer.5.linear.1': 'GELU'
    }
)

## Training on transfer learning data FROM SCRATCH

In [None]:
# Since this is training from scratch, don't use the --bias option since we assume we don't have information of DM density.
if USE_COLAB:
    !python create_data.py --num_sims {num_sims} --train_frac 0.8 --test_frac 0.1 --seed 42 --path /content/my_outputs_halo --grid_size 64 \
                        --num_maps_per_projection_direction 10 --prefix 'halos' --dataset_name '3D_halo_distribution' --log_1_plus --smallest_sim_number 0
else:
    !python create_data.py --num_sims {num_sims} --train_frac 0.8 --test_frac 0.1 --seed 42 --path {halo_dirname}/my_outputs_halo --grid_size 64 \
                        --num_maps_per_projection_direction 10 --prefix 'halos' --dataset_name '3D_halo_distribution' --log_1_plus --smallest_sim_number 0

In [None]:
import torch
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader

from model_dataset import CustomImageDataset
from torchvision.transforms import v2
from torchvision import transforms
transform = v2.Compose([
    MyRotationTransform(angles=[90, 180, 270]),
    v2.ToDtype(torch.float32)#, scale=False),
])

train_dataset = CustomImageDataset('/content/train', normalized_cosmo_params_path='train/train_normalized_params.csv', transform=transform)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)

val_dataset = CustomImageDataset('/content/val', normalized_cosmo_params_path='val/val_normalized_params.csv', transform=None)
val_loader = DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=False)

test_dataset = CustomImageDataset('/content/test', normalized_cosmo_params_path='test/test_normalized_params.csv', transform=None)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

TODO: I don't think floss and fmodel is used anywhere, right? If so, remove them from notebook entirely.

In [None]:
# Updated parameters for the transfer learning come here.
epochs = 25
# dr = dr * 2
# wd = wd * 2
# lr = lr * 0.1

# output files names
floss  = 'loss_transfer_learning_data_train_from_scratch_halo_grid64.txt'   #file with the training and validation losses for each epoch
fmodel = 'weights_transfer_learning_data_train_from_scratch_halo_grid64.pt' #file containing the weights of the best-model

In [None]:
model, results = train_model(
    model_kwargs=model_kwargs,
    lr=lr, wd=wd, beta1=beta1, beta2=beta2
)
print("ViT results", results)
model.to(device)

In [None]:
from train_val_test_boilerplate import test

# Below values calculated during data preparation. See above.
minimum = MIN_VALS
maximum = MAX_VALS

params_true, params_NN, errors_NN, filenames = test(model, test_loader, g=g, h=h, device=device, minimum=minimum, maximum=maximum)

In [None]:
from evaluation_analysis import post_test_analysis

post_test_analysis(
    params_true, params_NN, errors_NN, filenames,
    test_loader, params, num_sims, MEAN, STD, MEAN_DENSITIES, minimum, maximum,
    num_maps_per_projection_direction, None, dr, channels, fmodel,
    device=device, test_results_filename='test_results_transfer_learning_from_scratch_ViT.csv', cka_filename='cka_matrix_transfer_learning_from_scratch_halo_ViT_grid64_test.png',
    smallest_sim_number=0, model=model.model,
    return_layers = {
        'transformer.0.linear.1': 'GELU',
        'transformer.1.linear.1': 'GELU',
        'transformer.2.linear.1': 'GELU',
        'transformer.3.linear.1': 'GELU',
        'transformer.4.linear.1': 'GELU',
        'transformer.5.linear.1': 'GELU'
    }
)