In [1]:
import os
import sys
sys.path.append("../")

In [2]:
# Import all necessary packages
import numpy as np
from typing import Dict
import hydra
import dill
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch.optim import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR

from diffusion_policy.workspace.base_workspace import BaseWorkspace
from diffusion_policy.policy.diffusion_unet_hybrid_image_policy import DiffusionUnetHybridImagePolicy
from diffusion_policy.dataset.robomimic_replay_image_dataset import BaseImageDataset
from diffusion_policy.model.common.normalizer import LinearNormalizer, SingleFieldLinearNormalizer
from diffusion_policy.common.pytorch_util import dict_apply, optimizer_to

from matplotlib import pyplot as plt
from tqdm.notebook import tqdm
import functools


In [3]:
## Load workspace so that to get the (i) image dataset and (ii) vision encoder
checkpoint  = "../outputs/lift_image_ph_reproduction/2025.06.10_10.30.38_train_diffusion_unet_hybrid_lift_image/checkpoints/epoch=0050-test_mean_score=1.000.ckpt"
output_dir = checkpoint[:-5] # Remove '.ckpt' from checkpoint path

payload = torch.load(open(checkpoint, 'rb'), pickle_module=dill)
cfg = payload['cfg']
cfg.training.device = "cuda:1"

cls = hydra.utils.get_class(cfg._target_)
workspace = cls(cfg, output_dir=output_dir)
workspace: BaseWorkspace
workspace.load_payload(payload, exclude_keys=None, include_keys=None)

# get dataset
image_dataset = hydra.utils.instantiate(cfg.task.dataset, dataset_path=os.path.join("../", cfg.task.dataset.dataset_path))
assert isinstance(image_dataset, BaseImageDataset)
train_dataloader = DataLoader(image_dataset, **cfg.dataloader)

# get model
model = workspace.model

In [4]:
## Get normalizer
data_dict = {"data": list(), "condition": list()}

for batch in tqdm(train_dataloader):
    batch = dict_apply(batch, lambda x: x.to(cfg.training.device, non_blocking=True))
    
    obs_dict = batch["obs"]
    nobs_features = model.encode_obs(obs_dict)
    data_dict["data"].append(nobs_features.detach().cpu().clone())
    data_dict["condition"].append(batch["action"][:, 1, :].detach().cpu().clone())
    
    # break;
    
    del batch # without del, the memory will be filled up

data_dict["data"] = torch.cat(data_dict["data"], dim=0).detach().cpu().clone().numpy()
data_dict["condition"] = torch.cat(data_dict["condition"], dim=0).detach().cpu().clone().numpy()

In [5]:
# make normalizer
data_normalizer = SingleFieldLinearNormalizer()
data_normalizer.fit(data_dict["data"], last_n_dims=1, mode='limits')

condition_normalizer = SingleFieldLinearNormalizer()
condition_normalizer.fit(data_dict["condition"], last_n_dims=1, mode='limits')

normalizer = LinearNormalizer()
normalizer["data"] = data_normalizer
normalizer["condition"] = condition_normalizer

In [6]:
plot_data = normalizer['data'].normalize(data_dict["data"]).detach().cpu().numpy()
plot_condition = normalizer['condition'].normalize(data_dict["condition"]).detach().cpu().numpy()

print(plot_data.shape)
print(plot_condition.shape)

In [8]:
class ImageDatasetWrapper(Dataset):
    def __init__(self, dataset:Dict):
        self.dataset = dataset

    def __len__(self):
        return len(self.dataset["data"])
    
    def __getitem__(self, idx):
        return {"data": self.dataset["data"][idx], "condition": self.dataset["condition"][idx]}
    
    def get_normalizer(self, mode='limits', **kwargs):
        data_normalizer = SingleFieldLinearNormalizer()
        data_normalizer.fit(self.dataset["data"], last_n_dims=1, mode='limits')

        condition_normalizer = SingleFieldLinearNormalizer()
        condition_normalizer.fit(self.dataset["condition"], last_n_dims=1, mode='limits')

        normalizer = LinearNormalizer()
        normalizer["data"] = data_normalizer
        normalizer["condition"] = condition_normalizer

        return normalizer

In [9]:
## Plot the data
Do = plot_data.shape[1]
num_subplots = 16
num_column = int(np.sqrt(num_subplots))+1
# num_column = 15
num_row = (int((num_subplots)/num_column)+1)

plt.figure(figsize=(5*num_column, 5*num_row))
plot_iter = 0
for i in range(num_row):
    for j in range(num_column):
        plt.subplot(num_row, num_column, i*num_column+j+1)
        
        # plt.scatter(dataset.dataset.replay_buffer['obs'][:, (plot_iter)%(D-1)],
        #             dataset.dataset.replay_buffer['obs'][:, (plot_iter)%(D-1)+1])
        
        plt.scatter(plot_data[:, (plot_iter)], plot_data[:, (plot_iter+1)], s=2)
        plt.xlabel(f"x_{(plot_iter)}")
        plt.ylabel(f"x_{(plot_iter)}")
        plt.grid(True)
        plot_iter += 1
        
        if (plot_iter)+1 == Do:
            break
plt.show()

In [10]:
## Plot the condition
Da = plot_condition.shape[1]
num_subplots = Da-1
num_column = int(np.sqrt(num_subplots))+1
num_row = (int((num_subplots)/num_column)+1)

plt.figure(figsize=(5*num_column, 5*num_row))
plot_iter = 0
for i in range(num_row):
    for j in range(num_column):
        plt.subplot(num_row, num_column, i*num_column+j+1)
        
        plt.scatter(plot_condition[:, (plot_iter)], plot_condition[:, (plot_iter)+1], s=1)
        plt.xlabel(f"x_{(plot_iter)}")
        plt.ylabel(f"x_{(plot_iter)+1}")
        plt.grid(True)
        
        plot_iter += 1
        if (plot_iter)+1 == Da:
            break
plt.show()

In [40]:
class GaussianFourierProjection(nn.Module):
  """Gaussian random features for encoding time steps."""  
  def __init__(self, embed_dim, scale=30.):
    super().__init__()
    # Randomly sample weights during initialization. These weights are fixed 
    # during optimization and are not trainable.
    self.W = nn.Parameter(torch.randn(embed_dim // 2) * scale, requires_grad=False)
  def forward(self, x):
    x_proj = x[:, None] * self.W[None, :] * 2 * np.pi
    return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)


class Dense(nn.Module):
  """A fully connected layer that reshapes outputs to feature maps."""
  def __init__(self, input_dim, output_dim):
    super().__init__()
    self.dense = nn.Linear(input_dim, output_dim)
  def forward(self, x):
    return self.dense(x)[..., None]


class ScoreNet(nn.Module):
  """A time-dependent score-based model built upon U-Net architecture."""

  def __init__(self, marginal_prob_std, channels=[32, 64, 128, 256, 512, 1024], embed_dim=256):
    """Initialize a time-dependent score-based network.

    Args:
      marginal_prob_std: A function that takes time t and gives the standard
        deviation of the perturbation kernel p_{0t}(x(t) | x(0)).
      channels: The number of channels for feature maps of each resolution.
      embed_dim: The dimensionality of Gaussian random feature embeddings.
    """
    super().__init__()
    # Gaussian random feature embedding layer for time
    self.embed = nn.Sequential(GaussianFourierProjection(embed_dim=embed_dim),
         nn.Linear(embed_dim, embed_dim))
    # Encoding layers where the resolution decreases
    kernel_size = 3
    self.conv1 = nn.Conv1d(2, channels[0], kernel_size=kernel_size, stride=1, bias=False)
    self.dense1 = Dense(embed_dim, channels[0])
    self.gnorm1 = nn.GroupNorm(4, num_channels=channels[0])

    self.conv2 = nn.Conv1d(channels[0], channels[1], kernel_size=kernel_size, stride=2, bias=False)
    self.dense2 = Dense(embed_dim, channels[1])
    self.gnorm2 = nn.GroupNorm(32, num_channels=channels[1])

    self.conv3 = nn.Conv1d(channels[1], channels[2], kernel_size=kernel_size, stride=2, bias=False)
    self.dense3 = Dense(embed_dim, channels[2])
    self.gnorm3 = nn.GroupNorm(32, num_channels=channels[2])

    self.conv4 = nn.Conv1d(channels[2], channels[3], kernel_size=kernel_size, stride=2, bias=False)
    self.dense4 = Dense(embed_dim, channels[3])
    self.gnorm4 = nn.GroupNorm(32, num_channels=channels[3])

    self.conv5 = nn.Conv1d(channels[3], channels[4], kernel_size=kernel_size, stride=2, bias=False)
    self.dense5 = Dense(embed_dim, channels[4])
    self.gnorm5 = nn.GroupNorm(32, num_channels=channels[4])

    self.conv6 = nn.Conv1d(channels[4], channels[5], kernel_size=kernel_size, stride=2, bias=False)
    self.dense6 = Dense(embed_dim, channels[5])
    self.gnorm6 = nn.GroupNorm(32, num_channels=channels[5])

    # Decoding layers where the resolution increases
    self.tconv6 = nn.ConvTranspose1d(channels[5], channels[4], kernel_size=kernel_size, stride=2, bias=False)
    self.dense7 = Dense(embed_dim, channels[4])
    self.tgnorm6 = nn.GroupNorm(32, num_channels=channels[4])
    
    self.tconv5 = nn.ConvTranspose1d(channels[4] + channels[4], channels[3], kernel_size=kernel_size, stride=2, bias=False, output_padding=1)
    self.dense8 = Dense(embed_dim, channels[3])
    self.tgnorm5 = nn.GroupNorm(32, num_channels=channels[3])

    self.tconv4 = nn.ConvTranspose1d(channels[3] + channels[3], channels[2], kernel_size=kernel_size, stride=2, bias=False)
    self.dense9 = Dense(embed_dim, channels[2])
    self.tgnorm4 = nn.GroupNorm(32, num_channels=channels[2])

    self.tconv3 = nn.ConvTranspose1d(channels[2] + channels[2], channels[1], kernel_size=kernel_size, stride=2, bias=False)
    self.dense10 = Dense(embed_dim, channels[1])
    self.tgnorm3 = nn.GroupNorm(32, num_channels=channels[1])
    
    self.tconv2 = nn.ConvTranspose1d(channels[1] + channels[1], channels[0], kernel_size=kernel_size, stride=2, bias=False)
    self.dense11 = Dense(embed_dim, channels[0])
    self.tgnorm2 = nn.GroupNorm(32, num_channels=channels[0])
    
    self.tconv1 = nn.ConvTranspose1d(channels[0] + channels[0], 2, kernel_size=kernel_size, stride=1)
    
    # The swish activation function
    self.act = lambda x: x * torch.sigmoid(x)
    self.marginal_prob_std = marginal_prob_std
  
  def forward(self, x, t): 
    # Obtain the Gaussian random feature embedding for t   
    embed = self.act(self.embed(t))    
    # Encoding path
    h1 = self.conv1(x)    
    ## Incorporate information from t
    h1 += self.dense1(embed)
    ## Group normalization
    h1 = self.gnorm1(h1)
    h1 = self.act(h1)

    h2 = self.conv2(h1)
    h2 += self.dense2(embed)
    h2 = self.gnorm2(h2)
    h2 = self.act(h2)
    
    h3 = self.conv3(h2)
    h3 += self.dense3(embed)
    h3 = self.gnorm3(h3)
    h3 = self.act(h3)
    
    h4 = self.conv4(h3)
    h4 += self.dense4(embed)
    h4 = self.gnorm4(h4)
    h4 = self.act(h4)
    
    h5 = self.conv5(h4)
    h5 += self.dense5(embed)
    h5 = self.gnorm5(h5)
    h5 = self.act(h5)

    h6 = self.conv6(h5)
    h6 += self.dense6(embed)
    h6 = self.gnorm6(h6)
    h6 = self.act(h6)

    # Decoding path
    h = self.tconv6(h6)
    h += self.dense7(embed)
    h = self.tgnorm6(h)
    h = self.act(h)

    h = self.tconv5(torch.cat([h, h5], dim=1))
    h += self.dense8(embed)
    h = self.tgnorm5(h)
    h = self.act(h)

    h = self.tconv4(torch.cat([h, h4], dim=1))
    h += self.dense9(embed)
    h = self.tgnorm4(h)
    h = self.act(h)
    
    h = self.tconv3(torch.cat([h, h3], dim=1))
    h += self.dense10(embed)
    h = self.tgnorm3(h)
    h = self.act(h)
    
    h = self.tconv2(torch.cat([h, h2], dim=1))
    h += self.dense11(embed)
    h = self.tgnorm2(h)
    h = self.act(h)
    
    h = self.tconv1(torch.cat([h, h1], dim=1))

    # Normalize output
    h = h / self.marginal_prob_std(t)[:, None, None]
    return h
  
class TransformerScoreNet(nn.Module):
  def __init__(self, input_dim, marginal_prob_std, embed_dim=128, num_heads=4, num_layers=4, ff_dim=256, dropout=0.1):
      super().__init__()
      self.input_dim = input_dim
      self.marginal_prob_std = marginal_prob_std
      # Temporal embedding
      self.time_embed = nn.Sequential(
          GaussianFourierProjection(embed_dim),
          nn.Linear(embed_dim, embed_dim),
          nn.ReLU(),
          nn.Linear(embed_dim, embed_dim)
      )

      # Input projection
      self.input_proj = nn.Linear(input_dim, embed_dim)

      # Transformer encoder
      encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, dim_feedforward=ff_dim, dropout=dropout, batch_first=True)
      self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

      # Output projection
      self.output_proj = nn.Linear(embed_dim, input_dim)

  def forward(self, x, t):
      """
      Args:
          x: Tensor of shape (B, L, D)
          t: Tensor of shape (B,) or (B, 1)
      Returns:
          score: Tensor of shape (B, L, D)
      """
      B, L, D = x.shape

      # Project input
      x_proj = self.input_proj(x)  # (B, L, embed_dim)

      # Get time embeddings and expand across sequence
      t_embed = self.time_embed(t).unsqueeze(1)  # (B, 1, embed_dim)
      t_embed = t_embed.expand(-1, L, -1)        # (B, L, embed_dim)

      # Combine with input
      x_combined = x_proj + t_embed              # (B, L, embed_dim)

      # Transformer forward
      h = self.transformer(x_combined)           # (B, L, embed_dim)

      # Project back to input dimension
      out = self.output_proj(h)                  # (B, L, D)
      # Normalize output
      out = out / self.marginal_prob_std(t)[:, None, None]
      return out

In [41]:
device = 'cuda:1' #@param ['cuda', 'cpu'] {'type':'string'}

def marginal_prob_std(t, sigma):
  """Compute the mean and standard deviation of $p_{0t}(x(t) | x(0))$.

  Args:    
    t: A vector of time steps.
    sigma: The $\sigma$ in our SDE.  
  
  Returns:
    The standard deviation.
  """    
  t = torch.tensor(t, device=device)
  return torch.sqrt((sigma**(2 * t) - 1.) / 2. / np.log(sigma))

def diffusion_coeff(t, sigma):
  """Compute the diffusion coefficient of our SDE.

  Args:
    t: A vector of time t.
    sigma: The $\sigma$ in our SDE.
  
  Returns:
    The vector of diffusion coefficients.
  """
  return torch.tensor(sigma**t, device=device)
  
sigma =  50.0#@param {'type':'number'}
marginal_prob_std_fn = functools.partial(marginal_prob_std, sigma=sigma)
diffusion_coeff_fn = functools.partial(diffusion_coeff, sigma=sigma)

In [42]:
#@title Define the loss function (double click to expand or collapse)
def loss_fn(model, x, marginal_prob_std, eps=1e-4):
  """The loss function for training score-based generative models.

  Args:
    model: A PyTorch model instance that represents a 
      time-dependent score-based model.
    x: A mini-batch of training data.    
    marginal_prob_std: A function that gives the standard deviation of 
      the perturbation kernel.
    eps: A tolerance value for numerical stability.
  """
  random_t = torch.rand(x.shape[0], device=x.device) * (1. - eps) + eps  
  z = torch.randn_like(x)
  std = marginal_prob_std(random_t).reshape(-1, *[1 for _ in range( x.ndim-1)])
  perturbed_x = x + z * std

  # int_time = (random_t*100).to(torch.int32)
  # score = model(perturbed_x, int_time, condition)
  score = model(perturbed_x, random_t)
  loss = torch.mean(torch.sum((score * std + z)**2, dim=(1)))
  return loss

def loss_fn_fixed_time(model, x, marginal_prob_std, time):
  time = torch.max(time, torch.tensor(1e-3))
  time = torch.fill(torch.zeros(x.shape[0], device=x.device), time)
  z = torch.randn_like(x)
  std = marginal_prob_std(time).reshape(-1, *[1 for _ in range( x.ndim-1)])
  perturbed_x = x + z * std
  normalized_score = model(perturbed_x, time)
  loss = torch.mean(torch.sum((normalized_score * std + z)**2, dim=(1)))
  return loss

In [43]:
print(f"Do: {Do}, Da: {Da}")

In [58]:
# Check the dimensionality
# score_model = ConditionalUnet1D(
#     input_dim=2,
#     global_cond_dim=None,
#     cond_predict_scale=True,
#     down_dims=[256,512,1024],
#     kernel_size=3,
#     n_groups=8,
# )
# Make score net and data loader
image_dataset = ImageDatasetWrapper(data_dict)
image_dataloader = DataLoader(image_dataset, batch_size=128, shuffle=True)

score_model = ScoreNet(
    marginal_prob_std=marginal_prob_std_fn
)
# score_model = TransformerScoreNet(
#     input_dim=1,
#     marginal_prob_std=marginal_prob_std_fn,
# )
score_model.to(device)

data = next(iter(image_dataloader))
x = normalizer['data'].normalize(data['data']).reshape(data['data'].shape[0], -1, 2).permute(0, 2, 1).to(device)
# x = x.reshape(data['data'].shape[0], -1, 2).to(device)
# x = torch.randn(128, D, 2, device=device)
# condition = normalizer['condition'].normalize(data['condition']).to(device)
t = torch.ones(x.shape[0],  device=device)
# score = score_model(x, t, condition=condition)
# x = x.unsqueeze(1)
print(x.shape)
score = score_model(x, t)
# score = score_model(x, t, global_cond=None)
print(f"score.shape = {score.shape}")


In [47]:
score_model = ScoreNet(
    marginal_prob_std=marginal_prob_std_fn
)
# score_model = TransformerScoreNet(
#     input_dim=1,
#     marginal_prob_std=marginal_prob_std_fn,
# )
score_model.to(device)

n_epochs = 50
## learning rate
lr=1e-3
optimizer = Adam(score_model.parameters(), lr=lr)
scheduler = CosineAnnealingLR(optimizer, T_max=n_epochs*len(image_dataloader), eta_min=0.0)
tqdm_bar = tqdm(range(n_epochs))

lr_ls = list()
for epoch in tqdm_bar:
  
  avg_loss = 0.
  avg_sde_loss = 0.
  num_items = 0
  for data in image_dataloader:
    x = normalizer['data'].normalize(data['data']).reshape(data['data'].shape[0], -1, 2).permute(0, 2, 1).to(device)
    # y = normalizer['condition'].normalize(data['condition']).to(device)
    
    # conditional_loss = loss_fn(score_model, x, y, marginal_prob_std_fn, eps=1e-3)
    unconditional_loss = loss_fn(score_model, x, marginal_prob_std_fn, eps=1e-5)
    # fixed_time_loss_unconditional = loss_fn_fixed_time(score_model, x, marginal_prob_std_fn, torch.tensor(1e-3))
    # fixed_time_loss_conditional = loss_fn_fixed_time(score_model, x, y, marginal_prob_std_fn, torch.tensor(1e-3))
    
    # loss = conditional_loss + unconditional_loss #+ 0.00 * (fixed_time_loss_conditional + fixed_time_loss_unconditional)
    loss = unconditional_loss
    # dsm_loss = fixed_time_loss_conditional + fixed_time_loss_unconditional
    # sde_loss = conditional_loss + unconditional_loss

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    scheduler.step()
    
    avg_loss += loss.item() * x.shape[0]
    # avg_sde_loss += sde_loss.item() * x.shape[0]
    num_items += x.shape[0]
    
  # Print the averaged training loss so far.
  tqdm_bar.set_description(desc=f'Average Loss: {avg_loss / num_items:.5f}')
  print(f"Epoch {epoch} / {n_epochs}:{avg_loss / num_items:.5f} / lr:{scheduler.get_last_lr()} / sde_loss:{avg_sde_loss / num_items:.5f}")
  # Update the checkpoint after each epoch of training.
  torch.save(score_model.state_dict(), 'ckpt.pth')
  lr_ls.append(scheduler.get_last_lr())

In [49]:
x.shape

In [50]:
#@title Define the Euler-Maruyama sampler (double click to expand or collapse)

## The number of sampling steps.
num_steps =  500#@param {'type':'integer'}
def Euler_Maruyama_sampler(score_model,
                           marginal_prob_std,
                           diffusion_coeff, 
                           batch_size=64, 
                           num_steps=num_steps, 
                           device='cuda', 
                           eps=1e-3,
                           z = None):
  """Generate samples from score-based models with the Euler-Maruyama solver.

  Args:
    score_model: A PyTorch model that represents the time-dependent score-based model.
    marginal_prob_std: A function that gives the standard deviation of
      the perturbation kernel.
    diffusion_coeff: A function that gives the diffusion coefficient of the SDE.
    batch_size: The number of samplers to generate by calling this function once.
    num_steps: The number of sampling steps. 
      Equivalent to the number of discretized time steps.
    device: 'cuda' for running on GPUs, and 'cpu' for running on CPUs.
    eps: The smallest time step for numerical stability.
  
  Returns:
    Samples.    
  """
  t = torch.ones(batch_size, device=device)
  print(f"marginal_prob_std(t).shape: {marginal_prob_std(t).shape}")
  print(f"torch.randn(batch_size, Do, device=device).shape: {torch.randn(batch_size, Do, device=device).shape}")
  
  
  if z is None:
    init_x = torch.randn(batch_size, 2, 137, device=device) \
      * marginal_prob_std(t)[:, None, None]
  else:
    init_x = z
  time_steps = torch.linspace(1., eps, num_steps, device=device)
  step_size = time_steps[0] - time_steps[1]
  x = init_x
  tqdm_bar = tqdm(time_steps)
  with torch.no_grad():
    for time_step in tqdm_bar:      
      batch_time_step = torch.ones(batch_size, device=device) * time_step
      g = diffusion_coeff(batch_time_step)
      score_batch_time_step = batch_time_step
      # score_batch_time_step = (batch_time_step*100).to(dtype=torch.int32)
      mean_x = x + (g**2)[:, None, None] * score_model(x, score_batch_time_step) * step_size
      x = mean_x + torch.sqrt(step_size) * g[:, None, None] * torch.randn_like(x)      
  # Do not include any noise in the last sampling step.
  return mean_x

In [53]:
# ## Load the pre-trained checkpoint from disk.
# device = 'cuda' #@param ['cuda', 'cpu'] {'type':'string'}
# ckpt = torch.load('ckpt.pth', map_location=device)
# score_model.load_state_dict(ckpt)
# condition = torch.tensor([[0.0, 0.0]])
idx = torch.randint(0, len(image_dataset), size=(1,))
gt_data = image_dataset[idx]
condition = normalizer['condition'].normalize(gt_data['condition'])
condition = None
gt_data = gt_data['data']

if condition is None:
  sample_batch_size = 10000
else:
  sample_batch_size = 256

if condition is None:
  sampling_condition = None
else:
  sampling_condition = condition.repeat(sample_batch_size, 1).to(device)

sampler = Euler_Maruyama_sampler
#@param ['Euler_Maruyama_sampler', 'pc_sampler', 'ode_sampler'] {'type': 'raw'}

data = next(iter(image_dataloader))
data_dim = data['data'].shape[1:]

## Generate samples using the specified sampler.
samples = sampler(score_model,
                  marginal_prob_std_fn,
                  diffusion_coeff_fn, 
                  sample_batch_size, 
                  device=device,
                  ).detach().cpu().squeeze().permute(0, 2, 1).reshape(sample_batch_size, -1)

unnormalized_samples = normalizer['data'].unnormalize(samples)

In [None]:
print(f"gt_data.shape: {gt_data.shape} | gt_condition.shape: {condition.shape if condition is not None else None}")

In [None]:
samples.shape

In [None]:
# with torch.no_grad():
#     scores_on_sampled_samples = score_model(samples.to(device).unsqueeze(-1), torch.ones(samples.shape[0], device=device)*1e-3).squeeze(-1).detach().cpu().numpy()
#     scores_on_training_data = score_model(torch.tensor(plot_data).to(device).unsqueeze(-1), torch.ones(plot_data.shape[0], device=device)*1e-3).squeeze(-1).detach().cpu().numpy()

# print(scores_on_sampled_samples.shape)
# print(scores_on_training_data.shape)

In [59]:
num_subplots = Do-1
num_subplots = 273
Do = 273
num_column = int(np.sqrt(num_subplots))+1
num_row = (int((num_subplots)/num_column)+1)

plt.figure(figsize=(5*num_column, 5*num_row))
plot_iter = 0
for i in range(num_row):
    for j in range(num_column):
        plt.subplot(num_row, num_column, i*num_column+j+1)
        
        plt.scatter(plot_data[:, (plot_iter)%(Do-1)], plot_data[:, (plot_iter)%(Do-1)+1], s=2)
        plt.scatter(samples[:, (plot_iter)%(Do-1)], samples[:, (plot_iter)%(Do-1)+1], alpha=0.3, s=2)
        # plt.quiver(plot_data[:, (plot_iter)%(Do-1)], plot_data[:, (plot_iter)%(Do-1)+1], scores_on_training_data[:, (plot_iter)%(Do-1)], scores_on_training_data[:, (plot_iter)%(Do-1)+1], scale=1000)
        # plt.scatter(gt_data[(plot_iter)%(Do-1)], gt_data[(plot_iter)%(Do-1)+1], marker='*', color='red', s=50)
        plt.xlabel(f"x_{(plot_iter)%(Do-1)}")
        plt.ylabel(f"x_{(plot_iter)%(Do-1)+1}")
        plt.grid(True)
        
        plot_iter += 1
        
        if (plot_iter)+1 == Do:
            plot_iter = 0
            break 
plt.show()

In [21]:
def langevin_mcmc(x, score_model, n_steps, step_scale=0.001, t=1e-3):
    # Langevin MCMC sampling (No rejection)
    step_size = step_scale 
    noise = torch.randn_like(x).to(x.device) * np.sqrt(2*step_size)
    batch_t = torch.ones(x.shape[0], device=x.device) * t

    for _ in tqdm(range(n_steps)):
        with torch.no_grad():
            score = score_model(x, batch_t)
            x = x + step_size * score + noise
    return x

In [None]:
sample_batch_size=5000
sample_dataset_loader = DataLoader(image_dataset, batch_size=sample_batch_size, shuffle=True)

data = next(iter(sample_dataset_loader))
initial_x = (normalizer['data'].normalize(data['data']) + torch.randn_like(data['data'])).unsqueeze(1).to(device) * 0.00
samples = langevin_mcmc(initial_x, score_model, n_steps=100, step_scale=0.001, t=1e-3).squeeze(1).detach().cpu().numpy()

In [None]:
num_subplots = Do-1
num_subplots = 16
num_column = int(np.sqrt(num_subplots))+1
num_row = (int((num_subplots)/num_column)+1)

plt.figure(figsize=(5*num_column, 5*num_row))
plot_iter = 0
for i in range(num_row):
    for j in range(num_column):
        plt.subplot(num_row, num_column, i*num_column+j+1)
        
        plt.scatter(plot_data[:, (plot_iter)%(Do-1)], plot_data[:, (plot_iter)%(Do-1)+1], s=2)
        plt.scatter(samples[:, (plot_iter)%(Do-1)], samples[:, (plot_iter)%(Do-1)+1], alpha=0.3, s=2)
        # plt.scatter(gt_data[(plot_iter)%(Do-1)], gt_data[(plot_iter)%(Do-1)+1], marker='*', color='red', s=50)
        plt.xlabel(f"x_{(plot_iter)%(Do-1)}")
        plt.ylabel(f"x_{(plot_iter)%(Do-1)+1}")
        # plt.xlim([-1.1, 1.1])
        # plt.ylim([-1.1, 1.1])
        plt.grid(True)
        
        plot_iter += 1
        
        if (plot_iter)+1 == Do:
            plot_iter = 0
            break 
plt.show()

In [32]:
## Do you like this? Then save it!
# torch.save(score_model.state_dict(), 'score_with_trajectory_condition.pth')
# torch.save(normalizer.state_dict(), 'normalizer.pth')

In [None]:
dummy_normalizer = LinearNormalizer()
dummy_normalizer.load_state_dict(normalizer.state_dict())