In [2]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
import pymongo
from numpy import log10

client = pymongo.MongoClient("mongodb://localhost:27017/")
db = client["solve"]
collection = db["data"]


class GWDataset(Dataset):
    def __init__(self, data, x_scaler=None, y_scaler=None, fit_scalers=True):
        self.data = data

        params = np.array([[log10(item['r']), item['n_t'], log10(item['kappa10']),
                            log10(item['T_re']), item['DN_re']] for item in data])
        curves = np.array([np.column_stack((item['f_interp'],
                                            item['log10OmegaGW_interp']))
                           for item in data])

        # 分割x和y
        curves_x = curves[:, :, 0]
        curves_y = curves[:, :, 1]

        if fit_scalers or x_scaler is None:
            self.param_scaler = StandardScaler()
            self.param_scaler.fit(params)
            self.x_scaler = StandardScaler()
            self.x_scaler.fit(curves_x.reshape(-1, 1))
            self.y_scaler = StandardScaler()
            self.y_scaler.fit(curves_y.reshape(-1, 1))
        else:
            self.param_scaler = x_scaler
            self.x_scaler = x_scaler
            self.y_scaler = y_scaler

        self.params = self.param_scaler.transform(params)
        curves_x_scaled = self.x_scaler.transform(curves_x.reshape(-1, 1)).reshape(curves_x.shape)
        curves_y_scaled = self.y_scaler.transform(curves_y.reshape(-1, 1)).reshape(curves_y.shape)
        self.curves = np.stack([curves_x_scaled, curves_y_scaled], axis=2)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        params = torch.tensor(self.params[idx], dtype=torch.float32)
        curve = torch.tensor(self.curves[idx], dtype=torch.float32)
        return params, curve


def collate_fn(batch):
    params, curves = zip(*batch)
    return torch.stack(params), torch.stack(curves)


class ResidualBlock(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.block = nn.Sequential(
            nn.Linear(dim, dim),
            nn.GELU(),
            nn.LayerNorm(dim)
        )

    def forward(self, x):
        return x + self.block(x)


class CurvePredictor(nn.Module):
    def __init__(self):
        super().__init__()
        # 参数编码器
        self.encoder = nn.Sequential(
            nn.Linear(5, 128),
            nn.GELU(),
            nn.LayerNorm(128),
            nn.Linear(128, 256),
            nn.GELU(),
            nn.LayerNorm(256)
        )

        self.lstm = nn.LSTM(
            input_size=256,
            hidden_size=256,
            num_layers=2,
            bidirectional=False,
            batch_first=True
        )

        self.decoder = nn.Sequential(
            nn.Linear(256, 128),
            nn.GELU(),
            nn.Linear(128, 2)
        )

    def forward(self, x):
        # 编码参数 [B,5] -> [B,256]
        encoded = self.encoder(x)

        # 扩展为序列 [B,256] -> [B,256,256]
        repeated = encoded.unsqueeze(1).repeat(1, 256, 1)

        # 双向LSTM处理 [B,256,256] -> [B,256,512]
        lstm_out, _ = self.lstm(repeated)

        # 解码输出 [B,256,512] -> [B,256,2]
        return self.decoder(lstm_out)


from tqdm import tqdm
train_losses = []  
val_losses = []   

def train_gw_model(condition={}, epochs=200, batch_size=32):
    raw_data = list(collection.find(condition))
    full_dataset = GWDataset(raw_data)
    print(f'data num:{len(raw_data)}')

    train_idx, val_idx = train_test_split(
        np.arange(len(full_dataset)),
        test_size=0.2,
        random_state=42
    )
    train_data = torch.utils.data.Subset(full_dataset, train_idx)
    val_data = torch.utils.data.Subset(full_dataset, val_idx)

    train_loader = DataLoader(
        train_data,
        batch_size=batch_size,
        shuffle=True,
        collate_fn=collate_fn
    )
    val_loader = DataLoader(
        val_data,
        batch_size=batch_size,
        collate_fn=collate_fn
    )

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = CurvePredictor().to(device)

    optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode='min',
        factor=0.5,
        patience=5
    )
    criterion = nn.MSELoss()
    print('start training')

    best_loss = float('inf')

    
    for epoch in tqdm(range(epochs)):
        model.train()
        train_loss = 0.0

        for params, curves in train_loader:
            params = params.to(device)
            curves = curves.to(device)

            optimizer.zero_grad()
            outputs = model(params)
            loss = criterion(outputs, curves)
            # loss_last = criterion(outputs[:,-1, :], curves[:,-1,:]) * 5.0  # 权重设为5
            # loss_rest = criterion(outputs[:, :, :], curves[:, :, :])
            # loss = loss_last + loss_rest
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

            train_loss += loss.item() * params.size(0)

        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for params, curves in val_loader:
                params = params.to(device)
                curves = curves.to(device)
                outputs = model(params)
                val_loss += criterion(outputs, curves).item() * params.size(0)

        train_loss /= len(train_loader.dataset)
        val_loss /= len(val_loader.dataset)
        scheduler.step(val_loss)
        
        train_losses.append(train_loss)
        val_losses.append(val_loss)

        print(f"Epoch {epoch + 1}/{epochs}")
        print(f"Train Loss: {train_loss:.4e} | Val Loss: {val_loss:.4e}")

        # if val_loss < best_loss:
        #     best_loss = val_loss
        #     torch.save({
        #         'model_state': model.state_dict(),
        #         'x_scaler': full_dataset.x_scaler,
        #         'y_scaler': full_dataset.y_scaler,
        #         'param_scaler': full_dataset.param_scaler
        #     }, 'best_gw_model.pth')

    return model


class GWPredictor:
    def __init__(self, model_path='best_gw_model.pth'):
        checkpoint = torch.load(model_path, map_location='cpu')

        self.model = CurvePredictor()
        self.model.load_state_dict(checkpoint['model_state'])
        # self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        # self.model = self.model.to(self.device)
        self.model.eval()
        self.x_scaler = checkpoint['x_scaler']
        self.y_scaler = checkpoint['y_scaler']
        self.param_scaler = checkpoint['param_scaler']

    def predict(self, params_dict):
        params = np.array([
            log10(params_dict['r']),
            params_dict['n_t'],
            log10(params_dict['kappa10']),
            log10(params_dict['T_re']),
            params_dict['DN_re']
        ]).reshape(1, -1)

        scaled_params = self.param_scaler.transform(params)

        with torch.no_grad():
            inputs = torch.tensor(scaled_params, dtype=torch.float32)
            # inputs = inputs.to(self.device)
            outputs = self.model(inputs).to('cpu').numpy()

        # denorm = self.y_scaler.inverse_transform(
        #     outputs.reshape(-1, 2)).reshape(outputs.shape)
        denorm_x = self.x_scaler.inverse_transform(outputs[..., 0].reshape(-1, 1)).reshape(outputs.shape[0], -1)
        denorm_y = self.y_scaler.inverse_transform(outputs[..., 1].reshape(-1, 1)).reshape(outputs.shape[0], -1)

        return {
            'f': denorm_x[0].tolist(),
            'log10OmegaGW': denorm_y[0].tolist()
        }

In [21]:
trained_model = train_gw_model(
    {
        # 'r': {'$gte': 1e-6, '$lte': 1e-4},
        # 'n_t': {'$gte': 0, '$lte': 1},
        # 'kappa10': {'$gte': 1e2, '$lte': 2e2},
        # 'T_re': {'$gte': 0, '$lte': 1e3},
        # 'DN_re': {'$gte': 20, '$lte': 40}
    }, epochs=100)

data num:25689
start training


  1%|          | 1/100 [00:27<44:55, 27.23s/it]

Epoch 1/100
Train Loss: 3.1100e-01 | Val Loss: 1.1336e-01


  1%|          | 1/100 [00:31<52:11, 31.63s/it]


KeyboardInterrupt: 

In [None]:
np.save('trainloss',train_losses)
np.save('valloss',val_losses)

In [None]:
import matplotlib.pyplot as plt

epochs = 100
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(8, 8), sharex=False)

ax1.plot(range(1, epochs + 1), train_losses, 'g--', label='Train Loss', marker='*', color="royalblue")
ax1.plot(range(1, epochs + 1), val_losses, 'b--', label='Validation Loss', marker='.', color="red")
ax1.set_ylabel('Loss')
ax1.set_title('Training and Validation Loss Curves')
ax1.legend()
ax1.grid(False)

ax2.plot(range(51, epochs + 1), train_losses[-50:], 'g--', label='Train Loss', marker='*', color="royalblue")
ax2.plot(range(51, epochs + 1), val_losses[-50:], 'b--', label='Validation Loss', marker='.', color="red")
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Loss')
ax2.set_title('Training and Validation Loss Curves after 50 Epochs')
ax2.legend()
ax2.grid(False)

plt.tight_layout()

plt.savefig('./image/combined_train_loss.eps', format='eps', dpi=50, bbox_inches='tight')
plt.show()