In [None]:
from evodiff.utils import Tokenizer
import numpy as np

beta = 1
alpha = 0.0001
gamma = 0

tokenizer = Tokenizer()
# from https://web.expasy.org/protscale/pscale/A.A.Swiss-Prot.html
aa_freq = np.array([8.25, 5.53, 4.06, 5.45, 1.37, 3.93, 6.75,
                    7.07, 2.27, 5.96, 9.66, 5.84, 2.42, 3.86,
                    4.70, 6.56, 5.34, 1.08, 2.92, 6.87] + 11*[0]) / 100 
blosum_alphabet = np.array(list('ARNDCQEGHILKMFPSTWYVBZXJOU-'))
tok_alphabet = np.array(tokenizer.alphabet)
with open('/scratch/aa11803/d3pm/data/blosum62-special-MSA.mat') as f:
    load_matrix = np.array([line.split()[1:] for line in f if line[0] in blosum_alphabet], dtype=int)
map_ = blosum_alphabet[:, None] == tok_alphabet[None, :]
blosum_matrix = np.zeros((len(tok_alphabet), len(tok_alphabet)))
for i, ind_i in enumerate(np.argmax(map_, axis=1)):
    for j, ind_j in enumerate(np.argmax(map_, axis=1)):
        blosum_matrix[ind_i, ind_j] = load_matrix[i, j]

cond_liks = (2. ** (blosum_matrix/2)) * aa_freq[None, :] # X_ij = BLOSUM_ij * p(aa_j) = p(aa_j | aa_i)
cond_liks = cond_liks ** beta
cond_liks = cond_liks / cond_liks.sum(-1)[:, None]

L = cond_liks - np.eye(len(cond_liks))
l, V = np.linalg.eig(cond_liks[:20, :20])
V_inv = np.linalg.inv(V)
K = (V * (l**alpha)[None, :]) @ V_inv
K[K<0] = 0
K = K / K.sum(-1)[:, None]
L[:20, :20] = (K - np.eye(len(K))) / alpha
L[20:] *= -np.diagonal(L).min()

In [None]:
from matplotlib import pyplot as plt

plt.figure()
plt.imshow(L, cmap='bwr', vmin=-1, vmax=1)
plt.colorbar()
plt.xticks(np.arange(len(L)), tok_alphabet)
plt.yticks(np.arange(len(L)), tok_alphabet)

plt.figure()
plt.imshow(L+ np.eye(len(cond_liks)), cmap='Blues', vmin=0, vmax=1)
plt.colorbar()
plt.xticks(np.arange(len(L)), tok_alphabet)
plt.yticks(np.arange(len(L)), tok_alphabet)

from d3pm_sc.utils import get_inf_gen
L = get_inf_gen({'type': 'blosum', 'beta':beta, 'normalize': False, 'alpha':alpha}, 31)
plt.figure()
plt.imshow(L, cmap='bwr', vmin=-1, vmax=1)
plt.colorbar()

y = aa_freq
for i in range(100000):
    y = cond_liks.T @ y
plt.figure(figsize=[5, 3])
plt.plot(y, label='stationary', color='blue')
plt.plot(aa_freq, label='background freq', color='black')
plt.xlabel("AA")
plt.xticks(np.arange(len(y)), tok_alphabet)
plt.ylabel("freq")
plt.ylim(0, 0.11)
plt.legend()

In [None]:
K = L / ((-np.diagonal(L).min()) / (1-gamma)) + np.eye(len(L))

plt.figure()
plt.imshow(K, cmap='Blues', vmin=0, vmax=1)
plt.colorbar()
plt.xticks(np.arange(len(L)), tok_alphabet[:len(L)]);
plt.yticks(np.arange(len(L)), tok_alphabet[:len(L)]);

plt.figure()
plt.plot(np.diagonal(K))
plt.ylim(0, 1)

# look at seq data

In [None]:
from sequence_models.datasets import UniRefDataset

train_dataset = UniRefDataset('/vast/aa11803/uniref50_data/', 'train', structure=False, max_len=1024)
test_dataset = UniRefDataset('/vast/aa11803/uniref50_data/', 'test', structure=False, max_len=1024)

# look at dataloaders

In [None]:
%load_ext autoreload
%autoreload 2
    
from omegaconf import OmegaConf
import data
from evodiff.utils import Tokenizer

cfg = OmegaConf.load('configs/basic_protein.yaml')
cfg.train.batch_size = 128 // 4

num_classes = cfg.data.N
cfg.train.pack = False

##### Load data
train_dataloader, test_dataloader = data.get_dataloaders(cfg)
tokenizer = Tokenizer()

In [None]:
t = next(iter(train_dataloader))

In [None]:
tokenizer.untokenize(t[0][0])

In [None]:
1 / t[1].mean()