In [3]:
import torch.nn as nn
import torch
from diffusers import AutoencoderDC
from diffusers import DDPMScheduler, DDIMScheduler
from CombinationFunctions import PatchEmbedding
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt



In [10]:
class Decoder(nn.Module):
    def __init__(self, embedDimension, latentSize, latentChannels, patchSize, totalTimestamps=1000, beta_schedule = "squaredcos_cap_v2", modelName="mit-han-lab/dc-ae-f64c128-in-1.0-diffusers"):
        super().__init__()
        self.dc_ae = AutoencoderDC.from_pretrained(modelName, torch_dtype=torch.float32)
        self.noiseScheduler = DDPMScheduler(num_train_timesteps=totalTimestamps, beta_schedule=beta_schedule)
        self.patchEmbedding = PatchEmbedding(imageSize = latentSize, patchSize = patchSize, inChannels = latentChannels, embedDimension = embedDimension)


    def forward(self, x, timestep, noisyImage):
        predictedNoise = self.patchEmbedding.unPatchify(x)

        alphaT = self.noiseScheduler.alphas_cumprod[timestep].view(1, 1, 1, 1)
        originallatents = (noisyImage - torch.sqrt(1 - alphaT) * predictedNoise) / torch.sqrt(alphaT)
        originalImage = self.dc_ae.decode(originallatents).sample
        originalImage = originalImage * 0.5 + 0.5

        return predictedNoise, originalImage
        

dec = Decoder(embedDimension=784, latentSize=8, latentChannels=128, patchSize=2)
noisyImage = torch.randn(1, 128, 8, 8)

latents = torch.randn(1, 16, 784)
predictedNoise, originalImage = dec(latents, 10, noisyImage)
predictedNoise.shape, originalImage.shape

(torch.Size([1, 128, 8, 8]), torch.Size([1, 3, 512, 512]))