In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from src.basemodule import BaseModel, BaseLightningModule, BaseSpecDataset, BaseDataModule

In [3]:
import yaml
def load_config(config_path):
    with open(config_path, 'r') as file:
        config = yaml.safe_load(file)
    return config

config  = load_config('../configs/config.yaml')

In [4]:
config

{'model': {'image_size': 4096,
  'patch_size': 100,
  'hidden_size': 128,
  'num_hidden_layers': 6,
  'num_attention_heads': 8,
  'num_labels': 2,
  'stride_ratio': 1,
  'proj_fn': 'SW'},
 'train': {'batch_size': 128, 'ep': 1, 'debug': 0, 'workers': 24},
 'loss': {'name': 'T1'},
 'opt': {'type': 'adam',
  'lr': 0.001,
  'lr_sch': 'plateau',
  'factor': 0.8,
  'patience': 2},
 'data': {'file_path': '/datascope/subaru/user/swei20/data/bosz50000/test/mag215/train_100k/dataset.h5',
  'val_path': '/datascope/subaru/user/swei20/data/bosz50000/mag215/train_1k/dataset.h5',
  'test_path': '/datascope/subaru/user/swei20/data/bosz50000/mag215/val_1k/dataset.h5',
  'num_samples': 1000,
  'num_test_samples': 1000,
  'param_idx': 1},
 'mask': {'mask_ratio': 0.85},
 'noise': {'noise_level': 0},
 'project': 'vit-test'}

# Model

from src.model import MyViT, runs perfectly

In [5]:
from transformers import ViTModel, ViTConfig
import torch.nn as nn

In [6]:
def get_model_config(config, num_classes=2):
    """
    Create a ViTConfig object based on the provided configuration.
    Args:
        config (dict): Configuration dictionary containing model parameters.
        num_classes (int): Number of output classes for classification tasks.
        image_size (int): Size of the input images.
    Returns:
        ViTConfig: Config object for the Vision Transformer model.
    """
    

    vit_config = ViTConfig(
        image_size=config['model']['image_size'],
        patch_size=config['model']['patch_size'],
        num_channels=1,
        hidden_size=config['model']['hidden_size'],
        num_hidden_layers=config['model']['num_hidden_layers'],
        num_attention_heads=config['model']['num_attention_heads'],
        intermediate_size=4 * config['model']['hidden_size'],
        stride_ratio=config['model']['stride_ratio'],
        proj_fn=config['model']['proj_fn'],

        hidden_act="gelu",
        hidden_dropout_prob=0.1,
        attention_probs_dropout_prob=0.1,
        initializer_range=0.02,
        layer_norm_eps=1e-12,
        is_encoder_decoder=False,
        use_mask_token=False,
        qkv_bias=True,
        num_labels=num_classes,
        noise_level=config['noise']['noise_level'],
        learning_rate=config['opt']['lr'],
    )
    return vit_config

In [9]:
vit_config = get_model_config(config, num_classes=2)

In [10]:
from src.model import MyViT

In [11]:
m = MyViT(vit_config)

In [12]:
import torch
a = torch.randn(2, 4096)
label = torch.randint(0, 2, (2,))
out = m(a, labels = label)

In [14]:
out.loss

tensor(0.6387, grad_fn=<NllLossBackward0>)

In [15]:
out

ImageClassifierOutput(loss=tensor(0.6387, grad_fn=<NllLossBackward0>), logits=tensor([[ 0.0029,  0.0407],
        [-0.1544,  0.0349]], grad_fn=<AddmmBackward0>), hidden_states=None, attentions=None)

# Dataset

In [7]:
MASK_PATH = '/datascope/subaru/user/swei20/model/bosz50000_mask.npy'
import numpy as np

#region --DATA-----------------------------------------------------------
class SpecTrainDataset(BaseSpecDataset):
    def load_data(self, stage=None) -> None:
        super().load_data(stage=stage)
        if self.mask_ratio is not None:
            if self.mask_ratio < 1:
                self.mask = np.load(MASK_PATH)
                self.apply_mask()
    
class SpecTestDataset(BaseSpecDataset):
    @classmethod
    def from_dataset(cls, dataset, stage='test'):
        keys = ['file_path', 'val_path', 'test_path', 'num_samples', 'num_test_samples', 'root_dir', 'mask_ratio', 'mask_filler', 'mask', 'lvrg_num', 'lvrg_mask', 'noise_level', 'noise_max']
        c = cls(**{k: getattr(dataset, k) for k in keys}) 
        if stage == 'val': c.num_test_samples = min(c.num_test_samples, 1000) 
        return c
    def load_data(self, stage=None) -> None:
        super().load_data(stage=stage)
        if self.mask is None and self.mask_ratio is not None:
            if self.mask_ratio < 1:
                self.mask = np.load(MASK_PATH)
            # self.mask = self.create_quantile_mask(self.error, ratio=self.mask_ratio)
        if self.mask is not None: 
            self.mask_plot = {'wave': self.wave, 'error':self.error[0], 'mask': self.mask}
            self.apply_mask()
            self.mask_plot.update({'masked_error': self.error[0]})       
        self.set_noise()    
        
    def __getitem__(self, idx: int) -> torch.Tensor:
        return self.noisy[idx], self.flux[idx], self.error[idx]
    
    def set_noise(self, seed=42):
        torch.manual_seed(seed)
        self.noise = torch.randn_like(self.flux) * self.error * self.noise_level
        self.noisy = self.flux + self.noise
        self.flux_rms = torch.norm(self.flux, dim=-1)
        self.snr0 = torch.div(self.flux_rms , torch.norm(self.noise, dim=-1))
        
    # def get_single_spectrum_noise_testset(self, sample_idx=0, repeat=1000, seed=42):
    #     flux_0, error_0  = self.flux[sample_idx], self.error[sample_idx]
    #     test_dataset = SingleSpectrumNoiseDataset(flux_0, error_0, noise_level=self.noise_level,repeat=repeat, seed=seed)
    #     return test_dataset
    
#endregion --DATA-----------------------------------------------------------
#region --DATAMODULE-----------------------------------------------------------
class SpecDataModule(BaseDataModule):
    @classmethod
    def from_config(cls, config):
        return super().from_config(dataset_cls=SpecTrainDataset, config=config)
    def setup_test_dataset(self, stage):
        if hasattr(self, 'train'):
            return SpecTestDataset.from_dataset(self.train, stage) 
        return SpecTestDataset.from_config(self.config)
#endregion --DATAMODULE-----------------------------------------------------------

In [16]:
class ClassSpecDataset(BaseSpecDataset):
    def __init__(self, param_idx=1, **kwargs):
        super().__init__(**kwargs)
        self.param_idx = param_idx  # 指定使用哪个参数作为标签
        
    def load_data(self, stage=None):
        super().load_data(stage)
        self.load_params(stage)
        self.labels = (torch.tensor(self.logg > 2.5)).long() 
        
        # 将连续参数离散化为分类标签
        # params = torch.tensor(getattr(self, ['teff', 'logg', 'mh'][self.param_idx]))
        # self.labels = self.discretize_params(params)
        
    def discretize_params(self, params, bins=10):
        # 等频分箱创建分类标签
        quantiles = torch.linspace(0, 1, bins+1)
        bin_edges = torch.quantile(params, quantiles)
        return torch.bucketize(params, bin_edges[1:-1]).long()
    
    def __getitem__(self, idx):
        flux, error = super().__getitem__(idx)
        return flux, error, self.labels[idx]

In [17]:
c = ClassSpecDataset.from_config(config)

/datascope/subaru/user/swei20/data/bosz50000/test/mag215/train_100k/dataset.h5 1000 /datascope/subaru/user/swei20/data/bosz50000/mag215/val_1k/dataset.h5 1000 /datascope/subaru/user/swei20/data/bosz50000/mag215/train_1k/dataset.h5 None ./results


In [18]:
c.load_data('train')

loading data from /datascope/subaru/user/swei20/data/bosz50000/test/mag215/train_100k/dataset.h5 1000
torch.Size([1000, 4096]) torch.Size([1000, 4096]) torch.Size([4096]) 1000 4096


In [19]:
c[0]

(tensor([0.5135, 0.6785, 0.5748,  ..., 0.4817, 0.4453, 0.4898]),
 tensor([0.0587, 0.0587, 0.0572,  ..., 0.0632, 0.0564, 0.0564]),
 tensor(1))

In [32]:
class ViTDataModule(BaseDataModule):
    @classmethod
    def from_config(cls, config):
        return super().from_config(dataset_cls=ClassSpecDataset, config=config)
    def setup_test_dataset(self, stage):
        return ClassSpecDataset.from_config(self.config)
    # def setup(self, stage=None):
    #     super().setup(stage)
    #     # 添加通道维度 (B, L) -> (B, 1, L)
    #     self.train.flux = self.train.flux.unsqueeze(1)
    #     self.train.error = self.train.error.unsqueeze(1)
    #     self.val.flux = self.val.flux.unsqueeze(1)
    #     self.val.error = self.val.error.unsqueeze(1)

In [33]:
dd = ViTDataModule.from_config(config)
dd.setup('fit')

/datascope/subaru/user/swei20/data/bosz50000/test/mag215/train_100k/dataset.h5 1000 /datascope/subaru/user/swei20/data/bosz50000/mag215/val_1k/dataset.h5 1000 /datascope/subaru/user/swei20/data/bosz50000/mag215/train_1k/dataset.h5 None ./results
loading data from /datascope/subaru/user/swei20/data/bosz50000/test/mag215/train_100k/dataset.h5 1000
torch.Size([1000, 4096]) torch.Size([1000, 4096]) torch.Size([4096]) 1000 4096
/datascope/subaru/user/swei20/data/bosz50000/test/mag215/train_100k/dataset.h5 1000 /datascope/subaru/user/swei20/data/bosz50000/mag215/val_1k/dataset.h5 1000 /datascope/subaru/user/swei20/data/bosz50000/mag215/train_1k/dataset.h5 None ./results
loading data from /datascope/subaru/user/swei20/data/bosz50000/mag215/train_1k/dataset.h5 1000
torch.Size([1000, 4096]) torch.Size([1000, 4096]) torch.Size([4096]) 1000 4096


In [37]:
dd.train

<__main__.ClassSpecDataset at 0x7f945ad64190>

In [51]:
from torch.utils.data import DataLoader
d = DataLoader(c, batch_size=100, num_workers=0, shuffle=True)
with torch.no_grad():
    for batch in d:
        print(batch)
        flux, error, labels = batch
        flux = flux.to(m.device)
        error = error.to(m.device)
        labels = labels.to(m.device)

        output = m(flux, labels=labels)
        
        break
output.loss


[tensor([[0.3198, 0.2749, 0.3066,  ..., 0.5203, 0.5208, 0.5208],
        [0.5668, 0.5674, 0.5675,  ..., 0.4417, 0.4414, 0.4416],
        [0.6054, 0.6077, 0.6191,  ..., 0.4398, 0.4393, 0.4391],
        ...,
        [0.5923, 0.5635, 0.5696,  ..., 0.4658, 0.4704, 0.4721],
        [0.6074, 0.6073, 0.6073,  ..., 0.4152, 0.4156, 0.4155],
        [0.1181, 0.1411, 0.1168,  ..., 0.7641, 0.8004, 0.7934]]), tensor([[0.1139, 0.1139, 0.1131,  ..., 0.1299, 0.1167, 0.1167],
        [0.1397, 0.1397, 0.1384,  ..., 0.1555, 0.1389, 0.1389],
        [0.0603, 0.0603, 0.0599,  ..., 0.0679, 0.0598, 0.0598],
        ...,
        [0.0399, 0.0399, 0.0397,  ..., 0.0443, 0.0395, 0.0395],
        [0.1311, 0.1311, 0.1298,  ..., 0.1475, 0.1309, 0.1309],
        [0.0731, 0.0731, 0.0721,  ..., 0.0891, 0.0801, 0.0801]]), tensor([1, 1, 0, 0, 1, 1, 0, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 0, 1, 0, 0, 0,
        1, 1, 0, 1, 0, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 0,
        1, 1, 0, 0, 1, 1, 0, 0, 1, 0, 1, 

tensor(0.6826)

In [52]:
get_acc =Accuracy(task='multiclass', num_classes=2)

In [53]:
get_acc(output.logits, labels)

tensor(0.5600)

In [19]:
import lightning as L

In [None]:
from src.basemodule import BaseTrainer
t = BaseTrainer(config, num_gpus=1)


In [None]:
from src.basemodule import BaseLightningModule
lm = BaseLightningModule.from_config(config, model_cls=MyViT, dataset_cls=ClassSpecDataset)

In [None]:
from torchmetrics import Accuracy

class ViTLModule(BaseLightningModule):
    def __init__(self, model, data_module, config):
        super().__init__(model=model, data_module=data_module, config=config)
        self.save_hyperparameters()
        self.loss_name = 'train'  # Set the loss name for logging
        self.model.loss_name = self.loss_name  # Ensure the model has the loss name set
        self.get_accuracy = Accuracy(task='multiclass', num_classes=config['num_labels'])

    def forward(self, flux, labels, loss_only=True):
        outputs = self.model(flux, labels=labels)
        if loss_only:
            return outputs.loss
        else:
            return outputs
        
    def training_step(self, batch, batch_idx):
        flux, _, labels = batch
        loss = self.forward(flux, labels, loss_only=True)
        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        flux, _, labels = batch
        outputs = self.forward(flux, labels, loss_only=False)
        self.log('val_loss', outputs.loss, on_step=True, on_epoch=True, prog_bar=True)
        accuracy = self.get_accuracy(outputs.logits, labels)
        self.log('val_acc', accuracy, on_step=False, on_epoch=True,  prog_bar=False)
        return outputs.loss
        
lm = ViTLModule(model=m, data_module=dd, config=config)
t.fit(lm,datamodule=lm.data_module) 

In [43]:
m.loss_name = 'train'

In [54]:
lm = ViTLModule(model=m, data_module=dd, config=config)
t.fit(lm,datamodule=lm.data_module)


/srv/local/tmp/swei20/miniconda3/envs/viska-torch-3/lib/python3.13/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /srv/local/tmp/swei20/miniconda3/envs/viska-torch-3/ ...
/srv/local/tmp/swei20/miniconda3/envs/viska-torch-3/lib/python3.13/site-packages/lightning/pytorch/trainer/configuration_validator.py:68: You passed in a `val_dataloader` but have no `validation_step`. Skipping val loop.


/datascope/subaru/user/swei20/data/bosz50000/test/mag215/train_100k/dataset.h5 1000 /datascope/subaru/user/swei20/data/bosz50000/mag215/val_1k/dataset.h5 1000 /datascope/subaru/user/swei20/data/bosz50000/mag215/train_1k/dataset.h5 None ./results
loading data from /datascope/subaru/user/swei20/data/bosz50000/test/mag215/train_100k/dataset.h5 1000
torch.Size([1000, 4096]) torch.Size([1000, 4096]) torch.Size([4096]) 1000 4096
/datascope/subaru/user/swei20/data/bosz50000/test/mag215/train_100k/dataset.h5 1000 /datascope/subaru/user/swei20/data/bosz50000/mag215/val_1k/dataset.h5 1000 /datascope/subaru/user/swei20/data/bosz50000/mag215/train_1k/dataset.h5 None ./results
loading data from /datascope/subaru/user/swei20/data/bosz50000/mag215/train_1k/dataset.h5 1000


/srv/local/tmp/swei20/miniconda3/envs/viska-torch-3/lib/python3.13/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:654: Checkpoint directory /home/swei20/VIT/evals/checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [4,5,6,7]

  | Name  | Type  | Params | Mode 
----------------------------------------
0 | model | MyViT | 1.2 M  | train
----------------------------------------
1.2 M     Trainable params
0         Non-trainable params
1.2 M     Total params
4.834     Total estimated model params size (MB)
112       Modules in train mode
0         Modules in eval mode


torch.Size([1000, 4096]) torch.Size([1000, 4096]) torch.Size([4096]) 1000 4096


`Trainer.fit` stopped: `max_epochs=10` reached.


In [9]:
config = {
    'data': {'file_path': '/home/swei20/SirenSpec/tests/spec/test_dataset.h5', 'num_samples': 10},
}

In [20]:
class Experiment:
    def __init__(self, config, use_wandb=False, num_gpus=None, sweep=False, ckpt_path=None):
        # 创建数据模块
        dm = ViTDataModule.from_config(config)
        dm.setup()
        
        # 创建模型
        model = ViTModel(config, num_labels=config.get('num_classes', 10))
        
        # 创建Lightning模块
        self.lightning_module = ViTLightningModule(model, dm, config)
        
        # 其余初始化保持不变...
        self.lightning_module.sweep = sweep
        if use_wandb:
            logger = L.pytorch.loggers.WandbLogger(
                project=config['project'],
                config=config,
                name=config.get('exp_name', 'ViT_experiment'),
                log_model=True
            )
        else:
            logger = None
            
        self.t = SpecTrainer(config=config, logger=logger, 
                            num_gpus=num_gpus, sweep=sweep)
        self.ckpt_path = ckpt_path

    # run方法保持不变...