In [1]:
# Importing the libraries 
import pandas as pd
import numpy as np
from sklearn import metrics
import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import *
from torchvision import transforms
import torchvision
from tqdm import tqdm
from torchvision.utils import save_image
to_pil_image = transforms.ToPILImage()

from torchvision.utils import make_grid

In [3]:
class SelfAttention(nn.Module):
    def __init__(self, channels, size):
        super(SelfAttention, self).__init__()
        self.channels = channels
        self.size = size
        self.mha = nn.MultiheadAttention(channels, 4, batch_first=True)
        self.ln = nn.LayerNorm([channels])
        self.ff_self = nn.Sequential(
            nn.LayerNorm([channels]),
            nn.Linear(channels, channels),
            nn.GELU(),
            nn.Linear(channels, channels),
        )

    def forward(self, x):
        x = x.view(-1, self.channels, self.size * self.size).swapaxes(1, 2)
        x_ln = self.ln(x)
        attention_value, _ = self.mha(x_ln, x_ln, x_ln)
        attention_value = attention_value + x
        attention_value = self.ff_self(attention_value) + attention_value
        return attention_value.swapaxes(2, 1).view(-1, self.channels, self.size, self.size)


class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels, mid_channels=None, residual=False):
        super().__init__()
        self.residual = residual
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            nn.GroupNorm(1, mid_channels),
            nn.GELU(),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.GroupNorm(1, out_channels),
        )

    def forward(self, x):
        if self.residual:
            return F.gelu(x + self.double_conv(x))
        else:
            return self.double_conv(x)


class Down(nn.Module):
    def __init__(self, in_channels, out_channels, emb_dim=256):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, in_channels, residual=True),
            DoubleConv(in_channels, out_channels),
        )

        self.emb_layer = nn.Sequential(
            nn.SiLU(),
            nn.Linear(
                emb_dim,
                out_channels
            ),
        )

    def forward(self, x,t):
        x = self.maxpool_conv(x)
        emb = self.emb_layer(t)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1])
        return x + emb


class Up(nn.Module):
    def __init__(self, in_channels, out_channels, emb_dim=256):
        super().__init__()

        self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
        self.conv = nn.Sequential(
            DoubleConv(in_channels, in_channels, residual=True),
            DoubleConv(in_channels, out_channels, in_channels // 2),
        )

        self.emb_layer = nn.Sequential(
            nn.SiLU(),
            nn.Linear(
                emb_dim,
                out_channels
            ),
        )

    def forward(self, x, skip_x,t):
        x = self.up(x)
        x = torch.cat([skip_x, x], dim=1)
        x = self.conv(x)
        emb = self.emb_layer(t)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1])
        return x + emb


class UNet(nn.Module):
    def __init__(self, c_in=3, c_out=3, time_dim=256, device="cuda"):
        super().__init__()
        self.device = device
        self.time_dim = time_dim
        self.inc = DoubleConv(c_in, 64)
        self.down1 = Down(64, 128)
        self.sa1 = SelfAttention(128, 64)
        self.down2 = Down(128, 256)
        self.sa2 = SelfAttention(256, 32)
        self.down3 = Down(256, 512)
        self.sa3 = SelfAttention(512, 16)

        self.bot1 = DoubleConv(512, 512)
        #self.bot2 = DoubleConv(256, 256)
        self.bot3 = DoubleConv(512, 256)

        self.up1 = Up(512, 128)
        self.sa4 = SelfAttention(256,8)
        self.up2 = Up(256, 64)
        self.sa5 = SelfAttention(128, 16)
        self.up3 = Up(128, 64)
        self.sa6 = SelfAttention(32, 64)
        self.outc = nn.Conv2d(64, c_out, kernel_size=1)#nn.Sequential(nn.Conv2d(32, c_out, kernel_size=1),
                                 #nn.Sigmoid())

    def pos_encoding(self, t, channels):
        inv_freq = 1.0 / (
            10000
            ** (torch.arange(0, channels, 2, device=self.device).float() / channels)
        )
        pos_enc_a = torch.sin(t.repeat(1, channels // 2) * inv_freq)
        pos_enc_b = torch.cos(t.repeat(1, channels // 2) * inv_freq)
        pos_enc = torch.cat([pos_enc_a, pos_enc_b], dim=-1)
        return pos_enc

    def forward(self, x,t):
        t = t.unsqueeze(-1).type(torch.float)
        t = self.pos_encoding(t, self.time_dim)
        
        #print(x.size())
                           #Bx C x H x W
        x1 = self.inc(x)   # B x 64 x 128 x 128
        
        x2 = self.down1(x1,t) #  B x 128 x  64 x 64
        
        #x2 = self.sa1(x2)     #  B x 128 x  32 x 32
        #print(x2.size())
        x3 = self.down2(x2,t) #  B x 256 x  16 x 16
        
        #x3 = self.sa2(x3)      #  B x 256 x  16 x 16
        x4 = self.down3(x3,t) #  B x 512 x 8 x 8
        #x4 = self.sa3(x4)    #  B x 512 x  8 x 8
        
        x4 = self.bot1(x4)   #  B x 512 x  8 x 8
        #x4 = self.bot2(x4)   #  B x 256 x  4 x 4
        x4 = self.bot3(x4)   #  B x 128 x  4 x 4
        
        x = self.up1(x4, x3,t) #  B x 256 x  8x 8
        
        #x = self.sa4(x)        #  B x 256 x  8x 8
        x = self.up2(x, x2,t)  #  B x 128 x  16 x 16
        #x = self.sa5(x)   #  B x 128 x  16 x 16
        x = self.up3(x, x1,t) #  B x 64 x  32 x 32
        #x = self.sa6(x)   #  B x 64 x  64 x 64
        output = self.outc(x)  #  B x 3 x  128 x 128
        return output



In [4]:
class DKGM(nn.Module):
    def __init__(self,device,T=20):
        super(DKGM, self).__init__()
        
        self.latent_dim = latent_dim

        self.unet=UNet()
        

        self.T=T

        self.device=device

       
    def forward(self,x):

        
        #initial encoding
        t0=torch.zeros(x.size(dim=0),device=self.device)

        #initial decoding
        reconstruction=self.unet(x,t0)

        total_recons=reconstruction
        bias=-reconstruction+x
        recons_bias=torch.zeros_like(x, device=self.device)
        a_i=0
        #sequence of bias encoding +decoding 
        for i in range(self.T):

            

            bias=bias-recons_bias*a_i

            recons_bias=self.unet(bias,t0+i+1)
            a_i=1/(i+1.0)
            total_recons+=recons_bias*a_i


        return total_recons

In [5]:


def final_lossDKGM(mse_loss):

    return mse_loss

def model_trainDKGM(model,dataloader,dataset,device,optimizer,criterion,a=1):
    model.train()
    running_loss=0.0
    counter=0
    for i, data in tqdm(enumerate(dataloader),total=int(len(dataset)/dataloader.batch_size)):
        counter+=1
        data=data[0]
        data=data.to(device)
        optimizer.zero_grad()

        posterior_X=model(data+a*torch.randn_like(data))

        bce_loss= criterion(posterior_X,data)
        
        loss=final_lossEVAE(bce_loss)
        #print(loss)
        loss.backward()
        
        optimizer.step()
        
        running_loss+=loss.item()
        
    train_loss=running_loss/counter

    return train_loss

def model_validateDKGM(model,dataloader,dataset,device,optimizer,criterion,a=1):
    model.eval()
    running_loss=0.0
    counter=0
    with torch.no_grad():
        for i,data in tqdm(enumerate(dataloader),total=int(len(dataset)/dataloader.batch_size)):
            counter+=1
            data=data[0]
            data=data.to(device)
            optimizer.zero_grad()


            posterior_X=model(data+a*torch.randn_like(data))

            bce_loss= criterion(posterior_X,data)
            loss=bce_loss
            running_loss+=loss.item()

            if i==int(len(dataset)/dataloader.batch_size)-1:
                recon_images=posterior_X
                #noisez_image=data+noise
        valid_loss=running_loss/counter
        return valid_loss,recon_images#,noisez_image
        #return valid_loss

# training DKGM

In [7]:
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
latent_dim=64
lr=0.0003
epochs=10
batch_size=32

DKGMmodel =DKGM(T=0,device=device).to(device) #UNet().to(device)


transform= transforms.Compose([
    transforms.CenterCrop((160, 160)),
    transforms.Resize([128, 128]),
    transforms.ToTensor(),
])

#training set transforms.Resize((32,32)),
train_set=torchvision.datasets.CelebA(root='./',split='train',download=False,transform=transform)
train_loader=torch.utils.data.DataLoader(train_set,batch_size=batch_size,shuffle=True)

#test set
test_set=torchvision.datasets.CelebA(root='./',split='valid',download=False,transform=transform)
test_loader=torch.utils.data.DataLoader(test_set,batch_size=batch_size,shuffle=True)
#optimizerVAE=optim.Adam(VAEmodel.parameters(),lr=lr)
optimizerDKGM=optim.Adam(DKGMmodel.parameters(),lr=lr)
#optimizerD=optim.Adam(EVAEmodel.parameters(),lr=lr)
criterion=nn.MSELoss(reduction="sum")

In [None]:
grid_imagesDKGM=[]
train_lossDKGM=[]
valid_lossDKGM=[]

grid_imagesDKGM=[]
train_lossDKGM=[]
valid_lossDKGM=[]

for epoch in range(epochs):
    print(f"Epoch{epoch+1} of {epochs}")
    
    train_epoch_lossDKGM=model_trainDKGM(DKGMmodel,train_loader,train_set,device,optimizerDKGM,criterion,a=1,m=1)
    valid_epoch_lossDKGM,recon_images=model_validateDKGM(DKGMmodel,test_loader,test_set,device,optimizerDKGM,criterion,a=1,m=1)
    train_lossEVAE.append(train_epoch_lossEVAE)
    valid_lossEVAE.append(valid_epoch_lossEVAE)

    save_reconstructed_images((recon_images+1)/2,epoch+1)
    print(f"train loss:{train_epoch_lossDKGM:.4f}")
    print(f"valid loss:{valid_epoch_lossDKGM:.4f}")
    

In [18]:
transform_blur = transforms.Compose([
    transforms.GaussianBlur(kernel_size=(5, 5), sigma=(0.8,1.2))
])



DKGMmodel2 =DKGM(T=3,device=device).to(device) #UNet().to(device)


optimizerDKGM2=optim.Adam(DKGMmodel2.parameters(),lr=lr)


In [9]:


min_loss=1000000000
for epoch in range(epochs):
    print(f"Epoch{epoch+1} of {epochs}")
    DKGMmodel2.train()
    running_loss=0.0
    counter=0
    for i, data in tqdm(enumerate(train_loader),total=int(len(train_set)/train_loader.batch_size)):
        counter+=1
        data=data[0]
        data=data.to(device)
        optimizerDKGM2.zero_grad()



        posterior_Xt_1=DKGMmodel2(transform_blur(data))


        bce_loss= criterion(posterior_Xt_1,data)

        loss=final_lossDKGM(bce_loss)

        loss.backward()

        optimizerDKGM2.step()



        running_loss+=loss.item()#+errD_fake.item()+errD_real.item()

    train_loss=running_loss/counter
    if train_loss<min_loss :
        min_loss=train_loss
        torch.save({
        'epoch': epoch,
        'model_state_dict': DKGMmodel2.state_dict(),
        'optimizer_state_dict': optimizerDKGM2.state_dict(),
        'loss': loss,
        # ... any other relevant variables ...
    }, 'DKGM_boost_CelebA128_best.pt')
    print(train_loss)



Epoch1 of 10


20347it [3:03:44,  1.85it/s]                             


1080.2590066159205
Epoch2 of 10


20347it [3:02:18,  1.86it/s]                             


219.6051408356927
Epoch3 of 10


20347it [3:02:22,  1.86it/s]                             


80.79823357803367
Epoch4 of 10


20347it [3:02:18,  1.86it/s]                             


38.32189697246033
Epoch5 of 10


20347it [3:02:16,  1.86it/s]                             


23.239646034591782
Epoch6 of 10


20347it [3:03:00,  1.85it/s]                             


11.842965366574138
Epoch7 of 10


20347it [3:12:01,  1.77it/s]                             


8.316985123792128
Epoch8 of 10


20347it [3:03:08,  1.85it/s]                             


6.2605974176049894
Epoch9 of 10


20347it [3:03:06,  1.85it/s]                             


4.669031417296021
Epoch10 of 10


20347it [3:03:06,  1.85it/s]                             

6.758881488965174





In [None]:
#VAEmodel.eval()
DKGMmodel.eval()
DKGMmodel2.eval()
running_loss=0.0
counter=0
#test set
tota_sharpDKGM=0.0
tota_sharpDKGM=0.0
from scipy import signal
#laplace
kernel=np.array([[0 ,1, 0],[1, -4,1],[0, 1 ,0]])


from torcheval.metrics import FrechetInceptionDistance

fidDKGM = FrechetInceptionDistance(device=device)            
transform_grayscale=transforms.Grayscale(num_output_channels=1)

from torchmetrics.image.inception import InceptionScore
inception = InceptionScore(normalize=True)   
with torch.no_grad():
    for i,data in tqdm(enumerate(train_loader),total=int(len(train_set)/train_loader.batch_size)):
        counter+=1
        data=data[0]
        data=data.to(device)
        #optimizerVAE.zero_grad()
        optimizerDKGM.zero_grad()

        noise=torch.randn_like(data)

        state1=DKGMmodel(data+1*noise)
        reconstruction_DKGM=DKGMmodel2(state1)

        image_grid_DKGM=transform_grayscale(torch.clamp(reconstruction_DKGM,0,1)).detach().cpu()
        for j in range(data.size(dim=0)):


            sharpnessDKGM = np.var(np.abs(signal.convolve2d(image_grid_DKGM[j][0], kernel, mode="same")))
            tota_sharpDKGM+=sharpnessDKGM

        inception.update(torch.clamp(reconstruction_DKGM.cpu(),0.0,1.0))
        
        
        fidEVAE.update(torch.clamp(data,0,1), is_real=True)
        fidEVAE.update(torch.clamp(reconstruction_DKGM,0,1), is_real=False)
        

lossDKGM=fidDKGM.compute()
Is=inception.compute()
# print(f"FIDVAE: {float(lossVAE)}")
# print(f"shaprnessVAE:{tota_sharpVAE/len(test_set):.4f}")
print(f"FIDDKGM: {float(lossDKGM)}")
print(f"shaprnessDKGM:{tota_sharpDKGM/len(train_set):.4f}")
print(f"IS (mean): {float(Is[0])}")
print(f"IS (std): {float(Is[1])}")