In [22]:
import torch
import torchvision
import torchvision.transforms as transforms
import os

# 데이터 변환 정의 (텐서 변환 및 정규화)
transform = transforms.Compose([
    transforms.ToTensor(),
])

# MNIST 데이터셋 다운로드
train_dataset = torchvision.datasets.MNIST(root="/home/dataset/mnist", train=True, transform=transform, download=False)

# 데이터로더 생성
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=1000, shuffle=True)

def loader(dl):
    while True:
        for dataset in dl:
            yield dataset

In [23]:
def cosine_beta_schedule(timesteps, s = 0.008):
    """
    cosine schedule
    as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
    """
    steps = timesteps + 1
    t = torch.linspace(0, timesteps, steps, dtype = torch.float64) / timesteps
    alphas_cumprod = torch.cos((t + s) / (1 + s) * math.pi * 0.5) ** 2
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
    return torch.clip(betas, 0, 0.999)

def linear_beta_schedule(timesteps):
    """
    linear schedule, proposed in original ddpm paper
    """
    scale = 1000 / timesteps
    beta_start = scale * 0.0001
    beta_end = scale * 0.02
    return torch.linspace(beta_start, beta_end, timesteps, dtype = torch.float64)

def exists(x):
    return x is not None

def default(val, d):
    if exists(val):
        return val
    return d() if callable(d) else d

def extract(a, t, x_shape):
    b, *_ = t.shape
    out = a.gather(-1, t)
    return out.reshape(b, *((1,) * (len(x_shape) - 1)))

def gamma_cosine(t, lambda_val=2):
    """
    Cosine schedule 기반 gamma_t 계산
    t: 연속 시간 (0~1 사이)
    lambda_val: Cosine decay의 steepness 조절 파라미터
    """
    s_t = (torch.pi / 2) * t ** lambda_val  # s(t) = (pi/2) * t^lambda
    sin_2s = torch.sin(2 * s_t)
    cos2_s = torch.cos(s_t) ** 2
    gamma_t = (torch.pi / 4) * lambda_val * t ** (lambda_val - 1) * (sin_2s / cos2_s)

    return gamma_t



In [24]:
import torch.nn.functional as F
from torch.distributions import MultivariateNormal
import math

class Schedule():
    def __init__(self, device):
        betas = linear_beta_schedule(1000)
        self.betas = betas.to(device)

        alphas = 1. - self.betas
        self.alphas_cumprod = torch.cumprod(alphas, dim=0)
        self.alphas_cumprod_prev = F.pad(self.alphas_cumprod[:-1], (1, 0), value = 1.)

        self.one_minus_alphas_cumprod = 1. - self.alphas_cumprod

        self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
        self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - self.alphas_cumprod)
        self.log_one_minus_alphas_cumprod = torch.log(1. - self.alphas_cumprod)
        self.sqrt_recip_alphas_cumprod = torch.sqrt(1. / self.alphas_cumprod)
        self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1. / self.alphas_cumprod - 1)

        self.posterior_variance = self.betas * (1. - self.alphas_cumprod_prev) / (1. - self.alphas_cumprod)

        self.posterior_log_variance_clipped = torch.log(self.posterior_variance.clamp(min =1e-20))
        self.posterior_mean_coef1 = self.betas * torch.sqrt(self.alphas_cumprod_prev) / (1. - self.alphas_cumprod)
        self.posterior_mean_coef2 = (1. - self.alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - self.alphas_cumprod)



In [25]:
import torch.nn.functional as F


def Weight_post(image_t, dataset, sqrt_at, one_minus_at):

    b, *_ = dataset.shape

    mse = (image_t - sqrt_at * dataset) ** 2
    mse = mse.view(dataset.shape[0], -1).sum(dim=1)
    one_minus_recip = - 1. / (2 * one_minus_at)
    mse = mse * one_minus_recip

    weight = F.softmax(mse, dim=0)

    return weight.view(b, *((1,) * (len(dataset.shape) - 1)))


In [26]:
def score_f_estimator(image_t, dataset, t, schedule, gene = None):
    sqrt_at = schedule.sqrt_alphas_cumprod[t]
    one_minus_at = schedule.one_minus_alphas_cumprod[t]
    #sqrt_at = extract(schedule.sqrt_alphas_cumprod, t, image.shape).squeeze()
    #one_minus_at = extract(schedule.one_minus_alphas_cumprod, t, image.shape).squeeze()

    if gene is None:
        image_t = image_t * sqrt_at + noise * sqrt_one_minus_at
    weight_schedule = Weight_post(image_t, dataset, sqrt_at, one_minus_at)

    one_recip = -1. / one_minus_at

    score = (image_t - sqrt_at * dataset) * weight_schedule

    score = score * one_recip

    score = torch.sum(score, dim=0).unsqueeze(0)

    return score, max(weight_schedule), torch.argmax(weight_schedule)



In [28]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)  # "cuda" 또는 "cpu"

cuda


In [29]:
if torch.cuda.is_available():
    device = torch.device("cuda:0")  # 첫 번째 GPU 사용
    print(f"Using device: {device}")

Using device: cuda:0


In [30]:
if torch.cuda.is_available():
    print(torch.cuda.get_device_name(0))  # 첫 번째 GPU의 이름 출력

NVIDIA A40
