In [None]:
! pip install transformers
! pip3 install datasets

Collecting transformers
  Downloading transformers-4.35.2-py3-none-any.whl (7.9 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.9/7.9 MB[0m [31m29.3 MB/s[0m eta [36m0:00:00[0m
Collecting huggingface-hub<1.0,>=0.16.4 (from transformers)
  Downloading huggingface_hub-0.19.4-py3-none-any.whl (311 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m311.7/311.7 kB[0m [31m19.1 MB/s[0m eta [36m0:00:00[0m
Collecting tokenizers<0.19,>=0.14 (from transformers)
  Downloading tokenizers-0.15.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.8/3.8 MB[0m [31m75.4 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting safetensors>=0.3.1 (from transformers)
  Downloading safetensors-0.4.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m68.6 MB/s[0m eta [36m0:00:00[0m
Ins

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


In [None]:
from datasets import load_dataset

token = "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"
# If the dataset is gated/private, make sure you have run huggingface-cli login
winoground = load_dataset("facebook/winoground", use_auth_token=token)["test"]



In [None]:
from PIL import Image, ImageDraw
from tqdm import tqdm

# utility function to plot examples
def plot_example(i, winoground):
    max_width = max([winoground[i][f"image_{j}"].width for j in (0,1)])
    max_height = max([winoground[i][f"image_{j}"].height for j in (0,1)])

    canvas_width = max_width*2
    canvas_height = max_height

    canvas = Image.new('RGB', (canvas_width, canvas_height), (255, 255, 255))

    (img0, img1) = winoground[i]["image_0"], winoground[i]["image_1"]

    print(f"Left caption: {winoground[i]['caption_0']}")
    print(f"Right caption: {winoground[i]['caption_1']}")

    canvas.paste(img0, (0, 0))
    canvas.paste(img1, (canvas_width//2, 0))

    return canvas

In [None]:
from transformers import FlavaProcessor, FlavaForPreTraining, FlavaModel

model = FlavaForPreTraining.from_pretrained("facebook/flava-full").to(device)
processor = FlavaProcessor.from_pretrained("facebook/flava-full")

`text_config_dict` is provided which will be used to initialize `FlavaTextConfig`. The value `text_config["id2label"]` will be overriden.
`multimodal_config_dict` is provided which will be used to initialize `FlavaMultimodalConfig`. The value `multimodal_config["id2label"]` will be overriden.
`image_codebook_config_dict` is provided which will be used to initialize `FlavaImageCodebookConfig`. The value `image_codebook_config["id2label"]` will be overriden.


In [None]:
winoground_flava_contrastive_scores = []
winoground_flava_itm_scores = []
for example in tqdm(winoground):
  # Note that some images in winoground are RGBA and some are RGB. Need to convert all to RGB with .convert('RGB')
  inputs_c0_i0 = processor(text=[example["caption_0"]], images=[example["image_0"].convert("RGB")], return_tensors="pt", padding="max_length", max_length=77, return_codebook_pixels=True, return_image_mask=True).to("cuda")
  inputs_c1_i0 = processor(text=[example["caption_1"]], images=[example["image_0"].convert("RGB")], return_tensors="pt", padding="max_length", max_length=77, return_codebook_pixels=True, return_image_mask=True).to("cuda")
  inputs_c0_i1 = processor(text=[example["caption_0"]], images=[example["image_1"].convert("RGB")], return_tensors="pt", padding="max_length", max_length=77, return_codebook_pixels=True, return_image_mask=True).to("cuda")
  inputs_c1_i1 = processor(text=[example["caption_1"]], images=[example["image_1"].convert("RGB")], return_tensors="pt", padding="max_length", max_length=77, return_codebook_pixels=True, return_image_mask=True).to("cuda")

  inputs_c0_i0["input_ids_masked"] = inputs_c0_i0["input_ids"].detach().clone()
  inputs_c1_i0["input_ids_masked"] = inputs_c1_i0["input_ids"].detach().clone()
  inputs_c0_i1["input_ids_masked"] = inputs_c0_i1["input_ids"].detach().clone()
  inputs_c1_i1["input_ids_masked"] = inputs_c1_i1["input_ids"].detach().clone()

  inputs_c0_i0["bool_masked_pos"] = torch.zeros_like(inputs_c0_i0["bool_masked_pos"])
  inputs_c1_i0["bool_masked_pos"] = torch.zeros_like(inputs_c1_i0["bool_masked_pos"])
  inputs_c0_i1["bool_masked_pos"] = torch.zeros_like(inputs_c0_i1["bool_masked_pos"])
  inputs_c1_i1["bool_masked_pos"] = torch.zeros_like(inputs_c1_i1["bool_masked_pos"])

  with torch.no_grad():
    outputs_c0_i0 = model(**inputs_c0_i0)
    outputs_c1_i0 = model(**inputs_c1_i0)
    outputs_c0_i1 = model(**inputs_c0_i1)
    outputs_c1_i1 = model(**inputs_c1_i1)

  flava_contrastive_scores_c0_i0 = outputs_c0_i0.contrastive_logits_per_image.item()
  flava_contrastive_scores_c1_i0 = outputs_c1_i0.contrastive_logits_per_image.item()
  flava_contrastive_scores_c0_i1 = outputs_c0_i1.contrastive_logits_per_image.item()
  flava_contrastive_scores_c1_i1 = outputs_c1_i1.contrastive_logits_per_image.item()
  winoground_flava_contrastive_scores.append({"id" : example["id"], "c0_i0": flava_contrastive_scores_c0_i0, "c0_i1": flava_contrastive_scores_c0_i1, "c1_i0": flava_contrastive_scores_c1_i0, "c1_i1": flava_contrastive_scores_c1_i1})

  flava_itm_scores_c0_i0 = torch.nn.functional.softmax(outputs_c0_i0.itm_logits)[0][1].item()
  flava_itm_scores_c1_i0 = torch.nn.functional.softmax(outputs_c1_i0.itm_logits)[0][1].item()
  flava_itm_scores_c0_i1 = torch.nn.functional.softmax(outputs_c0_i1.itm_logits)[0][1].item()
  flava_itm_scores_c1_i1 = torch.nn.functional.softmax(outputs_c1_i1.itm_logits)[0][1].item()
  winoground_flava_itm_scores.append({"id" : example["id"], "c0_i0": flava_itm_scores_c0_i0, "c0_i1": flava_itm_scores_c0_i1, "c1_i0": flava_itm_scores_c1_i0, "c1_i1": flava_itm_scores_c1_i1})
  break

  flava_itm_scores_c0_i0 = torch.nn.functional.softmax(outputs_c0_i0.itm_logits)[0][1].item()
  flava_itm_scores_c1_i0 = torch.nn.functional.softmax(outputs_c1_i0.itm_logits)[0][1].item()
  flava_itm_scores_c0_i1 = torch.nn.functional.softmax(outputs_c0_i1.itm_logits)[0][1].item()
  flava_itm_scores_c1_i1 = torch.nn.functional.softmax(outputs_c1_i1.itm_logits)[0][1].item()
  0%|          | 0/400 [00:06<?, ?it/s]


In [None]:
image_embeddings, text_embeddings = outputs_c0_i0.image_embeddings, outputs_c0_i0.text_embeddings

In [None]:
from transformers import FlavaProcessor, FlavaForPreTraining, FlavaModel

import collections
import math
from collections import OrderedDict
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Set, Tuple, Union

import torch
import torch.utils.checkpoint
from torch import nn

from transformers.activations import ACT2FN
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
from transformers.modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
from transformers.utils import (
    ModelOutput,
    add_code_sample_docstrings,
    add_start_docstrings,
    add_start_docstrings_to_model_forward,
    logging,
    replace_return_docstrings,
)
from transformers.models.flava.configuration_flava import (
    FlavaConfig,
    FlavaImageCodebookConfig,
    FlavaImageConfig,
    FlavaMultimodalConfig,
    FlavaTextConfig,
)
from transformers.models.flava.modeling_flava import (
    FlavaForPreTrainingOutput,
    FlavaLosses,
    FLAVA_PRETRAINED_MODEL_ARCHIVE_LIST,
    FLAVA_CODEBOOK_PRETRAINED_MODEL_ARCHIVE_LIST,
    LOGIT_SCALE_CLAMP_MIN,
    LOGIT_SCALE_CLAMP_MAX,
    FlavaPossibleConfigs,
    logger
)

@dataclass
class FlavaGOTLosses(FlavaLosses):
    got: Optional[torch.FloatTensor] = None


class FlavaGOTConfig(FlavaConfig):
    def __init__(self, got_weight=1.0, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.got_weight = got_weight

@dataclass
class FlavaGOTForPreTrainingOutput(FlavaForPreTrainingOutput):
    loss_info: Optional[FlavaGOTLosses] = None


class FlavaGOTForPreTraining(FlavaForPreTraining):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.got_loss = GOTLoss

    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        input_ids_masked: Optional[torch.LongTensor] = None,
        pixel_values: Optional[torch.FloatTensor] = None,
        codebook_pixel_values: Optional[torch.FloatTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        token_type_ids: Optional[torch.Tensor] = None,
        bool_masked_pos: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        image_attention_mask: Optional[torch.Tensor] = None,
        skip_unmasked_multimodal_encoder: bool = None,
        mlm_labels: Optional[torch.Tensor] = None,
        mim_labels: Optional[torch.Tensor] = None,
        itm_labels: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: bool = True,
        return_dict: Optional[bool] = None,
        return_loss: Optional[bool] = None,
    ) -> Union[Tuple[torch.Tensor], FlavaGOTForPreTrainingOutput]:
        """
        Examples:
        ```python
        >>> from PIL import Image
        >>> import requests
        >>> from transformers import FlavaForPreTraining, AutoProcessor

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> model = FlavaForPreTraining.from_pretrained("facebook/flava-full")
        >>> processor = AutoProcessor.from_pretrained("facebook/flava-full")

        >>> text = ["a photo of a cat"]

        >>> inputs = processor(
        ...     images=[image],
        ...     text=text,
        ...     return_masks=True,
        ...     return_codebook_pixels=True,
        ...     padding=True,
        ...     max_length=77,
        ...     return_tensors="pt",
        ... )


        >>> output = model(**inputs)
        ```

        Return:

        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        return_loss = return_loss if return_loss is not None else self.config.return_loss

        skip_unmasked_multimodal_encoder = (
            skip_unmasked_multimodal_encoder
            if skip_unmasked_multimodal_encoder is not None
            else self.skip_unmasked_multimodal_encoder
        )

        if input_ids_masked is None and input_ids is not None:
            logger.warning(
                "`input_ids_masked` isn't passed which means MLM loss won't be calculated correctlySetting it to"
                " `input_ids` so that model can work. Please pass it if this is unintentional. This is usually OKAY if"
                " you are doing inference on unmasked text..."
            )
            input_ids_masked = input_ids

        flava_output = self.flava(
            input_ids=input_ids,
            pixel_values=pixel_values,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            image_attention_mask=image_attention_mask,
            # Don't need unmasked multimodal embedding for anything so skip it
            # NOTE: ITM uses masked version
            skip_multimodal_encoder=skip_unmasked_multimodal_encoder,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            # Pass true to have deterministic outputs
            return_dict=True,
        )

        flava_masked_output = self.flava(
            input_ids=input_ids_masked,
            pixel_values=pixel_values,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            image_attention_mask=image_attention_mask,
            bool_masked_pos=bool_masked_pos,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=True,
        )

        pos_mask = None

        image_embeddings = flava_output.image_embeddings
        text_embeddings = flava_output.text_embeddings
        image_masked_embeddings = flava_masked_output.image_embeddings
        text_masked_embeddings = flava_masked_output.text_embeddings
        multimodal_masked_embeddings = flava_masked_output.multimodal_embeddings

        total_loss = mim_loss = mlm_loss = mmm_text_loss = mmm_image_loss = gc_loss = itm_loss = None
        mim_logits = mlm_logits = mmm_text_logits = mmm_image_logits = None
        itm_logits = logits_per_image = logits_per_text = None

        # Calculate mim_labels if necessary from the image_codebook
        if image_masked_embeddings is not None or multimodal_masked_embeddings is not None:
            if mim_labels is None and return_loss:
                if self.image_codebook is None:
                    raise RuntimeError(
                        "`return_loss` is set to True but the image codebook is not initialized and no `mim_labels` "
                        " have been passed. Reinstantiate the model with `init_codebook` set to True or "
                        "pass in your custom `mim_labels`"
                    )
                if codebook_pixel_values is None:
                    raise ValueError(
                        "`codebook_pixel_value` are required to generate `mim_labels` if loss is expected. "
                        "Call `AutoProcessor` with `return_codebook_pixels` set to True"
                    )
                mim_labels = self.image_codebook.get_codebook_indices(codebook_pixel_values)
        # Unimodal MIM Loss
        # If multimodal embeddings are present, we will calculate MMM loss
        if self.mim_weight > 0 and image_masked_embeddings is not None and multimodal_masked_embeddings is None:
            sequence_for_image = image_masked_embeddings

            if mim_labels is not None:
                mim_labels = self._resize_to_2d(mim_labels)
                bool_masked_pos = self._resize_to_2d(bool_masked_pos)
                mim_labels[bool_masked_pos.ne(True)] = self.ce_ignore_index

                sequence_for_image = sequence_for_image[:, -mim_labels.size(1) :, :]
                masked_tokens = mim_labels.ne(self.ce_ignore_index)
                mim_labels_filtered = mim_labels[masked_tokens]
                sequence_for_image = sequence_for_image[masked_tokens, :]
                mim_logits = self.mim_head(sequence_for_image)
                if return_loss:
                    mim_loss = nn.functional.cross_entropy(
                        mim_logits.view(-1, self.image_vocab_size), mim_labels_filtered.view(-1)
                    )
                    mim_loss *= self.mim_weight
            else:
                mim_logits = self.mim_head(sequence_for_image)

        # Unimodal MLM Loss
        if self.mlm_weight > 0 and text_masked_embeddings is not None and multimodal_masked_embeddings is None:
            sequence_for_text = text_masked_embeddings
            if mlm_labels is not None:
                mlm_labels = self._resize_to_2d(mlm_labels)
                sequence_for_text = sequence_for_text[:, -mlm_labels.size(1) :, :]
                masked_tokens = mlm_labels.ne(self.ce_ignore_index)
                mlm_labels_filtered = mlm_labels[masked_tokens]
                sequence_for_text = sequence_for_text[masked_tokens, :]
                mlm_logits = self.mlm_head(sequence_for_text)
                if return_loss:
                    mlm_loss = nn.functional.cross_entropy(
                        mlm_logits.view(-1, self.text_vocab_size), mlm_labels_filtered.view(-1)
                    )
                    mlm_loss *= self.mlm_weight
            else:
                mlm_logits = self.mlm_head(sequence_for_text)

        # ITM Loss
        if self.itm_weight > 0 and multimodal_masked_embeddings is not None:
            itm_logits = self.itm_head(multimodal_masked_embeddings)

            if itm_labels is not None:
                pos_pairs = itm_labels.ne(0)
                pos_mask = torch.where(pos_pairs.any(), pos_pairs, pos_pairs.new([True]))
                if return_loss:
                    itm_loss = nn.functional.cross_entropy(itm_logits, itm_labels)
                    itm_loss *= self.itm_weight

                if multimodal_masked_embeddings is not None:
                    multimodal_masked_embeddings = multimodal_masked_embeddings[pos_mask]

                if mlm_labels is not None:
                    mlm_labels = mlm_labels[pos_mask]

                if mim_labels is not None:
                    mim_labels = mim_labels[pos_mask]

        # GOT Regularization Loss (TODO: Integrate GOT code here)
        got_loss = self.got_loss(image_embeddings, text_embeddings)

        # MMM Image Loss
        if multimodal_masked_embeddings is not None and self.mmm_image_weight > 0:
            sequence_for_image = multimodal_masked_embeddings
            end_index = image_masked_embeddings.size(1) - 1
            sequence_for_image = sequence_for_image[:, 2 : 2 + end_index, :]

            if pos_mask is not None:
                sequence_for_image = sequence_for_image[pos_mask]
            if mim_labels is not None:
                mim_labels = self._resize_to_2d(mim_labels)
                bool_masked_pos = self._resize_to_2d(bool_masked_pos)
                mim_labels[bool_masked_pos.ne(True)] = self.ce_ignore_index

                masked_tokens = mim_labels.ne(self.ce_ignore_index)
                mim_labels_filtered = mim_labels[masked_tokens]
                sequence_for_image = sequence_for_image[masked_tokens, :]
                mmm_image_logits = self.mmm_image_head(sequence_for_image)
                if return_loss:
                    mmm_image_loss = nn.functional.cross_entropy(
                        mmm_image_logits.view(-1, self.image_vocab_size), mim_labels_filtered.view(-1)
                    )
                    mmm_image_loss *= self.mmm_image_weight
            else:
                mmm_image_logits = self.mmm_image_head(sequence_for_image)

        # MMM Text Loss
        if multimodal_masked_embeddings is not None and self.mmm_text_weight > 0:
            sequence_for_text = multimodal_masked_embeddings
            sequence_for_text = sequence_for_text[:, -text_masked_embeddings.size(1) :, :]
            if pos_mask is not None:
                sequence_for_text = sequence_for_text[pos_mask]

            if mlm_labels is not None:
                mlm_labels = self._resize_to_2d(mlm_labels)
                masked_tokens = mlm_labels.ne(self.ce_ignore_index)
                mlm_labels_filtered = mlm_labels[masked_tokens]
                sequence_for_text = sequence_for_text[masked_tokens, :]
                mmm_text_logits = self.mmm_text_head(sequence_for_text)
                if return_loss:
                    mmm_text_loss = nn.functional.cross_entropy(
                        mmm_text_logits.view(-1, self.text_vocab_size), mlm_labels_filtered.view(-1)
                    )
                    mmm_text_loss *= self.mmm_text_weight
            else:
                mmm_text_logits = self.mmm_text_head(sequence_for_text)

        # Global Contrastive Loss
        if image_embeddings is not None and text_embeddings is not None and self.global_contrastive_weight > 0:
            text_embedding = self.flava.text_projection(text_embeddings[:, 0, :])
            text_embedding = nn.functional.normalize(text_embedding, dim=-1)

            image_embedding = self.flava.image_projection(image_embeddings[:, 0, :])
            image_embedding = nn.functional.normalize(image_embedding, dim=-1)

            self.flava.logit_scale.data.clamp_(LOGIT_SCALE_CLAMP_MIN, LOGIT_SCALE_CLAMP_MAX)

            logits_per_image, logits_per_text, gc_labels = self.global_contrastive_head(
                image_embedding, text_embedding, self.flava.logit_scale
            )

            # Apply ITM negative mask if any
            if pos_mask is not None:
                logits_per_image = logits_per_image[pos_mask]
                logits_per_text = logits_per_text[pos_mask]
                gc_labels = gc_labels[pos_mask]

            if return_loss:
                gc_loss_image = nn.functional.cross_entropy(logits_per_image, gc_labels)
                gc_loss_text = nn.functional.cross_entropy(logits_per_text, gc_labels)
                gc_loss = (gc_loss_image + gc_loss_text) / 2
                gc_loss *= self.global_contrastive_weight

        flava_losses = FlavaGOTLosses(
            got=got_loss,
            mim=mim_loss,
            mlm=mlm_loss,
            itm=itm_loss,
            global_contrastive=gc_loss,
            mmm_image=mmm_image_loss,
            mmm_text=mmm_text_loss,
        )

        if return_loss and not flava_losses.all_none():
            total_loss = sum(loss if loss is not None else 0 for loss in flava_losses.values())

        if not return_dict:
            output = (
                image_embeddings,
                flava_output.image_output.to_tuple() if flava_output.image_output is not None else None,
                text_embeddings,
                flava_output.text_output.to_tuple() if flava_output.text_output is not None else None,
                flava_output.multimodal_embeddings,
                flava_output.multimodal_output.to_tuple() if flava_output.multimodal_output is not None else None,
                image_masked_embeddings,
                flava_masked_output.image_output.to_tuple() if flava_masked_output.image_output is not None else None,
                text_masked_embeddings,
                flava_masked_output.text_output.to_tuple() if flava_masked_output.text_output is not None else None,
                multimodal_masked_embeddings,
                flava_masked_output.multimodal_output.to_tuple()
                if flava_masked_output.multimodal_output is not None
                else None,
                mim_logits,
                mlm_logits,
                itm_logits,
                logits_per_image,
                logits_per_image,
                mmm_image_logits,
                mmm_text_logits,
            )
            if return_loss and not flava_losses.all_none():
                output = (
                    total_loss,
                    flava_losses,
                ) + output

            # Filter None as transformer by default won't handle it
            return tuple(x for x in output if x is None)

        return FlavaGOTForPreTrainingOutput(
            loss=total_loss,
            loss_info=flava_losses,
            image_embeddings=image_embeddings,
            image_output=flava_output.image_output,
            text_embeddings=text_embeddings,
            text_output=flava_output.text_output,
            multimodal_embeddings=flava_output.multimodal_embeddings,
            multimodal_output=flava_output.multimodal_output,
            image_masked_embeddings=image_masked_embeddings,
            image_masked_output=flava_masked_output.image_output,
            text_masked_embeddings=text_masked_embeddings,
            text_masked_output=flava_masked_output.text_output,
            multimodal_masked_embeddings=multimodal_masked_embeddings,
            multimodal_masked_output=flava_masked_output.multimodal_output,
            mim_logits=mim_logits,
            mlm_logits=mlm_logits,
            itm_logits=itm_logits,
            contrastive_logits_per_image=logits_per_image,
            contrastive_logits_per_text=logits_per_text,
            mmm_image_logits=mmm_image_logits,
            mmm_text_logits=mmm_text_logits,
        )

## Optimal Transport

In [None]:
import numpy as np
import torch
from functools import partial
from sklearn.metrics.pairwise import euclidean_distances
from torch.autograd import Variable
import pdb

def cost_matrix_torch(x, y):
	"Returns the cosine distance"
	# x is the image embedding
	# y is the text embedding
	D = x.size(0)
	x = x.view(D, -1)
	assert(x.size(0)==y.size(0))
	x = x.div(torch.norm(x, p=2, dim=0, keepdim=True) + 1e-12)
	y = y.div(torch.norm(y, p=2, dim=0, keepdim=True) + 1e-12)
	cos_dis = torch.mm(torch.transpose(y,0,1), x)#.t()
	cos_dis = 1 - cos_dis # to minimize this value
	return cos_dis

def IPOT_torch(C, n, m, miu, nu, beta=0.5):
	# C is the distance matrix
	# c: n by m
	# miu: bs * n
	sigma = torch.ones(int(m), 1).float().cuda()/m # bs * m * 1
	T = torch.ones(n, m).cuda()
	C = torch.exp(-C/beta).float()
	for t in range(20):
		T = C * T # n * m
		for k in range(1):
			delta = miu / torch.squeeze(torch.matmul(T, sigma))
			# a = torch.matmul(torch.transpose(T,0,1), torch.unsqueeze(delta,1))
			# sigma = torch.unsqueeze(nu,1) / a
			sigma = torch.unsqueeze(nu,1) / torch.matmul(torch.transpose(T,0,1), torch.unsqueeze(delta,1))
		# tmp = torch.mm(torch.diag(torch.squeeze(delta)), Q)
		# tmp = torch.unsqueeze(delta,1) * A
		# dim_ = torch.diag(torch.squeeze(sigma)).dim()
		# dim_ = torch.diag(torch.squeeze(sigma)).dim()
		# assert (dim_ == 2 or dim_ == 1, "dim_ is %d" % dim_)
		# T = torch.mm(torch.unsqueeze(delta,1) * T, torch.diag(torch.squeeze(sigma)))
		T = torch.unsqueeze(delta,1) * T * sigma.transpose(1,0)
	return T.detach()

def IPOT_distance_torch(C, n, m, miu, nu):
	C = C.float().cuda()
	T = IPOT_torch(C, n, m, miu, nu)
	distance = torch.trace(torch.mm(torch.transpose(C,0,1), T))
	return -distance


def IPOT_distance_torch_batch(C, n, m, miu, nu, iteration):
	# C as a 2 d matrix
	C = C.float().cuda()
	bs = miu.size(0)
	# if C.dim()==2:
	# 	C=C.repeat(bs, 1, 1)
	if C.dim()==2:
		C = torch.unsqueeze(C, 0)
	# if not bs == C.size(0):
	# 	print('break')
	# assert(bs == C.size(0))
	T = IPOT_torch_batch(C, bs, n, m, miu, nu, iteration)
	temp = torch.matmul(torch.transpose(C,1,2), T)
	distance = batch_trace(temp, m, bs)
	return -distance


def IPOT_torch_batch(C, bs, n, m, miu, nu, iteration=20, beta=0.5):
	# C is the distance matrix, 2d matrix
	# c: n by m
	# miu: bs * n
	sigma = torch.ones(bs, int(m), 1).cuda().detach()/float(m) # bs * m * 1
	Q = torch.ones(bs, n, m).cuda().detach().float()
	C = torch.exp(-C/beta)#.unsqueeze(0)
	if nu.dim() < 3:
		nu = torch.unsqueeze(nu,2)
	# if miu.dim()<3:
	# 	miu = torch.unsqueeze(miu,1)
	miu = torch.squeeze(miu)
	for t in range(iteration):
		Q = C * Q # bs * n * m
		for k in range(1):
			delta = torch.unsqueeze((miu / torch.squeeze(torch.bmm(Q, sigma)+1e-6)),2)
			# delta = ((miu / (torch.bmm(Q, sigma) + 1e-6)))
			a = torch.bmm(torch.transpose(Q,1,2), delta)+1e-6
			sigma = nu / a
		Q = delta * Q * sigma.transpose(2,1)
		# Q = torch.matmul(tmp, diag_sigma)
	return Q.detach()

def IPOT_torch_uniform(C, n, m, beta=0.5):
	# C is the distance matrix
	sigma = torch.ones(int(m), 1).cuda()/m
	T = torch.ones(n, m).cuda()
	A = torch.exp(-C/beta)
	for t in range(50):
		Q = A * T # n * m
		for k in range(1):
			delta = 1 / (n * torch.mm(Q, sigma))
			a = torch.mm(torch.transpose(Q,0,1), delta)
			sigma = 1 / (float(m) * a)
		tmp = torch.mm(torch.diag(torch.squeeze(delta)), Q)
		dim_ = torch.diag(torch.squeeze(sigma)).dim()
		assert (dim_ == 2 or dim_ == 1)
		T = torch.mm(tmp, torch.diag(torch.squeeze(sigma)))
	return T.detach()

def IPOT_distance_torch_uniform(C, n, m):
	C = C.float().cuda()
	T = IPOT_torch_uniform(C, n, m)
	distance = torch.trace(torch.mm(torch.transpose(C,0,1), T))
	return distance


def cost_matrix_batch_torch(x, y):
	"Returns the cosine distance batchwise"
	# x is the image feature: bs * d * m * m
	# y is the audio feature: bs * d * nF
	# return: bs * n * m
	# print(x.size())
	bs = list(x.size())[0]
	D = x.size(1)
	assert(x.size(1)==y.size(1))
	x = x.contiguous().view(bs, D, -1) # bs * d * m^2
	x = x.div(torch.norm(x, p=2, dim=1, keepdim=True) + 1e-12)
	y = y.div(torch.norm(y, p=2, dim=1, keepdim=True) + 1e-12)
	cos_dis = torch.bmm(torch.transpose(x, 1, 2), y)#.transpose(1,2)
	cos_dis = 1 - cos_dis # to minimize this value
	# cos_dis = - cos_dis
	return cos_dis.transpose(2,1)


def cost_matrix_batch_torch_acos(x, y):
	"Returns the cosine distance batchwise"
	# x is the image feature: bs * d * m * m
	# y is the audio feature: bs * d * nF
	# return: bs * n * m
	# print(x.size())
	bs = list(x.size())[0]
	D = x.size(1)
	assert(x.size(1)==y.size(1))
	x = x.contiguous().view(bs, D, -1) # bs * d * m^2
	x = x.div(torch.norm(x, p=2, dim=1, keepdim=True) + 1e-12)
	y = y.div(torch.norm(y, p=2, dim=1, keepdim=True) + 1e-12)
	cos_dis = torch.bmm(torch.transpose(x,1,2), y)#.transpose(1,2)
	cos_dis = torch.acos(cos_dis) # to minimize this value
	# cos_dis = - cos_dis
	return cos_dis.transpose(2,1)

def cos_batch_torch(x, y):
	"Returns the cosine distance batchwise"
	# x is the image feature: bs * d * m * m
	# y is the audio feature: bs * d * nF
	# return: bs * n * m
	# print(x.size())
	bs = x.size(0)
	D = x.size(1)
	assert(x.size(1)==y.size(1))
	x = x.contiguous().view(bs, D, -1) # bs * d * m^2
	x = x.div(torch.norm(x, p=2, dim=1, keepdim=True) + 1e-12)
	y = y.div(torch.norm(y, p=2, dim=1, keepdim=True) + 1e-12)
	cos_dis = torch.bmm(torch.transpose(x,1,2), y)#.transpose(1,2)
	cos_dis = 1 - cos_dis # to minimize this value
	# return cos_dis.transpose(2,1)
	# TODO:
	beta = 0.1
	min_score = cos_dis.min()
	max_score = cos_dis.max()
	threshold = min_score + beta * (max_score - min_score)
	res = cos_dis - threshold
	# res = torch.nn.ReLU()

	return torch.nn.functional.relu(res.transpose(2,1))


def pairwise_distances(x, y=None):
	'''
	Input: x is a Nxd matrix
		   y is an optional Mxd matirx
	Output: dist is a NxM matrix where dist[i,j] is the square norm between x[i,:] and y[j,:]
			if y is not given then use 'y=x'.
	i.e. dist[i,j] = ||x[i,:]-y[j,:]||^2
	'''
	x_norm = (x ** 2).sum(1).view(-1, 1)
	if y is not None:
		y_t = torch.transpose(y, 0, 1)
		y_norm = (y ** 2).sum(1).view(1, -1)
	else:
		y_t = torch.transpose(x, 0, 1)
		y_norm = x_norm.view(1, -1)

	dist = x_norm + y_norm - 2.0 * torch.mm(x, y_t)
	# Ensure diagonal is zero if x=y
	# if y is None:
	#     dist = dist - torch.diag(dist.diag)
	return torch.clamp(dist, 0.0, np.inf)

def row_pairwise_distances(x, y=None, dist_mat=None):
    if y is None:
        y = x
    if dist_mat is None:
        dtype = x.data.type()
        dist_mat = Variable(torch.Tensor(x.size()[0], y.size()[0]).type(dtype))

    for i, row in enumerate(x.split(1)):
        r_v = row.expand_as(y)
        sq_dist = torch.sum((r_v - y) ** 2, 1)
        dist_mat[i] = sq_dist.view(1, -1)
    return dist_mat

def IPOT_barycenter(p, C, q, iteration=20, beta=0.5, iteration_inner = 1):
	'''

	:param p: probability vector set, K x n
	:param C: cost matrix, K x n x n
	:param q: initial q, mean of all support, n x d
	:return:
	'''
	K = p.size(0)
	n = p.size(1)
	assert(C.size(1)==C.size(2))
	assert(C.size(1)==p.size(1))
	b = torch.ones(K, int(n), 1).cuda().detach()/float(n) # bs * m * 1
	C = torch.exp(-C/beta)
	T = torch.ones(K, n, n).cuda().detach().float()
	q = torch.unsqueeze(q, 0)
	for t in range(iteration):
		H = T * C
		for k in range(iteration_inner):
			a = q/torch.bmm(H, b)
			b = p/torch.bmm(torch.transpose(H, 2, 1), a)
			q = a * (torch.bmm(H, b))
		T = a * H * b.transpose(2,1)
	return q


def IPOT_distance_torch_batch_uniform(C, bs, n, m, iteration=50):
	C = C.float().cuda()
	T = IPOT_torch_batch_uniform(C, bs, n, m, iteration=iteration)
	temp = torch.bmm(torch.transpose(C,1,2), T)
	distance = batch_trace(temp, m, bs)
	return -distance

def IPOT_distance_torch_batch_uniform_T(C, bs, n, m, iteration=50):
	C = C.float().cuda()
	T = IPOT_torch_batch_uniform(C, bs, n, m, iteration=iteration)
	# temp = torch.bmm(torch.transpose(C,1,2), T)
	# distance = batch_trace(temp, m, bs)
	return T


def IPOT_torch_batch_uniform(C, bs, n, m, beta=0.5, iteration=50):
	# C is the distance matrix
	# c: bs by n by m
	sigma = torch.ones(bs, int(m), 1).cuda()/float(m)
	T = torch.ones(bs, n, m).cuda()
	A = torch.exp(-C/beta).float().cuda()
	for t in range(iteration):
		Q = A * T # bs * n * m
		for k in range(1):
			delta = 1 / (n * torch.bmm(Q, sigma))
			a = torch.bmm(torch.transpose(Q,1,2), delta)
			sigma = 1 / (float(m) * a)
		T = delta * Q * sigma.transpose(2,1)

	return T#.detach()


def GW_distance(X, Y, p, q, lamda=0.5, iteration=5, OT_iteration=20):
	'''
	:param X, Y: Source and target embeddings , batchsize by embed_dim by n
	:param p, q: probability vectors
	:param lamda: regularization
	:return: GW distance
	'''
	Cs = cos_batch_torch(X, X).float().cuda()
	Ct = cos_batch_torch(Y, Y).float().cuda()
	# pdb.set_trace()
	bs = Cs.size(0)
	m = Ct.size(2)
	n = Cs.size(2)
	T, Cst = GW_torch_batch(Cs, Ct, bs, n, m, p, q, beta=lamda, iteration=iteration, OT_iteration=OT_iteration)
	temp = torch.bmm(torch.transpose(Cst,1,2), T)
	distance = batch_trace(temp, m, bs)
	return distance

def GW_torch_batch(Cs, Ct, bs, n, m, p, q, beta=0.5, iteration=5, OT_iteration=20):
	one_m = torch.ones(bs, m, 1).float().cuda()
	one_n = torch.ones(bs, n, 1).float().cuda()

	Cst = torch.bmm(torch.bmm(Cs**2, p), torch.transpose(one_m, 1, 2)) + \
	      torch.bmm(one_n, torch.bmm(torch.transpose(q,1,2), torch.transpose(Ct**2, 1, 2))) # bs by n by m
	gamma = torch.bmm(p, q.transpose(2,1)) # outer product, init
	# gamma = torch.einsum('bi,bj->bij', (torch.squeeze(p), torch.squeeze(q))) # outer product, initialization
	for i in range(iteration):
		C_gamma = Cst - 2 * torch.bmm(torch.bmm(Cs, gamma), torch.transpose(Ct, 1, 2))
		# # Sinkhorn iteration
		# b = torch.ones(bs, m, 1).cuda()
		# K = torch.exp(-C_gamma/beta)
		# for i in range(50):cd
		# 	a = p/(torch.bmm(K, b))
		# 	b = q/torch.bmm(K.transpose(1,2), a)
		# gamma = a * K * b
		gamma = IPOT_torch_batch_uniform(C_gamma, bs, n, m, beta=beta, iteration=OT_iteration)
	Cgamma = Cst - 2 * torch.bmm(torch.bmm(Cs, gamma), torch.transpose(Ct, 1, 2))
	return gamma.detach(), Cgamma

# def GW_torch_batch(Cs, Ct, bs, n, m, beta=0.5, iteration=5, OT_iteration=20):
# 	one_m = torch.ones(bs, m, 1).float().cuda()
# 	one_n = torch.ones(bs, n, 1).float().cuda()
# 	p = (torch.ones(bs, m, 1)/m).cuda()
# 	q = (torch.ones(bs, n, 1)/n).cuda()

# 	Cst = torch.bmm(torch.bmm(Cs**2, p), torch.transpose(one_m, 1, 2)) + \
# 	      torch.bmm(one_n, torch.bmm(torch.transpose(q,1,2), torch.transpose(Ct**2, 1, 2))) # bs by n by m
# 	gamma = torch.bmm(p, q.transpose(2,1)) # outer product, init
# 	# gamma = torch.einsum('bi,bj->bij', (torch.squeeze(p), torch.squeeze(q))) # outer product, initialization
# 	for i in range(iteration):
# 		C_gamma = Cst - 2 * torch.bmm(torch.bmm(Cs, gamma), torch.transpose(Ct, 1, 2))
# 		gamma = IPOT_torch_batch_uniform(C_gamma, bs, n, m, beta=beta, iteration=OT_iteration)
# 	Cgamma = Cst - 2 * torch.bmm(torch.bmm(Cs, gamma), torch.transpose(Ct, 1, 2))
# 	return gamma.detach(), Cgamma

def GW_distance_uniform(X, Y, lamda=1e-1, iteration=5, OT_iteration=20):
	m = X.size(2)
	n = Y.size(2)
	bs = X.size(0)
	p = (torch.ones(bs, m, 1)/m).cuda()
	q = (torch.ones(bs, n, 1)/n).cuda()
	return GW_distance(X, Y, p, q, lamda=lamda, iteration=iteration, OT_iteration=OT_iteration)


def batch_diag(a_emb, n, bs):
	a = torch.eye(n).cuda().unsqueeze(0).repeat(bs, 1, 1) # bs * n * n
	b = (a_emb.unsqueeze(1).repeat(1,n,1))# bs * n * n
	return a*b
	# diagonal bs by n by n

def batch_trace(input_matrix, n, bs):
	a = torch.eye(n).cuda().unsqueeze(0).repeat(bs, 1, 1)
	b = a * input_matrix
	return torch.sum(torch.sum(b,-1),-1).unsqueeze(1)

In [None]:
class GOTLoss(nn.Module):
    def __init__(self, **kwargs):
        super().__init__()
        self.args = kwargs
        self.got_lambda = self.args["got_lambda"]

    def GOT(self, v_, q_):
        cos_distance = cost_matrix_batch_torch(v_.transpose(2, 1), q_.transpose(2, 1))
        cos_distance = cos_distance.transpose(1,2)
        beta = 0.1
        min_score = cos_distance.min()
        max_score = cos_distance.max()
        threshold = min_score + beta * (max_score - min_score)
        cos_dist = torch.nn.functional.relu(cos_distance - threshold)

        wd = - IPOT_distance_torch_batch_uniform(cos_dist, v_.size(0), v_.size(1), q_.size(1), 30)
        gwd = GW_distance_uniform(v_.transpose(2,1), q_.transpose(2,1))
        twd = self.got_lambda * torch.mean(gwd) + self.got_lambda * torch.mean(wd) # Temporarily commented: #self.args.got_lambda *

        return twd

    def forward(self, v_, q_):
        return self.GOT(v_, q_)


In [None]:
got = GOTLoss(got_lambda=0.1).GOT(image_embeddings, text_embeddings)

{'got_lambda': 0.1}


In [None]:
got

tensor(0.1199, device='cuda:0')