In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import h5py
import matplotlib.pyplot as plt
import math
import random

In [None]:
import os

if os.path.exists('./3dshapes.h5') :
  print("file already exi")
else :
    !wget https://storage.googleapis.com/3d-shapes/3dshapes.h5

!ls -lh 3dshapes.h5

In [None]:
class Shapes3d(Dataset) :
  def __init__(self, path) :
    self.path = path
    print("dataset to RAM")
    with h5py.File(path, 'r') as f :
      self.images = f['images'][()]
      self.labels = f['labels'][()]

  def __len__(self) :
    return len(self.images)

  def __getitem__(self, idx) :
    img = self.images[idx]
    label = self.labels[idx]
    img = torch.from_numpy(img).float()/ 255.0
    img = img.permute(2, 0, 1)

    return img, label

dataset = Shapes3d('./3dshapes.h5')
dataset_loader = DataLoader(dataset, batch_size=128, shuffle=True, num_workers=2)

In [None]:
# Model Architecture

class BetaTCVAE(nn.Module) :
  def __init__(self) :
    super(BetaTCVAE, self).__init__()

    # encoder
    # (3, 64, 64)
    self.conv1 = nn.Conv2d(3, 32, kernel_size=4, stride=2, padding=1)
    # (32, 32, 32)
    self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1)
    # (64, 16, 16)
    self.conv3 = nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1)
    # (128, 8, 8)
    self.conv4 = nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1)
    # (256, 4, ,4)

    self.hl1 = nn.Linear(256 * 4 * 4, 256)
    self.hl2_mu = nn.Linear(256, 12) # mean
    self.hl2_logvar = nn.Linear(256, 12) # log(variance)

    #decoder
    self.hl3 = nn.Linear(12, 256)
    self.hl4 = nn.Linear(256, 256 * 4 * 4)

    self.convT1 = nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1)
    self.convT2 = nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1)
    self.convT3 = nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1)
    self.convT4 = nn.ConvTranspose2d(32, 3, kernel_size=4, stride=2, padding=1)

  def encode(self, x) :
    c1 = F.leaky_relu(self.conv1(x))
    c2 = F.leaky_relu(self.conv2(c1))
    c3 = F.leaky_relu(self.conv3(c2))
    c4 = F.leaky_relu(self.conv4(c3))
    h1 = F.leaky_relu(self.hl1(c4.view(-1, 256 * 4 * 4)))
    mu = self.hl2_mu(h1)
    logvar = self.hl2_logvar(h1)
    logvar = torch.clamp(logvar, min=-6.0, max=6.0)
    return mu, logvar

  def reparameterize(self, mu, logvar) :
    std = torch.exp(0.5 * logvar)
    eps = torch.randn_like(std)
    return mu + eps * std

  def decode(self, z) :
    h3 = F.leaky_relu(self.hl3(z))
    h4 = F.leaky_relu(self.hl4(h3))
    cT1 = F.leaky_relu(self.convT1(h4.view(-1, 256, 4, 4)))
    cT2 = F.leaky_relu(self.convT2(cT1))
    cT3 = F.leaky_relu(self.convT3(cT2))
    cT4 = torch.sigmoid(self.convT4(cT3))
    img = torch.clamp(cT4, min=1e-6, max=1.0)
    return img

  def forward(self, x) :
    mu, logvar = self.encode(x)
    z = self.reparameterize(mu, logvar)
    recon_x = self.decode(z)
    return recon_x, z, mu, logvar

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = BetaTCVAE()
checkpoint = torch.load("/kaggle/input/betatcvae-shapes3d/final_model.pth", map_location=torch.device(device))
model.load_state_dict(checkpoint)
model.eval()
torch.no_grad()
model.to(device)

In [None]:
# batch curation

def discretize(dataset) :
    all_floats = np.array([data[1] for data in dataset])
    all_uids = np.zeros_like(all_floats, dtype=int)

    for i in range(6) :
        _, all_uids[:, i] = np.unique(all_floats[:, i], return_inverse=True)

    return all_uids

class CuratedSampler :
    def __init__(self, dataset, normal_criteria, outlier_criteria) :
        self.discrete = discretize(dataset)
        self.dataset = dataset
        self.normal_idx = []
        self.outlier_idx = []

        for idx in range(len(dataset)) :            
            is_normal = True
            for label_idx, normal_value in normal_criteria.items() :
                if self.discrete[idx][label_idx] != normal_value :
                    is_normal = False
            if is_normal : self.normal_idx.append(idx)
            else :
                is_outlier = True
                for label_idx, outlier_value in outlier_criteria.items() :
                    if self.discrete[idx][label_idx] != outlier_value :
                        is_outlier = False
                if is_outlier : self.outlier_idx.append(idx)
    
    def get_batch(self, num_normal, num_outlier, device) :
        normal_idxes = random.choices(self.normal_idx, k = num_normal)
        outlier_idxes = random.choices(self.outlier_idx, k = num_outlier)
        batch = []
        for idx in normal_idxes :
            batch.append(self.dataset[idx][0])
        for idx in outlier_idxes :
            batch.append(self.dataset[idx][0])
            
        return torch.stack(batch).to(device)

In [None]:
normal_criteria = {1: 3, 4: 2, 2 : 5}
outlier_criteria = {}
num_normals = 300
num_outliers = 20
curator = CuratedSampler(dataset, normal_criteria, outlier_criteria)
batch = curator.get_batch(num_normals, num_outliers, device)

In [None]:
def calculate_z_scores(input_data) :
    batch_z, _ = model.encode(input_data)
    loo_z_scores = []
    num_samples = input_data.size(0)
    for i in range(num_samples) :
        mask = torch.ones(num_samples, dtype=bool)
        mask[i] = False
        curr_samples = batch_z[mask]
        test_sample = batch_z[i:i+1]
        mu = curr_samples.mean(dim=0)
        std = curr_samples.std(dim=0)
        std = torch.where(std < 1e-6, torch.ones_like(std), std)
        z_score = (test_sample - mu) / std
        loo_z_scores.append(z_score)
    z_scores = torch.vstack(loo_z_scores)
    return z_scores

def calculate_MD(input_data) :
    batch_z, _ = model.encode(input_data)
    loo_md = []
    n_samples = input_data.size(0)
    for i in range(n_samples) :
        mask = torch.ones(n_samples, dtype=bool)
        mask[i] = False
        curr_samples = batch_z[mask]
        test_sample = batch_z[i:i+1]
        mu = curr_samples.mean(dim=0)
        covar = torch.cov(curr_samples.T)
        inv_covar = torch.linalg.pinv(covar)
        disp_vec = test_sample - mu
        dist = torch.sqrt((disp_vec @ inv_covar @ disp_vec.T).squeeze())
        loo_md.append(dist)
    distances = torch.tensor([d.item() for d in loo_md]).to(device)
    return distances 

In [None]:
z_scores = calculate_z_scores(batch)
distances = calculate_MD(batch)

In [None]:
distances = distances.cpu().numpy()
plt.figure(figsize=(15, 10))
plt.hist(distances, bins=20, alpha=1, color="blue", density=True)
plt.title("Mahlanobis Distance Distribution")
plt.xlabel("MD Score (lower is normal)")
plt.ylabel("Density")
plt.grid(True, alpha=0.3)
plt.show()

In [None]:
def differentiate_batch(batch, distances, threshold) :
    if not isinstance(distances, torch.Tensor) :
        distances = torch.tensor(distances, device= batch.device)
    mask = distances > threshold
    outliers = batch[mask]
    normals = batch[~mask]
    normal_indices = torch.nonzero(~mask, as_tuple=True)[0]
    outlier_indices = torch.nonzero(mask, as_tuple=True)[0]
    return outliers, normals, outlier_indices, normal_indices

def generate_counterfactual(batch, normal_idxes, outlier_idx, z_scores) :
    batch_z, _ = model.encode(batch)
    outlier_dimension = torch.argmax(torch.abs(z_scores[outlier_idx])).cpu().numpy()
    mu = torch.mean(batch_z[normal_idxes], dim=0)
    counterfactual = batch_z[outlier_idx].clone()
    counterfactual[outlier_dimension] = mu[outlier_dimension]
    original_img = batch[outlier_idx]
    counterfactual_img = model.decode(counterfactual.unsqueeze(0))
    fig, axes = plt.subplots(1, 11, figsize=(15, 15))
    for i in range(11) :
        img_z = batch_z[outlier_idx] * (1 - i/10) + counterfactual * (i/10)
        img = model.decode(img_z)
        axes[i].imshow(img.squeeze().permute(1, 2, 0).detach().cpu().numpy())
        if i == 0 :
            axes[i].set_title("original outlier")
        if i == 10 :
            axes[i].set_title("counterfactual image")
        axes[i].axis('off')
    plt.show()

In [None]:
# we infer the threshold from the above produced histogram
threshold = 7
_, _, outlier_idxes, normal_idxes = differentiate_batch(batch, distances, threshold)

outlier_z = z_scores[outlier_idxes]
outlier_dimension = torch.argmax(torch.abs(outlier_z), dim=1).cpu().numpy()

plt.figure(figsize=(12, 6))
plt.hist(outlier_dimension, bins=range(outlier_z.shape[1] + 1), 
         align='left', rwidth=0.8, color='teal', alpha=0.7, density=False)
plt.title("Anomaly Distribution across Dimmensions in Latent Space")
plt.xlabel("Dimensions")
plt.ylabel("Density")
plt.xticks(range(outlier_z.shape[1]))
plt.grid(axis='y', linestyle='--', alpha=0.3)
plt.show()

In [None]:
generate_counterfactual(batch, normal_idxes, outlier_idxes[0], z_scores)
generate_counterfactual(batch, normal_idxes, outlier_idxes[1], z_scores)
generate_counterfactual(batch, normal_idxes, outlier_idxes[2], z_scores)
generate_counterfactual(batch, normal_idxes, outlier_idxes[3], z_scores)
generate_counterfactual(batch, normal_idxes, outlier_idxes[4], z_scores)
generate_counterfactual(batch, normal_idxes, outlier_idxes[5], z_scores)