In [1]:
import numpy as np
import warnings
from nnlscit import *
from model import *
from network import *
# import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import warnings
from tqdm import tqdm
import seaborn as sns
import matplotlib.pyplot as plt 
from torch.utils.data import DataLoader, TensorDataset
import torch.optim as optim

In [2]:
device=torch.device('cuda:0')

In [3]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
seed_everything(888)

In [4]:
def same(x):
    return x

def cube(x):
    return np.power(x, 3)

def negexp(x):
    return np.exp(-np.abs(x))


def generate_samples_random(size=1000, sType='CI', dx=1, dy=1, dz=20, nstd=0.7, fixed_function='linear',
                            debug=False, normalize = True, seed = None, dist_z = 'gaussian'):
    
    if seed == None:
        np.random.seed()
    else:
        np.random.seed(seed)

    if fixed_function == 'linear':
        f1 = same
        f2 = same
    else:
        I1 = random.randint(2, 5)
        I2 = random.randint(2, 5)

        if I1 == 2:
            f1 = np.square
        elif I1 == 3:
            f1 = cube
        elif I1 == 4:
            f1 = np.tanh
        elif I1 == 5:
            f1 = np.cos

        if I2 == 2:
            f2 = np.square
        elif I2 == 3:
            f2 = cube
        elif I2 == 4:
            f2 = np.tanh
        elif I2 == 5:
            f2 = np.cos
    if debug:
        print(f1, f2)

    num = size

    if dist_z =='gaussian':
        cov = np.eye(dz) 
        #mu = 0.7 * np.ones(dz)    # linear case
        mu = np.zeros(dz)     # nonlinear case
        Z = np.random.multivariate_normal(mu, cov, num)
        Z = np.matrix(Z)

    elif dist_z == 'laplace':
        Z = np.random.laplace(loc=0.0, scale=1.0, size=num*dz)
        Z = np.reshape(Z,(num,dz))
        Z = np.matrix(Z)

    elif dist_z == 'uniform':
        Z = np.random.uniform(-2.5, 2.5, (num,dz))
        Z = np.matrix(Z)
        
    Ax = np.random.rand(dz, dx)
    for i in range(dx):
        Ax[:, i] = Ax[:, i] / np.linalg.norm(Ax[:, i], ord=1) 
    Ax = np.matrix(Ax)

    Ay = np.random.rand(dz, dy)
    for i in range(dy):
        Ay[:, i] = Ay[:, i] / np.linalg.norm(Ay[:, i], ord=1) 
    Ay = np.matrix(Ay)

    Axy = np.random.rand(dx, dy)
    for i in range(dy):
        Axy[:, i] = Axy[:, i] / np.linalg.norm(Axy[:, i], ord=1) 
    Axy = np.matrix(Axy)
    
    Azy = np.random.normal(0., 1., (dz, dy))   
    alpha = 2


    if sType == 'CI':
        X = f1(Z * Ax + nstd * np.random.multivariate_normal(np.zeros(dx), np.eye(dx), num))
        Y = f2(Z * Ay + nstd * np.random.multivariate_normal(np.zeros(dy), np.eye(dy), num))
    elif sType == 'I':
        X = f1(nstd * np.random.multivariate_normal(np.zeros(dx), np.eye(dx), num))
        Y = f2(nstd * np.random.multivariate_normal(np.zeros(dy), np.eye(dy), num))
    else:
        eps=nstd * np.random.multivariate_normal(np.zeros(dy), np.eye(dy), num)
        X = f1(Z * Ax ) + eps
        Y = f2(Z * Ay ) + eps

    if normalize == True:
        Z = (Z - Z.min()) / (Z.max() - Z.min())
        X = (X - X.min()) / (X.max() - X.min())
        Y = (Y - Y.min()) / (Y.max() - Y.min())

    return np.array(X), np.array(Y), np.array(Z)

In [5]:
class DiffusionModelWithEmbedding(nn.Module):
    def __init__(self, 
                 input_dim, 
                 time_steps, 
                 embedding_dim,
                 cond_dim):
        super(DiffusionModelWithEmbedding, self).__init__()
        self.time_embedding = nn.Embedding(time_steps, embedding_dim)
        self.fc1 = nn.Linear(input_dim + embedding_dim+cond_dim, 128)
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, 128)
        self.fc4 = nn.Linear(128, input_dim)
        self.relu = nn.SELU()



    def forward(self, x, t,condition):
        t_emb = self.time_embedding(t).squeeze(1)
        # print(t_emb.shape)
        x = torch.cat([x, t_emb,condition], dim=1)
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.relu(self.fc3(x))
        return self.fc4(x)

In [6]:
def diffusion_loss_fn(model,batch,alphas_bar_sqrt,one_minus_alphas_bar_sqrt,n_steps,device=device):
    batch_x=batch[:,:dx]
    batch_z=batch[:,dx:]
    batch_size=batch.shape[0]
    t=torch.randint(0,n_steps,size=(batch_size,),device=device)
    t=t.unsqueeze(-1)
    a=alphas_bar_sqrt[t].to(device)
    aml=one_minus_alphas_bar_sqrt[t].to(device)
    e=torch.randn_like(batch_x).to(device)
    x_t=batch_x*a+aml*e
#     print(x_t.device,t.device)
    output=model(x_t,t.squeeze(-1),batch_z)
    # print(torch.norm(output, p='fro'))
    return (e-output).square().mean()

In [7]:
# 加噪过程所需要的参数
def make_beta_schedule(schedule="linear", num_timesteps=1000, start=1e-5, end=2e-3):
    if schedule == "linear":
        betas = torch.linspace(start, end, num_timesteps)
    elif schedule == "const":
        betas = end * torch.ones(num_timesteps)
    elif schedule == "quad":
        betas = torch.linspace(start ** 0.5, end ** 0.5, num_timesteps) ** 2
    elif schedule == "jsd":
        betas = 1.0 / torch.linspace(num_timesteps, 1, num_timesteps)
    elif schedule == "sigmoid":
        betas = torch.linspace(-6, 6, num_timesteps)
        betas = torch.sigmoid(betas) * (end - start) + start
    elif schedule == "cosine" or schedule == "cosine_reverse":
        max_beta = 0.999
        cosine_s = 0.008
        betas = torch.tensor(
            [min(1 - (math.cos(((i + 1) / num_timesteps + cosine_s) / (1 + cosine_s) * math.pi / 2) ** 2) / (
                    math.cos((i / num_timesteps + cosine_s) / (1 + cosine_s) * math.pi / 2) ** 2), max_beta) for i in
             range(num_timesteps)])
        if schedule == "cosine_reverse":
            betas = betas.flip(0)  # starts at max_beta then decreases fast
    elif schedule == "cosine_anneal":
        betas = torch.tensor(
            [start + 0.5 * (end - start) * (1 - math.cos(t / (num_timesteps - 1) * math.pi)) for t in
             range(num_timesteps)])
    return betas

In [8]:
def sample_from_model(model, num_samples, input_dim,cond, alphas_bar_sqrt, one_minus_alphas_bar_sqrt, betas, num_steps, device=device):
    traj=[]
    x_t = torch.randn(num_samples, input_dim).to(device)  # 从标准正态分布开始
#     print(x_t.mean(),x_t.std())
    traj.append(x_t)
    for t in reversed(range(num_steps)):
        t_tensor = torch.full((num_samples,), t, device=device, dtype=torch.long)
        noise_pred = model(x_t, t_tensor,cond)  # 预测的噪声
        # print(t,noise_pred.mean(),noise_pred.std(),torch.norm(noise_pred, p='fro'))

        if t > 0:
            beta_t = betas[t]
            alpha_t = 1 - beta_t
            alpha_t_sqrt = torch.sqrt(alpha_t)
            # 标准正态分布的噪声
            noise = torch.randn_like(x_t).to(device)
            # 使用逆过程公式
            x_t = (1 / alpha_t_sqrt) * (x_t - (1 - alpha_t) / one_minus_alphas_bar_sqrt[t] * noise_pred) + torch.sqrt(beta_t) * noise
            traj.append(x_t)
#             print(x_t.mean(),x_t.std())
        else:
            # 最后一步，不加噪声
            alpha_t = 1 - betas[t]
            alpha_t_sqrt = torch.sqrt(alpha_t)
            x_t = (1 / alpha_t_sqrt) * (x_t - (1 - alpha_t) / one_minus_alphas_bar_sqrt[t] * noise_pred)
            traj.append(x_t)

    return traj

In [9]:
def sample_from_model_ddim(model, num_samples, input_dim, cond, alphas_bar_sqrt, one_minus_alphas_bar_sqrt, num_steps, eta=0.0, device='cuda'):
    traj = []
    x_t = torch.randn(num_samples, input_dim, device=device)
    traj.append(x_t)

    for t in reversed(range(num_steps)):
        t_tensor = torch.full((num_samples,), t, device=device, dtype=torch.long)
        noise_pred = model(x_t, t_tensor, cond)

        alpha_bar_t = alphas_bar_sqrt[t] ** 2
        sqrt_alpha_bar_t = alphas_bar_sqrt[t]
        sqrt_one_minus_alpha_bar_t = one_minus_alphas_bar_sqrt[t]

        # 根据DDIM重参数化公式计算x0预测
        x0_pred = (x_t - sqrt_one_minus_alpha_bar_t * noise_pred) / sqrt_alpha_bar_t

        if t > 0:
            alpha_bar_prev = alphas_bar_sqrt[t - 1] ** 2
            sqrt_alpha_bar_prev = alphas_bar_sqrt[t - 1]
            sqrt_one_minus_alpha_bar_prev = one_minus_alphas_bar_sqrt[t - 1]

            # DDIM sigma_t
            sigma_t = eta * torch.sqrt(
                (1 - alpha_bar_prev) / (1 - alpha_bar_t) * (1 - alpha_bar_t / alpha_bar_prev)
            )

            noise = torch.randn_like(x_t) if eta > 0 else 0.0

            # 直接根据DDIM更新公式计算x_{t-1}
            x_t = (
                sqrt_alpha_bar_prev * x0_pred
                + torch.sqrt(1 - alpha_bar_prev - sigma_t**2) * noise_pred
                + sigma_t * noise
            )
        else:
            # 最后一步直接输出x0预测
            x_t = x0_pred

        traj.append(x_t)

    return traj

In [12]:
dx=5
dy=5
dz=10
num=1000
nstd=0.33
test_type='CI'


x_whole,y_whole,z_whole=generate_samples_random(size=num, sType=test_type, dx=dx, dy=dy, dz=dz, nstd=nstd, fixed_function='nonlinear',
                            debug=False, normalize = True, seed = 0, dist_z = 'gaussian')

x_crt=x_whole[int(num/2):,:]
y_crt=y_whole[int(num/2):,:]
z_crt=z_whole[int(num/2):,:]
x=torch.from_numpy(x_whole[:int(num/2),:]).float().to(device)
z=torch.from_numpy(z_whole[:int(num/2),:]).float().to(device)
data=torch.cat([x,z],dim=1)
dataloader = DataLoader(data.to(device), batch_size=2048, shuffle=True)  # 可以根据需要调整 batch_size
num_steps=1000
betas=make_beta_schedule(schedule="linear", num_timesteps=num_steps,start=1e-4, end=2e-2)
alphas=1-betas
alphas_bar=torch.cumprod(alphas,0).to(device)
alphas_bar_sqrt=torch.sqrt(alphas_bar)
one_minus_alphas_bar_sqrt=torch.sqrt(1-alphas_bar)
input_dim = dx
time_steps = 1000  # 时间步总数
embedding_dim = 16
num_samples = 1000
num_training_steps = 1000
learning_rate = 1e-3

# 定义噪声调度参数
betas = torch.linspace(1e-4, 0.02, time_steps)

model = DiffusionModelWithEmbedding(input_dim, time_steps, embedding_dim,cond_dim=dz).to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
for step in range(1500):
    total_loss=0
    for batch in dataloader:

        optimizer.zero_grad()
        loss = diffusion_loss_fn(model, batch, alphas_bar_sqrt, one_minus_alphas_bar_sqrt, num_steps, device)
        loss.backward()
        optimizer.step()
        total_loss+=loss.item()

    # 每 100 步打印一次损失
# print(f"Step {step}: Loss = {total_loss}")

# crt


original = NNCMI(x_crt, y_crt, z_crt, dx, dy, dz,
                     classifier='xgb', normalize=False)

count = 0
from tqdm import tqdm
for iiiii in tqdm(range(100)):
    with torch.no_grad():

        # sample pseudo_y

#         x_seq_crt = sample_from_model(model, num_samples=z_crt.shape[0], input_dim=input_dim,cond=torch.tensor(z_crt).to(device).float(),
#                                       alphas_bar_sqrt=alphas_bar_sqrt,
#                                       one_minus_alphas_bar_sqrt=one_minus_alphas_bar_sqrt,
#                                       betas=betas,
#                                       num_steps=1000,
#                                       device=device)
        x_seq_crt = sample_from_model_ddim(model, 
                                           num_samples=z_crt.shape[0], 
                                           input_dim=input_dim, 
                                           cond=torch.tensor(z_crt).to(device).float(), 
                                           alphas_bar_sqrt=alphas_bar_sqrt, 
                                           one_minus_alphas_bar_sqrt=one_minus_alphas_bar_sqrt, 
                                           num_steps=50, 
                                           eta=0.0, 
                                           device=device)

    crt_stat = NNCMI(x_seq_crt[-1].detach().cpu().numpy(), y_crt,
                         z_crt, dx, dy, dz, classifier='xgb',
                         normalize=False)

    if crt_stat > original:
        count += 1
print(count/100)



100%|████████████████████████████████████████████████████████████████████████████████| 100/100 [00:12<00:00,  7.96it/s]

0.77



