In [None]:
import h5py
import mlflow
import numpy as np
import pandas as pd
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from torch.utils.data import DataLoader

from autoencoder.concrete_autoencoder import Decoder
from autoencoder.dataset import MRIMemoryDataset, MRIMemorySHDataset
from autoencoder.logger import logger, set_log_level
from autoencoder.spherical.CG import real_clebsch_gordan_all
from autoencoder.spherical.convolution import (
    QuadraticNonLinearity,
    S2Convolution,
    SO3Convolution,
)

In [None]:
set_log_level(10)

In [None]:
# use gpu if available, else cpu
has_cuda = torch.cuda.is_available()

logger.info("Is the GPU available? %s", has_cuda)

device = torch.device("cuda" if has_cuda else "cpu")
if has_cuda:
    logger.info("Current device: %s", torch.cuda.current_device())
    logger.info("Device count: %s", torch.cuda.device_count())
    logger.info("Using device: %s", torch.cuda.get_device_properties(device))
else:
    logger.warning("No GPU dectected! Training will be slow")

In [None]:
class SphericalDecoder(pl.LightningModule):
    def __init__(self, *, learning_rate: float = 1e-4, profiler=None) -> None:
        super(SphericalDecoder, self).__init__()

        self.learning_rate = learning_rate

        L = [2, 2, 0]
        CG_r, CG_l = real_clebsch_gordan_all(L[0], L[1], device="cuda")

        self.spherical = torch.nn.Sequential(
            S2Convolution(28, 3, L[0], 5, 8, profiler=profiler),
            QuadraticNonLinearity(L[0], L[1], CG_r, CG_l),
            SO3Convolution(28, 3, L[1], 8, 16, profiler=profiler),
            QuadraticNonLinearity(L[1], L[2], CG_r, CG_l),
        )
        self.linear = torch.nn.Linear(2688, 1344)

    def forward(self, x: dict[int, torch.Tensor]) -> torch.Tensor:
        _, features = self.spherical(x)
        return self.linear(features)

    def configure_optimizers(self) -> torch.optim.Adam:
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

    def training_step(
        self,
        batch: dict[str, torch.Tensor],
        batch_idx: int,
    ) -> torch.Tensor:
        return self._shared_eval(batch, batch_idx, "train")

    def validation_step(self, batch: torch.Tensor, batch_idx: int) -> torch.Tensor:
        return self._shared_eval(batch, batch_idx, "val")

    def _shared_eval(
        self, batch: torch.Tensor, batch_idx: int, prefix: str
    ) -> torch.Tensor:
        """Calculate the loss for a batch.

        Args:
            batch (torch.Tensor): batch data.
            batch_idx (int): batch id.
            prefix (str): prefix for logging.

        Returns:
            torch.Tensor: calculated loss.
        """
        data, target = batch["data"], batch["target"]

        decoded = self.forward(data)
        loss = F.mse_loss(decoded, target)

        self.log(f"{prefix}_loss", loss, on_step=False, on_epoch=True, prog_bar=True)

        return loss

In [None]:
latent_features = np.loadtxt("latent_features.txt", dtype=int)

batch_size = 256
train_subjects = [11, 12, 13, 14]
validate_subjects = [15]

train_dataset = MRIMemorySHDataset(
    "../data/data.hdf5",
    train_subjects,
    include=latent_features,
    l_max=2,
)

validate_dataset = MRIMemorySHDataset(
    "../data/data.hdf5",
    validate_subjects,
    include=latent_features,
    l_max=2,
)

train_dataloader = DataLoader(
    train_dataset,
    shuffle=True,
    pin_memory=True,
    batch_size=batch_size,
    num_workers=6,
)
validate_dataloader = DataLoader(validate_dataset, batch_size=batch_size, num_workers=6)

In [None]:
experiment_name = "spherical_decoder"

# mlflow.set_tracking_uri(os.environ["MLFLOW_ENDPOINT_URL"])

mlflow.set_experiment(experiment_name)
mlflow.pytorch.autolog()

profiler = pl.profiler.AdvancedProfiler(filename="test_sh.txt")
profiler = None

model = SphericalDecoder(profiler=profiler)
trainer = pl.Trainer(
    gpus=-1,
    profiler=profiler,
    max_epochs=2000,
    callbacks=[
        EarlyStopping(monitor="val_loss"),
        ModelCheckpoint(monitor="val_loss"),
    ],
)
trainer.fit(model, train_dataloader, validate_dataloader)

In [None]:
class LinearDecoder(pl.LightningModule):
    def __init__(self, *, learning_rate: float = 1e-3, profiler=None) -> None:
        super(LinearDecoder, self).__init__()

        self._learning_rate = learning_rate

        self._linear = torch.nn.Sequential(
            torch.nn.Linear(500, 800),
            torch.nn.LeakyReLU(0.2),
            torch.nn.Linear(800, 1344),
            torch.nn.LeakyReLU(0.2),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self._linear(x)

    def configure_optimizers(self) -> torch.optim.Adam:
        optimizer = torch.optim.Adam(self.parameters(), lr=self._learning_rate)
        return optimizer

    def training_step(
        self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int
    ) -> torch.Tensor:
        return self._shared_eval(batch, batch_idx, "train")

    def validation_step(
        self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int
    ) -> torch.Tensor:
        return self._shared_eval(batch, batch_idx, "val")

    def _shared_eval(
        self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int, prefix: str
    ) -> torch.Tensor:
        """Calculate the loss for a batch.

        Args:
            batch (tuple[torch.Tensor, torch.Tensor]): batch data, first element in the tuple is the target data, and the second element in the input data.
            batch_idx (int): batch id.
            prefix (str): prefix for logging.

        Returns:
            torch.Tensor: calculated loss.
        """
        target, data = batch

        decoded = self.forward(data)
        loss = F.mse_loss(decoded, target)

        self.log(f"{prefix}_loss", loss, on_step=False, on_epoch=True, prog_bar=True)

        return loss

In [None]:
latent_features = np.loadtxt("latent_features.txt", dtype=int)

batch_size = 256
train_subjects = [11, 12, 13, 14]
validate_subjects = [15]

train_dataset = MRIMemoryDataset(
    "../data/data.hdf5", train_subjects, include=latent_features, do_store_in_gpu=False
)

validate_dataset = MRIMemoryDataset(
    "../data/data.hdf5",
    validate_subjects,
    include=latent_features,
    do_store_in_gpu=False,
)

train_dataloader = DataLoader(
    train_dataset, shuffle=True, pin_memory=True, batch_size=batch_size, num_workers=6
)
validate_dataloader = DataLoader(validate_dataset, batch_size=batch_size, num_workers=6)

In [None]:
experiment_name = "linear_decoder"

# mlflow.set_tracking_uri(os.environ["MLFLOW_ENDPOINT_URL"])

mlflow.set_experiment(experiment_name)
mlflow.pytorch.autolog()

model = LinearDecoder()
trainer = pl.Trainer(
    gpus=-1,
    max_epochs=2000,
    callbacks=[
        EarlyStopping(monitor="val_loss"),
        ModelCheckpoint(monitor="val_loss"),
    ],
)
trainer.fit(model, train_dataloader, validate_dataloader)

In [None]:
L = [2, 2, 0]
CG_r, CG_l = real_clebsch_gordan_all(L[0], L[1], device="cpu")

s2_conv = S2Convolution(28, 3, L[0], L[1], 5, 8, CG_r, CG_l)

In [None]:
rh = dict()
rh[0] = torch.rand((256, 28, 3, 5, 1, 5))
rh[2] = torch.rand((256, 28, 3, 5, 5, 5))

In [None]:
for l in range(0, 2 + 1, 2):
    # rh_n_l_t = torch.transpose(rh[l], 2, 0)
    rh_n_l_p = torch.pow(rh[l], 2)
    rh_n_l_s = torch.sum(rh_n_l_p, (5, 4))

    # else:
    x = torch.cat((x, torch.flatten(rh_n_l_s, start_dim=1)), dim=1)
    print(rh_n_l_p.shape, rh_n_l_s.shape, x.shape)

In [None]:
x = torch.rand((165, 28, 3, 5, 5))
w = torch.rand((28, 3, 5, 8, 5))

# torch.einsum("nabil, abiok->nabolk", x, w) + torch.zeros(1, 28, 3, 8, 1, 1)

In [None]:
f = torch.einsum("nabil, abiok->nabolk", x, w)

In [None]:
f + torch.zeros(1, 28, 3, 8, 1, 1)

In [None]:
f.shape

In [None]:
torch.cat((x,), dim=1)