In [52]:
import torch
import torch.nn as nn
import BasicModels as BM
from ComplexModels import TextConditioning, TimeConditioning
from einops import rearrange

In [2]:
time_cond = TimeConditioning(dim=20, cond_dim=20, time_embedding_dim=10, num_time_tokens = 68)

text_cond = TextConditioning(dim=20, cond_dim=20, text_embed_dim=10, dim_head=64, heads=8,
                             num_latents=64, max_text_len=15, Ttype=torch.float, device='cpu')

In [3]:
text_embeds = torch.randn(4, 15, 10)
text_masks  = torch.ones(4, 15).bool()
time        = torch.rand(4)

time_tokens, t = time_cond(time)
print(time_tokens.shape, t.shape)

text_tokens, text_hiddens = text_cond(text_embeds, text_masks)
print(text_tokens.shape, text_hiddens.shape)

torch.Size([4, 68, 20]) torch.Size([4, 80])
torch.Size([4, 68, 20]) torch.Size([4, 80])


In [135]:
import torch
import transformers
from transformers import T5Tokenizer, T5EncoderModel, T5Config

In [136]:
T5_CONFIGS = {}

In [137]:
def get_encoded_dim(name='google/t5-v1_1-base'):
    if name not in T5_CONFIGS:
        print('1')
        # avoids loading the model if we only want to get the dim
        config = T5Config.from_pretrained(name)
        T5_CONFIGS[name] = dict(config=config)
    elif "config" in T5_CONFIGS[name]:
        print('2')
        config = T5_CONFIGS[name]["config"]
    elif "model" in T5_CONFIGS[name]:
        print('3')
        config = T5_CONFIGS[name]["model"].config
    else:
        print('4')
        assert False
    return config.d_model

get_encoded_dim()

1


768

In [139]:
T5Config.from_pretrained('google/t5-v1_1-base').d_model

768

In [138]:
T5_CONFIGS

{'google/t5-v1_1-base': {'config': T5Config {
    "_name_or_path": "/home/patrick/hugging_face/t5/t5-v1_1-base",
    "architectures": [
      "T5ForConditionalGeneration"
    ],
    "d_ff": 2048,
    "d_kv": 64,
    "d_model": 768,
    "decoder_start_token_id": 0,
    "dropout_rate": 0.1,
    "eos_token_id": 1,
    "feed_forward_proj": "gated-gelu",
    "initializer_factor": 1.0,
    "is_encoder_decoder": true,
    "layer_norm_epsilon": 1e-06,
    "model_type": "t5",
    "num_decoder_layers": 12,
    "num_heads": 12,
    "num_layers": 12,
    "output_past": true,
    "pad_token_id": 0,
    "relative_attention_num_buckets": 32,
    "tie_word_embeddings": false,
    "transformers_version": "4.3.0",
    "use_cache": true,
    "vocab_size": 32128
  }}}

In [19]:
class Identity(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__()

    def forward(self, x, *args, **kwargs):
        return x
    
class Parallel(nn.Module):
    def __init__(self, *fns):
        super().__init__()
        self.fns = nn.ModuleList(fns)

    def forward(self, x):
        outputs = [fn(x) for fn in self.fns]
        return sum(outputs)
    
class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x, **kwargs):
        return self.fn(x, **kwargs) + x

In [124]:
class Unet(nn.Module):
    def __init__(
        self,
        *,
        dim,
        image_embed_dim = 1024,
        text_embed_dim = 40, # get_encoded_dim(DEFAULT_T5_NAME),
        num_resnet_blocks = 1,
        cond_dim = None,
        num_image_tokens = 4,
        num_time_tokens = 2,
        learned_sinu_pos_emb = True,
        learned_sinu_pos_emb_dim = 16,
        out_dim = None,
        dim_mults=(1, 2, 4, 8),
        cond_images_channels = 0,
        channels = 3,
        channels_out = None,
        attn_dim_head = 64,
        attn_heads = 8,
        ff_mult = 2.,
        lowres_cond = False,                # for cascading diffusion - https://cascaded-diffusion.github.io/
        layer_attns = True,
        layer_attns_add_text_cond = True,   # whether to condition the self-attention blocks with the text embeddings, as described in Appendix D.3.1
        attend_at_middle = True,            # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention)
        layer_cross_attns = True,
        use_linear_attn = False,
        use_linear_cross_attn = False,
        cond_on_text = True,
        max_text_len = 256,
        init_dim = None,
        init_conv_kernel_size = 7,
        resnet_groups = 8,
        init_cross_embed_kernel_sizes = (3, 7, 15),
        cross_embed_downsample = False,
        cross_embed_downsample_kernel_sizes = (2, 4),
        attn_pool_text = True,
        attn_pool_num_latents = 32,
        dropout = 0.,
        memory_efficient = False,
        init_conv_to_final_conv_residual = False,
        use_global_context_attn = True,
        scale_skip_connection = True,
        final_resnet_block = True,
        final_conv_kernel_size = 3
    ):
        super().__init__()
        
        self.time_cond = TimeConditioning(dim                = dim,
                                          cond_dim           = cond_dim,
                                          time_embedding_dim = learned_sinu_pos_emb_dim,
                                          num_time_tokens    = num_time_tokens)

        self.text_cond = TextConditioning(dim            = dim,
                                          cond_dim       = cond_dim,
                                          text_embed_dim = text_embed_dim,
                                          dim_head       = attn_dim_head,
                                          heads          = attn_heads,
                                          num_latents    = attn_pool_num_latents,
                                          max_text_len   = max_text_len,
                                          Ttype          = torch.float,
                                          device         = 'cpu')
        
        self.norm_cond = nn.LayerNorm(cond_dim)
        
        self.init_conv = BM.CrossEmbedding(channels, dim, stride=1, kernel_sizes=init_cross_embed_kernel_sizes)
        
        # Params for UNet
        
        dims   = [dim, *[m*dim for m in dim_mults]]
        in_out = list(zip(dims[:-1], dims[1:]))
        
        num_resnet_blocks = (num_resnet_blocks,)*len(in_out)
        resnet_groups     = (resnet_groups,)*len(in_out)
        
        assert len(num_resnet_blocks) == len(in_out), 'num_resnet_blocks and in_out must be the same size'
        assert len(resnet_groups) == len(in_out),     'resnet_groups and in_out must be the same size'
        assert len(layer_attns) == len(in_out),       'layer_attns and in_out must be the same size'
        assert len(layer_cross_attns) == len(in_out), 'layer_cross_attns and in_out must be the same size'
        
        params   = [in_out, num_resnet_blocks, resnet_groups, layer_attns, layer_cross_attns]
        r_params = [reversed(in_out), reversed(num_resnet_blocks), reversed(resnet_groups), reversed(layer_attns), reversed(layer_cross_attns)]
        
        self.downs = nn.ModuleList([])
        self.ups   = nn.ModuleList([])
        
        skip_connect_dims = []
        
        # UNet Encoder ==========================================================================================================================

        for i, ((dim_in, dim_out), resnet_n, groups, layer_attn, layer_cross_attn) in enumerate(zip(*params)):
            
            is_last = i >= (len(in_out) - 1)
            
            layer_cond_dim = cond_dim if layer_cross_attn else None            
            current_dim    = dim_in
            
            skip_connect_dims.append(current_dim)           
             
            # First Resnet
            init_resnet = BM.ResnetLayer(current_dim, current_dim, cond_dim=layer_cond_dim,
                                         time_cond_dim=dim*4, groups=groups, linear_att=False, gca=False)
            
            # Multiples Resnets
            mult_resnet = nn.ModuleList([BM.ResnetLayer(current_dim, current_dim, time_cond_dim=dim*4,
                                                        groups=groups, linear_att=False, gca=use_global_context_attn) for _ in range(resnet_n)])
            
            # Transformer Layer
            if layer_attn:
                print(type(current_dim), type(attn_heads), type(attn_dim_head), type(ff_mult), type('normal'), type(cond_dim))
                
                transformerLayer = BM.TransformerLayer(dim=current_dim, heads=attn_heads, dim_head=attn_dim_head,
                                                       ff_mult=ff_mult, att_type='normal', context_dim=cond_dim)
            else: 
                transformerLayer = Identity()
                
            # Downsample
            if not is_last: 
                downsample = nn.Conv2d(current_dim, dim_out, 4, 2, 1)
            else:
                downsample = Parallel(nn.Conv2d(dim_in, dim_out, 3, padding = 1), nn.Conv2d(dim_in, dim_out, 1))
                
            # Append self.downs for Encoder
            self.downs.append(nn.ModuleList([init_resnet, mult_resnet, transformerLayer, downsample]))
            
        # UNet Bottleneck =======================================================================================================================
        
        mid_dim = dims[-1]
        
        self.mid_block1 = BM.ResnetLayer(mid_dim, mid_dim, cond_dim=cond_dim, time_cond_dim=dim*4, groups=resnet_groups[-1])
        self.mid_attn   = BM.AttentionTypes(mid_dim, heads=attn_heads, dim_head=attn_dim_head)
        self.mid_block2 = BM.ResnetLayer(mid_dim, mid_dim, cond_dim=cond_dim, time_cond_dim=dim*4, groups=resnet_groups[-1])
        
        # UNet Decoder ==========================================================================================================================
            
        for i, ((dim_in, dim_out), resnet_n, groups, layer_attn, layer_cross_attn) in enumerate(zip(*r_params)):
            
            is_last = i == (len(in_out) - 1)
            
            layer_cond_dim = cond_dim if layer_cross_attn else None           
            
            
            skip_connect_dim = skip_connect_dims.pop()
            # First Resnet
            # ResnetBlock(dim_out + skip_connect_dim, dim_out, cond_dim = layer_cond_dim, linear_attn = False, time_cond_dim = time_cond_dim, groups = groups)
            init_resnet = BM.ResnetLayer(dim_out + skip_connect_dim, dim_out, cond_dim=layer_cond_dim,
                                         time_cond_dim=dim*4, groups=groups, linear_att=False, gca=False)
            
            # Multiples Resnets
            # nn.ModuleList([ResnetBlock(dim_out + skip_connect_dim, dim_out, time_cond_dim = time_cond_dim,
            #                            groups = groups, use_gca = use_global_context_attn) for _ in range(layer_num_resnet_blocks)])
            mult_resnet = nn.ModuleList([BM.ResnetLayer(dim_out + skip_connect_dim, dim_out, time_cond_dim=dim*4,
                                                        groups=groups, linear_att=False, gca=use_global_context_attn) for _ in range(resnet_n)])
            
            # Transformer Layer
            # transformer_block_klass(dim = dim_out, heads = attn_heads, dim_head = attn_dim_head, ff_mult = ff_mult, context_dim = cond_dim)
            # transformer_block_klass = TransformerBlock if layer_attn else (LinearAttentionTransformerBlock if use_linear_attn else Identity)
            if layer_attn:
                print(type(current_dim), type(attn_heads), type(attn_dim_head), type(ff_mult), type('normal'), type(cond_dim))
                
                transformerLayer = BM.TransformerLayer(dim=dim_out, heads=attn_heads, dim_head=attn_dim_head,
                                                       ff_mult=ff_mult, att_type='normal', context_dim=cond_dim)
            else: 
                transformerLayer = Identity()
                
            # Upsample
            # Upsample(dim_out, dim_in) if not is_last or memory_efficient else Identity()
            if not is_last: 
                upsample = nn.Sequential(nn.Upsample(scale_factor=2, mode='nearest'),
                                         nn.Conv2d(dim_out, dim_in, 3, padding=1))
            else:
                upsample = Identity()               

            # Append self.ups for Decoder
            self.ups.append(nn.ModuleList([init_resnet, mult_resnet, transformerLayer, upsample]))
            
        # Final Layers ==========================================================================================================================

        self.final_resnet = BM.ResnetLayer(dim, dim, time_cond_dim=dim*4, groups=resnet_groups[0], linear_att=False, gca=True)
        self.final_conv = nn.Conv2d(dim, channels, 3, padding=1)
        
    def forward(self, x, time, text_embeds, text_mask):
        
        # Time Conditioning        
        time_tokens, t = self.time_cond(time)
        print(time_tokens.shape, t.shape)

        # Text Conditioning   
        text_tokens, text_hiddens = self.text_cond(text_embeds, text_mask)
        print(text_tokens.shape, text_hiddens.shape)
        
        
        # Concatenating Time and Text

        c = time_tokens if text_tokens is None else torch.cat((time_tokens, text_tokens), dim = -2)
        c = self.norm_cond(c)
        
        t = t + text_hiddens
        
        # Processing Image
        
        x = self.init_conv(x)
        
        # Encoder
        
        hiddens = []
        i = 0
        for init_resnet, mult_resnet, transformerLayer, downsample in self.downs:
            print(i)
            x = init_resnet(x, t, c)

            for resnet in mult_resnet:
                x = resnet(x, t)
                hiddens.append(x)

            print(x.shape, c.shape)
                
            x = transformerLayer(x, c)
            hiddens.append(x)

            x = downsample(x)
            i += 1
            
        print(x.shape)
            
        x = self.mid_block1(x, t, c)
        
        w = x.shape[-1]
        
        x = rearrange(x, 'b c h w -> b (h w) c')        
        x = self.mid_attn(x) + x
        x = rearrange(x, 'b (h w) c -> b c h w',w=w)

        x = self.mid_block2(x, t, c)
        
        print(x.shape)
        
        # add_skip_connection = lambda x: torch.cat((x, hiddens.pop() * self.skip_connect_scale), dim = 1)

        for init_resnet, mult_resnet, transformerLayer, upsample in self.ups:
            pp = hiddens.pop()
            print('pp', pp.shape)
            print('x', x.shape)
            x = torch.cat((x, pp * (2 ** -0.5)), dim = 1) # add_skip_connection(x)
            
            x = init_resnet(x, t, c)

            for resnet in mult_resnet:
                x = torch.cat((x, hiddens.pop() * (2 ** -0.5)), dim = 1) # add_skip_connection(x)
                x = resnet(x, t)

            x = transformerLayer(x, c)
            x = upsample(x)
            
        x = self.final_resnet(x)
        print('x1', x.shape)
        x = self.final_conv(x)
        print('x2', x.shape)
        
        return x, c, t
        

In [114]:
3//2

1

In [125]:
u = Unet(
    dim = 8,
    cond_dim = 40,
    dim_mults = (1, 2, 4, 8),
    num_resnet_blocks = 3,
    layer_attns = (False, True, True, True),
    layer_cross_attns = (False, True, True, True)
)
# print(type(current_dim), type(attn_heads), type(attn_dim_head), type(ff_mult), type('normal'), ctype(cond_dim))

<class 'int'> <class 'int'> <class 'int'> <class 'float'> <class 'str'> <class 'int'>
<class 'int'> <class 'int'> <class 'int'> <class 'float'> <class 'str'> <class 'int'>
<class 'int'> <class 'int'> <class 'int'> <class 'float'> <class 'str'> <class 'int'>
<class 'int'> <class 'int'> <class 'int'> <class 'float'> <class 'str'> <class 'int'>
<class 'int'> <class 'int'> <class 'int'> <class 'float'> <class 'str'> <class 'int'>
<class 'int'> <class 'int'> <class 'int'> <class 'float'> <class 'str'> <class 'int'>


In [126]:
x           = torch.randn(4, 3, 32, 32)
text_embeds = torch.randn(4, 32, 40)
text_masks  = torch.ones(4, 32).bool()
time        = torch.rand(4)

In [127]:
x, c, t = u(x, time, text_embeds, text_masks)

torch.Size([4, 2, 40]) torch.Size([4, 32])
torch.Size([4, 36, 40]) torch.Size([4, 32])
0
torch.Size([4, 8, 32, 32]) torch.Size([4, 38, 40])
1
torch.Size([4, 8, 16, 16]) torch.Size([4, 38, 40])
2
torch.Size([4, 16, 8, 8]) torch.Size([4, 38, 40])
3
torch.Size([4, 32, 4, 4]) torch.Size([4, 38, 40])
torch.Size([4, 64, 4, 4])
torch.Size([4, 64, 4, 4])
pp torch.Size([4, 32, 4, 4])
x torch.Size([4, 64, 4, 4])
pp torch.Size([4, 16, 8, 8])
x torch.Size([4, 32, 8, 8])
pp torch.Size([4, 8, 16, 16])
x torch.Size([4, 16, 16, 16])
pp torch.Size([4, 8, 32, 32])
x torch.Size([4, 8, 32, 32])
x1 torch.Size([4, 8, 32, 32])
x2 torch.Size([4, 3, 32, 32])


In [123]:
print(x.shape)
print(c.shape)
print(t.shape)

torch.Size([4, 3, 32, 32])
torch.Size([4, 38, 40])
torch.Size([4, 32])


In [None]:
dim_mults=(1, 2, 4, 8)
init_dim=32
dim=32

dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
in_out = list(zip(dims[:-1], dims[1:]))
in_out

In [None]:
dims = [dim, *[m*dim for m in dim_mults]]
in_out = list(zip(dims[:-1], dims[1:]))
in_out

In [73]:
in_out = [(1, 10), (2, 20), (30, 300)]
a = [1, 4, 5, 6]
b = [-1, -2, -3]

layer_params = [in_out, a, b]

for i, ((o, oo), a1, b2) in enumerate(zip(*layer_params)):
    print(i, o, oo, a1, b2)

0 1 10 1 -1
1 2 20 4 -2
2 30 300 5 -3


In [74]:
reversed_layer_params = list(map(reversed, layer_params))
reversed_layer_params

[<list_reverseiterator at 0x25502238280>,
 <list_reverseiterator at 0x25502238340>,
 <list_reverseiterator at 0x255022383d0>]

In [75]:
for i, ((o, oo), a1, b2) in enumerate(zip(*reversed_layer_params)):
    print(i, o, oo, a1, b2)

0 30 300 6 -3
1 2 20 5 -2
2 1 10 4 -1


In [72]:
in_out = [(1, 10), (2, 20), (30, 300)]
a = [1, 4, 5, 6]
b = [-1, -2, -3]

layer_params = [reversed(in_out), reversed(a), reversed(b)]

for i, ((o, oo), a1, b2) in enumerate(zip(*layer_params)):
    print(i, o, oo, a1, b2)

0 30 300 6 -3
1 2 20 5 -2
2 1 10 4 -1
