In [None]:
import os, sys

sys.path.append("../..")

from astroclip.astroclip.preprocessing import ImageSpectrumCollator
from astroclip.astroclip.data import AstroClipDataloader
from astroclip.astroclip.modules import ImageModule, SpectrumModule

In [None]:
loader = AstroClipDataloader(
    "/mnt/ceph/users/polymathic/mmoma/datasets/astroclip_file/",
    collate_fn=ImageSpectrumCollator(center_crop=144),
    batch_size=10,
)

loader.setup("fit")

In [None]:
dummy = next(iter(loader.train_dataloader()))

In [None]:
%pylab inline

imshow(dummy["image"][0].permute(1, 2, 0))

In [None]:
imagemodule = ImageModule(
    save_directory="/mnt/home/lparker/ceph/dino_training",
    config="/mnt/home/lparker/Documents/AstroFoundationModel/AstroDino_legacy/astrodino/configs/ssl_default_config.yaml",
    model_weights="/mnt/home/lparker/ceph/astrodino/vitl12_simplified_better_wd/training_199999/teacher_checkpoint.pth",
)

In [None]:
imagemodule.cuda()

imagemodule.forward(dummy["image"].cuda())

In [None]:
from astroclip.astroclip.modules import CrossAttentionHead, MLP
import torch.nn as nn
import torch
import yaml


class SpectrumModule(nn.Module):
    def __init__(
        self,
        config: str,
        model_weights: str,
        embed_dim: int = 1024,
        n_head: int = 4,
        model_embed_dim: int = 768,
        dropout: float = 0.1,
        freeze_backbone: bool = True,
    ):
        """
        Cross-attention spectrum module that takes a spectrum and passes it through a pretrained SpecFormer model and
        then through a cross-attention mechanism and MLP to get the final embedding.

        Args:
            save_path (str): Path to the checkpoint of the SpecFormer model.
            embed_dim (int): Dimension of the AstroCLIP embedding.
            n_head (int): Number of heads in the multihead attention.
            model_embed_dim (int): Dimension of the SpecFormer embedding.
            dropout (float): Dropout rate for MLP layers.
            freeze_backbone (bool): Whether to freeze the backbone of the SpecFormer model.
        """
        super().__init__()
        # Load the model from the checkpoint
        checkpoint = torch.load(model_weights)
        config = yaml.safe_load(open(config))

        self.backbone.load_from_checkpoint(checkpoint["state_dict"])

        # Freeze backbone if necessary
        self.freeze_backbone = freeze_backbone
        if self.freeze_backbone:
            for param in self.backbone.parameters():
                param.requires_grad = False

        # Set up cross-attention
        self.cross_attention = CrossAttentionHead(
            embed_dim=embed_dim,
            n_head=n_head,
            model_embed_dim=model_embed_dim,
            dropout=dropout,
        )

        # Set up MLP
        self.mlp = MLP(
            embed_dim=embed_dim,
            dropout=dropout,
        )

    def forward(
        self, x: torch.tensor, y: torch.tensor = None, return_weights: bool = False
    ):
        # Slice the spectrum
        # TODO: use spectrum collate function
        x = fnc(x.unsqueeze(-1))

        # Embed the spectrum using the pretrained model
        if self.freeze_backbone:
            with torch.no_grad():
                embedding = self.backbone(x)["embedding"]
        else:
            embedding = self.backbone(x)["embedding"]

        # Pass through cross-attention
        x, attentions = self.cross_attention(embedding)

        # Pass through MLP and residual connection
        x += self.mlp(x)

        if return_weights:
            return x.squeeze(), attentions[1]

        return x.squeeze()

In [None]:
import yaml


def load_config(file_path):
    with open(file_path, "r") as file:
        config = yaml.safe_load(file)
    return config


load_config(
    "/mnt/home/lparker/Documents/AstroFoundationModel/AstroCLIP/astroclip/specformer/config.yaml"
)