In [1]:
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np

def show(path):
    img_arr = np.array(Image.open(path))
    plt.imshow(img_arr)
    plt.axis('off')
    plt.show()

In [2]:
import torch
import os
from PIL import Image
import matplotlib.pyplot as plt
from torch.utils.data import Dataset,DataLoader
from torchvision.transforms import transforms

class UnderWaterImages(Dataset):
    def __init__(self,folder,transforms=None):
        self.dir = folder
        self.images = os.listdir(self.dir)
        self.transforms = transforms
    
    def __len__(self):
        return len(self.images)
        
    def __getitem__(self,idx):
        img_name = self.images[idx]
        img_path = os.path.join(self.dir,img_name)
        img = Image.open(img_path).convert('RGB')

        if self.transforms:
            img = self.transforms(img)
        return img

In [3]:
transforms = transforms.Compose([
    transforms.Resize(224),
    transforms.CenterCrop(224),
    transforms.ToTensor()
])

In [4]:
from torch.utils.data import random_split, DataLoader
train_dataset = UnderWaterImages('/kaggle/input/underwaterimagestrain/Raw',transforms)
dataset_size = len(train_dataset)
val_size = int(0.2 * dataset_size)   
train_size = dataset_size - val_size
train_subset, val_subset = random_split(train_dataset, [train_size, val_size])
train_loader = DataLoader(train_subset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_subset, batch_size=32, shuffle=False)


In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ResidualBlock(nn.Module):
    def __init__(self,in_channels,out_channels,stride=1):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels,out_channels,3,stride=stride,padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels,out_channels,3,stride=1,padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.shortcut = nn.Identity()
        if in_channels!=out_channels or stride!=1:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels,out_channels,1,stride=stride),
                nn.BatchNorm2d(out_channels)
            )
            
    def forward(self,x):
        identity = self.shortcut(x)
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out = out + identity
        out = F.relu(out)
        return out 

In [None]:
epochs = 200 # 
best_val_loss = float('inf')
patience = 10 
patience_counter = 0


for epoch in range(1, epochs + 1):
    train(epoch)
    
    
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for data in val_loader:
            data = data.to(device)
            recon_batch, mu, logvar = model(data)
            loss = model.loss_function(recon_batch, data, mu, logvar)
            val_loss += loss.item()
    
    avg_val_loss = val_loss / len(val_loader.dataset)
    print(f'====> Epoch: {epoch} Average validation loss: {avg_val_loss:.4f}')

    
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        patience_counter = 0
        torch.save(model.state_dict(), 'best_model.pth')
        print('Validation loss decreased. Saving model.')
    else:
        patience_counter += 1
        print(f'Validation loss did not improve. Patience: {patience_counter}/{patience}')

    if patience_counter >= patience:
        print('Early stopping triggered!')
        break 

In [7]:
from torch import optim
device = torch.accelerator.current_accelerator() if torch.accelerator.is_available() else 'cpu'
model = VAE().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
log_interval = 10

In [8]:
def train(epoch):
    model.train() #what this does
    train_loss = 0
    for batch_idx, data in enumerate(train_loader):
        data = data.to(device)
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(data)
        loss = model.loss_function(recon_batch, data, mu, logvar)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        if batch_idx % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader),
                loss.item() / len(data)))

    print('====> Epoch: {} Average loss: {:.4f}'.format(
          epoch, train_loss / len(train_loader.dataset)))
    

In [99]:
epochs = 200 # 
best_val_loss = float('inf')
patience = 10 
patience_counter = 0


for epoch in range(1, epochs + 1):
    train(epoch)
    
    
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for data in val_loader:
            data = data.to(device)
            recon_batch, mu, logvar = model(data)
            loss = model.loss_function(recon_batch, data, mu, logvar)
            val_loss += loss.item()
    
    avg_val_loss = val_loss / len(val_loader.dataset)
    print(f'====> Epoch: {epoch} Average validation loss: {avg_val_loss:.4f}')

    
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        patience_counter = 0
        torch.save(model.state_dict(), 'best_model.pth')
        print('Validation loss decreased. Saving model.')
    else:
        patience_counter += 1
        print(f'Validation loss did not improve. Patience: {patience_counter}/{patience}')

    if patience_counter >= patience:
        print('Early stopping triggered!')
        break 

====> Epoch: 1 Average loss: 260.4387
====> Epoch: 1 Average validation loss: 252.8565
Validation loss decreased. Saving model.
====> Epoch: 2 Average loss: 206.4026
====> Epoch: 2 Average validation loss: 214.4464
Validation loss decreased. Saving model.
====> Epoch: 3 Average loss: 184.0703
====> Epoch: 3 Average validation loss: 197.1457
Validation loss decreased. Saving model.
====> Epoch: 4 Average loss: 178.7315
====> Epoch: 4 Average validation loss: 194.2405
Validation loss decreased. Saving model.
====> Epoch: 5 Average loss: 173.1309
====> Epoch: 5 Average validation loss: 190.5095
Validation loss decreased. Saving model.
====> Epoch: 6 Average loss: 172.1929
====> Epoch: 6 Average validation loss: 184.8078
Validation loss decreased. Saving model.
====> Epoch: 7 Average loss: 166.0671
====> Epoch: 7 Average validation loss: 178.5216
Validation loss decreased. Saving model.
====> Epoch: 8 Average loss: 157.1750
====> Epoch: 8 Average validation loss: 171.6220
Validation loss d

In [20]:
import os

for root, dirs, files in os.walk("/kaggle/input"):
    for f in files:
        if "reconstruction_76.png" in f:
            print(os.path.join(root, f))

/kaggle/input/results/reconstruction_76.png


**Results: Check img1 on GITHUB**

**The blurry images is mostly because of:**
1. The KL divergence is not correctly scaled down. Model is trying to squeeze in too much of info a limited gaussian
2. The model's complexity can be increased a bit more. It might be finding it difficult to reconstruct the image back


**Experiment: Scale down the KL loss by multiplying by a beta factor**

In [6]:
class VAE(nn.Module):
    def __init__(self):
        super().__init__()

        ##decoder
        self.conv1 = nn.Conv2d(3,16,7,stride=2,padding=3)
        
        #each residual block halfs the dimension
        self.res1 = ResidualBlock(16,32,2) 
        self.res2 = ResidualBlock(32,64,2)
        self.res3 = ResidualBlock(64,128,2)
        self.res4 = ResidualBlock(128,256,2)
        self.avgpool = nn.AdaptiveAvgPool2d((1,1))
        self.fc_mean = nn.Linear(256,128)
        self.fc_var = nn.Linear(256,128)
        # self.fc_mean = nn.Linear(128,64)
        # self.fc_var = nn.Linear(128,64)        

        ##encoder
        self.fc1 = nn.Linear(128,256*7*7)

        #kernel=4 stride=2 padding=1 always doubles the dim
        self.up1 = nn.ConvTranspose2d(256,128,4,stride=2,padding=1)
        self.bn1 = nn.BatchNorm2d(128)
        self.up2 = nn.ConvTranspose2d(128,64,4,stride=2,padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.up3 = nn.ConvTranspose2d(64,32,4,stride=2,padding=1)   
        self.bn3 = nn.BatchNorm2d(32)
        self.up4 = nn.ConvTranspose2d(32,16,4,stride=2,padding=1)  
        self.bn4 = nn.BatchNorm2d(16)
        self.up5 = nn.ConvTranspose2d(16,3,4,stride=2,padding=1)      

        
    def encode(self,x):
        x = F.relu(self.conv1(x))
        x = self.res1(x)
        x = self.res2(x)
        x = self.res3(x)
        x = self.res4(x)
        x = torch.flatten(self.avgpool(x),1)
        mean = self.fc_mean(x)
        log_var = self.fc_var(x)
        return mean,log_var

    def reparameterize(self,mu,log_var):
        std = torch.exp(0.5*log_var) #0.5 prevents very huge value 
        eps = torch.randn_like(std)
        z = mu + std*eps
        return z

    def decode(self,z):
        out = F.relu(self.fc1(z))
        out = out.view(-1,256,7,7)
        out = F.relu(self.bn1(self.up1(out)))
        out = F.relu(self.bn2(self.up2(out)))
        out = F.relu(self.bn3(self.up3(out)))
        out = F.relu(self.bn4(self.up4(out)))
        out = torch.sigmoid(self.up5(out))
        return out

    def loss_function(self,out,x,mu,log_vars,beta=0.5):
        loss = nn.MSELoss(reduction='mean')
        recon_loss = loss(x,out)*(x.shape[1] * x.shape[2] * x.shape[3])
        
        kld = -0.5 * torch.sum(1 + log_vars - mu.pow(2) - log_vars.exp())

        return recon_loss+(kld*beta)

    def forward(self,x):
        mu,log_vars = self.encode(x)
        z = self.reparameterize(mu,log_vars)
        out = self.decode(z)

        return out,mu,log_vars

In [9]:
epochs = 100
best_val_loss = float('inf')
patience = 10 
patience_counter = 0


for epoch in range(1, epochs + 1):
    train(epoch)
    
    
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for data in val_loader:
            data = data.to(device)
            recon_batch, mu, logvar = model(data)
            loss = model.loss_function(recon_batch, data, mu, logvar)
            val_loss += loss.item()
    
    avg_val_loss = val_loss / len(val_loader.dataset)
    print(f'Epoch: {epoch} Average validation loss: {avg_val_loss:.4f}')

    
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        patience_counter = 0
        torch.save(model.state_dict(), 'best_model.pth')
        print('Validation loss decreased. Saving model.')
    else:
        patience_counter += 1
        print(f'Validation loss did not improve. Patience: {patience_counter}/{patience}')

    if patience_counter >= patience:
        print('Early stopping triggered!')
        break 

====> Epoch: 1 Average loss: 343.7460
Epoch: 1 Average validation loss: 391.1387
Validation loss decreased. Saving model.
====> Epoch: 2 Average loss: 274.7474
Epoch: 2 Average validation loss: 284.1609
Validation loss decreased. Saving model.
====> Epoch: 3 Average loss: 226.7657
Epoch: 3 Average validation loss: 223.9309
Validation loss decreased. Saving model.
====> Epoch: 4 Average loss: 188.8889
Epoch: 4 Average validation loss: 192.5239
Validation loss decreased. Saving model.
====> Epoch: 5 Average loss: 169.3959
Epoch: 5 Average validation loss: 176.6779
Validation loss decreased. Saving model.
====> Epoch: 6 Average loss: 161.9433
Epoch: 6 Average validation loss: 172.9578
Validation loss decreased. Saving model.
====> Epoch: 7 Average loss: 158.5534
Epoch: 7 Average validation loss: 168.4836
Validation loss decreased. Saving model.
====> Epoch: 8 Average loss: 156.7319
Epoch: 8 Average validation loss: 166.3069
Validation loss decreased. Saving model.
====> Epoch: 9 Average l

In [10]:
from torchvision.utils import save_image

def test(epoch):
    model.eval()
    with torch.no_grad():
        for i, data in enumerate(train_loader):
            data = data.to(device)
            recon_batch, mu, logvar = model(data)

            if i == 0:
                n = min(data.size(0), 8)   # take up to 8 samples

                # concat original and reconstructions along the batch dimension
                comparison = torch.cat([data[:n],
                                        recon_batch[:n]])

                save_image(comparison.cpu(),
                           f"/kaggle/working/reconstruction_{epoch}.png",
                           nrow=n)
            break   # only first batch


In [11]:
test(95)

**Results: Check Img2 on GITHUB**

**Comments:**
The images show that there is definitely an improvement in the generation but it is not very good

**Experiment:** Trying to increase the complexity of the model

In [6]:
class VAE2(nn.Module):
    def __init__(self):
        super().__init__()

        ## Encoder
        self.conv1 = nn.Conv2d(3, 32, 7, stride=2, padding=3) 
        
        self.res1 = ResidualBlock(32, 64, 2)    
        self.res2 = ResidualBlock(64, 128, 2)   
        self.res3 = ResidualBlock(128, 256, 2)  
        self.res4 = ResidualBlock(256, 512, 2)  
        self.avgpool = nn.AdaptiveAvgPool2d((1,1))
        
        self.fc_mean = nn.Linear(512, 128)
        self.fc_var = nn.Linear(512, 128)

        ## Decoder
        self.fc1 = nn.Linear(128, 512 * 7 * 7) 

        self.up1 = nn.ConvTranspose2d(512, 256, 4, stride=2, padding=1)
        self.bn1 = nn.BatchNorm2d(256)
        self.up2 = nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1)
        self.bn2 = nn.BatchNorm2d(128)
        self.up3 = nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1)
        self.bn3 = nn.BatchNorm2d(64)
        self.up4 = nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1)
        self.bn4 = nn.BatchNorm2d(32)
        self.up5 = nn.ConvTranspose2d(32, 3, 4, stride=2, padding=1)

    def encode(self,x):
        x = F.relu(self.conv1(x))
        x = self.res1(x)
        x = self.res2(x)
        x = self.res3(x)
        x = self.res4(x)
        x = torch.flatten(self.avgpool(x),1)
        mean = self.fc_mean(x)
        log_var = self.fc_var(x)
        return mean,log_var

    def reparameterize(self,mu,log_var):
        std = torch.exp(0.5*log_var) #0.5 prevents very huge value 
        eps = torch.randn_like(std)
        z = mu + std*eps
        return z

    def decode(self,z):
        out = F.relu(self.fc1(z))
        out = out.view(-1,512,7,7)
        out = F.relu(self.bn1(self.up1(out)))
        out = F.relu(self.bn2(self.up2(out)))
        out = F.relu(self.bn3(self.up3(out)))
        out = F.relu(self.bn4(self.up4(out)))
        out = torch.sigmoid(self.up5(out))
        return out

    def loss_function(self,out,x,mu,log_vars,beta=0.5):
        loss = nn.MSELoss(reduction='mean')
        recon_loss = loss(x,out)*(x.shape[1] * x.shape[2] * x.shape[3])
        
        kld = -0.5 * torch.sum(1 + log_vars - mu.pow(2) - log_vars.exp())

        return recon_loss+(kld*beta)

    def forward(self,x):
        mu,log_vars = self.encode(x)
        z = self.reparameterize(mu,log_vars)
        out = self.decode(z)
        return out,mu,log_vars

In [11]:
from torch import optim
device = torch.accelerator.current_accelerator() if torch.accelerator.is_available() else 'cpu'
model2 = VAE2().to(device)
model2.load_state_dict(torch.load("/kaggle/input/variationalautoencoderunderwaterimages/pytorch/default/1/best_model2.pth"))
optimizer = optim.Adam(model2.parameters(), lr=1e-3)
log_interval = 10

In [12]:
def train2(epoch):
    model2.train() #what this does
    train_loss = 0
    for batch_idx, data in enumerate(train_loader):
        data = data.to(device)
        optimizer.zero_grad()
        recon_batch, mu, logvar = model2(data)
        loss = model2.loss_function(recon_batch, data, mu, logvar)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        if batch_idx % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader),
                loss.item() / len(data)))

    print('====> Epoch: {} Average loss: {:.4f}'.format(
          epoch, train_loss / len(train_loader.dataset)))

In [None]:
epochs = 100
best_val_loss = float('inf')
patience = 10 
patience_counter = 0


for epoch in range(1, epochs + 1):
    train2(epoch)
    
    
    model2.eval()
    val_loss = 0
    with torch.no_grad():
        for data in val_loader:
            data = data.to(device)
            recon_batch, mu, logvar = model2(data)
            loss = model2.loss_function(recon_batch, data, mu, logvar)
            val_loss += loss.item()
    
    avg_val_loss = val_loss / len(val_loader.dataset)
    print(f'Epoch: {epoch} Average validation loss: {avg_val_loss:.4f}')

    
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        patience_counter = 0
        torch.save(model2.state_dict(), 'best_model2.pth')
        print('Validation loss decreased. Saving model.')
    else:
        patience_counter += 1
        print(f'Validation loss did not improve. Patience: {patience_counter}/{patience}')

    if patience_counter >= patience:
        print('Early stopping triggered!')
        break 

====> Epoch: 1 Average loss: 238.8863
Epoch: 1 Average validation loss: 285.3499
Validation loss decreased. Saving model.
====> Epoch: 2 Average loss: 184.6625
Epoch: 2 Average validation loss: 195.3419
Validation loss decreased. Saving model.
====> Epoch: 3 Average loss: 178.0129
Epoch: 3 Average validation loss: 181.7257
Validation loss decreased. Saving model.
====> Epoch: 4 Average loss: 172.1455
Epoch: 4 Average validation loss: 178.3906
Validation loss decreased. Saving model.
====> Epoch: 5 Average loss: 166.7676
Epoch: 5 Average validation loss: 177.2870
Validation loss decreased. Saving model.
====> Epoch: 6 Average loss: 157.9762
Epoch: 6 Average validation loss: 169.9976
Validation loss decreased. Saving model.
====> Epoch: 7 Average loss: 153.7622
Epoch: 7 Average validation loss: 200.9513
Validation loss did not improve. Patience: 1/10
====> Epoch: 8 Average loss: 151.3638
Epoch: 8 Average validation loss: 183.7894
Validation loss did not improve. Patience: 2/10
====> Epoc

In [18]:
epochs = 100
best_val_loss = float('inf')
patience = 10 
patience_counter = 0


for epoch in range(36, epochs + 1):
    train2(epoch)
    
    
    model2.eval()
    val_loss = 0
    with torch.no_grad():
        for data in val_loader:
            data = data.to(device)
            recon_batch, mu, logvar = model2(data)
            loss = model2.loss_function(recon_batch, data, mu, logvar)
            val_loss += loss.item()
    
    avg_val_loss = val_loss / len(val_loader.dataset)
    print(f'Epoch: {epoch} Average validation loss: {avg_val_loss:.4f}')

    
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        patience_counter = 0
        torch.save(model2.state_dict(), 'best_model2.pth')
        print('Validation loss decreased. Saving model.')
    else:
        patience_counter += 1
        print(f'Validation loss did not improve. Patience: {patience_counter}/{patience}')

    if patience_counter >= patience:
        print('Early stopping triggered!')
        break 

====> Epoch: 36 Average loss: 100.8290
Epoch: 36 Average validation loss: 99.9969
Validation loss decreased. Saving model.
====> Epoch: 37 Average loss: 97.9063
Epoch: 37 Average validation loss: 97.4970
Validation loss decreased. Saving model.
====> Epoch: 38 Average loss: 97.3558
Epoch: 38 Average validation loss: 99.3302
Validation loss did not improve. Patience: 1/10
====> Epoch: 39 Average loss: 94.7040
Epoch: 39 Average validation loss: 97.7283
Validation loss did not improve. Patience: 2/10
====> Epoch: 40 Average loss: 95.5497
Epoch: 40 Average validation loss: 109.3431
Validation loss did not improve. Patience: 3/10
====> Epoch: 41 Average loss: 94.8704
Epoch: 41 Average validation loss: 96.8546
Validation loss decreased. Saving model.
====> Epoch: 42 Average loss: 92.2838
Epoch: 42 Average validation loss: 95.3490
Validation loss decreased. Saving model.
====> Epoch: 43 Average loss: 92.0407
Epoch: 43 Average validation loss: 97.1793
Validation loss did not improve. Patience:

In [16]:
from torchvision.utils import save_image

def test2(epoch):
    model2.eval()
    with torch.no_grad():
        for i, data in enumerate(train_loader):
            data = data.to(device)
            recon_batch, mu, logvar = model2(data)

            if i == 0:
                n = min(data.size(0), 8)   # take up to 8 samples

                # concat original and reconstructions along the batch dimension
                comparison = torch.cat([data[:n],
                                        recon_batch[:n]])

                save_image(comparison.cpu(),
                           f"/kaggle/working/reconstruction2_{epoch}.png",
                           nrow=n)
            break   # only first batch

In [19]:
test2(98)

**Results: Check Img3 on GITHUB**

**Comments:** The performance of this model is significantly better than the last 2. So this is the best model among all the three.

**Hypothesis:**
1. A larger embedding space could improve perfomance even more but computationally required will be more.


**Conclusion: Introducing a beta factor to the KL loss and increasing the model complexity can definitely improve model perfomance**
