In [1]:
import torch
import torch.nn as nn
from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler, DiffusionPipeline

# model_name = "runwayml/stable-diffusion-v1-5"
# encoder=CLIPTextModel.from_pretrained(model_name,subfolder="text_encoder")
# vae = AutoencoderKL.from_pretrained(model_name, subfolder='vae')
# unet=UNet2DConditionModel.from_pretrained(model_name, subfolder='unet')
# scheduler=DDPMScheduler.from_pretrained(model_name, subfolder='scheduler')
# tokenizer=CLIPTokenizer.from_pretrained(model_name, subfolder='tokenizer')
# pipeline = DiffusionPipeline.from_pretrained(model_name)
# vae

In [2]:
class Resnet(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Resnet, self).__init__()
        self.layers = nn.Sequential(
            nn.GroupNorm(32, in_channels,eps=1e-6),
            nn.SiLU(),
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.GroupNorm(32, out_channels,eps=1e-6),
            nn.SiLU(),
            nn.Conv2d(out_channels, out_channels, 3, padding=1)
        )

        #incase the input and output channels are different
        self.shortcut = None
        if in_channels != out_channels:
            self.shortcut = nn.Conv2d(in_channels, out_channels, 1)
    def forward(self, x):
        res=x
        if self.shortcut:
            res = self.shortcut(x)
        return res + self.layers(x)

In [3]:
# x=torch.tensor(torch.arange(16).reshape(1,2,2,4),dtype=torch.float32)
# x.flatten(start_dim=2) == x.reshape(1,2,2*4)

In [4]:
(1 / 512)**0.5

0.04419417382415922

In [5]:
class Attention(nn.Module):
    def __init__(self):
        super(Attention, self).__init__()
        self.norm = nn.GroupNorm(32, 512,eps=1e-6)
        self.q=nn.Linear(512, 512)
        self.k=nn.Linear(512, 512)
        self.v=nn.Linear(512, 512)
        self.out=nn.Linear(512, 512)
    def forward(self, x):
        res = x
        x=self.norm(x)
        #b,512,64,64 -> b,512,64*64 -> b,64*64,512
        x=x.reshape(x.shape[0], 512, 64*64).transpose(1,2)
        q=self.q(x)
        k=self.k(x)
        v=self.v(x) 
        k=k.transpose(1,2)
        #b,64*64,512 x b,512,64*64 -> b,64*64,64*64
        #(1/512)**0/5 = 0.044194173824159216
        atten = torch.bmm(q, k) * 0.044194173824159216
        # atten = torch.baddbmm(torch.empty(1, 4096, 4096, device=q.device),
        #                       q,
        #                       k,
        #                       beta=0,
        #                       alpha=0.044194173824159216)




        atten = torch.nn.functional.softmax(atten, dim=2)
        #b,64*64,64*64 x b,64*64,512 -> b,64*64,512

        atten = torch.bmm(atten, v)

        atten =self.out(atten)

        atten = atten.transpose(1,2).reshape(res.shape)
        return res + atten

Attention()(torch.randn(1, 512, 64, 64)).shape

        #x is of shape (batch, 512, 64, 64)


torch.Size([1, 512, 64, 64])

In [6]:
class Pad(nn.Module):
    def forward(self, x):
        #pad the input tensor with 0s on last two dimensions
        return nn.functional.pad(x, (0, 1, 0, 1), mode='constant', value=0)

In [7]:
class VAE(torch.nn.Module):

    def __init__(self):
        super().__init__()

        self.encoder = torch.nn.Sequential(
            #in
            torch.nn.Conv2d(3, 128, kernel_size=3, stride=1, padding=1),

            #down
            torch.nn.Sequential(
                Resnet(128, 128),
                Resnet(128, 128),
                torch.nn.Sequential(
                    Pad(),
                    torch.nn.Conv2d(128, 128, 3, stride=2, padding=0),
                ),
            ),
            torch.nn.Sequential(
                Resnet(128, 256),
                Resnet(256, 256),
                torch.nn.Sequential(
                    Pad(),
                    torch.nn.Conv2d(256, 256, 3, stride=2, padding=0),
                ),
            ),
            torch.nn.Sequential(
                Resnet(256, 512),
                Resnet(512, 512),
                torch.nn.Sequential(
                    Pad(),
                    torch.nn.Conv2d(512, 512, 3, stride=2, padding=0),
                ),
            ),
            torch.nn.Sequential(
                Resnet(512, 512),
                Resnet(512, 512),
            ),

            #mid
            torch.nn.Sequential(
                Resnet(512, 512),
                Attention(),
                Resnet(512, 512),
            ),

            #out
            torch.nn.Sequential(
                torch.nn.GroupNorm(num_channels=512, num_groups=32, eps=1e-6),
                torch.nn.SiLU(),
                torch.nn.Conv2d(512, 8, 3, padding=1),
            ),

            #norm distribution layer
            torch.nn.Conv2d(8, 8, 1),
        )

        self.decoder = torch.nn.Sequential(
            #norm distribution layer
            torch.nn.Conv2d(4, 4, 1),

            #in
            torch.nn.Conv2d(4, 512, kernel_size=3, stride=1, padding=1),

            #middle
            torch.nn.Sequential(Resnet(512, 512), Attention(), Resnet(512, 512)),

            #up
            torch.nn.Sequential(
                Resnet(512, 512),
                Resnet(512, 512),
                Resnet(512, 512),
                torch.nn.Upsample(scale_factor=2.0, mode='nearest'),
                torch.nn.Conv2d(512, 512, kernel_size=3, padding=1),
            ),
            torch.nn.Sequential(
                Resnet(512, 512),
                Resnet(512, 512),
                Resnet(512, 512),
                torch.nn.Upsample(scale_factor=2.0, mode='nearest'),
                torch.nn.Conv2d(512, 512, kernel_size=3, padding=1),
            ),
            torch.nn.Sequential(
                Resnet(512, 256),
                Resnet(256, 256),
                Resnet(256, 256),
                torch.nn.Upsample(scale_factor=2.0, mode='nearest'),
                torch.nn.Conv2d(256, 256, kernel_size=3, padding=1),
            ),
            torch.nn.Sequential(
                Resnet(256, 128),
                Resnet(128, 128),
                Resnet(128, 128),
            ),

            #out
            torch.nn.Sequential(
                torch.nn.GroupNorm(num_channels=128, num_groups=32, eps=1e-6),
                torch.nn.SiLU(),
                torch.nn.Conv2d(128, 3, 3, padding=1),
            ),
        )

    def sample(self, h):
        #repameterization trick use first 4 channels as mean and last 4 as logvar
        #h -> [1, 8, 64, 64]

        #[1, 4, 64, 64]
        mean = h[:, :4]
        logvar = h[:, 4:]
        std = logvar.exp()**0.5

        #[1, 4, 64, 64]
        h = torch.randn(mean.shape, device=mean.device)
        h = mean + std * h

        return h

    def forward(self, x):
        #x -> [1, 3, 512, 512]

        #[1, 3, 512, 512] -> [1, 8, 64, 64]
        h = self.encoder(x)

        #[1, 8, 64, 64] -> [1, 4, 64, 64]
        h = self.sample(h)

        #[1, 4, 64, 64] -> [1, 3, 512, 512]
        h = self.decoder(h)

        return h


# VAE()(torch.randn(1, 3, 512, 512)).shape
# import torch
# from torchview import draw_graph

# model = VAE()
# x = torch.randn(1, 3, 512, 512)
# draw_graph(model, x).visual_graph
# from torchsummary import summary
# summary(model, (3, 512, 512))

In [8]:
from diffusers import AutoencoderKL

#load the pretrained model parameters
params = AutoencoderKL.from_pretrained(
    'runwayml/stable-diffusion-v1-5', subfolder='vae')

vae = VAE()


def load_res(model, param):
    model.layers[0].load_state_dict(param.norm1.state_dict())
    model.layers[2].load_state_dict(param.conv1.state_dict())
    model.layers[3].load_state_dict(param.norm2.state_dict())
    model.layers[5].load_state_dict(param.conv2.state_dict())

    if isinstance(model.shortcut, torch.nn.Module):
        model.shortcut.load_state_dict(param.conv_shortcut.state_dict())


def load_atten(model, param):
    model.norm.load_state_dict(param.group_norm.state_dict())
    model.q.load_state_dict(param.to_q.state_dict())
    model.k.load_state_dict(param.to_k.state_dict())
    model.v.load_state_dict(param.to_v.state_dict())
    model.out.load_state_dict(param.to_out[0].state_dict())


#encoder.in
vae.encoder[0].load_state_dict(params.encoder.conv_in.state_dict())

#encoder.down
for i in range(4):
    load_res(vae.encoder[i + 1][0], params.encoder.down_blocks[i].resnets[0])
    load_res(vae.encoder[i + 1][1], params.encoder.down_blocks[i].resnets[1])

    if i != 3:
        vae.encoder[i + 1][2][1].load_state_dict(
            params.encoder.down_blocks[i].downsamplers[0].conv.state_dict())

#encoder.mid
load_res(vae.encoder[5][0], params.encoder.mid_block.resnets[0])
load_res(vae.encoder[5][2], params.encoder.mid_block.resnets[1])
load_atten(vae.encoder[5][1], params.encoder.mid_block.attentions[0])

#encoder.out
vae.encoder[6][0].load_state_dict(params.encoder.conv_norm_out.state_dict())
vae.encoder[6][2].load_state_dict(params.encoder.conv_out.state_dict())

#encoder norm distribution layer
vae.encoder[7].load_state_dict(params.quant_conv.state_dict())

#decoder norm distribution layer 
vae.decoder[0].load_state_dict(params.post_quant_conv.state_dict())

#decoder in
vae.decoder[1].load_state_dict(params.decoder.conv_in.state_dict())

#decoder mid
load_res(vae.decoder[2][0], params.decoder.mid_block.resnets[0])
load_res(vae.decoder[2][2], params.decoder.mid_block.resnets[1])
load_atten(vae.decoder[2][1], params.decoder.mid_block.attentions[0])

#decoder up
for i in range(4):
    load_res(vae.decoder[i + 3][0], params.decoder.up_blocks[i].resnets[0])
    load_res(vae.decoder[i + 3][1], params.decoder.up_blocks[i].resnets[1])
    load_res(vae.decoder[i + 3][2], params.decoder.up_blocks[i].resnets[2])

    if i != 3:
        vae.decoder[i + 3][4].load_state_dict(
            params.decoder.up_blocks[i].upsamplers[0].conv.state_dict())

#decoder out
vae.decoder[7][0].load_state_dict(params.decoder.conv_norm_out.state_dict())
vae.decoder[7][2].load_state_dict(params.decoder.conv_out.state_dict())

<All keys matched successfully>

In [9]:
# data = torch.randn(1, 3, 512, 512)

# a = vae.encoder(data)
# b = params.encode(data).latent_dist.parameters

# torch.allclose(a, b, atol=1e-5)

True

In [10]:
# data = torch.randn(1, 4, 64, 64)

# a = vae.decoder(data)
# b = params.decode(data).sample

# torch.allclose(a, b, atol=1e-4)

True