# Attention Recap

In [1]:
import torch
import torch.nn as nn
torch.random.manual_seed(0)
QKT = torch.rand(5, 5)
v = torch.rand(5, 1)
attmap = QKT.unsqueeze(0)
print(attmap.shape)
mask = torch.empty(attmap.shape[0], attmap.shape[1], attmap.shape[2])
mask.fill_(float('-inf'))
mask.triu_(1)
print(mask.shape)
print(mask)

torch.Size([1, 5, 5])
torch.Size([1, 5, 5])
tensor([[[0., -inf, -inf, -inf, -inf],
         [0., 0., -inf, -inf, -inf],
         [0., 0., 0., -inf, -inf],
         [0., 0., 0., 0., -inf],
         [0., 0., 0., 0., 0.]]])


In [16]:
masked_attmap = attmap + mask
print(masked_attmap)
masked_attmap = torch.softmax(masked_attmap, dim=-1)
print(masked_attmap)

tensor([[[0.4963,   -inf,   -inf,   -inf,   -inf],
         [0.6341, 0.4901,   -inf,   -inf,   -inf],
         [0.3489, 0.4017, 0.0223,   -inf,   -inf],
         [0.5185, 0.6977, 0.8000, 0.1610,   -inf],
         [0.6816, 0.9152, 0.3971, 0.8742, 0.4194]]])
tensor([[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.5359, 0.4641, 0.0000, 0.0000, 0.0000],
         [0.3603, 0.3798, 0.2599, 0.0000, 0.0000],
         [0.2369, 0.2834, 0.3140, 0.1657, 0.0000],
         [0.2001, 0.2528, 0.1506, 0.2426, 0.1540]]])


In [23]:
temp = [nn.Conv2d(3,3,4)] + [nn.Linear(10,10)]*2 
nn.Sequential(*temp)

Sequential(
  (0): Conv2d(3, 3, kernel_size=(4, 4), stride=(1, 1))
  (1): Linear(in_features=10, out_features=10, bias=True)
  (2): Linear(in_features=10, out_features=10, bias=True)
)

# VAE

In [2]:
h = torch.rand(2, 8, 64, 64)
h_mean_part = h[:, :4]
h_log_std_part = h[:, 4:]
print(h_mean_part.shape, h_log_std_part.shape)

torch.Size([2, 4, 64, 64]) torch.Size([2, 4, 64, 64])


In [4]:
from image_encoder import Resnet, Pad, Atten

In [6]:
in_layer = torch.nn.Conv2d(3, 128, kernel_size=3, stride=1, padding=1)
down_1 = torch.nn.Sequential(
    Resnet(128, 128),
    Resnet(128, 128),
    torch.nn.Sequential(
        Pad(),
        torch.nn.Conv2d(128, 128, 3, stride=2, padding=0),
    ),
)
down_2 = torch.nn.Sequential(
    Resnet(128, 256),
    Resnet(256, 256),
    torch.nn.Sequential(
        Pad(),
        torch.nn.Conv2d(256, 256, 3, stride=2, padding=0),
    ),
)
down_3 = torch.nn.Sequential(
    Resnet(256, 512),
    Resnet(512, 512),
    torch.nn.Sequential(
        Pad(),
        torch.nn.Conv2d(512, 512, 3, stride=2, padding=0),
    ),
)

down_4 = torch.nn.Sequential(
                Resnet(512, 512),
                Resnet(512, 512),
            )
mid = torch.nn.Sequential(
    Resnet(512, 512),
    Atten(),
    Resnet(512, 512),
)

out_1 = torch.nn.Sequential(
    torch.nn.GroupNorm(num_channels=512, num_groups=32, eps=1e-6),
    torch.nn.SiLU(),
    torch.nn.Conv2d(512, 8, 3, padding=1),
)

out_2 = torch.nn.Conv2d(8, 8, 1)

input = torch.rand(2, 3, 512, 512)
print('input dim:', input.shape)
out = in_layer(input)
print('after input_layer:', out.shape)
out = down_1(out)
print('after down_1:', out.shape)
out = down_2(out)
print('after down_2:', out.shape)
out = down_3(out)
print('after down_3:', out.shape)
out = down_4(out)
print('after down_4:', out.shape)
out = mid(out)
print('after mid:', out.shape)
out = out_1(out)
print('after out_1:', out.shape)
out = out_2(out)
print('after out_1:', out.shape)

input dim: torch.Size([2, 3, 512, 512])
after input_layer: torch.Size([2, 128, 512, 512])
after down_1: torch.Size([2, 128, 256, 256])
after down_2: torch.Size([2, 256, 128, 128])
after down_3: torch.Size([2, 512, 64, 64])
after down_4: torch.Size([2, 512, 64, 64])
after mid: torch.Size([2, 512, 64, 64])
after out_1: torch.Size([2, 8, 64, 64])
after out_1: torch.Size([2, 8, 64, 64])


In [7]:
# normal transformation
mean = h[:, :4]
# [1, 4, 64, 64]
logvar = h[:, 4:]
std = logvar.exp()**0.5

# [1, 4, 64, 64]
h = torch.randn(mean.shape, device=mean.device)
h = mean + std * h

In [12]:
# in
in_1 = torch.nn.Conv2d(4, 4, 1)


in_2 = torch.nn.Conv2d(4, 512, kernel_size=3, stride=1, padding=1)

# middle
middle_1 = torch.nn.Sequential(Resnet(512, 512), Atten(), Resnet(512, 512))

# up
up_1 = torch.nn.Sequential(
    Resnet(512, 512),
    Resnet(512, 512),
    Resnet(512, 512),
    torch.nn.Upsample(scale_factor=2.0, mode='nearest'),
    torch.nn.Conv2d(512, 512, kernel_size=3, padding=1),
)
up_2 = torch.nn.Sequential(
    Resnet(512, 512),
    Resnet(512, 512),
    Resnet(512, 512),
    torch.nn.Upsample(scale_factor=2.0, mode='nearest'),
    torch.nn.Conv2d(512, 512, kernel_size=3, padding=1),
)
up_3 = torch.nn.Sequential(
    Resnet(512, 256),
    Resnet(256, 256),
    Resnet(256, 256),
    torch.nn.Upsample(scale_factor=2.0, mode='nearest'),
    torch.nn.Conv2d(256, 256, kernel_size=3, padding=1),
)
up_4 = torch.nn.Sequential(
    Resnet(256, 128),
    Resnet(128, 128),
    Resnet(128, 128),
)


out_1 = torch.nn.Sequential(
    torch.nn.GroupNorm(num_channels=128, num_groups=32, eps=1e-6),
    torch.nn.SiLU(),
    torch.nn.Conv2d(128, 3, 3, padding=1),
)

In [13]:
input = h
print('input dim:', input.shape)
out = in_1(input)
print('after in_1:', out.shape)
out = in_2(out)
print('after in_2:', out.shape)
out = middle_1(out)
print('after middle_1:', out.shape)
out = up_1(out)
print('after up_1:', out.shape)
out = up_2(out)
print('after up_2:', out.shape)
out = up_3(out)
print('after up_3:', out.shape)
out = up_4(out)
print('after up_4:', out.shape)
out = out_1(out)
print('after out_1:', out.shape)


input dim: torch.Size([2, 4, 64, 64])
after in_1: torch.Size([2, 4, 64, 64])
after in_2: torch.Size([2, 512, 64, 64])
after middle_1: torch.Size([2, 512, 64, 64])
after up_1: torch.Size([2, 512, 128, 128])
after up_2: torch.Size([2, 512, 256, 256])
after up_3: torch.Size([2, 256, 512, 512])
after up_4: torch.Size([2, 128, 512, 512])
after out_1: torch.Size([2, 3, 512, 512])


# Unet

In [3]:
from unet import DownBlock
import torch

In [11]:
down_0 = DownBlock(dim_in=320, dim_out=320)
z_img = torch.rand(2, 320, 64, 64)
z_text = torch.rand(2, 77, 768)
z_time = torch.rand(2, 1280)
out, save = down_0(z_img, z_text, z_time)
print('input for the next down block:', out.shape)
print('input for the corresponding up block:', [item.shape for item in save])

input for the next down block: torch.Size([2, 320, 32, 32])
input for the corresponding up block: [torch.Size([2, 320, 64, 64]), torch.Size([2, 320, 64, 64]), torch.Size([2, 320, 32, 32])]


In [12]:
down_2 = DownBlock(640, 1280)
z_img = torch.rand(2, 640, 16, 16)
z_text = torch.rand(2, 77, 768)
z_time = torch.rand(2, 1280)
out, save = down_2(z_img, z_text, z_time)
print('input for the next down block:', out.shape)
print('input for the corresponding up block:', [item.shape for item in save])

input for the next down block: torch.Size([2, 1280, 8, 8])
input for the corresponding up block: [torch.Size([2, 1280, 16, 16]), torch.Size([2, 1280, 16, 16]), torch.Size([2, 1280, 8, 8])]


In [18]:
class Resnet(torch.nn.Module):

    def __init__(self, dim_in, dim_out):
        super().__init__()

        self.time = torch.nn.Sequential(
            torch.nn.SiLU(),
            torch.torch.nn.Linear(1280, dim_out),
            torch.nn.Unflatten(dim=1, unflattened_size=(dim_out, 1, 1)),
        )

        self.s0 = torch.nn.Sequential(
            torch.torch.nn.GroupNorm(num_groups=32,
                                     num_channels=dim_in,
                                     eps=1e-05,
                                     affine=True),
            torch.nn.SiLU(),
            torch.torch.nn.Conv2d(dim_in,
                                  dim_out,
                                  kernel_size=3,
                                  stride=1,
                                  padding=1),
        )

        self.s1 = torch.nn.Sequential(
            torch.torch.nn.GroupNorm(num_groups=32,
                                     num_channels=dim_out,
                                     eps=1e-05,
                                     affine=True),
            torch.nn.SiLU(),
            torch.torch.nn.Conv2d(dim_out,
                                  dim_out,
                                  kernel_size=3,
                                  stride=1,
                                  padding=1),
        )

        self.res = None
        if dim_in != dim_out:
            self.res = torch.torch.nn.Conv2d(dim_in,
                                             dim_out,
                                             kernel_size=1,
                                             stride=1,
                                             padding=0)

    def forward(self, x, time):
        '''
        Why use time embedding?
        https://www.reddit.com/r/MachineLearning/comments/101s5kj/r_on_time_embeddings_in_diffusion_models/
        '''
        # x: [1, 320, 64, 64]
        # time: [1, 1280]   time step embedding

        res = x

        # [1, 1280] -> [1, 640, 1, 1]
        time = self.time(time)
        print(time.shape)
        # [1, 320, 64, 64] -> [1, 640, 32, 32]
        print(self.s0(x).shape)
        x = self.s0(x) + time

        # 维度不变
        # [1, 640, 32, 32]
        x = self.s1(x)

        # [1, 320, 64, 64] -> [1, 640, 32, 32]
        if self.res:
            res = self.res(res)

        # 维度不变
        # [1, 640, 32, 32]
        x = res + x

        return x

In [19]:
res = Resnet(320, 640)
ans = res(torch.rand(1, 320, 64, 64), torch.rand(1, 1280))

torch.Size([1, 640, 1, 1])
torch.Size([1, 640, 64, 64])
