In [157]:
import torch
import time
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

from data.swarmset import ContinuingDataset, SwarmDataset
from networks.embedding import NoveltyEmbedding
from networks.archive import DataAggregationArchive
from networks.ensemble import Ensemble
import numpy as np
from scipy import ndimage
import random
from sklearn.manifold import TSNE
import matplotlib
import matplotlib.pyplot as plt

matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42

def CSVLineToVec(line):
    line_list = line.strip().replace("\n", "").split(",")
    float_list = []
    for i in line_list:
        float_list.append(float(i))
    float_list = np.array(float_list)
    return float_list

In [293]:
# For single Sensor Baseline Model
# TRUTH_FILE = "validation-data-two-sensor.txt"
# TRUTH_FILE = "validation-data-baseline.txt"

"""
RANDOM - Only Random Weight Initialization
"""
# VALIDATION_FILE = "validation-data-baseline.txt"
# VALIDATION_DATA = SwarmDataset("../data/validation-easy-model", rank=0)
# TESTING_FILE = "original-hand-labeled-classes.txt"
# TESTING_DATA = SwarmDataset("../data/full-mini", rank=0)
# ENSEMBLE_PATH = None

"""
BASELINE - Pretraining only. No HIL.
"""
# VALIDATION_FILE = "validation-data-baseline.txt"
# VALIDATION_DATA = SwarmDataset("../data/validation-easy-model", rank=0)
# TESTING_FILE = "original-hand-labeled-classes.txt"
# TESTING_DATA = SwarmDataset("../data/full-mini", rank=0)
# ENSEMBLE_PATH = "../checkpoints/ensembles/01-20-23-baseline"

"""
BASELINE - Pretraining + HIL.
"""
VALIDATION_FILE = "validation-data-baseline.txt"
VALIDATION_DATA = SwarmDataset("../data/validation-easy-model", rank=0)
TESTING_FILE = "original-hand-labeled-classes.txt"
TESTING_DATA = SwarmDataset("../data/full-mini", rank=0)
# TESTING_FILE = "heuristic-simple-model-classes.txt"
# TESTING_DATA = SwarmDataset("../data/filtered-full", rank=0)
ENSEMBLE_PATH = "../checkpoints/ensembles/01-27-23-BLH-HIL-B"


OUT = "../data/oracle"
validation_classes = []
with open(os.path.join(OUT, VALIDATION_FILE), "r") as f:
    lines = f.readlines()
    validation_classes = [-1 for i in range(len(lines))]
    for line in lines:
        triplet = CSVLineToVec(line)
        validation_classes[int(triplet[0])] = int(triplet[1])

testing_classes = []
with open(os.path.join(OUT, TESTING_FILE), "r") as f:
    lines = f.readlines()
    testing_classes = [-1 for i in range(len(lines))]
    for line in lines:
        triplet = CSVLineToVec(line)
        testing_classes[int(triplet[0])] = int(triplet[1])

validation_set = []
testing_set = []
for i, _class in enumerate(testing_classes):
    testing_set.append((i, _class))

for i, _class in enumerate(validation_classes):
    validation_set.append((i, _class))

print(len(validation_set), len(testing_set))

200 1000


In [294]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
ensemble = Ensemble(size=3, output_size=5, lr=15e-4, learning_decay=0.7, decay_step=1, threshold=9.0, weight_decay=1e-4, new_model=True, init="Random")
if ENSEMBLE_PATH:
    ensemble.load_ensemble(ENSEMBLE_PATH, full=True)
ensemble.eval_mode()


# for metric in ["TRAIN", "VAL"]:
for metric in ["VAL"]:
    a = []
    sampled_dataset = TESTING_DATA if metric == "TRAIN" else VALIDATION_DATA
    c_set = testing_set if metric == "TRAIN" else validation_set
    classes = testing_classes if metric == "TRAIN" else validation_classes

    # random.shuffle(c_set)
    print("=" * 20)
    print(f"{metric} results")
    print(f"Class Set of Size: {len(c_set)}")
    for i in range(len(ensemble.ensemble)):
        embedded_positions = []
        for j, c in enumerate(classes):
            image, _ = sampled_dataset[j][0], sampled_dataset[j][1][0]
            image = np.expand_dims(image, axis=0)
            embed = ensemble.ensemble[i].forward(torch.tensor(image, device=device, dtype=torch.float))
            embed = embed.detach().cpu().squeeze(dim=0).numpy()
            embedded_positions.append(embed)

        # Evaluate Accuracy
        correct, total = 0, 0
        class_counts = {i:0 for i in range(max(classes) + 1)}
        class_accuracy = {i:0 for i in range(max(classes) + 1)}
        for l in range(len(c_set)):
            x, _classX = c_set[l]
            break_class_X = False
            if _classX == 0:
                continue
            for j in range(l, len(c_set)):
                y, _classY = c_set[j]
                if x == y or _classX != _classY:
                    continue
                for k in range(len(c_set)):
                    z, _classZ = c_set[k]
                    # if _classZ == 0:
                    #     continue
                    if _classZ == _classX or x == z or y == z:
                        continue
                    positive_dist = np.linalg.norm(embedded_positions[x] - embedded_positions[y])
                    negative_dist = np.linalg.norm(embedded_positions[x] - embedded_positions[z])
                    if positive_dist < negative_dist:
                        correct += 1
                        class_accuracy[_classX] += 1
                    total += 1
                    class_counts[_classX] += 1

        print(class_counts)
        print(f"CLASS ACCURACY (Out of {total} triplets):")
        for class_value in class_accuracy:
            if class_value == 0:
                continue
            print(f"{class_value}: {class_accuracy[class_value] * 100 / class_counts[class_value]}")
        acc = correct * 100 / total
        a.append(acc)
        print(f"Ensemble {i} ~ Accuracy: {acc}")

    print(f"Average: {sum(a) / 3}")
    print("=" * 20)

Adjusting learning rate of group 0 to 1.5000e-03.
Adjusting learning rate of group 0 to 1.5000e-03.
Adjusting learning rate of group 0 to 1.5000e-03.
VAL results
Class Set of Size: 200
{0: 0, 1: 60723, 2: 27846, 3: 6876, 4: 108558, 5: 12408}
CLASS ACCURACY (Out of 216411 triplets):
1: 95.60463086474647
2: 73.90648567119155
3: 95.08435136707388
4: 90.5396193739752
5: 86.78272082527401
Ensemble 0 ~ Accuracy: 89.74959683195401
{0: 0, 1: 60723, 2: 27846, 3: 6876, 4: 108558, 5: 12408}
CLASS ACCURACY (Out of 216411 triplets):
1: 80.94791100571447
2: 67.3202614379085
3: 80.36649214659685
4: 67.37964958823855
5: 75.99129593810444
Ensemble 1 ~ Accuracy: 72.0855224549584
{0: 0, 1: 60723, 2: 27846, 3: 6876, 4: 108558, 5: 12408}
CLASS ACCURACY (Out of 216411 triplets):
1: 94.00885990481366
2: 75.88522588522588
3: 88.91797556719023
4: 89.39092466699829
5: 92.69825918762089
Ensemble 2 ~ Accuracy: 89.12347339090897
Average: 83.65286422594046
