#### 1. 这个notebook 是利用人造数据集来测试模型对于groundtruth的拟合能力

创建虚拟数据集，

模型训练

In [1]:
##虚拟数据集
from torch.utils.data import Dataset
import torch 
import numpy as np
class Synthetic(Dataset):
    def __init__(self, 
                 num_feat = 4, 
                 coefs = [0.5, 0.5, 0.5, 0.5], 
                 bias = 0.5, 
                 noise_std = 0.05, 
                 sample_size = 1000
                 ):
        
        self.num_feat = num_feat
        self.coefs = coefs
        self.bias = bias
        self.noise_std = noise_std
        self.sample_size = sample_size
    
    def __len__(self):
        return self.sample_size

    def get_feat_and_y(self, x):
        f0 = np.sin(x[0])
        f1 = np.cos(x[1])
        f2 = np.sin(2*x[2]) * np.cos(x[2])
        f3 = np.cos(2*x[3]) * np.sin(x[3])

        feat = np.array([f0, f1, f2, f3])
        y = np.dot(feat, self.coefs) + self.bias
        return feat, y

    def __getitem__(self, idx, seed = None):
        if seed:
            np.random.seed(seed)
        x = np.random.rand(self.num_feat) * 2 * np.pi
        feat, y = self.get_feat_and_y(x)

        y = y + np.random.randn() * self.noise_std        
        return torch.tensor(x, dtype=torch.float32), torch.tensor(y, dtype=torch.float32)

In [2]:
## model config
from models import BaseModel

config = {}
config['modelconfig'] = {}
config['modelconfig']['num_concepts'] = 4
config['modelconfig']['num_classes'] = 1
config['modelconfig']['modeltype'] = 'nbm'
config['modelconfig']['dropout'] = 0.1
config['modelconfig']['batchnorm'] = True
config['modelconfig']['num_bases'] = 200
config['modelconfig']['hidden_dims'] = (32, 32, 16)


weight_decay = 1e-5
batch_size = 256
config['modelconfig']['learning_rate'] = 0.0001
config['modelconfig']['weight_decay'] = weight_decay

model = BaseModel(**config['modelconfig'])


In [3]:
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping

from pytorch_lightning import Trainer

model_saved_path =  f'./saved_model_New/SyntheticDataset_B200/'


train_dataset = Synthetic(sample_size = 10000)
val_dataset = Synthetic(sample_size = 1000)
test_dataset = Synthetic(sample_size = 1000)

datamodule = pl.LightningDataModule.from_datasets(
            train_dataset=train_dataset,
            val_dataset=val_dataset,
            test_dataset=test_dataset,
            batch_size = batch_size,
        )

modelcheckpoint = ModelCheckpoint(
            dirpath=f'{model_saved_path}/',
            filename = f'best_model',
            monitor='val_loss', 
            save_top_k=1,
            save_last=True,
            mode='min'
            )

logger = pl.loggers.TensorBoardLogger(f'{model_saved_path}/logs/')
earlystop = EarlyStopping(monitor='val_loss', patience=300)
trainer = Trainer(acceler
ator='mps', 
                devices=1,
                max_epochs=10000,
                callbacks=[modelcheckpoint,earlystop],
                logger=logger,
                log_every_n_steps = 5,
                )
trainer.fit(model, datamodule=datamodule)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/Users/menghan/mambaforge/envs/newtorch/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:639: Checkpoint directory ./saved_model_New/SyntheticDataset_B200// exists and is not empty.

  | Name  | Type           | Params
-----------------------------------------
0 | model | ConceptNBMNary | 6.4 K 
-----------------------------------------
6.4 K     Trainable params
0         Non-trainable params
6.4 K     Total params
0.026     Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/Users/menghan/mambaforge/envs/newtorch/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.


                                                                           

/Users/menghan/mambaforge/envs/newtorch/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.


Epoch 9:  55%|█████▌    | 22/40 [00:00<00:00, 78.94it/s, v_num=1, train_loss=0.363, val_loss=0.122]

/Users/menghan/mambaforge/envs/newtorch/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...
