In [1]:
import os
import sys
sys.path.append(os.path.expanduser('~'))
from variational_diffusion_cdm.data.astro_dataset import get_astro_data
from variational_diffusion_cdm.model.utils.utils import draw_figure,compute_pk
import comet_ml
import torch
from torch import nn
from torch.nn.functional import mse_loss
from torch import autograd, Tensor
from lightning.pytorch import LightningModule, Trainer, seed_everything
from lightning.pytorch.loggers import CometLogger
from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint
from typing import Optional, Tuple
from torch.special import expm1
from tqdm import trange
from torch.distributions.normal import Normal
import numpy as np
import matplotlib.pyplot as plt

device="cuda" if torch.cuda.is_available() else "mps"
print(device)

ModuleNotFoundError: No module named 'data'

In [None]:
seed_everything(7)
cropsize = 256
batch_size = 12
num_workers = 8
   
dataset = 'Astrid'
learning_rate = 1e-3 #1e-4


In [None]:
dm = get_astro_data(
        dataset,
        num_workers=num_workers,
        # resize=cropsize,
        batch_size=batch_size,
    )

print(np.shape(dm.train_data), len(dm.valid_data),len(dm.test_data))
one_batch = next(iter( dm.train_dataloader()))
print(np.shape(one_batch))

In [None]:
class SimpleNet(nn.Module):
    def __init__(self):
        super().__init__()

        # the first convolutional layer, followed by batch normalization,
        # and then the ReLU activation function
        #input (12, 1, 256, 256) --> (12, 64, 256, 256)
        self.conv1 = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1, padding_mode='circular'),
            nn.BatchNorm2d(64),
            nn.ReLU()
        )

        # second convolutional layer has a similar sequence
        # (12, 64, 256, 256) --> (12, 128, 128, 128)
        self.downsample = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=2, stride=2, padding=0),
            nn.BatchNorm2d(128),
            nn.ReLU()
        )

        # (12, 128, 128, 128) --> (12, 64, 256, 256)
        self.upsample = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2, padding=0),
            nn.BatchNorm2d(64),
            nn.ReLU()
        )
        
        # (12, 64, 256, 256) --> (12, 1, 256, 256)
        self.output_layer = nn.ConvTranspose2d(64, 1, kernel_size=3, stride=1, padding=1) 


    # this is the forward pass!
    def forward(self, x):
        out = self.conv1(x)
        out = self.downsample(out)
        out = self.upsample(out)
        out = self.output_layer(out, output_size=x.size())
        return out

 "one_batch = next(iter( dm.train_dataloader()))\n",
    "conditioning, x = one_batch\n",
    "print(f'input: {np.shape(conditioning)}, output: {np.shape(x)}')\n",
    "model = SimpleNet()\n",
    "model = model.eval()\n",
    "out_1 = model.conv1(conditioning)\n",
    "print(f'out_layer1:{np.shape(out_1)}')\n",
    "out_2 = model.downsample(out_1)\n",
    "print(f'out_layer2:{np.shape(out_2)}')\n",
    "out_3 = model.upsample(out_2)\n",
    "print(f'out_layer2:{np.shape(out_3)}')\n",
    "final = model.output_layer(out_3)\n",
    "print(f'final_layer:{np.shape(final)}

In [None]:
class BasicCNN(LightningModule):
    def __init__(
        self,
        learning_rate: float = 3.0e-4,
        weight_decay: float = 1.0e-5,
        n_sampling_steps: int = 250,
        draw_figure=None,
        dataset='illustris',
        **kwargs
    ):
        super().__init__()
        self.save_hyperparameters(ignore=["draw_figure"])
        
        self.model= SimpleNet() #ResnetGenerator(1,1) #SimpleNet()
        self.dataset=dataset
        print("suite:", self.dataset)
        self.draw_figure=draw_figure
        if self.draw_figure is None:
            def draw_figure(args,**kwargs):
                fig=plt.figure(figsize=(5,5))
                return fig
            self.draw_figure=draw_figure

    def forward(self, x) -> Tensor:
        return self.model(x)

    def evaluate(self, batch: Tuple, stage: str = None) -> Tensor:

        cdm_map, true_map = batch
        mtot_pred = self(cdm_map)
        loss = mse_loss(mtot_pred, true_map)
        
        return loss

    def training_step(
        self,
        batch: Tuple,
        batch_idx: int,
    ) -> Tensor:
        return self.evaluate(batch, "train")


    def validation_step(self, batch: Tuple, batch_idx: int) -> Tensor:
        """validate model

        Args:
            batch (Tuple): batch of examples
            batch_idx (int): idx for batch

        Returns:
            Tensor: loss
        """
        conditioning, x = batch    
        loss = 0    
        
        if batch_idx == 0:
            sample = self(conditioning)
            loss = mse_loss(x, sample)
            fig = self.draw_figure(x,sample,conditioning,self.dataset)
            self.log_dict({'val_loss': loss}, on_epoch=True)
            if self.logger is not None:
                self.logger.experiment.log_figure(figure=fig)
            plt.close()
        return loss

    def test_step(self, batch, batch_idx):
        loss = self.evaluate(batch, "test")
        if self.logger is not None:
            self.logger.log_metrics({"test_loss": loss.mean()})
        return self.evaluate(batch, "test")

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(
            self.parameters(),
            lr=self.hparams.learning_rate,
            weight_decay=self.hparams.weight_decay,
        )
        scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
            optimizer=optimizer,
            T_0=10,

        )
        return {'optimizer': optimizer, 'lr_scheduler': scheduler}

In [None]:
cnn = BasicCNN(
        dataset=dataset,
        learning_rate=learning_rate,
        image_shape=(1,cropsize,cropsize),
        draw_figure=draw_figure,
    )
    # Checkpoint every time val_loss improves
val_checkpoint = ModelCheckpoint(
        filename="{epoch}-{step}-{val_loss:.3f}",
        monitor="val_loss",
        mode="min",
    )

    # Checkpoint at every 6000 steps
latest_checkpoint = ModelCheckpoint(
        filename="latest-{epoch}-{step}",
        monitor="step",
        mode="max",
        every_n_train_steps=60, #6000
        save_top_k=10
    )

comet_logger = CometLogger(
            api_key=os.environ.get("COMET_API_KEY"),
            project_name="baryonize_DM",
            experiment_name='SimpleCNN',
        )

trainer = Trainer(
        logger=comet_logger,
        accelerator="auto",
        max_epochs=10, #1000
        gradient_clip_val=0.5,
        callbacks=[LearningRateMonitor(),
                    latest_checkpoint,
                    val_checkpoint],
    )

In [None]:
trainer.fit(model=cnn, datamodule=dm)

In [None]:
test_loss = trainer.test(model=cnn, datamodule=dm)

trainer.logger.experiment.end()

In [None]:
trainer.logger.experiment.end()

In [None]:
ckpt = '/Users/link/PycharmProjects/bccp/baryonize_DM/8ad6ca35ce4d4f029b73dccbbe16282e/checkpoints/epoch=9-step=10000-val_loss=0.019.ckpt'
state_dict=torch.load(ckpt)["state_dict"]
cnn.load_state_dict(state_dict)
cnn.eval()

In [None]:
conditioning, x =  next(iter(dm.test_dataloader()))   #ValueError: too many values to unpack (expected 2)
with torch.no_grad():
    sample = cnn(conditioning)
fig = draw_figure(x,sample,conditioning,cnn.dataset)
fig.savefig('/Users/link/PycharmProjects/bccp/test_IllustrisTNG_1P.png')
plt.show()

In [None]:
print(np.shape(dm.test_data), np.shape(next(iter(dm.test_dataloader())) ), np.shape(next(iter(dm.train_dataloader())) )) #(2100, 2, 1, 256, 256) (2, 100, 1, 256, 256) (2, 12, 1, 256, 256)