In [None]:
"""
Initial loading: 100 gifs
3. The general query pipeline we have for gif query is following:
    1. Input: a query gif (may contain multiple frames)
    2. Extract the gif to multiple images based on frame number (currently using PIL IMAGE package)
    3. For each image frame in gif, query the top similar image from training gif-image collections using CAPSULE.
    4. Rank potential gifs based on all similar images’ corresponding gifID frequencies.
    5. Return the top k similar gif id.
4. Training pipeline:
    1. For each gif in collections, split it into multiple frames.
    Apply feature extraction on each image frame (same as CAPSULE process).
    Insert gif id in our hash table.

"""

import os
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import io
import requests
import random
import cv2
from collections import defaultdict
import pandas as pd
import time
from ast import literal_eval

In [None]:
!pip uninstall opencv-python



In [None]:
!pip install opencv-python==3.4.2.17

Collecting opencv-python==3.4.2.17
  Using cached https://files.pythonhosted.org/packages/8f/8f/a5d2fa3a3309c4e4aa28eb989d81a95b57c63406b4d439758a1a0a810c77/opencv_python-3.4.2.17-cp37-cp37m-manylinux1_x86_64.whl
[31mERROR: albumentations 0.1.12 has requirement imgaug<0.2.7,>=0.2.5, but you'll have imgaug 0.2.9 which is incompatible.[0m
Installing collected packages: opencv-python
Successfully installed opencv-python-3.4.2.17


In [None]:
!pip install opencv-contrib-python==3.4.2.17

Collecting opencv-contrib-python==3.4.2.17
[?25l  Downloading https://files.pythonhosted.org/packages/12/32/8d32d40cd35e61c80cb112ef5e8dbdcfbb06124f36a765df98517a12e753/opencv_contrib_python-3.4.2.17-cp37-cp37m-manylinux1_x86_64.whl (30.6MB)
[K     |████████████████████████████████| 30.6MB 143kB/s 
Installing collected packages: opencv-contrib-python
  Found existing installation: opencv-contrib-python 4.1.2.30
    Uninstalling opencv-contrib-python-4.1.2.30:
      Successfully uninstalled opencv-contrib-python-4.1.2.30
Successfully installed opencv-contrib-python-3.4.2.17


In [None]:
print(cv2.__version__)

4.1.2


In [None]:
import csv
import pandas as pd
from google.colab import drive
drive.mount('/content/gdrive/', force_remount=True)


Mounted at /content/gdrive/


In [None]:
# alg = cv2.AKAZE_create()
# surf = cv2.xfeatures2d.SURF_create(hessianThreshold=900)
surf = cv2.xfeatures2d.SIFT_create(50)



class DataLoader():
    def __init__(self):
        self.numImages = 0
        self.imageUrls = []

    def readImage(self, link):
        """
        update imageUrls from tsv files
        :return:
        """
        # gif_file = open('raw_gifs.tsv')
        # lines = gif_file.readlines()
        # self.numImages = len(lines)
        # for line in lines:
        #     self.imageUrls.append(line.split()[0])

        # for i in range(self.numImages):
        #     print("CHEK!!!!")
        #     print(i,self.imageUrls[i] )
            # response = requests.get(self.imageUrls[i])
        response = requests.get(link)
        image_bytes = io.BytesIO(response.content)
        im = Image.open(image_bytes)
        img_data_lst = self.removeDuplicates(im)

        features_lst = []
        if img_data_lst == None:
            return
        else:
            for img_data in img_data_lst:
                key_points = surf.detect(img_data, None)
                descriptor = surf.compute(img_data, key_points)
                features = descriptor[1] 
                # print("len im_data", img_data.shape)
                features_lst.append(features)
        return features_lst

    def removeDuplicates(self, im):
        """

        1) remove duplicate image for the frame
        :return: [[]]
        """
        array_lst = []
        image_lst = []
        # To iterate through the entire gif
        numframes = 0
        try:
            while 1:
                im.seek(im.tell() + 1)
                numframes += 1
        except EOFError:
            im.seek(0)
        interval = numframes // 5 + 1
        
        try:
            while 1:
                if (im.tell() % interval == 0):
                    new = Image.new("RGBA", im.size)
                    new.paste(im)
                    arr = np.array(new).astype(np.uint8)
                    array_lst.append(arr)
                    image_lst.append(np.array(new))
                # if (im.tell() % 10 == 0):
                #     im.show()
                im.seek(im.tell() + 1)
        except EOFError:
            pass
        if len(image_lst) > 10:
            return -1
        print("numframes", numframes, "interval", interval, "image num", len(image_lst))
        return image_lst



# {hashCode: image id}
# L k*r matrix

# [
# hash table 1: [hashcode 1: [image 1, iamge 2, ..] hash code 2: []]   ]

# def hashfuncGenerate(seed):
#     return hashfunc(sedd)
class SRP:
    def __init__(self, k, L, d, seed):
        self.k = k
        self.L = L  #
        self.d = d
        self.seed = seed
    def hash(self, input):
        # for k times
            # for i in 128 bit input ==> sum
            #     random 1/0 ==> + input[i]/ - input[i]
        #   sum > 0 => 1 / 0
        # k bit of sum
        # convert to base10
        # [1, 2, 3]  * [-1, 1, -1]  ==> [0, 1, 0]  ==
        random.seed(self.seed)
        bitArray = []
        # print("input!!", input)
        for i in range(self.k):
            curSum = 0
            for j in range(len(input)):
                randomVal  = random.choice([-1, 1])
                curSum += randomVal * input[j]
            bitArray.append(int(curSum > 0))
        res = int("".join(str(x) for x in bitArray), 2)

        return res

# srp = SRP(10, 1, 10, 1)
# inputArray = [random.randint(0, 100) for i in range(128)]
# print(inputArray)
# print(srp.hash(inputArray))

class Resovoir:
    def __init__(self, r):
        self.r = r
        self.count = 0
        self.resArray = []
    def insert(self, id):
        if (len(self.resArray) < self.r):
            self.resArray.append(id)
        else:
            prob = random.randint(0, self.count)
            if (prob < self.r):
                self.resArray[prob] = id
        self.count += 1
    def get(self):
        return self.resArray

    def printRes(self):
        return self.resArray

class hashTable:
    def __init__(self, k, L, r):
        self.k = k  # hash dimension
        self.L = L  # number of hash functions
        self.r = r  # resovoir size
        self.hashtables = [[Resovoir(self.r) for j in range(2 ** self.k)] for l in range(self.L)]
        self.d = 128
        self.hashfunc_lst = [SRP(k, L, self.d, i) for i in range(self.L)]

    def insert(self, input, id):
        for l in range(self.L):
            hashed_index = self.hashfunc_lst[l].hash(input)
            self.hashtables[l][hashed_index].insert(id)

    def printHashTable(self):
        for i in range(self.L):
            tableLen = 0
            for j in range(2 ** self.k):
                tableLen += len(self.hashtables[i][j].get())
                # print(i, j, "check", self.hashtables[i][j].printRes())
            print("Table: ", i, " Length: ", tableLen)
    #
    def tocsv(self):

        csvList = []
        for i in range(self.L):
            tableLen = 0
            for j in range(2 ** self.k):
                tableLen += len(self.hashtables[i][j].get())
                csvList.append([i, j, self.hashtables[i][j].get()])
                # print(i, j, "check", self.hashtables[i][j].printRes())
        csvPandas = pd.DataFrame(csvList)
        csvPandas.to_csv("gdrive/MyDrive/HashTable_Result_sift2.csv", index=None, header=["table", "index", "ids"])

    def query(self, features_lst):
        scores = defaultdict(int)
        framecount = 0
        qstart = time.time()
        for features in features_lst:  # features: each frame's feature matrix
            framecount += 1
            for feature in features:  # feature: each frame's feature vector
                for l in range(self.L):
                    hashed_index = self.hashfunc_lst[l].hash(feature)
                    result_ids = self.hashtables[l][hashed_index].get()
                    # print("len of result ids", len(result_ids))
                    for id in result_ids:
                        scores[id] += 1
        print("Query takes", time.time() - qstart, "s")

        return scores

In [None]:
# gif_file = open('gdrive/MyDrive/raw_gifs.tsv')
gif_file = open('gdrive/MyDrive/training.txt')
gif_file_testing = open('gdrive/MyDrive/testing.txt')

#     initialize dataloader
dataloader = DataLoader()
lines = [line.split()[0] for line in gif_file.readlines()]
test_lines = [line_test.split()[0] for line_test in gif_file_testing.readlines()]


numGifs = len(lines)
#       initialize hash table
starttime = time.time()
auc = []
lshHashTable = hashTable(15, 10, 2000)

In [None]:
for id in range(1000):  # numGifs
    link = lines[id]
    print("LINK:", link, "ID", id)
    features_lst = dataloader.readImage(link)  # features_lst: numFrames features matrix
    if features_lst is None:
        continue

    # Insertion
    framecount = 0
    start = time.time()
    for features in features_lst:  # features: each frame's feature matrix
        if features is None:
            print("Boom!")
            break
        print("FRAME COUNT: ", framecount, "feature length: ", len(features))
        framecount += 1
        for feature in features:    # feature: each frame's feature vector
            lshHashTable.insert(feature, id)
    print("Insertion takes", time.time() - start, "s")
    if (id % 30 == 1):
        lshHashTable.tocsv()
    if (id >= 15 and id % 5 == 0):
        correct = 0
        for testId in range(len(test_lines)):
          test_features_lst = dataloader.readImage(test_lines[testId])
          rslt_dict = lshHashTable.query(test_features_lst)
          sort_rslt = sorted(rslt_dict.items(), key=lambda kv: kv[1], reverse=True)[:3]
          print(testId, sort_rslt)
          print("Test Link", test_lines[testId])
          print("Top 1 Link", sort_rslt[0], lines[sort_rslt[0][0]])
          print("Top 2 Link", lines[sort_rslt[1][0]])
          print("Top 3 Link", lines[sort_rslt[2][0]])
          if testId == sort_rslt[0][0] or testId == sort_rslt[1][0] or testId == sort_rslt[2][0]:
                correct += 1
        auc.append(correct / 5)
        print(id, "correctness:", correct/5, "auc", auc)
print(auc)        


LINK: https://media.giphy.com/media/2jCl9ZlZikbOo/giphy.gif ID 0
numframes 22 interval 5 image num 5
FRAME COUNT:  0 feature length:  51
FRAME COUNT:  1 feature length:  50
FRAME COUNT:  2 feature length:  51
FRAME COUNT:  3 feature length:  50
FRAME COUNT:  4 feature length:  50
Insertion takes 19.075226068496704 s
LINK: https://media.giphy.com/media/jp3CUK9cKzImnzCfxR/giphy.gif ID 1
numframes 81 interval 17 image num 5
FRAME COUNT:  0 feature length:  51
FRAME COUNT:  1 feature length:  50
FRAME COUNT:  2 feature length:  50
FRAME COUNT:  3 feature length:  50
FRAME COUNT:  4 feature length:  50
Insertion takes 19.21231698989868 s
LINK: https://media.giphy.com/media/3oEjI0xCVDBd0SxEOI/giphy.gif ID 2
numframes 46 interval 10 image num 5
FRAME COUNT:  0 feature length:  50
FRAME COUNT:  1 feature length:  50
FRAME COUNT:  2 feature length:  50
FRAME COUNT:  3 feature length:  50
FRAME COUNT:  4 feature length:  50
Insertion takes 19.05814552307129 s
LINK: https://media.giphy.com/media/

UnidentifiedImageError: ignored

In [None]:
notInTrainLink = "https://i.pinimg.com/originals/ac/e4/96/ace4966bdadf25bda7acbb128a3393d8.gif"
test_features_lst = dataloader.readImage(notInTrainLink)
print(test_features_lst)
rslt_dict = lshHashTable.query(test_features_lst)
sort_rslt = sorted(rslt_dict.items(), key=lambda kv: kv[1], reverse=True)[:3]
for tup in sort_rslt:
    print(tup[0], lines[tup[0]])
print("Not In Train", sort_rslt)

numframes 13 interval 5 image num 3
[array([[ 26.,   3.,   0., ...,   0.,   0.,   1.],
       [ 34.,   0.,   0., ...,   0.,   0.,   1.],
       [ 32.,  54.,   8., ...,   1.,   6., 139.],
       ...,
       [  0., 108.,  70., ...,   0.,   0.,   0.],
       [  4.,   0.,   0., ...,   3., 145.,  45.],
       [ 25.,  11.,   1., ...,   6.,   0.,   4.]], dtype=float32), array([[ 20.,   0.,   0., ...,   0.,   1.,   2.],
       [  3.,  45., 129., ..., 106.,   2.,   0.],
       [ 23.,   0.,   0., ...,   1.,   0.,   0.],
       ...,
       [ 63.,  27.,  17., ...,   0.,   0.,   3.],
       [ 57.,  76.,   0., ...,   0.,   3.,  64.],
       [ 68.,  45.,   0., ...,   0.,   0.,   0.]], dtype=float32), array([[ 20.,   1.,   0., ...,   0.,   0.,   0.],
       [  1.,  57.,  12., ...,   0.,   0.,   3.],
       [ 58.,  86.,   2., ...,   0.,   3.,  19.],
       ...,
       [ 97., 132.,   1., ...,   3.,   0.,   0.],
       [ 33.,   0.,   0., ...,   0.,   2.,   1.],
       [ 52.,   2.,   0., ...,   3.,   1., 

In [None]:
lshHashTable.tocsv()

In [None]:
testdataloader = DataLoader()
correct = 0
for i in range(250):
  print("ID", i)
  test_features_lst = testdataloader.readImage(lines[i])
  query_result = sorted(lshHashTable.query(test_features_lst).items(), key=lambda kv: kv[1], reverse=True)
  if query_result[0][0] == i:
      correct += 1
      print("CORRECT", query_result[:2])
  else:
      print("NOOOOOOO", query_result[:2])

In [None]:
def loadResovoir(r, inputArray):
    res = Resovoir(r)
    res.resArray = inputArray
    return res


K = 16
L = 25
R = 2000

test = pd.read_csv("gdrive/MyDrive/HashTable_Result.csv")
test.ids = test.ids.apply(literal_eval)

hashtables = [[None for k in range(2**K)] for table in range(L)]
for idx, row in test.iterrows():
    print(row)
    table = int(row["table"])
    index = int(row["index"])
    resbucket = loadResovoir(R, row["ids"])
    hashtables[table][index] = resbucket
print(hashtables)
lshHashTable = hashTable(K, L, R)
lshHashTable.hashtables = hashtables
lshHashTable.printHashTable()