# Taller _Representation Learning_

## Entrenando la red

### NOTA: Para que funcione el codigo hay que descargar el dataset.

Para descargar el Flickr8K dataset:
[https://github.com/jbrownlee/Datasets/releases/download/Flickr8k/Flickr8k_Dataset.zip](https://github.com/jbrownlee/Datasets/releases/download/Flickr8k/Flickr8k_Dataset.zip).
Si ese link ya no funciona hay que seguir elp proceso y llenar el formulario [aqui](https://forms.illinois.edu/sec/1713398).

- Extraer el ZIP en el directorio `data`
- Ademas hay que descargar los _captions_ del dataset [aqui](http://cs.stanford.edu/people/karpathy/deepimagesent/caption_datasets.zip). Extrar en `caption_datasets`.

### Imports

In [1]:
import matplotlib.pyplot as plt
import random
import json

In [2]:
%matplotlib widget

In [3]:
from scipy import ndimage
import numpy as np
from copy import deepcopy
from PIL import Image
import IPython.display
from math import floor
import string

In [4]:
import torch
import torch.nn as nn                     # neural networks
import torch.nn.functional as F           # layers, activations and more
import torch.optim as optim  
import torchvision.transforms.functional as TF

In [5]:
import torchvision
from torchvision import datasets, models, transforms

In [6]:
is_cuda = torch.cuda.is_available()
is_cuda

False

In [7]:
if(is_cuda):
    USE_GPU = True
else:
    USE_GPU = False

### Parametros

In [8]:
from classes import INCEPTION as inception
from classes import \
    ENDWORD, STARTWORD, PADWORD, HEIGHT, WIDTH, \
    INPUT_EMBEDDING, HIDDEN_SIZE, OUTPUT_EMBEDDING, \
    CAPTION_FILE, IMAGE_DIR

> Original arch:
Inception3(
  (Conv2d_1a_3x3): BasicConv2d(
    (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
    (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (Conv2d_2a_3x3): BasicConv2d(
    (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), bias=False)
    (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (Conv2d_2b_3x3): BasicConv2d(
    (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (Conv2d_3b_1x1): BasicConv2d(
    (conv): Conv2d(64, 80, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn): BatchNorm2d(80, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (Conv2d_4a_3x3): BasicConv2d(
    (conv): Conv2d(80, 192, kernel_size=(3, 3), stride=(1, 1), bias=False)
    (bn): BatchNorm2d(192, eps=0.001, momentum=0.

## Clase para iterar en los datos

In [9]:
from classes import Flickr8KImageCaptionDataset

## Clase de la red

In [10]:
from classes import IC_V6

In [11]:
f = Flickr8KImageCaptionDataset()

In [12]:
net = IC_V6(f.tokens)

In [13]:
net.load_state_dict(torch.load('models/epochs_40_loss_2_841_v6.pth'))

<All keys matched successfully>

In [14]:
if(USE_GPU):
    net.cuda()
    inception.cuda()

## Entrenar la red

In [15]:
l = torch.nn.CrossEntropyLoss(reduction='none')

In [16]:
o = optim.Adam(net.parameters(), lr=0.0001)

In [17]:
# epochs to train
epochs = 20
# setting evaluation mode
inception.eval()
net.train()
loss_so_far = 0.0
total_samples = len(f.training_data)
for epoch in range(epochs):
    for (image_tensor, tokens, _, index) in f.return_train_batch():
        o.zero_grad()
        net.zero_grad()
        words = []
        loss=0.
        
        input_token = f.w2i[STARTWORD]
        input_tensor = torch.tensor(input_token)
        for token in tokens:
            if(input_token==f.w2i[STARTWORD]):
                out, hidden=net(input_tensor, image_tensor, process_image=True)
            else:
                out, hidden=net(input_tensor, hidden)
            class_label = torch.tensor(token).view(1)
            input_token = token
            input_tensor = torch.tensor(input_token)

            out = out.squeeze().view(1,-1)
            # loss
            loss += l(out, class_label)

        loss = loss/len(tokens)
        loss.backward()
        o.step()
        loss_so_far += loss.detach().item()

        if(np.random.rand() < 0.002): # 5% of cases for print
            print("Epoch: ", epoch, ", index: ", index,
                  " loss: ", round(loss.detach().item(),3),
                  " | running avg loss: ", round(loss_so_far/((epoch*total_samples)+(index+1)),3))

            torch.save(net.state_dict(), 'models/running_save.pth')
            torch.save(net.state_dict(), 'models/running_inception_save.pth')

            net.eval()
           
            #test dataset
            #random_train_index = np.random.randint(len(f.training_data))
            random_train_index = index
            train_filename = IMAGE_DIR+f.training_data[random_train_index]['filename']
            print("> Original caption: ")
            [print(x['raw'].lower()) for x in f.training_data[random_train_index]['sentences']]
            print("")
            print("> Current caption:", f.caption_image_greedy(net, train_filename))
            print("---")
            """
            pil_im = Image.open(train_filename, 'r')
            plt.figure()
            plt.imshow(np.asarray(pil_im))
            plt.show()
            """
            net.train()
    
    print("\n\n")
    print("==== EPOCH DONE. === ")
    print("\n\n")

Epoch:  0 , index:  924  loss:  3.379  | running avg loss:  3.636
> Original caption: 
a baby is laughing .
a baby laughs at his reflection in a mirror .
a blonde haired toddler in a burgundy hoodie smiling .
a boy in a red hooded top is smiling whilst looking away from his reflection .
a child in front of his own reflection turning towards the camera and smiling .

> Current caption: a man and a woman are sitting on a bench
---
Epoch:  0 , index:  1029  loss:  3.923  | running avg loss:  3.636
> Original caption: 
a man leads two cows down the dirt shoulder of a paved road .
a man walking his cow down the side of the road .
a man walks down the road leading a cow with no rider and another cow with a rider .
a man walks with a cow down a dirt sidewalk .
two men pull two cows down the road .

> Current caption: a man in a black shirt and jeans is sitting on a bench
---
Epoch:  0 , index:  1170  loss:  2.19  | running avg loss:  3.628
> Original caption: 
two brown dog grapple each other

KeyboardInterrupt: 

## Guardar la red

In [19]:
torch.save(net.state_dict(), 'models_new/epochs_40_loss_2_841.pth')
torch.save(inception.state_dict(), 'models_new/inception_epochs_40.pth')