In [1]:
import numpy as np
import matplotlib.pyplot as plt
import os
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.nn import Sequential
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

In [2]:
class Dataset(Dataset):
  # 'Characterizes a dataset for PyTorch'
    def __init__(self, samp_idxs, base_path = "./data/"):
        # 'Initialization'
        self.base_path = base_path
        self.samp_idxs = samp_idxs # a list of the indexs of the samples to use 
        
    def __len__(self):
        # 'Denotes the total number of samples'
        return len(self.samp_idxs)

    def __getitem__(self, index):
        i = self.samp_idxs[index]
        df = np.load(f"{self.base_path}dens{i:05d}.npy") # density field
        labels = np.load(f"{self.base_path}labels.npy")
        label = labels[i, :]
        
        df = np.expand_dims(df, 0)
        return df, label

class QuickDataset(Dataset):
    def __init__(self, samp_idxs, base_path = "./data/"):
        # 'Initialization'
        self.base_path = base_path
        self.samp_idxs = samp_idxs # a list of the indexs of the samples to use 

        self.dfs = np.zeros((len(samp_idxs), 64, 64), dtype=np.float32)
        self.labels = np.zeros((len(samp_idxs), 2))
        
        all_labels = np.load(f"{self.base_path}labels.npy")
        
        for i in tqdm(range(len(samp_idxs))):
            data_idx = samp_idxs[i]
            self.dfs[i, :, :] = np.load(f"{self.base_path}dens{data_idx:05d}.npy")
            self.labels[i, :] = all_labels[data_idx, :]
        
    def __len__(self):
        # 'Denotes the total number of samples'
        return len(self.samp_idxs)

    def __getitem__(self, index):
        df = self.dfs[index, :, :]
        label = self.labels[index, :]
        
        df = np.expand_dims(df, 0)
        return df, label

In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"
bs = 64

train_idxs = np.arange(0, 7000)
val_idxs = np.arange(7000, 8500)
test_idxs = np.arange(8500, 10000)

# trainset = Dataset(train_idxs)
# testset = Dataset(test_idxs)

trainset = QuickDataset(train_idxs)
valset = QuickDataset(val_idxs)
testset = QuickDataset(test_idxs)

train_loader = DataLoader(trainset, batch_size=bs, shuffle=True, pin_memory = True)
val_loader = DataLoader(valset, batch_size=bs, shuffle=True, pin_memory=True)
test_loader = DataLoader(testset, batch_size=bs, shuffle=True, pin_memory = True)

100%|██████████| 7000/7000 [01:34<00:00, 74.35it/s] 
100%|██████████| 1500/1500 [00:17<00:00, 83.38it/s]
100%|██████████| 1500/1500 [00:20<00:00, 74.09it/s]


In [5]:
#uses some parts from https://github.com/tonyduan/mixture-density-network/blob/master/src/blocks.py
class MyMDM(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.cnn_layers = Sequential( #given an input shape of 64x64
            nn.Conv2d(1, 4, (3,3), padding = 1),
            nn.ReLU(),
            nn.Conv2d(4, 8, (3,3), padding = 1),
            nn.ReLU(),
            nn.MaxPool2d(4, 4),
            nn.Conv2d(8, 16, (3,3), padding = 1),
            nn.ReLU(),
            nn.Conv2d(16, 32, (3,3), padding = 1),
            nn.ReLU(),
            nn.MaxPool2d(4, 4),
        )
        
        self.dim_in = 4*4*32 # size of output from cnn_layers
        self.dim_out = 2 #number of parameters in output (A_s_1e9 and Omega_m)
        self.n_components = 1 #number of components in mixture model(each parameter explained by single gaussian)
        num_sigma_channels = self.dim_out * self.n_components # because we assume diagonal covariance  
        
        self.pi_network = nn.Sequential(
            nn.Linear(self.dim_in, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, self.n_components),
        )
        self.normal_network = nn.Sequential(
            nn.Linear(self.dim_in, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, self.dim_out * self.n_components + num_sigma_channels)
        )

    def forward(self, x, eps=1e-6):
        #
        # Returns
        # -------
        # log_pi: (bsz, n_components)
        # mu: (bsz, n_components, dim_out)
        # sigma: (bsz, n_components, dim_out)
        #
        x = self.cnn_layers(x)
        x = x.reshape(x.shape[0], -1)
        log_pi = torch.log_softmax(self.pi_network(x), dim=-1)
        normal_params = self.normal_network(x)
        mu = normal_params[..., :self.dim_out * self.n_components]
        sigma = normal_params[..., self.dim_out * self.n_components:]
        sigma = torch.exp(sigma + eps) # <-- this will need to change if we are not assuming diagonal covariance
        mu = mu.reshape(-1, self.n_components, self.dim_out)
        sigma = sigma.reshape(-1, self.n_components, self.dim_out)
        return log_pi, mu, sigma

    def loss(self, y, log_pi, mu, sigma):
        # print(y[0, :])
        # print(log_pi[0,:])
        # print(mu[0,:, :])
        # print(sigma[0,:,:])
        
        z_score = (y.unsqueeze(1) - mu) / sigma
        normal_loglik = (
            -0.5 * torch.einsum("bij,bij->bi", z_score, z_score)
            -torch.sum(torch.log(sigma), dim=-1)
        )
        loglik = torch.logsumexp(log_pi + normal_loglik, dim=-1)
        return -loglik

    def sample(self, x):
        log_pi, mu, sigma = self.forward(x)
        cum_pi = torch.cumsum(torch.exp(log_pi), dim=-1)
        rvs = torch.rand(len(x), 1).to(x)
        rand_pi = torch.searchsorted(cum_pi, rvs)
        rand_normal = torch.randn_like(mu) * sigma + mu
        samples = torch.gather(rand_normal, index=rand_pi.unsqueeze(-1), dim=1).squeeze(dim=1)
        return samples 
        

In [23]:
def train_model(model):
    epochs = 750
    optim = torch.optim.Adam(model.parameters(), lr = 1e-4)
    # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optim, verbose = False)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, 50)
    
    train_ep_losses = [] #stores the sum of train_loss
    val_ep_losses = [] #stores the sum of val_loss 
    val_epochs = [] # stores the epochs at which validation set is evaluated
    min_val_loss = np.inf
    for epoch in tqdm(range(epochs)):

        train_ep_loss = []
        val_ep_loss = []
        model.train()

        for b, batch in enumerate(train_loader):

            df, label = batch

            df = df.to(device) / 1000 #some normalization
            label = label.to(device)
            label[:, 0] = label[:, 0]/10 #Bring A_s_1e9 range from [0.5, 5.5] to [0.05, 0.55]      

            pred = model(df)
            loss = model.loss(label, *pred).mean()

            optim.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
            optim.step()
            
            train_ep_loss.append(loss.item())
            
        train_ep_losses.append(sum(train_ep_loss)/len(train_ep_loss))
        
        if epoch % 5 == 0:
            
            model.eval()
            
            for b, batch in enumerate(val_loader):

                df, label = batch
                
                df = df.to(device) / 1000 #some normalization
                label = label.to(device)
                label[:, 0] = label[:, 0]/10 #Bring A_s_1e9 range from [0.5, 5.5] to [0.05, 0.55]      

                pred = model(df)
                loss = model.loss(label, *pred).mean()
                
                val_ep_loss.append(loss.item())
                
            val_ep_losses.append(sum(val_ep_loss)/len(val_ep_loss))
            val_epochs.append(epoch)
            scheduler.step(val_ep_losses[-1])
        
        if val_ep_losses[-1] < min_val_loss:
           
            min_val_loss = val_ep_losses[-1]
            torch.save({
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optim.state_dict(),
                    'scheduler_state_dict': scheduler.state_dict(),
                    'train_loss': train_ep_loss,
                    'val_loss': val_ep_loss
                    }, f"./model_ckpt.tar")
        
        plt.plot(train_ep_losses, label = "Train Loss")
        plt.plot(val_epochs, val_ep_losses, label = "Val Loss")
        plt.legend()
        plt.savefig("./training_val_loss.png")
        plt.close()
    

In [24]:
torch.manual_seed(0)
model = MyMDM()

print("Num trainable params:",  sum(p.numel() for p in model.parameters() if p.requires_grad))

def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find('Linear') != -1:
        y = m.in_features
        m.weight.data.normal_(0.0,1/np.sqrt(y))
        m.bias.data.fill_(0)

model.apply(weights_init_normal)

model = model.to(device = device)
train_model(model)

Num trainable params: 171141


100%|██████████| 750/750 [12:31<00:00,  1.00s/it]


### Evaluate Model

In [16]:
model = MyMDM()
model_cp = torch.load("./model_ckpt.tar")
model.load_state_dict(model_cp["model_state_dict"])
model.eval()
model = model.to(device = device)

pred = None
for b, batch in enumerate(test_loader):

    df, label = batch

    df = df.to(device) / 1000 #some normalization
    label[:, 0] = label[:, 0]/10 #Bring A_s_1e9 range from [0.5, 5.5] to [0.05, 0.55]    
    
    pred = model(df)
    break

In [17]:
mus = pred[1].detach().cpu().numpy()[:,0,:]
sigmas = pred[2].detach().cpu().numpy()[:,0,:]

labels = label.detach().cpu().numpy()

In [20]:
for i in range(15):
    # print("A_s_1e8 | Label:", labels[i,0], "| Pred:", mus[i,0], "+/-", sigmas[i,0])
    print("Omega_m | Label:", labels[i,1], "| Pred:", mus[i,1], "+/-", sigmas[i,1])

Omega_m | Label: 0.35000000000000003 | Pred: 0.37367725 +/- 0.061160143
Omega_m | Label: 0.45 | Pred: 0.47929844 +/- 0.08339021
Omega_m | Label: 0.25 | Pred: 0.29973027 +/- 0.06623718
Omega_m | Label: 0.5 | Pred: 0.4699641 +/- 0.04503933
Omega_m | Label: 0.25 | Pred: 0.24811687 +/- 0.069735296
Omega_m | Label: 0.2 | Pred: 0.33428103 +/- 0.07652944
Omega_m | Label: 0.35000000000000003 | Pred: 0.38484484 +/- 0.070030704
Omega_m | Label: 0.2 | Pred: 0.16515684 +/- 0.036609013
Omega_m | Label: 0.05 | Pred: 0.053703614 +/- 0.003204
Omega_m | Label: 0.35000000000000003 | Pred: 0.37044168 +/- 0.10709389
Omega_m | Label: 0.15000000000000002 | Pred: 0.22771005 +/- 0.052512005
Omega_m | Label: 0.3 | Pred: 0.19820632 +/- 0.04593917
Omega_m | Label: 0.25 | Pred: 0.29672006 +/- 0.07785397
Omega_m | Label: 0.5 | Pred: 0.33930328 +/- 0.06986257
Omega_m | Label: 0.15000000000000002 | Pred: 0.13587587 +/- 0.02919797


In [49]:
mus.shape

(28, 2)

In [50]:
label.shape

torch.Size([28, 2])

In [67]:
for b, batch in enumerate(train_loader):

    df, label = batch
    print(df.shape)
    print(label.shape)
    break

torch.Size([64, 1, 64, 64])
torch.Size([64, 2])
