In [None]:
import pandas as pd
from pathlib import Path

In [None]:
root=Path('~/datasets_row/d1')

In [None]:
!tree ~/datasets_row/d1 --du -h

In [None]:
pd.read_parquet(root / 'test.parquet')

In [None]:
pd.read_parquet(root / 'train.parquet')

In [None]:
import torch
print(torch.__version__)

In [None]:
from torch.utils.data import Dataset
from typing import List, Any, Dict
import os
import json
from easydict import EasyDict
from collections import defaultdict
import numpy as np
import cv2
from torchvision import transforms
from pprint import pprint
class d1_text(Dataset):
    features: List[str]
    target: str
    mode: str
    imgs_train_ids: List[str]

    def __init__(self, root_:str):

        self.root_ = root_
        df_train= pd.read_parquet(os.path.join(root_, 'train.parquet'))
        imgs_train_ids = df_train['product_id'].apply(lambda x: str(x)).values
        train_text_attributes=df_train['text_fields'].apply(lambda x: EasyDict(json.loads(x))).values

        self.features = list(train_text_attributes[0].keys())
        self.target = 'category_name'

        train_targets = df_train['category_name'].values

        data_dict=  defaultdict(list)
        for attributes_ in train_text_attributes:
            obj= EasyDict(attributes_)
            for k,v in obj.items():
                data_dict[k].append(v)
        self.features = list(data_dict.keys())

        data_dict = data_dict
        
        all = pd.DataFrame(data=data_dict)
        # get category levels
        splitted = [path_.split('->') for path_ in train_targets]
        splitted = [el[1:] for el in splitted]
        max_len_ = max([len(x) for x in splitted])
        cat_dict = defaultdict(list)
        for row in splitted:
            for j in range(len(row)):
                cat_dict[f'cat_{j}'].append(row[j])
            for j in range(len(row), max_len_):
                cat_dict[f'cat_{j}'].append(None)
        cat_dict = pd.DataFrame(data=cat_dict)
        self.cat_levels  = list(cat_dict.keys())
        all = pd.concat([all, cat_dict],axis=1)
        all['product_id'] = imgs_train_ids 
        all[self.target] = ['->'.join(x) for x in splitted]
        self.all = all

        self.text()
        
    def __getitem__(self, idx:int):
        row = self.all.iloc[idx:idx+1]
        
        if self.mode == 'text':
            text_ = row[self.features].values.flatten()
            target_ = row[self.target].values[0]
            return text_,target_
        elif self.mode == 'img':
            img_ = cv2.cvtColor(cv2.imread(os.path.join(self.root_, 'images', row['product_id'].values[0]+'.jpg')),
                            cv2.COLOR_RGB2BGR)
            target_ = row[self.target].values[0]
            return img_,target_
        elif self.mode == 'multi':
            text_ = row[self.features].values.flatten()
            img_ = cv2.cvtColor(cv2.imread(os.path.join(self.root_, 'images', row['product_id'].values[0]+'.jpg')),
                            cv2.COLOR_RGB2BGR)
            target_ = row[self.target].values[0]
            return text_,img_,target_


    def __len__(self):
        return len(self.imgs_train_ids)
                
    def text(self):
        self.mode = 'text'
    def img(self):
        self.mode = 'img'
    def multi(self):
        self.mode = 'multi'

    def get_view(self, idx):
        import matplotlib.pyplot as plt
        if self.mode != 'multi':
            print(f'mode {self.mode} not supported yet')
        text,img,target = self[idx]
        fig = plt.figure()
        ax = fig.add_subplot(1,1,1)
        ax.imshow(img)
        def add_new_lines(text_, max_line_length):
            o_ = ''
            n_ = len(text_)
            for i in range(int(n_/max_line_length)+1):
                o_ += text_[i*max_line_length:(i+1)*max_line_length] + '\n'
            return o_
        title_ = text[0]
        desc_ = text[1]
        final_ = 'title: '+add_new_lines(title_, 50) +'\n'+'desc: '+ add_new_lines(desc_, 50)

        ax.set_xlabel(final_,fontsize=10)
        ax.set_title('target:'+add_new_lines(target,50),fontsize=10)
        return fig,ax
        


In [None]:
dataset = d1_text(root_='/home/user/datasets_row/d1')

In [None]:
dataset.text()
print(dataset.features)
print(dataset.target)
dataset[0][0],dataset[0][1]

In [None]:
dataset.multi()

In [None]:
import matplotlib.pyplot as plt
fig,ax = plt.subplots()
dt_ = dataset.all.groupby('cat_0').agg(cnt=('product_id',len))
ax.barh(dt_.index.values, dt_['cnt'].values)
ax.set_xticks(dt_['cnt'].values)
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right')
print('')

In [None]:
dt_ = dataset.all.groupby(dataset.target).agg(cnt=('product_id',len))
print('total:',dt_['cnt'].sum(),'min:',dt_['cnt'].min(),'max:',dt_['cnt'].max(), 'median:', dt_['cnt'].median())
dt_

In [None]:
for i in range(10):
    dataset.get_view(i)

In [None]:
import copy
def parse_category_tree(paths):
    splitted = [path_.split('->') for path_ in paths]
    max_len_ = max([len(x) for x in splitted])
    arr_ = np.empty(shape=(len(splitted),max_len_),dtype=object)
    for i in range(len(splitted)):
        for j in range(len(splitted[i])):
            arr_[i][j] = splitted[i][j]


    levels_ = []
    for level in range(arr_.shape[1]-1):
        parents_values = set(arr_[:,level])
        if None in parents_values:
            parents_values = set(filter(lambda x: x is not None, parents_values))
        parents_childs = dict()
        for par in parents_values:
            parent_poses = (arr_[:,level]==par)
            childs_ = arr_[:,level+1][parent_poses]
            childs_ = childs_[childs_!=None]
            parents_childs.update({par: list(set(childs_))})
        levels_.append(parents_childs)
        
    for i in range(len(levels_)-2, -1,-1):
        # i <- i+1
        d_parent = copy.deepcopy(levels_[i])
        d_child = copy.deepcopy(levels_[i+1])
        for parent_category in d_parent:
            for j in range(len(d_parent[parent_category])):
                # for j in range(len(d_parent[parent_category][i])):
                if d_parent[parent_category][j] in d_child:
                    if len(d_child[d_parent[parent_category][j]]) !=0:
                        d_parent[parent_category][j] = {d_parent[parent_category][j]:d_child[d_parent[parent_category][j]]}
                    else:
                        # d_parent[parent_category][j] = {d_parent[parent_category][j]:None}
                        pass
                
        levels_[i] = copy.deepcopy(d_parent)
    tree_ = levels_[0]
    cat0 = list(tree_.keys())[0]
    return tree_[cat0]

In [None]:
tree = parse_category_tree(np.unique(dataset.train_targets))

In [None]:
pprint(tree)