In [10]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, TensorDataset

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

device

device(type='cuda')

In [2]:
def get_dataset(func, n_samples=1024):
    x = np.linspace(-10, 10, n_samples).astype(np.float32)
    y = func(x).astype(np.float32)
    y = y.reshape(1, -1) 
    return torch.from_numpy(y).to(device), torch.from_numpy(x).to(device)

In [12]:
# import diffusers.schedulers.scheduling_ddpm 

# def get_beta_schedule(T):
#     return torch.linspace(1e-4, 0.02, T).to(device)

def custom_beta_schedule(T, kind='cosine'):
    if kind == 'cosine':
        steps = torch.arange(T, dtype=torch.float32, device=device)
        betas = (torch.cos(steps / T * (torch.pi / 2))) ** 2 * 0.02
        return betas
    elif kind == 'exp':
        start = torch.log(torch.tensor(1e-4, device=device))
        end = torch.log(torch.tensor(0.02, device=device))
        betas = torch.exp(torch.linspace(start, end, T, device=device))
        return betas
    else:
        return torch.linspace(1e-4, 0.02, T, device=device)



# def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps):
#   if beta_schedule == 'quad':
#     betas = np.linspace(beta_start ** 0.5, beta_end ** 0.5, num_diffusion_timesteps, dtype=np.float64) ** 2
#   elif beta_schedule == 'linear':
#     betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
#   elif beta_schedule == 'warmup10':
#     betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1)
#   elif beta_schedule == 'warmup50':
#     betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.5)
#   elif beta_schedule == 'const':
#     betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
#   elif beta_schedule == 'jsd':  # 1/T, 1/(T-1), 1/(T-2), ..., 1
#     betas = 1. / np.linspace(num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64)
#   else:
#     raise NotImplementedError(beta_schedule)
#   assert betas.shape == (num_diffusion_timesteps,)
#   return betas

In [4]:
def forward_diffusion_sample(x0, t, betas):
    noise = torch.randn_like(x0)
    sqrt_alpha_cumprod = torch.sqrt(torch.cumprod(1. - betas, dim=0)).to(device)
    alpha_t = sqrt_alpha_cumprod[t].view(-1, 1)
    xt = alpha_t * x0 + torch.sqrt(1 - alpha_t ** 2) * noise
    return xt, noise

In [None]:
def get_model(signal_length=1024, n_func=10):
    class UNet1D(nn.Module):
        def __init__(self):
            super().__init__()
            self.time_embed = nn.Sequential(
                nn.Embedding(1000, 64),
                nn.Linear(64, signal_length)
            )
            self.func_embed = nn.Sequential(
                nn.Embedding(n_func, 64),
                nn.Linear(64, signal_length)
            )
            self.downs = nn.ModuleList([
                nn.Sequential(nn.Conv1d(1, 32, kernel_size=4, stride=2, padding=1), nn.ReLU()),
                nn.Sequential(nn.Conv1d(32, 64, kernel_size=4, stride=2, padding=1), nn.ReLU())
            ])
            self.mid = nn.Sequential(nn.Conv1d(64, 64, kernel_size=3, padding=1), nn.ReLU())
            self.ups = nn.ModuleList([
                nn.Sequential(nn.ConvTranspose1d(64, 32, kernel_size=4, stride=2, padding=1), nn.ReLU()),
                nn.ConvTranspose1d(32, 1, kernel_size=4, stride=2, padding=1)
            ])
        def forward(self, x, t, func_id):
            x = x.unsqueeze(1)
            t_emb = self.time_embed(t).unsqueeze(1)
            f_emb = self.func_embed(func_id).unsqueeze(1)
            x = x + t_emb + f_emb
            h1 = self.downs[0](x)
            h2 = self.downs[1](h1)
            h = self.mid(h2)
            h = self.ups[0](h)
            h = self.ups[1](h)
            return h.squeeze(1)
    return UNet1D().to(device)

In [None]:
def train(model, dataset, betas, T=1000, epochs=1000):
    model.train()
    loader = DataLoader(dataset, batch_size=1, shuffle=True)
    opt = torch.optim.Adam(model.parameters(), lr=1e-4)
    for epoch in range(epochs):
        for x0, func_id in loader:
            x0 = x0.to(device)
            func_id = func_id.to(device).long()
            t = torch.randint(0, T, (x0.size(0),), device=device)
            xt, noise = forward_diffusion_sample(x0, t, betas)
            pred = model(xt, t, func_id)
            loss = F.mse_loss(pred, noise)
            opt.zero_grad()
            loss.backward()
            opt.step()


In [7]:
def compare_signals(original, test, label):
    mse = F.mse_loss(test, original).item()

    print(f"{label} mse: {mse:.6f}")

In [None]:
@torch.no_grad()
def diff_process(model, func, betas, steps=100, x=None):
    model.eval()
    if x is None:
        x = np.linspace(-10, 10, 1024).astype(np.float32)
    y = func(x).astype(np.float32)
    y_torch = torch.from_numpy(y).to(device).view(1, -1)

    noisy_versions = []
    noise = torch.randn_like(y_torch)
    sqrt_alpha_cumprod = torch.sqrt(torch.cumprod(1 - betas, dim=0)).to(device)

    for t in range(steps):
        alpha_bar = sqrt_alpha_cumprod[t]
        xt = alpha_bar * y_torch + torch.sqrt(1 - alpha_bar**2) * noise
        noisy_versions.append(xt.detach().cpu().numpy()[0])

    recovered_versions = []
    xt = noisy_versions[-1]
    xt = torch.from_numpy(xt).to(device).view(1, -1)

    for t in reversed(range(steps)):
        t_tensor = torch.tensor([t], device=device)
        pred_noise = model(xt, t_tensor)
        coef = 1 / torch.sqrt(1 - betas[t])
        xt = (xt - betas[t] * pred_noise) * coef
        recovered_versions.append(xt.detach().cpu().numpy()[0])

    indices_to_show = [int(steps * i / 10) for i in range(11)]
    if indices_to_show[-1] != steps-1:
        indices_to_show[-1] = steps-1

    import matplotlib.pyplot as plt
    fig, axs = plt.subplots(2, 11, figsize=(33, 6), sharey=True)

    for i, idx in enumerate(indices_to_show):
        axs[0, i].plot(x, y, label='Oryginał', linewidth=1, color='black')
        axs[0, i].plot(x, noisy_versions[idx], label=f"Krok {idx}", color='blue')
        axs[0, i].set_title(f"Zaszum {idx}")
        axs[0, i].set_xticks([])
        axs[0, i].set_yticks([])

        axs[1, i].plot(x, y, label='Oryginał', linewidth=1, color='black')
        axs[1, i].plot(x, recovered_versions[idx], label=f"Krok {idx}", color='red')
        axs[1, i].set_title(f"Odszum {idx}")
        axs[1, i].set_xticks([])
        axs[1, i].set_yticks([])

    axs[0, 0].legend(loc='upper right')

    original = torch.from_numpy(y).to(device).view(1, -1)
    noisy_final = torch.from_numpy(noisy_versions[-1]).to(device).view(1, -1)
    denoised_final = torch.from_numpy(recovered_versions[-1]).to(device).view(1, -1)

    compare_signals(original, noisy_final, label='zaszumiony (ostatni krok)')
    compare_signals(original, denoised_final, label='odszumiony (ostatni krok)')

    plt.tight_layout()
    plt.show()

In [5]:
def f_sin(x): return np.sin(x) 
def f_tan(x): return np.tan(x)
def f_sgn(x): return np.sign(x)
def f_sigmoid(x): return 1 / (1 + np.exp(-x))
def f_relu(x): return np.maximum(0, x)
def f_log10(x): return np.log10(np.clip(x, 1e-3, None))
def f_log2(x): return np.log2(np.clip(x, 1e-3, None))
def f_inv(x): return 1 / np.clip(x, 1e-3, None)
def f_exp(x): return np.exp(x)
def f_poly(x): return x**2 + 2*x + 1
def f_sin_1x(x): return np.sin(1/x)
def f_sin_2(x): return np.sin(x)**2

f_names =  [("sin", f_sin),("tan", f_tan),("exp", f_exp),("log10", f_log10),("poly", f_poly),
            ("sigmoid", f_sigmoid),("relu", f_relu),("inv", f_inv),("sin_1x", f_sin_1x),("zin_2", f_sin_2)] #("log2", f_log2),


# f_names =  [("log10", f_log10),("log2", f_log2),("poly", f_poly)]

f_index = {name: idx for idx, (name, _) in enumerate(f_names)}


In [None]:
def get_combined_dataset(f_names):
    xs = []
    funcs = []
    for i, (name, func) in enumerate(f_names):
        y, x = get_dataset(func)
        xs.append(y)
        funcs.append(torch.full((y.shape[0],), i, dtype=torch.long)) 
    xs = torch.cat(xs, dim=0)
    funcs = torch.cat(funcs, dim=0)
    return TensorDataset(xs, funcs)

In [None]:
if __name__ == '__main__':

    log_every = 200  

    epochs = [1000, 5000, 10000]

    Ts=[100, 250, 500, 1000]

    for epoch in epochs:
        for ts in Ts:
            betas_cosine = custom_beta_schedule(ts, kind='cosine')
            betas_exp = custom_beta_schedule(ts, kind='exp')
            betas_linear = custom_beta_schedule(ts, kind='linear')
            betas =[("cosine", betas_cosine), ("exp", betas_exp), ("linear", betas_linear)]
            for beta_name, beta in betas:
                betas = custom_beta_schedule(ts, kind='cosine')
                dataset = get_combined_dataset(f_names)
                model = get_model()
             
                train(model, dataset, beta, T=ts, epochs=epoch)
                # diff_process(model, func[1], beta, steps=ts, x=x.cpu().numpy())
                print(f"unet_general_T{ts}_ep{epoch}_kind{beta_name}")

                
                torch.save(model.state_dict(), f"unet_general_T{ts}_ep{epoch}_kind{beta_name}.pth")

unet_general_T100_ep1000_kindcosine
unet_general_T100_ep1000_kindexp
unet_general_T100_ep1000_kindlinear
unet_general_T250_ep1000_kindcosine
unet_general_T250_ep1000_kindexp
unet_general_T250_ep1000_kindlinear
unet_general_T500_ep1000_kindcosine
unet_general_T500_ep1000_kindexp
unet_general_T500_ep1000_kindlinear
unet_general_T1000_ep1000_kindcosine
unet_general_T1000_ep1000_kindexp
unet_general_T1000_ep1000_kindlinear
unet_general_T100_ep5000_kindcosine
unet_general_T100_ep5000_kindexp
unet_general_T100_ep5000_kindlinear
unet_general_T250_ep5000_kindcosine
unet_general_T250_ep5000_kindexp
unet_general_T250_ep5000_kindlinear
unet_general_T500_ep5000_kindcosine
unet_general_T500_ep5000_kindexp
unet_general_T500_ep5000_kindlinear
unet_general_T1000_ep5000_kindcosine
unet_general_T1000_ep5000_kindexp
unet_general_T1000_ep5000_kindlinear
unet_general_T100_ep10000_kindcosine
unet_general_T100_ep10000_kindexp
unet_general_T100_ep10000_kindlinear
unet_general_T250_ep10000_kindcosine
unet_gene