In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
!pip install git+https://github.com/openai/CLIP.git

In [None]:
import torch
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F
import torch.distributions as distributions
import torch.optim as optim

import torchvision
import torchvision.transforms as transforms
import torchvision.utils as utils
import numpy as np
import matplotlib.pyplot as plt

class PlanarFlow(nn.Module):
    def __init__(self, dim):
        """Instantiates one step of planar flow.

        Reference:
        Variational Inference with Normalizing Flows
        Danilo Jimenez Rezende, Shakir Mohamed
        (https://arxiv.org/abs/1505.05770)

        Args:
            dim: input dimensionality.
        """
        super(PlanarFlow, self).__init__()

        self.u = nn.Parameter(torch.randn(1, dim))
        self.w = nn.Parameter(torch.randn(1, dim))
        self.b = nn.Parameter(torch.randn(1))

    def forward(self, x):
        """Forward pass.

        Args:
            x: input tensor (B x D).
        Returns:
            transformed x and log-determinant of Jacobian.
        """
        def m(x):
            return F.softplus(x) - 1.
        def h(x):
            return torch.tanh(x)
        def h_prime(x):
            return 1. - h(x)**2

        inner = (self.w * self.u).sum()
        u = self.u + (m(inner) - inner) * self.w / self.w.norm()**2
        activation = (self.w * x).sum(dim=1, keepdim=True) + self.b
        x = x + u * h(activation)
        psi = h_prime(activation) * self.w
        log_det = torch.log(torch.abs(1. + (u * psi).sum(dim=1, keepdim=True)))

        return x, log_det

class RadialFlow(nn.Module):
    def __init__(self, dim):
        """Instantiates one step of radial flow.

        Reference:
        Variational Inference with Normalizing Flows
        Danilo Jimenez Rezende, Shakir Mohamed
        (https://arxiv.org/abs/1505.05770)

        Args:
            dim: input dimensionality.
        """
        super(RadialFlow, self).__init__()

        self.a = nn.Parameter(torch.randn(1))
        self.b = nn.Parameter(torch.randn(1))
        self.c = nn.Parameter(torch.randn(1, dim))
        self.d = dim

    def forward(self, x):
        """Forward pass.

        Args:
            x: input tensor (B x D).
        Returns:
            transformed x and log-determinant of Jacobian.
        """
        def m(x):
            return F.softplus(x)
        def h(r):
            return 1. / (a + r)
        def h_prime(r):
            return -h(r)**2

        a = torch.exp(self.a)
        b = -a + m(self.b)
        r = (x - self.c).norm(dim=1, keepdim=True)
        tmp = b * h(r)
        x = x + tmp * (x - self.c)
        log_det = (self.d - 1) * torch.log(1. + tmp) + torch.log(1. + tmp + b * h_prime(r) * r)

        return x, log_det

class HouseholderFlow(nn.Module):
    def __init__(self, dim):
        """Instantiates one step of householder flow.

        Reference:
        Improving Variational Auto-Encoders using Householder Flow
        Jakub M. Tomczak, Max Welling
        (https://arxiv.org/abs/1611.09630)

        Args:
            dim: input dimensionality.
        """
        super(HouseholderFlow, self).__init__()

        self.v = nn.Parameter(torch.randn(1, dim))
        self.d = dim

    def forward(self, x):
        """Forward pass.

        Args:
            x: input tensor (B x D).
        Returns:
            transformed x and log-determinant of Jacobian.
        """
        outer = self.v.t() * self.v
        v_sqr = self.v.norm()**2
        H = torch.eye(self.d).cuda() - 2. * outer / v_sqr
        x = torch.mm(H, x.t()).t()

        return x, 0

class NiceFlow(nn.Module):
    def __init__(self, dim, mask, final=False):
        """Instantiates one step of NICE flow.

        Reference:
        NICE: Non-linear Independent Components Estimation
        Laurent Dinh, David Krueger, Yoshua Bengio
        (https://arxiv.org/abs/1410.8516)

        Args:
            dim: input dimensionality.
            mask: mask that determines active variables.
            final: True if the final step, False otherwise.
        """
        super(NiceFlow, self).__init__()

        self.final = final
        if final:
            self.scale = nn.Parameter(torch.zeros(1, dim))
        else:
            self.mask = mask
            self.coupling = nn.Sequential(
                nn.Linear(dim//2, dim*5), nn.ReLU(),
                nn.Linear(dim*5, dim*5), nn.ReLU(),
                nn.Linear(dim*5, dim//2))

    def forward(self, x):
        if self.final:
            x = x * torch.exp(self.scale)
            log_det = torch.sum(self.scale)

            return x, log_det
        else:
            [B, W] = list(x.size())
            x = x.reshape(B, W//2, 2)

            if self.mask:
                on, off = x[:, :, 0], x[:, :, 1]
            else:
                off, on = x[:, :, 0], x[:, :, 1]

            on = on + self.coupling(off)

            if self.mask:
                x = torch.stack((on, off), dim=2)
            else:
                x = torch.stack((off, on), dim=2)

            return x.reshape(B, W), 0

class Flow(nn.Module):
    def __init__(self, dim, type, length):
        """Instantiates a chain of flows.

        Args:
            dim: input dimensionality.
            type: type of flow.
            length: length of flow.
        """
        super(Flow, self).__init__()

        if type == 'planar':
            self.flow = nn.ModuleList([PlanarFlow(dim) for _ in range(length)])
        elif type == 'radial':
            self.flow = nn.ModuleList([RadialFlow(dim) for _ in range(length)])
        elif type == 'householder':
            self.flow = nn.ModuleList([HouseholderFlow(dim) for _ in range(length)])
        elif type == 'nice':
            self.flow = nn.ModuleList([NiceFlow(dim, i//2, i==(length-1)) for i in range(length)])
        else:
            self.flow = nn.ModuleList([])

    def forward(self, x):
        """Forward pass.

        Args:
            x: input tensor (B x D).
        Returns:
            transformed x and log-determinant of Jacobian.
        """
        [B, _] = list(x.size())
        log_det = torch.zeros(B, 1).cuda()
        for i in range(len(self.flow)):
            x, inc = self.flow[i](x)
            log_det = log_det + inc

        return x, log_det

class GatedLayer(nn.Module):
    def __init__(self, in_dim, out_dim):
        """Instantiates a gated MLP layer.

        Args:
            in_dim: input dimensionality.
            out_dim: output dimensionality.
        """
        super(GatedLayer, self).__init__()

        self.linear = nn.Linear(in_dim, out_dim)
        self.gate = nn.Sequential(nn.Linear(in_dim, out_dim), nn.Sigmoid())

    def forward(self, x):
        """Forward pass.

        Args:
            x: input tensor (B x D).
        Returns:
            transformed x.
        """
        return self.linear(x) * self.gate(x)

class MLPLayer(nn.Module):
    def __init__(self, in_dim, out_dim, gate):
        """Instantiates an MLP layer.

        Args:
            in_dim: input dimensionality.
            out_dim: output dimensionality.
            gate: whether to use gating mechanism.
        """
        super(MLPLayer, self).__init__()

        if gate:
            self.layer = GatedLayer(in_dim, out_dim)
        else:
            self.layer = nn.Sequential(nn.Linear(in_dim, out_dim), nn.ReLU())

    def forward(self, x):
        """Forward pass.

        Args:
            x: input tensor (B x D).
        Returns:
            transformed x.
        """
        return self.layer(x)

class VAE_with_flow(nn.Module):
    def __init__(self, dataset, layer, in_dim, hidden_dim, latent_dim, gate, flow, length):
        """Instantiates a VAE.

        Args:
            dataset: dataset to be modeled.
            layer: number of hidden layers.
            in_dim: input dimensionality.
            hidden_dim: hidden dimensionality.
            latent_dim: latent dimensionality.
            gate: whether to use gating mechanism.
            flow: type of the flow (None if do not use flow).
            length: length of the flow.
        """
        super(VAE_with_flow, self).__init__()

        self.dataset = dataset
        self.latent_dim = latent_dim
        self.mean = nn.Linear(hidden_dim, latent_dim)
        self.log_var = nn.Linear(hidden_dim, latent_dim)

        self.encoder = nn.ModuleList(
            [MLPLayer(in_dim, hidden_dim, gate)] + \
            [MLPLayer(hidden_dim, hidden_dim, gate) for _ in range(layer - 1)])
        self.flow = Flow(latent_dim, flow, length)
        self.decoder = nn.ModuleList(
            [MLPLayer(latent_dim, hidden_dim, gate)] + \
            [MLPLayer(hidden_dim, hidden_dim, gate) for _ in range(layer - 1)] + \
            [nn.Linear(hidden_dim, in_dim)])

    def encode(self, x):
        """Encodes input.

        Args:
            x: input tensor (B x D).
        Returns:
            mean and log-variance of the gaussian approximate posterior.
        """
        for i in range(len(self.encoder)):
            x = self.encoder[i](x)
        return self.mean(x), self.log_var(x)

    def transform(self, mean, log_var):
        """Transforms approximate posterior.

        Args:
            mean: mean of the gaussian approximate posterior.
            log_var: log-variance of the gaussian approximate posterior.
        Returns:
            transformed latent codes and the log-determinant of the Jacobian.
        """
        std = torch.exp(.5 * log_var)
        eps = torch.randn_like(std)
        z = eps.mul(std).add_(mean)

        return self.flow(z)

    def decode(self, z):
        """Decodes latent codes.

        Args:
            z: latent codes.
        Returns:
            reconstructed input.
        """
        for i in range(len(self.decoder)):
            z = self.decoder[i](z)
        return z

    def sample(self, size):
        """Generates samples from the prior.

        Args:
            size: number of samples to generate.
        Returns:
            generated samples.
        """
        z = torch.randn(size, self.latent_dim).cuda()
        if self.dataset == 'mnist':
            return torch.sigmoid(self.decode(z))
        else:
            return self.decode(z)

    def reconstruction_loss(self, x, x_hat):
        """Computes reconstruction loss.

        Args:
            x: original input (B x D).
            x_hat: reconstructed input (B x D).
        Returns:
            sum of reconstruction loss over the minibatch.
        """
        if self.dataset == 'mnist':
            return nn.BCEWithLogitsLoss(reduction='none')(x_hat, x).sum(dim=1, keepdim=True)
        else:
            return nn.MSELoss(reduction='none')(x_hat, x).sum(dim=1, keepdim=True)

    def latent_loss(self, mean, log_var, log_det):
        """Computes KL loss.

        Args:
            mean: mean of the gaussian approximate posterior.
            log_var: log-variance of the gaussian approximate posterior.
            log_det: log-determinant of the Jacobian.
        Returns: sum of KL loss over the minibatch.
        """
        kl = -.5 * torch.sum(1. + log_var - mean.pow(2) - log_var.exp(), dim=1, keepdim=True)
        return kl - log_det

    def loss(self, x, x_hat, mean, log_var, log_det):
        """Computes overall loss.

        Args:
            x: original input (B x D).
            x_hat: reconstructed input (B x D).
            mean: mean of the gaussian approximate posterior.
            log_var: log-variance of the gaussian approximate posterior.
            log_det: log-determinant of the Jacobian.
        Returns:
            sum of reconstruction and KL loss over the minibatch.
        """
        return self.reconstruction_loss(x, x_hat) + self.latent_loss(mean, log_var, log_det)

    def forward(self, x):
        """Forward pass.

        Args:
            x: input tensor (B x D).
        Returns:
            average loss over the minibatch.
        """
        mean, log_var = self.encode(x)
        z, log_det = self.transform(mean, log_var)
        x_hat = self.decode(z)

        return x_hat, self.loss(x, x_hat, mean, log_var, log_det).mean(),z

def logit_transform(x, constraint=0.9, reverse=False):
    '''Transforms data from [0, 1] into unbounded space.

    Restricts data into [0.05, 0.95].
    Calculates logit(alpha+(1-alpha)*x).

    Args:
        x: input tensor.
        constraint: data constraint before logit.
        reverse: True if transform data back to [0, 1].
    Returns:
        transformed tensor and log-determinant of Jacobian from the transform.
        (if reverse=True, no log-determinant is returned.)
    '''
    if reverse:
        x = 1. / (torch.exp(-x) + 1.)    # [0.05, 0.95]
        x *= 2.             # [0.1, 1.9]
        x -= 1.             # [-0.9, 0.9]
        x /= constraint     # [-1, 1]
        x += 1.             # [0, 2]
        x /= 2.             # [0, 1]
        return x, 0
    else:
        [B, C, H, W] = list(x.size())

        # dequantization
        noise = distributions.Uniform(0., 1.).sample((B, C, H, W))
        x = (x * 255. + noise) / 256.

        # restrict data
        x *= 2.             # [0, 2]
        x -= 1.             # [-1, 1]
        x *= constraint     # [-0.9, 0.9]
        x += 1.             # [0.1, 1.9]
        x /= 2.             # [0.05, 0.95]

        # logit data
        logit_x = torch.log(x) - torch.log(1. - x)

        # log-determinant of Jacobian from the transform
        pre_logit_scale = torch.tensor(
            np.log(constraint) - np.log(1. - constraint))
        log_diag_J = F.softplus(logit_x) + F.softplus(-logit_x) \
            - F.softplus(-pre_logit_scale)

        return logit_x, torch.sum(log_diag_J, dim=(1, 2, 3)).mean()

In [None]:
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from types import SimpleNamespace
def compute_similarity(img_z,text_z):
   similarity=F.cosine_similarity(img_z,text_z,dim=1)
   return similarity

import clip
import os
import csv
import re
# optimizer = torch.optim.Adam(list(vae.parameters()) + list(flow.parameters()), lr=1e-4)
def get_clip_features(images, texts, clip_model, preprocess, device):
    # Convert images to PIL, preprocess, and get features
    images_resized = torch.stack([preprocess(transforms.ToPILImage()(img.cpu())) for img in images])
    image_features = clip_model.encode_image(images_resized.to(device))

    text_tokens = clip.tokenize(texts).to(device)
    text_features = clip_model.encode_text(text_tokens)

    return image_features, text_features

class Flickr8kCaptionDataset(Dataset):
    def __init__(self, image_folder, captions_dict, transform=None):
        """
        captions_dict: {image_filename: [caption1, caption2, ...]}
        We'll pick one caption randomly per sample or you can change as needed.
        """
        self.image_folder = image_folder
        self.captions_dict = captions_dict
        self.transform = transform
        self.image_files = list(captions_dict.keys())

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_file = self.image_files[idx]
        img_path = os.path.join(self.image_folder, img_file)

        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)

        # For simplicity, pick the first caption
        caption = self.captions_dict[img_file][0]

        return image, caption

import os # Import os module for path manipulation

from collections import defaultdict

def load_captions(caption_file, delimiter=','):
    """
    Loads image captions from a file, skipping comments and malformed lines.

    Parameters:
    - caption_file (str): Path to the caption file.
    - delimiter (str): Delimiter separating filename and caption.

    Returns:
    - captions_dict (dict): Dictionary mapping filenames to lists of captions.
    - skipped (int): Number of skipped lines due to formatting issues.
    """
    captions_dict = defaultdict(list)
    skipped = 0

    with open(caption_file, 'r', encoding='utf-8') as f:
        header = next(f)
        
        for line_num, line in enumerate(f, start=1):
            line = line.strip()

            # Skip empty lines and comments
            if not line or line.startswith('#'):
                continue

            # Split only on the first comma
            if delimiter in line:
                filename, caption = line.split(delimiter, 1)
                filename = filename.strip()
                caption = caption.strip()

                if filename and caption:
                    captions_dict[filename].append(caption)
                else:
                    skipped += 1
                    print(f"Line {line_num} skipped (missing filename or caption): {line}")
            else:
                skipped += 1
                print(f"Line {line_num} skipped (no delimiter found): {line}")

    return dict(captions_dict), skipped

captions,skipped = load_captions("/kaggle/input/flickr8k/captions.txt")

def main(args):

    device = torch.device(device="cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using {device} device")
    # model hyperparameters
    dataset_name = args.dataset
    batch_size = args.batch_size
    layer = args.layer
    hidden_dim = args.hidden_dim
    latent_dim = args.latent_dim
    gate = args.gate
    flow = args.flow
    length = args.length

    clip_model, preprocess = clip.load("ViT-B/32", device=device)

    # optimization hyperparameters
    lr = args.lr
    momentum = args.momentum
    decay = args.decay

    from torchvision import transforms

    transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
   ])
    from torch.utils.data import DataLoader

    image_folder = "/kaggle/input/flickr8k/Images"

    dataset = Flickr8kCaptionDataset(image_folder, captions, transform=transform)
    dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=2)

    # for images, texts in dataloader:
    #     image_features, text_features = get_clip_features(images, texts, clip_model, preprocess, device)

    vae_image = VAE_with_flow(dataset, layer, 512, hidden_dim, latent_dim, gate, flow, length).to(device)
    vae_text = VAE_with_flow(dataset, layer, 512, hidden_dim, latent_dim, gate, flow, length).to(device)
    optimizer = optim.Adam(list(vae_image.parameters())+list(vae_text.parameters()), lr=lr, betas=(momentum, decay))
    total_iter = 0


    running_loss = 0.0
    from tqdm import tqdm
    for epoch in range(1, 100):
        for images,texts in tqdm(dataloader, desc=f"Epoch {epoch}"):
            image_features, text_features = get_clip_features(images, texts, clip_model, preprocess, device)
            vae_image.train()
            if total_iter == args.max_iter:
                break

            total_iter += 1
            optimizer.zero_grad()
            image_features = image_features / image_features.norm(dim=1, keepdim=True)

            # forward pass
            x_in_image=image_features.to(dtype=torch.float32)
            x_in_text=text_features.to(dtype=torch.float32)
            x_hat_image, loss_image, z_image = vae_image(x_in_image)
            x_hat_text, loss_text, z_text = vae_text(x_in_text)
            similarity=compute_similarity(z_image, z_text)
            loss = loss_image + loss_text
            running_loss += loss.item()

        
            loss.backward()
            optimizer.step()

            if total_iter % 1000 == 0:
                mean_loss = running_loss / 1000
                print('iter %s:' % total_iter,
                      'loss = %.3f' % mean_loss,
                      )
                running_loss = 0.0
                with torch.no_grad():
                # Cosine similarity between image and text latent vectors
                  cos = torch.nn.CosineSimilarity(dim=1)
                  sim_scores = cos(z_image, z_text)
                  avg_sim = sim_scores.mean().item()

                # Reconstruction quality (optional): MSE or MAE
                  mse_image = torch.nn.functional.mse_loss(x_hat_image, x_in_image).item()
                  mse_text = torch.nn.functional.mse_loss(x_hat_text, x_in_text).item()

                print(f"[Iter {total_iter}] Loss: {loss.item():.4f} | Cosine Sim: {avg_sim:.4f} | MSE (img): {mse_image:.4f} | MSE (txt): {mse_text:.4f}")


                if total_iter % 20000 == 0:
                    torch.save({
                        'total_iter': total_iter,
                        'loss': mean_loss,
                        'vae_image_state_dict': vae_image.state_dict(),
                        'vae_text_state_dict': vae_text.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'batch_size': batch_size,
                        'layer': layer,
                        'hidden_dim': hidden_dim,
                        'latent_dim': latent_dim,
                        'gate': gate,
                        'flow': flow,
                        'length': length,
                        'similarity':similarity
                        }
                        )
                    print('Checkpoint saved.')

    print('Training finished.')

if __name__ == '__main__':
    args = {
    'dataset': 'flickr8k',
    'batch_size': 32,
    'layer': 3,
    'hidden_dim': 512,
    'latent_dim': 256,
    'gate': 1,
    'flow': None,
    'length': 2,
    'lr': 1e-4,
    'momentum': 0.9,
    'decay': 0.990,
    'max_iter': 100000,
    'sample_size': 16
}

args = SimpleNamespace(**args)
main(args)