In [None]:
from typing import List 
import torch 
import torch.nn as nn 
import torch.nn.functional as F
from labml_nn.diffusion.stable_diffusion.model.autaencoder import Decoder 
from labml_nn.diffusion.stable_diffusion.model.clip_embedder import CLIPTextEmbedder 
from labml nn.diffusion.stable diffusion.model.unet import UNetModel 
from transformers import CLIPProcessor, CLIPModel 
import torch.optim as optim 
from torch.utils.data import DataLoader, Dataset 
import loralib as lora 
from ISR.models import RDN

In [None]:
class DiffusionWrapper(nn.Module): 
    def _init__(self, diffusion_model: UNetModel): 
        super().__init__()
        self.diffusion_model = diffusion_model 
    def forward(self, x: torch.Tensor, time_steps: torch.Tensor, context:torch.Tensor):
        return self.diffusion_model(x, time_steps, context)

In [None]:
class LatentDiffusion(nn.Module): 
    model: Diffusionwrapper 
    decode_model: Decoder 
    text_embedding_model: CLIPTextEmbedder 
    def __init__(self, 
                 unet_model: UNetModel, 
                 decoder: decoder, 
                 clip_embedder: CLIPTextEmbedder, 
                 latent_scaling_factor: float, 
                 n_steps: int, 
                 linear_start: float, 
                 linear_end: float, 
                 ):
        super (). init_() 
        self.model = DiffusionWrapper(unet_model) 
        self.decode_model = decoder 
        self.latent_scaling_factor = latent_scaling_factor 
        self.text_embedding_model = clip_embedder 
        self.n_steps = n_steps 
        beta = torch.linspace(linear_start**0.5, linear_end** 0.5, n_steps, dtype=torch.float64)**2 
        self.beta = nn.Parameter(beta.to(torch.float32), requires_grad=False) 
        alpha =  1. - beta 
        alpha_bar = torch.cumprod(alpha, dim=0) 
        self.alpha_bar = nn.Parameter(alpha_bar.to(torch.float32), requires_grad=False) 
        @property 
        def device(self): 
            return next(iter(self.model.parameters())).device 
        def get_text_conditioning(self, prompts: List[str]): 
            return self.text_embedding_model(prompts) 
        def decode(self, z: torch.Tensor):
            return self.decode_model(z/self.latent_scaling_factor)
        def forward(self,x:torch.Tensor,t:torch.Tensor,context:torch.Tensor):
            return self.model(x,t,context)

In [None]:
clip = CLIPMdel.from_pretrained("openai/clip-vit- base-patch32") 
#processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") 
rdn_model = RDN(weights='psnr-large', scale=4)


In [None]:
decoder =  Decoder(out_channels=3, 
                 z_channels=1, 
                 channels=128, 
                 channel_multipliers=[1, 2, 4, 4], 
                 n_resnet_blocks=2) 
unet_model = UNetModel(in_channels=4, 
                     out_channels=4,
                     channels=320, 
                     attention_levels=[0, 1, 2], 
                     n_res_blocks=2, 
                     channel_multipliers=[1, 2, 4, 4], 
                     n_heads=8, 
                     tf_layers=1, 
                     d_cond=768) 
clip_embedder = CLIPTextEmbedder()

In [None]:
model = LatentDiffusion(linear_start=0.00085, 
                        linear_end=0.0120, 
                        n_steps=1000, 
                        latent_scaling_factor=0.18215,
                        decoder=decoder, 
                        clip_embedder=clip_embedder,
                        unet_model=unet_model)

In [None]:
decoder_optimizer = optim.Adam(decoder.parameters(), lr=2e-4) 
unet_optimizer = optim.Adam(unet_model.parameters(), lr=2e-4) 
clip_ optimizer = optim.Adam(filter(lambda p: requires_grad, clip.parameters()), lr=lr=2e-4)


In [None]:
class CustomDataset(Dataset): 
    def __init__(self, data): 
        self.data = data
    def __len__(self): 
        return len(self.data)
    def __getitem__(self, idx): 
        return self.data[idx]
#Example dataset initialization '
data = [(image_tensor1, "a man wearing tuxedo"), (image_tensor2, "a girl wearing gown")] 
dataset = CustomDataset(data) 
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)


In [None]:
def gaussian(window_size, sigma):
    gauss = torch.Tensor([torch.exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])
    return gauss / gauss.sum()

def create_window(window_size, channel):
    _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
    _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
    window = _2D_window.expand(channel, 1, window_size, window_size).contiguous()
    return window

def ssim(img1, img2, window_size=11, window=None, size_average=True, full=False):
    L = 255  
    padd = 0
    (_, channel, height, width) = img1.size()
    if window is None:
        real_size = min(window_size, height, width)
        window = create_window(real_size, channel).to(img1.device)

    mu1 = F.conv2d(img1, window, padding=padd, groups=channel)
    mu2 = F.conv2d(img2, window, padding=padd, groups=channel)

    mu1_sq = mu1 ** 2
    mu2_sq = mu2 ** 2
    mu1_mu2 = mu1 * mu2

    sigma1_sq = F.conv2d(img1 * img1, window, padding=padd, groups=channel) - mu1_sq
    sigma2_sq = F.conv2d(img2 * img2, window, padding=padd, groups=channel) - mu2_sq
    sigma12 = F.conv2d(img1 * img2, window, padding=padd, groups=channel) - mu1_mu2

    C1 = (0.01 * L) ** 2
    C2 = (0.03 * L) ** 2

    ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))

    if size_average:
        return ssim_map.mean()

class SSIMLoss(nn.Module):
    def __init__(self, window_size=11, size_average=True):
        super(SSIMLoss, self).__init__()
        self.window_size = window_size
        self.size_average = size_average

    def forward(self, img1, img2):
        return 1 - ssim(img1, img2, window_size=self.window_size, size_average=self.size_average)


In [None]:
criterion = SSIMLoss(window_size=11)

In [None]:
def train(data,num_epochs,lora_in,lora_out): 
    layer = lora.Linear(lora_in, lora_out, r=16) 
    lora.mark_only_lora_as_trainable(clip) 
    for epoch in range(num_epochs): 
        model.unet_model.train() 
        model decoder.train()
        
        for images, texts in dataloader:
            
            unet_optimizer.zero_grad()
            decoder_optimizer.zero_grad() 
            clip_optimizer.zero_grad()

            text_embeddings = model.clip_embedder(texts) 
             
            with torch.no_grad(): 
                img_features clip(**text_embeddings).last_hidden_state 
            
            latent_images = model.unet_model(img_features, text_embeddings) 
            generated_images = model.decoder(latent_images)
            
            loss = criterion(generated_images, images)
            
            loss.backward() 
            unet_optimizer.step()
            
            loss.backward() 
            decoder_optimizer.step()
            
            loss.backward() 
            clip_optimizer.step()
         
        print(f"Epoch [{epoch+1}/{num_epochs}],Loss: {loss})


In [None]:
def inference(prompts): 
    text_embeddings = model.get_text_conditioning(prompts) 
    with torch.no_grad(): 
        img_features = clip(**text_embeddings).last_hidden_state 
    latent_img = model.unet_model(img_features,text_embeddings) 
    generated_images = model.decode(latent_img) 
    sr_images = rdn model.predict(generated_images)
    return sr_images