In [1]:
import numpy as np
import pandas as pd
import os
import torch

from pytorch_widedeep import Trainer
from pytorch_widedeep.preprocessing import (
    WidePreprocessor,
    TabPreprocessor,
    TextPreprocessor,
    ImagePreprocessor,
)
from pytorch_widedeep.models import (
    Wide,
    TabMlp,
    Vision,
    BasicRNN,
    WideDeep,
)
from pytorch_widedeep.losses import RMSELoss
from pytorch_widedeep.initializers import *
from pytorch_widedeep.callbacks import *

In [2]:
df = pd.read_csv('D:\\RMBI_MMKG_data\\entity_name_des_trunc.csv')
df

Unnamed: 0,LabelName,ImageID,name,description
0,/m/09j2d,5975e5e6973ca2ad.jpg,Clothing,Clothing are items worn on the body. Typically...
1,/m/01yrx,bfe93a5657824d86.jpg,Cat,The cat is a domestic species of a small carni...
2,/m/09j2d,44784a16118b2bc0.jpg,Clothing,Clothing are items worn on the body. Typically...
3,/m/0k4j,673aef54336cf496.jpg,Car,A car is a wheeled motor vehicle used for tran...
4,/m/09j2d,64fdc83210da5cee.jpg,Clothing,Clothing are items worn on the body. Typically...
...,...,...,...,...
995,/m/06msq,09fc7ac0ae728478.jpg,Sculpture,Sculpture is the branch of the visual arts tha...
996,/m/01yrx,4a0ebed614054324.jpg,Cat,The cat is a domestic species of a small carni...
997,/m/09j2d,6dc01a5376d7f23e.jpg,Clothing,Clothing are items worn on the body. Typically...
998,/m/0k4j,493ee5a73b5863de.jpg,Car,A car is a wheeled motor vehicle used for tran...


In [3]:
text_col = "description"
img_col = "ImageID"

word_vectors_path = "D:\\RMBI_MMKG_data\\glove.6B.100d.txt"
img_path = "D:\\RMBI_MMKG_data\\img"

target_col = "name"

In [4]:
target = df[target_col].values

In [5]:
crossed_cols = []
wide_cols = []

In [6]:
wide_preprocessor = WidePreprocessor(wide_cols=wide_cols)
X_wide = wide_preprocessor.fit_transform(df)

In [7]:
X_wide

array([], shape=(1000, 0), dtype=int64)

In [8]:
cat_embed_cols = ['name']

In [9]:
tab_preprocessor = TabPreprocessor(cat_embed_cols)
X_tab = tab_preprocessor.fit_transform(df)

In [10]:
X_tab

array([[ 1],
       [ 2],
       [ 1],
       [ 3],
       [ 1],
       [ 1],
       [ 2],
       [ 1],
       [ 4],
       [ 3],
       [ 1],
       [ 1],
       [ 5],
       [ 6],
       [ 7],
       [ 1],
       [ 1],
       [ 3],
       [ 5],
       [ 3],
       [ 3],
       [ 3],
       [ 8],
       [ 9],
       [10],
       [ 1],
       [11],
       [ 8],
       [ 1],
       [ 3],
       [ 1],
       [ 3],
       [ 1],
       [ 3],
       [ 1],
       [ 1],
       [ 1],
       [ 1],
       [ 7],
       [ 7],
       [ 1],
       [ 5],
       [12],
       [ 1],
       [ 1],
       [13],
       [ 1],
       [ 1],
       [ 3],
       [ 7],
       [ 3],
       [ 1],
       [14],
       [ 1],
       [ 3],
       [ 8],
       [ 1],
       [ 1],
       [ 7],
       [ 7],
       [ 1],
       [15],
       [ 1],
       [ 1],
       [ 1],
       [16],
       [ 1],
       [ 3],
       [ 1],
       [ 7],
       [ 1],
       [ 5],
       [ 1],
       [ 1],
       [17],
       [ 1],
       [17],

In [11]:
text_preprocessor = TextPreprocessor(
    word_vectors_path=word_vectors_path, text_col=text_col
)
X_text = text_preprocessor.fit_transform(df)

The vocabulary contains 357 tokens
Indexing word vectors...
Loaded 400001 word vectors
Preparing embeddings matrix...
343 words in the vocabulary had D:\RMBI_MMKG_data\glove.6B.100d.txt vectors and appear more than 5 times


In [12]:
df[img_col]

0      5975e5e6973ca2ad.jpg
1      bfe93a5657824d86.jpg
2      44784a16118b2bc0.jpg
3      673aef54336cf496.jpg
4      64fdc83210da5cee.jpg
               ...         
995    09fc7ac0ae728478.jpg
996    4a0ebed614054324.jpg
997    6dc01a5376d7f23e.jpg
998    493ee5a73b5863de.jpg
999    983897722524a451.jpg
Name: ImageID, Length: 1000, dtype: object

In [14]:
print(df[img_col])

0      5975e5e6973ca2ad.jpg
1      bfe93a5657824d86.jpg
2      44784a16118b2bc0.jpg
3      673aef54336cf496.jpg
4      64fdc83210da5cee.jpg
               ...         
995    09fc7ac0ae728478.jpg
996    4a0ebed614054324.jpg
997    6dc01a5376d7f23e.jpg
998    493ee5a73b5863de.jpg
999    983897722524a451.jpg
Name: ImageID, Length: 1000, dtype: object


In [13]:
image_processor = ImagePreprocessor(img_col=img_col, img_path=img_path)
X_images = image_processor.fit_transform(df)

Reading Images from D:\RMBI_MMKG_data\img
Resizing


100%|█████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:02<00:00, 419.71it/s]


Computing normalisation metrics


In [54]:
# Linear model
wide = Wide(input_dim=np.unique(X_wide).shape[0], pred_dim=1)

# DeepDense: 2 Dense layers
tab_mlp = TabMlp(
    column_idx=tab_preprocessor.column_idx,
    cat_embed_input=tab_preprocessor.cat_embed_input,
    cat_embed_dropout=0.1,
    continuous_cols=continuous_cols,
    mlp_hidden_dims=[128, 64],
    mlp_dropout=0.1,
)

# DeepText: a stack of 2 LSTMs
basic_rnn = BasicRNN(
    vocab_size=len(text_preprocessor.vocab.itos),
    embed_matrix=text_preprocessor.embedding_matrix,
    n_layers=2,
    hidden_dim=64,
    rnn_dropout=0.5,
)

# Pretrained Resnet 18
resnet = Vision(pretrained_model_name="resnet18", n_trainable=4)

In [55]:
model = WideDeep(
    wide=wide,
    deeptabular=tab_mlp,
    deeptext=basic_rnn,
    deepimage=resnet,
    head_hidden_dims=[256, 128],
)

In [56]:
model

WideDeep(
  (wide): Wide(
    (wide_linear): Embedding(1, 1, padding_idx=0)
  )
  (deeptabular): TabMlp(
    (cat_and_cont_embed): DiffSizeCatAndContEmbeddings(
      (cat_embed): DiffSizeCatEmbeddings(
        (embed_layers): ModuleDict(
          (emb_layer_name): Embedding(35, 12, padding_idx=0)
        )
        (embedding_dropout): Dropout(p=0.1, inplace=False)
      )
      (cont_norm): BatchNorm1d(0, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (tab_mlp): MLP(
      (mlp): Sequential(
        (dense_layer_0): Sequential(
          (0): Dropout(p=0.1, inplace=False)
          (1): Linear(in_features=12, out_features=128, bias=True)
          (2): ReLU(inplace=True)
        )
        (dense_layer_1): Sequential(
          (0): Dropout(p=0.1, inplace=False)
          (1): Linear(in_features=128, out_features=64, bias=True)
          (2): ReLU(inplace=True)
        )
      )
    )
  )
  (deeptext): BasicRNN(
    (word_embed): Embedding(360, 100, padding_

In [57]:
trainer = Trainer(model, objective="rmse")

In [59]:
trainer.fit(
    X_wide=X_wide,
    X_tab=X_tab,
    X_text=X_text,
    X_img=X_images,
    target=target,
    n_epochs=1,
    batch_size=32,
    val_split=0.2,
)

epoch 1:   0%|                                                                                  | 0/25 [02:39<?, ?it/s]


AttributeError: 'tuple' object has no attribute 'view'

In [15]:
import torch
print(torch.__version__)

1.11.0+cpu
