In [1]:
import os
import time
from itertools import chain

import pandas as pd
import numpy as np

from nltk.tokenize import word_tokenize
from sklearn.model_selection import train_test_split
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as nnF
import torch.optim as optim
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader

import torchvision.models as models
from torchvision.transforms import functional as F

In [2]:
def data_text_prep():
    d_train = pd.read_csv("../data/text_clean/train.csv")
    d_test = pd.read_csv("../data/text_clean/test.csv")

    # tokenize
    d_train.loc[:, 'title_1_token'] = d_train.title_1_pre.apply(word_tokenize)
    d_train.loc[:, 'title_2_token'] = d_train.title_2_pre.apply(word_tokenize)

    d_test.loc[:, 'title_1_token'] = d_test.title_1_pre.apply(word_tokenize)
    d_test.loc[:, 'title_2_token'] = d_test.title_2_pre.apply(word_tokenize)
    
    title_token = list(chain(*d_train.title_1_token.tolist() + d_train.title_2_token.tolist()))
    vocab_token = list(set(title_token))

    word2idx = dict((w, k) for k, w in enumerate(vocab_token, 2))
    idx2word = dict((k, w) for k, w in enumerate(vocab_token, 2))

    word2idx['<UNK>'] = 1
    idx2word[1] = '<UNK>'
    word2idx['<PAD>'] = 0
    idx2word[0] = '<PAD>'
    
    return d_train, d_test, word2idx, idx2word

In [3]:
class ShopeeDataset():
    def __init__(self, data, test, word2idx, idx2word):
        train, val = train_test_split(data, random_state=127)
        train.reset_index(drop=True, inplace=True)
        val.reset_index(drop=True, inplace=True)
        self.word2idx = word2idx
        self.idx2word = idx2word
        self.dataset = {
            'train': (train, train.shape[0]),
            'val': (val, val.shape[0]),
            'test': (test, test.shape[0])
        }
        self.set_split('train')
        
    def set_split(self, split='train'):
        self.data, self.length = self.dataset[split]
    
    def encode(self, text):
        token_ids = []
        for word in text:
            try:
                token_ids.append(self.word2idx[word])
            except:
                token_ids.append(1)
        token_ids = torch.LongTensor(token_ids)
        return token_ids
    
    def decode(self, ids):
        words = []
        for id_ in ids:
            try:
                words.append(self.idx2word[id_])
            except:
                words.append('<UNK>')
                
        return words
    
    def set_fix_length(self, ids):
        length = ids.shape[0]
        zeros = torch.zeros(25, dtype=torch.long)
        
        if length <= 25:
            zeros[:length] = ids
        else:
            zeros = ids[:25]
            
        return zeros
    
    def read_image(self, path):
        img_arr = Image.open(path)
        img_arr = img_arr.resize((224, 224))
        img_arr = img_arr.convert('RGB')
        img_arr = F.to_tensor(img_arr)
        
        return img_arr
    
    def __getitem__(self, idx):
        t1 = self.data.loc[idx, 'title_1_token']
        t2 = self.data.loc[idx, 'title_2_token']
        i1 = self.data.loc[idx, 'image_1']
        i2 = self.data.loc[idx, 'image_2']
        label = self.data.loc[idx, 'Label']
        
        t1_encode = self.encode(t1)
        t2_encode = self.encode(t2)
        
        t1_encode = self.set_fix_length(t1_encode)
        t2_encode = self.set_fix_length(t2_encode)
        
        i1_scaled = self.read_image(os.path.join("../data/raw/training_img/training_img", i1))
        i2_scaled = self.read_image(os.path.join("../data/raw/training_img/training_img", i2))
        
        return t1_encode, t2_encode, i1_scaled, i2_scaled, label
    
    def __len__(self):
        return self.length

In [4]:
class TextEncoder(nn.Module):
    def __init__(self, num_vocab, emb_size=512, hid_size=256, num_layers=1):
        super(TextEncoder, self).__init__()
        self.network = nn.Sequential(
            nn.Embedding(num_vocab, emb_size),
            nn.LSTM(emb_size, hid_size, num_layers=num_layers, batch_first=True)
        )
        
    def forward(self, input_):
        out, (h, c) = self.network(input_)
        out = out.unsqueeze(1)
        
        return out

In [5]:
class ImageEncoder(nn.Module):
    def __init__(self, out_channels=256, kernel_size=(3,3)):
        super(ImageEncoder, self).__init__()
        
        self.mobilenet = models.mobilenet_v2()
        self.backbone = self.mobilenet.features
        self.model = nn.Sequential(
            self.backbone,
            nn.Conv2d(in_channels=1280, out_channels=out_channels, kernel_size=kernel_size)
        )
    
    def forward(self, input_):
        batch_size = input_.shape[0]
        out = self.model(input_)
        
        n_channel = out.shape[1]
        out = torch.reshape(out, (batch_size, n_channel, -1))
        
        out = out.unsqueeze(1)
        out = out.permute(0,1,3,2)
        
        return out

In [6]:
class BaseNetwork(nn.Module):
    def __init__(self, in_channel, kernel_size_cnn=(3,11), kernel_size_max_pool=2):
        super(BaseNetwork, self).__init__()
        
        self.base_network = nn.Sequential(
            nn.Conv2d(in_channels=in_channel, out_channels=1, kernel_size=kernel_size_cnn),
            nn.MaxPool2d(kernel_size=kernel_size_max_pool),
            nn.Conv2d(in_channels=1, out_channels=1, kernel_size=kernel_size_cnn),
            nn.MaxPool2d(kernel_size=kernel_size_max_pool),
            nn.Conv2d(in_channels=1, out_channels=1, kernel_size=kernel_size_cnn),
            nn.MaxPool2d(kernel_size=kernel_size_max_pool)
        )
        
    def forward(self, input_):
        out = self.base_network(input_)
        out = out.squeeze(1)
        out = out.squeeze(1)
        
        return out

In [7]:
class WrapperModel(nn.Module):
    def __init__(self):
        super(WrapperModel, self).__init__()
        self.model_text = TextEncoder(num_vocab=len(word2idx))
        self.model_image = ImageEncoder()
        self.model_base = BaseNetwork(in_channel=1)
        self.fc = nn.Linear(1, 1)
        
    def forward(self, t1_encode, t2_encode, i1_scaled, i2_scaled):
        feat_t1 = self.model_text(t1_encode)
        feat_t2 = self.model_text(t2_encode)
        
        feat_i1 = self.model_image(i1_scaled)
        feat_i2 = self.model_image(i2_scaled)
        
        # concatenate
        concat_1 = torch.cat((feat_t1, feat_i1), axis=3)
        concat_2 = torch.cat((feat_t2, feat_i2), axis=3)
        
        vec_1 = self.model_base(concat_1)
        vec_2 = self.model_base(concat_2)
        
        ed = euclidean_distance(vec_1, vec_2)
        
        out = self.fc(ed)
        
        return out, ed

In [8]:
def euclidean_distance(vec_1, vec_2):
    ed = torch.sqrt(torch.sum(torch.pow(vec_1-vec_2, 2), dim=1))
    ed = ed.reshape(-1, 1)
    return ed

In [9]:
def cont_loss(label, distance, margin=0.5):
    loss_contrastive = torch.mean(label * torch.pow(distance, 2) +
                                  (1-label )* torch.pow(torch.clamp(margin - distance, min=0), 2))
    
    return loss_contrastive

In [10]:
def compute_accuracy(y_true, y_pred):
    y_pred = (y_pred > 0.5).long().squeeze()
#     y_pred = y_pred.argmax(1)
    n_correct = torch.eq(y_true, y_pred).sum().item()
    accuracy = (n_correct/label.shape[0])*100
    
    return accuracy

In [11]:
train, test, word2idx, idx2word = data_text_prep()

In [12]:
dataset = ShopeeDataset(train, test, word2idx, idx2word)
model = WrapperModel()

In [13]:
num_params = sum(p.numel() for p in model.parameters())
print(f"Trainable params: {num_params:,}")

Trainable params: 11,582,544


In [14]:
optimizer = optim.AdamW(model.parameters(), lr=1e-3)
# criterion = nn.BCELoss()

In [None]:
for epoch in range(1, 101):
    
    running_loss = 0
    running_loss_v = 0
    running_acc = 0
    running_acc_v = 0
    
    start = time.time()
    
    model.train()
    dataset.set_split('train')
    data_gen = DataLoader(dataset, batch_size=32)
    for batch_index, (t1_encode, t2_encode, i1_scaled, i2_scaled, label) in enumerate(data_gen, 1):
        optimizer.zero_grad()
        
        y_pred, distance = model(t1_encode, t2_encode, i1_scaled, i2_scaled)
        y_pred = torch.sigmoid(y_pred)
        loss = cont_loss(label, distance, margin=0.5)
#         loss = criterion(torch.sigmoid(distance), label.type(torch.float))
        running_loss += (loss.item() - running_loss) / batch_index
        
        accuracy = compute_accuracy(label, y_pred)
        running_acc += (accuracy - running_acc) / batch_index
        
        loss.backward()
        
        optimizer.step()
        break
    
    
    model.eval()
    dataset.set_split('val')
    data_gen = DataLoader(dataset, batch_size=32)
    for batch_index, (t1_encode, t2_encode, i1_scaled, i2_scaled, label) in enumerate(data_gen, 1):

        y_pred, distance = model(t1_encode, t2_encode, i1_scaled, i2_scaled)
        y_pred = torch.sigmoid(y_pred)
        
        loss = cont_loss(label, distance, margin=0.5)
#         loss = criterion(torch.sigmoid(distance), label.type(torch.float))
        running_loss_v += (loss.item() - running_loss_v) / batch_index
        
        accuracy = compute_accuracy(label, y_pred)
        running_acc_v += (accuracy - running_acc_v) / batch_index
        break
    
    duration = time.time() - start
    print(f"epoch: {epoch} | time: {duration:.1f}s")
    print(f"\ttrain loss: {running_loss:.2f} | train accuracy: {running_acc:.2f}")
    print(f"\tval loss: {running_loss_v:.2f} | val accuracy: {running_acc_v:.2f}")

In [None]:
dataset.data.loc[12, :]

In [None]:
img = dataset.read_image("../data/raw/training_img/training_img/b5fccecda25cde1a5e24e8d509e342a7.jpg")

In [None]:
img.shape

In [None]:
img = Image.open("../data/raw/training_img/training_img/b5fccecda25cde1a5e24e8d509e342a7.jpg")

In [None]:
np.array(img).shape

In [None]:
label

In [None]:
y_pred

In [None]:
len(embedding_weight)

In [None]:
embedding_weight[1].data

In [None]:
embedding_weight[2].data

In [None]:
for i in range(len(embedding_weight) - 1):
    change = torch.eq(embedding_weight[i].data, embedding_weight[i+1].data).sum().item()
    print(change)