In [193]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math

In [194]:
text_embd = torch.from_numpy(np.random.random(1024))
img_embd = torch.from_numpy(np.random.random(768))
audio_embd = torch.from_numpy(np.random.random(128))
graph_embd = torch.from_numpy(np.random.random(32))

In [195]:
class Fairouz(nn.Module):
    def __init__(self, out_shape):
        super().__init__()
        self.linear_1 = nn.Linear(1952, 1952*3)
        self.linear_2 = nn.Linear(1952*3, 1952)
        self.linear_3 = nn.Linear(1952, out_shape)

        self.relu = nn.ReLU()

    def forward(self, all_embds):

        all_embds = [torch.concat(all_embd) for all_embd in all_embds]
        all_embds = torch.from_numpy(np.array(all_embds))
        all_embds_reshaped = all_embds.type(torch.FloatTensor)

        x = self.relu(self.linear_1(all_embds_reshaped))
        x = self.relu(self.linear_2(x))
        return self.linear_3(x)

In [198]:
model = Fairouz(256)
embds = model([[text_embd, img_embd, audio_embd, graph_embd],
               [text_embd, img_embd, audio_embd, graph_embd],
               [text_embd, img_embd, audio_embd, graph_embd],
               [text_embd, img_embd, audio_embd, graph_embd]])
embds.shape

torch.Size([4, 256])

#### Loss

In [189]:
class ContrastiveLoss(nn.Module):
    def __init__(self, model, margin =0.5, size_average: bool = True):

      super(ContrastiveLoss, self).__init__()
      self.distance_metric = lambda x, y: 1 - F.cosine_similarity(x, y)
      self.margin = margin
      self.model = model
      self.size_average = size_average

    def forward(self, embedding_anchor, embedding_other, label):
      reps = self.model([embedding_anchor, embedding_other])
      assert len(reps) == 2
      rep_anchor, rep_other = reps
      distances = self.distance_metric(rep_anchor.reshape(-1, 1), rep_other.reshape(-1, 1))
      losses = 0.5 * (
          label.float() * distances.pow(2) + (1 - label).float() * F.relu(self.margin - distances).pow(2)
      )
      return losses.mean() if self.size_average else losses.sum()

In [205]:
loss = ContrastiveLoss(model, margin = 0.5)

# 1 if positive, 0 if negative
labels = torch.tensor([0])

loss([text_embd*-332, img_embd*-335, audio_embd*-334, graph_embd*-332], \
     [text_embd, img_embd, audio_embd, graph_embd], \
     label = labels)

tensor(0.0737, grad_fn=<MeanBackward0>)