In [None]:
import os

os.chdir("../..")

In [None]:
import torch

from astroclip.datasets.astroclip_dataloader import AstroClipDataset
from astroclip.datasets.preprocessing.spectrum import SpectrumCollator
from astroclip.specformer.model import SpecFormer
from astroclip import format_with_env

In [None]:
DATASET_PATH = format_with_env("{ASTROCLIP_ROOT}/datasets/astroclip_file/")
collator = SpectrumCollator(num_chunks=0, chunk_width=50)
dataset = AstroClipDataset(
    path=DATASET_PATH,
    columns="spectrum",
    batch_size=32,
    num_workers=0,
    collate_fn=collator,
)

In [None]:
dataset.setup(None)
dataloader = dataset.train_dataloader()
sample = next(iter(dataloader))

In [None]:
CHECKPOINT_PATH = "{ASTROCLIP_ROOT}/pretrained/specformer.ckpt"
CHECKPOINT_PATH = format_with_env(
    "/mnt/ceph/users/polymathic/astroclip/outputs/astroclip-spectrum/364hmgbl/checkpoints/last.ckpt"
)

checkpoint = torch.load(format_with_env(CHECKPOINT_PATH))

In [None]:
model = SpecFormer(
    input_dim=22,
    embed_dim=768,
    num_layers=6,
    num_heads=6,
    max_len=800,
    dropout=0.0,
    norm_first=False,
)

model.load_state_dict(state_dict=checkpoint["state_dict"])

In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=[15, 5])
for samp in range(4):
    plt.subplot(1, 4, samp + 1)
    batch = next(iter(dataloader))
    sp_ = batch["target"][samp, :, 6]
    in_ = batch["input"][samp, :, 6]
    out_ = model(batch["input"])[samp, :, 6].detach()
    plt.plot(sp_, label="original")
    plt.plot(in_, label="dropped", linestyle="--", alpha=0.5)
    plt.plot(out_, label="output")

    plt.legend()