In [1]:
import os
import random
import numpy as np

import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
import torch.nn as nn
from PIL import Image
from torch import nn, List, Tensor
from torch import optim
from torch.utils.data import DataLoader, Dataset, Subset
from torchvision import transforms
from torch.autograd import Variable
from torch import distributions as dist

In [2]:
#For one dataset
class BMPLoader(Dataset):
    def __init__(self, base_directory):
        self.base_directory = base_directory
        self.file_names = [os.path.join(base_directory, f) for f in os.listdir(base_directory) if f.endswith('.bmp')]
        
        self.transform = transforms.Compose([
            transforms.ToTensor(),  # Converts to [0, 1] range
            transforms.Normalize(mean=[0.5], std=[0.5])  # Normalize to [-1, 1] range
        ])
    
    def __len__(self):
        return len(self.file_names)
    
    def __getitem__(self, idx):
        img = Image.open(self.file_names[idx])
        img = img.convert('L')  # Convert to grayscale
        return self.transform(img)

Min/Max for ACL

In [3]:
# Data for ACL
dataset = BMPLoader('../data/terrain_bitmaps/Automatic-CL/TrainSet-train')
dataloader = DataLoader(dataset, batch_size=64, shuffle=False)

# Print the number of images in the dataset
print(f"Number of images in the dataset: {len(dataset)}")

Number of images in the dataset: 50000


In [4]:
from torch import nn
from abc import abstractmethod
from typing import Callable, List, Any, Optional, Sequence, Type

class BaseVAE(nn.Module):
    
    def __init__(self) -> None:
        super(BaseVAE, self).__init__()

    def encode(self, input: Tensor) -> List[Tensor]:
        raise NotImplementedError

    def decode(self, input: Tensor) -> Any:
        raise NotImplementedError

    def sample(self, batch_size:int, current_device: int, **kwargs) -> Tensor:
        raise NotImplementedError

    def generate(self, x: Tensor, **kwargs) -> Tensor:
        raise NotImplementedError

    @abstractmethod
    def forward(self, *inputs: Tensor) -> Tensor:
        pass

    @abstractmethod
    def loss_function(self, *inputs: Any, **kwargs) -> Tensor:
        pass

In [5]:
class SWAE(BaseVAE):

    def __init__(self,
                 in_channels: int,
                 latent_dim: int,
                 hidden_dims: List = None,
                 reg_weight: int = 100,
                 wasserstein_deg: float= 2.,
                 num_projections: int = 200,
                 projection_dist: str = 'normal',
                    **kwargs) -> None:
        super(SWAE, self).__init__()

        self.latent_dim = latent_dim
        self.reg_weight = reg_weight
        self.p = wasserstein_deg
        self.num_projections = num_projections
        self.proj_dist = projection_dist

        modules = []
        if hidden_dims is None:
            hidden_dims = [32, 64, 128, 256, 512]

        # Build Encoder
        for h_dim in hidden_dims:
            modules.append(
                nn.Sequential(
                    nn.Conv2d(in_channels, out_channels=h_dim,
                              kernel_size= 3, stride= 2, padding  = 1),
                    nn.BatchNorm2d(h_dim),
                    nn.LeakyReLU())
            )
            in_channels = h_dim

        self.encoder = nn.Sequential(*modules)
        self.fc_z = nn.Linear(hidden_dims[-1]*4, latent_dim)

        # Build Decoder
        modules = []

        self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4)

        hidden_dims.reverse()

        for i in range(len(hidden_dims) - 1):
            modules.append(
                nn.Sequential(
                    nn.ConvTranspose2d(hidden_dims[i],
                                       hidden_dims[i + 1],
                                       kernel_size=3,
                                       stride = 2,
                                       padding=1,
                                       output_padding=1),
                    nn.BatchNorm2d(hidden_dims[i + 1]),
                    nn.LeakyReLU())
            )
            

        self.decoder = nn.Sequential(*modules)

        self.final_layer = nn.Sequential(
                            nn.ConvTranspose2d(hidden_dims[-1],
                                               hidden_dims[-1],
                                               kernel_size=3,
                                               stride=2,
                                               padding=1,
                                               output_padding=1),
                            nn.BatchNorm2d(hidden_dims[-1]),
                            nn.LeakyReLU(),
                            nn.Conv2d(hidden_dims[-1], out_channels= 1,
                                      kernel_size= 3, padding= 1),
                            nn.Tanh())

    def encode(self, input: Tensor) -> Tensor:
        """
        Encodes the input by passing through the encoder network
        and returns the latent codes.
        :param input: (Tensor) Input tensor to encoder [N x C x H x W]
        :return: (Tensor) List of latent codes
        """
        result = self.encoder(input)
        result = torch.flatten(result, start_dim=1)

        # Split the result into mu and var components
        # of the latent Gaussian distribution
        z = self.fc_z(result)
        return z

    def decode(self, z: Tensor) -> Tensor:
        result = self.decoder_input(z)
        result = result.view(-1, 512, 2, 2)
        result = self.decoder(result)
        result = self.final_layer(result)
        return result

    def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
        z = self.encode(input)
        return  [self.decode(z), input, z]

    def loss_function(self, recons, input, z) -> dict:
        batch_size = input.size(0)
        bias_corr = batch_size *  (batch_size - 1)
        reg_weight = self.reg_weight / bias_corr

        recons_loss_l2 = F.mse_loss(recons, input)
        recons_loss_l1 = F.l1_loss(recons, input)

        swd_loss = self.compute_swd(z, self.p, reg_weight)

        loss = recons_loss_l2 + recons_loss_l1 + swd_loss
        return loss

    def get_random_projections(self, latent_dim: int, num_samples: int) -> Tensor:
        """
        Returns random samples from latent distribution's (Gaussian)
        unit sphere for projecting the encoded samples and the
        distribution samples.

        :param latent_dim: (Int) Dimensionality of the latent space (D)
        :param num_samples: (Int) Number of samples required (S)
        :return: Random projections from the latent unit sphere
        """
        if self.proj_dist == 'normal':
            rand_samples = torch.randn(num_samples, latent_dim)
        elif self.proj_dist == 'cauchy':
            rand_samples = dist.Cauchy(torch.tensor([0.0]),
                                       torch.tensor([1.0])).sample((num_samples, latent_dim)).squeeze()
        else:
            raise ValueError('Unknown projection distribution.')

        rand_proj = rand_samples / rand_samples.norm(dim=1).view(-1,1)
        return rand_proj # [S x D]


    def compute_swd(self,
                    z: Tensor,
                    p: float,
                    reg_weight: float) -> Tensor:
        """
        Computes the Sliced Wasserstein Distance (SWD) - which consists of
        randomly projecting the encoded and prior vectors and computing
        their Wasserstein distance along those projections.

        :param z: Latent samples # [N  x D]
        :param p: Value for the p^th Wasserstein distance
        :param reg_weight:
        :return:
        """
        prior_z = torch.randn_like(z) # [N x D]
        device = z.device

        proj_matrix = self.get_random_projections(self.latent_dim,
                                                  num_samples=self.num_projections).transpose(0,1).to(device)

        latent_projections = z.matmul(proj_matrix) # [N x S]
        prior_projections = prior_z.matmul(proj_matrix) # [N x S]

        # The Wasserstein distance is computed by sorting the two projections
        # across the batches and computing their element-wise l2 distance
        w_dist = torch.sort(latent_projections.t(), dim=1)[0] - \
                 torch.sort(prior_projections.t(), dim=1)[0]
        w_dist = w_dist.pow(p)
        return reg_weight * w_dist.mean()
    
    def freeze_encoder(self):
        for param in list(self.encoder.parameters()) + list(self.fc_z.parameters()):
            param.requires_grad = False

In [6]:
# Initialize the models
swae_model = SWAE(in_channels=1, latent_dim=64)
swae_model.load_state_dict(torch.load('./SWAE_ACL64.pth'))
swae_model.eval()

# Move models to the appropriate device (GPU or CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
swae_model.to(device)

# Initialize min and max vectors
min_vector = torch.full((64,), float('inf')).to(device)
max_vector = torch.full((64,), float('-inf')).to(device)

# Iterate through the dataset to find the min and max values
with torch.no_grad():  # Disable gradient computation for efficiency
    for batch in dataloader:
        batch = batch.to(device)
        latent_vector = swae_model.encode(batch)
        
        # Update min and max vectors
        min_vector = torch.min(min_vector, latent_vector.min(dim=0)[0])
        max_vector = torch.max(max_vector, latent_vector.max(dim=0)[0])

# Convert min and max values to numpy arrays if needed
min_vectorACL = min_vector.cpu().numpy()
max_vectorACL = max_vector.cpu().numpy()

np.save('min_vectorACL.npy', min_vectorACL)
np.save('max_vectorACL.npy', max_vectorACL)

print("Min values for ACL: ", min_vectorACL)
print("Max values for ACL: ", max_vectorACL)

Min values for ACL:  [-3.0846817 -3.5480852 -3.310703  -3.0203302 -3.5437503 -3.467721
 -4.337242  -3.4853144 -3.92666   -3.1474633 -3.3734627 -3.3316848
 -3.3131342 -3.136463  -3.452568  -3.7839847 -3.4500816 -4.136976
 -2.9277575 -3.8204045 -3.4796271 -3.2866995 -3.5500462 -2.2695384
 -3.2348194 -3.5750706 -3.260126  -3.6794882 -4.7874947 -3.0784323
 -3.7974613 -3.9324567 -3.629394  -3.590842  -3.905172  -3.7703118
 -3.9054127 -3.9840758 -3.340094  -3.1688256 -3.364424  -3.5373845
 -3.1199927 -3.3769403 -3.8936872 -3.7437787 -3.4582052 -3.8529396
 -4.635456  -3.5075243 -3.3206415 -4.2381206 -2.9492784 -4.135289
 -3.3254037 -3.443234  -3.5208828 -3.4595518 -3.641638  -3.1465628
 -3.4188154 -3.2668934 -3.3007722 -3.4350948]
Max values for ACL:  [3.3847892 3.8328092 4.1091824 3.720243  3.5649352 3.3531775 3.036339
 3.7109766 3.2968366 3.9416375 3.7054825 3.4662652 3.6412013 3.1793785
 3.8827918 3.064621  3.5510583 3.4603624 3.0272202 3.0003166 3.6909683
 4.3498893 3.5821972 6.2207127 3.

Min/Max for MCL

In [7]:
# Data for MCL
dataset = BMPLoader('../data/terrain_bitmaps/Manual-CL/TrainSet')
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

# Print the number of images in the dataset
print(f"Number of images in the dataset: {len(dataset)}")

Number of images in the dataset: 50000


In [8]:
# Initialize the models
swae_model = SWAE(in_channels=1, latent_dim=64)
swae_model.load_state_dict(torch.load('./SWAE_64.pth'))
swae_model.eval()

# Move models to the appropriate device (GPU or CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
swae_model.to(device)

# Initialize min and max vectors
min_vector = torch.full((64,), float('inf')).to(device)
max_vector = torch.full((64,), float('-inf')).to(device)

# Iterate through the dataset to find the min and max values
with torch.no_grad():  # Disable gradient computation for efficiency
    for batch in dataloader:
        batch = batch.to(device)
        latent_vector = swae_model.encode(batch)
        
        # Update min and max vectors
        min_vector = torch.min(min_vector, latent_vector.min(dim=0)[0])
        max_vector = torch.max(max_vector, latent_vector.max(dim=0)[0])

# Convert min and max values to numpy arrays if needed
min_vectorMCL = min_vector.cpu().numpy()
max_vectorMCL = max_vector.cpu().numpy()

np.save('min_vectorMCL.npy', min_vectorMCL)
np.save('max_vectorMCL.npy', max_vectorMCL)

print("Min values for MCL: ", min_vectorMCL)
print("Max values for MCL: ", max_vectorMCL)

Min values for MCL:  [-4.026846  -3.7844894 -3.113536  -3.2230625 -4.7366796 -4.057123
 -3.706329  -3.392455  -3.8743103 -3.8935816 -3.4339774 -3.3696833
 -4.183838  -3.5761437 -3.6986923 -3.6437447 -3.6003337 -3.6840518
 -4.4190817 -3.5381734 -4.2054825 -3.8751025 -4.0558653 -4.286704
 -4.062463  -3.2248898 -3.6386237 -3.290187  -3.7026496 -3.8115523
 -2.063816  -4.2707257 -3.9774156 -4.296227  -3.4535143 -3.6449165
 -3.4875813 -3.6147952 -3.4858165 -3.4618804 -3.7035217 -3.3803794
 -4.1484838 -3.5151718 -3.7985313 -3.2333407 -3.2580166 -3.656097
 -3.101282  -3.7456577 -3.6446877 -3.7860909 -3.8346455 -3.3757591
 -3.464241  -3.337626  -4.181528  -3.6549032 -3.5443556 -2.9420679
 -3.2472672 -3.2864485 -3.2318344 -3.0205083]
Max values for MCL:  [4.122861  3.4537418 3.812586  3.7506456 3.3736398 3.5027428 2.5564022
 3.6559896 3.690561  4.2581143 3.1773424 3.2667673 3.7388036 3.7841072
 4.3424907 3.53028   3.6651552 3.109188  3.8801513 3.5555248 3.111924
 3.1390214 3.2859235 2.8316622 4.