In [35]:
# %load UNet.py
import torch
import torch.nn as nn

from einops import rearrange
from einops.layers.torch import Rearrange


class Swish(nn.Module):

    def __init__(self):
        super(Swish, self).__init__()

    def forward(self, x):
        return x * torch.sigmoid(x)

#######################################################################
#################### CLASSES FOR TIME CONDITIONING ####################
#######################################################################

class SinusoidalPositionEmbedding(nn.Module):

    def __init__(self, dim):
        
        super(SinusoidalPositionEmbedding, self).__init__()
        assert (dim % 2) == 0
        self.weights = nn.Parameter(torch.randn(dim // 2))

    def forward(self, x):
        
        x = rearrange(x, 'b -> b 1')    
        f = x * rearrange(self.weights, 'd -> 1 d') 
        w = f * 2 * torch.pi
        print(torch.cat((x, torch.sin(w), torch.cos(w)), dim = -1).shape)
        return torch.cat((x, torch.sin(w), torch.cos(w)), dim = -1)

class TimeConditioning(nn.Module):
    
    def __init__(self, unet_dim, time_embedding_dim=16, num_time_tokens = 2):
        super(TimeConditioning, self).__init__()
        
        self.to_time_hiddens = nn.Sequential(SinusoidalPositionEmbedding(time_embedding_dim),
                                             nn.Linear(time_embedding_dim+1, unet_dim*4),
                                             Swish())

        self.to_time_cond = nn.Linear(unet_dim*4, unet_dim*4)

        self.to_time_tokens = nn.Sequential(nn.Linear(unet_dim*4, unet_dim * num_time_tokens),
                                            Rearrange('b (r d) -> b r d', r = num_time_tokens))
        
    def forward(self, time):
        
        time_hiddens = self.to_time_hiddens(time)
        print(time_hiddens.shape)
        time_tokens = self.to_time_tokens(time_hiddens)
        t           = self.to_time_cond(time_hiddens)
        
        return t, time_tokens

In [36]:
tc = TimeConditioning(8, 10, 5)

x = torch.rand((10))
t, time_tokens = tc(x)

torch.Size([10, 11])
torch.Size([10, 32])


In [31]:
t.shape

torch.Size([10, 32])

In [32]:
time_tokens.shape

torch.Size([10, 5, 8])

In [None]:
class PerceiverAttention(nn.Module):
    def __init__(
        self,
        *,
        dim,
        dim_head = 64,
        heads = 8
    ):
        super().__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        inner_dim = dim_head * heads

        self.norm = nn.LayerNorm(dim)
        self.norm_latents = nn.LayerNorm(dim)

        self.to_q = nn.Linear(dim, inner_dim, bias = False)
        self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim, bias = False),
            nn.LayerNorm(dim)
        )

    def forward(self, x, latents, mask = None):
        x = self.norm(x)
        latents = self.norm_latents(latents)

        b, h = x.shape[0], self.heads

        q = self.to_q(latents)

        # the paper differs from Perceiver in which they also concat the key / values derived from the latents to be attended to
        kv_input = torch.cat((x, latents), dim = -2)
        k, v = self.to_kv(kv_input).chunk(2, dim = -1)

        q, k, v = rearrange_many((q, k, v), 'b n (h d) -> b h n d', h = h)

        q = q * self.scale

        # attention

        sim = einsum('... i d, ... j d  -> ... i j', q, k)

        if exists(mask):
            max_neg_value = -torch.finfo(sim.dtype).max
            mask = F.pad(mask, (0, latents.shape[-2]), value = True)
            mask = rearrange(mask, 'b j -> b 1 1 j')
            sim = sim.masked_fill(~mask, max_neg_value)

        attn = sim.softmax(dim = -1, dtype = torch.float32)

        out = einsum('... i j, ... j d -> ... i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)', h = h)
        return self.to_out(out)


In [62]:
class PerceiverAttention(nn.Module):
    def __init__(self, cond_dim, dim_head = 64, heads = 8):
        super().__init__()
        
        self.scale = dim_head ** -0.5
        self.heads = heads

        self.norm_x = nn.LayerNorm(cond_dim)
        self.norm_l = nn.LayerNorm(cond_dim)

        self.Q  = nn.Linear(cond_dim, dim_head * heads, bias = False)
        self.KV = nn.Linear(cond_dim, dim_head * heads * 2, bias = False)

        self.output_layer = nn.Sequential(nn.Linear(dim_head * heads, cond_dim, bias = False),
                                          nn.LayerNorm(cond_dim))
        
    def attention(self, q, k, v, mask=None):
        
        score = torch.matmul(q, k.transpose(-1, -2))
        
        if mask is not None:            
            max_neg = -torch.finfo(score.dtype).max            
            mask    = mask[:, None, None, :]
            score   = score.masked_fill(~mask, max_neg)
            
        probs = score.softmax(dim = -1, dtype = torch.float32)
        
        return torch.matmul(probs, v)

    def forward(self, x, latents, mask = None):
        
        batch = x.shape[0]
        
        x = self.norm_x(x)
        l = self.norm_l(latents)

        q    = self.Q(l)
        k, v = self.KV(torch.cat((x, l), dim = -2)).chunk(2, dim = -1)

        q, k, v = rearrange_many((q, k, v), 'b n (h d) -> b h n d', h = self.heads)

        q = q * self.scale
        
        if mask is not None:
            mask = F.pad(mask, (0, l.shape[-2]), value = True)
            att = self.attention(q, k, v, mask)
        else:
            att = self.attention(q, k, v)        

        out = rearrange(att, 'b h n d -> b n (h d)', h = self.heads)
        return self.output_layer(out)


In [55]:
K = torch.rand((10, 4, 3))
Q = torch.rand((10, 3, 5))

In [56]:
pp = torch.matmul(K, Q)
pp.shape

torch.Size([10, 4, 5])

In [57]:
pp = einsum('... i j, ... j d -> ... i d', K, Q)
pp.shape

torch.Size([10, 4, 5])

In [68]:
PA = PerceiverAttention(cond_dim=10, dim_head = 5, heads = 4)

x = torch.rand((4, 10, 10))
m = torch.ones((4, 10), dtype=torch.bool)
y = PA(x, x, m)
y.shape

torch.Size([4, 10, 10])

In [38]:
m = torch.rand((3, 2))
m[:, None, None, :]

tensor([[[[0.2584, 0.9313]]],


        [[[0.4733, 0.5533]]],


        [[[0.7753, 0.6512]]]])

In [39]:
rearrange(m, 'b j -> b 1 1 j')

tensor([[[[0.2584, 0.9313]]],


        [[[0.4733, 0.5533]]],


        [[[0.7753, 0.6512]]]])

In [15]:
y.shape

torch.Size([4, 10, 10])

In [25]:
scores = torch.matmul(Q, K.transpose(-1, -2))
scores

tensor([[[1.5745, 1.1404, 0.4792],
         [0.9358, 0.4613, 0.3689],
         [1.3690, 1.1262, 0.6740],
         [1.0577, 0.7342, 0.4657],
         [1.1156, 0.8266, 0.6988]],

        [[0.3407, 0.9385, 0.0908],
         [0.7362, 0.9733, 0.1799],
         [1.1299, 1.5349, 0.5749],
         [0.8635, 1.2521, 0.1771],
         [1.3654, 1.4977, 0.4404]],

        [[0.8268, 0.9538, 0.8726],
         [1.6497, 1.2049, 1.6238],
         [1.3238, 1.4486, 1.4216],
         [1.1053, 0.4499, 1.0812],
         [1.3090, 1.1986, 1.3285]],

        [[1.7625, 0.7689, 1.7370],
         [1.1853, 0.9257, 0.9524],
         [2.1698, 1.2840, 1.8145],
         [1.6721, 1.0293, 1.5588],
         [0.5936, 0.2678, 0.6266]],

        [[1.3417, 1.2941, 1.1318],
         [0.6657, 0.9366, 0.2966],
         [0.6792, 0.8422, 0.6958],
         [1.2936, 1.3798, 0.7328],
         [1.3701, 1.3936, 1.4157]],

        [[1.1298, 1.9560, 1.9029],
         [1.6332, 1.8136, 1.9687],
         [1.4505, 2.1276, 1.8893],
         [

In [22]:
sim = einsum('... i d, ... j d  -> ... i j', Q, K)
sim

tensor([[[1.5745, 1.1404, 0.4792],
         [0.9358, 0.4613, 0.3689],
         [1.3690, 1.1262, 0.6740],
         [1.0577, 0.7342, 0.4657],
         [1.1156, 0.8266, 0.6988]],

        [[0.3407, 0.9385, 0.0908],
         [0.7362, 0.9733, 0.1799],
         [1.1299, 1.5349, 0.5749],
         [0.8635, 1.2521, 0.1771],
         [1.3654, 1.4977, 0.4404]],

        [[0.8268, 0.9538, 0.8726],
         [1.6497, 1.2049, 1.6238],
         [1.3238, 1.4486, 1.4216],
         [1.1053, 0.4499, 1.0812],
         [1.3090, 1.1986, 1.3285]],

        [[1.7625, 0.7689, 1.7370],
         [1.1853, 0.9257, 0.9524],
         [2.1698, 1.2840, 1.8145],
         [1.6721, 1.0293, 1.5588],
         [0.5936, 0.2678, 0.6266]],

        [[1.3417, 1.2941, 1.1318],
         [0.6657, 0.9366, 0.2966],
         [0.6792, 0.8422, 0.6958],
         [1.2936, 1.3798, 0.7328],
         [1.3701, 1.3936, 1.4157]],

        [[1.1298, 1.9560, 1.9029],
         [1.6332, 1.8136, 1.9687],
         [1.4505, 2.1276, 1.8893],
         [

In [24]:
sim.max()

tensor(2.2724)

In [21]:
max_neg_value = -torch.finfo(sim.dtype).max
max_neg_value

-3.4028234663852886e+38

In [None]:
class PerceiverResampler(nn.Module):
    def __init__(
        self,
        *,
        dim,
        depth,
        dim_head = 64,
        heads = 8,
        num_latents = 64,
        num_latents_mean_pooled = 4, # number of latents derived from mean pooled representation of the sequence
        max_seq_len = 512,
        ff_mult = 4,
    ):
        super().__init__()
        self.pos_emb = nn.Embedding(max_seq_len, dim)

        self.latents = nn.Parameter(torch.randn(num_latents, dim))

        self.to_latents_from_mean_pooled_seq = None

        if num_latents_mean_pooled > 0:
            self.to_latents_from_mean_pooled_seq = nn.Sequential(
                LayerNorm(dim),
                nn.Linear(dim, dim * num_latents_mean_pooled),
                Rearrange('b (n d) -> b n d', n = num_latents_mean_pooled)
            )

        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PerceiverAttention(dim = dim, dim_head = dim_head, heads = heads),
                FeedForward(dim = dim, mult = ff_mult)
            ]))

    def forward(self, x, mask = None):
        n, device = x.shape[1], x.device
        pos_emb = self.pos_emb(torch.arange(n, device = device))

        x_with_pos = x + pos_emb

        latents = repeat(self.latents, 'n d -> b n d', b = x.shape[0])

        if exists(self.to_latents_from_mean_pooled_seq):
            meanpooled_seq = masked_mean(x, dim = 1, mask = torch.ones(x.shape[:2], device = x.device, dtype = torch.bool))
            meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq)
            latents = torch.cat((meanpooled_latents, latents), dim = -2)

        for attn, ff in self.layers:
            latents = attn(x_with_pos, latents, mask = mask) + latents
            latents = ff(latents) + latents

        return latents

In [1]:
def exists(val):
    return val is not None

def default(val, d):
    if exists(val):
        return val
    return d() if callable(d) else d
cond_dim = None
dim = 64

cond_dim = default(cond_dim, dim)

In [2]:
cond_dim

64

In [8]:
import torch
import torch.nn.functional as F
from torch import nn, einsum

from einops.layers.torch import Rearrange
from einops import rearrange, repeat
from einops_exts import rearrange_many

In [48]:
class PerceiverAttention(nn.Module):
    def __init__(self, cond_dim, dim_head = 64, heads = 8):
        super(PerceiverAttention, self).__init__()
        
        self.scale = dim_head ** -0.5
        self.heads = heads

        self.norm_x = nn.LayerNorm(cond_dim)
        self.norm_l = nn.LayerNorm(cond_dim)

        self.Q  = nn.Linear(cond_dim, dim_head * heads, bias = False)
        self.KV = nn.Linear(cond_dim, dim_head * heads * 2, bias = False)

        self.output_layer = nn.Sequential(nn.Linear(dim_head * heads, cond_dim, bias = False),
                                          nn.LayerNorm(cond_dim))
        
    def attention(self, q, k, v, mask=None):
        
        score = torch.matmul(q, k.transpose(-1, -2))
        
        if mask is not None:            
            max_neg = -torch.finfo(score.dtype).max            
            mask    = mask[:, None, None, :]
            score   = score.masked_fill(~mask, max_neg)
            
        probs = score.softmax(dim = -1, dtype = torch.float32)
        
        return torch.matmul(probs, v)

    def forward(self, x, latents, mask = None):
        
        batch = x.shape[0]
        
        x = self.norm_x(x)
        l = self.norm_l(latents)

        q    = self.Q(l)
        k, v = self.KV(torch.cat((x, l), dim = -2)).chunk(2, dim = -1)

        q, k, v = rearrange_many((q, k, v), 'b n (h d) -> b h n d', h = self.heads)

        q = q * self.scale
        
        if mask is not None:
            mask = F.pad(mask, (0, l.shape[-2]), value = True)
            att = self.attention(q, k, v, mask)
        else:
            att = self.attention(q, k, v)        

        out = rearrange(att, 'b h n d -> b n (h d)', h = self.heads)
        return self.output_layer(out)

In [42]:
class Swish(nn.Module):

    def __init__(self):
        super(Swish, self).__init__()

    def forward(self, x):
        return x * torch.sigmoid(x)

In [46]:
class MasterPerceiver(nn.Module):
    def __init__(self, cond_dim, depth, dim_head = 64, heads = 8, num_latents = 64,
                 num_latents_mean_pooled = 4, max_seq_len = 512, ff_mult = 4):        
        # num_latents_mean_pooled: number of latents derived from mean pooled representation of the sequence
        
        super(MasterPerceiver, self).__init__()
        
        self.embedding = nn.Embedding(max_seq_len, cond_dim)
        self.latents   = nn.Parameter(torch.randn(num_latents, cond_dim))

        self.Mean_Pooled = None
        if num_latents_mean_pooled > 0:
            self.Mean_Pooled = nn.Sequential(nn.LayerNorm(cond_dim),
                                             nn.Linear(cond_dim, cond_dim * num_latents_mean_pooled),
                                             Rearrange('b (n d) -> b n d', n = num_latents_mean_pooled))

        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([PerceiverAttention(cond_dim, dim_head, heads),
                               nn.Sequential(nn.LayerNorm(cond_dim),
                                             nn.Linear(cond_dim, cond_dim*ff_mult, bias = False),
                                             Swish(),
                                             nn.LayerNorm(cond_dim*ff_mult),
                                             nn.Linear(cond_dim*ff_mult, cond_dim, bias = False))]))

    def forward(self, x, mask = None):
        
        b, n, device = x.shape[0], x.shape[1], x.device
        
        pos_emb = self.embedding(torch.arange(n, device = device))
        X = x + pos_emb

        lat = repeat(self.latents, 'n d -> b n d', b = b)

        if self.Mean_Pooled is not None:
                
            pool_mask = torch.ones(x.shape[:2], device = x.device, dtype = torch.bool)
                
            denom     = pool_mask.sum(dim = 1, keepdim = True)
            pool_mask = pool_mask[:, :, None]
            masked_x  = x.masked_fill(~pool_mask, 0.)
            
            m_seq = masked_x.sum(dim = 1) / denom.clamp(min = 1e-5)
            m_lat = self.Mean_Pooled(m_seq)
            lat   = torch.cat((m_lat, lat), dim = -2)
            
            print(lat.shape)

        for att, nn in self.layers:
            lat = nn(att(X, lat, mask = mask) + lat) + lat

        return lat

In [47]:
mp = MasterPerceiver(cond_dim=20, depth=2, dim_head = 64, heads = 8, num_latents = 64,
                 num_latents_mean_pooled = 4, max_seq_len = 15, ff_mult = 4)

In [49]:
x = torch.rand((4, 15, 20))
m = torch.ones((4, 15), dtype=torch.bool)
y = mp(x, m)
y.shape

torch.Size([4, 68, 20])


torch.Size([4, 68, 20])

In [None]:
PA = PerceiverAttention(cond_dim=10, dim_head = 5, heads = 4)

x = torch.rand((4, 10, 10))
m = torch.ones((4, 10), dtype=torch.bool)
y = PA(x, x, m)
y.shape

In [53]:
m = torch.rand((2))
rearrange(m, 'b -> 1 b')

tensor([[0.9309, 0.9206]])

In [54]:
m[None, :]

tensor([[0.9309, 0.9206]])

In [None]:
def masked_mean(t, *, dim, mask = None):
    if not exists(mask):
        return t.mean(dim = dim)

    denom = mask.sum(dim = dim, keepdim = True)
    mask = rearrange(mask, 'b n -> b n 1')
    masked_t = t.masked_fill(~mask, 0.)

    return masked_t.sum(dim = dim) / denom.clamp(min = 1e-5)

In [74]:
b = 8

print(b is not None)
print(b is None)

True
False


In [69]:
class cu():
    
    def __init__(self):
        
        self.me()
        
    def me(self):
        print('cu')

In [71]:
c = cu()

cu


In [36]:
def FeedForward(dim, mult = 2):
    hidden_dim = int(dim * mult)
    return nn.Sequential(
        LayerNorm(dim),
        nn.Linear(dim, hidden_dim, bias = False),
        nn.GELU(),
        LayerNorm(hidden_dim),
        nn.Linear(hidden_dim, dim, bias = False)
    )

class LayerNorm(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(dim))
        self.register_buffer('beta', torch.zeros(dim))

    def forward(self, x):
        return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)
    
class PerceiverAttention(nn.Module):
    def __init__(
        self,
        *,
        dim,
        dim_head = 64,
        heads = 8
    ):
        super().__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        inner_dim = dim_head * heads

        self.norm = nn.LayerNorm(dim)
        self.norm_latents = nn.LayerNorm(dim)

        self.to_q = nn.Linear(dim, inner_dim, bias = False)
        self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim, bias = False),
            nn.LayerNorm(dim)
        )

    def forward(self, x, latents, mask = None):
        x = self.norm(x)
        latents = self.norm_latents(latents)

        b, h = x.shape[0], self.heads

        q = self.to_q(latents)

        # the paper differs from Perceiver in which they also concat the key / values derived from the latents to be attended to
        kv_input = torch.cat((x, latents), dim = -2)
        k, v = self.to_kv(kv_input).chunk(2, dim = -1)

        q, k, v = rearrange_many((q, k, v), 'b n (h d) -> b h n d', h = h)

        q = q * self.scale

        # attention

        sim = einsum('... i d, ... j d  -> ... i j', q, k)

        if exists(mask):
            max_neg_value = -torch.finfo(sim.dtype).max
            mask = F.pad(mask, (0, latents.shape[-2]), value = True)
            mask = rearrange(mask, 'b j -> b 1 1 j')
            sim = sim.masked_fill(~mask, max_neg_value)

        attn = sim.softmax(dim = -1, dtype = torch.float32)

        out = einsum('... i j, ... j d -> ... i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)', h = h)
        return self.to_out(out)

class PerceiverResampler(nn.Module):
    def __init__(
        self,
        *,
        dim,
        depth,
        dim_head = 64,
        heads = 8,
        num_latents = 64,
        num_latents_mean_pooled = 4, # number of latents derived from mean pooled representation of the sequence
        max_seq_len = 512,
        ff_mult = 4,
    ):
        super().__init__()
        self.pos_emb = nn.Embedding(max_seq_len, dim)

        self.latents = nn.Parameter(torch.randn(num_latents, dim))

        self.to_latents_from_mean_pooled_seq = None

        if num_latents_mean_pooled > 0:
            self.to_latents_from_mean_pooled_seq = nn.Sequential(
                LayerNorm(dim),
                nn.Linear(dim, dim * num_latents_mean_pooled),
                Rearrange('b (n d) -> b n d', n = num_latents_mean_pooled)
            )

        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PerceiverAttention(dim = dim, dim_head = dim_head, heads = heads),
                FeedForward(dim = dim, mult = ff_mult)
            ]))

    def forward(self, x, mask = None):
        n, device = x.shape[1], x.device
        pos_emb = self.pos_emb(torch.arange(n, device = device))

        x_with_pos = x + pos_emb

        latents = repeat(self.latents, 'n d -> b n d', b = x.shape[0])
        
        print(latents.shape)

        if exists(self.to_latents_from_mean_pooled_seq):
            meanpooled_seq = masked_mean(x, dim = 1, mask = torch.ones(x.shape[:2], device = x.device, dtype = torch.bool))
            meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq)
            latents = torch.cat((meanpooled_latents, latents), dim = -2)
            print(latents.shape)

        for attn, ff in self.layers:
            latents = attn(x_with_pos, latents, mask = mask) + latents
            latents = ff(latents) + latents
            
            print(latents.shape)

        return latents

In [None]:
mp = MasterPerceiver(cond_dim=20, depth=2, dim_head = 64, heads = 8, num_latents = 64,
                 num_latents_mean_pooled = 4, max_seq_len = 15, ff_mult = 4)

In [37]:
pr = PerceiverResampler(dim=20, depth=2, dim_head = 64, heads = 8, num_latents = 64,
                        num_latents_mean_pooled = 4, max_seq_len = 15, ff_mult = 4)

In [38]:
x = torch.rand((4, 15, 20))
m = torch.ones((4, 15), dtype=torch.bool)
y = pr(x, m)
y.shape

torch.Size([4, 64, 20])
torch.Size([4, 68, 20])
torch.Size([4, 68, 20])
torch.Size([4, 68, 20])


torch.Size([4, 68, 20])

In [30]:
def exists(val):
    return val is not None

In [34]:
def masked_mean(t, *, dim, mask = None):
    if not exists(mask):
        return t.mean(dim = dim)

    denom = mask.sum(dim = dim, keepdim = True)
    mask = rearrange(mask, 'b n -> b n 1')
    masked_t = t.masked_fill(~mask, 0.)

    return masked_t.sum(dim = dim) / denom.clamp(min = 1e-5)

In [None]:
class TextConditioning(nn.Module):
    
    def __init__(self, dim, cond_dim, text_embed_dim, dim_head, heads,
                 num_latents, max_text_len, Ttype, device='cpu'):
        super(TextConditioning, self).__init__()
        
        self.max_text_len = max_text_len
        self.Ttype        = Ttype
        
        self.text_to_cond = nn.Linear(text_embed_dim, dim)
        self.attention    = MasterPerceiver(dim=dim, depth=2, dim_head=dim_head, heads=heads, num_latents=num_latents,
                                            num_latents_mean_pooled=4, max_seq_len=512, ff_mult=4)    
        
        self.null_text_embed  = nn.Parameter(torch.randn(1, max_text_len, cond_dim))
        self.null_text_hidden = nn.Parameter(torch.randn(1, dim*4))
        
        self.text_to_hiddens = nn.Sequential(nn.LayerNorm(cond_dim),
                                             nn.Linear(cond_dim, dim*4),
                                             Swish(),
                                              nn.Linear(dim*4, time_cond_dim))
        
    def forward(self, text_embeds, text_mask=None):
        
        # Making Masks
        
        text_keep_mask        = torch.zeros((batch_size,), device = device).float().uniform_(0, 1) < 0.9
        text_keep_mask_embed  = text_keep_mask[:, None, None]
        text_keep_mask_hidden = text_keep_mask[:, None]
        
        # Text Tokens
        
        text_tokens = self.text_to_cond(text_embeds)[:, :self.max_text_len]
        
        remainder = self.max_text_len - text_tokens.shape[1]
        if remainder > 0: text_tokens = F.pad(text_tokens, (0, 0, 0, remainder))
        
        if text_mask is not None:
            if remainder > 0: text_mask = F.pad(text_mask, (0, remainder), value = False)
            
            text_mask = text_mask[:, :, None]
            text_keep_mask_embed = text_mask & text_keep_mask_embed
            
        null_text_embed = self.null_text_embed.to(text_tokens.dtype) # for some reason pytorch AMP not working

        text_tokens = torch.where(text_keep_mask_embed, text_tokens, null_text_embed)        
        text_tokens = self.attention(text_tokens)
        
        # Text Hiddens
        
        text_hiddens     = self.text_to_hiddens(text_tokens.mean(dim = -2))
        null_text_hidden = self.null_text_hidden.to(self.Ttype)
        text_hiddens     = torch.where(text_keep_mask_hidden, text_hiddens, null_text_hidden)
        
        
        return text_tokens, text_hiddens

        
        