In [3]:
import json
import pickle
from scipy.sparse import csr_matrix
import numpy as np
import pandas as pd

with open("buckets.json", "r") as f:
    buckets = json.load(f)

with open("data/mol_bits.pkl", "rb") as f:
    data = pickle.load(f)

with open("data/results/test_mols_neighbors.pkl", "rb") as f:
    test_mols_neighbors = pickle.load(f)

train = pd.read_csv("data/activity_train.csv", header=None)
train.columns = ["uniprot_id", "mol_id", "activity"]
train["mol_id"] = train["mol_id"].astype(str).str.strip()

train_pivot = train.pivot(
    index="uniprot_id", columns="mol_id", values="activity"
).fillna(0)

# Baseline Estimation for CF
- Define similarity $s_{ij}$ of molecules i and j.
- Select k nearest neighbors N(i;x).
    - Molecules most similar to i, that have activity on protein x.
- Estimate rating $r_{xi}$ as the weighted average:

$$ r_{xi} = b_{xi} + \frac{\sum_{j\in N(i;x)} S_{ij} \cdot (r_{xj} - b_{xj})}{\sum_{j\in N(i;x)} s_{ij}} $$
Where $b_{xi}$ is the baseline estimate for $r_{xi}$:
$$ b_{xi} = \mu + b_x + b_i $$

- $\mu$ - overall mean protein activity
- $b_x$ - activity deviation of molecule x (molecule bias)
- $b_i$ - activity deviation of protein i (protein bias)


## Obtaining $b_{xi}$

In [4]:
mu = train_pivot.mean(axis=1).mean()
bx = train_pivot.mean(axis=0) - mu
bi = train_pivot.mean(axis=1) - mu

bxi = pd.DataFrame(
    np.full(train_pivot.shape, mu), index=train_pivot.index, columns=train_pivot.columns
)
for i, row in train_pivot.iterrows():
    bxi.loc[i] = mu + bx + bi.loc[i]

In [5]:
# MSE
((bxi.round(0).values - train_pivot.values) ** 2).mean()

0.3908519813733769

In [8]:
# masked mse
mask = (bxi.round(0).values != 0) * 1

mse_masked = ((train_pivot.values - bxi.round(0).values * mask) ** 2).sum() / mask.sum()
mse_masked
# np.sqrt(mse_masked)

172.31175181251055

## Similarities and Nearest Neighbors

- Jaccard similarity - $J(A, B) = \frac{A \cap B}{ A \cup B}$
- Find in each bucket, the top k highest similarity neighbors.


In [12]:
train

Unnamed: 0,uniprot_id,mol_id,activity
0,O14842,CHEMBL2022243,4
1,O14842,CHEMBL2022244,6
2,O14842,CHEMBL2022245,2
3,O14842,CHEMBL2022246,1
4,O14842,CHEMBL2022247,4
...,...,...,...
135706,Q9Y5Y4,CHEMBL4214909,6
135707,Q9Y5Y4,CHEMBL4218012,2
135708,Q9Y5Y4,CHEMBL4217503,7
135709,Q9Y5Y4,CHEMBL4204359,8


In [4]:
for bucket, mols in buckets.items():
    if "CHEMBL10" in mols:
        print(bucket)

92997
15719
5294
96137
56944
85665
21457
7562
32023
496
65645
23764
43528
6992
25669
1620
82128
6512
95716
74642
6715
82068
92333
66436
92714
15917
49361
30162
66442
12774
51688
53651
1427
48951
20690
45343
37894
2143
17820
13765
89928
9377
36607
90024
77788
30423
62567
58279
5702
92511
