In [1]:
from latentis.measure.cka import CKA, CKAMode
import torch 

## CPU

In [2]:
cka = CKA(mode=CKAMode.LINEAR, device='cpu')

X = torch.randn(100, 10)
Y = torch.randn(100, 10)

cka(X, Y)

tensor(0.1093)

## With GPU

In [3]:
cka = CKA(mode=CKAMode.LINEAR, device='cuda')

X = torch.randn(100, 10)
Y = torch.randn(100, 10)

res = cka(X, Y)
print(res.device)

cuda:0


## No device

In [4]:
cka = CKA(mode=CKAMode.LINEAR, device='cuda')

X = torch.randn(100, 10)
Y = torch.randn(100, 10)

cka(X, Y)

tensor(0.1054, device='cuda:0')

## Random spaces experiment

In [9]:
num_samples = [100, 1000]

K = [1, 2, 5, 10, 100]
starting_dim = 10
num_iterations = 100

results = {n: {k*starting_dim: [] for k in K} for n in num_samples }

for n in num_samples:
    for k in K:
        ckas = []
        for i in range(num_iterations):
            
            X = torch.randn(n, starting_dim * k)
            Y = torch.randn(n, starting_dim * k)

            cka_score = cka(X, Y)
            ckas.append(cka_score)

        ckas = torch.stack(ckas)
        results[n][k*starting_dim] = ckas.mean()


In [10]:
results

{100: {10: tensor(0.0901, device='cuda:0'),
  20: tensor(0.1663, device='cuda:0'),
  50: tensor(0.3330, device='cuda:0'),
  100: tensor(0.5004, device='cuda:0'),
  1000: tensor(0.9089, device='cuda:0')},
 1000: {10: tensor(0.0098, device='cuda:0'),
  20: tensor(0.0198, device='cuda:0'),
  50: tensor(0.0475, device='cuda:0'),
  100: tensor(0.0910, device='cuda:0'),
  1000: tensor(0.4997, device='cuda:0')}}

In [13]:
import plotly.graph_objects as go

# Prepare the data for plotting
x_values = [k*starting_dim for k in K]
traces = []

for n in num_samples:
    y_values = [results[n][k*starting_dim].item() for k in K]
    traces.append(go.Scatter(x=x_values, y=y_values, mode='lines+markers', name=f'Samples: {n}'))

# Create the plot
fig = go.Figure(traces)

# Add layout details
fig.update_layout(
    title='Mean CKA Scores vs dimensionality for Different Sample Sizes',
    xaxis_title='K',
    yaxis_title='Mean CKA Score',
    legend_title='Sample Size'
)

# Show the plot
fig.show()

: 