In [1]:
import torch
import models as m
import os
import json
import matplotlib.pyplot as plt
import torch.distributions as dist
import numpy as np
import math
import shutil


In [2]:
torch.set_default_dtype(torch.float32)

if torch.cuda.is_available(): 
    device = torch.device('cuda:0')
    dtype = torch.FloatTensor
elif torch.mps.is_available():
    device = torch.device('mps')
    dtype = torch.FloatTensor
else:
    device = torch.device('cpu')
    dtype = torch.FloatTensor
    
def t(x):
    # j'avais des problèmes de type avec les long 
    return torch.as_tensor(x, dtype=torch.get_default_dtype()).to(device)

print(device)


cuda:0


In [3]:
BASE_PATH = os.getcwd()
RUNS_ROOT = os.path.join(BASE_PATH, "runs")
EXP_NAME = "baseline"
EXP_DIR = os.path.join(RUNS_ROOT, EXP_NAME)
RUN_ID = '001'
RUN_DIR = os.path.join(EXP_DIR, RUN_ID)  
WEIGHTS_DIR = os.path.join(RUN_DIR, "weights")
LOGS_DIR = os.path.join(RUN_DIR, "logs")
VIDEO_DIR = os.path.join(RUN_DIR, 'videos')

with open(os.path.join(LOGS_DIR, 'hparams.json')) as json_data:
    hparams = json.load(json_data)
    json_data.close()
   
hparams

{'batch_size': 128,
 'N_train': 10000,
 'EVAL_EVERY': 700,
 'lr': 0.001,
 'EPS_TRAINNING': False,
 'sigma': {'schedule': 'lin',
  'min': 0.1,
  'max': 1,
  'n_sigmas': 10,
  'values': [1.0,
   0.7742636799812317,
   0.5994842648506165,
   0.46415889263153076,
   0.35938137769699097,
   0.2782559394836426,
   0.2154434621334076,
   0.1668100506067276,
   0.1291549652814865,
   0.10000000149011612]},
 'device': 'cuda:0',
 'model': {'in_channel': 1,
  'base_ch': 16,
  'channel_mults': [1, 2, 4],
  'sigma_emb_dim': 16}}

In [4]:
SIGMAS = t(hparams['sigma']['values'])
EPS_TRAINNING = hparams['EPS_TRAINNING']
SCORE_NORM = np.load(os.path.join(LOGS_DIR,'score_norm.npy'))

In [5]:
w = torch.load(os.path.join(WEIGHTS_DIR, 'model.pt'),map_location=device,weights_only=True)
model = m.SmallUNetSigma(
    in_ch=hparams['model']['in_channel'],
    base_ch=hparams['model']['base_ch'],
    channel_mults=hparams['model']['channel_mults'],  
    emb_dim=hparams['model']['sigma_emb_dim'],
).to(device)

model.load_state_dict(w)
model = model.to(device)
model.eval()

SmallUNetSigma(
  (sigma_emb): SigmaEmbedding(
    (net): Sequential(
      (0): Linear(in_features=1, out_features=16, bias=True)
      (1): SiLU()
      (2): Linear(in_features=16, out_features=16, bias=True)
      (3): SiLU()
    )
  )
  (init_conv): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (down_blocks): ModuleList(
    (0): ResBlock(
      (norm1): GroupNorm(8, 16, eps=1e-05, affine=True)
      (act1): SiLU()
      (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (norm2): GroupNorm(8, 16, eps=1e-05, affine=True)
      (act2): SiLU()
      (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (emb_proj1): Sequential(
        (0): SiLU()
        (1): Linear(in_features=16, out_features=16, bias=True)
      )
      (emb_proj2): Sequential(
        (0): SiLU()
        (1): Linear(in_features=16, out_features=16, bias=True)
      )
      (skip): Identity()
    )
    (1): ResBlock(
      (norm1): GroupN

In [6]:
# pas de gradient ici pour ALD 
def make_score_from_model(model, sigma_scalar,eps_loss = False):
    sigma_scalar = float(sigma_scalar)
    
    @torch.no_grad()
    def score(x):
        # x: (B, x_dim)
        B = x.shape[0]
        sigma = x.new_full((B, 1), sigma_scalar).to(device)   # (B,1) 
        return model(x, sigma)
    
    @torch.no_grad()
    def score_eps(x):
       # x: (B, x_dim)
        B = x.shape[0]
        sigma = x.new_full((B, 1), sigma_scalar)   # (B,1) 
        return model(x, sigma) /sigma
    if eps_loss : 
        return score_eps
    else : 
        return score

In [7]:
estimated_distribution_scores = [make_score_from_model(model,noise,EPS_TRAINNING) for noise in list(SIGMAS)]

In [8]:
sigma_prior = SIGMAS.max().item()  # ou sigmas.max().item()
prior_normal = dist.Normal(
    loc=torch.tensor(0.0, device=device),
    scale=torch.tensor(sigma_prior, device=device),
)
prior_unif = dist.Uniform(-1,1)

In [14]:
def annealded_langevin_sampler_snr(prior, noisy_distrib_scores, noise_factor,
                                   SNR, norm, T, n_chain, save_dir=None):
    C = 1
    H = 28
    W = 28
    X = prior.sample((n_chain, C, H, W)).to(device)

    D = H * W
    PLOT_STEP_EVERY = 100           # tu peux changer si tu veux plus/moins de frames
    IDX_TO_TRACK = 0              # on suit l'image X[0] dans le temps

    if save_dir is not None:
        if os.path.exists(save_dir):
            shutil.rmtree(save_dir)   # supprime tout ce qu'il y a dedans (et le dossier)
        os.makedirs(save_dir)

    frame = 0  # pour numéroter les images

    for i in range(len(noisy_distrib_scores)):
        tau = 2 * D * SNR / norm[i]
        noise_std = math.sqrt(2 * tau)   # niveau de bruit

        for step in range(T):

            X = X + tau * noisy_distrib_scores[i](X) + noise_std * torch.randn_like(X)

            # on sauve régulièrement l'évolution d'UNE SEULE image
            if step % PLOT_STEP_EVERY == 0  or (step + 1)% T == 0:
                plt.figure(figsize=(3, 3))
                img = X[IDX_TO_TRACK].clone().squeeze().detach().cpu()
                plt.imshow(img, cmap='gray')
                plt.axis('off')
                plt.title(
                    f"sigma={noise_factor[i]:.3f} | "
                    f"SNR={SNR:.2e} | "
                    f"noise_std={noise_std:.2e} | "
                    f"step={step}"
                )

                if save_dir is not None:
                    fname = os.path.join(save_dir, f"frame_{frame:05d}.png")
                    plt.savefig(fname, dpi=150, bbox_inches='tight', pad_inches=0)
                    plt.close()
                else:
                    plt.show()

                frame += 1

    if save_dir is not None:
        os.system(f"./make_ald_video.sh {save_dir} {SNR} {T}")

    return X


In [10]:
OUTDIR = os.path.join(VIDEO_DIR,'out/')

In [15]:
SNR = 0.001
ALD_estimated_score_snr = annealded_langevin_sampler_snr(
    prior_normal,
    estimated_distribution_scores,
    SIGMAS,
    0.01,
    SCORE_NORM,
    T = 1000,
    n_chain=1,
    save_dir=OUTDIR
)

ffmpeg version 4.4.2-0ubuntu0.22.04.1 Copyright (c) 2000-2021 the FFmpeg developers
  built with gcc 11 (Ubuntu 11.2.0-19ubuntu1)
  configuration: --prefix=/usr --extra-version=0ubuntu0.22.04.1 --toolchain=hardened --libdir=/usr/lib/x86_64-linux-gnu --incdir=/usr/include/x86_64-linux-gnu --arch=amd64 --enable-gpl --disable-stripping --enable-gnutls --enable-ladspa --enable-libaom --enable-libass --enable-libbluray --enable-libbs2b --enable-libcaca --enable-libcdio --enable-libcodec2 --enable-libdav1d --enable-libflite --enable-libfontconfig --enable-libfreetype --enable-libfribidi --enable-libgme --enable-libgsm --enable-libjack --enable-libmp3lame --enable-libmysofa --enable-libopenjpeg --enable-libopenmpt --enable-libopus --enable-libpulse --enable-librabbitmq --enable-librubberband --enable-libshine --enable-libsnappy --enable-libsoxr --enable-libspeex --enable-libsrt --enable-libssh --enable-libtheora --enable-libtwolame --enable-libvidstab --enable-libvorbis --enable-libvpx --enab