In [1]:
from CombinationFunctions import TimeEmbedding, TextEmbedding
import torch
import torch.nn as nn
import torch.nn.functional as Fn



In [2]:
if(torch.cuda.is_available()):
    device = torch.device('cuda')
else:
    device = torch.device('cpu')
device

device(type='cpu')

In [9]:
embedDimension = 768
tEmbed = TimeEmbedding(embedDimension=embedDimension)
tEmbed.to(device)
time = torch.tensor([1000]).to(device)
tout = tEmbed(time)
tout.shape

torch.Size([1, 768])

In [10]:
class AdaptiveLayerNorm(nn.Module):
    def __init__(self, embedDimension):
        super().__init__()
        self.embedDimension = embedDimension
        self.adaLN = nn.Sequential(
            nn.SiLU(),
            nn.Linear(embedDimension, 6 * embedDimension)
        )
        self.scaleShiftParameters = nn.Parameter(torch.zeros(6, embedDimension))
        nn.init.zeros_(self.adaLN[1].weight)
        nn.init.zeros_(self.adaLN[1].bias)      
    
    def forward(self, t):
        batchSize, _ = t.shape
        t = self.adaLN(t)
        t = t.reshape(batchSize, 6, -1)
        gamma_msa, beta_msa, alpha_msa, gamma_mlp, beta_mlp, alpha_mlp = (
            (self.scaleShiftParameters[None] + t).chunk(6, dim = 1)
        )
        gamma_msa = gamma_msa.squeeze(1)
        beta_msa = beta_msa.squeeze(1)
        alpha_msa = alpha_msa.squeeze(1)
        gamma_mlp = gamma_mlp.squeeze(1)
        beta_mlp = beta_mlp.squeeze(1)
        alpha_mlp = alpha_mlp.squeeze(1)
        return gamma_msa, beta_msa, alpha_msa, gamma_mlp, beta_mlp, alpha_mlp
    
adaNorm = AdaptiveLayerNorm(embedDimension)
g1, b1, a1, g2, b2, a2 = adaNorm(tout)
g1.shape, b1.shape, a1.shape, g2.shape, b2.shape, a2.shape

(torch.Size([1, 768]),
 torch.Size([1, 768]),
 torch.Size([1, 768]),
 torch.Size([1, 768]),
 torch.Size([1, 768]),
 torch.Size([1, 768]))

In [11]:
def shiftModulate(x, scale, shift):
    return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)

class ScaleShiftBlock(nn.Module):
    def __init__(self, embedDimension):
        super().__init__()
        self.embedDimension = embedDimension
        self.norm = nn.LayerNorm(embedDimension, elementwise_affine=False, eps=1e-6)

    def forward(self, x, beta, gamma):
        B, N, W = x.shape
        x_norm = self.norm(x)
        out = shiftModulate(x_norm, gamma, beta)
        return out
    
patchify_latents = torch.randn(1, 16, 768)
scShft = ScaleShiftBlock(embedDimension)
out = scShft(patchify_latents, g1, b1)
out.shape

torch.Size([1, 16, 768])

In [12]:
def scaleModulate(x, scale):
    return x * (1 + scale.unsqueeze(1))

class ScaleBlock(nn.Module):
    def __init__(self, embedDimension):
        super().__init__()
        self.embedDimension = embedDimension
        self.norm = nn.LayerNorm(embedDimension, elementwise_affine=False, eps=1e-6)

    def forward(self, x, alpha):
        B, N, W = x.shape
        x_norm = self.norm(x)
        out = scaleModulate(x_norm, alpha)
        return out
    
patchify_latents = torch.randn(1, 16, 768)
scShft = ScaleBlock(embedDimension)
out = scShft(patchify_latents, a1)
out.shape

torch.Size([1, 16, 768])

In [13]:
class MultiHeadSelfAttention(nn.Module):
    def __init__(self, embedDimension, numHeads, dropout = 0.2):
        super().__init__()

        assert embedDimension%numHeads == 0, "Embedding Dimension is Not Divisible By NumHeads"
        self.embedDimension = embedDimension
        self.numHeads = numHeads
        self.headDim = embedDimension//numHeads

        self.queryKeyValue = nn.Linear(embedDimension, embedDimension * 3, bias=False)
        self.drop = nn.Dropout(dropout)
        self.scale = self.headDim ** -0.5 
        self.outProjection = nn.Linear(embedDimension, embedDimension)

    def forward(self, x):
        BatchSize, N, EmbedDim = x.shape

        qkv = self.queryKeyValue(x)
        qkv = qkv.reshape(BatchSize, N, 3, self.numHeads, EmbedDim // self.numHeads)
        q, k, v = qkv.unbind(2)
        attentionScore = (q @ k.transpose(-2, -1)) * self.scale
        attn = attentionScore.softmax(dim=-1)
        out = attn @ v 
        out = out.transpose(1, 2).reshape(BatchSize, N, EmbedDim)
        out = self.outProjection(out)
        out = self.drop(out)
        return out
    
input = torch.randn(1, 16, 768)
msa = MultiHeadSelfAttention(embedDimension=768, numHeads=8)
out = msa(input)
out.shape

torch.Size([1, 16, 768])

In [14]:
class MultiHeadCrossAttention(nn.Module):
    def __init__(self, embedDimension, numHeads, dropout = 0.2):
        super().__init__()

        assert embedDimension%numHeads == 0, "Embedding Dimension is Not Divisible By NumHeads"
        self.embedDimension = embedDimension
        self.numHeads = numHeads
        self.headDim = embedDimension//numHeads

        self.q = nn.Linear(embedDimension, embedDimension)
        self.k = nn.Linear(embedDimension, embedDimension)
        self.v = nn.Linear(embedDimension, embedDimension)
        self.kv = nn.Linear(embedDimension, 2 * embedDimension)

        self.outProjection = nn.Linear(embedDimension, embedDimension)
        self.drop = nn.Dropout(dropout)
        self.scale = self.headDim ** -0.5 

    def forward(self, x, textCondition):
        batch, Nimg, embedDim = x.shape
        batch, Ntext, embedDim = textCondition.shape

        q = self.q(x)
        k = self.k(textCondition)
        v = self.v(textCondition)

        q = q.view(batch, Nimg, self.numHeads, self.headDim).transpose(1, 2)
        k = k.view(batch, Ntext, self.numHeads, self.headDim).transpose(1, 2)
        v = v.view(batch, Ntext, self.numHeads, self.headDim).transpose(1, 2)

        attn_scores = (q @ k.transpose(-2, -1)) * self.scale
        attn_probs = Fn.softmax(attn_scores, dim=-1)
       
        out = attn_probs @ v 
        out = out.transpose(1, 2).contiguous().view(batch, Nimg, embedDim)
        out = self.outProjection(out)
        out = self.drop(out)
        
        return out


text = ["Generate an Image of a Dog Eating"]
textModel = TextEmbedding()
textembed = textModel(text)
xinp = torch.randn(1, 16, 768)

mca = MultiHeadCrossAttention(embedDimension, numHeads=8)
out = mca(xinp, textembed)
out.shape

torch.Size([1, 16, 768])

In [15]:
class FeedForwardBlock(nn.Module):
    def __init__(self, embedDimension):
        super().__init__()

        self.linear1 = nn.Linear(embedDimension, embedDimension * 4)
        self.linear2 = nn.Linear(embedDimension * 4, embedDimension)
        self.gelu = nn.GELU()

    def forward(self, x):
        x = self.linear1(x)
        x = self.gelu(x)
        x = self.linear2(x)
        return x
    
latents = torch.randn(1, 16, 768)
ff = FeedForwardBlock(embedDimension)
out = ff(latents)
out.shape

torch.Size([1, 16, 768])

In [16]:
class DiTModule(nn.Module):
    def __init__(self, embedDimension, numHeads, dropout = 0.2):
        super().__init__()
        self.embedDimension = embedDimension
        self.numHeads = numHeads

        self.scaleShift1 = ScaleShiftBlock(embedDimension)
        self.multiHeadselfAtten = MultiHeadSelfAttention(embedDimension, numHeads, dropout)
        self.scale1 = ScaleBlock(embedDimension)

        self.multiHeadcrossAtten = MultiHeadCrossAttention(embedDimension, numHeads, dropout)

        self.scaleShift2 = ScaleShiftBlock(embedDimension)
        self.pointwiseFeedForward = FeedForwardBlock(embedDimension)
        self.scale2 = ScaleBlock(embedDimension)

    def forward(self, imageLatents, textEmbeddings, sharedParameters):
        gamma1, beta1, alpha1, gamma2, beta2, alpha2 = sharedParameters

        x = imageLatents
        scaleShiftOut1 = self.scaleShift1(x, gamma1, beta1)
        selfAttnOut = self.multiHeadselfAtten(scaleShiftOut1)
        scaleOut1 = self.scale1(selfAttnOut, alpha1)

        x =  x + scaleOut1

        y = self.multiHeadcrossAtten(x, textEmbeddings)

        y = y + x

        scaleShiftOut2 = self.scaleShift2(y, gamma2, beta2)
        mlpOut = self.pointwiseFeedForward(scaleShiftOut2)
        scaleOut2 = self.scale2(mlpOut, alpha2)

        z = y + scaleOut2

        return z
    

text = ["Generate an Image of a Dog Eating"]
textModel = TextEmbedding()
textembed = textModel(text)
noisedlatents = torch.randn(1, 16, 768)
adaNorm = AdaptiveLayerNorm(embedDimension)
sharedParameters = adaNorm(tout)

dit = DiTModule(embedDimension, numHeads=12)

out = dit(noisedlatents, textembed, sharedParameters)
out.shape

torch.Size([1, 16, 768])

In [17]:
class NDiTModule(nn.Module):
    def __init__(self, blocks, embedDimension, numHeads, dropout = 0.2):
        super().__init__()

        self.adaNorm = AdaptiveLayerNorm(embedDimension)
        
        self.eachBlocks = nn.ModuleList([
            DiTModule(embedDimension, numHeads, dropout)
            for _ in range(blocks)
        ])
        self.finalNorm = nn.LayerNorm(embedDimension, elementwise_affine=False, eps=1e-6)

    def forward(self, imageLatents, textEmbeddings, timeEmbeddings):
        
        x = imageLatents
        sharedParams = self.adaNorm(timeEmbeddings)
        for block in self.eachBlocks:
            x = block(x, textEmbeddings, sharedParams)

        x = self.finalNorm(x)
        return x
        
nDit = NDiTModule(blocks=12, embedDimension=embedDimension, numHeads=12, dropout=0.2)
tEmbed = TimeEmbedding(embedDimension=embedDimension)
tEmbed.to(device)
time = torch.tensor([1000]).to(device)
tout = tEmbed(time)


text = ["Generate an Image of a Dog Eating"]
textModel = TextEmbedding()
textembed = textModel(text)
noisedlatents = torch.randn(1, 16, 768)

out = nDit(noisedlatents, textembed, tout)
out.shape

torch.Size([1, 16, 768])