# Taller _Representation Learning_: Transfer Learning

### NOTA: Para que funcione el codigo hay que descargar el dataset.

Para descargar el Flickr8K dataset:
[https://github.com/jbrownlee/Datasets/releases/download/Flickr8k/Flickr8k_Dataset.zip](https://github.com/jbrownlee/Datasets/releases/download/Flickr8k/Flickr8k_Dataset.zip).
Si ese link ya no funciona hay que seguir elp proceso y llenar el formulario [aqui](https://forms.illinois.edu/sec/1713398).

- Extraer el ZIP en el directorio `data`
- Ademas hay que descargar los _captions_ del dataset [aqui](http://cs.stanford.edu/people/karpathy/deepimagesent/caption_datasets.zip). Extrar en `caption_datasets`.

Despues de aprender la tarea de "generar _captions_", ahora utilizaremos ese "cortex" para resolver tareas relaciondaas para las cuales el modelo no fue entrenado pero tambien "aprendio". Algunas de estas son:

- Hacer algebra en la semantica conceptual de las palabras

- Encontrar imagenes semanticamente similares

- Encontrar imagenes a partir de una descripcion

### Imports

In [7]:
import matplotlib.pyplot as plt
import random
import json

In [8]:
%matplotlib widget

In [9]:
from scipy import ndimage
import numpy as np
from copy import deepcopy
from PIL import Image
import IPython.display
from math import floor
import string
import torch
import torch.nn as nn                     # neural networks
import torch.nn.functional as F           # layers, activations and more
import torch.optim as optim  
import torchvision.transforms.functional as TF
import torchvision
from torchvision import datasets, models, transforms

In [10]:
is_cuda = torch.cuda.is_available()
is_cuda

False

In [11]:
if(is_cuda):
    USE_GPU = True
else:
    USE_GPU = False

### Parametros

In [12]:
from classes import INCEPTION as inception
from classes import \
    ENDWORD, STARTWORD, PADWORD, HEIGHT, WIDTH, \
    INPUT_EMBEDDING, HIDDEN_SIZE, OUTPUT_EMBEDDING, \
    CAPTION_FILE, IMAGE_DIR

### Cargando InceptionV3 pre-entrenada

In [13]:
inception.load_state_dict(torch.load('models/inception_epochs_40.pth'))

<All keys matched successfully>

In [14]:
if(USE_GPU):
    inception.cuda()

## Clase para iterar en los datos

In [15]:
import pickle

f = pickle.load(open("pickles/flickr_data_loader.pkl", "rb"))

## Clase de la red

In [16]:
from classes import IC_V6

net = IC_V6(f.tokens)

In [17]:
net.load_state_dict(torch.load('models/epochs_40_loss_2_841_v6.pth'))

<All keys matched successfully>

In [18]:
if(USE_GPU):
    net.cuda()
    inception.cuda()

In [19]:
net.eval()

IC_V6(
  (batchnorm): BatchNorm1d(300, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (input_embedding): Embedding(8385, 300)
  (embedding_dropout): Dropout(p=0.22, inplace=False)
  (gru): GRU(300, 300, num_layers=3, dropout=0.22)
  (linear): Linear(in_features=300, out_features=300, bias=True)
  (out): Linear(in_features=300, out_features=8385, bias=True)
)

## Visualizando los embeddings

In [20]:
frequency_threshold = 50 # the word should have appeared at least this many times for us to visualize

all_word_embeddings = []
all_words = []

for word in f.word_frequency.keys():
    if(f.word_frequency[word] >= frequency_threshold):
        all_word_embeddings.append(net.input_embedding(torch.tensor(f.w2i[word])).detach().numpy())
        all_words.append(word)

In [21]:
len(all_words)

701

Usando T-SNE (http://www.jmlr.org/papers/volume9/vandermaaten08a/vandermaaten08a.pdf) para visualizar el embedding.

In [22]:
from sklearn.manifold import TSNE
tsne = TSNE(n_components=2, random_state=0)

In [23]:
X_2d = tsne.fit_transform(all_word_embeddings)

In [24]:
#new_cmap = rand_cmap(10, type='bright', first_color_black=True, last_color_black=False, verbose=True)

In [25]:
def update_annot(ind):

    pos = sc.get_offsets()[ind["ind"][0]]
    annot.xy = pos
    text = "{}".format(" ".join([all_words[n] for n in ind["ind"]]))
    annot.set_text(text)
    annot.get_bbox_patch().set_facecolor('white')
    annot.get_bbox_patch().set_alpha(0.9)


def hover(event):
    
    vis = annot.get_visible()
    if event.inaxes == ax:
        cont, ind = sc.contains(event)
        if cont:
            update_annot(ind)
            annot.set_visible(True)
            fig.canvas.draw_idle()
        else:
            if vis:
                annot.set_visible(False)
                fig.canvas.draw_idle()
                
def onpick(event):
    ind = event.ind
    print(ind)
    label_pos_x = event.mouseevent.xdata
    label_pos_y = event.mouseevent.ydata
    annot.xy = (label_pos_x,label_pos_y)
    annot.set_text(y[ind])
    ax.figure.canvas.draw_idle()

In [26]:
fig,ax = plt.subplots(figsize=(12, 12))
    
sc = plt.scatter(X_2d[:,0], X_2d[:,1])
#plt.legend()
#plt.show()

annot = ax.annotate("", xy=(0,0), xytext=(20,20),textcoords="offset points",
                    bbox=dict(boxstyle="round", fc="w"),
                    arrowprops=dict(arrowstyle="->", color='red'))
annot.set_visible(False)
fig.canvas.mpl_connect("motion_notify_event", hover)
#fig.canvas.mpl_connect('pick_event', onpick)

plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

## Algebra en los embeddings

### Encontrar las palabras mas similares

In [27]:
from nlp_utils import return_cosine_sorted, return_similar_words, \
    return_embedding, return_analogy


In [28]:
query = '_'
while len(query) < 5:
    query = np.random.choice(all_words)
query

'person'

In [33]:
return_similar_words('person', all_words, all_word_embeddings)

array([['girl', '0.18318498134613037'],
       ['cat', '0.17168444395065308'],
       ['shoes', '0.17018908262252808'],
       ['leaping', '0.16315896809101105'],
       ['cowboy', '0.16065259277820587']], dtype='<U32')

In [36]:
return_analogy('earth', 'brown', 'sky', all_words, all_word_embeddings)

0

### Visualizar embeddings de imagenes

In [37]:
from utils import cart2pol, pol2cart

In [38]:
import itertools

inception.eval()

try:
    all_image_embeddings = pickle.load(open('pickles/all_image_embeddings.pkl', 'rb'))
    all_image_filenames = pickle.load(open('pickles/all_image_filenames.pkl', 'rb'))
except Exception as e:
    print("> error loading data:", e)
    all_image_embeddings = []
    all_image_filenames = []
    for i in range(len(f.training_data)):
        all_image_embeddings.append(
            inception(f.image_to_tensor('data/'+f.training_data[i]['filename'])).detach().numpy())
        all_image_filenames.append(f.training_data[i]['filename'])
    pickle.dump(all_image_embeddings, open('pickles/all_image_embeddings.pkl', 'wb'))
    pickle.dump(all_image_filenames, open('pickles/all_image_filenames.pkl', 'wb'))

In [39]:
all_image_embeddings_temp = all_image_embeddings[:]
all_image_filenames_temp = all_image_filenames[:]

In [40]:
from matplotlib.offsetbox import (TextArea, DrawingArea, OffsetImage,
                                  AnnotationBbox)

In [41]:
from sklearn.manifold import TSNE
tsne_images = TSNE(n_components=2, random_state=0)

In [42]:
X_2d = tsne.fit_transform(np.squeeze(all_image_embeddings_temp))

In [44]:
fig,ax = plt.subplots(figsize=(10, 10))
sc = plt.scatter(X_2d[:,0], X_2d[:,1])
annot = ax.annotate("", xy=(0,0), xytext=(20,20),textcoords="offset points",
                    bbox=dict(boxstyle="round", fc="w"),
                    arrowprops=dict(arrowstyle="->", color='red'))
annot.set_visible(False)

def update_annot(ind):
    pos = sc.get_offsets()[ind["ind"][0]]
    annot.xy = pos
    #text = "{}".format(" ".join([all_words[n] for n in ind["ind"]]))
    #annot.set_text(text)
    
    rho = 10 #how for to draw centers of new images
    total_radians = 2* np.pi
    num_images = len(ind["ind"])
    if(num_images > 4): #at max 4
        num_images=4
    radians_offset = total_radians/num_images
    for i in range(num_images):
        hovered_filename = 'data/'+all_image_filenames_temp[ind["ind"][i]]
        arr_img = Image.open(hovered_filename, 'r')
        imagebox = OffsetImage(arr_img, zoom=0.3)
        #imagebox.image.axes = ax
        offset = pol2cart(rho, i*radians_offset)
        new_xy = (pos[0]+offset[0], pos[1]+offset[1])
        ab = AnnotationBbox(imagebox, new_xy)
        ax.add_artist(ab)  
        annot.get_bbox_patch().set_facecolor('white')
        annot.get_bbox_patch().set_alpha(0.9)


def hover(event):
    vis = annot.get_visible()
    if event.inaxes == ax:
        cont, ind = sc.contains(event)
        if cont:
            update_annot(ind)
            annot.set_visible(True)
            fig.canvas.draw_idle()
        else:
            if vis:
                annot.set_visible(False)
                remove_all_images()
                fig.canvas.draw_idle()

def remove_all_images():
    for obj in ax.findobj(match = type(AnnotationBbox(1, 1))):
        obj.remove()

fig.canvas.mpl_connect("motion_notify_event", hover)
#fig.canvas.mpl_connect('pick_event', onpick)
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

## Encontrar imagenes similares

In [45]:
def plot_image(filename):
    pil_im = Image.open(filename, 'r')
    plt.figure()
    plt.imshow(np.asarray(pil_im))
    plt.show()

In [46]:
from scipy import spatial


def return_embedding_image(image_filename):
    return inception(f.image_to_tensor(image_filename)).detach().numpy().squeeze()

def return_similar_images(image_filename, top_n=5):
    return return_cosine_sorted_image(return_embedding_image(image_filename))[1:top_n+1]
    
def return_cosine_sorted_image(target_image_embedding):
    cosines = []
    for i in range(len(all_image_embeddings)):
        cosines.append(1 - spatial.distance.cosine(target_image_embedding, all_image_embeddings[i]))    
    sorted_indexes = np.argsort(cosines)[::-1]
    return np.vstack((np.array(all_image_filenames)[sorted_indexes], np.array(cosines)[sorted_indexes])).T

In [47]:
search_filename = 'custom_images/kite.jpg'
plot_image(search_filename)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [48]:
similar_images = return_similar_images(search_filename)

In [52]:
plot_image('data/'+similar_images[2][0])

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

## Buscar imagenes con una frase

In [85]:
target_sentence = 'a kid playing'
tokens= f.convert_sentence_to_tokens(target_sentence)

In [86]:
from classes import set_parameter_requires_grad, INPUT_EMBEDDING

set_parameter_requires_grad(net, True)

In [87]:
embedding_tensor = torch.autograd.Variable(torch.randn(1, INPUT_EMBEDDING)*0.01, requires_grad=True)

In [88]:
l = torch.nn.CrossEntropyLoss(reduction='none')

In [89]:
print(embedding_tensor.shape)

torch.Size([1, 300])


In [90]:
epochs = 100 # best at > 10**5
loss_so_far = 0.0
lr = 0.001
with torch.autograd.set_detect_anomaly(True):
    for epoch in range(epochs):
        input_token = f.w2i[STARTWORD]
        input_tensor = torch.tensor(input_token)
        loss=0.
        
        # forward
        for token in tokens:
            if(input_token==f.w2i[STARTWORD]):
                out, hidden=net(input_tensor, embedding_tensor, process_image=True, use_inception=False)
            else:
                out, hidden=net(input_tensor, hidden)
            # current label
            class_label = torch.tensor(token).view(1)
            input_token = token
            input_tensor = torch.tensor(input_token)
            # predicted label
            out = out.squeeze().view(1,-1)
            loss += l(out, class_label)


        # backward
        loss.backward()
        #print(image_tensor.grad)
        embedding_tensor = torch.autograd.Variable(embedding_tensor.clone() - lr * embedding_tensor.grad, requires_grad=True)
        loss_so_far += loss.detach().item()

        if(epoch %10 == 0):
            print("==== Epoch: ",epoch, " loss: ",round(loss.detach().item(),3)," | running avg loss: ", round(loss_so_far/(epoch+1),3))
            if(epoch %90 ==0):
                similar_images = return_cosine_sorted_image(embedding_tensor.detach().numpy().squeeze())
                print(similar_images[:2])
                #plot_image('data/'+similar_images[0][0])
        

==== Epoch:  0  loss:  17.016  | running avg loss:  17.016
[['3108544687_c7115823f5.jpg' '0.14893005788326263']
 ['1806580620_a8fe0fb9f8.jpg' '0.1461697816848755']]
==== Epoch:  10  loss:  14.165  | running avg loss:  15.294
==== Epoch:  20  loss:  12.042  | running avg loss:  14.169
==== Epoch:  30  loss:  10.97  | running avg loss:  13.261
==== Epoch:  40  loss:  10.513  | running avg loss:  12.638
==== Epoch:  50  loss:  10.026  | running avg loss:  12.171
==== Epoch:  60  loss:  9.548  | running avg loss:  11.775
==== Epoch:  70  loss:  9.169  | running avg loss:  11.431
==== Epoch:  80  loss:  8.814  | running avg loss:  11.128
==== Epoch:  90  loss:  8.497  | running avg loss:  10.853
[['3527715826_ea5b4e8de4.jpg' '0.16713720560073853']
 ['3208032657_27b9d6c4f3.jpg' '0.16551019251346588']]


In [91]:
plot_image('data/3527715826_ea5b4e8de4.jpg')

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …