In [1]:
import h5py
import mlflow
import numpy as np
import pandas as pd
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from torch.utils.data import DataLoader

from autoencoder.dataset import MRIMemorySHDataset
from autoencoder.logger import logger, set_log_level
from autoencoder.spherical.CG import real_clebsch_gordan_all
from autoencoder.spherical.convolution import S2Convolution, SO3Convolution

In [2]:
set_log_level(10)

In [3]:
# use gpu if available, else cpu
has_cuda = torch.cuda.is_available()

logger.info("Is the GPU available? %s", has_cuda)

device = torch.device("cuda" if has_cuda else "cpu")
if has_cuda:
    logger.info("Current device: %s", torch.cuda.current_device())
    logger.info("Device count: %s", torch.cuda.device_count())
    logger.info("Using device: %s", torch.cuda.get_device_properties(device))
else:
    logger.warning("No GPU dectected! Training will be slow")

[38;21m2021-12-15 15:38:40,883 - MUDI - INFO - Is the GPU available? True (4121038090.py:4)[0m
[38;21m2021-12-15 15:38:40,886 - MUDI - INFO - Current device: 0 (4121038090.py:8)[0m
[38;21m2021-12-15 15:38:40,887 - MUDI - INFO - Device count: 1 (4121038090.py:9)[0m
[38;21m2021-12-15 15:38:40,888 - MUDI - INFO - Using device: _CudaDeviceProperties(name='NVIDIA GeForce GTX 1080', major=6, minor=1, total_memory=8116MB, multi_processor_count=20) (4121038090.py:10)[0m


In [4]:
latent_features = np.loadtxt("latent_features.txt", dtype=int)

batch_size = 256
train_subjects = [11, 12, 13, 14]
validate_subjects = [15]

train_dataset = MRIMemorySHDataset(
    "../data/data.hdf5",
    train_subjects,
    include=latent_features,
    l_max=2,
)

validate_dataset = MRIMemorySHDataset(
    "../data/data.hdf5",
    validate_subjects,
    include=latent_features,
    l_max=2,
)

train_dataloader = DataLoader(
    train_dataset,
    shuffle=True,
    pin_memory=True,
    batch_size=batch_size,
    num_workers=6,
)
validate_dataloader = DataLoader(validate_dataset, batch_size=batch_size, num_workers=6)

  Y[:, n_sph : n_sph + 2 * i + 1] = Y_n_m


In [7]:
class SphericalDecoder(pl.LightningModule):
    def __init__(self, *, learning_rate: float = 1e-3, profiler=None) -> None:
        super(SphericalDecoder, self).__init__()

        self.learning_rate = learning_rate

        L = [2, 2, 0]
        CG_r, CG_l = real_clebsch_gordan_all(L[0], L[1], device="cuda")

        self.s2_conv = S2Convolution(
            28, 3, L[0], L[1], 5, 8, CG_r, CG_l, profiler=profiler
        )
        self.so3_conv1 = SO3Convolution(
            28, 3, L[1], L[2], 8, 16, CG_r, CG_l, profiler=profiler
        )

        self.linear = torch.nn.Sequential(
            torch.nn.Linear(2688, 1344),
        )

    def forward(self, x: dict[int, torch.Tensor]) -> torch.Tensor:
        # print(x[0], x[2])
        rh0, feats0 = self.s2_conv(x)
        rh1, feats1 = self.so3_conv1(rh0)

        features = torch.cat((feats0, feats1), dim=1)
        # print(features.shape)
        decoded = self.linear(features)

        return decoded

    def configure_optimizers(self) -> torch.optim.Adam:
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

    def training_step(
        self,
        batch: dict[str, torch.Tensor],
        batch_idx: int,
    ) -> torch.Tensor:
        return self._shared_eval(batch, batch_idx, "train")

    def validation_step(self, batch: torch.Tensor, batch_idx: int) -> torch.Tensor:
        return self._shared_eval(batch, batch_idx, "val")

    def _shared_eval(
        self, batch: torch.Tensor, batch_idx: int, prefix: str
    ) -> torch.Tensor:
        """Calculate the loss for a batch.

        Args:
            batch (torch.Tensor): batch data.
            batch_idx (int): batch id.
            prefix (str): prefix for logging.

        Returns:
            torch.Tensor: calculated loss.
        """
        data, target = batch["data"], batch["target"]

        decoded = self.forward(data)
        loss = F.mse_loss(decoded, target)

        self.log(f"{prefix}_loss", loss, on_step=False, on_epoch=True, prog_bar=True)

        return loss

In [8]:
experiment_name = "spherical_decoder"

# mlflow.set_tracking_uri(os.environ["MLFLOW_ENDPOINT_URL"])

mlflow.set_experiment(experiment_name)
mlflow.pytorch.autolog()

profiler = pl.profiler.AdvancedProfiler(filename="test_sh.txt")
profiler = None

model = SphericalDecoder(profiler=profiler)
trainer = pl.Trainer(
    gpus=-1,
    profiler=profiler,
    max_epochs=2000,
    callbacks=[
        EarlyStopping(monitor="val_loss"),
        ModelCheckpoint(monitor="val_loss"),
    ],
)
trainer.fit(model, train_dataloader, validate_dataloader)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
2021/12/15 15:39:15 INFO mlflow.utils.autologging_utils: Created MLflow autologging run with ID '9c9345ffada4436e9bb7a7e714b92f65', which will track hyperparameters, performance metrics, model artifacts, and lineage information for the current pytorch workflow
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type           | Params
---------------------------------------------
0 | s2_conv   | S2Convolution  | 20.8 K
1 | so3_conv1 | SO3Convolution | 280 K 
2 | linear    | Sequential     | 3.6 M 
---------------------------------------------
3.9 M     Trainable params
0         Non-trainable params
3.9 M     Total params
15.663    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

Training: -1it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

  rank_zero_deprecation(


Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



In [None]:
import h5py

from autoencoder.spherical.harmonics import (
    convert_cart_to_s2,
    gram_schmidt_sh_inv,
    sh_basis_real,
)

with h5py.File("../data/data.hdf5", "r") as archive:
    scheme = archive.get("scheme")[()]

filter_scheme = (scheme[:, 3] == 2000) & (scheme[:, 4] == 20) & (scheme[:, 5] == 80)
l = 2
a = convert_cart_to_s2(scheme[:, :3])
a = sh_basis_real(a, l)
a = gram_schmidt_sh_inv(a, l, n_iters=1000)
a = torch.from_numpy(a)[np.newaxis, :, :]