In [1]:
!git clone https://github.com/XzwHan/CARD.git

Cloning into 'CARD'...


In [2]:
import os
import torch
import numpy as np
import pandas as pd
import torch.nn as nn
from tqdm import trange
from torch.optim import Adam
import torch.nn.functional as F
from sklearn.metrics import r2_score, mean_squared_error
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, TensorDataset
from sklearn.preprocessing import StandardScaler

from CARD.regression.model import (
    DeterministicFeedForwardNeuralNetwork,
    ConditionalLinear,
)
from CARD.regression.diffusion_utils import (
    make_beta_schedule,
    q_sample,
    p_sample_loop,
    p_sample,
)

scl = StandardScaler()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
import pandas as pd
df = pd.read_csv(r'C:\Users\dell\Desktop\MyDocs\Docs\MK\protein.csv')
X = (df.drop(columns=["RMSD"], axis=1)).values
y = (df["RMSD"]).values
y = y[:, np.newaxis]

X, y = torch.from_numpy(X).to(torch.float32), torch.from_numpy(y).to(torch.float32)

X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=1
)
train_data = TensorDataset(X_train, y_train)
val_data = TensorDataset(X_test, y_test)

train_loader = DataLoader(train_data, batch_size=128, shuffle=True)
test_loader = DataLoader(val_data, batch_size=128, shuffle=True)

In [4]:
n_steps = 500
cat_x = True
cat_y_pred = True

x_dim = X_train.shape[1]
y_dim = 1
z_dim = 2

hid_layers = [100, 50]

beta_schedule = "linear"
beta_start = 0.0001
beta_end = 0.02

In [5]:
class ConditionalGuidedModel(nn.Module):
    def __init__(
        self,
        n_steps: int,
        cat_x: bool,
        cat_y_pred: bool,
        x_dim: int,
        y_dim: int,
        z_dim: int,
    ):
        super(ConditionalGuidedModel, self).__init__()
        self.cat_x = cat_x
        self.cat_y_pred = cat_y_pred
        data_dim = y_dim
        if self.cat_x:
            data_dim += x_dim
        if self.cat_y_pred:
            data_dim += y_dim
        self.lin1 = ConditionalLinear(data_dim, 128, n_steps)
        self.lin2 = ConditionalLinear(128, 128, n_steps)
        self.lin3 = ConditionalLinear(128, 128, n_steps)
        self.lin4 = nn.Linear(128, 1)

    def forward(self, x, y_t, y_0_hat, t):
        if self.cat_x:
            if self.cat_y_pred:
                eps_pred = torch.cat((y_t, y_0_hat, x), dim=1)
            else:
                eps_pred = torch.cat((y_t, x), dim=1)
        else:
            if self.cat_y_pred:
                eps_pred = torch.cat((y_t, y_0_hat), dim=1)
            else:
                eps_pred = y_t
        eps_pred = F.softplus(self.lin1(eps_pred, t))
        eps_pred = F.softplus(self.lin2(eps_pred, t))
        eps_pred = F.softplus(self.lin3(eps_pred, t))
        return self.lin4(eps_pred)


diff_model = ConditionalGuidedModel(
    n_steps=n_steps,
    cat_x=cat_x,
    cat_y_pred=cat_y_pred,
    x_dim=x_dim,
    y_dim=y_dim,
    z_dim=z_dim,
)
diff_model.to(device)

ConditionalGuidedModel(
  (lin1): ConditionalLinear(
    (lin): Linear(in_features=11, out_features=128, bias=True)
    (embed): Embedding(500, 128)
  )
  (lin2): ConditionalLinear(
    (lin): Linear(in_features=128, out_features=128, bias=True)
    (embed): Embedding(500, 128)
  )
  (lin3): ConditionalLinear(
    (lin): Linear(in_features=128, out_features=128, bias=True)
    (embed): Embedding(500, 128)
  )
  (lin4): Linear(in_features=128, out_features=1, bias=True)
)

In [6]:
cond_pred_model = DeterministicFeedForwardNeuralNetwork(dim_in=x_dim, 
                                                        dim_out=y_dim, 
                                                        hid_layers=hid_layers
                                                        )
cond_pred_model.to(device)

DeterministicFeedForwardNeuralNetwork(
  (network): Sequential(
    (0): Linear(in_features=9, out_features=100, bias=True)
    (1): LeakyReLU(negative_slope=0.01)
    (2): Dropout(p=0, inplace=False)
    (3): Linear(in_features=100, out_features=50, bias=True)
    (4): LeakyReLU(negative_slope=0.01)
    (5): Dropout(p=0, inplace=False)
    (6): Linear(in_features=50, out_features=1, bias=True)
  )
)

In [7]:
n_pretrain_epochs = 10
aux_optimizer = Adam(cond_pred_model.parameters(), lr=0.01)
aux_cost_fn = nn.MSELoss()
cond_pred_model.train()

bar = trange(n_pretrain_epochs, leave=True)
for epoch in bar:
    for x, y in train_loader:
        x, y = x.to(device), y.to(device)
        y_pred = cond_pred_model(x)
        aux_cost = aux_cost_fn(y_pred, y)

        aux_optimizer.zero_grad()
        aux_cost.backward()
        aux_optimizer.step()
        bar.set_description(f"Loss: {aux_cost.item()}")

Loss: 46.64861297607422: 100%|██████████| 10/10 [00:15<00:00,  1.58s/it]


In [8]:
betas = make_beta_schedule(beta_schedule, n_steps, beta_start, beta_end).to(device)
betas_sqrt = torch.sqrt(betas)
alphas = 1.0 - betas
alphas_cumprod = alphas.cumprod(dim=0)
alphas_bar_sqrt = torch.sqrt(alphas_cumprod)
one_minus_alphas_bar_sqrt = torch.sqrt(1 - alphas_cumprod)

In [9]:
n_epochs = 10

optimizer = Adam(diff_model.parameters(), lr=0.01)

diff_bar = trange(n_epochs, leave=True)
diff_model.train()

for epoch in diff_bar:
    for x, y in train_loader:
        x, y = x.to(device), y.to(device)
        batch_size = x.shape[0]

        # antithetic sampling
        ant_samples_t = torch.randint(
            low=0, high=n_steps, size=(batch_size // 2 + 1,)
        ).to(device)
        ant_samples_t = torch.cat([ant_samples_t, n_steps - 1 - ant_samples_t], dim=0)[
            :batch_size
        ]

        # noise estimation loss
        y_0_hat = cond_pred_model(x)

        e = torch.randn_like(y)

        y_t_sample = q_sample(
            y,
            y_0_hat,
            alphas_bar_sqrt,
            one_minus_alphas_bar_sqrt,
            ant_samples_t,
            noise=e,
        )

        model_output = diff_model(x, y_t_sample, y_0_hat, ant_samples_t)

        # use the same noise sample e during training to compute loss
        loss = (e - model_output).square().mean()

        # optimize diffusion model that predicts eps_theta
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # optimize non-linear guidance model
        aux_cost = aux_cost_fn(cond_pred_model(x), y)
        aux_optimizer.zero_grad()
        aux_cost.backward()
        aux_optimizer.step()

        diff_bar.set_description(f"Loss: {loss.item()}", refresh=True)

Loss: 24.112424850463867: 100%|██████████| 10/10 [00:37<00:00,  3.78s/it]


In [10]:
n_z_samples = 10

y_0_hat = cond_pred_model(X_test.to(device))

# obtain y samples through reverse diffusion -- some pytorch version might not have torch.tile
# y_0_tile = torch.tile(y, (n_z_samples, 1))
y_0_hat_tile = torch.tile(y_0_hat, (n_z_samples, 1)).to(device)
test_x_tile = torch.tile(X_test, (n_z_samples, 1)).to(device)

z = torch.randn_like(y_0_hat_tile).to(device)

y_t = y_0_hat_tile + z

def extract(input, t, x):
    shape = x.shape
    out = torch.gather(input, 0, t)
    reshape = [t.shape[0]] + [1] * (len(shape) - 1)
    return out.reshape(*reshape)


def p_sample(
    x, y, y_0_hat, y_T_mean, t: int, alphas, one_minus_alphas_bar_sqrt, guidance_model
):
    z = torch.randn_like(y)
    t = torch.tensor([t]).to(device)
    alpha_t = extract(alphas, t, y)
    sqrt_one_minus_alpha_bar_t = extract(one_minus_alphas_bar_sqrt, t, y)
    sqrt_one_minus_alpha_bar_t_m_1 = extract(one_minus_alphas_bar_sqrt, t - 1, y)
    sqrt_alpha_bar_t = (1 - sqrt_one_minus_alpha_bar_t.square()).sqrt()
    sqrt_alpha_bar_t_m_1 = (1 - sqrt_one_minus_alpha_bar_t_m_1.square()).sqrt()
    # y_t_m_1 posterior mean component coefficients
    gamma_0 = (
        (1 - alpha_t) * sqrt_alpha_bar_t_m_1 / (sqrt_one_minus_alpha_bar_t.square())
    )
    gamma_1 = (
        (sqrt_one_minus_alpha_bar_t_m_1.square())
        * (alpha_t.sqrt())
        / (sqrt_one_minus_alpha_bar_t.square())
    )
    gamma_2 = 1 + (sqrt_alpha_bar_t - 1) * (alpha_t.sqrt() + sqrt_alpha_bar_t_m_1) / (
        sqrt_one_minus_alpha_bar_t.square()
    )
    eps_theta = guidance_model(x, y, y_0_hat, t).detach()
    # y_0 reparameterization
    y_0_reparam = (
        1
        / sqrt_alpha_bar_t
        * (
            y
            - (1 - sqrt_alpha_bar_t) * y_T_mean
            - eps_theta * sqrt_one_minus_alpha_bar_t
        )
    )
    y_t_m_1_hat = gamma_0 * y_0_reparam + gamma_1 * y + gamma_2 * y_T_mean

    beta_t_hat = (
        (sqrt_one_minus_alpha_bar_t_m_1.square())
        / (sqrt_one_minus_alpha_bar_t.square())
        * (1 - alpha_t)
    )
    y_t_m_1 = y_t_m_1_hat.to(device) + beta_t_hat.sqrt().to(device) * z.to(device)
    return y_t_m_1


def p_sample_loop(
    x,
    y_0_hat,
    y_T_mean,
    n_steps,
    alphas,
    one_minus_alphas_bar_sqrt,
    only_last_sample,
    guidance_model,
):
    num_t, y_p_seq = None, None
    z = torch.randn_like(y_T_mean).to(device)
    cur_y = z + y_T_mean  # sampled y_T
    if only_last_sample:
        num_t = 1
    else:
        y_p_seq = [cur_y]
    for t in reversed(range(1, n_steps)):
        y_t = cur_y
        cur_y = p_sample(
            x,
            y_t,
            y_0_hat,
            y_T_mean,
            t,
            alphas,
            one_minus_alphas_bar_sqrt,
            guidance_model,
        )  # y_{t-1}
        if only_last_sample:
            num_t += 1
        else:
            y_p_seq.append(cur_y)
    if only_last_sample:
        assert num_t == n_steps
        y_0 = p_sample_t_1to0(
            x, cur_y, y_0_hat, y_T_mean, one_minus_alphas_bar_sqrt, guidance_model
        )
        return y_0
    else:
        assert len(y_p_seq) == n_steps
        y_0 = p_sample_t_1to0(
            x, y_p_seq[-1], y_0_hat, y_T_mean, one_minus_alphas_bar_sqrt, guidance_model
        )
        y_p_seq.append(y_0)
        return y_p_seq


def p_sample_t_1to0(x, y, y_0_hat, y_T_mean, one_minus_alphas_bar_sqrt, guidance_model):
    # corresponding to timestep 1 (i.e., t=1 in diffusion models)
    t = torch.tensor([0]).to(device)
    sqrt_one_minus_alpha_bar_t = extract(one_minus_alphas_bar_sqrt, t, y)
    sqrt_alpha_bar_t = (1 - sqrt_one_minus_alpha_bar_t.square()).sqrt()
    eps_theta = guidance_model(x, y, y_0_hat, t).detach()
    # y_0 reparameterization
    y_0_reparam = (
        1
        / sqrt_alpha_bar_t
        * (
            y
            - (1 - sqrt_alpha_bar_t) * y_T_mean
            - eps_theta * sqrt_one_minus_alpha_bar_t
        )
    )
    y_t_m_1 = y_0_reparam.to(device)
    return y_t_m_1


def predict_step(
    x,
    y_0_hat,
    alphas,
    one_minus_alphas_bar_sqrt,
    cond_pred_model,
    n_z_samples,
    n_steps,
    guidance_model,
):
    with torch.no_grad():
        x = x.to(device)
        y_0_hat = cond_pred_model(x)

        y_0_hat_tile = torch.tile(y_0_hat, (n_z_samples, 1)).to(device)
        test_x_tile = torch.tile(x, (n_z_samples, 1)).to(device)

        z = torch.randn_like(y_0_hat_tile).to(device)

        # generate samples from all time steps for the mini-batch
        y_tile_seq = p_sample_loop(
            test_x_tile,
            y_0_hat_tile,
            y_0_hat_tile,
            n_steps,
            alphas.to(device),
            one_minus_alphas_bar_sqrt.to(device),
            False,
            guidance_model,
        )

        # put in shape [n_z_samples, batch_size, output_dimension]
        y_tile_seq = [
            arr.reshape(n_z_samples, x.shape[0], y_t.shape[-1]) for arr in y_tile_seq
        ]

        final_recoverd = y_tile_seq[-1]

        mean_pred = final_recoverd.mean(dim=0).detach().cpu().squeeze()
        std_pred = final_recoverd.std(dim=0).detach().cpu().squeeze()

        return {
            "pred": mean_pred,
            "pred_uct": std_pred,
            "aleatoric_uct": std_pred,
            "samples": y_tile_seq,
        }

In [11]:
y_0_hat = cond_pred_model(x.to(device))
pred = predict_step(
    X_test,
    y_0_hat,
    alphas,
    one_minus_alphas_bar_sqrt,
    cond_pred_model,
    n_z_samples,
    n_steps,
    diff_model,
)

In [13]:
print(f'RMSE score {mean_squared_error(y_test.detach().numpy().squeeze(), pred["pred"].detach().numpy().squeeze()):5f}')

RMSE score 190.108276
