In [None]:
class Flatten(nn.Module):
    def forward(self , input):
        return input.view(input.size(0) , -1)
    
class Unflatten(nn.Module):
    def __init__(self , channel , height , width):
        super(Unflatten , self).__init__()
        
        self.channel = channel
        self.height = height
        self.width = width
          
    def forward(self , input):
        return input.view(input.size(0) , self.channel , self.height , self.width)

In [None]:
class convVAE(nn.Module):
    
    def __init__(self , latent_size):
        super(convVAE , self).__init__()
        
        self.latent_size = latent_size
        
        self.encoder = nn.Sequential(
             
            nn.Conv2d(3 , 32 , 4 , 4 , 1) ,
            nn.BatchNorm2d(32) ,
            nn.ReLU() , 
            
            nn.Conv2d(32 , 64 , 4 , 2 , 1) , 
            nn.BatchNorm2d(64) ,
            nn.ReLU() , 
            
            nn.Conv2d(64 , 128 , 4 , 2 , 1) , 
            nn.BatchNorm2d(128) , 
            nn.ReLU() , 
            
            Flatten() , 
            nn.Linear(2048 , 1024) , 
            nn.ReLU()            
            )
        
        self.mu = nn.Linear(1024 , self.latent_size)
        self.logvar = nn.Linear(1024 , self.latent_size)
        
        self.decoder = nn.Sequential(
            nn.Linear(self.latent_size , 1024) , 
            nn.ReLU() , 
            nn.Linear(1024 , 2048) , 
            nn.ReLU() , 
            Unflatten(128 , 4 , 4) , 
            nn.ConvTranspose2d(128 , 64 , 4 , 2 , 1) , 
            nn.BatchNorm2d(64) , 
            nn.ReLU() , 
            
            nn.ConvTranspose2d(64 , 32 , 4 , 2 , 1) , 
            nn.BatchNorm2d(32) ,
            nn.ReLU() , 
            
            nn.ConvTranspose2d(32 , 16 , 4 , 2 , 1) , 
            nn.BatchNorm2d(16) , 
            nn.ReLU() , 
            nn.ConvTranspose2d(16 , 3 , 4 , 2 , 1) , 
            nn.BatchNorm2d(3) , 
            nn.Sigmoid()
            )
        
    def encode(self , x):
        h = self.encoder(x)
        mu , logvar = self.mu(h) , self.logvar(h)
        return mu , logvar            
        
    def reparameterize(self , mu , logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        z = mu +eps * std
        return z
        
    def decode(self , z):
        decoded_z = self.decoder(z)
        return decoded_z
        
    def forward(self , x):
        mu , logvar = self.encode(x)
        z = self.reparameterize(mu , logvar)
        return self.decode(z) , mu , logvar
            
vae = convVAE(1024)
vae.to("cuda")

In [None]:
def loss_function(recon_x , x , mu , logvar):
    BCE = F.binary_cross_entropy(recon_x , x , reduction = "sum")
    KLD = -0.5 * torch.sum(1 + logvar - mu**2 - logvar.exp())
    return torch.mean(BCE + KLD)

In [None]:
optimizer = optim.RMSprop(vae.parameters() , lr = 0.001)

In [None]:
EPOCHS = 10

loss_list = []
epoch_list = []

for epoch in range(EPOCHS):
    epoch_list.append(epoch)
    train_loss = 0
    
    for i in tqdm(image_loader):
              #with torch.autograd.set_detect_anomaly(True):
            input_img = i.to("cuda")

            optimizer.zero_grad()
            output , mu , logvar = vae.forward(input_img)

              #loss  = dice_loss(output , ground_truth )
            loss = loss_function(output , input_img , mu , logvar)


            loss.backward()
            optimizer.step()
            train_loss += loss.item() 
            #print(loss)
            input_img.to("cpu")
    
            
        
    
       
    train_loss = train_loss / 23000
    loss_list.append(train_loss)
    print(epoch , train_loss)
torch.save(vae.state_dict() , r"C:\Users\Abhrant\Desktop\abhrant\work\DEEP_LEARNING\vae_weight.pt")
plt.plot(epoch_list , loss_list)
plt.show()