# Encoding Item Textual and Visual Modality Features for Amazon2023

In [24]:
from typing import Callable, Iterable

import torch, tqdm, os
import pandas as pd
from freerec.data.tags import USER, ITEM
from freerec.data.utils import download_from_url
from freerec.utils import export_pickle

from concurrent.futures import ThreadPoolExecutor

import torchdata.datapipes as dp
from transformers import AutoImageProcessor, AutoModel
from sentence_transformers import SentenceTransformer
from PIL import Image

You can download the models according to [[hf-mirror](https://hf-mirror.com/)].

In [25]:
dataset: str = "Amazon2023Beauty_10104811_ROU"
datadir: str = f"../data/Processed/{dataset}"
image_folder: str = os.path.join(datadir, "item_images", "large")
model_cache_dir: str = "../models"

In [26]:
item_df = pd.read_csv(
    os.path.join(datadir, "item.txt"), sep='\t'
)
item_df.head(5)

Unnamed: 0,ITEM,TEXT,IMAGE_URL
0,0,Title: Klutz Metallic Glam Nail Studio Activit...,https://m.media-amazon.com/images/I/51dyKdZMlC...
1,1,Title: Versace Bright Crystal Eau de Toilette ...,https://m.media-amazon.com/images/I/41lnN8CpvE...
2,2,Title: Conair CD82ZCS Instant Heat Curling Iro...,https://m.media-amazon.com/images/I/31N529CJ78...
3,3,Title: Conair CD82ZCS Instant Heat Curling Iro...,https://m.media-amazon.com/images/I/31N529CJ78...
4,4,Title: Refill Cartridges CCR\nFeatures: 1. Cle...,https://m.media-amazon.com/images/I/41KIM5M9xi...


We first download images from given urls. Please check `image_size` before going on.

In [32]:
def download_images(item_df: pd.DataFrame):
    ids = item_df[ITEM.name]
    urls = item_df['IMAGE_URL']
    with ThreadPoolExecutor() as executor:
        for id_, url in tqdm.tqdm(zip(ids, urls), desc="Download images: "):
            if url:
                executor.submit(
                    download_from_url,
                    url=url,
                    root=image_folder,
                    filename=f"{id_}.jpg",
                    log=False
                )

In [33]:
download_images(item_df)

Download images: : 30380it [00:32, 937.29it/s] 


Then, we will encode visual modality first.

In [27]:
def encode_visual_modality(
    item_df: pd.DataFrame,
    model: str, model_dir: str,
    num_workers: int = 4, batch_size: int = 128,
):
    images = []
    try:
        processor = AutoImageProcessor.from_pretrained(
            os.path.join(model_dir, model), local_files_only=True
        )
    except OSError:
        print("No processor file")
        processor = lambda x: x

    is_missed_urls = torch.ones((len(item_df,)))
    def _process(idx: int):
        try:
            image = Image.open(
                os.path.join(
                    image_folder, f"{idx}.jpg"
                )
            ).convert('RGB')
        except FileNotFoundError:
            image = Image.new('RGB', (224, 224))
            is_missed_urls[idx] = 0
        return idx, processor(images=image, return_tensors='pt')['pixel_values'][0]

    datapipe = dp.iter.IterableWrapper(
        range(len(item_df))
    ).sharding_filter().map(
        _process
    )
    dataloader = torch.utils.data.DataLoader(
        datapipe, 
        num_workers=num_workers, batch_size=batch_size,
        shuffle=False
    )

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    encoder = AutoModel.from_pretrained(
        os.path.join(model_dir, model), local_files_only=True
    ).to(device).eval()

    vIndices = []
    vFeats = []
    with torch.no_grad():
        encoder.eval()
        for (indices, images) in tqdm.tqdm(dataloader, desc="Visual batches: "):
            vIndices.append(indices)
            outputs = encoder(pixel_values=images.to(device)).last_hidden_state
            if outputs.ndim == 3:
                # vit (Batch, Sequence, D) -> (Batch, D, Sequence)
                outputs = outputs.transpose(-1, -2)
            else:
                # resnet (Batch, D, K, K) -> (Batch, D, K x K)
                outputs = outputs.flatten(2)
            vFeats.append(
                outputs.mean(-1)
            )
    vIndices = torch.cat(vIndices, dim=0)
    vFeats = torch.cat(vFeats, dim=0).flatten(1) # (N, D)
    vFeats = vFeats[vIndices.argsort()] # reindex
    assert vFeats.size(0) == len(item_df), f"Unknown errors happen ..."

    is_missed_urls = is_missed_urls.bool().to(vFeats.device)
    mean = vFeats[is_missed_urls].mean(dim=0, keepdim=True).repeat((vFeats.size(0), 1))
    is_missed_urls = is_missed_urls.unsqueeze(-1).expand_as(vFeats)
    vFeats = torch.where(is_missed_urls, vFeats, mean)

    export_pickle(
        vFeats, os.path.join(
            datadir, f"visual_{model}.pkl"
        )
    )
    return vFeats

In [33]:
encode_visual_modality(
    item_df,
    model="vit-base-16-224",
    model_dir=model_cache_dir
)

Visual batches:   0%|          | 0/238 [00:00<?, ?it/s]huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been 

tensor([[-0.0488, -0.0903, -0.2875,  ..., -0.0211, -0.2395,  0.0547],
        [ 0.2017,  0.0873, -0.1558,  ..., -0.1068,  0.2998,  0.1285],
        [-0.0352,  0.2041,  0.0408,  ..., -0.0813,  0.0746,  0.0412],
        ...,
        [-0.1274,  0.1027, -0.1186,  ..., -0.0758, -0.0648, -0.0265],
        [ 0.1171, -0.1759,  0.0610,  ...,  0.0516, -0.0141, -0.1135],
        [-0.2421,  0.3244, -0.1851,  ..., -0.2860,  0.1293,  0.0326]],
       device='cuda:0')

In [29]:
def encode_textual_modality(
    item_df: pd.DataFrame,
    model: str, model_dir: str,
    batch_size: int = 128
):
    sentences = item_df['TEXT']
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    encoder = SentenceTransformer(
        os.path.join(model_dir, model),
        device=device
    ).eval()

    with torch.no_grad():
        tFeats = encoder.encode(
            sentences, 
            convert_to_tensor=True,
            batch_size=batch_size, show_progress_bar=True
        ).cpu()
    assert tFeats.size(0) == len(item_df), f"Unknown errors happen ..."

    export_pickle(
        tFeats, os.path.join(
            datadir, f"textual_{model}.pkl"
        )
    )
    return tFeats

In [36]:
encode_textual_modality(
    item_df,
    model="roberta-base",
    model_dir=model_cache_dir
)

Batches: 100%|██████████| 238/238 [03:22<00:00,  1.17it/s]


tensor([[ 1.2388, -0.1938, -0.3167,  ..., -0.0088, -0.0606, -0.3454],
        [ 0.5566, -0.7622,  0.1914,  ...,  0.5741,  0.0815,  0.2836],
        [ 0.7353,  0.8229, -0.2983,  ...,  0.1093, -0.0770, -0.1295],
        ...,
        [ 1.1915, -0.4700,  0.2279,  ..., -1.0194,  0.5871, -0.0360],
        [ 0.4231, -0.4110, -0.7817,  ..., -0.7333,  0.1331, -0.4086],
        [ 1.5562,  0.1625,  0.5239,  ..., -1.0524,  0.3007,  0.2076]])

The following code is used to extract features by CLIP.

In [31]:

def encode_clip_textual_visual_modality(
    item_df: pd.DataFrame,
    img_clip_model: str,
    text_clip_model: str,
    model_dir: str,
    batch_size: int = 128
):
    images = []
    is_missed_urls = torch.ones((len(item_df,)))
    for idx in range(len(item_df)):
        try:
            image = Image.open(
                os.path.join(
                    image_folder, f"{idx}.jpg"
                )
            ).convert('RGB')
        except FileNotFoundError:
            image = Image.new('RGB', (224, 224))
            is_missed_urls[idx] = 0
        images.append(image)

    sentences = item_df['TEXT']

    img_encoder = SentenceTransformer(
        os.path.join(model_dir, img_clip_model),
        device=torch.device('cuda:0')
    ).eval()
    text_encoder = SentenceTransformer(
        os.path.join(model_dir, text_clip_model),
        device=torch.device('cuda:1')
    ).eval()

    with torch.no_grad():
        vFeats = img_encoder.encode(
            images,
            convert_to_tensor=True,
            batch_size=batch_size, show_progress_bar=True
        ).cpu()
        tFeats = text_encoder.encode(
            sentences, 
            convert_to_tensor=True,
            batch_size=batch_size, show_progress_bar=True
        ).cpu()
    assert vFeats.size(0) == len(item_df), f"Unknown errors happen ..."
    assert tFeats.size(0) == len(item_df), f"Unknown errors happen ..."

    is_missed_urls = is_missed_urls.bool().to(vFeats.device)
    mean = vFeats[is_missed_urls].mean(dim=0, keepdim=True).repeat((vFeats.size(0), 1))
    is_missed_urls = is_missed_urls.unsqueeze(-1).expand_as(vFeats)
    vFeats = torch.where(is_missed_urls, vFeats, mean)

    export_pickle(
        vFeats, os.path.join(
            datadir, f"visual_{img_clip_model}.pkl"
        )
    )

    export_pickle(
        tFeats, os.path.join(
            datadir, f"textual_{text_clip_model}.pkl"
        )
    )
    return vFeats, tFeats

In [32]:
encode_clip_textual_visual_modality(
    item_df,
    img_clip_model="clip-vit-b-32",
    text_clip_model="clip-vit-b-32-multilingual-v1",
    model_dir=model_cache_dir
)

Batches: 100%|██████████| 238/238 [04:12<00:00,  1.06s/it]
Batches: 100%|██████████| 238/238 [01:41<00:00,  2.35it/s]


(tensor([[-0.0367, -0.4885,  0.3980,  ..., -0.1649,  0.4623,  0.0621],
         [-0.2082, -0.0440,  0.0904,  ...,  0.3867, -0.0183, -0.0969],
         [-0.0832,  0.3915, -0.0901,  ...,  0.9376,  0.2346,  0.1945],
         ...,
         [-0.1495,  0.3978,  0.2217,  ...,  0.1837, -0.4292,  0.1551],
         [-0.2783,  0.1620,  0.1765,  ...,  0.4337, -0.3054, -0.1246],
         [-0.2874, -0.1172,  0.1805,  ...,  0.1124,  0.0043,  0.2117]]),
 tensor([[-0.0480,  0.1438, -0.0667,  ...,  0.0107,  0.0527, -0.1216],
         [ 0.0211,  0.1272, -0.0593,  ...,  0.0686,  0.0270, -0.0941],
         [-0.0333,  0.1581, -0.0986,  ...,  0.0823,  0.0271,  0.0128],
         ...,
         [ 0.0293,  0.1610, -0.0843,  ..., -0.0547, -0.0385, -0.0440],
         [ 0.1392,  0.0145, -0.0654,  ..., -0.0097,  0.0556, -0.1122],
         [ 0.0427,  0.1206, -0.1094,  ...,  0.0034, -0.0012, -0.0377]]))