<a href="https://colab.research.google.com/github/EkaterinaVoloshina/ASR_probing/blob/main/image_similarity.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
%cd drive/MyDrive/Thesis

/content/drive/MyDrive/Thesis


In [5]:
from PIL import Image
import requests
from io import BytesIO
from tqdm.notebook import tqdm
import csv
import os
import json

import pandas as pd
import numpy as np
import torchvision.transforms as transforms

from torchvision.datasets import ImageFolder
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.neighbors import NearestNeighbors

from torch.utils.data import DataLoader, Dataset
from torchvision.models import resnet50
from itertools import combinations

In [19]:
class DatasetLoader(Dataset):
    def __init__(self, filename, dir_name="images",
                 image_size=256, device="cpu", 
                 download=False, from_cache=False):
        self.device = device
        self.image_size = image_size
        self.dir_name = dir_name
        self.loader = transforms.Compose([
            transforms.Resize((image_size, image_size)),  # scale imported image
            transforms.ToTensor(),
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),])

        if from_cache:
            self.image_tensor, self.labels = self.load_from_cache(
                dir_name=dir_name,
                filename=filename
            )
        else:
            annotations, images = self.open_file(filename)
            self.image_tensor, self.labels = self.load_dataset(
                images=images,
                annotations=annotations,
                download=download)
            
    
    def __len__(self):
        return len(self.image_tensor)

    def __getitem__(self, idx):
        return self.image_tensor[idx], self.labels[idx]
    
    def load_from_cache(self, dir_name, filename):
        annotations = pd.read_csv(filename)
        labels = []
        image_tensor = torch.Tensor(len(annotations), 3, self.image_size, self.image_size)
        for num, image_name in enumerate(os.listdir(dir_name)):
            img = Image.open(os.path.join(dir_name, image_name)).convert('RGB')
            image = self.loader(img)
            idx = image_name.split(".")[0]
            image_tensor[num, :, :, :] = image
            label = annotations[annotations["idx"] == int(idx)].values.tolist()[0]
            print(label)
            labels.append(label)
        return image_tensor, labels
    
    def open_file(self, filename):
        with open(filename) as f:
            content = json.load(f)
        return (content["annotations"],
        content["images"])

    def open_and_save(self, url, filename=None, download=False):
        response = requests.get(url)
        try:
            img = Image.open(BytesIO(response.content)).convert('RGB')
            image = self.loader(img)
            if download:
                img.save(os.path.join(self.dir_name, f"{filename}.jpg"))
            return image
        except:
            print(url)

    def find_annotation(self, idx, annotations):
        labels = []
        for annotation in annotations:
            if annotation[0]["photo_flickr_id"] == idx:
                labels.append(annotation[0]["original_text"])

        return labels

    def load_dataset(self, images, annotations, download=False):
        labels = []
        images = images[:50]
        image_tensor = torch.Tensor(len(images), 3, self.image_size, self.image_size)
        for num, image in tqdm(enumerate(images)):
            idx = image["id"]
            label = self.find_annotation(idx, annotations)
            if label != []:
                img = self.open_and_save(image["url_o"], 
                                         filename=image["id"],
                                         download=download)
                if img != None: 
                    image_tensor[num, :, :, :] = img
                    labels.append([str(idx), label[0], image["album_id"]])
        
        image_tensor = image_tensor[:len(labels), :, :, :]

        if download:
            with open('labels.csv','w') as f:
                w = csv.writer(f)
                w.writerow(("idx", "annotation", "album_id"))
                for label in labels:
                    w.writerow(label)
                
        return image_tensor, labels    

In [7]:
class EncoderModel(torch.nn.Module):
    def __init__(self, img_size):
        super(EncoderModel, self).__init__()
        pretrained_model = resnet50(pretrained=True)
        self.model = torch.nn.Sequential(*list(pretrained_model.children())[:-1])
        self.adaptive_pool = torch.nn.AdaptiveAvgPool2d((img_size, img_size))
    
    def forward(self, x):
        x = self.model(x)
        x = self.adaptive_pool(x)
        x = x.max(-1).values.max(-1).values
        return x

In [8]:
def get_vectors(data, model, device="cpu"):
    images = [] # change to tensor
    annotations = [] 
    albums = []
    ids = []
    for image, (idx, annotation, album_id) in data:
        image = image.to(device)
        vector = model(image)
        images.append(vector)
        ids.append(idx)
        annotations.append(annotation)
        albums.append(album_id)
    return images, ids, annotations, albums

In [9]:
def confuse_captions(images, annotations, ids, album):

    new_annotations = {}

    cos_sim = cosine_similarity(images[0].detach().cpu().numpy())

    сheck_album_id = lambda img1, img2: True if album[img1] == album[img2] else False

    
    for i, vector in enumerate(cos_sim):
        closest = np.argsort(vector)
        if closest[-1] == i:
            closest_value_idx = -2
        else:
            closest_value_idx = -1

        closest_value = closest[closest_value_idx]
        while сheck_album_id(i, closest_value) and closest_value != -1:
            closest_value_idx -= 1
            closest_value = closest[closest_value_idx]
        
        new_annotations[ids[0][i]] = (annotations[0][i], annotations[0][closest_value])
    
    return new_annotations

In [10]:
%mkdir images

mkdir: cannot create directory ‘images’: File exists


In [20]:
#loader = DatasetLoader(filename="dii/test.description-in-isolation.json")
#                       download=True)
loader = DatasetLoader(filename="labels.csv", dir_name="images", from_cache=True)
data = torch.utils.data.DataLoader(loader, batch_size=32)
model = EncoderModel(256).to("cpu")

[1741642, 'The sign is describing when the services will begin.', 44277]
[1741587, 'A man in a top hat has a magic trick on the floor.', 44277]
[1741622, 'A older man with a black hat, mustache and glasses.', 44277]
[1741640, 'Sitting there waiting on someone to come over and buy something.', 44277]
[1741632, 'a case full of books in a house, books appear to be old', 44277]
[355205, 'Taken at some sort of carnival, the camera captured the movement and lights of the amusement ride.', 8139]
[355331, 'Large stuffed neon ape toys hang from the ceiling of a carnival game.', 8139]
[355208, 'two children riding on a dragon roller coaster', 8139]
[355204, 'Two girls smiling while sitting in a cart for a carnival ride.', 8139]
[355332, 'A mother and her daughters look at a carnival game.', 8139]
[21728852, 'Furry animals being pet by some people inside a building', 504823]
[21725505, 'A small car is demoed at a show under blue lighting.', 504823]
[21731442, 'the guy in pink shirt is riding a mo



In [22]:
images, ids, annotations, albums = get_vectors(data, model)

In [None]:
images[1]

In [None]:
table_ann = confuse_captions(images, annotations, ids, albums[0])

with open('annotations.csv','w') as f:
    w = csv.writer(f)
    w.writerows(table_ann.items())