In [None]:
# Most commonly used
import sys
import os
import json
import pickle
import math
from collections import Counter, defaultdict
from functools import partial
from tqdm import tqdm, trange
from colors import blue, red, green, cyan

# Numerical computation
import numpy as np
import torch
import torch.nn.functional as F

# Visualization
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_theme()
sns.set_context("talk")

sys.path.append('ANONYMOUS_ROOTDIR/develop/open-world/')
from utils import svd, reduce_and_visualize, load_clip, encode_clip, encode_clip_classification, train_clip_toy, ce_loss, uniform_loss, dual_ce_loss, simple_ce_loss
from datasets import ImageCaptionDataset, ClassificationDataset


def evaluate_retrieval(image_features, text_features):
    metrics = {}
    sim = image_features @ text_features.T
    for K in [1, 5, 10]:
        pred = sim.argsort(dim=-1)
        text_r = np.mean([i in pred[i, -K:] for i in range(len(pred))])

        pred = sim.argsort(dim=0)
        image_r = np.mean([i in pred[-K:, i] for i in range(len(pred))])

        metrics[f'Text R@{K}'] = text_r
        metrics[f'Image R@{K}'] = image_r
    return metrics


def evaluate_classification(image_features, text_features, labels):
    metrics = {}
    sim = image_features @ text_features.T
    for K in [1, 5, 10]:
        pred = sim.argsort(dim=-1)
        text_r = np.mean([labels[i] in pred[i, -K:] for i in range(len(pred))])
        metrics[f'Hit@{K}'] = text_r
    return metrics


def evaluate_binary_classification(image_features, text_features, labels):
    from sklearn.metrics import roc_auc_score
    metrics = {}
    sim = image_features @ text_features.T * 100
    probs = F.softmax(sim, dim=-1)[:, 1]
    roc_auc = roc_auc_score(labels, probs)
    metrics[f'ROC-AUC'] = roc_auc
    return metrics


def move_features(image_features, text_features, evaluate_func, direction_vec=None):
    all_metrics = {}
    if direction_vec is None:
        modality_gap = image_features.mean(axis=0) - text_features.mean(axis=0)
        modality_gap = modality_gap / modality_gap.norm()
        direction_vec = modality_gap
    
    for delta in np.arange(-5, 5, 0.25):
        modified_text_features = text_features + 0.5 * delta * direction_vec
        modified_text_features /= modified_text_features.norm(dim=-1, keepdim=True)

        modified_image_features = image_features - 0.5 * delta * direction_vec
        modified_image_features /= modified_image_features.norm(dim=-1, keepdim=True)

        # reduce_and_visualize(modified_image_features.numpy(), modified_text_features.numpy(), methods=['svd', 'pca'], n_dim=2)

        preds = (modified_image_features @ modified_text_features.T).argmax(dim=-1)

        gap_distance = (modified_text_features.mean(axis=0) - modified_image_features.mean(axis=0)).norm().item()

        metrics = evaluate_func(modified_image_features, modified_text_features)
        all_metrics[delta] = (metrics, gap_distance, preds)

        print(delta, metrics, gap_distance)
    return all_metrics


def move_features_along_hypersphere(image_features, text_features, evaluate_func):
    return "Impossible"


def plot_metrics(all_metrics, metric_name='Hit@1'):
    xs, ys = [], []
    for delta in sorted(all_metrics.keys()):
        metrics, gap_distance, preds = all_metrics[delta]
        xs.append(gap_distance)
        ys.append(metrics[metric_name])
    print(f'Optimal {metric_name}: {max(ys)}')

    minidx = xs.index(min(xs))
    for i in range(minidx + 1, len(xs)): xs[i] = -xs[i]
    plt.plot(xs, ys, 'o-')
    plt.xlabel('Gap Distance')
    plt.ylabel(metric_name)

    initial_gap = all_metrics[0][1]
    plt.axvline(initial_gap, color='k', linestyle='--')

    plt.show()

In [None]:
# Move features along direction computed on downstream tasks

model = load_clip()
dataset = ClassificationDataset(name='EuroSAT')
image_features, text_features = encode_clip_classification(model, dataset, prompt='a centered satellite photo of {}.')
labels = [item[1] for item in dataset]
metrics = evaluate_classification(image_features, text_features, labels)
print(metrics)

reduce_and_visualize(image_features.numpy(), text_features.numpy(), methods=['svd', 'pca', 'tsne', 'umap'], n_dim=2)

all_metrics = move_features(image_features, text_features, partial(evaluate_classification, labels=labels))
plot_metrics(all_metrics, metric_name='Hit@1')

In [None]:
# Move features along direction computed on MSCOCO

model = load_clip()

dataset = ImageCaptionDataset()
image_features, text_features = encode_clip(model, dataset)
direction_vec = image_features.mean(axis=0) - text_features.mean(axis=0)
direction_vec = direction_vec / direction_vec.norm()

dataset = ClassificationDataset(name='SVHN')
image_features, text_features = encode_clip_classification(model, dataset, prompt='a street sign of the number: "{}".')
labels = [item[1] for item in dataset]
metrics = evaluate_classification(image_features, text_features, labels)
print(metrics)

reduce_and_visualize(image_features.numpy(), text_features.numpy(), methods=['svd', 'pca', 'tsne', 'umap'], n_dim=2)

all_metrics = move_features(image_features, text_features, partial(evaluate_classification, labels=labels), direction_vec)
plot_metrics(all_metrics, metric_name='Hit@1')

In [None]:
# Retrieval

model = load_clip()
dataset = ImageCaptionDataset()
image_features, text_features = encode_clip(model, dataset)
metrics = evaluate_retrieval(image_features, text_features)
print(metrics)
reduce_and_visualize(image_features.numpy(), text_features.numpy(), methods=['svd', 'pca', 'tsne', 'umap'], n_dim=2)

all_metrics = move_features(image_features, text_features, evaluate_retrieval)
plot_metrics(all_metrics, metric_name='Image R@1')
plot_metrics(all_metrics, metric_name='Text R@1')

# Fine-tuning CLIP

In [None]:
dataset = ImageCaptionDataset(split='train', max_data_size=50000)
model = load_clip()
model.logit_scale.data = torch.log(torch.tensor(100))
logs, model = train_clip_toy(model, dataset, f'ANONYMOUS_ROOTDIR/develop/open-world/exps/pretrained_512d_refactor_t100/', batch_size=64, end_epoch=5)

In [None]:
dataset = ImageCaptionDataset()
model = load_clip()
logs, model = train_clip_toy(model, dataset, f'ANONYMOUS_ROOTDIR/develop/open-world/exps/pretrained_512d_uniform_refactor/', loss_funcs=[ce_loss, uniform_loss])

In [None]:
dataset = ImageCaptionDataset()
model = load_clip()
logs, model = train_clip_toy(model, dataset, f'ANONYMOUS_ROOTDIR/develop/open-world/exps/pretrained_512d_dualloss_refactor/', loss_funcs=[dual_ce_loss])

In [None]:
dataset = ImageCaptionDataset()
model = load_clip()
logs, model = train_clip_toy(model, dataset, f'ANONYMOUS_ROOTDIR/develop/open-world/exps/pretrained_512d_removehard_refactor/', loss_funcs=[simple_ce_loss])

# Downstream Task using Fine-tuned Models

In [None]:
model = load_clip('ANONYMOUS_ROOTDIR/develop/open-world/exps/pretrained_512d_refactor_t30/model_epoch_1.pt')
# dataset = ImageCaptionDataset(split='train', max_data_size=50000)
# dataset.data = dataset.data[:500]
dataset = ImageCaptionDataset(split='val')
image_features, text_features = encode_clip(model, dataset)
feature_dist = (image_features.mean(axis=0) - text_features.mean(axis=0)).norm().item()
print(feature_dist)
metrics = evaluate_retrieval(image_features, text_features)
print(metrics)
reduce_and_visualize(image_features.numpy(), text_features.numpy(), methods=['svd', 'pca'], n_dim=2)


In [None]:
model = load_clip('ANONYMOUS_ROOTDIR/develop/open-world/exps/pretrained_512d_refactor_t30/model_epoch_1.pt')
dataset = ClassificationDataset(name='CIFAR10')
image_features, text_features = encode_clip_classification(model, dataset)
feature_dist = (image_features.mean(axis=0) - text_features.mean(axis=0)).norm().item()
print(feature_dist)
labels = [item[1] for item in dataset]
metrics = evaluate_classification(image_features, text_features, labels)
print(metrics)
reduce_and_visualize(image_features.numpy(), text_features.numpy(), methods=['svd', 'pca'], n_dim=2)

In [None]:
plt.figure()
gaps = [0.2384, 0.3028, 0.5524, 0.6352, 0.7961, 1.0006]
acc_i = [0.0214, 0.1896, 0.1772, 0.2048, 0.2090, 0.1836]
acc_t = [0.0170, 0.1660, 0.1740, 0.2098, 0.2036, 0.1894]
xs = [1, 1/10, 1/20, 1/30, 1/50, 1/100]
plt.plot(xs, gaps, 'o-', label='Gap')
plt.plot(xs, acc_i, 'o-', label='Image R@1')
plt.plot(xs, acc_t, 'o-', label='Text R@1')
# plt a line at y=0.8262
plt.axhline(y=0.8262, color='k', linestyle='--')
plt.legend()
plt.xlabel('Temperature')

plt.figure()
gaps = [0.9407, 0.6450, 0.8455, 0.9346, 1.0092, 1.1241]
acc = [0.1918, 0.5036, 0.4525, 0.4544, 0.5065, 0.3348]
plt.plot(xs, gaps, 'o-', label='Gap')
plt.plot(xs, acc, 'o-', label='Acc')
plt.axhline(y=1.1136, color='k', linestyle='--')
plt.legend()
plt.xlabel('Temperature')

# Gap vs Prediction Overlap

In [None]:
print((preds1 == preds2).float().mean())
sim1 = image_features1 @ text_features1.t()
sim2 = image_features2 @ text_features2.t()

overlaps = []
for idx in range(len(sim1)):
    top_preds1 = sim1[idx].argsort().tolist()[::-1][:5]
    for pred in top_preds1: print(dataset.data[pred])
    print()
    top_preds2 = sim2[idx].argsort().tolist()[::-1][:5]
    for pred in top_preds2: print(dataset.data[pred])
    overlap = len(set(top_preds1) & set(top_preds2)) / len(set(top_preds1) | set(top_preds2))
    overlaps.append(overlap)
    break

# print(np.mean(overlaps))

# Fix initialization

In [None]:
model = load_clip('ANONYMOUS_ROOTDIR/develop/open-world/exps/random_t100/model_epoch_1.pt')
dataset = ImageCaptionDataset(split='train', max_data_size=50000)
dataset.data = dataset.data[:500]
# dataset = ImageCaptionDataset(split='val')
image_features, text_features = encode_clip(model, dataset)
feature_dist = (image_features.mean(axis=0) - text_features.mean(axis=0)).norm().item()
print(feature_dist)
metrics = evaluate_retrieval(image_features, text_features)
print(metrics)
reduce_and_visualize(image_features.numpy(), text_features.numpy(), methods=['svd', 'pca'], n_dim=2)


In [None]:
model = load_clip('ANONYMOUS_ROOTDIR/develop/open-world/exps/random_t100_fix_init/model_epoch_1.pt')
dataset = ImageCaptionDataset(split='train', max_data_size=50000)
dataset.data = dataset.data[:500]
w, _, _ = torch.load('ANONYMOUS_ROOTDIR/develop/open-world/exps/random_t100_fix_init/w.pt')
# dataset = ImageCaptionDataset(split='val')
image_features, text_features = encode_clip(model, dataset)
text_features = text_features @ w.T
feature_dist = (image_features.mean(axis=0) - text_features.mean(axis=0)).norm().item()
print(feature_dist)
metrics = evaluate_retrieval(image_features, text_features)
print(metrics)
reduce_and_visualize(image_features.numpy(), text_features.numpy(), methods=['svd', 'pca'], n_dim=2)
