In [1]:
from typing import Any, Optional

import matplotlib.pyplot as plt
import wandb, torch
from lightning import Callback, LightningModule, Trainer
from lightning.pytorch.cli import (
    ArgsType,
    LightningArgumentParser,
    LightningCLI,
    LRSchedulerTypeUnion,
)
from lightning.pytorch.loggers import WandbLogger
from torch.optim import Optimizer

from astroclip import format_with_env
from astroclip.callbacks import CustomSaveConfigCallback

A matching Triton is not available, some optimizations will not be enabled
Traceback (most recent call last):
  File "c:\Users\mi3se\AppData\Local\Programs\Python\Python310\lib\site-packages\xformers\__init__.py", line 57, in _is_triton_available
    import triton  # noqa
ModuleNotFoundError: No module named 'triton'


In [2]:
from astroclip.models.astroclip import AstroClipModel, ImageHead, SpectrumHead

In [3]:
image_encoder = ImageHead(
    config="astroclip/astrodino/config.yaml", 
    model_weights="pretrained/astrodino.ckpt", 
    save_directory="outputs/astrodino"
)
spectrum_encoder = SpectrumHead(
    model_path="pretrained/specformer.ckpt"
)
model = AstroClipModel(
    image_encoder=image_encoder, 
    spectrum_encoder=spectrum_encoder
)

In [4]:
from astroclip.data.datamodule import AstroClipDataloader, AstroClipCollator

In [5]:
data_loader = AstroClipDataloader(
    path="mhsotoudeh/astroclip-mini",
    columns=["image", "spectrum"],
    batch_size=32,
    num_workers=0,
    collate_fn=AstroClipCollator()
)

data_loader.setup("fit")

In [6]:
train = data_loader.train_dataloader()
val = data_loader.val_dataloader()

In [7]:
item = next(iter(train))

for k, v in item.items():
    print(k, v.shape)

In [8]:
from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint, 
from astroclip.callbacks import CustomWandbLogger

trainer = Trainer(
    default_root_dir="outputs",
    enable_checkpointing=True,
    gradient_clip_val=1.,
    max_epochs=10,
    precision=16,
    callbacks=[
        ModelCheckpoint(
            save_last=True, save_top_k=2, 
            every_n_epochs=1, monitor="val_loss_nologit"
        ), 
        LearningRateMonitor(logging_interval="step")
    ],
    logger=CustomWandbLogger(
        project="astroclip-alignment", save_dir="outputs"
    ),
    enable_progress_bar=True
)

Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [9]:
trainer.fit(
    model=model, train_dataloaders=train,
    val_dataloaders=val
)

You are using a CUDA device ('NVIDIA GeForce RTX 3060 Laptop GPU') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
wandb: Currently logged in as: khairulislamtanim (khairulislamtanim-university-of-virginia). Use `wandb login --relogin` to force relogin
wandb: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name             | Type         | Params | Mode 
----------------------------------------------------------
0 | image_encoder    | ImageHead    | 315 M  | train
1 | spectrum_encoder | SpectrumHead | 55.2 M | train
2 | criterion        | CLIPLoss     | 0      | train
----------------------------------------------------------
24.7 M    Trainable params
346 M     Non-trainable params
370 M     Total params
1,483.193 Total estimated model params size (MB)
102       Modules in train mode
422       Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=10` reached.


In [13]:
# Access the logged metrics
logged_metrics = trainer.logged_metrics
logged_metrics

{'train_loss_withlogit': tensor(3.4853),
 'train_loss_nologit': tensor(3.4853),
 'scale': tensor(2.7408),
 'val_loss_nologit': tensor(3.4664),
 'val_loss_withlogit': tensor(3.4664)}