In [None]:
# for google colaboratory
from os import path
if 'google.colab' in str(get_ipython()):
    files = ['LLD-icon-sharp.hdf5', 'embed_train.zip']
    !pip install wget
    !pip install transformers
    import wget
    import shutil
    from google.colab import drive
    drive.mount('/content/drive')
    !mkdir Data
    for f in files:
        if not path.isfile('Data/' + f):
            if path.isfile('/content/drive/My Drive/Colab/AFRO/' + f):
                shutil.copy('/content/drive/My Drive/Colab/AFRO/' + f, 'Data')
            else:
                wget.download('https://data.vision.ee.ethz.ch/sagea/lld/data/' + f, 'Data')
    !unzip -q -n Data/embed_train.zip

In [None]:
# for debugging ------
from importlib import reload
import utils
utils = reload(utils)
from utils import lemmatize_and_clearing
# -------------------

import h5py
import numpy as np
import pandas as pd
from tqdm import notebook
import matplotlib.pyplot as plt
%matplotlib inline

import nltk

import torch
import torch.nn as nn
from IPython.display import clear_output
from torch.utils.data import DataLoader

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

# Open data

## icon-sharp

In [None]:
files = h5py.File('Data/LLD-icon-sharp.hdf5', 'r')

In [None]:
files.keys()

In [None]:
clusters = files['labels/resnet/rc_128']
names = files['meta_data/names'][()].astype(str)
images = files['data']

# Work with labels

In [None]:
nltk.download('stopwords')
nltk.download('wordnet')

In [None]:
lem_names =  map(lemmatize_and_clearing, names)

## Download bert files and berting all that is possible

In [None]:
# tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
# tokenizer.add_special_tokens({'pad_token': '[PAD]'})

# tok_names = tokenizer(list(lem_names), padding=True, return_tensors="pt")

# model = BertModel.from_pretrained('bert-base-uncased', return_dict=True)
# model = model.to(device)

# batch_size = 10
# dataloader = DataLoader(tok_names['input_ids'], batch_size=batch_size)

# embed_and_write_file(dataloader, model, device, 'embed_train.csv')

In [None]:
df = pd.read_csv('Data/embed_train.csv', header=None)

In [None]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        dim = 3

        self.features_to_image = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=4*dim,
                               kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(4 * dim),
            nn.ConvTranspose2d(4 * dim, 3 * dim, 4, 2, 1),
            nn.ReLU(),
            nn.BatchNorm2d(3 * dim),
            nn.Conv2d(3 * dim, 2*dim, 3, 1, 1),
            nn.ReLU(),
            nn.BatchNorm2d(2*dim),
            nn.Conv2d(2*dim, dim, 3, 1, 1),
            nn.Sigmoid()
        )

    def forward(self, input_data):
        return self.features_to_image(input_data)

In [None]:
ix = np.arange(len(df))
# np.random.shuffle(ix)
tr, val, ts = np.split(ix, [len(df)//2, len(df)//2 + len(df)//4])
print(len(tr), len(val), len(ts))

In [None]:
batch_size = 20
train_dataloader = DataLoader(list(zip(df.iloc[tr].values, images[tr])), batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(list(zip(df.iloc[val].values, images[val])), batch_size=batch_size)

In [None]:
def train(model, opt, loss_fn, epochs, data_tr, data_val):
    X_val, Y_val = next(iter(data_val))
    X_val = torch.as_tensor(X_val, dtype=torch.float, device=device)
    for epoch in range(epochs):
        print('* Epoch %d/%d' % (epoch+1, epochs))
        avg_loss = 0
        model.train()  # train mode
        with notebook.tqdm(total=len(data_tr)) as progress_bar:
            for step, (X_batch, Y_batch) in enumerate(data_tr):
                # data to device
                X_batch = torch.as_tensor(X_batch, dtype=torch.float, device=device)
                Y_batch = Y_batch.to(device) / 255.
                # set parameter gradients to zero
                opt.zero_grad()
                # forward
                Y_pred = model(X_batch.view(-1, 3, 16, 16))
                loss = loss_fn(Y_batch, Y_pred) # forward-pass
                avg_loss += loss.item()
                loss = loss.backward() # backward-pass
                opt.step()  # update weights
                progress_bar.set_description('loss: %f' % (avg_loss/(step+1)))
                progress_bar.update()

            # calculate loss to show the user
        avg_loss = avg_loss / len(data_tr)
        # print('loss: %f' % avg_loss)

        # show intermediate results
        model.eval()  # testing mode
        Y_hat = model(X_val.view(-1, 3, 16, 16)).detach().cpu().numpy() # detach and put into cpu
        # Visualize tools
        clear_output(wait=True)
        for k in range(6):
            plt.subplot(2, 6, k+1)
            plt.imshow(np.moveaxis(Y_val[k].numpy(), 0, 2))
            plt.title('Real')
            plt.axis('off')

            plt.subplot(2, 6, k+7)
            plt.imshow(np.moveaxis(Y_hat[k], 0, 2))
            plt.title('Output')
            plt.axis('off')
        plt.suptitle('%d / %d - loss: %f' % (epoch+1, epochs, avg_loss))
        plt.show()

In [None]:
gen = Generator()
gen = gen.to(device)

In [None]:
max_epochs = 1
loss_func = nn.MSELoss()
optim = torch.optim.Adam(gen.parameters(), lr=1e-3)
train(gen, optim, loss_func, max_epochs, train_dataloader, val_dataloader)

# Create text and add to icon

In [None]:
# custom module ---
import utils
utils = reload(utils)
from utils import add_text_to_img
# ------------------

In [None]:
from skimage.transform import resize
size = (128, 128)
img = resize(np.moveaxis(images[2], 0, -1), size, mode='constant', anti_aliasing=True,)


In [None]:
plt.imshow(img)

In [None]:
img.shape

In [None]:
plt.imshow(add_text_to_img('Hello', img))
