In [3]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, random_split
from transformers import AutoTokenizer, BertModel
from tqdm import tqdm
from data_utils import *
import fasttext, re, json, faiss
import numpy as np
from sklearn.neighbors import NearestNeighbors
from sklearn.metrics import average_precision_score
import matplotlib.pyplot as plt

%load_ext autoreload
%autoreload 2

In [4]:
config = {
    'IMG_WIDTH': 224,
    'IMG_HEIGHT': 224,
    'TEST_DATASET_DIR': 'data/val2014',
    'batch_size': 64,
    'epochs': 80,
    'learning_rate': 0.0001,
    'n_neighbors': 5,
    'classifier': 'knn',
    'text_encoder_type': 'ft',
    'metric': 'euclidean',
    'voronoi_cells': 64,
    'lookup_cells': 8,
    'device': torch.device("cuda" if torch.cuda.is_available() else "cpu")
}

In [5]:
# File containing images info (file_name)
with open(f"data/captions_val2014.json", "r") as f:
    captions_val = json.load(f)

# Dataset for the retrieval
retrieval_dataset = CocoMetricDataset(
    root=config["TEST_DATASET_DIR"],
    captions_file=captions_val,
    transforms=CustomTransform(config, mode="val"))

total_length = len(retrieval_dataset)
train_size = int(0.6 * total_length)  # e.g., 60% for training
valid_size = int(0.2 * total_length)  # e.g., 20% for validation
test_size = total_length - train_size - valid_size # remaining 20% for testing
train_dataset, validation_dataset, test_dataset = random_split(retrieval_dataset, [train_size, valid_size, test_size])

dataloader_train = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True, collate_fn=coco_collator)
dataloader_validation = DataLoader(validation_dataset, batch_size=config['batch_size'], shuffle=False, collate_fn=coco_collator)
dataloader_test = DataLoader(test_dataset, batch_size=config['batch_size'], shuffle=False, collate_fn=coco_collator)

Creating image-answer pairs...: 100%|██████████| 202654/202654 [10:52<00:00, 310.80it/s]


In [5]:
def normalize_vector(vec):
    norm = np.sqrt(np.sum(vec**2))
    if not norm==0:
        return vec/norm
    else:
        return vec

# Model Definition remains the same
class Net(torch.nn.Module):
    def __init__(self, text_encoder_type):
        super(Net, self).__init__()

        self.visual_encoder = torch.hub.load('pytorch/vision:v0.10.0', 'resnet50', weights='ResNet50_Weights.DEFAULT')
        self.visual_encoder.fc = nn.Identity()
        
        self.text_encoder_type = text_encoder_type
        if self.text_encoder_type == 'ft':
            self.text_encoder = fasttext.load_model('fasttext_wiki.en.bin')
            text_dimension = 300
        else:
            self.tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
            self.text_encoder = BertModel.from_pretrained("google-bert/bert-base-uncased")
            self.text_encoder.eval()
            text_dimension = 768

            for param in self.text_encoder.parameters():
                param.requires_grad = False

        self.proj = nn.Linear(text_dimension, 2048)

    def sentence_to_vector(self, captions, text_type='ft'):
        """
        Convert a sentence into a vector
        """
        if text_type == 'ft':
            batch_embedding = []
            for sentence in captions:
                sentence = sentence.lower()
                words = re.findall(r'\b\w+\b', sentence)
                batch_embedding.append(torch.Tensor([normalize_vector(self.text_encoder[word]) for word in words if word in self.text_encoder]).mean(dim=0))
            return torch.stack(batch_embedding)
        else:
            inputs = self.tokenizer(captions, return_tensors="pt", padding="longest", add_special_tokens=True, return_attention_mask=True)
            return self.text_encoder(**inputs).last_hidden_state[:, 0, :]
        
    def forward(self, image, captions):
        return self.visual_encoder(image), self.proj(self.sentence_to_vector(captions, text_type=self.text_encoder_type))

model = Net(config['text_encoder_type'])
# model.load_state_dict(torch.load('pretrained/best_model_triplet.pth', map_location=config['device']))
model = model.to(config['device'])

Using cache found in C:\Users\Luis/.cache\torch\hub\pytorch_vision_v0.10.0


In [23]:
def extract_features(loader, model, config, query=False):
    model.eval()
    with torch.no_grad():
        features, labels = [], []
        for img, captions, labels in tqdm.tqdm(loader, total=len(loader), desc='Extracting features'):
                
                if query:
                    visual_features, text_features = model(img.to(config['device']), captions.to(config['device']))
                else:
                    visual_features, _ = model(img.to(config['device']), captions.to(config['device']))


        return np.concatenate(features).astype('float32'), np.concatenate(labels)

# k-NN Classifier for Image Retrieval
class ImageRetrievalSystem:
    def __init__(self, model, database_loader, query_loader, config):
        self.model = model
        self.database_loader = database_loader
        self.query_loader = query_loader
        self.device = config['device']

        self.dim = 2048
        self.classifier_type = config['classifier']
        self.n_neighbors = config['n_neighbors']

        if self.classifier_type == 'knn':
            self.classifier = NearestNeighbors(n_neighbors=config['n_neighbors'], metric=config['metric'])
        else:
            self.classifier = faiss.IndexIVFFlat(faiss.IndexFlatL2(self.dim), self.dim, config['voronoi_cells'])
            self.classifier.nprobe = config['lookup_cells']

    def fit(self):

        features, self.train_labels = extract_features(self.database_loader, self.model, self.device)
        
        print('Fitting the classifier...')
        if self.classifier_type == 'knn':
            self.classifier.fit(features, self.train_labels)
        else:
            self.classifier.train(features)
            self.classifier.add(features)

    def retrieve(self):

        features, labels = extract_features(self.query_loader, self.model, self.device)
        
        print('Retrieving images...')
        if self.classifier_type == 'knn':
            _, predictions = self.classifier.kneighbors(features, return_distance=True)
        else:
            _, predictions = self.classifier.search(features, self.n_neighbors)
        
        return predictions, labels

In [18]:
for i, (images, captions, labels) in tqdm(enumerate(dataloader_train), total=len(dataloader_train)):
    print(images.shape, len(captions), labels.shape)
    break

  0%|          | 0/1900 [00:01<?, ?it/s]

0 torch.Size([64, 3, 224, 224]) 64 torch.Size([64])





In [22]:
a = [labels.numpy(), labels.numpy()]
np.concatenate(a).shape

(128,)