In [None]:
import torch

from astroclip.datasets import AstroClipDataloader
from astroclip.models.specformer import SpecFormer
from astroclip import format_with_env

In [None]:
DATASET_PATH = format_with_env("{ASTROCLIP_ROOT}/datasets/astroclip_file/")
dataset = AstroClipDataloader(
    path=DATASET_PATH,
    columns="spectrum",
    batch_size=32,
    num_workers=0,
)

In [None]:
dataset.setup(None)
dataloader = dataset.train_dataloader()
sample = next(iter(dataloader))

In [None]:
# Convert old checkpoint to new formats
model = SpecFormer(
    input_dim=22,
    embed_dim=768,
    num_layers=6,
    num_heads=6,
    max_len=800,
    dropout=0.0,
    norm_first=False,
)
checkpoint = torch.load(
    "/mnt/home/sgolkar/ceph/saves/fillm/run-seqformer-2708117/ckpt.pt"
)
model.load_state_dict(checkpoint["model"])
torch.save(
    {"state_dict": model.state_dict(), "hyper_parameters": model.hparams},
    format_with_env("{ASTROCLIP_ROOT}/pretrained/specformer.ckpt"),
)

In [None]:
CHECKPOINT_PATH = "{ASTROCLIP_ROOT}/pretrained/specformer.ckpt"
CHECKPOINT_PATH = format_with_env(
    "/mnt/ceph/users/polymathic/astroclip/outputs/astroclip-spectrum/wfms4nfu/checkpoints/last.ckpt"
)
checkpoint = torch.load(format_with_env(CHECKPOINT_PATH))
model = SpecFormer(**checkpoint["hyper_parameters"])
model.load_state_dict(checkpoint["state_dict"])

In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=[15, 5])
for samp in range(4):
    plt.subplot(1, 4, samp + 1)
    batch = next(iter(dataloader))["spectrum"]
    in_ = model.preprocess(batch)
    sp_ = model.mask_sequence(in_)
    out_ = model.forward_without_preprocessing(sp_)["predictions"].detach()
    plt.plot(in_[samp, :, 6], label="original")
    plt.plot(sp_[samp, :, 6].detach(), label="dropped", linestyle="--", alpha=0.5)
    plt.plot(out_[samp, :, 6].detach(), label="output")

    plt.legend()