In [1]:
%load_ext autoreload
%autoreload 2

In [12]:
import torch
from plaid.datasets import CATHShardedDataModule

shard_dir = "/homefs/home/lux70/storage/data/rocklin/shards/"
dm = CATHShardedDataModule(storage_type="hdf5", shard_dir=shard_dir, seq_len=256, batch_size=)
dm.setup()
train_dataloader = dm.train_dataloader()
val_dataloader = dm.val_dataloader()

In [13]:
print(len(train_dataloader.dataset))
print(len(val_dataloader.dataset))

train_batch = next(iter(train_dataloader))
val_batch = next(iter(val_dataloader))

print(train_batch[0].shape)
print(val_batch[0].shape)


7904
1976
torch.Size([1976, 43, 1024])
torch.Size([1976, 43, 1024])


In [None]:
from plaid.evaluation import calc_fid_fn, parmar_fid

# Unnormalized raw latent

In [19]:
print(train_batch[0].mean(), train_batch[0].std())
print(val_batch[0].mean(), val_batch[0].std())

tensor(1.1961) tensor(71.3851)
tensor(1.1991) tensor(71.4518)


In [16]:
calc_fid_fn(train_batch[0].mean(dim=1), val_batch[0].mean(dim=1))

tensor(-200.1610)

In [21]:
# try a different implementation
parmar_fid(train_batch[0].mean(dim=1).numpy(), val_batch[0].mean(dim=1).numpy())

ValueError: Imaginary component 0.0029816698803220455

# Squish to [-1, 1] range

In [23]:
from plaid.utils import LatentScaler

scaler = LatentScaler()
scaled_train = scaler.scale(train_batch[0])
scaled_val = scaler.scale(val_batch[0])

In [24]:
print(scaled_train.max())
print(scaled_train.min())
print(scaled_train.mean())
print(scaled_train.std())

tensor(1.0689)
tensor(-1.0252)
tensor(-0.0052)
tensor(0.1865)


In [25]:
print(scaled_val.max())
print(scaled_val.min())
print(scaled_val.mean())
print(scaled_val.std())

tensor(1.1002)
tensor(-1.1092)
tensor(-0.0051)
tensor(0.1867)


In [26]:
calc_fid_fn(scaled_train.mean(dim=1), scaled_val.mean(dim=1))

tensor(-0.0066)

In [27]:
# try dif implementation
parmar_fid(scaled_train.mean(dim=1).numpy(), scaled_val.mean(dim=1).numpy())

0.04311427881934016

In [28]:
from plaid.evaluation import parmar_kid

parmar_kid(scaled_train.mean(dim=1).numpy(), scaled_val.mean(dim=1).numpy())

1.65781656656236e-06

# Compare to UniRef

In [30]:
from plaid.constants import CACHED_TENSORS_DIR
from pathlib import Path
from safetensors.torch import load_file

In [37]:
uniref_holdout = load_file(Path(CACHED_TENSORS_DIR) / "holdout_esmfold_feats.st")['features']
cath_holdout = load_file(Path(CACHED_TENSORS_DIR) / "cath_esmfold_feats.st")['embeddings']

scaled_uniref = scaler.scale(uniref_holdout)
scaled_cath = scaler.scale(cath_holdout)

In [38]:
parmar_fid(
    scaled_train.mean(dim=1).numpy(),
    scaled_uniref.numpy(),
)

8.382239767818914

In [39]:
parmar_kid(
    scaled_train.mean(dim=1).numpy(),
    scaled_uniref.numpy(),
)

0.01907546199324326

# Compare to CATH

In [42]:
parmar_fid(
    scaled_train.mean(dim=1).numpy(),
    scaled_cath.mean(dim=1).numpy(),
)

6.866759709475925

In [43]:
parmar_kid(
    scaled_train.mean(dim=1).numpy(),
    scaled_cath.mean(dim=1).numpy(),
)

0.011983237093343315