# Import libraries

In [None]:
import os
import torch
from pathlib import Path
import pandas
import pdb; 
%load_ext autoreload
%autoreload 2

# Specify a image folder

In [None]:
IMAGE_FOLDER = Path(os.getcwd()).parent.parent.parent/'Data'/'houses/resize256jpg'

#test
assert os.path.isdir(IMAGE_FOLDER), str(str(IMAGE_FOLDER) + " is not an existing directory.")

# Specify columns

In [None]:
UNLABELED_TAG = "UNLABELED"
IGNORE_TAG = "IGNORE"

ATTRIBUTES = {
    "house_color":["WhiteGreyHouse","BrownHouse","BrickHouse","RedPinkOrangeHouse","BlueHouse","MixedHouse"],
    "fence_type":["WhitePicketFence","BlackFence","NoFence","ChainFence","BrownFence"],
    "steps_up":["NoSteps","Steps","LargePorch"]
}
for key in ATTRIBUTES.keys():
    ATTRIBUTES[key].append(IGNORE_TAG)
    ATTRIBUTES[key].append(UNLABELED_TAG)

# Specify a CSV (existing or not) and Image Column

In [None]:
OVERWRITE_CSV = False
CSV_PATH = Path('combined128.csv').absolute()
CSV_IMAGE_COLUMN = 'image_path' #relative path of image inside IMG_FOLDER e.g. train/cat/10.jpg

#test
assert(os.path.isdir(CSV_PATH.parent))
if (not os.path.isfile(CSV_PATH)) or OVERWRITE_CSV:
    print("OVERWRITING CSV ")
    from lib.prep import create_csv_with_image_paths
    output = create_csv_with_image_paths(CSV_PATH, CSV_IMAGE_COLUMN, IMAGE_FOLDER, list(ATTRIBUTES.keys()))

# Load Widget
Widget loads subset (4 or so) images for user to label

In [None]:
from lib.widgets import MultilabelerWidget
        
mlw = MultilabelerWidget(csv_path = CSV_PATH, image_folder = IMAGE_FOLDER, image_column=CSV_IMAGE_COLUMN, attributes=ATTRIBUTES, width = 300)

# Load labeled images into dataset

In [None]:
dtype_dict = {}
for attribute in ATTRIBUTES:
    dtype_dict[attribute]='category'

df = pandas.read_csv(CSV_PATH,dtype=dtype_dict)

In [None]:
df

In [None]:
import fastai
from fastai import *
from fastai.vision import *
from fastai.vision.data import *

In [None]:
porches_df = df[df['steps_up']!=UNLABELED_TAG]
porches_df = porches_df[porches_df['steps_up']!=IGNORE_TAG]

In [None]:
porches_df

fastai loading taken from
https://gist.github.com/yang-zhang/ec071ae4775c2125595fd80f40efb0d6#file-multi-face-ipynb

In [None]:
from lib.prep import LabelCls
il = ImageList.from_df(df=porches_df, path='/', folder=IMAGE_FOLDER, cols = CSV_IMAGE_COLUMN)
sil = il.split_by_rand_pct(0.4,2)
lsil = sil.label_from_df(cols=list(ATTRIBUTES), label_cls=LabelCls, ATTRIBUTES=ATTRIBUTES)

tfms = get_transforms(flip_vert=False, max_rotate= 10,xtra_tfms=[])
for tfm in tfms:
    for subtfm in tfm:
        subtfm.use_on_y = False
        
lsil.transform(tfms, tfm_y=False)
blsil = lsil.databunch(num_workers=4, bs=10)

In [None]:
blsil

In [None]:
from torchvision.utils import make_grid , save_image
x, y = blsil.one_batch()
Image(make_grid(x))

# Specify a save location for the classifier model

In [None]:
CLASSIFIER_EXPORT = Path(os.getcwd())/'classifier_export.pt'

In [None]:
def myloss(input,target):
    target = target.long()
    one_hot_map = blsil.y.one_hot_map
    input_split = []
    losses = 0
    for segment_endpoints in blsil.y.attribute_label_endpoints:
        attribute_input = input[:,segment_endpoints[0]:segment_endpoints[1]]
        attribute_target = target[:,segment_endpoints[0]:segment_endpoints[1]]
        assert(torch.sum(attribute_target)==attribute_target.shape[0]), attribute_target
        mask = (1-attribute_target[:,-1:])*(1-attribute_target[:,-2:-1])
        masked_target = attribute_target[:,:] * mask
        masked_input = attribute_input[:,:] * mask.float()
        assert(torch.sum(masked_target)<=masked_target.shape[0]),masked_target
        attribute_loss = None
        if segment_endpoints[1]-segment_endpoints[0] > 1:
            masked_target = masked_target.argmax(dim=1)
            attribute_loss = F.cross_entropy(masked_input, masked_target)
        else:
            attribute_loss = F.l1_loss(attribute_input, attribute_target.unsqueeze(1))
        losses+=(attribute_loss)
    return  losses
        

In [None]:
class Attribute_accuracy_metric:
    
    def __init__(self,attribute_idx):
        print(attribute_idx)
        self.segment_start, self.segment_end = blsil.y.attribute_label_endpoints[attribute_idx]
        self.func = self.__call__
        self.name = list(ATTRIBUTES)[attribute_idx]
        
        
    def __call__(self, input_targs):
        input_segment = input[self.segment_starts : self.segment_end]
        target_segment = targs[self.segment_starts : self.segment_end]
        return accuracy(input_segment, target_segment)
        

In [None]:
metrics=[Attribute_accuracy_metric(idx) for idx in range(len(blsil.y.attribute_label_endpoints))]
metrics = []
learn = cnn_learner(blsil, models.resnet18, metrics=metrics, pretrained=True, callback_fns=ShowGraph)
learn.loss_func = myloss

In [None]:
learn.fit(4)

In [None]:
def train_model(sz, bs, lr):
    learn.data=get_data(sz, bs)
    learn.freeze()
    learn.fit_one_cycle(5, slice(lr))
    learn.unfreeze()
    learn.fit_one_cycle(5, slice(lr/20, lr/2), pct_start=0.1)
    learn.save(f"{target}")



In [None]:
from lib.widgets import MultilabelerActiveLearningWidget

In [None]:
from lib.widgets import MultilabelerActiveLearningWidget
        
mlw = MultilabelerActiveLearningWidget(learner = learn, classifier_export= CLASSIFIER_EXPORT, 
                         csv_path = CSV_PATH,
                                       
                         image_folder = IMAGE_FOLDER, 
                         image_column=CSV_IMAGE_COLUMN, 
                         attributes=ATTRIBUTES,
                                       unlabeled_tag = UNLABELED_TAG,
                         width = 600)

# Train VAE

# Load classifier / train new classifier

# Load classifier + vae + modification vector for each column