# Import Libraries

In [1]:
### Importing basic libraries
import os   
import sys
import torch
import glob
import numpy as np
import matplotlib.pyplot as plt
import healpy as hp
from tqdm.auto import tqdm
from packaging import version
import datetime
from collections import namedtuple
from functools import wraps, partial
from inspect import isfunction

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
### Importing torch libraries
import torch
import torch.nn as nn
from torch import einsum
import torch.nn.functional as F
import torch.utils.data as data
import pytorch_lightning as pl
from torchvision.transforms import Compose, ToTensor, Normalize
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger


In [3]:
### Importing needed libraries
from einops import rearrange
import pickle

# Prefix Parameters

In [4]:
pl.seed_everything(1234)

[rank: 0] Global seed set to 1234


1234

In [5]:
### training params
num_epochs = 300
batch_size = 2
learning_rate = 1e-5
learning_rate_decay = 0.99

### diffusion params
timesteps = 1000
beta_start = 0.0001
beta_end = 0.02

### data params
LR_dir = "/gpfs02/work/tanimura/ana/UNet/data/dens_magneticum_snap25_Box128_grid32_CIC_noRSD/"
HR_dir = "/gpfs02/work/tanimura/ana/UNet/data/dens_magneticum_snap25_Box128_grid128_CIC_noRSD/"
mid_dir = "/gpfs02/work/tanimura/ana/UNet/data/dens_magneticum_snap25_Box128_grid64_CIC_noRSD/"

n_maps=100
rate_train =0.8

### model parameters
save_dir = './tests/'

In [6]:
### model parameters
save_dir = '/gpfs02/work/akira.tokiwa/gpgpu/instruction/'
logger = TensorBoardLogger(save_dir=save_dir, name='diffusion')

In [7]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# Diffusion Code

## Diffusion step
Code snippets ported from:
https://huggingface.co/blog/annotated-diffusion

In [8]:
### Functions for diffusion
def exists(x):
    return x is not None

def default(val, d):
    if exists(val):
        return val
    return d() if isfunction(d) else d

def cast_tuple(t, length = 1):
    if isinstance(t, tuple):
        return t
    return ((t,) * length)

def identity(t, *args, **kwargs):
    return t

def cycle(dl):
    while True:
        for data in dl:
            yield data

def extract(a, t, x_shape):
    batch_size = t.shape[0]
    out = a.gather(-1, t.cpu())
    shape = (batch_size, *((1,) * (len(x_shape) - 1)))
    return out.reshape(shape).to(t.device)

In [9]:
class Diffusion():
    def __init__(self, betas):
        #NOTE: you're choosing to make loss-type an argument of p_losses and not the diffusion model class
        #you should then save it separately when saving a run
        # calculations for diffusion q(x_t | x_0) and others
        self.betas = betas
        self.alphas = 1. - betas
        self.alphas_cumprod = torch.cumprod(self.alphas, axis=0)  # alpha_bar
        self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
        self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - self.alphas_cumprod)
        # x_t = sqrt_alphas_cumprod* x_0 + sqrt_one_minus_alphas_cumprod * eps_t
        self.sqrt_recip_alphas = torch.sqrt(1.0 / self.alphas)
        self.alphas_cumprod_prev = F.pad(self.alphas_cumprod[:-1], (1, 0), value=1.0)
        self.posterior_variance = self.betas * (1. - self.alphas_cumprod_prev) / (1. - self.alphas_cumprod)
        self.timesteps = len(self.betas)

    # forward diffusion (using the nice property)
    def q_sample(self, x_start, t, noise=None):
        if noise is None:
            noise = torch.randn_like(x_start)

        sqrt_alphas_cumprod_t = extract(self.sqrt_alphas_cumprod, t, x_start.shape)
        sqrt_one_minus_alphas_cumprod_t = extract(
            self.sqrt_one_minus_alphas_cumprod, t, x_start.shape
        )
        return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise

    def p_losses(self, denoise_model, x_start, t, noise=None, loss_type="l1", labels=None):
        # L_CE <= L_VLB ~ Sum[eps_t - MODEL(x_t(x_0, eps_t), t) ]
        if noise is None:
            noise = torch.randn_like(x_start)

        x_t = self.q_sample(x_start=x_start, t=t, noise=noise)
        predicted_noise = denoise_model(x_t, t, labels)

        if loss_type == 'l1':
            loss = F.l1_loss(noise, predicted_noise)
        elif loss_type == 'l2':
            loss = F.mse_loss(noise, predicted_noise)
        elif loss_type == "huber":
            loss = F.smooth_l1_loss(noise, predicted_noise)
        else:
            raise NotImplementedError()
        return loss

    @torch.no_grad()
    def timewise_loss(self, denoise_model, x_start, t, noise=None, loss_type="l1", labels=None):
        if noise is None:
            noise = torch.randn_like(x_start)

        x_t = self.q_sample(x_start=x_start, t=t, noise=noise)
        predicted_noise = denoise_model(x_t, t, labels)
        if loss_type == 'l1':
            loss = F.l1_loss(noise, predicted_noise, reduction='none')
        elif loss_type == 'l2':
            loss = F.mse_loss(noise, predicted_noise, reduction='none')
        elif loss_type == "huber":
            loss = F.smooth_l1_loss(noise, predicted_noise, reduction='none')
        else:
            raise NotImplementedError()
        loss = torch.mean(loss, dim=[-3, -2, -1]) #mean over all spatial dims
        return loss

    @torch.no_grad()
    def p_sample(self, model, x, t, t_index, label=None):
        betas_t = extract(self.betas, t, x.shape)
        sqrt_one_minus_alphas_cumprod_t = extract(
            self.sqrt_one_minus_alphas_cumprod, t, x.shape
        )
        sqrt_recip_alphas_t = extract(self.sqrt_recip_alphas, t, x.shape)

        # Equation 11 in the paper
        # Use our model (noise predictor) to predict the mean
        model_output = model(x, t) if label is None else model(x, t, label)
        model_mean = sqrt_recip_alphas_t * (
                x - betas_t * model_output / sqrt_one_minus_alphas_cumprod_t
        )

        if t_index == 0:
            return model_mean
        else:
            posterior_variance_t = extract(self.posterior_variance, t, x.shape)
            noise = torch.randn_like(x)
            # Algorithm 2 line 4:
            return model_mean + torch.sqrt(posterior_variance_t) * noise

    @torch.no_grad()
    def p_sample_loop(self, model, shape, labels=None):
        device = next(model.parameters()).device
        print('sample device', device)
        b = shape[0]
        # start from pure noise (for each example in the batch)
        img = torch.randn(shape, device=device)
        imgs = []
        if labels is not None:
            assert labels.shape[0] == shape[0]

        for i in tqdm(reversed(range(0, self.timesteps)), desc='sampling loop time step', total=self.timesteps):
            img = self.p_sample(model, img, torch.full((b,), i, device=device, dtype=torch.long), i, labels)
            imgs.append(img.cpu().numpy())
        return imgs

    @torch.no_grad()
    def sample(self, model, image_size, batch_size=16, channels=1, labels=None):
        return self.p_sample_loop(model, shape=(batch_size, image_size, channels), labels=labels)

## Diffusion scheduler
based on https://github.com/openai/improved-diffusion/blob/783b6740edb79fdb7d063250db2c51cc9545dcd1/improved_diffusion/resample.py

In [10]:
### Functions for scheduler

def linear_beta_schedule(timesteps, beta_start, beta_end):
    #beta_start = 0.0001
    #beta_end = 0.02
    return torch.linspace(beta_start, beta_end, timesteps)

def quadratic_beta_schedule(timesteps, beta_start, beta_end):
    #beta_start = 0.0001
    #beta_end = 0.02
    return torch.linspace(beta_start**0.5, beta_end**0.5, timesteps) ** 2

def sigmoid_beta_schedule(timesteps, beta_start, beta_end):
    #beta_start = 0.0001
    #beta_end = 0.02
    betas = torch.linspace(-6, 6, timesteps)
    return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start

def cosine_beta_schedule(timesteps, s=0.015):
    """
    cosine schedule as proposed in https://arxiv.org/abs/2102.09672
    """
    steps = timesteps + 1
    x = torch.linspace(0, timesteps, steps)
    alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * np.pi * 0.5) ** 2
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
    return torch.clip(betas, 0.0, 0.999) 

In [11]:
class TimestepSampler():
    def __init__(self, sampler_type='uniform', history=None, nstart=None, timesteps=None, uniweight=None, device='cuda'):
        self.type= sampler_type
        print('Sampler type', self.type)
        if self.type not in ['uniform', 'loss_aware']:
            raise NotImplementedError()
        if self.type=='loss_aware':
            self.sqloss_history = torch.ones((history, timesteps), device=device)*np.nan #L^2[b, t]
            self.nstart = nstart
            self.uniweight = 1/timesteps if uniweight is None else uniweight
            print('Nstart', nstart, type(nstart))
            print('History', history, type(history))
            print('Uniweight', self.uniweight, type(self.uniweight))

        self.timesteps = timesteps
        self.uniform = 1/timesteps
        self.device=device
        self.history_per_term = torch.zeros(timesteps, device=device, dtype=int)
        self.not_enough_history = True

    def get_weights(self, batch_size, iteration):
        if (iteration<self.nstart) or self.not_enough_history:
            return np.ones(self.timesteps)*self.uniform
        else:
            laweights = torch.sqrt(torch.mean(self.sqloss_history**2, dim=0))
            laweights /= laweights.sum()
            laweights *= (1-self.uniweight)
            laweights += self.uniweight/self.timesteps
            return laweights
        #fast way to evaluate / store the loss for different timesteps??
        #do you need a different sampler for storing history?

    def update_history(self, tl, loss_timewise):
        if self.not_enough_history:  # not-full loss history array
            for (t, tloss) in zip(tl, loss_timewise):
                if self.history_per_term[t] == self.sqloss_history.shape[0]:  # enough history
                    self.sqloss_history[:-1, t] = self.sqloss_history[1:, t]
                    self.sqloss_history[-1, t] = tloss
                else:
                    self.sqloss_history[self.history_per_term[t], t] = tloss
                    self.history_per_term[t] += 1
                    if self.history_per_term.min()==self.sqloss_history.shape[0]:
                        self.not_enough_history = False
                        print('Enough history for all')
        else:#enough history for all terms
            #test if this works fine
            self.sqloss_history[:-1, tl] = self.sqloss_history[1:, tl]
            self.sqloss_history[-1, tl] = loss_timewise
        return

    def get_timesteps(self, batch_size, iteration):
        if self.type=='uniform':
            return torch.randint(0, self.timesteps, (batch_size,), device=self.device).long()
        elif self.type=='loss_aware':
            weights = self.get_weights(batch_size, iteration)
            return torch.tensor(list(torch.utils.data.WeightedRandomSampler(weights, batch_size, replacement=True)), device=self.device).long()
        else:
            raise NotImplementedError()

# U-net Architecture 

In [12]:
class Upsample(nn.Module):
    def __init__(self, dim, dim_out = None, kernel_size=3):
        super().__init__()
        self.upsample = nn.Upsample(scale_factor = 2, mode='trilinear')
        self.conv = nn.Conv3d(dim, default(dim_out, dim), kernel_size, padding=1)

    def forward(self, x):
        x = self.upsample(x)
        x = self.conv(x)
        return x

class Downsample(nn.Module):
    def __init__(self, dim, dim_out = None, kernel_size=1):
        super().__init__()
        self.conv = nn.Conv3d(dim*8, default(dim_out, dim), kernel_size)

    def forward(self, x):
        x = rearrange(x, 'b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w', p1 = 2, p2 = 2, p3 = 2)
        x = self.conv(x)
        return x

In [13]:
class RMSNorm(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.g = nn.Parameter(torch.ones(1, dim, 1, 1, 1))

    def forward(self, x):
        return F.normalize(x, dim = 1) * self.g * (x.shape[1] ** 0.5)

In [14]:
AttentionConfig = namedtuple('AttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])
def once(fn):
    called = False
    @wraps(fn)
    def inner(x):
        nonlocal called
        if called:
            return
        called = True
        return fn(x)
    return inner

print_once = once(print)
class Attend(nn.Module):
    def __init__(
        self,
        dropout = 0.,
        flash = False
    ):
        super().__init__()
        self.dropout = dropout
        self.attn_dropout = nn.Dropout(dropout)

        self.flash = flash
        assert not (flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'

        # determine efficient attention configs for cuda and cpu

        self.cpu_config = AttentionConfig(True, True, True)
        self.cuda_config = None

        if not torch.cuda.is_available() or not flash:
            return

        device_properties = torch.cuda.get_device_properties(torch.device('cuda'))

        if device_properties.major == 8 and device_properties.minor == 0:
            print_once('A100 GPU detected, using flash attention if input tensor is on cuda')
            self.cuda_config = AttentionConfig(True, False, False)
        else:
            print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda')
            self.cuda_config = AttentionConfig(False, True, True)

    def flash_attn(self, q, k, v):
        _, heads, q_len, _, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device

        q, k, v = map(lambda t: t.contiguous(), (q, k, v))

        # Check if there is a compatible device for flash attention

        config = self.cuda_config if is_cuda else self.cpu_config

        # pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale

        with torch.backends.cuda.sdp_kernel(**config._asdict()):
            out = F.scaled_dot_product_attention(
                q, k, v,
                dropout_p = self.dropout if self.training else 0.
            )

        return out

    def forward(self, q, k, v):
        """
        einstein notation
        b - batch
        h - heads
        n, i, j - sequence length (base sequence length, source, target)
        d - feature dimension
        """

        q_len, k_len, device = q.shape[-2], k.shape[-2], q.device

        if self.flash:
            return self.flash_attn(q, k, v)

        scale = q.shape[-1] ** -0.5

        # similarity

        sim = einsum(f"b h i d, b h j d -> b h i j", q, k) * scale

        # attention

        attn = sim.softmax(dim = -1)
        attn = self.attn_dropout(attn)

        # aggregate values

        out = einsum(f"b h i j, b h j d -> b h i d", attn, v)

        return out

In [15]:
class SinusoidalPositionEmbeddings(nn.Module):
    #embeds time in the phase
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        device = time.device
        half_dim = self.dim // 2
        embeddings = np.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = time[:, None].float() * embeddings[None, :] #t1: [40, 1], t2: [1, 32]. Works on cpu, not on mps
        #^ is matmul: torch.allclose(res, torch.matmul(t1.float(), t2)): True when cpu
        #NM: added float for mps
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        return embeddings #Bx64

In [16]:
class Block(nn.Module):
    """
    Basic building block for the Unet architecture.
    """
    def __init__(self, in_channels, out_channels, kernel_size=3, padding=1, num_groups=8):
        super().__init__()
        self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, padding=padding)
        self.norm = nn.GroupNorm(num_groups, out_channels)
        self.act = nn.LeakyReLU(0.1) if out_channels > 1 else nn.Identity()

    def forward(self, x, scale_shift = None):
        x = self.conv(x)
        x = self.norm(x) 

        if exists(scale_shift):
            scale, shift = scale_shift
            x = x * (scale + 1) + shift

        x = self.act(x)
        return x

In [17]:
class ResnetBlock(nn.Module):
    """
    Residual block composed of two basic blocks. https://arxiv.org/abs/1512.03385
    """
    def __init__(self, in_channels, out_channels, time_emb_dim, kernel_size=3, padding=1, groups = 8):
        super().__init__()
        self.mlp = (
            nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, out_channels * 2))
            if exists(time_emb_dim)
            else None
        )
        self.block1 = Block(in_channels, out_channels, kernel_size, padding, groups)
        self.block2 = Block(out_channels, out_channels, kernel_size, padding, groups) 
        self.res_conv = nn.Conv3d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()


    def forward(self, x, time_emb = None):
        scale_shift = None
        if exists(self.mlp) and exists(time_emb):
            time_emb = self.mlp(time_emb)
            time_emb = rearrange(time_emb, 'b c -> b c 1 1 1')
            scale_shift = time_emb.chunk(2, dim = 1)

        h = self.block1(x, scale_shift = scale_shift)

        h = self.block2(h)

        return h + self.res_conv(x)

In [18]:
class LinearAttention(nn.Module):
    def __init__(
        self,
        dim,
        heads = 4,
        dim_head = 32
    ):
        super().__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        hidden_dim = dim_head * heads

        self.norm = RMSNorm(dim)
        self.to_qkv = nn.Conv3d(dim, hidden_dim * 3, 1, bias = False)

        self.to_out = nn.Sequential(
            nn.Conv3d(hidden_dim, dim, 1),
            RMSNorm(dim)
        )

    def forward(self, x):
        b, c, d, h, w = x.shape

        #x = self.norm(x)

        qkv = self.to_qkv(x).chunk(3, dim = 1)
        q, k, v = map(lambda t: rearrange(t, 'b (h c) x y z -> b h c (x y z)', h = self.heads), qkv)

        q = q.softmax(dim = -2)
        k = k.softmax(dim = -1)

        q = q * self.scale

        context = torch.einsum('b h d n, b h e n -> b h d e', k, v)

        out = torch.einsum('b h d e, b h d n -> b h e n', context, q)
        out = rearrange(out, 'b h c (x y z) -> b (h c) x y z', h = self.heads, x = d, y = h, z = w)
        return self.to_out(out)

In [19]:
class Attention(nn.Module):
    def __init__(
        self,
        dim,
        heads = 4,
        dim_head = 32,
        flash = False
    ):
        super().__init__()
        self.heads = heads
        hidden_dim = dim_head * heads

        self.norm = RMSNorm(dim)
        self.attend = Attend(flash = flash)

        self.to_qkv = nn.Conv3d(dim, hidden_dim * 3, 1, bias = False)
        self.to_out = nn.Conv3d(hidden_dim, dim, 1)

    def forward(self, x):
        b, c, d, h, w = x.shape

        #x = self.norm(x)

        qkv = self.to_qkv(x).chunk(3, dim = 1)
        q, k, v = map(lambda t: rearrange(t, 'b (h c) x y z -> b h (x y z) c', h = self.heads), qkv)

        out = self.attend(q, k, v)

        out = rearrange(out, 'b h (x y z) d -> b (h d) x y z', x = d, y = h, z = w, h = self.heads)
        return self.to_out(out)

In [30]:
class Unet(pl.LightningModule):
    """
    Full Unet architecture composed of an encoder (downsampler), a bottleneck, and a decoder (upsampler).
    """
    def __init__(
        self,
        dim,
        init_dim = None,
        out_dim = None,
        dim_mults = (1, 2, 4),
        channels = 3,
        self_condition = False,
        resnet_block_groups = 8,
        full_attn = (False, False, True),
        flash_attn = False
    ):
        super().__init__()

        # determine dimensions

        self.channels = channels
        self.self_condition = self_condition
        input_channels =  channels * (2 if self_condition else 1)

        init_dim = default(init_dim, dim)
        self.init_conv = nn.Conv3d(input_channels, init_dim, 7, padding = 3)

        dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
        in_out = list(zip(dims[:-1], dims[1:]))

        block_klass = partial(ResnetBlock, groups = resnet_block_groups)

        # time embeddings

        time_dim = dim * 4
        fourier_dim = dim

        self.time_mlp = nn.Sequential(
            SinusoidalPositionEmbeddings(fourier_dim),
            nn.Linear(fourier_dim, time_dim),
            nn.GELU(),
            nn.Linear(time_dim, time_dim)
        )

        # attention

        full_attn = cast_tuple(full_attn, length = len(dim_mults))
        assert len(full_attn) == len(dim_mults)

        FullAttention = partial(Attention, flash = flash_attn)

        # layers

        self.downs = nn.ModuleList([])
        self.ups = nn.ModuleList([])
        num_resolutions = len(in_out)

        for ind, ((dim_in, dim_out), layer_full_attn) in enumerate(zip(in_out, full_attn)):
            is_last = ind >= (num_resolutions - 1)

            attn_klass = FullAttention if layer_full_attn else LinearAttention

            self.downs.append(nn.ModuleList([
                block_klass(dim_in, dim_in, time_emb_dim = time_dim),
                block_klass(dim_in, dim_in, time_emb_dim = time_dim),
                attn_klass(dim_in),
                Downsample(dim_in, dim_out) if not is_last else nn.Conv3d(dim_in, dim_out, 3, padding = 1)
            ]))

        mid_dim = dims[-1]
        self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim)
        self.mid_attn = FullAttention(mid_dim)
        self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim)

        for ind, ((dim_in, dim_out), layer_full_attn) in enumerate(zip(reversed(in_out), reversed(full_attn))):
            is_last = ind == (len(in_out) - 1)

            attn_klass = FullAttention if layer_full_attn else LinearAttention

            self.ups.append(nn.ModuleList([
                block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim),
                block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim),
                attn_klass(dim_out),
                Upsample(dim_out, dim_in) if not is_last else  nn.Conv3d(dim_out, dim_in, 3, padding = 1)
            ]))

        default_out_dim = channels * 2
        self.out_dim = default(out_dim, default_out_dim)

        self.final_res_block = block_klass(dim * 2, dim, time_emb_dim = time_dim)
        self.final_conv = nn.Conv3d(dim, self.out_dim, 1)

    def forward(self, x, time, x_self_cond=None):
        if self.self_condition:
            x_self_cond = default(x_self_cond, lambda: torch.zeros_like(x))
            x = torch.cat((x, x_self_cond), dim = 1)
            
        x = self.init_conv(x)
        r = x.clone()

        t = self.time_mlp(time)

        h = []

        for block1, block2, attn, downsample in self.downs:
            x = block1(x, t)
            print("x_block1", x.shape)
            h.append(x)

            x = block2(x, t)
            print("x_block2", x.shape)
            #tmp_attn = attn(x)
            #print("tmp_attn", tmp_attn.shape)
            #x = tmp_attn + x
            #print("x_attn", x.shape)
            h.append(x)

            x = downsample(x)
            print("x_downsample", x.shape)

        x = self.mid_block1(x, t)
        print("x_mid_block1", x.shape)
        #x = self.mid_attn(x) + x
        x = self.mid_block2(x, t)
        print("x_mid_block2", x.shape)

        for block1, block2, attn, upsample in self.ups:
            x = torch.cat((x, h.pop()), dim = 1)
            print("x_cat", x.shape)
            x = block1(x, t)
            print("x_block1", x.shape)

            x = torch.cat((x, h.pop()), dim = 1)
            print("x_cat", x.shape)
            x = block2(x, t)
            print("x_block2", x.shape)
            #x = attn(x) + x

            x = upsample(x)
            print("x_upsample", x.shape)

        x = torch.cat((x, r), dim = 1)

        x = self.final_res_block(x, t)
        return self.final_conv(x)

In [21]:
class Unet_pl(pl.LightningModule):
    def __init__(self, 
                channels = 1,  
                dim = 64,
                init_dim = None,
                out_dim = None,
                batch_size = 64,
                learning_rate = 1e-5,
                learning_rate_decay = 0.99,
                num_epochs = 300,
                timesteps = 1000,
                beta_start = 0.0001,
                beta_end = 0.02,
                loss_type="huber", 
                sampler=None, 
                conditional=False):
        super().__init__()
        self.model = Unet(dim=dim, channels=dim, init_dim=init_dim, out_dim=out_dim, self_condition=conditional)
        self.batch_size = batch_size
        self.lr_init = learning_rate
        self.learning_rate_decay = learning_rate_decay
        self.num_epochs = num_epochs

        betas = linear_beta_schedule(timesteps, beta_start, beta_end)
        self.diffusion = Diffusion(betas)

        self.init_conv = nn.Conv3d(channels, dim, 7, padding = 3)
        self.init_conv_condition = nn.Sequential(
            Upsample(channels, dim),
            Upsample(dim, dim)
        )

        self.loss_type = loss_type
        self.sampler = sampler
        self.conditional = conditional
        self.loss_spike_flg = 0

    def training_step(self, batch, batch_idx):
        if self.conditional:
            hr, lr = batch 
            hr = self.init_conv(hr)
            lr = self.init_conv_condition(lr)
            x = hr -lr
            labels = lr
        else:
            x = batch

        t = self.sampler.get_timesteps(x.shape[0], self.current_epoch)
        loss = self.diffusion.p_losses(self.model, x, t, loss_type=self.loss_type, labels=labels if self.conditional else None)
        self.log('train_loss', loss)

        if self.sampler.type == 'loss_aware':
            loss_timewise = self.diffusion.timewise_loss(self.model, x, t, loss_type=self.loss_type, labels=labels if self.conditional else None)
            self.sampler.update_history(t, loss_timewise)

        if loss.item() > 0.1 and self.current_epoch > 300 and (self.loss_spike_flg < 2):
            badbdict = {'batch': batch.detach().cpu().numpy(), 'itn': self.current_epoch, 't': t.detach().cpu().numpy(), 'loss': loss.item()}
            pickle.dump(badbdict, open(f'largeloss_{self.current_epoch}.pkl', 'wb'))
            self.loss_spike_flg += 1
        return loss
    
    def validation_step(self, batch, batch_idx):
        if self.conditional:
            hr, lr = batch 
            hr = self.init_conv(hr)
            lr = self.init_conv_condition(lr)
            x = hr - lr
            labels = lr
        else:
            x = batch

        t = self.sampler.get_timesteps(x.shape[0], self.current_epoch)
        loss = self.diffusion.p_losses(self.model, x, t, loss_type=self.loss_type, labels=labels if self.conditional else None)
        self.log('val_loss', loss)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr_init)
        scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=self.learning_rate_decay)
        return {"optimizer": optimizer, "lr_scheduler": scheduler}  

# prepare data

## Functions

In [8]:
def get_minmax_transform(rangemin, rangemax):
    """
    Function to get a pair of transforms that normalize and denormalize tensors.
    """
    transform = Compose([lambda t: (t - rangemin) / (rangemax - rangemin) * 2 - 1])
    inverse_transform = Compose([lambda t: (t + 1) / 2 * (rangemax - rangemin) + rangemin])
    
    return transform, inverse_transform

In [9]:
class MapDataset(data.Dataset):
    """
    Class for the map dataset.
    """
    def __init__(self, mapdir, n_maps=100):
        self.maps = sorted(glob.glob(f'{mapdir}*.npy'))[:n_maps]

    def __getitem__(self, index):
        dmaps = np.array([np.load(map) for map in self.maps])  
        tensor_map = torch.from_numpy(dmaps).float()
        
        return tensor_map

## Data

In [10]:
data_hr = MapDataset(HR_dir, n_maps=10).__getitem__(0)
data_lr = MapDataset(LR_dir, n_maps=10).__getitem__(0)

In [11]:
RANGE_MIN, RANGE_MAX = data_hr.min().clone().detach(), data_hr.max().clone().detach()
transforms, inverse_transforms = get_minmax_transform(RANGE_MIN, RANGE_MAX)
combined_dataset = data.TensorDataset(transforms(data_hr).unsqueeze(1), transforms(data_lr).unsqueeze(1))

In [13]:
len_train = int(rate_train * len(data_lr))
len_val = len(data_lr) - len_train
train, val = data.random_split(combined_dataset, [len_train, len_val])
loaders = {x: data.DataLoader(ds, batch_size=batch_size, shuffle=x=='train', num_workers=os.cpu_count()) for x, ds in zip(('train', 'val'), (train, val))}

train_loader, val_loader= loaders['train'], loaders['val']

# Train Model

In [27]:
### Functions for training
def setup_trainer(num_epochs, logger=None, patience=10):
    early_stop_callback = EarlyStopping(
        monitor="val_loss",
        patience=patience,
        verbose=0,
        mode="min"
    )

    dt = datetime.datetime.now()
    name = dt.strftime('Run_%m-%d_%H-%M')

    checkpoint_callback = ModelCheckpoint(
        filename= name + "{epoch:02d}-{val_loss:.2f}",
        save_top_k=1,
        monitor="val_loss",
        save_last=True,
        mode="min"
    )

    trainer = pl.Trainer(
        max_epochs=num_epochs,
        callbacks=[checkpoint_callback, early_stop_callback],
        num_sanity_val_steps=0,
        accelerator='gpu', 
        devices=1,
        logger=logger
    )
    return trainer

In [29]:
model = Unet_pl(dim = 64,
                init_dim = 64,
                out_dim = 1,
                batch_size = 2,
                learning_rate = 1e-5,
                learning_rate_decay = 0.99,
                num_epochs = 300,
                timesteps = 1000,
                beta_start = 0.0001,
                beta_end = 0.02,
                loss_type="huber", 
                sampler=sampler, 
                conditional=True).to(device)

In [31]:
trainer = setup_trainer(num_epochs, logger=logger)
trainer.fit(model, train_loader, val_loader)

  rank_zero_warn(
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [3]

  | Name                | Type       | Params
---------------------------------------------------
0 | model               | Unet       | 28.6 M
1 | init_conv           | Conv3d     | 22.0 K
2 | init_conv_condition | Sequential | 112 K 
---------------------------------------------------
28.8 M    Trainable params
0         Non-trainable params
28.8 M    Total params
115.045   Total estimated model params size (MB)
  rank_zero_warn(


Epoch 0:   0%|          | 0/5 [00:00<?, ?it/s] x_block1 torch.Size([2, 64, 128, 128, 128])
x_block2 torch.Size([2, 64, 128, 128, 128])


RuntimeError: Given groups=1, weight of size [64, 256, 1, 1, 1], expected input[2, 512, 64, 64, 64] to have 256 channels, but got 512 channels instead

# Load Model

In [None]:
map_dir = '/gpfs02/work/akira.tokiwa/gpgpu/instruction/maps/'
if not os.path.exists(map_dir):
    os.makedirs(map_dir)

In [None]:
ckpt_path = trainer.checkpoint_callback.best_model_path

In [None]:
model.load_state_dict(torch.load(ckpt_path)['state_dict'], strict=False)

_IncompatibleKeys(missing_keys=['model.init_conv.laplacian', 'model.init_conv_lr.laplacian', 'model.down_blocks.0.0.block1.conv.laplacian', 'model.down_blocks.0.0.block2.conv.laplacian', 'model.down_blocks.0.1.block1.conv.laplacian', 'model.down_blocks.0.1.block2.conv.laplacian', 'model.down_blocks.1.0.block1.conv.laplacian', 'model.down_blocks.1.0.block2.conv.laplacian', 'model.down_blocks.1.0.res_conv.laplacian', 'model.down_blocks.1.1.block1.conv.laplacian', 'model.down_blocks.1.1.block2.conv.laplacian', 'model.down_blocks.2.0.block1.conv.laplacian', 'model.down_blocks.2.0.block2.conv.laplacian', 'model.down_blocks.2.0.res_conv.laplacian', 'model.down_blocks.2.1.block1.conv.laplacian', 'model.down_blocks.2.1.block2.conv.laplacian', 'model.down_blocks.3.0.block1.conv.laplacian', 'model.down_blocks.3.0.block2.conv.laplacian', 'model.down_blocks.3.0.res_conv.laplacian', 'model.down_blocks.3.1.block1.conv.laplacian', 'model.down_blocks.3.1.block2.conv.laplacian', 'model.mid_block1.block

In [None]:
model = model.to(device)