### Tips for training a ViT:

- https://wandb.ai/dtamkus/posts/reports/5-Tips-for-Creating-Lightweight-Vision-Transformers--Vmlldzo0MjQyMzg0
- Also see the Keras tutorial: https://keras.io/examples/vision/image_classification_with_vision_transformer/
- Improving performance of ViTs on small datasets: https://keras.io/examples/vision/vit_small_ds/
- What if we use a pretrained ViT from Imagenet, for example? See [here](https://github.com/hananshafi/vits-for-small-scale-datasets):
> ...in contrast to convolutional neural networks, Vision Transformer lacks inherent inductive biases. Therefore, successful training of such models is mainly attributed to pre-training on large-scale datasets such as ImageNet with 1.2M or JFT with 300M images.
- The self-attention mechanism allows to learn the relationship between different patches in the input sequence. From [here](https://arxiv.org/pdf/2304.08192.pdf).
- ViTs: permutation-equivariant, not translation invariant, not feature heirarchy. See [the tutorial](https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/tutorial15/Vision_Transformer.html).


From Lightning documentation
- See https://lightning.ai/docs/pytorch/2.1.3/notebooks/course_UvA-DL/11-vision-transformer.html
- Advanced training tips: https://lightning.ai/docs/pytorch/2.1.3/advanced/training_tricks.html

Papers:
- https://openreview.net/pdf?id=4nPswr1KcP
- https://arxiv.org/pdf/2203.09795.pdf
- https://arxiv.org/pdf/2304.08192.pdf

In [1]:
!pip install Pylians

Collecting Pylians
  Downloading Pylians-0.11.tar.gz (3.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.3/3.3 MB[0m [31m10.8 MB/s[0m eta [36m0:00:00[0m00:01[0m0:01[0m
[?25h  Installing build dependencies ... [?25ldone
[?25h  Getting requirements to build wheel ... [?25ldone
[?25h  Installing backend dependencies ... [?25ldone
[?25h  Preparing metadata (pyproject.toml) ... [?25ldone
Collecting hdf5plugin (from Pylians)
  Using cached hdf5plugin-4.4.0-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (2.9 kB)
Collecting Cython<3.0.0 (from Pylians)
  Using cached Cython-0.29.37-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_24_x86_64.whl.metadata (3.1 kB)
Collecting pyfftw (from Pylians)
  Downloading pyFFTW-0.13.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (2.6 kB)
Using cached Cython-0.29.37-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_24_x86_64.whl (1.9 MB)
Using cac

In [2]:
!pip install torch_intermediate_layer_getter

Collecting torch_intermediate_layer_getter
  Downloading torch_intermediate_layer_getter-0.1.post1.tar.gz (3.0 kB)
  Preparing metadata (setup.py) ... [?25ldone
[?25hBuilding wheels for collected packages: torch_intermediate_layer_getter
  Building wheel for torch_intermediate_layer_getter (setup.py) ... [?25ldone
[?25h  Created wheel for torch_intermediate_layer_getter: filename=torch_intermediate_layer_getter-0.1.post1-py3-none-any.whl size=3698 sha256=0d797647894968156ba8d9c3f5eb3cb2d0e66e0f5e121c164d68d870c726a334
  Stored in directory: /root/.cache/pip/wheels/6a/11/c0/30d81aa26172d10d68ffaf352b0762eb9fe0a5f5dcf3de63e0
Successfully built torch_intermediate_layer_getter
Installing collected packages: torch_intermediate_layer_getter
Successfully installed torch_intermediate_layer_getter-0.1.post1


In [4]:
USE_COLAB = False
if USE_COLAB:
    base_dir = '/content'
else:
    base_dir = '/kaggle/working'

In [4]:
if USE_COLAB:
    !rm -rf /content/ViT-LSS
else:
    !rm -rf /kaggle/working/ViT-LSS
!git clone https://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@github.com/Yash-10/ViT-LSS.git

Cloning into 'ViT-LSS'...
remote: Enumerating objects: 587, done.[K
remote: Counting objects: 100% (228/228), done.[K
remote: Compressing objects: 100% (131/131), done.[K
remote: Total 587 (delta 130), reused 183 (delta 97), pack-reused 359[K
Receiving objects: 100% (587/587), 54.32 MiB | 34.96 MiB/s, done.
Resolving deltas: 100% (320/320), done.


In [5]:
if USE_COLAB:
    !cp /content/ViT-LSS/scripts/*.py /content
else:
    !cp /kaggle/working/ViT-LSS/scripts/*.py /kaggle/working

In [6]:
!pip install pytorch-lightning==2.1.3



In [2]:
## Standard libraries
import os
import numpy as np
import random
import math
import json
from functools import partial
from PIL import Image

## Imports for plotting
import matplotlib.pyplot as plt
plt.set_cmap('cividis')
%matplotlib inline
from IPython.display import set_matplotlib_formats
set_matplotlib_formats('svg', 'pdf') # For export
from matplotlib.colors import to_rgb
import matplotlib
matplotlib.rcParams['lines.linewidth'] = 2.0
import seaborn as sns
sns.reset_orig()

## tqdm for loading bars
from tqdm.notebook import tqdm

## PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torch.optim as optim

## Torchvision
import torchvision
from torchvision.datasets import CIFAR10
from torchvision import transforms

# PyTorch Lightning
try:
    import pytorch_lightning as pl
except ModuleNotFoundError: # Google Colab does not have PyTorch Lightning installed by default. Hence, we do it here if necessary
    !pip install --quiet pytorch-lightning>=1.4
    import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint

USE_TENSORBOARD = False
if USE_TENSORBOARD:
    # Import tensorboard
    %load_ext tensorboard

# Path to the folder where the pretrained models are saved
CHECKPOINT_PATH = "./saved_models"
# Create checkpoint path if it doesn't exist yet
os.makedirs(CHECKPOINT_PATH, exist_ok=True)

# Setting the seed
SEED = 42
pl.seed_everything(SEED)

# Ensure that all operations are deterministic on GPU (if used) for reproducibility
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
print("Device:", device)

  set_matplotlib_formats('svg', 'pdf') # For export


Device: cuda:0


<Figure size 640x480 with 0 Axes>

In [5]:
import numpy as np
import gzip
import pandas as pd
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
if USE_COLAB:
    from google.colab import drive
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader
import time, sys, os
import matplotlib.pyplot as plt

# optimizer parameters
beta1 = 0.5
beta2 = 0.999

batch_size = 64
lr         = 1e-3
wd         = 1e-5  # value of weight decay
dr         = 0.2
epochs     = 100    # number of epochs to train the network

channels        = 1                #we only consider here 1 field
params          = [0,1,2,3,4]    #Omega_m, Omega_b, h, n_s, sigma_8. The code will be trained to predict all these parameters.
g               = params           #g will contain the mean of the posterior
h               = [5+i for i in g] #h will contain the variance of the posterior

image_size = 64
patch_size = 16

model_kwargs = {
    'embed_dim': 256,
    'hidden_dim': 512,
    'num_heads': 8,
    'num_layers': 6,
    'patch_size': patch_size,
    'num_channels': 1,
    'num_patches': (image_size // patch_size) ** 2,
    'num_classes': 10,
    'dropout': dr
}

GRID_SIZE = 64

num_maps_per_projection_direction = 64
num_sims = 1000

In [9]:
import seaborn as sns
sns.set_style('whitegrid')
sns.set(style='ticks')
sns.set_context("paper", font_scale = 2)

## Visualization of fields for different parameters

In [10]:
# from utils import read_hdf5
# import glob

# o, s = [], []
# sorted_files = sorted(glob.glob('/kaggle/input/density-fields-vit-lss-64/my_outputs/*.h5'))
# for f in sorted_files:
#     _, params = read_hdf5(f)
#     o.append(params[0])
#     s.append(params[4])  # 4 or -1 will give the same.

# den_max_om, den_max_om_params = read_hdf5(sorted_files[o.index(max(o))])
# den_min_om, den_min_om_params = read_hdf5(sorted_files[o.index(min(o))])
# den_max_s8, den_max_s8_params = read_hdf5(sorted_files[s.index(max(s))])
# den_min_s8, den_min_s8_params = read_hdf5(sorted_files[s.index(min(s))])

# # Showing the same (random) slice from all four cases.
# vmin = np.log10(min(den_max_om[:, 10, :].min(), den_min_om[:, 10, :].min(), den_max_s8[:, 10, :].min(), den_min_s8[:, 10, :].min()))
# vmax = np.log10(max(den_max_om[:, 10, :].max(), den_min_om[:, 10, :].max(), den_max_s8[:, 10, :].max(), den_min_s8[:, 10, :].max()))
# plt.imshow(np.log10(den_max_om[:, 10, :]), vmin=vmin, vmax=vmax); plt.title(den_max_om_params); plt.colorbar(); plt.show()
# plt.imshow(np.log10(den_min_om[:, 10, :]), vmin=vmin, vmax=vmax); plt.title(den_min_om_params); plt.colorbar(); plt.show()
# plt.imshow(np.log10(den_max_s8[:, 10, :]), vmin=vmin, vmax=vmax); plt.title(den_max_s8_params); plt.colorbar(); plt.show()
# plt.imshow(np.log10(den_min_s8[:, 10, :]), vmin=vmin, vmax=vmax); plt.title(den_min_s8_params); plt.colorbar(); plt.show()

# Pretraining

In [11]:
if USE_COLAB:
    !wget https://www.dropbox.com/scl/fi/jqyvpxl17hp7pinqtd68c/density_fields_3D_LH_z0_grid64_masCIC.tar.gz?rlkey=cvf3oxbd922xxrzv8tue0zoiv&dl=0

In [12]:
if USE_COLAB:
    !tar -xzf /content/density_fields_3D_LH_z0_grid64_masCIC.tar.gz?rlkey=cvf3oxbd922xxrzv8tue0zoiv
else:
    pass

In [7]:
import subprocess

prefix = ''
command = [
    'python', 'create_data.py', '--num_sims', f'{num_sims}', '--train_frac', '0.8', '--test_frac', '0.1',
    '--seed', f'{SEED}', '--path', '/content/my_outputs' if USE_COLAB else '/kaggle/input/density-fields-vit-lss-64/my_outputs', '--grid_size', f'{GRID_SIZE}',
    '--num_maps_per_projection_direction', f'{num_maps_per_projection_direction}', '--prefix', '',
    '--smallest_sim_number', '0'
]
result = subprocess.run(command)
result

In [8]:
# Store the mean, std, min_vals and max_vals into variables
MEAN = np.load(f'{prefix}_dataset_mean.npy')
STD = np.load(f'{prefix}_dataset_std.npy')
MIN_VALS = np.load(f'{prefix}_dataset_min_vals.npy')
MAX_VALS = np.load(f'{prefix}_dataset_max_vals.npy')
MEAN_DENSITIES = np.load(f'{prefix}_dataset_mean_densities.npy')
print(MEAN, STD, MIN_VALS, MAX_VALS)

-0.054211672 0.212188 [0.1003  0.03003 0.5003  0.8001  0.6001 ] [0.4997  0.06993 0.8999  1.1999  0.9985 ]


In [9]:
!ls train | wc -l
!ls val | wc -l
!ls test | wc -l

153602
19202
19202


In [10]:
from model_dataset import CustomImageDataset
from torchvision.transforms import v2

import torchvision.transforms.functional as TF

import random
# See https://pytorch.org/vision/0.15/transforms.html
class MyRotationTransform:
    """Rotate by 90/180/270 degrees."""

    def __init__(self, angles):
        self.angles = angles

    def __call__(self, x):
        angle = random.choice(self.angles)
        x = torch.from_numpy(x).unsqueeze(0)
        return TF.rotate(x, angle).squeeze()

In [67]:
transform = v2.Compose([
    MyRotationTransform(angles=[90, 180, 270]),
    v2.ToDtype(torch.float32),#, scale=False),
])

train_dataset = CustomImageDataset(f'{base_dir}/train', normalized_cosmo_params_path=f'{base_dir}/train/train_normalized_params.csv', transform=transform)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, num_workers=3)

val_dataset = CustomImageDataset(f'{base_dir}/val', normalized_cosmo_params_path=f'{base_dir}/val/val_normalized_params.csv', transform=None)
val_loader = DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=False, num_workers=3)

test_dataset = CustomImageDataset(f'{base_dir}/test', normalized_cosmo_params_path=f'{base_dir}/test/test_normalized_params.csv', transform=None)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False, num_workers=3)
# num_workers=3 was suggested by PyTorch Lightning while running the `train_model` function on Kaggle.

In [68]:
train_dataset = torch.utils.data.Subset(train_dataset, indices=[100, 200, 300])  # indices=list(range(2))
train_loader = DataLoader(dataset=train_dataset, batch_size=1, shuffle=False, num_workers=3)
len(train_dataset)

3

In [34]:
for x, y, _ in train_loader:
    print(y * (MAX_VALS-MIN_VALS) + MIN_VALS)

tensor([[0.1755, 0.0668, 0.7737, 0.8849, 0.6641],
        [0.4149, 0.0454, 0.5761, 0.9911, 0.6283]])


TODO: Sometime in future, we may want to separate these functions and modules into a python script and just load from there. Currently, it seemed a bit uncomfortable to do that since these modules/functions/classes depend on global variables defined only in the notebook. Although, this is not necessarily a problem, it may become difficult to understand the modules/functions/classes if it's not in the notebook.

In [14]:
def img_to_patch(x, patch_size, flatten_channels=True):
    """
    Inputs:
        x - torch.Tensor representing the image of shape [B, C, H, W]
        patch_size - Number of pixels per dimension of the patches (integer)
        flatten_channels - If True, the patches will be returned in a flattened format
                           as a feature vector instead of a image grid.
    """
    B, C, H, W = x.shape
    x = x.reshape(B, C, H//patch_size, patch_size, W//patch_size, patch_size)
    x = x.permute(0, 2, 4, 1, 3, 5) # [B, H', W', C, p_H, p_W]
    x = x.flatten(1,2)              # [B, H'*W', C, p_H, p_W]
    if flatten_channels:
        x = x.flatten(2,4)          # [B, H'*W', C*p_H*p_W]
    return x

# x = next(iter(train_loader))
# print(x[0].shape)

# # See original code in the original tutorial: https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/tutorial15/Vision_Transformer.html
# img_patches = img_to_patch(x[0], patch_size=16, flatten_channels=False)

# fig, ax = plt.subplots(1, 1, figsize=(10,10))
# fig.suptitle("Images as input sequences of patches")
# for i in range(x[0].shape[0]):
#     img_grid = torchvision.utils.make_grid(img_patches[i], nrow=4, normalize=False, pad_value=0.9)
#     img_grid = img_grid.permute(1, 2, 0)
#     ax.imshow(img_grid)
#     ax.axis('off')
#     break

# plt.tight_layout()
# plt.show()
# plt.close()

TODO: Remember that the positional embedding (`self.pos_embedding`) below is learnable assuming we use fixed-resolution images. In the future, we would like to transfer learn to different resolution images, in which case we need to use the sine and cosine functions proposed in the original ViT paper.

In [15]:
from utils import get_rmse_score

2024-03-20 07:01:35.160034: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-03-20 07:01:35.160144: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-03-20 07:01:35.356009: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [16]:
class AttentionBlock(nn.Module):

    def __init__(self, embed_dim, hidden_dim, num_heads, dropout=0.0):
        """
        Inputs:
            embed_dim - Dimensionality of input and attention feature vectors
            hidden_dim - Dimensionality of hidden layer in feed-forward network
                         (usually 2-4x larger than embed_dim)
            num_heads - Number of heads to use in the Multi-Head Attention block
            dropout - Amount of dropout to apply in the feed-forward network
        """
        super().__init__()

        self.layer_norm_1 = nn.LayerNorm(embed_dim)
        self.attn = nn.MultiheadAttention(embed_dim, num_heads,
                                          dropout=dropout)
        self.layer_norm_2 = nn.LayerNorm(embed_dim)
        self.linear = nn.Sequential(
            nn.Linear(embed_dim, hidden_dim),
            nn.GELU(),
#             nn.Dropout(dropout),
            nn.Linear(hidden_dim, embed_dim),
#             nn.Dropout(dropout)
        )


    def forward(self, x):
        inp_x = self.layer_norm_1(x)
        x = x + self.attn(inp_x, inp_x, inp_x)[0]
        x = x + self.linear(self.layer_norm_2(x))
        return x


class VisionTransformer(nn.Module):

    def __init__(self, embed_dim, hidden_dim, num_channels, num_heads, num_layers, num_classes, patch_size, num_patches, dropout=0.0):
        """
        Inputs:
            embed_dim - Dimensionality of the input feature vectors to the Transformer
            hidden_dim - Dimensionality of the hidden layer in the feed-forward networks
                         within the Transformer
            num_channels - Number of channels of the input (3 for RGB)
            num_heads - Number of heads to use in the Multi-Head Attention block
            num_layers - Number of layers to use in the Transformer
            num_classes - Number of classes to predict
            patch_size - Number of pixels that the patches have per dimension
            num_patches - Maximum number of patches an image can have
            dropout - Amount of dropout to apply in the feed-forward network and
                      on the input encoding
        """
        super().__init__()

        self.patch_size = patch_size

        # Layers/Networks
        self.input_layer = nn.Linear(num_channels*(patch_size**2), embed_dim)
        self.transformer = nn.Sequential(*[AttentionBlock(embed_dim, hidden_dim, num_heads, dropout=dropout) for _ in range(num_layers)])

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(embed_dim),
            nn.Linear(embed_dim, num_classes)
        )

#         self.mlp_head = nn.Sequential(
#             nn.LayerNorm(embed_dim),
#             nn.Linear(embed_dim, embed_dim//2),
#             nn.GELU(),
#             nn.Dropout(p=0.1),
#             nn.Linear(embed_dim//2, embed_dim//4),
#             nn.GELU(),
#             nn.Dropout(p=0.1),
#             nn.Linear(embed_dim//4, num_classes)
#         )
#         self.dropout = nn.Dropout(dropout)

        # Parameters/Embeddings
        self.cls_token = nn.Parameter(torch.randn(1,1,embed_dim))
        self.pos_embedding = nn.Parameter(torch.randn(1,1+num_patches,embed_dim))


    def forward(self, x):
        # Preprocess input
        x = img_to_patch(x, self.patch_size)
        B, T, _ = x.shape
        x = self.input_layer(x)

        # Add CLS token and positional encoding
        cls_token = self.cls_token.repeat(B, 1, 1)
        x = torch.cat([cls_token, x], dim=1)
        x = x + self.pos_embedding[:,:T+1]

        # Apply Transforrmer
#         x = self.dropout(x)
        x = x.transpose(0, 1)
        x = self.transformer(x)

        # Perform regression prediction
        cls = x[0]
        out = self.mlp_head(cls)
        
        # enforce the errors to be positive
        y = torch.clone(out)
        y[:,5:10] = torch.square(out[:,5:10])

        return y


class ViT(pl.LightningModule):

    def __init__(self, model_kwargs, lr, wd, beta1, beta2, minimum, maximum):
        super().__init__()
        self.save_hyperparameters()
        self.model = VisionTransformer(**model_kwargs)
        self.example_input_array = next(iter(train_loader))[0]

        self.maximum = maximum
        self.minimum = minimum

    def forward(self, x):
        # NOTE: See https://lightning.ai/docs/pytorch/2.1.3/starter/style_guide.html#forward-vs-training-step
        # forward is recommended to be used for prediction/inference, whereas for actual training, training_step is recommended.
        return self.model(x)

    def configure_optimizers(self):
        optimizer = optim.AdamW(self.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.wd, betas=(self.hparams.beta1, self.hparams.beta2))
        lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.3, patience=5)
#         return [optimizer], [lr_scheduler]

        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": lr_scheduler,
                "monitor": "val_loss"
            },
        }

    def _calculate_loss(self, batch, mode="train"):
        x, y, _ = batch
        p = self.model(x)
        y_NN = p[:,g]             #posterior mean
        e_NN = p[:,h]             #posterior std

        y_NN = y_NN[:, [0, 4]]
        e_NN = e_NN[:, [0, 4]]
        y = y[:, [0, 4]]

        loss1 = torch.mean((y_NN - y)**2,                axis=0)
        loss2 = torch.mean(((y_NN - y)**2 - e_NN**2)**2, axis=0)
        loss  = torch.mean(torch.log(loss1) + torch.log(loss2))
        # NOTE: See logging for more information: https://lightning.ai/docs/pytorch/2.1.3/extensions/logging.html
        # Not sure if the below logic is even needed, but should be fine.
        if mode == "train" or mode == 'val':
            # To match the CNN training code, we need to log the train and val
            # loss after each batch, and also after each epoch.
            self.log(f'{mode}_loss', loss, on_step=True, on_epoch=True, prog_bar=True, batch_size=batch_size)
        elif mode == 'test':
            # For testing, logging the loss after each step is not required.
            # So we only log after the epoch. For testing, there will be one epoch only.
            self.log(f'{mode}_loss', loss, on_step=False, on_epoch=True, prog_bar=True, batch_size=batch_size)

        # TODO: Where are the logs/logged values stored? Need to find and print manually at end of training outside this function.

        if mode == 'val' or mode == 'train':
            # Untransform the parameters for the sake of calculating RMSE and sigma_bar.
            # `minimum` and `maximum` must be defined globally.
            y = y.cpu().detach().numpy() * (self.maximum - self.minimum) + self.minimum
            y_NN   = y_NN.cpu().detach().numpy() * (self.maximum - self.minimum) + self.minimum
            e_NN   = e_NN.cpu().detach().numpy() * (self.maximum - self.minimum)

            # Also log RMSE and sigma_bar for all parameters.
            rmse = get_rmse_score(y, y_NN)
            sigma_bar = np.mean(y_NN, axis=0)
            # Only log at the end of epoch instead of each step.
            # Logging is only done for Omega_m and sigma_8 since only these are interesting for DM density/DM halo fields.
            # But more can easily be added here if and when needed.
            metrics_to_log = {
                f'{mode}_omegam_rmse': rmse[0],
                f'{mode}_omegam_sigma_bar': sigma_bar[0],
                f'{mode}_sigma8_rmse': rmse[-1],
                f'{mode}_sigma8_sigma_bar': sigma_bar[-1]
            }
            self.log_dict(metrics_to_log, on_step=False, on_epoch=True, prog_bar=True, batch_size=batch_size)

        return loss

    def training_step(self, batch, batch_idx):
        loss = self._calculate_loss(batch, mode="train")
        return loss

    def validation_step(self, batch, batch_idx):
        self._calculate_loss(batch, mode="val")

    def test_step(self, batch, batch_idx):
        self._calculate_loss(batch, mode="test")

In [23]:
import wandb
import os
YOUR_KEY = '2694f9864be3a11448f361be00fce534845723b9'
os.environ['WAND_AUTH_LOGIN_KEY'] = YOUR_KEY
wandb.login(key=os.environ.get('WAND_AUTH_LOGIN_KEY'))

[34m[1mwandb[0m: W&B API key is configured. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

In [24]:
from pytorch_lightning.loggers import WandbLogger
logger_csv = pl.loggers.CSVLogger(CHECKPOINT_PATH, name="lightning_logs_csv")
#WANDB_RUN_NAME = f'ViT_DiffSimsNomasscut-batchsize-{batch_size}_lr-{lr}_epochs-{epochs}_wd-{wd}_dr-{model_kwargs["dropout"]}'
WANDB_RUN_NAME = f'Vit-overfit2examples11'
wandb_logger = WandbLogger(name=WANDB_RUN_NAME, project='Cosmo-parameter-inference')
wandb_logger.experiment.config.update({"batch_size": batch_size, "epochs": epochs})

def train_model(**kwargs):
    # See https://lightning.ai/docs/pytorch/2.1.3/common/trainer.html#reproducibility
    pl.seed_everything(SEED, workers=True) # To be reproducible
    trainer = pl.Trainer(default_root_dir=os.path.join(CHECKPOINT_PATH, "ViT"),
                         accelerator="gpu" if str(device).startswith("cuda") else "cpu",
                         devices=1,
                         max_epochs=epochs,
                         logger=[logger_csv, wandb_logger],
                         log_every_n_steps=int(len(train_dataset)/batch_size),
#                          progress_bar_refresh_rate=50,  # recommended for Kaggle/Colab here: https://www.youtube.com/watch?v=-XakoRiMYCg
#                          callbacks=[
#                              ModelCheckpoint(save_weights_only=True, mode="min", monitor="val_loss"),
#                              LearningRateMonitor("epoch")
#                          ],
#                          deterministic=True
                        )
#     trainer.logger._log_graph = True         # If True, we plot the computation graph in tensorboard
#     trainer.logger._default_hp_metric = None # Optional logging argument that we don't need

    # Check whether pretrained model exists. If yes, load it and skip training
    pretrained_filename = os.path.join(CHECKPOINT_PATH, "ViT.ckpt")
    if os.path.isfile(pretrained_filename):
        print(f"Found pretrained model at {pretrained_filename}, loading...")
        model = ViT.load_from_checkpoint(pretrained_filename) # Automatically loads the model with the saved hyperparameters
    else:
        model = ViT(**kwargs)
        trainer.fit(model, train_loader, val_loader)
        # Since we use every_n_epochs=None, by default the model is checkpointed after each epoch and the best is selected.
        model = ViT.load_from_checkpoint(trainer.checkpoint_callback.best_model_path) # Load best checkpoint after training

    # Test best model on validation and test set
    val_result = trainer.test(model, val_loader, verbose=False)
    test_result = trainer.test(model, test_loader, verbose=False)
    result = {"test": test_result[0]["test_loss"], "val": val_result[0]["test_loss"]}

    return model, result, trainer.checkpoint_callback.best_model_path

[34m[1mwandb[0m: Currently logged in as: [33myash10[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [32]:
!pip install torchsummary

Collecting torchsummary
  Downloading torchsummary-1.5.1-py3-none-any.whl.metadata (296 bytes)
Downloading torchsummary-1.5.1-py3-none-any.whl (2.8 kB)
Installing collected packages: torchsummary
Successfully installed torchsummary-1.5.1


In [33]:
from torchsummary import summary

def load_model_for_torch_summary(**kwargs):
    model_torch_summary = ViT(**kwargs)
    return model_torch_summary

model_torch_summary = load_model_for_torch_summary(model_kwargs=model_kwargs, lr=lr, wd=wd, beta1=beta1, beta2=beta2, minimum=MIN_VALS, maximum=MAX_VALS)
summary(model_torch_summary.to(device), (1, 64, 64))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Linear-1              [-1, 16, 256]          65,792
         LayerNorm-2               [-1, 2, 256]             512
MultiheadAttention-3  [[-1, 2, 256], [-1, 17, 17]]               0
         LayerNorm-4               [-1, 2, 256]             512
            Linear-5               [-1, 2, 512]         131,584
              GELU-6               [-1, 2, 512]               0
            Linear-7               [-1, 2, 256]         131,328
    AttentionBlock-8               [-1, 2, 256]               0
         LayerNorm-9               [-1, 2, 256]             512
MultiheadAttention-10  [[-1, 2, 256], [-1, 17, 17]]               0
        LayerNorm-11               [-1, 2, 256]             512
           Linear-12               [-1, 2, 512]         131,584
             GELU-13               [-1, 2, 512]               0
           Linear-14            

In [25]:
epochs=30

In [26]:
MIN_VALS=MIN_VALS[[0, 4]]
MAX_VALS=MAX_VALS[[0, 4]]
model, results, PRETRAINED_FILENAME = train_model(
    model_kwargs=model_kwargs,
    lr=1e-6, wd=wd, beta1=beta1, beta2=beta2,
    minimum=MIN_VALS, maximum=MAX_VALS
)
print("ViT results", results)
model.to(device)

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Testing: |          | 0/? [00:00<?, ?it/s]

Testing: |          | 0/? [00:00<?, ?it/s]

ViT results {'test': -3.6930603981018066, 'val': -4.03756046295166}


ViT(
  (model): VisionTransformer(
    (input_layer): Linear(in_features=256, out_features=256, bias=True)
    (transformer): Sequential(
      (0): AttentionBlock(
        (layer_norm_1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)
        )
        (layer_norm_2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (linear): Sequential(
          (0): Linear(in_features=256, out_features=512, bias=True)
          (1): GELU(approximate='none')
          (2): Linear(in_features=512, out_features=256, bias=True)
        )
      )
      (1): AttentionBlock(
        (layer_norm_1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)
        )
        (layer_norm_2): LayerNorm((256,), eps=1e-05, eleme

In [72]:
########### Manual training
def train(
        model, train_loader, epochs, optimizer, scheduler,
        fmodel='weights.pt', floss='loss.txt', g=[0,1,2,3,4], h=[5,6,7,8,9], device='cpu',
        minimum=None, maximum=None
):
    # do a loop over all epochs
    start = time.time()
    for epoch in range(epochs):
        # do training
#         train_loss1, train_loss2 = torch.zeros(len(g)).to(device), torch.zeros(len(g)).to(device)
        train_loss1, train_loss2 = torch.zeros(1).to(device), torch.zeros(1).to(device)
        train_loss, points = 0.0, 0
        model.train()
        for x, y, _ in train_loader:
            bs   = x.shape[0]         #batch size
            x    = x.to(device)       #maps
            y    = y.to(device)[:,g]  #parameters
            p    = model(x)           #NN output
            y_NN = p[:,g]             #posterior mean
            e_NN = p[:,h]             #posterior std

            y_NN = y_NN[:, [0]]
            e_NN = e_NN[:, [0]]
            y = y[:, [0]]

#             print(y_NN, y)

            loss1 = torch.mean((y_NN - y)**2,                axis=0)
#             loss2 = torch.mean(((y_NN - y)**2 - e_NN**2)**2, axis=0)
#             loss  = torch.mean(torch.log(loss1) + torch.log(loss2))
            loss = torch.mean(torch.log(loss1))
            train_loss1 += loss1*bs
#             train_loss2 += loss2*bs
            points      += bs
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            #if points>18000:  break
#         train_loss = torch.log(train_loss1/points) + torch.log(train_loss2/points)
            train_loss = torch.log(train_loss1/points)
        train_loss = torch.mean(train_loss).item()
        print(f'train_loss: {train_loss}')

    stop = time.time()
    print('Time take (h):', "{:.4f}".format((stop-start)/3600.0))

    return model

class model_o3_err(nn.Module): # TODO: This is not used in the notebooks currently.
    def __init__(self, hidden, dr, channels):
        super(model_o3_err, self).__init__()

        # input: 1x64x64 ---------------> output: 2*hiddenx32x32  # These dimensions are written assuming 64^3 density field.
        self.C01 = nn.Conv2d(channels,  2*hidden, kernel_size=3, stride=2, padding=1,
                            padding_mode='circular', bias=True)
#         self.C02 = nn.Conv2d(2*hidden,  2*hidden, kernel_size=3, stride=1, padding=1,
#                             padding_mode='circular', bias=True)
#         self.C03 = nn.Conv2d(2*hidden,  2*hidden, kernel_size=2, stride=2, padding=0,
#                             padding_mode='circular', bias=True)
        self.B01 = nn.BatchNorm2d(2*hidden)
#         self.B02 = nn.BatchNorm2d(2*hidden)
#         self.B03 = nn.BatchNorm2d(2*hidden)

        # input: 2*hiddenx32x32 ----------> output: 4*hiddenx16x16
        self.C11 = nn.Conv2d(2*hidden, 4*hidden, kernel_size=3, stride=2, padding=1,
                            padding_mode='circular', bias=True)
#         self.C12 = nn.Conv2d(4*hidden, 4*hidden, kernel_size=3, stride=1, padding=1,
#                             padding_mode='circular', bias=True)
#         self.C13 = nn.Conv2d(4*hidden, 4*hidden, kernel_size=2, stride=2, padding=0,
#                             padding_mode='circular', bias=True)
        self.B11 = nn.BatchNorm2d(4*hidden)
#         self.B12 = nn.BatchNorm2d(4*hidden)
#         self.B13 = nn.BatchNorm2d(4*hidden)

        # input: 4*hiddenx16x16 --------> output: 8*hiddenx8x8
        self.C21 = nn.Conv2d(4*hidden, 8*hidden, kernel_size=3, stride=2, padding=1,
                            padding_mode='circular', bias=True)
#         self.C22 = nn.Conv2d(8*hidden, 8*hidden, kernel_size=3, stride=1, padding=1,
#                             padding_mode='circular', bias=True)
#         self.C23 = nn.Conv2d(8*hidden, 8*hidden, kernel_size=2, stride=2, padding=0,
#                             padding_mode='circular', bias=True)
        self.B21 = nn.BatchNorm2d(8*hidden)
#         self.B22 = nn.BatchNorm2d(8*hidden)
#         self.B23 = nn.BatchNorm2d(8*hidden)

        # input: 8*hiddenx8x8 ----------> output: 16*hiddenx4x4
        self.C31 = nn.Conv2d(8*hidden,  16*hidden, kernel_size=3, stride=2, padding=1,
                            padding_mode='circular', bias=True)
#         self.C32 = nn.Conv2d(16*hidden, 16*hidden, kernel_size=3, stride=1, padding=1,
#                             padding_mode='circular', bias=True)
#         self.C33 = nn.Conv2d(16*hidden, 16*hidden, kernel_size=2, stride=2, padding=0,
#                             padding_mode='circular', bias=True)
        self.B31 = nn.BatchNorm2d(16*hidden)
#         self.B32 = nn.BatchNorm2d(16*hidden)
#         self.B33 = nn.BatchNorm2d(16*hidden)

        # input: 16*hiddenx4x4 ----------> output: 32*hiddenx1x1
        self.C41 = nn.Conv2d(16*hidden, 32*hidden, kernel_size=4, stride=2, padding=0,
                            padding_mode='circular', bias=True)

        self.B41 = nn.BatchNorm2d(32*hidden)

        self.P0  = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)

#         self.FC1  = nn.Linear(32*hidden, 16*hidden)
#         self.FC2  = nn.Linear(16*hidden, 10)

#         self.dropout   = nn.Dropout(p=dr)
        self.ReLU      = nn.ReLU()
        self.LeakyReLU = nn.LeakyReLU(0.2)
        self.tanh      = nn.Tanh()

        self.mlp_head = nn.Sequential(
                nn.Linear(32*hidden, 16*hidden),
                self.LeakyReLU,
#                 self.dropout,
                nn.Linear(16*hidden, 1)
        )
        self.mlp_head = nn.Sequential(
                nn.Linear(32*hidden, 1),
                # self.LeakyReLU,
                # self.dropout,
                # nn.Linear(16*hidden, 10)
        )

        for m in self.modules():
            if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight)


    def forward(self, image):

        x = self.LeakyReLU(self.C01(image))
#         x = self.LeakyReLU(self.B02(self.C02(x)))
#         x = self.LeakyReLU(self.B03(self.C03(x)))

        x = self.LeakyReLU(self.B11(self.C11(x)))
#         x = self.LeakyReLU(self.B12(self.C12(x)))
#         x = self.LeakyReLU(self.B13(self.C13(x)))

        x = self.LeakyReLU(self.B21(self.C21(x)))
#         x = self.LeakyReLU(self.B22(self.C22(x)))
#         x = self.LeakyReLU(self.B23(self.C23(x)))

        x = self.LeakyReLU(self.B31(self.C31(x)))
#         x = self.LeakyReLU(self.B32(self.C32(x)))
#         x = self.LeakyReLU(self.B33(self.C33(x)))

        x = self.LeakyReLU(self.B41(self.C41(x)))

        x = x.view(image.shape[0], -1)
#         x = self.dropout(x)

        # The MLP head implements the two commented lines below.
        x = self.mlp_head(x)
#         x = self.dropout(self.LeakyReLU(self.FC1(x)))
#         x = self.FC2(x)

        # enforce the errors to be positive
#         y = torch.clone(x)
#         y[:,5:10] = torch.square(x[:,5:10])

        return x

# model = model_o3_err(4, 0, 1)
model = VisionTransformer(
    model_kwargs['embed_dim'], model_kwargs['hidden_dim'], 1, 16,
    12, 10, 4, 256, dropout=model_kwargs['dropout']
)
# (embed_dim, hidden_dim, num_channels, num_heads, num_layers, num_classes, patch_size, num_patches, dropout=0.0)

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-6, weight_decay=wd, betas=(beta1, beta2))
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.3, patience=5)

train(model, train_loader, epochs=100, optimizer=optimizer, scheduler=scheduler, minimum=MIN_VALS, maximum=MAX_VALS)

train_loss: -0.8501911759376526
train_loss: -1.1425760984420776
train_loss: -1.3234699964523315
train_loss: -1.5420111417770386
train_loss: -1.906543254852295
train_loss: -2.2279205322265625
train_loss: -2.901834726333618
train_loss: -3.7917938232421875
train_loss: -5.895630836486816
train_loss: -6.683720111846924
train_loss: -6.33296012878418
train_loss: -4.864938735961914
train_loss: -5.941388130187988
train_loss: -6.695052146911621
train_loss: -5.217763423919678
train_loss: -5.020942687988281
train_loss: -5.788515090942383
train_loss: -7.324041843414307
train_loss: -7.627298831939697
train_loss: -6.787846565246582
train_loss: -7.730579376220703
train_loss: -7.345180034637451
train_loss: -6.281937599182129
train_loss: -6.148342132568359
train_loss: -7.48093843460083
train_loss: -4.893425941467285
train_loss: -4.188638687133789
train_loss: -4.027434825897217
train_loss: -4.3242411613464355
train_loss: -4.126983165740967
train_loss: -4.219609260559082
train_loss: -4.566134929656982
tra

KeyboardInterrupt: 

In [54]:
model_kwargs

{'embed_dim': 256,
 'hidden_dim': 512,
 'num_heads': 8,
 'num_layers': 6,
 'patch_size': 16,
 'num_channels': 1,
 'num_patches': 16,
 'num_classes': 10,
 'dropout': 0.2}

In [70]:
for x, y, _ in train_loader:
    bs   = x.shape[0]         #batch size
#     x    = x.to(device)       #maps
#     y    = y.to(device)[:,g]  #parameters
    p    = model(x)           #NN output
    y_NN = p[:,g]             #posterior mean
    e_NN = p[:,h]             #posterior std

    y_NN = y_NN[:, [0, 4]]
    e_NN = e_NN[:, [0, 4]]
    y = y[:, [0, 4]]

    print(y)
    print(y_NN)

ValueError: Expected more than 1 value per channel when training, got input size torch.Size([1, 128, 1, 1])

In [None]:
from train_val_test_boilerplate import test

# Below values calculated during data preparation. See above.
minimum = MIN_VALS
maximum = MAX_VALS

params_true, params_NN, errors_NN, filenames = test(model, test_loader, g=g, h=h, device=device, minimum=minimum, maximum=maximum)

In [None]:
if USE_TENSORBOARD:
    # Opens tensorboard in notebook. Adjust the path to your CHECKPOINT_PATH!
    # The exact command is taken from https://lightning.ai/docs/pytorch/stable/visualize/logging_basic.html
    %reload_ext tensorboard
    %tensorboard --logdir ./saved_models/tensorboards/

In [None]:
from evaluation_analysis import post_test_analysis, get_cka

post_test_analysis(
    params_true, params_NN, errors_NN, filenames,
    params, num_sims, MEAN, STD, MEAN_DENSITIES, minimum, maximum,
    num_maps_per_projection_direction, test_results_filename='test_results.csv',
    smallest_sim_number=0
)

get_cka(
    model.model, test_loader,
    return_layers = {
        'transformer.0.linear.1': 'GELU',
        'transformer.1.linear.1': 'GELU',
        'transformer.2.linear.1': 'GELU',
        'transformer.3.linear.1': 'GELU',
        'transformer.4.linear.1': 'GELU',
        'transformer.5.linear.1': 'GELU'
    },
    cka_filename='cka_matrix_pretrained_ViT_grid64_test.png',
    device=device
)

## Interpreting the ViT

See [here](https://github.com/jacobgil/pytorch-grad-cam/issues/140) for more information.

TODO: Grad-CAM section below is commented since the code didn't work. Try to make it work.

In [None]:
# model.model.transformer[-1].layer_norm_1

In [None]:
# # from grad_cam_interpret import GradCAMRegressor
# import torch.nn.functional as F

# images, labels, _ = next(iter(test_loader))
# images = images.to(device)
# labels = labels.to(device)

# # Example usage:
# # Assuming 'model' is your regression model and 'target_layer' is the layer you want to visualize
# for img_idx in [5, 7, 10, 13, 15]:
#     # Only select one image
#     images_ = images[img_idx, :].unsqueeze(0)
#     print(images_.shape)
#     for index in [0, 4]:
#         if index == 0:
#             param = r'$\Omega_m$'
#         elif index == 4:
#             param = r'$\sigma_8$'

#         gradcam_regressor = GradCAMRegressor(model, target_layer=model.model.transformer[-1].layer_norm_1, ground_truth_param_value=labels[:, index][img_idx], index=index)

#         # Visualize Grad-CAM for the entire image in a regression context
#         gradcam = gradcam_regressor.generate_gradcam(images_)

#         # Remove hooks after usage
#         gradcam_regressor.remove_hooks()

#         fig, ax = plt.subplots(2, 3, figsize=(15, 7))
#         show_image = images_.squeeze().cpu().detach().numpy()
#         ax[0,0].imshow(show_image)
#         ax[0,1].imshow(gradcam.cpu().detach().numpy())

#         gradcam = gradcam.unsqueeze(0).unsqueeze(0)
#         gradcam = gradcam.squeeze().reshape(16,16).unsqueeze(0).unsqueeze(0)

#         # Resize Grad-CAM to match the input image size
#         gradcamI = F.interpolate(gradcam, size=show_image.shape, mode='bilinear', align_corners=False)
#         # Convert to numpy array for visualization
#         gradcamI = gradcamI.squeeze().cpu().detach().numpy()
#         # Normalize for visualization
#         gradcamI = (gradcamI - np.min(gradcamI)) / (np.max(gradcamI) - np.min(gradcamI) + 1e-8)
#         original_image = images_[0].permute(1, 2, 0).detach().cpu().numpy()
#         original_image = (original_image - np.min(original_image)) / (np.max(original_image) - np.min(original_image) + 1e-8)
#         ax[0,2].imshow(original_image)
#         # Overlay Grad-CAM on the original image
#         ax[0,2].imshow(gradcamI, cmap='jet', alpha=0.3, interpolation='bilinear')
#         ax[0,0].set_title(param)

#         # Plot normalized density and gradcam value.
#         den_gradcam_pairs = []
#         original_image = original_image.squeeze()
#         gradcam = gradcam.squeeze().cpu().detach().numpy()
#         receptive_size = (original_image.shape[0]//gradcam.shape[0], original_image.shape[1]//gradcam.shape[1])
#         for i in range(0, original_image.shape[0], original_image.shape[0]//gradcam.shape[0]):
#             for j in range(0, original_image.shape[1], original_image.shape[1]//gradcam.shape[1]):
#                 mean_den = original_image[i:i+receptive_size[0], j:j+receptive_size[1]].mean()
#                 this_grad_cam = gradcam[i//receptive_size[0], j//receptive_size[1]]
#                 den_gradcam_pairs.append((mean_den, this_grad_cam))

#         ax[1,0].scatter([d[0] for d in den_gradcam_pairs], [d[1] for d in den_gradcam_pairs]);
#         ax[1,0].set_xlabel('Normalized density')
#         ax[1,0].set_ylabel('Normalized Grad-CAM')

#         sns.kdeplot(x=original_image.flatten(), y=gradcamI.flatten(), ax=ax[1,1], fill=True)
#         ax[1,1].set_xlabel('Normalized density')
#         ax[1,1].set_ylabel('Interpolated normalized Grad-CAM')

#         plt.show()

In [None]:
import shutil
shutil.rmtree(f'{base_dir}/train')
shutil.rmtree(f'{base_dir}/val')
shutil.rmtree(f'{base_dir}/test')

In [27]:
wandb.finish()

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇████
test_loss,▁█
train_loss_epoch,▅█▆▅▆▆▆▇▄▇▄▄▅▅▅▅▄▃▄▃▁▃▄▄▄▃▃▁▃▃
train_omegam_rmse,█▆▃▄▇▄▄▄▅▄▅▆▆▄▄▅▃▁▄▂▃▅▂▆▁▅▃▆▃▅
train_omegam_sigma_bar,▁▅▆▅▁▆▅▇▃▆▄▄▄▆▆▄▆█▇▅▇▅▆▄█▃▆▆▅▄
train_sigma8_rmse,▅▆▅█▄▄█▂▁▆▅▄▅▃▅▅▄▂▄▄▄▂▇▃▆▃▄▃▃▄
train_sigma8_sigma_bar,▄▃▄▁▄▅▁▇█▃▃▆▄▆▅▅▅▇▅▅▅▇▂▇▄▆▄▇█▆
trainer/global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
val_loss_epoch,███▇▇▇▇▇▆▆▆▆▅▅▅▅▄▄▄▄▄▄▄▃▃▃▂▂▁▁
val_loss_step,▅▆▆▄▃▇▆▂▆█▄▄▂▅█▆▆▃▇▂▅▄▄▃▃▁▅▆▂▇▅▄▇▅▇▁▆▇▆▆

0,1
epoch,30.0
test_loss,-3.69306
train_loss_epoch,-5.35604
train_omegam_rmse,0.18727
train_omegam_sigma_bar,0.14391
train_sigma8_rmse,0.12657
train_sigma8_sigma_bar,0.52933
trainer/global_step,30.0
val_loss_epoch,-4.03756
val_loss_step,-6.82894


## Transfer learning

In [None]:
import seaborn as sns
sns.set_context("paper", font_scale = 2)
sns.set_style('whitegrid')
sns.set(style='ticks')

In [None]:
from utils import smooth_3D_field

In [None]:
# halo-distribution-vit-lss-64 --> Different simulations (1000-1999) than DM density (0-999)
# halo-vit-lss-64-same-sims-no-mass-cuts --> Same simulations (0-999) as DM density (0-999)
# halo-vit-lss-64-different-sims-1e14-mass-cut --> Different simulations (1000-1999) than DM density (0-999) with mass cut: removing halos with mass < 1e14 solar masses. This value was decided using the HMF plot.
halo_dirname = '/kaggle/input/halo-distribution-vit-lss-64'

In [None]:
SAME_SIMS = False  # Whether the halo and DM density simulations exactly match.

In [None]:
if SAME_SIMS:
    # Analysis of the bias parameter
    import h5py
    import os
    import glob
    import numpy as np
    from utils import read_hdf5
    import matplotlib.pyplot as plt

    from scipy.stats import iqr

    DEN_FIELD_DIRECTORY = 'my_outputs'

    if USE_COLAB:
        dpath = os.path.join('/content/', f'{DEN_FIELD_DIRECTORY}', '*.h5')
        hpath = os.path.join('/content/', f'{DEN_FIELD_DIRECTORY}_halo', '*.h5')
    else:
        dpath = os.path.join('/kaggle/input/density-fields-vit-lss-64/', f'{DEN_FIELD_DIRECTORY}', '*.h5')
        hpath = os.path.join(f'{halo_dirname}', f'{DEN_FIELD_DIRECTORY}_halo', '*.h5')

    all_bias_params = []
    dens = []
    halos = []
    for i, filename in enumerate(
        zip(
          sorted(glob.glob(dpath)),
          sorted(glob.glob(hpath))
        )
    ):
        # print(filename[0], filename[1])
        den, dparams = read_hdf5(filename[0], dataset_name='3D_density_field')
        halo, hparams = read_hdf5(filename[1], dataset_name='3D_halo_distribution')

        den = smooth_3D_field(den)
        halo = smooth_3D_field(halo)

        den_contrast = den/den.mean() - 1
        halo_contrast = halo/halo.mean() - 1

        dens.append(den_contrast)
        halos.append(halo_contrast)
    #     bias_params = (halo_contrast[np.where(den_contrast < 1)]/den_contrast[np.where(den_contrast < 1)])
        bias_params = (halo_contrast/den_contrast)

        # # Remove outliers.
        # bias_params = bias_params[(bias_params > np.quantile(bias_params, 0.01)) & (bias_params < np.quantile(bias_params, 0.99))]
        all_bias_params.append(np.median(bias_params))

    all_bias_params = np.array(all_bias_params)
    bias = np.mean(all_bias_params)

#     fig, ax = plt.subplots(1, 1, figsize=(7, 5))
#     ax.hist(all_bias_params.ravel(), bins=20)
#     ax.set_yscale('log')
#     ax.set_title(f'Bias: {bias:.2f} +/- {np.std(all_bias_params):.2f}')
#     plt.show()

#     # fig, ax = plt.subplots(1, 1, figsize=(7, 5))
#     # ax.scatter(dens, halos, alpha=0.6)
#     # plt.show()

    print(f'Bias: {bias}')

In [None]:
# if SAME_SIMS:
#     from utils import power_spectrum

#     # Analysis of the bias parameter
#     import h5py
#     import os
#     import glob
#     import numpy as np
#     from utils import read_hdf5
#     import matplotlib.pyplot as plt
#     import contextlib

#     DEN_FIELD_DIRECTORY = 'my_outputs'

#     if USE_COLAB:
#         dpath = os.path.join('/content/', f'{DEN_FIELD_DIRECTORY}', '*.h5')
#         hpath = os.path.join('/content/', f'{DEN_FIELD_DIRECTORY}_halo', '*.h5')
#     else:
#         dpath = os.path.join('/kaggle/input/density-fields-vit-lss-64/', f'{DEN_FIELD_DIRECTORY}', '*.h5')
#         hpath = os.path.join(f'{halo_dirname}', f'{DEN_FIELD_DIRECTORY}_halo', '*.h5')

#     Pk_dens = []
#     Pk_halos = []
#     for i, filename in enumerate(
#         zip(
#           sorted(glob.glob(dpath)),
#           sorted(glob.glob(hpath))
#         )
#     ):
#         # print(filename[0], filename[1])
#         den, dparams = read_hdf5(filename[0], dataset_name='3D_density_field')
#         halo, hparams = read_hdf5(filename[1], dataset_name='3D_halo_distribution')

#         with open(os.devnull, "w") as f, contextlib.redirect_stdout(f):  # Prevent unnecessary verbose output from printing on screen.
#             k_den, Pk_den = power_spectrum(den, dimensional=3)
#             k_halo, Pk_halo = power_spectrum(halo, dimensional=3)

#         # Discard scales below the Nyquist frequency since they are unreliable.
#         L = 1000   # Mpc/h
#         Ng = 64
#         k_Nq = (2 * np.pi / L) * (Ng / 2)
#         condition_k = k_den <= k_Nq
#         k_den = k_den[condition_k]
#         k_halo = k_halo[condition_k]

#         # Remove shot noise from halo power spectra: Pk' = Pk - (1/n_bar)
#         # The field `halo` contains the effective no. of particles inside a 2D cell (n_bar).
#         # And Pk has units (Mpc/h)^2, so n_bar must be divided by the area of each cell.
#         n_bar = halo.mean()
#         n_bar = n_bar / ((L / Ng) ** 2)
#         Pk_halo = Pk_halo - (1 / n_bar)

#         # Select Pk powers corresponding to the refined set of wavenumbers.
#         Pk_den = Pk_den[condition_k]
#         Pk_halo = Pk_halo[condition_k]

#         Pk_dens.append(Pk_den)
#         Pk_halos.append(Pk_halo)

#     Pk_den = np.vstack(Pk_dens).mean(axis=0)
#     Pk_halo = np.vstack(Pk_halos).mean(axis=0)

#     assert np.all(k_den == k_halo)

#     # Most likely bias from the power spectrum.
#     # Following this paper: https://iopscience.iop.org/article/10.1088/0004-637X/724/2/878 >>> "We calculate b2 as the average over the 10 largest wavelength modes in the simulation".
#     most_likely_bias_from_ps = np.sqrt(np.mean(Pk_halo[:10]/Pk_den[:10]))

#     print(f'Most likely bias from the power spectrum: {most_likely_bias_from_ps}')
#     print(f'Bias obtained from the halo_contrast/DM_density_contrast analysis in the above cell: {bias}')

#     fig, ax = plt.subplots(2, 1, figsize=(6, 5))
#     fig.subplots_adjust(hspace=0)
#     ax[0].loglog(k_den, Pk_den, c='black', label='DM', linewidth=4)
#     ax[0].loglog(k_den, Pk_halo, c='black', linestyle='--', label='Halo', linewidth=4);
#     ax[0].legend();
#     ax[0].set_ylabel(r'$P(k)$')
#     ax[1].loglog(k_den, np.sqrt(Pk_halo/Pk_den), linewidth=4)
#     ax[1].set_ylabel(r'$\sqrt{P_{halo}(k)/P_{dm}(k)}$')
#     ax[1].set_xlabel(r'$k$')
#     ax[1].axhline(y=bias, linestyle='--', c='gray', linewidth=2)
#     ax[1].axhline(y=most_likely_bias_from_ps, linestyle='--', c='red', linewidth=2)
#     # ax[1].set_ylim([np.mean(all_bias_params)-3*np.std(all_bias_params), np.mean(all_bias_params)+3*np.std(all_bias_params)])

Using the power spectrum is not the best choice since it leads to a degeneracy between bias and $\sigma_8$. See e.g., https://arxiv.org/pdf/2006.01146.pdf

QUESTION: It looks like the bias found here (~1.5-1.6) is a bit on the higher side than typically found values such as 1.15 or so (for e.g., here). Could this be because of the small grid size? Due to the small grid size, the halos may not be resolved well, and hence these estimates may not be very reliable. This is just my guess. Bias dependence on resolution was mentioned in this paper as well.

Understanding any correlation between and bias. So we plot both for each simulation in the below plot.

In [None]:
# if SAME_SIMS:
#     biases, sigma_8s, omegams = [], [], []
#     for i, filename in enumerate(
#         zip(
#           sorted(glob.glob(dpath)),
#           sorted(glob.glob(hpath))
#         )
#     ):
#         den, dparams = read_hdf5(filename[0], dataset_name='3D_density_field')
#         halo, hparams = read_hdf5(filename[1], dataset_name='3D_halo_distribution')

#         sigma_8 = dparams[-1]
#         omega_m = dparams[0]
#         assert sigma_8 == hparams[-1]
#         assert omega_m == hparams[0]

#         den = smooth_3D_field(den)
#         halo = smooth_3D_field(halo)

#         den_contrast = den/den.mean() - 1
#         halo_contrast = halo/halo.mean() - 1

#         bias_param = np.median(halo_contrast/den_contrast)

#         biases.append(bias_param)
#         sigma_8s.append(sigma_8)
#         omegams.append(omega_m)

#     # Fit a line to the bias vs sigma_8 data points.
#     from scipy import stats
#     res = stats.linregress(biases, sigma_8s)  # res contains slope, intercept, r_value, p_value, std_err.

#     from mpl_toolkits.axes_grid1 import make_axes_locatable

#     fig, ax = plt.subplots(1, 1)
#     im = ax.scatter(biases, sigma_8s, c=omegams, alpha=0.75)
#     ax.plot(biases, res.intercept + res.slope * np.array(biases), 'r', label='fitted line')
#     ax.set_xlabel(r'$b$')
#     ax.set_ylabel(r'$\sigma_8$')
#     ax.set_title(r'$b$ vs $\sigma_8$ for all 1000 simulations')
#     ax.legend()
#     print(fr'Equation of fitted line: $\sigma_8 = {res.slope:.2f} * bias + {res.intercept:.2f}$')

#     divider = make_axes_locatable(ax)
#     cax = divider.append_axes('right', size='5%', pad=0.05)
#     cbar = fig.colorbar(im, cax=cax, orientation='vertical')
#     cbar.ax.set_ylabel(r'$\Omega_m$', rotation=270)
#     cbar.ax.get_yaxis().labelpad = 15

#     fig, ax = plt.subplots(1, 1)
#     sns.kdeplot(x=biases, y=sigma_8s, ax=ax, fill=True)
#     ax.set_xlabel(r'$b$')
#     ax.set_ylabel(r'$\sigma_8$')
#     ax.set_title(r'$b$ vs $\sigma_8$ for all 1000 simulations')

One speculation: if we use bias (=1.58) to scale the halo distribution for transfer learning, then predictions for $\sigma_8$ corresponding to ranges corresponding to the bias value of 1.58, which from the above figure seems to be ~0.7-0.8, could be the most affected. One can look at the prediction plots below to ascertain this.

Plotting $\delta$ vs. $b$

In [None]:
# if SAME_SIMS:
#     from mpl_toolkits.axes_grid1 import make_axes_locatable

#     biases, dm_density_contrasts = [], []
#     counter = 0
#     for i, filename in enumerate(
#         zip(
#           sorted(glob.glob(dpath)),
#           sorted(glob.glob(hpath))
#         )
#     ):
#         den, dparams = read_hdf5(filename[0], dataset_name='3D_density_field')
#         halo, hparams = read_hdf5(filename[1], dataset_name='3D_halo_distribution')

#         den = smooth_3D_field(den)
#         halo = smooth_3D_field(halo)

#         den_contrast = den/den.mean() - 1
#         halo_contrast = halo/halo.mean() - 1

#         bias_param = halo_contrast/den_contrast

#         # Remove extreme (and probably unphysical) bias parameters.
#         lower_limit_condition = bias_param > np.quantile(bias_param, 0.01)
#         upper_limit_condition = bias_param < np.quantile(bias_param, 0.99)

#         bias_param = bias_param[(lower_limit_condition) & (upper_limit_condition)]
#         # Need to also do this for corresponding density contrast.
#         den_contrast = den_contrast[(lower_limit_condition) & (upper_limit_condition)]

#         bpf = bias_param.flatten()
#         dcf = den_contrast.flatten()
#         biases.append(bpf)
#         dm_density_contrasts.append(dcf)

#         fig, ax = plt.subplots(1, 2, figsize=(10, 5))
#         im = ax[0].scatter(dcf, bpf)
#         ax[0].set_xlabel(r'$\delta_{DM}$')
#         ax[0].set_ylabel(r'$b$')
#         ax[0].set_title(r'$\delta_{DM}$ vs $b$')
#         ax[0].set_ylim([0, 3])
#     #     divider = make_axes_locatable(ax)
#     #     cax = divider.append_axes('right', size='5%', pad=0.05)
#     #     fig.colorbar(im, cax=cax, orientation='vertical')
#     #     sns.kdeplot(x=den_contrast.flatten(), y=bias_param.flatten(), ax=ax[1], fill=True)
#         ax[1].hexbin(dcf, bpf)
#         ax[1].set_xlabel(r'$\delta_{DM}$')
#         ax[1].set_ylabel(r'$b$')
#         ax[1].set_title(r'$\delta_{DM}$ vs $b$')
#         ax[1].set_ylim([0, 3])
#         plt.show()

#     #     fig, ax = plt.subplots(1, 1)
#     #     sns.kdeplot(x=dm_density_contrasts, y=biases, ax=ax, fill=True)
#     #     ax.set_xlabel(r'$\delta_{DM}$')
#     #     ax.set_ylabel(r'$b$')
#     #     ax.set_title(r'$\delta_{DM}$ vs $b$ for all 1000 simulations')
#     #     ax.set_ylim([0, 2])
#     #     plt.show()

#         counter += 1
#         if counter == 5:
#             break

Plotting the PDF of density and halo distribution for comparison.

In the below plot, we use the mean DM density for calculating the density contrast and mean halo density for the halo contrast.

In [None]:
# if SAME_SIMS:
#     dens, halos = [], []
#     for i, filename in enumerate(
#         zip(
#           sorted(glob.glob(dpath)),
#           sorted(glob.glob(hpath))
#         )
#     ):
#         den, dparams = read_hdf5(filename[0], dataset_name='3D_density_field')
#         halo, hparams = read_hdf5(filename[1], dataset_name='3D_halo_distribution')

#         den = smooth_3D_field(den)
#         halo = smooth_3D_field(halo)

#         dens.append(den/den.mean())
#         halos.append(halo/halo.mean())

#     dens = np.array(dens)
#     halos = np.array(halos)
#     fig, ax = plt.subplots(1, 1)
#     ax.hist(dens.ravel(), histtype='step', linewidth=3, label='DM density')
#     ax.hist(halos.ravel(), histtype='step', linewidth=3, label='Halo')
#     ax.set_xlabel(r'$1 + \delta$')
#     ax.set_ylabel('Counts')
#     ax.set_yscale('log')
#     ax.legend()

In [None]:
# NOTE: use precomputed_mean, precomputed_stddev, precomputed_min_vals, and precomputed_max_vals only if you are using bias.
# Else it's better to use statistics of this new dataset for preprocessing.

#########################################################################################################

# The logic is that when the exact same sims are used for DM density and halo, we use the bias and also
# the precomputed statistics. Whereas if different simulations are used for DM and halo, we don't use the
# bias and also don't use the precomputed statistics. When same sims, the bias adds a positive value, thus
# log10(halo) does not give divide by zero error. When different sims, the same is handled by log(1+halo).

#########################################################################################################
if USE_COLAB:
    if SAME_SIMS:
        command = [
            'python', 'create_data.py', '--num_sims', f'{num_sims}', '--train_frac', '0.8', '--test_frac', '0.1', '--seed', '42', '--path', '/content/my_outputs_halo', '--grid_size', '64',
            '--num_maps_per_projection_direction', '10', '--prefix', 'halos', '--dataset_name', '3D_halo_distribution', '--bias', f'{bias}',
            '--precomputed_mean', f'{MEAN}', '--precomputed_stddev', f'{STD}',
            '--precomputed_min_vals', f'{MIN_VALS[0]}', f'{MIN_VALS[1]}', f'{MIN_VALS[2]}', f'{MIN_VALS[3]}', f'{MIN_VALS[4]}',
            '--precomputed_max_vals', f'{MAX_VALS[0]}', f'{MAX_VALS[1]}', f'{MAX_VALS[2]}', f'{MAX_VALS[3]}', f'{MAX_VALS[4]}',
            '--smallest_sim_number', '0',
        ]
    else:  # don't use bias.
        command = [
            'python', 'create_data.py', '--num_sims', f'{num_sims}', '--train_frac', '0.8', '--test_frac', '0.1', '--seed', '42', '--path', '/content/my_outputs_halo', '--grid_size', '64',
            '--num_maps_per_projection_direction', '10', '--prefix', 'halos', '--dataset_name', '3D_halo_distribution',
            '--smallest_sim_number', '1000', '--log_1_plus'
        ]
        #         --precomputed_mean {MEAN} --precomputed_stddev {STD} \
        #         --precomputed_min_vals {MIN_VALS[0]} {MIN_VALS[1]} {MIN_VALS[2]} {MIN_VALS[3]} {MIN_VALS[4]} \
        #         --precomputed_max_vals {MAX_VALS[0]} {MAX_VALS[1]} {MAX_VALS[2]} {MAX_VALS[3]} {MAX_VALS[4]} \
else:
    if SAME_SIMS:
        command = [
            'python', 'create_data.py', '--num_sims', f'{num_sims}', '--train_frac', '0.8', '--test_frac', '0.1', '--seed', '42', '--path', f'{halo_dirname}/my_outputs_halo', '--grid_size', '64',
            '--num_maps_per_projection_direction', '10', '--prefix', 'halos', '--dataset_name', '3D_halo_distribution', '--bias', f'{bias}',
            '--precomputed_mean', f'{MEAN}', '--precomputed_stddev', f'{STD}',
            '--precomputed_min_vals', f'{MIN_VALS[0]}', f'{MIN_VALS[1]}', f'{MIN_VALS[2]}', f'{MIN_VALS[3]}', f'{MIN_VALS[4]}',
            '--precomputed_max_vals', f'{MAX_VALS[0]}', f'{MAX_VALS[1]}', f'{MAX_VALS[2]}', f'{MAX_VALS[3]}', f'{MAX_VALS[4]}',
            '--smallest_sim_number', '0',
        ]
    else:  # don't use bias.'
        command = [
            'python', 'create_data.py', '--num_sims', f'{num_sims}', '--train_frac', '0.8', '--test_frac', '0.1', '--seed', '42', '--path', f'{halo_dirname}/my_outputs_halo', '--grid_size', '64',
            '--num_maps_per_projection_direction', '10', '--prefix', 'halos', '--dataset_name', '3D_halo_distribution',
            '--smallest_sim_number', '1000', '--log_1_plus'
        ]
        #         --precomputed_mean {MEAN} --precomputed_stddev {STD} \
        #         --precomputed_min_vals {MIN_VALS[0]} {MIN_VALS[1]} {MIN_VALS[2]} {MIN_VALS[3]} {MIN_VALS[4]} \
        #         --precomputed_max_vals {MAX_VALS[0]} {MAX_VALS[1]} {MAX_VALS[2]} {MAX_VALS[3]} {MAX_VALS[4]} \

import subprocess
result = subprocess.run(command)
result

In [None]:
# Store the mean, std, min_vals and max_vals into variables
prefix = 'halos'
MEAN = np.load(f'{prefix}_dataset_mean.npy')
STD = np.load(f'{prefix}_dataset_std.npy')
MIN_VALS = np.load(f'{prefix}_dataset_min_vals.npy')
MAX_VALS = np.load(f'{prefix}_dataset_max_vals.npy')
MEAN_DENSITIES = np.load(f'{prefix}_dataset_mean_densities.npy')
print(MEAN, STD, MIN_VALS, MAX_VALS)

In [None]:
import gzip
import numpy as np
import glob
filename = sorted(glob.glob(f'{base_dir}/train/processed_sim*_X1_LH_z0_grid64_masCIC.npy.gz'))[0]
f = gzip.GzipFile(filename, 'r'); halo = np.load(f)

In [None]:
import pandas as pd
df = pd.read_csv('train/train_original_params.csv')
df.head(2)

In [None]:
# import matplotlib.pyplot as plt
# v = df[df['0'] == '/'.join(filename.split('/')[-2:])]
# params = list(v[v.columns[-5:]].iloc[0])

# plt.imshow(halo); plt.title(np.round(params, 4)); plt.colorbar()

In [None]:
import torch
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader

from model_dataset import CustomImageDataset
from torchvision.transforms import v2
from torchvision import transforms
transform = v2.Compose([
    MyRotationTransform(angles=[90, 180, 270]),
    v2.ToDtype(torch.float32)
])

train_dataset = CustomImageDataset(f'{base_dir}/train', normalized_cosmo_params_path=f'{base_dir}/train/train_normalized_params.csv', transform=transform)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, num_workers=3)

val_dataset = CustomImageDataset(f'{base_dir}/val', normalized_cosmo_params_path=f'{base_dir}/val/val_normalized_params.csv', transform=None)
val_loader = DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=False, num_workers=3)

test_dataset = CustomImageDataset(f'{base_dir}/test', normalized_cosmo_params_path=f'{base_dir}/test/test_normalized_params.csv', transform=None)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False, num_workers=3)

In [None]:
len(train_dataset), len(val_dataset), len(test_dataset)

In [None]:
# Updated parameters for the transfer learning come here.
epochs = 10  # Smaller no. of epochs than pretraining.
# dr = dr * 2
# wd = wd * 2
# lr = lr * 0.1

In [None]:
lr, wd, dr

In [None]:
FREEZE_LAYERS = False  # Whether to freeze all layers for transfer learning. If False, all layers are retrained on the new dataset.

In [None]:
from torch_intermediate_layer_getter import IntermediateLayerGetter as MidGetter

In [None]:
class ViT_FineTune(ViT):
    def __init__(self, PRETRAINED_FILENAME, model_kwargs, lr, wd, beta1, beta2, minimum, maximum):
        super(ViT_FineTune, self).__init__(model_kwargs=model_kwargs, lr=lr, wd=wd, beta1=beta1, beta2=beta2, minimum=MIN_VALS, maximum=MAX_VALS)
        self.save_hyperparameters()

        self.model = ViT.load_from_checkpoint(PRETRAINED_FILENAME)

        # Re-initialize the MLP head.
        # See, for example, https://pyimagesearch.com/2019/06/03/fine-tuning-with-keras-and-deep-learning/
        self.model.model.mlp_head = nn.Sequential(
            nn.LayerNorm(model_kwargs['embed_dim']),
            nn.Linear(model_kwargs['embed_dim'], model_kwargs['num_classes'])
        )
        
#         self.model.model.mlp_head = nn.Sequential(
#             nn.LayerNorm(model_kwargs['embed_dim']),
#             nn.Linear(model_kwargs['embed_dim'], model_kwargs['embed_dim']//2),
#             nn.GELU(),
#             nn.Dropout(p=0.1),
#             nn.Linear(model_kwargs['embed_dim']//2, model_kwargs['embed_dim']//4),
#             nn.GELU(),
#             nn.Dropout(p=0.1),
#             nn.Linear(model_kwargs['embed_dim']//4, model_kwargs['num_classes'])
#         )

        self.example_input_array = next(iter(train_loader))[0]

    def configure_optimizers(self):
        mlp_head_params = list(map(lambda x: x[1],list(filter(lambda kv: 'model.mlp_head' in kv[0], self.model.named_parameters()))))
        feature_params = list(map(lambda x: x[1],list(filter(lambda kv: 'model.mlp_head' not in kv[0], self.model.named_parameters()))))
        assert len(mlp_head_params) > 0  # Because we know there exists a MLP head in our model.
        assert len(feature_params) > 0  # Because we know there exists parameters corresponding to the transformer and input layers.
        optimizer = torch.optim.AdamW(
            [
                {'params': mlp_head_params, 'lr': 1e-2},
                {'params': feature_params, 'lr': 1e-4}
            ],
            weight_decay=self.hparams.wd, betas=(self.hparams.beta1, self.hparams.beta2)
        )
#         optimizer = torch.optim.AdamW(
#             [
#                 # Set a small learning rate for all pre-trained layers, but a larger
#                 # learning rate for the MLP head layers.
#                 {"params": self.model.model.input_layer.parameters(), "lr": 1e-4},
#                 {"params": self.model.model.transformer.parameters(), "lr": 1e-4},
#                 {"params": self.model.model.mlp_head.parameters(), "lr": 1e-2},
#             ],
#             weight_decay=self.hparams.wd, betas=(self.hparams.beta1, self.hparams.beta2)
#         )
        lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.3, patience=5)

        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": lr_scheduler,
                "monitor": "val_loss"
            },
        }

In [None]:
from pytorch_lightning.loggers import WandbLogger
logger_csv = pl.loggers.CSVLogger(CHECKPOINT_PATH, name="lightning_logs_csv")
WANDB_RUN_NAME = f'ViT_TL_DiffSimsNomasscut-batchsize-{batch_size}_lr-{lr}_epochs-{epochs}_wd-{wd}_dr-{model_kwargs["dropout"]}'
wandb_logger = WandbLogger(name=WANDB_RUN_NAME, project='Cosmo-parameter-inference')
wandb_logger.experiment.config.update({"batch_size": batch_size, "epochs": epochs})

def finetune_model(PRETRAINED_FILENAME, **kwargs):
    pl.seed_everything(SEED, workers=True) # To be reproducable
    trainer = pl.Trainer(default_root_dir=os.path.join(CHECKPOINT_PATH, "ViT"),
                         accelerator="gpu" if str(device).startswith("cuda") else "cpu",
                         devices=1,
                         max_epochs=epochs,
                         logger=[logger_csv, wandb_logger],
#                          progress_bar_refresh_rate=50,
                         callbacks=[ModelCheckpoint(save_weights_only=True, mode="min", monitor="val_loss"),
                                    LearningRateMonitor("epoch")],
#                          deterministic=True
                        )
#     trainer.logger._log_graph = True         # If True, we plot the computation graph in tensorboard
#     trainer.logger._default_hp_metric = None # Optional logging argument that we don't need

    # Check whether pretrained model exists. If yes, load it before fine-tuning. If no, raise an error.
    if os.path.isfile(PRETRAINED_FILENAME):
        print(f"Found pretrained model at {PRETRAINED_FILENAME}. This will be used to load the model.")
    else:
        raise ValueError("Finetuning requires a pretrained model file to be specified by the `PRETRAINED_FILENAME` argument!")
    model = ViT_FineTune(PRETRAINED_FILENAME, **kwargs) # Automatically loads the model with the saved hyperparameters
    # After loading the pretrained model, finetune it.
    trainer.fit(model, train_loader, val_loader)
    model = ViT_FineTune.load_from_checkpoint(trainer.checkpoint_callback.best_model_path) # Load best checkpoint after training

    # Test best model on validation and test set
    val_result = trainer.test(model, val_loader, verbose=False)
    test_result = trainer.test(model, test_loader, verbose=False)
    result = {"test": test_result[0]["test_loss"], "val": val_result[0]["test_loss"]}

    return model, result, trainer.checkpoint_callback.best_model_path

In [None]:
# PRETRAINED_FILENAME was already defined above, where it was set to
# `trainer.checkpoint_callback.best_model_path`, where trainer was for pre-training.
model, results, FINETUNED_FILENAME = finetune_model(
    PRETRAINED_FILENAME, model_kwargs=model_kwargs,
    lr=lr, wd=wd, beta1=beta1, beta2=beta2,
    minimum=MIN_VALS, maximum=MAX_VALS
)
print("ViT_FineTune results", results)
model.to(device)

In [None]:
if USE_TENSORBOARD:
    # Opens tensorboard in notebook. Adjust the path to your CHECKPOINT_PATH!
    # The exact command is taken from https://lightning.ai/docs/pytorch/stable/visualize/logging_basic.html
    %reload_ext tensorboard
    %tensorboard --logdir ./saved_models/tensorboards/

In [None]:
from train_val_test_boilerplate import test

# Below values calculated during data preparation. See above.
minimum = MIN_VALS
maximum = MAX_VALS

params_true, params_NN, errors_NN, filenames = test(model, test_loader, g=g, h=h, device=device, minimum=minimum, maximum=maximum)

In [None]:
from evaluation_analysis import post_test_analysis, get_cka

post_test_analysis(
    params_true, params_NN, errors_NN, filenames,
    params, num_sims, MEAN, STD, MEAN_DENSITIES, minimum, maximum,
    num_maps_per_projection_direction,
    test_results_filename='test_results_transfer_learning_ViT.csv',
    smallest_sim_number=0 if SAME_SIMS else 1000
)

get_cka(
    model.model.model, test_loader,
    return_layers = {
        'transformer.0.linear.1': 'GELU',
        'transformer.1.linear.1': 'GELU',
        'transformer.2.linear.1': 'GELU',
        'transformer.3.linear.1': 'GELU',
        'transformer.4.linear.1': 'GELU',
        'transformer.5.linear.1': 'GELU'
    },
    cka_filename='cka_matrix_transfer_learning_halo_ViT_grid64_test.png',
    device=device
)

In [None]:
import shutil
shutil.rmtree(f'{base_dir}/train')
shutil.rmtree(f'{base_dir}/val')
shutil.rmtree(f'{base_dir}/test')

In [None]:
wandb.finish()

## Training on transfer learning data FROM SCRATCH

In [None]:
# Since this is training from scratch, don't use the --bias option since we assume we don't have information of DM density.
if USE_COLAB:
    command = [
        'python', 'create_data.py', '--num_sims', f'{num_sims}', '--train_frac', '0.8', '--test_frac', '0.1', '--seed', '42', '--path', '/content/my_outputs_halo', '--grid_size', '64',
        '--num_maps_per_projection_direction', '10', '--prefix', 'halos', '--dataset_name', '3D_halo_distribution', '--log_1_plus', '--smallest_sim_number', f'{0 if SAME_SIMS else 1000}'
    ]
else:
    command = [
        'python', 'create_data.py', '--num_sims', f'{num_sims}', '--train_frac', '0.8', '--test_frac', '0.1', '--seed', '42', '--path', f'{halo_dirname}/my_outputs_halo', '--grid_size', '64',
        '--num_maps_per_projection_direction', '10', '--prefix', 'halos', '--dataset_name', '3D_halo_distribution', '--log_1_plus', '--smallest_sim_number', f'{0 if SAME_SIMS else 1000}'
    ]

import subprocess
result = subprocess.run(command)
result

In [None]:
# Store the mean, std, min_vals and max_vals into variables
prefix = 'halos'
MEAN = np.load(f'{prefix}_dataset_mean.npy')
STD = np.load(f'{prefix}_dataset_std.npy')
MIN_VALS = np.load(f'{prefix}_dataset_min_vals.npy')
MAX_VALS = np.load(f'{prefix}_dataset_max_vals.npy')
MEAN_DENSITIES = np.load(f'{prefix}_dataset_mean_densities.npy')
print(MEAN, STD, MIN_VALS, MAX_VALS)

In [None]:
import torch
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader

from model_dataset import CustomImageDataset

from torchvision.transforms import v2
from torchvision import transforms
transform = v2.Compose([
    MyRotationTransform(angles=[90, 180, 270]),
    v2.ToDtype(torch.float32)#, scale=False),
])

train_dataset = CustomImageDataset(f'{base_dir}/train', normalized_cosmo_params_path=f'{base_dir}/train/train_normalized_params.csv', transform=transform)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, num_workers=3)

val_dataset = CustomImageDataset(f'{base_dir}/val', normalized_cosmo_params_path=f'{base_dir}/val/val_normalized_params.csv', transform=None)
val_loader = DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=False, num_workers=3)

test_dataset = CustomImageDataset(f'{base_dir}/test', normalized_cosmo_params_path=f'{base_dir}/test/test_normalized_params.csv', transform=None)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False, num_workers=3)

In [None]:
# Updated parameters for the transfer learning come here.
epochs = 25
# dr = dr * 2
# wd = wd * 2
# lr = lr * 0.1

In [None]:
from pytorch_lightning.loggers import WandbLogger
logger_csv = pl.loggers.CSVLogger(CHECKPOINT_PATH, name="lightning_logs_csv")
WANDB_RUN_NAME = f'ViT_TLScratch_DiffSimsNomasscut-batchsize-{batch_size}_lr-{lr}_epochs-{epochs}_wd-{wd}_dr-{model_kwargs["dropout"]}'
wandb_logger = WandbLogger(name=WANDB_RUN_NAME, project='Cosmo-parameter-inference')
wandb_logger.experiment.config.update({"batch_size": batch_size, "epochs": epochs})

model, results, PRETRAINED_FILENAME_SCRATCH = train_model(
    model_kwargs=model_kwargs,
    lr=lr, wd=wd, beta1=beta1, beta2=beta2,
    minimum=MIN_VALS, maximum=MAX_VALS
)
print("ViT results", results)
model.to(device)

In [None]:
if USE_TENSORBOARD:
    # Opens tensorboard in notebook. Adjust the path to your CHECKPOINT_PATH!
    # The exact command is taken from https://lightning.ai/docs/pytorch/stable/visualize/logging_basic.html
    %reload_ext tensorboard
    %tensorboard --logdir ./saved_models/tensorboards/

In [None]:
from train_val_test_boilerplate import test

# Below values calculated during data preparation. See above.
minimum = MIN_VALS
maximum = MAX_VALS

params_true, params_NN, errors_NN, filenames = test(model, test_loader, g=g, h=h, device=device, minimum=minimum, maximum=maximum)

In [None]:
from evaluation_analysis import post_test_analysis, get_cka

post_test_analysis(
    params_true, params_NN, errors_NN, filenames,
    params, num_sims, MEAN, STD, MEAN_DENSITIES, minimum, maximum,
    num_maps_per_projection_direction,
    test_results_filename='test_results_transfer_learning_from_scratch_ViT.csv',
    smallest_sim_number=0 if SAME_SIMS else 1000
)

get_cka(
    model.model, test_loader,
    return_layers = {
        'transformer.0.linear.1': 'GELU',
        'transformer.1.linear.1': 'GELU',
        'transformer.2.linear.1': 'GELU',
        'transformer.3.linear.1': 'GELU',
        'transformer.4.linear.1': 'GELU',
        'transformer.5.linear.1': 'GELU'
    },
    cka_filename='cka_matrix_transfer_learning_from_scratch_halo_ViT_grid64_test.png',
    device=device
)

In [None]:
import shutil
shutil.rmtree(f'{base_dir}/train')
shutil.rmtree(f'{base_dir}/val')
shutil.rmtree(f'{base_dir}/test')

In [None]:
wandb.finish()

TODO: Don't know why the ViT predicts omega_m and sigma_8 relatively well on DM density but fails to predict sigma_8 on DM halo (but still okay on omega_m)? 
**TODO: Try to redo CNN in the same interface like this notebook and see if those results are same as without lightning....**