In [1]:
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn
import util.util_validation as ut_val

from torchvision import transforms, datasets
from tqdm import tqdm

from networks.resnet_big import model_dict

seaborn.set_style("darkgrid")

In [2]:
models_dict = {"CE_baseline": ["./save/SupCE/animals10/SupCE_animals10_resnet18_lr_0.125_decay_0.0001_bsz_26_trial_0_baseline_cosine/models/ckpt_epoch_500.pth", None],
               "CE_diffAug": ["./save/SupCE/animals10_diff_-1+4000/SupCE_animals10_diff_-1+4000_resnet18_lr_0.125_decay_0.0001_bsz_26_trial_0_diffAug_cosine/models/last.pth", None],
               "CE_diffAugAllAug": ["./save/SupCE/animals10_diff_-1+4000/SupCE_animals10_diff_-1+4000_resnet18_lr_0.125_decay_0.0001_bsz_26_trial_0_diffAugAllAug_cosine/models/last.pth", None],
               "SupCon_baseline": ["./save/SupCon/animals10_diff_-1/SupCon_animals10_diff_-1_resnet18_lr_0.125_decay_0.0001_bsz_26_temp_0.1_trial_0_try3_cosine/models/last.pth", ""],
               "SupCon_diffCSameSAug": ["./save/SupCon/animals10_diff_-1+4000/SupCon_animals10_diff_-1+4000_resnet18_lr_0.125_decay_0.0001_bsz_26_temp_0.1_trial_0_colorAugSameShapeAug_cosine/models/last.pth", ""],
               "SupConHybrid_diffColorAug": ["./save/SupCon/animals10_diff_-1+4000/SupConHybrid_animals10_diff_-1+4000_resnet18_lr_0.125_decay_0.0001_bsz_26_temp_0.1_trial_0_colorAug_cosine/models/last.pth", ""]}

In [3]:
cuda_device = 0

root_model = models_dict["SupConHybrid_diffColorAug"][0]

root_dataset = "./datasets/adaIN/shape_texture_conflict_animals10_many/"

path_save, path_run_md, epoch = ut_val.get_paths_to_embeddings_and_run_md(root_model)

params = ut_val.read_parameters_from_run_md(path_run_md)

## Set Dataloader and Model

In [4]:
normalize = transforms.Normalize(mean=params['mean'], std=params['std'])
val_transform = transforms.Compose([transforms.Resize(params['size']), transforms.CenterCrop(params['size']), transforms.ToTensor(), normalize])

conflict_dataset = ut_val.shapeTextureConflictDataset(root_dataset, val_transform)
classes = conflict_dataset.classes

conflict_dataloader = torch.utils.data.DataLoader(conflict_dataset, batch_size=params['batch_size'],
                                                  shuffle=False, num_workers=16, pin_memory=True)

model = ut_val.set_model(root_model, params, len(classes), cuda_device)

## Compute Embedding

In [5]:
_, embedding_size = model_dict[params['model']]

model.eval()

embedding = np.array([])
shape_labels = np.array([], dtype=int)
texture_labels = np.array([], dtype=int)
with torch.no_grad():
    for images, labels in tqdm(conflict_dataloader):
        images = images.cuda(device=cuda_device, non_blocking=True)

        features = model.encoder(images)

        embedding = np.append(embedding, features.cpu().numpy())
        shape_labels = np.append(shape_labels, labels[0].numpy())
        texture_labels = np.append(texture_labels, labels[1].numpy())

embedding = embedding.reshape(-1, embedding_size)

100%|██████████| 897/897 [04:42<00:00,  3.18it/s]


## Estimate Dimension

In [6]:
shape_texture_name_list = [path.replace(".jpg", '').split('/')[-3:] for path in conflict_dataset.paths]
shapeName_textureName_list = [(s+'/'+n.split('_stylized_')[0], t+'/'+n.split('_stylized_')[1]) for s,t,n in shape_texture_name_list]

shape_pairs = []
shape_array = np.array([sN for sN,_ in shapeName_textureName_list])
for sN in set(shape_array):
    shape_indices = np.where(shape_array == sN)[0]
    shape_pairs.append(shape_indices)

texture_pairs = []
texture_array = np.array([tN for _,tN in shapeName_textureName_list])
for tN in set(texture_array):
    texture_indices = np.where(texture_array == tN)[0]
    texture_pairs.append(texture_indices)

shape_pair_A = np.concatenate([np.tile(shape_pairs[i], reps=len(shape_pairs[i])-1) for i in range(len(shape_pairs))])
shape_pair_B = np.concatenate([np.concatenate([np.roll(shape_pairs[i], shift=j) for j in range(1,len(shape_pairs[i]))]) for i in range(len(shape_pairs))])

texture_pair_A = np.concatenate([np.tile(texture_pairs[i], reps=len(texture_pairs[i])-1) for i in range(len(texture_pairs))])
texture_pair_B = np.concatenate([np.concatenate([np.roll(texture_pairs[i], shift=j) for j in range(1,len(texture_pairs[i]))]) for i in range(len(texture_pairs))])

In [7]:
def compute_correlation_score(embedding_A, embedding_B):
    A = torch.tensor(embedding_A)
    B = torch.tensor(embedding_B)

    A_dm = A - A.mean(dim=0)
    B_dm = B - B.mean(dim=0)

    correlation = (A_dm.T * B_dm.T).sum(dim=1) / ((A_dm.T * A_dm.T).sum(dim=1) * (B_dm.T * B_dm.T).sum(dim=1)).sqrt()
    correlation = torch.nan_to_num(correlation, nan=0.0)

    return correlation.mean().item()

def estimate_dims(correlation_scores, embedding_size):
    scores = np.array(np.concatenate((correlation_scores, [1.0])))

    m = np.max(scores)
    e = np.exp(scores-m)
    softmaxed = e / np.sum(e)

    dim = embedding_size
    dims = [int(s*dim) for s in softmaxed]
    dims[-1] = dim - sum(dims[:-1])

    return dims

In [8]:
corr_score_shape = compute_correlation_score(embedding[shape_pair_A], embedding[shape_pair_B])
corr_score_texture = compute_correlation_score(embedding[texture_pair_A], embedding[texture_pair_B])

dims = estimate_dims([corr_score_shape, corr_score_texture], embedding_size)
dims

[105, 161, 246]

- CE_baseline: [93, 179, 240]
- CE_diffAug: [90, 184, 238]
- CE_diffAugAllAug: [96, 173, 243]
- SupCon_baseline: [102, 159, 251]
- SupCon_diffCSameSAug: [104, 158, 250]
- SupConHybrid_diffColorAug: [105, 161, 246]