In [28]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm import tqdm
import random

In [35]:
class FeatureDataset(torch.utils.data.Dataset):
    def __init__(self, image_features, text_features):
        self.image_features = image_features
        self.text_features = text_features

    def __len__(self):
        return len(self.image_features)

    def __getitem__(self, index):
        return self.image_features[index], self.text_features[index]

In [36]:
class BranchNetwork(nn.Module):
    def __init__(self, img_feature_size, txt_feature_size, fc_dim = 2048, embed_dim = 512):
        super(BranchNetwork, self).__init__()
        self.img_fc = self._branch_network(img_feature_size, fc_dim, embed_dim)
        self.txt_fc = self._branch_network(txt_feature_size, fc_dim, embed_dim)
    
    def _branch_network(self, input_dim, fc_dim, embed_dim):
        return nn.Sequential(
            nn.Linear(input_dim, fc_dim),
            nn.BatchNorm1d(fc_dim, momentum=0.1, eps=1e-5),
            nn.ReLU(),
            nn.Dropout(),
            nn.Linear(fc_dim, embed_dim)
        )

    def forward(self, image_features, text_features):
        img_out = self.img_fc(image_features)
        img_embedding = F.normalize(img_out, p=2, dim=1, eps=1e-10)
        txt_out = self.txt_fc(text_features)
        txt_embedding = F.normalize(txt_out, p=2, dim=1, eps=1e-10)

        return img_embedding, txt_embedding

In [37]:
class EmbeddingLoss(nn.Module):
    def __init__(self, margin=0.2, num_neg_sample=5, im_loss_factor=1.5, txt_only_loss_factor=0.05):
        super(EmbeddingLoss, self).__init__()

        self.im_loss_factor = im_loss_factor
        self.txt_only_loss_factor = txt_only_loss_factor
        self.margin = margin
        self.num_neg_sample = num_neg_sample

    def forward(self, im_embeds, txt_embeds, im_labels, sample_size):
        txt_im_ratio = sample_size
        num_img = im_embeds.shape[0]
        num_txt = num_img * txt_im_ratio

        txt_im_dist = self._pdist(txt_embeds, im_embeds)

        # Image loss
        pos_pair_dist = txt_im_dist[im_labels].view(num_img, 1)
        neg_pair_dist = txt_im_dist[~im_labels].view(num_img, -1)
        im_loss = F.relu(self.margin + pos_pair_dist - neg_pair_dist).topk(self.num_neg_sample)[0].mean()

        # Sentence loss
        neg_pair_dist = txt_im_dist.t()[~im_labels.t()].view(num_img, -1).repeat(1, txt_im_ratio)
        txt_loss = F.relu(self.margin + pos_pair_dist - neg_pair_dist).topk(self.num_neg_sample)[0].mean()

        # Sentence only loss
        txt_txt_dist = self._pdist(txt_embeds, txt_embeds)
        txt_txt_mask = torch.reshape(im_labels.t().repeat(1, txt_im_ratio), (num_txt, num_txt))
        pos_pair_dist = txt_txt_dist[txt_txt_mask].max(dim=0, keepdim=True)[0]
        neg_pair_dist = txt_txt_dist[~txt_txt_mask].view(num_txt, -1)
        sent_only_loss = F.relu(self.margin + pos_pair_dist - neg_pair_dist).topk(self.num_neg_sample)[0].mean()

        loss = im_loss * self.im_loss_factor + txt_loss + sent_only_loss * self.txt_only_loss_factor
        return loss

    def _pdist(self, x1, x2):
        x1_square = torch.sum(x1**2, dim=1).view(-1, 1)
        x2_square = torch.sum(x2**2, dim=1).view(1, -1)
        return torch.sqrt(x1_square - 2 * torch.mm(x1, x2.t()) + x2_square + 1e-4)

In [38]:
def generate_neg_pairs(image_embedding, text_embedding, sample_size):
    m, _ = image_embedding.shape
    txt_embeds = torch.zeros(sample_size * m, text_embedding.shape[1])
    im_labels = torch.zeros(sample_size * m, m, dtype=bool)
    all_indices = np.arange(m)
    for i in range(m):
        indices = np.delete(all_indices, i)
        random_indices = np.random.choice(indices, size=sample_size-1, replace=False)
        txt_embeds[i * sample_size, :] = text_embedding[i, :]
        txt_embeds[(i * sample_size + 1) : ((i + 1) * sample_size), :] = text_embedding[random_indices, :]
        im_labels[0, i] = True
    return image_embedding, txt_embeds, im_labels

In [39]:
def set_seed(seed):
    random.seed(seed)
    torch.manual_seed(seed)
    # transformers.set_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

In [40]:
SEED = 595
set_seed(SEED)

image_features = np.load('features/image_features_resnet50_remove.npy')
text_features = np.load('features/text_LaBSE_merge.npy')
batch_size = 256
img_feature_size = image_features.shape[1]
txt_feature_size = text_features.shape[1]
learning_rate = 0.0001
weight_decay = 0.01
epochs = 10
sample_size = 5
embed_dim = 1024

dataset = FeatureDataset(image_features, text_features)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
branch_net = BranchNetwork(img_feature_size, txt_feature_size, embed_dim = embed_dim)
# optimizer = torch.optim.Adam(branch_net.parameters(), lr=learning_rate)
# contrastive_loss = ContrastiveLoss(0.07)
optimizer = torch.optim.AdamW(branch_net.parameters(),
                                lr=learning_rate)#,
                                # weight_decay=weight_decay)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                    T_max=epochs,
                                                    eta_min=learning_rate/50)
embed_loss = EmbeddingLoss()

# Training
for epoch in range(epochs):
    for image_batch, text_batch in dataloader:
        image_embedding, text_embedding = branch_net(image_batch, text_batch)
        im_embeds, txt_embeds, im_labels = generate_neg_pairs(image_embedding, text_embedding, sample_size)
        loss = embed_loss(im_embeds, txt_embeds, im_labels, sample_size)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    lr_scheduler.step()
    print(f"Epoch {epoch+1}/{epochs}, Loss: {loss.item()}")

Epoch 1/10, Loss: 0.5996758341789246
Epoch 2/10, Loss: 0.5843840837478638
Epoch 3/10, Loss: 0.6054803729057312
Epoch 4/10, Loss: 0.5952678322792053
Epoch 5/10, Loss: 0.5595031380653381
Epoch 6/10, Loss: 0.5649318099021912
Epoch 7/10, Loss: 0.5687665343284607
Epoch 8/10, Loss: 0.5615719556808472
Epoch 9/10, Loss: 0.5653002858161926
Epoch 10/10, Loss: 0.5565263032913208


In [31]:
i_result, t_result = branch_net.forward(torch.tensor(image_features), torch.tensor(text_features))

In [32]:
np.save('trained_img_features.npy',i_result.detach().numpy())

In [33]:
np.save('trained_txt_features.npy',t_result.detach().numpy())

In [41]:
!python3.11 train.py --text_feature='features/text_mBERT_merge.npy' --model='BinaryRelevance(LogisticRegression(random_state=42))'

(11854, 768)
BinaryRelevance(LogisticRegression(random_state=42))
0.5426593689632968
Sample-based Accuracy: 0.41611845187892416
Sample-based Precision: 0.5734232967856808
Sample-based Recall: 0.5037578483740981
Sample-based F1-score: 0.5009937927002869
Label-based Accuracy: 0.35450781571043094
Label-based Precision: 0.6374963051564114
Label-based Recall: 0.443306598027435
Label-based F1-score: 0.520058912991418


In [42]:
!python3.11 train.py --image_feature='features/image_features_resnet50_remove.npy' --model='BinaryRelevance(LogisticRegression(random_state=42))'

(11854, 2048)
BinaryRelevance(LogisticRegression(random_state=42))
0.5026655334930078
Sample-based Accuracy: 0.37796698662596895
Sample-based Precision: 0.5286336800674726
Sample-based Recall: 0.48488426576703214
Sample-based F1-score: 0.4677519400600649
Label-based Accuracy: 0.3028654724292546
Label-based Precision: 0.5109079695300565
Label-based Recall: 0.4062938233821338
Label-based F1-score: 0.4487280098957174


In [43]:
!python3.11 train.py --text_feature='features/text_mBERT_merge.npy' --image_feature='features/image_features_resnet50_remove.npy' --model='BinaryRelevance(LogisticRegression(random_state=42))'

(11854, 768)
(11854, 2048)
BinaryRelevance(LogisticRegression(random_state=42))
0.5969511746078504
Sample-based Accuracy: 0.4678380657857745
Sample-based Precision: 0.607928029238122
Sample-based Recall: 0.6000749695436229
Sample-based F1-score: 0.5647585132825503
Label-based Accuracy: 0.4147852419374694
Label-based Precision: 0.6221411222292456
Label-based Recall: 0.5453912901237937
Label-based F1-score: 0.5801559515731219


In [47]:
!python3.11 train.py --text_feature='features/text_mBERT_merge.npy' --model='OneVsRestClassifier(LogisticRegression(random_state=42))'

(11854, 768)
OneVsRestClassifier(LogisticRegression(random_state=42))
0.5426593689632968
Sample-based Accuracy: 0.41611845187892416
Sample-based Precision: 0.5734232967856808
Sample-based Recall: 0.5037578483740981
Sample-based F1-score: 0.5009937927002869
Label-based Accuracy: 0.35450781571043094
Label-based Precision: 0.6374963051564114
Label-based Recall: 0.443306598027435
Label-based F1-score: 0.520058912991418


In [48]:
!python3.11 train.py --image_feature='features/image_features_resnet50_remove.npy' --model='OneVsRestClassifier(LogisticRegression(random_state=42))'

(11854, 2048)
OneVsRestClassifier(LogisticRegression(random_state=42))
0.5026655334930078
Sample-based Accuracy: 0.37796698662596895
Sample-based Precision: 0.5286336800674726
Sample-based Recall: 0.48488426576703214
Sample-based F1-score: 0.4677519400600649
Label-based Accuracy: 0.3028654724292546
Label-based Precision: 0.5109079695300565
Label-based Recall: 0.4062938233821338
Label-based F1-score: 0.4487280098957174


In [49]:
!python3.11 train.py --text_feature='features/text_mBERT_merge.npy' --image_feature='features/image_features_resnet50_remove.npy' --model='OneVsRestClassifier(LogisticRegression(random_state=42))'

(11854, 768)
(11854, 2048)
OneVsRestClassifier(LogisticRegression(random_state=42))
0.5969511746078504
Sample-based Accuracy: 0.4678380657857745
Sample-based Precision: 0.607928029238122
Sample-based Recall: 0.6000749695436229
Sample-based F1-score: 0.5647585132825503
Label-based Accuracy: 0.4147852419374694
Label-based Precision: 0.6221411222292456
Label-based Recall: 0.5453912901237937
Label-based F1-score: 0.5801559515731219


xlm

In [50]:
!python3.11 train.py --text_feature='features/text_XLM-RoBERTa_merge.npy' --model='BinaryRelevance(LogisticRegression(random_state=42))'

(11854, 768)
BinaryRelevance(LogisticRegression(random_state=42))
0.48500365764447695
Sample-based Accuracy: 0.3595305032330615
Sample-based Precision: 0.5279261549995314
Sample-based Recall: 0.40648486552338114
Sample-based F1-score: 0.42957280747553445
Label-based Accuracy: 0.27399551726596105
Label-based Precision: 0.7213468475990187
Label-based Recall: 0.3130077107263505
Label-based F1-score: 0.4218470926741217


In [51]:
!python3.11 train.py --image_feature='features/image_features_resnet50_remove.npy' --model='BinaryRelevance(LogisticRegression(random_state=42))'

(11854, 2048)
BinaryRelevance(LogisticRegression(random_state=42))
0.5026655334930078
Sample-based Accuracy: 0.37796698662596895
Sample-based Precision: 0.5286336800674726
Sample-based Recall: 0.48488426576703214
Sample-based F1-score: 0.4677519400600649
Label-based Accuracy: 0.3028654724292546
Label-based Precision: 0.5109079695300565
Label-based Recall: 0.4062938233821338
Label-based F1-score: 0.4487280098957174


In [52]:
!python3.11 train.py --text_feature='features/text_XLM-RoBERTa_merge.npy' --image_feature='features/image_features_resnet50_remove.npy' --model='BinaryRelevance(LogisticRegression(random_state=42))'

(11854, 768)
(11854, 2048)
BinaryRelevance(LogisticRegression(random_state=42))
0.5433685176686553
Sample-based Accuracy: 0.4157161599528763
Sample-based Precision: 0.5632227532564895
Sample-based Recall: 0.52908349732921
Sample-based F1-score: 0.5072483946253977
Label-based Accuracy: 0.3473223671162996
Label-based Precision: 0.5798377517626594
Label-based Recall: 0.456364042268391
Label-based F1-score: 0.5055567773541776


In [53]:
!python3.11 train.py --text_feature='features/text_XLM-RoBERTa_merge.npy' --model='OneVsRestClassifier(LogisticRegression(random_state=42))'

(11854, 768)
OneVsRestClassifier(LogisticRegression(random_state=42))
0.48500365764447695
Sample-based Accuracy: 0.3595305032330615
Sample-based Precision: 0.5279261549995314
Sample-based Recall: 0.40648486552338114
Sample-based F1-score: 0.42957280747553445
Label-based Accuracy: 0.27399551726596105
Label-based Precision: 0.7213468475990187
Label-based Recall: 0.3130077107263505
Label-based F1-score: 0.4218470926741217


In [54]:
!python3.11 train.py --image_feature='features/image_features_resnet50_remove.npy' --model='OneVsRestClassifier(LogisticRegression(random_state=42))'

(11854, 2048)
OneVsRestClassifier(LogisticRegression(random_state=42))
0.5026655334930078
Sample-based Accuracy: 0.37796698662596895
Sample-based Precision: 0.5286336800674726
Sample-based Recall: 0.48488426576703214
Sample-based F1-score: 0.4677519400600649
Label-based Accuracy: 0.3028654724292546
Label-based Precision: 0.5109079695300565
Label-based Recall: 0.4062938233821338
Label-based F1-score: 0.4487280098957174


In [55]:
!python3.11 train.py --text_feature='features/text_XLM-RoBERTa_merge.npy' --image_feature='features/image_features_resnet50_remove.npy' --model='OneVsRestClassifier(LogisticRegression(random_state=42))'

(11854, 768)
(11854, 2048)
OneVsRestClassifier(LogisticRegression(random_state=42))
0.5433685176686553
Sample-based Accuracy: 0.4157161599528763
Sample-based Precision: 0.5632227532564895
Sample-based Recall: 0.52908349732921
Sample-based F1-score: 0.5072483946253977
Label-based Accuracy: 0.3473223671162996
Label-based Precision: 0.5798377517626594
Label-based Recall: 0.456364042268391
Label-based F1-score: 0.5055567773541776
