In [None]:
# Plot PR Curves and generate latex for displaying them

In [None]:
import torch
from torch import nn
from torchvision import transforms, models
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import Dataset

from PIL import Image
from pathlib import Path
import numpy as np
from tqdm import tqdm
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, roc_curve, precision_recall_curve


import json

root = Path('/home/justin/Desktop/arid-crops')
stats_root = Path('/home/justin/Desktop/thesis-results')

In [None]:
class AridCropDataset(Dataset):
    
    def __init__(self, crop_root, object_name, transform):
        self.crop_root = Path(crop_root)
        self.transform = transform
        self.items = []
        pos = []
        neg = []
        for obj in self.crop_root.iterdir():
            if not obj.is_dir():
                continue
            obj_instance_id = obj.stem
            label = 1 if obj_instance_id.startswith(object_name) else 0
            
            for img_path in obj.iterdir():
                if label == 1:
                    pos.append((img_path, label))
                else:
                    neg.append((img_path, label))
                       
        self.items.extend(pos)
        self.items.extend(neg)
        
        
    def __len__(self):
        return len(self.items)
    
    
    def __getitem__(self, idx):
        path, lbl = self.items[idx]
        img = Image.open(path)
        img = img.convert('RGB')
        sample = self.transform(img)
        return sample, lbl

In [None]:
target_objects = {
    'apple_0': 'green apple fruit',
    'apple_1': 'yellow apple fruit',
    'apple_2': 'red apple fruit',
    'banana_0': 'green banana', 
    'banana_1': 'yellow banana',
    'banana_2': 'yellow banana',
    'bell_pepper_0': 'green bell pepper', 
    'bell_pepper_1': 'red bell pepper',
    'bell_pepper_2': 'yellow bell pepper',
    'bowl_0': 'white bowl',
    'bowl_1': 'blue bowl',
    'bowl_2': 'white bowl',
    'cell_phone_0': 'black smartphone',
    'cell_phone_1': 'white smartphone',
    'cell_phone_2': 'black white smartphone',
    'coffee_mug_0': 'blue coffee mug',
    'coffee_mug_1': 'black white coffee mug',
    'coffee_mug_2': 'white coffee mug',
    'keyboard_0': 'black computer keyboard',
    'keyboard_1': 'white computer keyboard',
    'keyboard_2': 'black computer keyboard',
    'orange_0': 'orange fruit',
    'orange_1': 'orange fruit',
    'orange_2': 'orange fruit',
    'scissors_0': 'green handle scissors',
    'scissors_1': 'purple handle scissors',
    'scissors_2': 'black handle scissors',
    'water_bottle_0': 'empty water bottle',
    'water_bottle_1': 'empty water bottle',
    'water_bottle_2': 'empty water bottle',
    'toothbrush_0': 'toothbrush',
    'toothbrush_2': 'green white toothbrush',
    'ball_0': 'blue ball',
    'binder_0': 'green three ring binder',
    'calculator_0': 'ti calculator',
    'camera_0': 'silver digital camera',
    'cap_2': 'black baseball hat',
    'cereal_box_0': 'corn flakes cereal box',
    'comb_0': 'brown hair comb',
    'dry_battery_0': 'double a duracell battery',
    'flashlight_2': 'black flashlight',
    'greens_0': 'leafy green vegetable',
    'lemon_0': 'yellow lemon',
    'lightbulb_0': 'white lightbulb',
    'mushroom_0': 'white mushroom',
    'pliers_1': 'needlenose pliers',
    'shampoo_0': 'grey shampoo bottle',
    'toothpaste_0': 'toothpaste tube',
    'stapler_0': 'grey stapler',
    'potato_0': 'potato',
    'sponge_0': 'yellow sponge',
    'tomato_0': 'red tomato',
}
print(len(target_objects))

In [None]:
def get_tranform(mean, std):
    return transforms.Compose([
        transforms.Resize((128, 128)),
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ])

In [None]:
# Ran on floydhub due to long compute time.


# import random
# for obj, search_term in target_objects.items():
#     print(obj)
#     with open(Path(stats_root / f'stats-{obj}.json')) as outfile:
#         stats = json.load(outfile)

#     model = models.resnet152(pretrained=True)
#     model.fc = nn.Linear(in_features=2048, out_features=1, bias=True)
#     device = torch.device('cpu')
#     model.load_state_dict(torch.load(Path(stats_root / f'{obj}.pth'), map_location=device))
#     dset = AridCropDataset(root, obj, get_tranform(stats['mean'], stats['std']))
#     loader = DataLoader(dset, batch_size=600, shuffle=False, num_workers=8)
    
#     with torch.no_grad():
#         model.eval()
#         labels = []
#         preds = []
        
#         for inputs, lbls in tqdm(loader):
#             outputs = model(inputs)
#             labels.extend([l.item() for l in lbls])
#             preds.extend([o.item() for o in outputs.sigmoid()])

#         arid_f1 = f1_score(labels, np.array([1 if p > 0.5 else 0 for p in preds]))
#         roc_auc = roc_auc_score(labels, np.array(preds))
#         precision, recall, thresholds = precision_recall_curve(labels, np.array(preds))
#         acc = accuracy_score(labels, np.array([1 if p > 0.5 else 0 for p in preds]))
#         results = {
#             'f1': arid_f1,
#             'roc_auc': roc_auc,
#             'precision': precision.tolist(),
#             'recall': recall.tolist(),
#             'accuracy': acc,
#         }
#         with open(Path(stats_root / f'results-{obj}.json'), 'w') as outfile:
#             json.dump(results, outfile)

In [None]:
# Plot results
d = {}
for obj, search in target_objects.items():
    with open(Path(stats_root / f'results-{obj}.json')) as outfile:
        stats = json.load(outfile)
    f1 = stats['f1']
    roc_auc = stats['roc_auc']
    accuracy = stats['accuracy']
    precision = stats['precision']
    recall = stats['recall']
    d[obj] = {
        'p': precision,
        'r': recall,
        'f1': f1,
    }
    
o = sorted(list(target_objects.keys()))

groups = []
for i in range(10):
    t = []
    t.append(o[i*5])
    t.append(o[i*5 + 1])
    t.append(o[i*5 + 2])
    t.append(o[i*5 + 3])
    t.append(o[i*5 + 4])
    groups.append(t)
    
    
t = []
t.append(o[50])
t.append(o[51])
groups.append(t)

# for idx, g in enumerate(groups):
#     plot(d, g, f'pr-curve-{idx}.png')



In [None]:
def plot(data, groups, f_name, w=10, t=90):
    import numpy as np
    import matplotlib.pyplot as plt
    
    plt.figure(figsize=(25,15))
    font = {'family' : 'DejaVu Sans',
        'weight' : 'bold',
        'size'   : 30}

    plt.rc('font', **font)
    
    for g in groups:
        plt.plot(d[g]['r'], d[g]['p'], linewidth=4)
        
    plt.ylabel('Precision')
    plt.xlabel('Recall')
    
    plt.legend(groups)
    plt.tight_layout()
    plt.savefig(f_name)

In [None]:
def max_f1(ps, rs):
    mx = 0
    for p, r in zip(ps, rs):
        if p + r == 0:
            continue
        f1 = (2 * p * r) / (p + r)
        if f1 > mx:
            mx = f1
    return mx

In [None]:
max_f1(d['ball_0']['p'], d['ball_0']['r'])

In [None]:
rs = {}
for g in groups:
    a = r'\begin{tabular}{| c | c | c | c | c | }\hline'
    b = []
    c = []
    for o in g:
        data = d[o]
        f1 = round(max_f1(data['p'], data['r']), 3)
        b.append(r'\textbf{' + o.replace('_', '\_') + r'}')
        c.append(f'{f1}')
    e = r'\hline \end{tabular}'
    print(a)
    print(f'{"&".join(b)} \\\\')
    print(r'\hline')
    print(f'{"&".join(c)} \\\\')
    print(e)
    print(r'\\[1.0ex]')

    