In [1]:
import pandas as pd
import numpy as np
import os 
from PIL import Image, ImageFile
from transformers import pipeline
import torch

from tqdm.notebook import tqdm
import torch.nn as nn
from torchvision import models

from torchvision.transforms import Resize, ToTensor, Normalize, Compose

In [2]:
class ImageEmbedder(nn.Module):
    def __init__(self):
        super().__init__()
        resnet = models.resnet50(pretrained=True)
        modules = list(resnet.children())[:-2]
        self.resnet = nn.Sequential(*modules)
        for p in self.resnet.parameters():
            p.requires_grad = False

    def forward(self, images):
        img_embeddings = self.resnet(images)
        size = img_embeddings.size()
        out = img_embeddings.view(*size[:2], -1)

        return out.view(*size).contiguous() 

In [3]:
meta_path = '/home/smart01/SFLAB/su_GTM_t/GTM_T_sanguk/'
data_path = "/home/smart01/SFLAB/sanguk/mind_br_data/"

meta_df = pd.read_csv(os.path.join(meta_path,'meta_data_image_text_nofilter.csv'), index_col='item_number')
meta_df = meta_df.drop(['sales_std'], axis=1)

text_df = pd.read_excel(os.path.join(data_path, "품번description(텍스트).xlsx"), index_col="품번").astype(str)
text_embedder = pipeline('feature-extraction', model='klue/bert-base')

image_embedder = ImageEmbedder()
img_transforms = Compose([Resize((256, 256)), ToTensor(), Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])



In [18]:
df = pd.DataFrame()

df['fabric'] = meta_df.loc[:,meta_df.columns[meta_df.columns.str.startswith('fabric')]].idxmax(axis=1).apply(lambda x:x.split('_')[-1])
df['color'] = meta_df.loc[:,meta_df.columns[meta_df.columns.str.startswith(('main_color','color'))]].idxmax(axis=1).apply(lambda x:x.split('_')[-1])
df['category'] = meta_df.loc[:,meta_df.columns[meta_df.columns.str.startswith('category')]].idxmax(axis=1).apply(lambda x:x.split('_')[-1])
df['sales_mean'] = meta_df.loc[:,'sales_mean']

img_df = pd.DataFrame(columns = [f'img_{i}' for i in range(2048)], index = df.index)
df = pd.concat([df, img_df],axis=1)

text_df = pd.DataFrame(columns = [f'text_{i}' for i in range(768)], index = df.index)
df = pd.concat([df, text_df],axis=1)

In [5]:
for item_id, _ in tqdm(df.iterrows(), total=len(df), ascii=True):
    img = Image.open(os.path.join(data_path+'images', item_id + '.png')).convert('RGB')
    img = img_transforms(img)
    img_embedding = image_embedder(img.unsqueeze(dim=0)).mean(axis=-1).mean(axis=-1).squeeze()
    df.loc[item_id, df.columns.str.startswith('img')] = np.array(img_embedding)

    text = text_df['설명'].loc[item_id]
    word_embeddings = text_embedder(text)
    text_embedding = torch.FloatTensor(word_embeddings).mean(axis=1).squeeze()
    df.loc[item_id, df.columns.str.startswith('text')] = np.array(text_embedding)

  0%|          | 0/1771 [00:00<?, ?it/s]

In [7]:
df.to_csv('../../data/preprocess/all_meta_data.csv')

In [15]:
df = pd.read_csv('../../data/preprocess/all_meta_data.csv', index_col='item_number')
df