In [195]:
%load_ext autoreload
%autoreload 2
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"

import sys
sys.path.append("../src")

import numpy as np
from PIL import Image
from skimage import io
import torch
from torchvision import transforms
import torch.nn.functional as F
from torch.utils.data import DataLoader

from dataset.birds_dataset import BirdsDataset
from models.discriminator import Discriminator
from models.text_encoder import TextEncoder

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [3]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device("cpu")

In [4]:
device

device(type='cuda')

### Fine Tuning

### Implementing the tips from the How To Train a GAN video:
https://www.youtube.com/watch?v=myGAju4L7O8&t=482s

In [7]:
dataset_path = "../test/Example_Dataset/"
base_image_size = 64

In [8]:
def compose_image_transforms(base_image_size):
    """Returns composed image transforms for using on PIL images"""
    resize_factor_for_cropping = 76 / 64 # TODO Understand why they hardcodes this value
    new_size = tuple(2*[int(base_image_size * resize_factor_for_cropping)])
    image_transforms = transforms.Compose([transforms.Resize(new_size),
                                           transforms.RandomCrop(base_image_size),
                                           transforms.RandomHorizontalFlip()
                                           ])
    return image_transforms

#### 1  - Normalized Inputs

TODO: Add to unittests

In [9]:
train_dataset = BirdsDataset(dataset_path, split='train', image_transform=compose_image_transforms(base_image_size),
                                   base_image_size=base_image_size, number_images=3,
                                   text_transform=None, max_caption_length=18)

In [98]:
train_dataset.draw_random_sample().images[0].min()

tensor(-0.8980)

In [99]:
train_dataset.draw_random_sample().images[0].max()

tensor(0.9137)

#### 5 - Avoiding spart gradients by switching frm upsample to ConvTranspose

In [88]:
ct = torch.nn.ConvTranspose2d(3, 3, 3, stride=2, padding=1, output_padding=1)

In [89]:
us = torch.nn.Upsample(scale_factor=2)

In [90]:
img = train_dataset[0].images[0].unsqueeze(0)

In [101]:
ct(img).shape

torch.Size([1, 3, 128, 128])

In [102]:
us(img).shape

torch.Size([1, 3, 128, 128])

#### 6 - Label Smoothing

In [105]:
batch_size = 4

In [107]:
real_labels = (torch.ones(batch_size) * 1)
fake_labels = (torch.ones(batch_size) * 0)

In [111]:
def prepare_label_smoothing(real_labels, fake_labels, smooth=0.1):
    batch_size = real_labels.shape[0]
    smoothed_real = real_labels - smooth * torch.rand(batch_size)
    smoothed_fake = fake_labels + smooth * torch.rand(batch_size)
    return smoothed_real, smoothed_fake

In [113]:
prepare_label_smoothing(real_labels, fake_labels)

(tensor([0.9355, 0.9524, 0.9334, 0.9181]),
 tensor([0.0839, 0.0383, 0.0930, 0.0913]))

#### TODO Wrong labels

In [131]:
torch.any(torch.isnan(torch.tensor(np.inf)))

tensor(0, dtype=torch.uint8)

In [134]:
smooth = 0.1

In [152]:
real_labels = torch.ones(100)
fake_labels = torch.zeros(100)

In [150]:
(real_labels - smooth * torch.rand_like(real_labels)).min()

tensor(0.9003)

In [179]:
(fake_labels + smooth * torch.rand_like(real_labels)).max()

tensor(0.0988)

In [232]:
(torch.randn(10) * 0.02).max()
(torch.randn(10) * 0.02).min()

tensor(0.0314)

tensor(-0.0350)

In [2]:
m1 = torch.randn(2, 3, 64, 64)
m2 = torch.randn(2, 3, 64, 64)

In [3]:
mean1 = torch.mean(m1, dim=0)
mean2 = torch.mean(m2, dim=0)
torch.norm(mean1 - mean2, p=2)

tensor(110.5243)

### Do Not Delete: Example of masking difference:
I think they implemented it wrongly, since the data from a each sample in the batch will contaminate the info in all other samples in the batch. 
i.e. the seqlen of a single sample affects all other samples.
This shouldn't happen.
Thus changed.

In [85]:
B = 2
N = 16
SEQ_LEN = 4

In [86]:
S = torch.randn(B, N, SEQ_LEN)

In [87]:
S.shape

torch.Size([2, 16, 4])

In [88]:
S

tensor([[[ 0.8598, -0.5874, -0.6670,  2.3519],
         [-0.7280, -0.1804,  0.7576, -0.5519],
         [-0.0641, -2.0211, -0.5546, -0.0865],
         [-0.0550,  0.2636,  1.4460, -0.6592],
         [-0.3674,  0.2565, -0.3410, -0.6178],
         [ 1.1867,  0.7430,  0.0890,  1.7202],
         [-0.9460,  0.0776, -1.4905, -0.4916],
         [-0.6252,  1.4715,  1.3044,  1.7291],
         [ 0.5857, -0.1119, -0.8255,  1.3710],
         [-0.6686, -0.2314,  0.2326, -1.1778],
         [ 0.9265,  1.7572,  0.1455,  0.3765],
         [-0.6366,  1.2138,  0.4201,  1.3872],
         [ 0.0313,  1.0288, -0.5221, -1.0860],
         [-0.4163,  1.3163, -0.4635, -1.2261],
         [ 0.2558,  0.7133, -0.5025,  0.7142],
         [-0.1134, -0.7476, -0.1653, -0.7838]],

        [[ 0.1064,  0.0189, -0.4552, -1.4402],
         [-1.4196, -0.3377, -1.8147, -0.5775],
         [ 1.1651, -0.8108,  0.9696, -0.9492],
         [ 0.5759, -1.0418, -0.5664, -0.1953],
         [ 1.3160,  0.0754, -1.3854,  0.2612],
         [-

In [89]:
mask = torch.Tensor([[0, 0, 0, 1],
                     [0, 0, 1, 1]])
mask = mask.type(torch.uint8)

In [90]:
assert mask.shape == torch.Size([B, SEQ_LEN])

In [91]:
mask

tensor([[0, 0, 0, 1],
        [0, 0, 1, 1]], dtype=torch.uint8)

***The important part!!!!***

In [92]:
new_mask = mask.unsqueeze(1)
new_S = S.data.masked_fill_(new_mask.data, -float('inf'))

In [93]:
S

tensor([[[ 0.8598, -0.5874, -0.6670,    -inf],
         [-0.7280, -0.1804,  0.7576,    -inf],
         [-0.0641, -2.0211, -0.5546,    -inf],
         [-0.0550,  0.2636,  1.4460,    -inf],
         [-0.3674,  0.2565, -0.3410,    -inf],
         [ 1.1867,  0.7430,  0.0890,    -inf],
         [-0.9460,  0.0776, -1.4905,    -inf],
         [-0.6252,  1.4715,  1.3044,    -inf],
         [ 0.5857, -0.1119, -0.8255,    -inf],
         [-0.6686, -0.2314,  0.2326,    -inf],
         [ 0.9265,  1.7572,  0.1455,    -inf],
         [-0.6366,  1.2138,  0.4201,    -inf],
         [ 0.0313,  1.0288, -0.5221,    -inf],
         [-0.4163,  1.3163, -0.4635,    -inf],
         [ 0.2558,  0.7133, -0.5025,    -inf],
         [-0.1134, -0.7476, -0.1653,    -inf]],

        [[ 0.1064,  0.0189,    -inf,    -inf],
         [-1.4196, -0.3377,    -inf,    -inf],
         [ 1.1651, -0.8108,    -inf,    -inf],
         [ 0.5759, -1.0418,    -inf,    -inf],
         [ 1.3160,  0.0754,    -inf,    -inf],
         [-

In [94]:
new_S

tensor([[[ 0.8598, -0.5874, -0.6670,    -inf],
         [-0.7280, -0.1804,  0.7576,    -inf],
         [-0.0641, -2.0211, -0.5546,    -inf],
         [-0.0550,  0.2636,  1.4460,    -inf],
         [-0.3674,  0.2565, -0.3410,    -inf],
         [ 1.1867,  0.7430,  0.0890,    -inf],
         [-0.9460,  0.0776, -1.4905,    -inf],
         [-0.6252,  1.4715,  1.3044,    -inf],
         [ 0.5857, -0.1119, -0.8255,    -inf],
         [-0.6686, -0.2314,  0.2326,    -inf],
         [ 0.9265,  1.7572,  0.1455,    -inf],
         [-0.6366,  1.2138,  0.4201,    -inf],
         [ 0.0313,  1.0288, -0.5221,    -inf],
         [-0.4163,  1.3163, -0.4635,    -inf],
         [ 0.2558,  0.7133, -0.5025,    -inf],
         [-0.1134, -0.7476, -0.1653,    -inf]],

        [[ 0.1064,  0.0189,    -inf,    -inf],
         [-1.4196, -0.3377,    -inf,    -inf],
         [ 1.1651, -0.8108,    -inf,    -inf],
         [ 0.5759, -1.0418,    -inf,    -inf],
         [ 1.3160,  0.0754,    -inf,    -inf],
         [-

Copy the above code

In [95]:
new_S.shape

torch.Size([2, 16, 4])

In [96]:
new_S

tensor([[[ 0.8598, -0.5874, -0.6670,    -inf],
         [-0.7280, -0.1804,  0.7576,    -inf],
         [-0.0641, -2.0211, -0.5546,    -inf],
         [-0.0550,  0.2636,  1.4460,    -inf],
         [-0.3674,  0.2565, -0.3410,    -inf],
         [ 1.1867,  0.7430,  0.0890,    -inf],
         [-0.9460,  0.0776, -1.4905,    -inf],
         [-0.6252,  1.4715,  1.3044,    -inf],
         [ 0.5857, -0.1119, -0.8255,    -inf],
         [-0.6686, -0.2314,  0.2326,    -inf],
         [ 0.9265,  1.7572,  0.1455,    -inf],
         [-0.6366,  1.2138,  0.4201,    -inf],
         [ 0.0313,  1.0288, -0.5221,    -inf],
         [-0.4163,  1.3163, -0.4635,    -inf],
         [ 0.2558,  0.7133, -0.5025,    -inf],
         [-0.1134, -0.7476, -0.1653,    -inf]],

        [[ 0.1064,  0.0189,    -inf,    -inf],
         [-1.4196, -0.3377,    -inf,    -inf],
         [ 1.1651, -0.8108,    -inf,    -inf],
         [ 0.5759, -1.0418,    -inf,    -inf],
         [ 1.3160,  0.0754,    -inf,    -inf],
         [-

In [97]:
final_S = torch.exp(new_S) / torch.sum(torch.exp(new_S), dim=2, keepdim=True)

In [98]:
final_S

tensor([[[0.6885, 0.1619, 0.1496, 0.0000],
         [0.1399, 0.2419, 0.6181, 0.0000],
         [0.5703, 0.0806, 0.3492, 0.0000],
         [0.1457, 0.2004, 0.6538, 0.0000],
         [0.2569, 0.4794, 0.2637, 0.0000],
         [0.5063, 0.3248, 0.1689, 0.0000],
         [0.2292, 0.6379, 0.1330, 0.0000],
         [0.0624, 0.5079, 0.4297, 0.0000],
         [0.5742, 0.2858, 0.1400, 0.0000],
         [0.1996, 0.3090, 0.4914, 0.0000],
         [0.2665, 0.6115, 0.1220, 0.0000],
         [0.0977, 0.6214, 0.2810, 0.0000],
         [0.2333, 0.6326, 0.1341, 0.0000],
         [0.1314, 0.7432, 0.1254, 0.0000],
         [0.3280, 0.5183, 0.1537, 0.0000],
         [0.4033, 0.2139, 0.3829, 0.0000]],

        [[0.5219, 0.4781, 0.0000, 0.0000],
         [0.2531, 0.7469, 0.0000, 0.0000],
         [0.8782, 0.1218, 0.0000, 0.0000],
         [0.8345, 0.1655, 0.0000, 0.0000],
         [0.7757, 0.2243, 0.0000, 0.0000],
         [0.0197, 0.9803, 0.0000, 0.0000],
         [0.4089, 0.5911, 0.0000, 0.0000],
         

In [99]:
values = torch.randn(B, 128, SEQ_LEN)

In [127]:
final_S.transpose(1,2)

tensor([[[0.6885, 0.1399, 0.5703, 0.1457, 0.2569, 0.5063, 0.2292, 0.0624,
          0.5742, 0.1996, 0.2665, 0.0977, 0.2333, 0.1314, 0.3280, 0.4033],
         [0.1619, 0.2419, 0.0806, 0.2004, 0.4794, 0.3248, 0.6379, 0.5079,
          0.2858, 0.3090, 0.6115, 0.6214, 0.6326, 0.7432, 0.5183, 0.2139],
         [0.1496, 0.6181, 0.3492, 0.6538, 0.2637, 0.1689, 0.1330, 0.4297,
          0.1400, 0.4914, 0.1220, 0.2810, 0.1341, 0.1254, 0.1537, 0.3829],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],

        [[0.5219, 0.2531, 0.8782, 0.8345, 0.7757, 0.0197, 0.4089, 0.8038,
          0.6607, 0.6790, 0.2401, 0.1374, 0.5467, 0.8496, 0.8581, 0.5801],
         [0.4781, 0.7469, 0.1218, 0.1655, 0.2243, 0.9803, 0.5911, 0.1962,
          0.3393, 0.3210, 0.7599, 0.8626, 0.4533, 0.1504, 0.1419, 0.4199],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.00

In [101]:
res = torch.bmm(values, final_S.transpose(1,2))

In [131]:
res.shape

torch.Size([2, 128, 16])

##### What in their code was supposed to happen?

In [119]:
S = torch.randn(B, N, SEQ_LEN)

In [120]:
s_target = S.view(B * N, SEQ_LEN)
mask_target = mask.repeat(N, 1)

In [121]:
mask_target.shape

torch.Size([32, 4])

In [122]:
s_target.shape

torch.Size([32, 4])

In [123]:
s_target = s_target.data.masked_fill_(mask_target.data, -float('inf'))

In [124]:
s_target

tensor([[ 0.0721,  0.3769,  0.8276,    -inf],
        [ 1.1037,  0.2412,    -inf,    -inf],
        [ 0.2906,  0.5414,  0.1116,    -inf],
        [ 0.5539,  0.5204,    -inf,    -inf],
        [ 0.9170, -0.0896, -2.3579,    -inf],
        [-0.2896,  0.1527,    -inf,    -inf],
        [ 1.1611,  1.0140, -1.1554,    -inf],
        [ 0.1728, -0.4008,    -inf,    -inf],
        [ 0.3642, -0.2595,  2.7537,    -inf],
        [ 1.1848, -0.3510,    -inf,    -inf],
        [-0.2630,  0.3359, -1.1433,    -inf],
        [-0.1059,  0.7815,    -inf,    -inf],
        [-0.5517, -1.1202, -0.5394,    -inf],
        [-1.3095, -0.1720,    -inf,    -inf],
        [-0.7143, -0.3530, -2.3358,    -inf],
        [-1.3454, -0.0322,    -inf,    -inf],
        [ 0.8054,  0.6466, -1.2950,    -inf],
        [-0.6130,  0.0538,    -inf,    -inf],
        [-1.4917, -0.0735,  0.0297,    -inf],
        [-0.9626,  0.8355,    -inf,    -inf],
        [-0.5957, -1.2979,  3.6472,    -inf],
        [-0.7035, -2.0725,    -inf

In [125]:
final = s_target.view(B, N, SEQ_LEN)
final

tensor([[[ 0.0721,  0.3769,  0.8276,    -inf],
         [ 1.1037,  0.2412,    -inf,    -inf],
         [ 0.2906,  0.5414,  0.1116,    -inf],
         [ 0.5539,  0.5204,    -inf,    -inf],
         [ 0.9170, -0.0896, -2.3579,    -inf],
         [-0.2896,  0.1527,    -inf,    -inf],
         [ 1.1611,  1.0140, -1.1554,    -inf],
         [ 0.1728, -0.4008,    -inf,    -inf],
         [ 0.3642, -0.2595,  2.7537,    -inf],
         [ 1.1848, -0.3510,    -inf,    -inf],
         [-0.2630,  0.3359, -1.1433,    -inf],
         [-0.1059,  0.7815,    -inf,    -inf],
         [-0.5517, -1.1202, -0.5394,    -inf],
         [-1.3095, -0.1720,    -inf,    -inf],
         [-0.7143, -0.3530, -2.3358,    -inf],
         [-1.3454, -0.0322,    -inf,    -inf]],

        [[ 0.8054,  0.6466, -1.2950,    -inf],
         [-0.6130,  0.0538,    -inf,    -inf],
         [-1.4917, -0.0735,  0.0297,    -inf],
         [-0.9626,  0.8355,    -inf,    -inf],
         [-0.5957, -1.2979,  3.6472,    -inf],
         [-