In [15]:
import torch
import numpy as np
import ot
from torch.autograd import Function

# Define the Kullback-Leibler divergence
def kl_divergence(p, q):
    return (p * (p / q).log()).sum()

# Define the entropy function
def entropy(p):
    return -(p * p.log()).sum()

class SinkhornDistance(Function):
    @staticmethod
    def forward(ctx, C, epsilon, max_iter):
        # Sinkhorn iterations
        K = torch.exp(-C.double() / epsilon)
        u = torch.ones(C.shape[0]).double()
        v = torch.ones(C.shape[1]).double()
        
        for _ in range(max_iter):
            u = 1.0 / (K @ v)
            v = 1.0 / (K.t() @ u)
        
        ctx.save_for_backward(K, u, v)
        return torch.sum(u.reshape((-1, 1)) * C.double() * v.reshape((1, -1)) * K)

    @staticmethod
    def backward(ctx, grad_output):
        K, u, v = ctx.saved_tensors
        grad_C = grad_output * u.reshape((-1, 1)) * v.reshape((1, -1)) * K
        return grad_C, None, None, None



def optimize_unbalanced_ot(C, P, epsilon, tau1, tau2, w_g, w_f, max_iter=1000, lr=1e-2):
    sinkhorn_distance = SinkhornDistance.apply

    # Define the optimizer
    optimizer = torch.optim.Adam([C], lr=lr)

    for i in range(max_iter):
        optimizer.zero_grad()

        # Compute the Sinkhorn distance
        dist = sinkhorn_distance(C,  epsilon, max_iter)

        # Compute the Kullback-Leibler divergences and the entropy
        kl1 = tau1 * kl_divergence(P.sum(1), w_g)
        kl2 = tau2 * kl_divergence(P.sum(0), w_f)
        ent = epsilon * entropy(P)

        # Total loss
        loss = dist - ent + kl1 + kl2

        # Backward pass and optimization step
        loss.backward()
        optimizer.step()

    return C


In [55]:
def select_top_k_classes(C_optimized, P, k):
    # Compute the measure of how much each class in the source domain contributes to the target domain
    class_contributions = P.sum(axis=1)

    # Select the indices of the top k values
    #top_k_classes = np.argpartition(class_contributions, -k)[-k:]
    top_k_classes = np.argsort(class_contributions)[:k]
    return top_k_classes


In [56]:
import torch
import numpy as np


# Step 1: Prepare your datasets

target_list = torch.load('/home/raphael/Documents/models/B/f_baselinemetadataset_cub_test_features.pt')
target_vectors = torch.stack([x['features'].mean(0) for x in target_list])
source_vectors = torch.load('mean_original_imagenet.pt')
print(source_vectors.shape, target_vectors.shape)
num_classes_source = source_vectors.shape[0]
num_classes_target = target_vectors.shape[0]

# Calculate the cost matrix C
C = np.zeros((num_classes_source, num_classes_target))
for i in range(num_classes_source):
    for j in range(num_classes_target):
        C[i, j] = np.linalg.norm(source_vectors[i] - target_vectors[j])  # replace with actual distance metric

# Convert to PyTorch tensor for optimization
C = torch.tensor(C, requires_grad=True,dtype=torch.double)

# Step 2: Normalize class distributions
# For this example, we'll assume uniform distributions
w_g = torch.ones(num_classes_source,dtype=torch.double) / num_classes_source  # replace with actual source class distribution
w_f = torch.ones(num_classes_target,dtype=torch.double) / num_classes_target  # replace with actual target class distribution

# Convert to PyTorch tensor
w_g = torch.tensor(w_g)
w_f = torch.tensor(w_f)

# Assume we have a transportation matrix P
# For this example, we'll just use a uniform matrix
P = np.ones((num_classes_source, num_classes_target)) / num_classes_source*num_classes_target  # replace with actual initialization
P = torch.tensor(P)

# Parameters for optimization
epsilon = 1.0  # entropy regularization parameter
tau1 = 1.0  # parameter for KL divergence of source domain
tau2 = 100.0  # parameter for KL divergence of target domain

# Step 3: Run the optimization
C_optimized = optimize_unbalanced_ot(C, P, epsilon, tau1, tau2, w_g, w_f)

# Step 4: Select top k classes
k = 5  # number of classes to select
top_k_classes = select_top_k_classes(C_optimized.detach().numpy(), P.detach().numpy(), k)

print(f'Top {k} classes: {top_k_classes}')


torch.Size([712, 640]) torch.Size([30, 640])


  w_g = torch.tensor(w_g)
  w_f = torch.tensor(w_f)


Top 5 classes: [  0 469 470 471 472]


In [53]:
name = torch.load('name_original.pt')

In [57]:
for k in   [0 ,469, 470, 471, 472]:
    print(name[k])

tench, Tinca tinca
parallel bars, bars
park bench
passenger car, coach, carriage
patio, terrace
