# TIME

In [None]:
self.to_time_hiddens = nn.Sequential(
            sinu_pos_emb,
            nn.Linear(sinu_pos_emb_input_dim, time_cond_dim),
            nn.SiLU()
        )

        self.to_time_cond = nn.Sequential(
            nn.Linear(time_cond_dim, time_cond_dim)
        )

        # project to time tokens as well as time hiddens

        self.to_time_tokens = nn.Sequential(
            nn.Linear(time_cond_dim, cond_dim * num_time_tokens),
            Rearrange('b (r d) -> b r d', r = num_time_tokens)
        )

In [20]:
class LearnedSinusoidalPosEmb(nn.Module):

    def __init__(self, dim):
        
        super().__init__()
        assert (dim % 2) == 0
        self.weights = nn.Parameter(torch.randn(dim // 2))

    def forward(self, x):
        
        x = rearrange(x, 'b -> b 1')    
        f = x * rearrange(self.weights, 'd -> 1 d') 
        w = f * 2 * torch.pi
        
        return torch.cat((x, torch.sin(w), torch.cos(w)), dim = -1)

In [22]:
from einops import rearrange
import math

lspe = LearnedSinusoidalPosEmb(16)
x = torch.ones(8)
print(lspe.weights)
lspe(x).shape

Parameter containing:
tensor([-0.3285, -0.2634,  0.4139, -0.3833,  0.7469, -0.0939, -0.2985, -0.0953],
       requires_grad=True)
tensor([[1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.]])
tensor([[ 1.0000, -0.8808, -0.9964,  0.5149, -0.6694, -0.9998, -0.5565, -0.9539,
         -0.5636, -0.4735, -0.0844, -0.8573, -0.7429, -0.0197,  0.8309, -0.3000,
          0.8260],
        [ 1.0000, -0.8808, -0.9964,  0.5149, -0.6694, -0.9998, -0.5565, -0.9539,
         -0.5636, -0.4735, -0.0844, -0.8573, -0.7429, -0.0197,  0.8309, -0.3000,
          0.8260],
        [ 1.0000, -0.8808, -0.9964,  0.5149, -0.6694, -0.9998, -0.5565, -0.9539,
         -0.5636, -0.4735, -0.0844, -0.8573, -0.7429, -0.0197,  0.8309, -0.3000,
          0.8260],
        [ 1.0000, -0.8808, -0.9964,  0.5149, -0.6694, -0.9998, -0.5565, -0.9539,
         -0.5636, -0.4735, -0.0844, -0.8573, -0.7429, -0.0197,  0.8309, -0.3000,
          0.8260],
        [ 1.0000, -0.8808, -0.996

torch.Size([8, 18])

In [None]:
import torch
import torch.nn as nn

class Swish(nn.Module):

    def forward(self, x):
        return x * torch.sigmoid(x)
    
class LearnedSinusoidalPosEmb(nn.Module):

    def __init__(self, dim):
        
        super().__init__()
        assert (dim % 2) == 0
        self.weights = nn.Parameter(torch.randn(dim // 2))

    def forward(self, x):
        
        x = rearrange(x, 'b -> b 1')    
        f = x * rearrange(self.weights, 'd -> 1 d') 
        w = f * 2 * torch.pi
        
        return torch.cat((x, torch.sin(w), torch.cos(w)), dim = -1)

class TimeConditioning(nn.Module):
    
    def __init__(self, unet_dim, time_embedding_dim=16, num_time_tokens = 2):
        super(TimeConditioning, self).__init__()
        
        self.to_time_hiddens = nn.Sequential(LearnedSinusoidalPosEmb(time_embedding_dim)
                                             nn.Linear(time_embedding_dim+1, unet_dim*4),
                                             nn.Swish())

        self.to_time_cond = nn.Linear(unet_dim*4, unet_dim*4)

        self.to_time_tokens = nn.Sequential(nn.Linear(unet_dim*4, unet_dim * num_time_tokens),
                                            Rearrange('b (r d) -> b r d', r = num_time_tokens))
        
        self.norm_cond = nn.LayerNorm(cond_dim)
        
    def forward(self, time):
        
        time_hiddens = self.to_time_hiddens(time)
        
        time_tokens = self.to_time_tokens(time_hiddens)
        t           = self.to_time_cond(time_hiddens)
        
        return t, time_tokens

In [2]:
import torch
import torch.nn as nn
nn.Parameter(torch.randn(16 // 2))

Parameter containing:
tensor([-0.1929,  1.2791, -0.5375,  1.5420,  2.4278,  1.8879, -0.4106,  0.3239],
       requires_grad=True)

In [3]:
torch.randn(16 // 2, requires_grad=True)

tensor([ 0.4047, -1.2444, -1.2889, -0.4955,  0.7217, -1.7086, -0.1643,  1.3201],
       requires_grad=True)

In [None]:
class PerceiverAttention(nn.Module):
    def __init__(self, *, dim, dim_head = 64, heads = 8):
        super().__init__()
        
        self.scale = dim_head ** -0.5
        self.heads = heads
        inner_dim = dim_head * heads

        self.norm         = nn.LayerNorm(dim)
        self.norm_latents = nn.LayerNorm(dim)

        self.Q  = nn.Linear(dim, inner_dim, bias = False)
        self.KV = nn.Linear(dim, inner_dim * 2, bias = False)

        self.output_layer = nn.Sequential(nn.Linear(inner_dim, dim, bias = False), nn.LayerNorm(dim))

    def forward(self, x, latents, mask = None):
        
        x       = self.norm(x)
        latents = self.norm_latents(latents)
        kv      = torch.cat((x, latents), dim = -2)

        b, h = x.shape[0], self.heads

        queries      = self.Q(latents)        
        keys, values = self.KV(kv_input).chunk(2, dim = -1)

        q, k, v = rearrange_many((q, k, v), 'b n (h d) -> b h n d', h = h)

        q = q * self.scale

        # attention

        sim = einsum('... i d, ... j d  -> ... i j', q, k)

        if exists(mask):
            max_neg_value = -torch.finfo(sim.dtype).max
            mask = F.pad(mask, (0, latents.shape[-2]), value = True)
            mask = rearrange(mask, 'b j -> b 1 1 j')
            sim = sim.masked_fill(~mask, max_neg_value)

        attn = sim.softmax(dim = -1, dtype = torch.float32)

        out = einsum('... i j, ... j d -> ... i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)', h = h)
        return self.to_out(out)

    
class PerceiverResampler(nn.Module):
    def __init__(self, unet_dim, depth=2, dim_head = 64, heads = 8, num_latents = 32,
                 num_latents_mean_pooled = 4, max_seq_len = 512, ff_mult = 4):
        
        super(PerceiverResampler, self).__init__()
        
        self.unet_dim = unet_dim
                
        self.pos_emb = nn.Embedding(max_seq_len, dim)
        self.latents = nn.Parameter(torch.randn(num_latents, dim))

        self.to_latents_from_mean_pooled_seq = None

        self.to_latents_from_mean_pooled_seq = nn.Sequential(
                # LayerNorm(dim),
                nn.Linear(unet_dim, unet_dim * num_latents_mean_pooled),
                Rearrange('b (n d) -> b n d', n = num_latents_mean_pooled))

        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([PerceiverAttention(dim = dim, dim_head = dim_head, heads = heads),
                                              FeedForward(dim = dim, mult = ff_mult)]))

    def forward(self, x, mask = None):
        n, device = x.shape[1], x.device
        pos_emb = self.pos_emb(torch.arange(n, device = device))

        x_with_pos = x + pos_emb

        latents = repeat(self.latents, 'n d -> b n d', b = x.shape[0])
        
        F.layer_norm(x, x.shape[-1:], nn.Parameter(torch.ones(self.unet_dim)), torch.zeros(self.unet_dim))

        if exists(self.to_latents_from_mean_pooled_seq):
            meanpooled_seq = masked_mean(x, dim = 1, mask = torch.ones(x.shape[:2], device = x.device, dtype = torch.bool))
            meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq)
            latents = torch.cat((meanpooled_latents, latents), dim = -2)

        for attn, ff in self.layers:
            latents = attn(x_with_pos, latents, mask = mask) + latents
            latents = ff(latents) + latents

        return latents

In [None]:
class TextConditioning(nn.Module):
    
    def __init__(self, unet_dim, text_embedding_dim):
        super(TextConditioning, self).__init__()
        
        self.text_to_cond  = nn.Linear(text_embedding_dim, unet_dim)
        self.attn_pool     = PerceiverResampler(dim = cond_dim, depth = 2, dim_head = attn_dim_head, heads = attn_heads, num_latents = attn_pool_num_latents)
        self.non_attn_cond = nn.Sequential(nn.LayerNorm(unet_dim),
                                           nn.Linear(unet_dim, unet_dim*4),
                                           Swish(),
                                           nn.Linear(unet_dim*4, unet_dim*4))
        
    def forward(self, text_embeds, text_mask):
        
        text_tokens = self.text_to_cond(text_embeds)[:, :self.max_text_len]

        
        