In [2]:
from abc import ABC, abstractmethod
import math
import torch
from torch import utils
from torch import nn
from torch import distributions
from torch import optim
import torch.nn.functional as F
from torchvision import datasets
from torchvision import transforms
from torchvision.transforms import ToPILImage

In [None]:
data_dir = '/Users/armandli/data/'

In [None]:
use_cuda = torch.cuda.is_available()
use_mps = torch.backends.mps.is_built()
if use_cuda:
    device = torch.device('cuda')
elif use_mps:
    device = torch.device('mps')
else:
    device = torch.device('cpu')
cpu = torch.device('cpu')

In [None]:
default_batch_size = 256
loader_args = {'batch_size' : default_batch_size, 'shuffle' : True}
score_args = {'batch_size' : default_batch_size, 'shuffle' : False}
if use_cuda:
    loader_args.update({'pin_memory' : True})
    score_args.update({'pin_memory' : True})

In [None]:
class Reporter(ABC):
    @abstractmethod
    def report(self, typ, **metric):
        pass
    @abstractmethod
    def reset(self):
        pass

In [None]:
class SReporter(Reporter):
    def __init__(self):
        self.log = []
    def report(self, typ, **data):
        self.log.append((typ, data))
    def reset(self):
        self.log.clear()
    def loss(self, t):
        losses = []
        for (typ, data) in self.log:
            if typ == t:
                losses.append(data['loss'])
        return losses
    def loss(self, t, idx):
        if idx >= 0:
            count = 0
            for (typ, data) in self.log:
                if typ == t:
                    if count == idx:
                        return data['loss']
                    count += 1
        else:
            count = -1
            for (typ, data) in reversed(self.log):
                if typ == t:
                    if count == idx:
                        return data['loss']
                    count -= 1
        return float("inf")
    def eval_loss(self):
        return self.loss('eval')
    def train_loss(self):
        return self.loss('train')
    def eval_loss(self, idx):
        return self.loss('eval', idx)
    def train_loss(self, idx):
        return self.loss('train', idx)
    def get_record(self, t, idx):
        if idx >= 0:
            count = 0
            for (typ, data) in self.log:
                if typ == t:
                    if count == idx:
                        return data
                    count += 1
        else:
            count = -1
            for (typ, data) in reversed(self.log):
                if typ == t:
                    if count == idx:
                        return data
                    count -= 1
        return dict()
    def eval_record(self, idx):
        return self.get_record('eval', idx)
    def train_record(self, idx):
        return self.get_record('train', idx)

Models

In [None]:
class GaussianDistributionV1:
    def __init__(self, parameters):
        self.mean, log_var = torch.chunk(parameters, 2, dim=1)
        self.log_var = torch.clamp(log_var, -30., 20.)
        self.std = torch.exp(0.5 * self.log_var)
    def sample(self):
        return self.mean + self.std * torch.randn_like(self.std)

In [None]:
class CrossAttentionV1(nn.Module):
    def __init__(self, d_model, d_cond, n_heads, d_head, is_inplace=True):
        super(CrossAttentionV1, self).__init__()
        d_attn = d_head * n_heads
        self.is_inplace = is_inplace
        self.n_heads = n_heads
        self.scale = d_head ** -0.5
        self.to_q = nn.Linear(d_model, d_attn, bias=False)
        self.to_k = nn.Linear(d_cond, d_attn, bias=False)
        self.to_v = nn.Linear(d_cond, d_attn, bias=False)
        self.to_out = nn.Sequential(
            nn.Linear(d_attn, d_model),
        )

    
    def forward(self, x, cond=None):
        has_cond = cond is not None
        if not has_cond:
            cond = x
        q = self.to_q(x)
        k = self.to_k(cond)
        v = self.to_v(cond)
        return self.attention(q, k, v)

    def attention(self, q, k, v):
        q = q.view(*q.shape[:2], self.n_heads, -1)
        k = k.view(*k.shape[:2], self.n_heads, -1)
        v = v.view(*v.shape[:2], self.n_heads, -1)
        attn = torch.einsum('bihd,bjhd->bhij', q, k) * self.scale
        if self.is_inplace:
            half = attn.shape[0] // 2
            attn[half:] = attn[half:].softmax(dim=-1)
            attn[:half] = attn[:half].softmax(dim=-1)
        else:
            attn = attn.softmax(dim=-1)
        out = torch.einsum('bhij,bjhd->bihd', attn, v)
        out = out.reshape(*out.shape[:2], -1)
        return self.to_out(out)

In [None]:
class GeGLUV1(nn.Module):
    def __init__(self, d_in, d_out):
        super(GeGLUV1, self).__init__()
        self.proj = nn.Linear(d_in, d_out * 2)
    
    def forward(self, x):
        x, gate = self.proj(x).chunk(2, dim=-1)
        return x * F.gelu(gate)

In [None]:
class FeedForwardV1(nn.Module):
    def __init__(self, d_model, d_mult=4):
        super(FeedForwardV1, self).__init__()
        self.net = nn.Sequential(
            GeGLUV1(d_model, d_model * d_mult),
            nn.Dropout(0.), #TODO: not needed ?
            nn.Linear(d_model * d_mult, d_model),
        )
    def forward(self, x):
        return self.net(x)

In [None]:
class BasicTransformerBlockV1(nn.Module):
    def __init__(self, d_model, n_heads, d_head, d_cond):
        super(BasicTransformerBlockV1, self).__init__()
        self.attn1 = CrossAttentionV1(d_model, d_model, n_heads, d_head)
        self.norm1 = nn.LayerNorm(d_model)
        self.attn2 = CrossAttentionV1(d_model, d_cond, n_heads, d_head)
        self.norm2 = nn.LayerNorm(d_model)
        self.ff = FeedForwardV1(d_model)
        self.norm3 = nn.LayerNorm(d_model)
    
    def forward(self, x, cond):
        x = self.attn1(self.norm1(x)) + x
        x = self.attn2(self.norm2(x), cond=cond) + x
        x = self.ff(self.norm3(x)) + x
        return x

In [None]:
class SpatialTransformerV1(nn.Module):
    def __init__(self, channels, n_heads, n_layers, d_cond):
        super(SpatialTransformerV1, self).__init__()
        self.norm = nn.GroupNorm(num_groups=32, num_channels=channels, eps=1e-6, affine=True)
        self.proj_in = nn.Conv2d(channels, channels, kernel_size=1, stride=1, padding=0)
        self.transformer_blocks = nn.ModuleList(
            [BasicTransformerBlockV1(channels, n_heads, channels // n_heads, d_cond=d_cond) for _ in range(n_layers)]
        )
        self.proj_out = nn.Conv2d(channels, channels, kernel_size=1, stride=1, padding=0)
    
    def forward(self, x, cond):
        b, c, h, w = x.shape
        x_in = x
        x = self.norm(x)
        x = self.proj_in(x)
        x = x.permute(0, 2, 3, 1).view(b, h*w, c)
        for block in self.transformer_blocks:
            x = block(x, cond)
        x = x.view(b, h, w, c).permute(0, 3, 1, 2)
        x = self.proj_out(x)
        return x + x_in

In [None]:
class TimeEmbeddingV1(nn.Module):
    def __init__(self, n_channels):
        super(TimeEmbeddingV1, self).__init__()
        self.n_channels = n_channels
        self.layers = nn.Sequential(
            nn.Linear(self.n_channels // 4, self.n_channels),
            nn.SiLU(inplace=True),
            nn.Linear(self.n_channels, self.n_channels),
        )

    def forward(self, t):
        half_dim = self.n_channels // 8
        emb = math.log(10_000) * (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=t.device) * -emb)
        emb = t[:, None] * emb[None, :]
        emb = torch.cat((emb.sin(), emb.cos()), dim=1)
        emb = self.layers(emb)
        return emb

In [None]:
class TimeEmbeddingV2(nn.Sequential):
    def forward(self, x, t_emb, cond=None):
        for layer in self:
            if isinstance(layer, ResidualBlockV2): #TODO: what's ResidualBlockV2?
                x = layer(x, t_emb)
            elif isinstance(layer, SpatialTransformerV1):
                x = layer(x, cond)
            else:
                x = layer(x)
        return x

In [None]:
#TODO: why ?
class GroupNorm32(nn.GroupNorm):
    def forward(self, x):
        return super().forward(x.float()).type(x.dtype)

In [None]:
class ResidualBlockV1(nn.Module):
    def __init__(self, in_channels, out_channels, time_channels, n_groups, dropout):
        super(ResidualBlockV1, self).__init__()
        self.layer1 = nn.Sequential(
            nn.GroupNorm(n_groups, in_channels),
            nn.SiLU(inplace=True),
            nn.Conv2d(in_channels, out_channels, kernel_size=(3, 3), padding=(1, 1)),
        )
        self.layer2 = nn.Sequential(
            nn.GroupNorm(n_groups, out_channels,),
            nn.SiLU(inplace=True),
            nn.Dropout(dropout),
            nn.Conv2d(out_channels, out_channels, kernel_size=(3, 3), padding=(1, 1)),
        )
        self.t_layer = nn.Sequential(
            nn.SiLU(inplace=True),
            nn.Linear(time_channels, out_channels),
        )
        if in_channels != out_channels:
            self.skip = nn.Conv2d(in_channels, out_channels, kernel_size=(1, 1))
        else:
            self.skip = nn.Identity()
    
    def forward(self, x, t):
        h = self.layer1(x)
        h += self.t_layer(t)[:, :, None, None]
        h = self.layer2(h)
        return h + self.skip(x)

In [None]:
# only differ by not having the t 
class ResidualBlockV3(nn.Module):
    def __init__(self, in_channels, out_channels, n_groups, dropout):
        super(ResidualBlockV3, self).__init__()
        self.layer1 = nn.Sequential(
            nn.GroupNorm(n_groups, in_channels),
            nn.SiLU(inplace=True),
            nn.Conv2d(in_channels, out_channels, kernel_size=(3, 3), padding=(1, 1)),
        )
        self.layer2 = nn.Sequential(
            nn.GroupNorm(n_groups, out_channels,),
            nn.SiLU(inplace=True),
            nn.Dropout(dropout),
            nn.Conv2d(out_channels, out_channels, kernel_size=(3, 3), padding=(1, 1)),
        )
        if in_channels != out_channels:
            self.skip = nn.Conv2d(in_channels, out_channels, kernel_size=(1, 1))
        else:
            self.skip = nn.Identity()
    
    def forward(self, x):
        h = self.layer1(x)
        h = self.layer2(h)
        return h + self.skip(x)

In [None]:
class AttentionBlockV1(nn.Module):
    def __init__(self, n_channels, n_heads, d_k=None, n_groups=32):
        super(AttentionBlockV1, self).__init__()
        if d_k is None:
            d_k = n_channels
        self.projection = nn.Linear(n_channels, n_heads * d_k * 3)
        self.output = nn.Linear(n_heads * d_k, n_channels)
        self.scale = d_k ** -0.5
        self.n_heads = n_heads
        self.d_k = d_k
    
    #NOTE: t is not used
    def forward(self, x, t=None):
        batch_size, n_channels, height, width = x.shape
        x = x.view(batch_size, n_channels, -1).permute(0, 2, 1)
        qkv = self.projection(x).view(batch_size, -1, self.n_heads, 3 * self.d_k)
        q, k, v = torch.chunk(qkv, 3, dim=-1)
        attn = torch.einsum('bihd,bjhd->bijh', q, k) * self.scale
        attn = attn.softmax(dim=2)
        res = torch.einsum('bijh,bjhd->bihd', attn, v)
        res = res.view(batch_size, -1, self.n_heads * self.d_k)
        res = self.output(res)
        res += x
        res = res.permute(0, 2, 1).view(batch_size, n_channels, height, width)
        return res

In [None]:
class AttentionBlockV2(nn.Module):
    def __init__(self, channels):
        super(AttentionBlockV2, self).__init__()
        self.norm = nn.GroupNorm(32, channels)
        self.q = nn.Conv2d(channels, channels, 1)
        self.k = nn.Conv2d(channels, channels, 1)
        self.v = nn.Conv2d(channels, channels, 1)
        self.proj_out = nn.Conv2d(channels, channels, 1)
        self.scale = channels ** -0.5
    
    def forward(self, x):
        x = self.norm(x)
        q = self.q(x)
        k = self.k(x)
        v = self.v(x)
        b, c, h, w = q.shape
        q = q.view(b, c, h * w)
        k = k.view(b, c, h * w)
        v = v.view(b, c, h * w)
        attn = torch.einsum('bci,bcj->bij', q, k) * self.scale
        attn = F.softmax(attn, dim=2)
        out = torch.einsum('bij,bcj->bci', attn, v)
        out = out.view(b, c, h, w)
        out = self.proj_out(out)
        return x + out

In [None]:
class DownBlockV1(nn.Module):
    def __init__(self, in_channels, out_channels, time_channels, has_attn):
        super(DownBlockV1, self).__init__()
        self.res = ResidualBlockV1(in_channels, out_channels, time_channels)
        if has_attn:
            self.attn = AttentionBlockV1(out_channels)
        else:
            self.attn = nn.Identity()
    
    def forward(self, x, t):
        x = self.res(x, t)
        x = self.attn(x)
        return x

In [None]:
class DownSampleV1(nn.Module):
    def __init__(self, channels):
        super(DownSampleV1, self).__init__()
        self.conv = nn.Conv2d(channels, channels, 3, stride=2, padding=0)
        
    def forward(self, x):
        x = F.pad(x, (0, 1, 0, 1), mode='constant', value=0)
        return self.conv(x)

In [None]:
class UpBlockV1(nn.Module):
    def __init__(self, in_channels, out_channels, time_channels, has_attn):
        super(UpBlockV1, self).__init__()
        self.res = ResidualBlockV1(in_channels + out_channels, out_channels, time_channels)
        if has_attn:
            self.attn = AttentionBlockV1(out_channels)
        else:
            self.attn = nn.Identity()
    
    def forward(self, x, t):
        x = self.res(x, t)
        x = self.attn(x)
        return x

In [None]:
class MiddleBlockV1(nn.Module):
    def __init__(self, n_channels, time_channels):
        super(MiddleBlockV1, self).__init__()
        self.res1 = ResidualBlockV1(n_channels, n_channels, time_channels)
        self.attn = AttentionBlockV1(n_channels)
        self.res2 = ResidualBlockV1(n_channels, n_channels, time_channels)
    
    def forward(self, x, t):
        x = self.res1(x, t)
        x = self.attn(x)
        x = self.res2(x, t)
        return x

In [None]:
class UpSampleV1(nn.Module):
    def __init__(self, n_channels):
        super(UpSampleV1, self).__init__()
        self.conv = nn.ConvTranspose2d(n_channels, n_channels, (4, 4), (2, 2), (1, 1))
    
    def forward(self, x, t):
        return self.conv(x)

In [None]:
class UpSampleV2(nn.Module):
    def __init__(self, channels):
        super(UpSampleV2, self).__init__()
        self.conv = nn.Conv2d(channels, channels, 3, padding=1)
    
    def forward(self, x):
        x = F.interpolate(x, scale_factor=2, mode='nearest')
        return self.conv(x)

In [None]:
class DownSampleV1(nn.Module):
    def __init__(self, n_channels):
        super(DownSampleV1, self).__init__()
        self.conv = nn.Conv2d(n_channels, n_channels, (3, 3), (2, 2), (1, 1))
    
    def forward(self, x, t):
        return self.conv(x)

In [None]:
class ResidualBlockV2(nn.Module):
    def __init__(self, channels, d_t_emb, out_channels=None):
        super(ResidualBlockV2, self).__init__()
        if out_channels is None:
            out_channels = channels
        self.in_layers = nn.Sequential(
            GroupNorm32(32, channels),
            nn.SiLU(inplace=True),
            nn.Conv2d(channels, out_channels, 3, padding=1),
        )
        self.emb_layers = nn.Sequential(
            nn.SiLU(inplace=True),
            nn.Linear(d_t_emb, out_channels),
        )
        self.out_layers = nn.Sequential(
            GroupNorm32(32, out_channels),
            nn.SiLU(inplace=True),
            nn.Groupout(0.),
            nn.Conv2d(out_channels, out_channels, 3, padding=1),
        )
        if out_channels == channels:
            self.skip = nn.Identity()
        else:
            self.skip = nn.Conv2d(channels, out_channels, 1)
    
    def forward(self, x, t_emb):
        h = self.in_layers(x)
        t_emb = self.emb_layers(t_emb).type(h.dtype)
        h = h + t_emb[:, :, None, None]
        h = self.out_layers(h)
        return self.skip(x) + h

In [None]:
class EncoderV1(nn.Module):
    def __init__(self, channels, channel_multipliers, n_resnet_blocks, in_channels, z_channels):
        super(EncoderV1, self).__init__()
        n_resolutions = len(channel_multipliers)
        channels_list = [m * channels for m in [1] + channel_multipliers]
        self.conv_in = nn.Conv2d(in_channels, channels, 3, stride=1, padding=1)
        self.down = nn.ModuleList()
        for i in range(n_resolutions):
            resnet_blocks = nn.ModuleList()
            for _ in range(n_resnet_blocks):
                resnet_blocks.append(ResidualBlockV3(channels, channels_list[i+1], 32, 0.0))
                channels = channels_list[i+1]
            down = nn.Module()
            down.block = resnet_blocks
            if i != n_resolutions - 1:
                down.downsample = DownSampleV1(channels)
            else:
                down.downsample = nn.Identity()
            self.down.append(down)
        self.mid = nn.Sequential(
            ResidualBlockV3(channels, channels, 32, 0.0),
            AttentionBlockV2(channels),
            ResidualBlockV3(channels, channels, 32, 0.0),
            nn.GroupNorm(32, channels),
            nn.SiLU(inplace=True),
            nn.Conv2d(channels, 2 * z_channels, 3, stride=1, padding=1),
        )
    
    def forward(self, x):
        x = self.conv_in(x)
        for down in self.down:
            for block in down.block:
                x = block(x)
            x = down.downsample(x)
        x = self.mid(x)
        return x
        

In [None]:
class DecoderV1(nn.Module):
    def __init__(self, *, channels, channel_multipliers, n_resnet_blocks, out_channels, z_channels):
        super(DecoderV1, self).__init__()
        num_resolutions = len(channel_multipliers)
        channels_list = [m * channels for m in channel_multipliers]
        self.mid = nn.Sequential(
            nn.Conv2d(z_channels, channels, 3, stride=1, padding=1),
            ResidualBlockV3(channels, channels, 32, 0.0),
            AttentionBlockV2(channels),
            ResidualBlockV3(channels, channels, 32, 0.0),
        )
        self.up = nn.ModuleList()
        for i in reversed(range(num_resolutions)):
            resnet_blocks = nn.ModuleList()
            for _ in range(n_resnet_blocks + 1):
                resnet_blocks.append(ResidualBlockV3(channels, channels_list[i], 32, 0.0))
                channels = channels_list[i]
            up = nn.Module()
            up.block = resnet_blocks
            if i != 0:
                up.upsample = UpSampleV2(channels)
            else:
                up.upsample = nn.Identity()
            self.up.insert(0, up)
        self.out = nn.Sequentail(
            nn.GroupNorm(32, channels),
            nn.SiLU(inplace=True),
            nn.Conv2d(channels, out_channels, 3, stride=1, padding=1),
        )
    
    def forward(self, z):
        h = self.mid(z)
        for up in reversed(self.up):
            for block in up.block:
                h = block(h)
            h = up.upsample(h)
        x = self.out(h)
        return x

In [None]:
class AutoEncoderV1(nn.Module):
    def __init__(self, encoder, decoder, emb_channels, z_channels):
        super(AutoEncoderV1, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.quant_conv = nn.Conv2d(2 * z_channels, 2 * emb_channels, 1)
        self.post_quant_conv = nn.Conv2d(emb_channels, z_channels, 1)
    
    def encode(self, x):
        z = self.encoder(x)
        moments = self.quant_conv(z)
        return GaussianDistributionV1(moments)

    def decode(self, z):
        z = self.post_quant_conv(z)
        return self.decoder(z)

In [None]:
class UNetV1(nn.Module):
    def __init__(self, image_channels=3, n_channels=64, ch_mults=(1,2,2,4), is_attn=(False,False,True,True), n_blocks=2):
        super(UNetV1, self).__init__()
        n_resolution = len(ch_mults)
        self.img_proj = nn.Conv2d(image_channels, n_channels, kernel_size=(3, 3), padding=(1,1))
        self.time_emb = TimeEmbeddingV1(n_channels * 4)
        out_channels = in_channels = n_channels
        down = []
        for i in range(n_resolution):
            out_channels = in_channels * ch_mults[i]
            for _ in range(n_blocks):
                down.append(DownBlockV1(in_channels, out_channels, n_channels * 4, is_attn[i]))
                in_channels = out_channels
            if i < n_resolution - 1:
                down.append(DownSampleV1(in_channels))
        self.down = nn.ModuleList(down)
        self.middle = MiddleBlockV1(out_channels, n_channels * 4, )
        in_channels = out_channels
        up = []
        for i in reversed(range(n_resolution)):
            out_channels = in_channels
            for _ in range(n_blocks):
                up.append(UpBlockV1(in_channels, out_channels, n_channels * 4, is_attn[i]))
            out_channels = in_channels // ch_mults[i]
            up.append(UpBlockV1(in_channels, out_channels, n_channels * 4, is_attn[i]))
            in_channels = out_channels
            if i > 0:
                up.append(UpSampleV1(in_channels))
        self.up = nn.ModuleList(up)

        self.final = nn.Sequential(
            nn.GroupNorm(8, n_channels),
            nn.SilU(inplace=True),
            nn.Conv2d(in_channels, image_channels, kernel_size=(3,3), padding=(1,1)),
        )
    
    def forward(self, x, t):
        t = self.time_emb(t)
        x = self.image_proj(x)
        h = [x]
        for m in self.down:
            x = m(x, t)
            h.append(x)
        x = self.middle(x, t)
        for m in self.up:
            if isinstance(m, UpSampleV1):
                x = m(x, t)
            else:
                s = h.pop()
                x = torch.cat((x, s), dim=1)
                x = m(x, t)
        return self.final(x)

In [None]:
class UNetV2(nn.Module):
    def __init__(self, in_channels, out_channels, channels, n_res_blocks, attention_levels, channel_multipliers, n_heads, tf_layers=1, d_cond=768):
        super(UNetV2, self).__init__()
        self.channels = channels
        d_time_emb = channels * 4
        self.time_embed = nn.Sequential(
            nn.Linear(channels, d_time_emb),
            nn.SilU(),
            nn.Linear(d_time_emb, d_time_emb),
        )
        self.input_blocks = nn.ModuleList()
        levels = len(channel_multipliers)
        channels_list = [channels * m for m in channel_multipliers]
        input_block_channels = [channels]
        self.input_blocks.append(TimeEmbeddingV2(nn.Conv2d(in_channels, channels, 3, padding=1)))
        for i in range(levels):
            for _ in range(n_res_blocks):
                layers = [ResidualBlockV2(channels, d_time_emb, out_channels=channels_list[i])]
                channels = channels_list[i]
                if i in attention_levels:
                    layers.append(SpatialTransformerV1(channels, n_heads, tf_layers, d_cond))
                self.input_blocks.append(TimeEmbeddingV2(*layers))
                input_block_channels.append(channels)
            if i != levels - 1:
                self.input_blocks.append(TimeEmbeddingV2(DownSampleV1(channels)))
                input_block_channels.append(channels)
        self.middle_block = TimeEmbeddingV2(
            ResidualBlockV2(channels, d_time_emb),
            SpatialTransformerV1(channels, n_heads, tf_layers, d_cond),
            ResidualBlockV2(channels, d_time_emb),
        )
        self.output_blocks = nn.ModuleList()
        for i in reversed(range(levels)):
            for j in range(n_res_blocks + 1):
                layers = [ResidualBlockV2(channels + input_block_channels.pop(), d_time_emb, out_channels=channels_list[i])]
                channels = channels_list[i]
                if i in attention_levels:
                    layers.append(SpatialTransformerV1(channels, n_heads, tf_layers, d_cond))
                if i != 0 and j == n_res_blocks:
                    layers.apppend(UpSampleV2(channels))
                self.output_blocks.append(TimeEmbeddingV2(*layers))
        self.out = nn.Sequential(
            GroupNorm32(32, channels),
            nn.SiLU(inplace=True),
            nn.Conv2d(channels, out_channels, 3, padding=1),
        )
    
    #TODO: this is similar to TimeEmbedding Module
    def time_step_embedding(self, time_steps, max_period=10000):
        half = self.channels // 2
        freqs = torch.exp(-math.log(max_period) * torch.arange(start=0,end=half,dtype=torch.float32)/half).to(device=time_steps.device)
        args = time_steps[:, None].float() * freqs[None]
        return torch.cat([torch.cos(args), torch.sin(args)], dim=-1)

    def forward(self, x, time_steps, cond):
        t_emb = self.time_step_embedding(time_steps)
        t_emb = self.time_embed(t_emb)
        x_input_block = []
        for module in self.input_blocks:
            x = module(x, t_emb, cond)
            x_input_block.append(x)
        x = self.middle_block(x, t_emb, cond)
        for module in self.output_blocks:
            x = torch.cat([x, x_input_block.pop()], dim=1)
            x = module(x, t_emb, cond)
        return self.out(x)
