In [2]:
import pandas as pd
from typing import *

In [3]:
import numpy as np

In [4]:
df = pd.read_csv("rotten_tomatoes_reviews.csv")

In [26]:
from pathlib import Path

In [5]:
def listify(o):
    if o is None: return []
    if isinstance(o, list): return o
    if isinstance(o, str): return [o]
    if isinstance(o, Iterable): return list(o)
    return [o]

In [6]:
def df_names_to_idx(names, df):
    "Return the column indexes of `names` in `df`."
    if not listify(names): names = [names]
    if isinstance(names[0], int): return names
    return [df.columns.get_loc(c) for c in names]

In [7]:
df.columns

Index(['Freshness', 'Review'], dtype='object')

In [8]:
df_names_to_idx(["Freshness","Review"],df)

[0, 1]

In [9]:
def is1d(a:Collection)->bool:
    "Return `True` if `a` is one-dimensional"
    return len(a.shape) == 1 if hasattr(a, 'shape') else True

In [10]:
def _maybe_squeeze(arr): return (arr if is1d(arr) else np.squeeze(arr))

In [11]:
def from_df(df, path ='.', cols=0, processor=None, **kwargs):
        "Create an `ItemList` in `path` from the inputs in the `cols` of `df`."
        inputs = df.iloc[:,df_names_to_idx(cols, df)]
        assert inputs.isna().sum().sum() == 0, f"You have NaN values in column(s) {cols} of your dataframe, please fix it."
        res = _maybe_squeeze(inputs.values)
        return res

In [None]:
np.squeeze(df.iloc[:,[0,1]].values)

In [12]:
inputs = df.iloc[:,[0,1]]

In [13]:
inputs.isna().sum().sum()

0

In [None]:
inputs.values

In [14]:
res = _maybe_squeeze(inputs.values)

In [15]:
res[0]

array([1,
       " Manakamana doesn't answer any questions, yet makes its point: Nepal, like the rest of our planet, is a picturesque but far from peaceable kingdom."],
      dtype=object)

In [19]:
df.loc[0].tolist()

[1,
 " Manakamana doesn't answer any questions, yet makes its point: Nepal, like the rest of our planet, is a picturesque but far from peaceable kingdom."]

In [20]:
class ListContainer():
    def __init__(self, items): self.items = listify(items)
    def __getitem__(self, idx):
        if isinstance(idx, (int,slice)): return self.items[idx]
        if isinstance(idx[0],bool):
            assert len(idx)==len(self) # bool mask
            return [o for m,o in zip(idx,self.items) if m]
        return [self.items[i] for i in idx]
    def __len__(self): return len(self.items)
    def __iter__(self): return iter(self.items)
    def __setitem__(self, i, o): self.items[i] = o
    def __delitem__(self, i): del(self.items[i])
    def __repr__(self):
        res = f'{self.__class__.__name__} ({len(self)} items)\n{self.items[:10]}'
        if len(self)>10: res = res[:-1]+ '...]'
        return res

In [44]:
def compose(x, funcs, *args, order_key='_order', **kwargs):
    key = lambda o: getattr(o, order_key, 0)
    for f in sorted(listify(funcs), key=key): x = f(x, **kwargs)
    return x

class ItemList(ListContainer):
    def __init__(self, items, path='.', tfms=None):
        super().__init__(items)
        self.path,self.tfms = Path(path),tfms

    def __repr__(self): return f'{super().__repr__()}\nPath: {self.path}'
    def new(self, items): return self.__class__(items, self.path, tfms=self.tfms)
    
    def  get(self, i): return i
    def _get(self, i): return compose(self.get(i), self.tfms)
    
    def __getitem__(self, idx):
        ("Getitem called")
        res = super().__getitem__(idx)
        if isinstance(res,list): return [self._get(o) for o in res]
        return self._get(res)

In [45]:
class TextList(ItemList):
    @classmethod
    def from_files(cls, path, extensions='.txt', recurse=True, include=None, **kwargs):
        return cls(get_files(path, extensions, recurse=recurse, include=include), path, **kwargs)
    
    @classmethod
    def from_df(cls,df, path ='.', cols=0, processor=None, **kwargs):
        "Create an `ItemList` in `path` from the inputs in the `cols` of `df`."
        inputs = df.iloc[:,df_names_to_idx(cols, df)]
        assert inputs.isna().sum().sum() == 0, f"You have NaN values in column(s) {cols} of your dataframe, please fix it."
        res = cls(_maybe_squeeze(inputs.values),path=path,**kwargs)
        return res
    
    
        
    
    def get(self, i):
        if isinstance(i, Path): return read_file(i)
        if isinstance(i,pd.DataFrame): return df.loc(i).tolist()
        return i

In [46]:
il = TextList.from_df(df,cols=["Freshness","Review"])

In [None]:
il.

In [58]:
len(il.items) - 96000

384000

In [64]:
il.new

<bound method ItemList.new of TextList (480000 items)
[array([1,
       " Manakamana doesn't answer any questions, yet makes its point: Nepal, like the rest of our planet, is a picturesque but far from peaceable kingdom."],
      dtype=object), array([1,
       " Wilfully offensive and powered by a chest-thumping machismo, but it's good clean fun."],
      dtype=object), array([0,
       ' It would be difficult to imagine material more wrong for Spade than Lost & Found.'],
      dtype=object), array([0,
       " Despite the gusto its star brings to the role, it's hard to ride shotgun on Hector's voyage of discovery."],
      dtype=object), array([0,
       " If there was a good idea at the core of this film, it's been buried in an unsightly pile of flatulence jokes, dog-related bad puns and a ridiculous serial arson plot."],
      dtype=object), array([0,
       ' Gleeson goes the Hallmark Channel route, damaging an intermittently curious entry in the time travel subgenre.'],
      dty

In [85]:
def split_df(items,valid_pct = 0.2):
    cut = int(valid_pct * len(items))
    train_idxs = len(items) - cut
    print(train_idxs)
    train = items[:train_idxs]
    valid = items[train_idxs:]
    return train,valid
    
    

In [86]:
#export
class SplitData():
    def __init__(self, train, valid): self.train,self.valid = train,valid
        
    def __getattr__(self,k): return getattr(self.train,k)
    #This is needed if we want to pickle SplitData and be able to load it back without recursion errors
    def __setstate__(self,data:Any): self.__dict__.update(data) 
    
    @classmethod
    def split_by_func(cls, il, f):
        lists = map(il.new,f)
        return cls(*lists)

    def __repr__(self): return f'{self.__class__.__name__}\nTrain: {self.train}\nValid: {self.valid}\n'

In [66]:
from functools import partial

In [87]:
splitter = partial(split_df,il.items)

In [88]:
sd = SplitData.split_by_func(il,splitter(0.2))

384000


array([1,
       " Manakamana doesn't answer any questions, yet makes its point: Nepal, like the rest of our planet, is a picturesque but far from peaceable kingdom."],
      dtype=object)

In [None]:
def _label_from_list(self, labels:Iterator, label_cls:Callable=None, from_item_lists:bool=False, **kwargs)->'LabelList':
        "Label `self.items` with `labels`."
        if not from_item_lists:
            raise Exception("Your data isn't split, if you don't want a validation set, please use `split_none`.")
        labels = array(labels, dtype=object)
        label_cls = self.get_label_cls(labels, label_cls=label_cls, **kwargs)
        y = label_cls(labels, path=self.path, **kwargs)
        res = self._label_list(x=self, y=y)
        return res

    def label_from_df(self, cols:IntsOrStrs=1, label_cls:Callable=None, **kwargs):
        "Label `self.items` from the values in `cols` in `self.inner_df`."
        labels = self.inner_df.iloc[:,df_names_to_idx(cols, self.inner_df)]
        assert labels.isna().sum().sum() == 0, f"You have NaN values in column(s) {cols} of your dataframe, please fix it."
        if is_listy(cols) and len(cols) > 1 and (label_cls is None or label_cls == MultiCategoryList):
            new_kwargs,label_cls = dict(one_hot=True, classes= cols),MultiCategoryList
            kwargs = {**new_kwargs, **kwargs}
        return self._label_from_list(_maybe_squeeze(labels), label_cls=label_cls, **kwargs)

In [93]:
type(sd.train.items)

list

In [94]:
import numba