In [1]:
%load_ext autoreload
%autoreload 2

In [None]:

from tqdm import tqdm
from ft_utils.BatchImageGenerator import BatchImageGenerator
from ft_utils.BatchImageClassifier import BatchImageClassifier
from ft_utils.utils import BATCH_SIZE
import numpy as np
import torch

In [None]:
attributes = ["blonde", "not blonde"]

batch_classifier = BatchImageClassifier("out_batch_transfer")
batch_generator = BatchImageGenerator("out_batch_transfer", True)

scores = []
latent_vectors_list = []

text_features = batch_classifier.tokenize_attributes(attributes)

print("Scoring images")
for i in tqdm(range(0, round(200_000 / BATCH_SIZE))):
    probs = batch_classifier.classify_from_batch(i*BATCH_SIZE, BATCH_SIZE, text_features)
    scores.extend([t[0,0].item() for t in probs]) # Use extend for efficiency
    latent_vectors_list.append(batch_generator.load_w_batch(i*BATCH_SIZE, BATCH_SIZE))

latent_vectors = np.concatenate(latent_vectors_list, axis=0)
scores = np.array(scores).reshape(-1, 1)

In [None]:
import ft_utils.InterfaceGANMethod
import ft_utils.AverageMethod
import ft_utils.GaussianMethod

import importlib
importlib.reload(ft_utils.InterfaceGANMethod)
importlib.reload(ft_utils.AverageMethod)
importlib.reload(ft_utils.GaussianMethod)

methods = {
    "InterfaceGAN": ft_utils.InterfaceGANMethod.InterfaceGANMethod(),
    "Average": ft_utils.AverageMethod.AverageMethod(),
    "Gaussian": ft_utils.GaussianMethod.GaussianMethod()
}

print("Starting training...")
for method_name, method in methods.items():
    print("{}: Training...".format(method_name))
    method.train(latent_vectors, scores)

print("Finished training")

In [None]:
print("Measuring...")

method_scores = {
    "InterfaceGAN": {
        "openclip_scores": [],
    },
    "Average": {
        "openclip_scores": [],
    },
    "Gaussian": {
        "openclip_scores": [],
    },
    "Original": {
        "openclip_scores": [],
    }
}

def measure_seed(initial_w_vector):
    ws_first = torch.from_numpy(initial_w_vector.astype(np.float32)[np.newaxis, :].repeat(14, axis=1)).to("mps")
    original_image = batch_generator.generate_from_w_vec(ws_first[0], filename=None)
    original_scores = batch_classifier.classify_image_vec(original_image, text_features)
    # if original_scores[0][1] < 0.3: return

    method_scores["Original"]["openclip_scores"].append(original_scores.detach().numpy())

    for method_name, method in methods.items():
        resulting_w_vector = method.latent_walk(initial_w_vector)
        ws = torch.from_numpy(resulting_w_vector.astype(np.float32)[np.newaxis, :].repeat(14, axis=1)).to("mps")
        image = batch_generator.generate_from_w_vec(ws[0], filename=None)
        new_scores = batch_classifier.classify_image_vec(image, text_features)
        method_scores[method_name]["openclip_scores"].append(new_scores.detach().numpy())

batch_idx_start = 1 + round(200_000 / BATCH_SIZE)
batch_idx = batch_idx_start
seed_idx = 0
num_samples = 100

current_sample_group = batch_generator.load_w_batch(batch_idx*BATCH_SIZE, BATCH_SIZE)

for _ in range(num_samples):
    if seed_idx >= 64:
        seed_idx = 0
        batch_idx += 1
        current_sample_group = batch_generator.load_w_batch(batch_idx*BATCH_SIZE, BATCH_SIZE)
        print(f"Batch {batch_idx - batch_idx_start + 1}/{num_samples/64:.0f}")
    measure_seed(current_sample_group[seed_idx:seed_idx+1])
    seed_idx += 1

og_scores = method_scores["Original"]["openclip_scores"]
og_avg = np.average([score[0, 0] for score in og_scores])

for method_name, method in methods.items():
    scores = method_scores[method_name]["openclip_scores"]
    avg = np.average([score[0, 0] for score in scores])
    print(f"{method_name}: score improvement {100*(avg - og_avg):+.1f}%")
