In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from src.basemodule import BaseModel, BaseLightningModule, BaseSpecDataset, BaseDataModule

# Data loader

In [3]:
import yaml
def load_config(config_path):
    with open(config_path, 'r') as file:
        config = yaml.safe_load(file)
    return config

config  = load_config('../configs/config.yaml')

In [4]:
config

{'model': {'image_size': 4096,
  'patch_size': 100,
  'hidden_size': 128,
  'num_hidden_layers': 6,
  'num_attention_heads': 8,
  'num_labels': 2,
  'stride_ratio': 1,
  'proj_fn': 'SW'},
 'train': {'batch_size': 128, 'ep': 1, 'debug': 0, 'workers': 24},
 'loss': {'name': 'T1'},
 'opt': {'type': 'adam',
  'lr': 0.001,
  'lr_sch': 'plateau',
  'factor': 0.8,
  'patience': 2},
 'data': {'file_path': '/datascope/subaru/user/swei20/data/bosz50000/test/mag215/train_100k/dataset.h5',
  'val_path': '/datascope/subaru/user/swei20/data/bosz50000/mag215/train_1k/dataset.h5',
  'test_path': '/datascope/subaru/user/swei20/data/bosz50000/mag215/val_1k/dataset.h5',
  'num_samples': 1000,
  'num_test_samples': 1000,
  'param_idx': 1},
 'mask': {'mask_ratio': 0.85},
 'noise': {'noise_level': 0},
 'project': 'vit-test'}

In [3]:
config = {
    'image_size': ,
}

SyntaxError: expression expected after dictionary key and ':' (1493945288.py, line 2)

# Model

In [11]:
import yaml
def load_config(config_path):
    with open(config_path, 'r') as file:
        config = yaml.safe_load(file)
    return config

In [14]:
config  = load_config('../configs/config.yaml')

In [15]:
from transformers import ViTModel, ViTConfig
import torch.nn as nn

In [18]:
def get_model_config(config, num_classes=2):
    """
    Create a ViTConfig object based on the provided configuration.
    Args:
        config (dict): Configuration dictionary containing model parameters.
        num_classes (int): Number of output classes for classification tasks.
        image_size (int): Size of the input images.
    Returns:
        ViTConfig: Config object for the Vision Transformer model.
    """
    

    vit_config = ViTConfig(
        image_size=config['model']['image_size'],
        patch_size=config['model']['patch_size'],
        num_channels=1,
        hidden_size=config['model']['hidden_size'],
        num_hidden_layers=config['model']['num_hidden_layers'],
        num_attention_heads=config['model']['num_attention_heads'],
        intermediate_size=4 * config['model']['hidden_size'],
        stride_ratio=config['model']['stride_ratio'],
        proj_fn=config['model']['proj_fn'],

        hidden_act="gelu",
        hidden_dropout_prob=0.1,
        attention_probs_dropout_prob=0.1,
        initializer_range=0.02,
        layer_norm_eps=1e-12,
        is_encoder_decoder=False,
        use_mask_token=False,
        qkv_bias=True,
        num_labels=num_classes,
        noise_level=config['noise']['noise_level'],
        learning_rate=config['opt']['lr'],
    )
    return vit_config

In [19]:
vit_config = get_model_config(config)

In [None]:
model_config = {
    
}

In [None]:


class ViTModel(BaseModel):
    def __init__(self, config, num_labels=None):
        super().__init__(model_name='ViT', loss_name='cls')
        vit_config = ViTConfig(
            image_size=4096 or config['image_size'],
            patch_size=200 or config['patch_size'],
            num_channels=1,
            hidden_size=config['hidden_size'],
            num_hidden_layers=config['num_hidden_layers'],
            num_attention_heads=config['num_attention_heads'],
            intermediate_size=4 * config['hidden_size'],
            num_labels=num_labels or config.get('num_labels', 2)
        )
        self.vit = ViTModel(vit_config)
        self.classifier = nn.Linear(vit_config.hidden_size, vit_config.num_labels)
        
    def forward(self, x, labels=None):
        outputs = self.vit(x)
        logits = self.classifier(outputs.last_hidden_state[:, 0])
        return {'logits': logits, 'loss': self.compute_loss(logits, labels)}
    
    def compute_loss(self, logits, labels):
        return nn.CrossEntropyLoss()(logits, labels) if labels is not None else None
    
    def log_outputs(self, outputs, log_fn=print, stage=''):
        loss = outputs.get('loss')
        if loss is not None:
            log_fn(f'{stage}_{self.loss_name}_loss', loss.item())

In [None]:
m = ViTModel(config)

In [6]:
from torchmetrics import Accuracy
import torch

In [7]:
MASK_PATH = '/datascope/subaru/user/swei20/model/bosz50000_mask.npy'
import numpy as np

#region --DATA-----------------------------------------------------------
class SpecTrainDataset(BaseSpecDataset):
    def load_data(self, stage=None) -> None:
        super().load_data(stage=stage)
        if self.mask_ratio is not None:
            if self.mask_ratio < 1:
                self.mask = np.load(MASK_PATH)
                self.apply_mask()
    
class SpecTestDataset(BaseSpecDataset):
    @classmethod
    def from_dataset(cls, dataset, stage='test'):
        keys = ['file_path', 'val_path', 'test_path', 'num_samples', 'num_test_samples', 'root_dir', 'mask_ratio', 'mask_filler', 'mask', 'lvrg_num', 'lvrg_mask', 'noise_level', 'noise_max']
        c = cls(**{k: getattr(dataset, k) for k in keys}) 
        if stage == 'val': c.num_test_samples = min(c.num_test_samples, 1000) 
        return c
    def load_data(self, stage=None) -> None:
        super().load_data(stage=stage)
        if self.mask is None and self.mask_ratio is not None:
            if self.mask_ratio < 1:
                self.mask = np.load(MASK_PATH)
            # self.mask = self.create_quantile_mask(self.error, ratio=self.mask_ratio)
        if self.mask is not None: 
            self.mask_plot = {'wave': self.wave, 'error':self.error[0], 'mask': self.mask}
            self.apply_mask()
            self.mask_plot.update({'masked_error': self.error[0]})       
        self.set_noise()    
        
    def __getitem__(self, idx: int) -> torch.Tensor:
        return self.noisy[idx], self.flux[idx], self.error[idx]
    
    def set_noise(self, seed=42):
        torch.manual_seed(seed)
        self.noise = torch.randn_like(self.flux) * self.error * self.noise_level
        self.noisy = self.flux + self.noise
        self.flux_rms = torch.norm(self.flux, dim=-1)
        self.snr0 = torch.div(self.flux_rms , torch.norm(self.noise, dim=-1))
        
    # def get_single_spectrum_noise_testset(self, sample_idx=0, repeat=1000, seed=42):
    #     flux_0, error_0  = self.flux[sample_idx], self.error[sample_idx]
    #     test_dataset = SingleSpectrumNoiseDataset(flux_0, error_0, noise_level=self.noise_level,repeat=repeat, seed=seed)
    #     return test_dataset
    
#endregion --DATA-----------------------------------------------------------
#region --DATAMODULE-----------------------------------------------------------
class SpecDataModule(BaseDataModule):
    @classmethod
    def from_config(cls, config):
        return super().from_config(dataset_cls=SpecTrainDataset, config=config)
    def setup_test_dataset(self, stage):
        if hasattr(self, 'train'):
            return SpecTestDataset.from_dataset(self.train, stage) 
        return SpecTestDataset.from_config(self.config)
#endregion --DATAMODULE-----------------------------------------------------------

In [None]:
class ClassSpecDataset(BaseSpecDataset):
    def __init__(self, param_idx=1, **kwargs):
        super().__init__(**kwargs)
        self.param_idx = param_idx  # 指定使用哪个参数作为标签
        
    def load_data(self, stage=None):
        super().load_data(stage)
        self.load_params(stage)
        
        # 将连续参数离散化为分类标签
        params = torch.tensor(getattr(self, ['teff', 'logg', 'mh'][self.param_idx]))
        self.labels = self.discretize_params(params)
        
    def discretize_params(self, params, bins=10):
        # 等频分箱创建分类标签
        quantiles = torch.linspace(0, 1, bins+1)
        bin_edges = torch.quantile(params, quantiles)
        return torch.bucketize(params, bin_edges[1:-1]).long()
    
    def __getitem__(self, idx):
        flux, error = super().__getitem__(idx)
        return flux, error, self.labels[idx]

In [10]:
c = ClassSpecDataset.from_config(config)

/home/swei20/SirenSpec/tests/spec/test_dataset.h5 10 None None None None ./results


In [None]:
c.load_data(stage='fit')

In [9]:
config = {
    'data': {'file_path': '/home/swei20/SirenSpec/tests/spec/test_dataset.h5', 'num_samples': 10},
}

In [17]:

class ViTLightningModule(BaseLightningModule):
    def __init__(self, model, data_module, config):
        super().__init__(model, data_module, config)
        self.accuracy = Accuracy(task='multiclass', num_classes=config['num_labels'])
        
    def training_step(self, batch, batch_idx):
        flux, error, labels = batch
        noise = torch.randn_like(flux) * error * self.config.get('noise_level', 0)
        inputs = flux + noise
        outputs = self.model(inputs, labels=labels)
        self.log(f'train_{self.model.loss_name}_loss', outputs['loss'])
        return outputs['loss']
    
    def validation_step(self, batch, batch_idx):
        flux, error, labels = batch
        outputs = self.model(flux, labels=labels)
        acc = self.accuracy(outputs['logits'], labels)
        self.log('val_loss', outputs['loss'])
        self.log('val_acc', acc, prog_bar=True)
        return outputs
    
    def test_step(self, batch, batch_idx):
        flux, error, labels = batch
        outputs = self.model(flux, labels=labels)
        acc = self.accuracy(outputs['logits'], labels)
        self.log('test_acc', acc)
        return outputs

In [None]:
ClassSpecDataset

In [19]:
class ViTDataModule(BaseDataModule):
    def __init__(self, **kwargs):
        super().__init__(dataset_cls=ClassSpecDataset, **kwargs)
        
    def setup(self, stage=None):
        super().setup(stage)
        # 添加通道维度 (B, L) -> (B, 1, L)
        self.train.flux = self.train.flux.unsqueeze(1)
        self.train.error = self.train.error.unsqueeze(1)
        self.val.flux = self.val.flux.unsqueeze(1)
        self.val.error = self.val.error.unsqueeze(1)

In [20]:
class Experiment:
    def __init__(self, config, use_wandb=False, num_gpus=None, sweep=False, ckpt_path=None):
        # 创建数据模块
        dm = ViTDataModule.from_config(config)
        dm.setup()
        
        # 创建模型
        model = ViTModel(config, num_labels=config.get('num_classes', 10))
        
        # 创建Lightning模块
        self.lightning_module = ViTLightningModule(model, dm, config)
        
        # 其余初始化保持不变...
        self.lightning_module.sweep = sweep
        if use_wandb:
            logger = L.pytorch.loggers.WandbLogger(
                project=config['project'],
                config=config,
                name=config.get('exp_name', 'ViT_experiment'),
                log_model=True
            )
        else:
            logger = None
            
        self.t = SpecTrainer(config=config, logger=logger, 
                            num_gpus=num_gpus, sweep=sweep)
        self.ckpt_path = ckpt_path

    # run方法保持不变...