In [1]:
%load_ext autoreload
%autoreload 2

from utils.dataloader import pdb_to_frames, SE3FrameDataset, squeeze_batch, to_device
from utils.noise_denoise import noise_translations, DSM_R3, score_R3, noise_rotations
from torch.utils.data import DataLoader
from utils.pdb_utils import write_ca_to_pdb, write_frames_to_pdb

from glob import glob
from pathlib import Path

import torch
import torch.nn as nn
import torch.optim as optim
from models.gnn import EGNNScoreModel, encode_node_features, encode_edge_features

In [2]:
pdb_folder = 'pdb_small/*.pdb'
pdb_frames = {Path(pdb_file).stem: pdb_to_frames(pdb_file)
              for pdb_file in glob(pdb_folder)}

In [3]:
dataset = SE3FrameDataset(pdb_frames)
loader = DataLoader(dataset, batch_size=1, shuffle=True, collate_fn = squeeze_batch)

In [4]:
from tqdm import tqdm

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Hyperparameters
lr = 1e-3
epochs = 50
h_dim = 32
edge_dim = 32
hidden_dim = 64
n_layers = 4

# Initialize model
model = EGNNScoreModel(
    h_dim=2 * h_dim,
    edge_dim=2 * edge_dim,
    hidden_dim=hidden_dim,
    n_layers=n_layers
).to(device)

# optimizer = optim.Adam(model.parameters(), lr=lr)
optimizer = torch.optim.Adam(
    model.parameters(),
    lr=lr,             
    betas=(0.9, 0.999),  
    eps=1e-8,            
    weight_decay=0      
)

for epoch in range(epochs):
    model.train()
    total_loss = 0.0
    print(f"Epoch {epoch+1}")

    for batch in tqdm(loader, desc="Training"):
        batch = to_device(batch, device)
        x0 = batch["translations"]       # (N_copies, L, 3)
        t = batch["timesteps"]           # (N_copies, 1)
        x_t = noise_translations(x0, t)  # (N_copies, L, 3)

        L = x_t.shape[1]
        node_feats = encode_node_features(x_t, t, dim=h_dim)  # (N_copies, L, 2*h_dim)
        edge_feats = encode_edge_features(L, dim=edge_dim, device=x_t.device)  # (L, L, 2*edge_dim)
        edge_feats = edge_feats.unsqueeze(0).expand(x_t.shape[0], -1, -1, -1)  # (N_copies, L, L, 2*edge_dim)

        score_true = score_R3(x_t, x0, t)
        
        with torch.no_grad():
            pred = model(x_t, t, node_feats, edge_feats)
            print("Predicted score mean:", pred.mean().item(), "std:", pred.std().item())
            print("True score mean:", score_true.mean().item(), "std:", score_true.std().item())

        optimizer.zero_grad()
        
        score_pred = model(x_t, t, node_feats, edge_feats)
        loss = DSM_R3(score_pred, score_true, t)

        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    avg_loss = total_loss / len(loader)
    print(f"[Epoch {epoch+1}] Loss: {avg_loss:.6f}")

Epoch 1


Training:   0%|                                           | 0/2 [00:00<?, ?it/s]

Predicted score mean: -0.2108448710119146 std: 0.6592480620443284
True score mean: 0.01277614274546102 std: 1.2964030164034546


Training:  50%|█████████████████▌                 | 1/2 [00:00<00:00,  1.43it/s]

Predicted score mean: -0.046275668292075495 std: 0.2705360506005359
True score mean: -0.0020094920925993543 std: 1.4504224585755265


Training: 100%|███████████████████████████████████| 2/2 [00:01<00:00,  1.43it/s]


[Epoch 1] Loss: 396.155486
Epoch 2


Training:   0%|                                           | 0/2 [00:00<?, ?it/s]

Predicted score mean: 0.06830521225058303 std: 0.13708125403200902
True score mean: -0.0022953664208346593 std: 1.6453921559011437


Training:  50%|█████████████████▌                 | 1/2 [00:00<00:00,  1.43it/s]

Predicted score mean: 0.12760770037226546 std: 0.33188327610993523
True score mean: -0.0008940545806285584 std: 2.2395539032369816


Training: 100%|███████████████████████████████████| 2/2 [00:01<00:00,  1.45it/s]


[Epoch 2] Loss: 1676.028592
Epoch 3


Training:   0%|                                           | 0/2 [00:00<?, ?it/s]

Predicted score mean: 0.11716544124737226 std: 0.32815231175227155
True score mean: -0.005780920100912969 std: 1.8256661577846458


Training:  50%|█████████████████▌                 | 1/2 [00:00<00:00,  1.41it/s]

Predicted score mean: 0.07931111956284426 std: 0.28511561556503223
True score mean: -0.009265682152358125 std: 1.2831466849209474


Training: 100%|███████████████████████████████████| 2/2 [00:01<00:00,  1.43it/s]


[Epoch 3] Loss: 590.159805
Epoch 4


Training:   0%|                                           | 0/2 [00:00<?, ?it/s]

Predicted score mean: 0.03768500139082536 std: 0.20287200959108223
True score mean: 0.022708589552696484 std: 1.5647546645755328


Training:  50%|█████████████████▌                 | 1/2 [00:00<00:00,  1.51it/s]

Predicted score mean: -0.00476459245097523 std: 0.12173123634780603
True score mean: -0.011225063320352112 std: 1.4523050342689205


Training: 100%|███████████████████████████████████| 2/2 [00:01<00:00,  1.47it/s]


[Epoch 4] Loss: 643.201614
Epoch 5


Training:   0%|                                           | 0/2 [00:00<?, ?it/s]

Predicted score mean: -0.034302529638423875 std: 0.04783745441513638
True score mean: 0.006065145309701501 std: 1.4390748800061477


Training:  50%|█████████████████▌                 | 1/2 [00:00<00:00,  1.45it/s]

Predicted score mean: -0.048568301053655706 std: 0.015285392403512324
True score mean: 0.020953097162313715 std: 1.4110838055653363


Training: 100%|███████████████████████████████████| 2/2 [00:01<00:00,  1.47it/s]


[Epoch 5] Loss: 457.589909
Epoch 6


Training:   0%|                                           | 0/2 [00:00<?, ?it/s]

Predicted score mean: -0.02530040007066563 std: 0.06897611375015623
True score mean: 0.023625769137748023 std: 1.6555670181305566


Training:  50%|█████████████████▌                 | 1/2 [00:00<00:00,  1.43it/s]

Predicted score mean: 0.011542355827037006 std: 0.11752989206427834
True score mean: 0.01975166471949788 std: 1.4511007321503693


Training: 100%|███████████████████████████████████| 2/2 [00:01<00:00,  1.46it/s]


[Epoch 6] Loss: 630.923190
Epoch 7


Training:   0%|                                           | 0/2 [00:00<?, ?it/s]

Predicted score mean: 0.04593888308070261 std: 0.1442023435496035
True score mean: -0.00434949958753362 std: 1.5952805017800862


Training:  50%|█████████████████▌                 | 1/2 [00:00<00:00,  1.43it/s]

Predicted score mean: 0.06728705748616164 std: 0.1546212272569815
True score mean: 0.02139059140490348 std: 1.5484042411136876


Training: 100%|███████████████████████████████████| 2/2 [00:01<00:00,  1.44it/s]


[Epoch 7] Loss: 710.852271
Epoch 8


Training:   0%|                                           | 0/2 [00:00<?, ?it/s]

Predicted score mean: 0.08972804222799614 std: 0.15004986156759884
True score mean: 0.0013186659693588474 std: 1.8040405762012868


Training:  50%|█████████████████▌                 | 1/2 [00:00<00:00,  1.48it/s]

Predicted score mean: 0.08660483705583992 std: 0.11013939363263703
True score mean: -0.007380195751630573 std: 1.3357427368374986


Training: 100%|███████████████████████████████████| 2/2 [00:01<00:00,  1.46it/s]


[Epoch 8] Loss: 855.353512
Epoch 9


Training:   0%|                                           | 0/2 [00:00<?, ?it/s]

Predicted score mean: 0.06994739457668427 std: 0.0748100258006372
True score mean: 0.0058250457507057745 std: 1.6407007062639765


Training:  50%|█████████████████▌                 | 1/2 [00:00<00:00,  1.45it/s]

Predicted score mean: 0.047155210259415936 std: 0.05590775694782812
True score mean: 0.004131330882291202 std: 1.6174683044375562


Training: 100%|███████████████████████████████████| 2/2 [00:01<00:00,  1.44it/s]


[Epoch 9] Loss: 819.704598
Epoch 10


Training:   0%|                                           | 0/2 [00:00<?, ?it/s]

Predicted score mean: 0.026342190691339536 std: 0.0682833888292964
True score mean: -0.00287119176247589 std: 1.4547674767369743


Training:  50%|█████████████████▌                 | 1/2 [00:00<00:00,  1.41it/s]

Predicted score mean: 0.0020470284584670874 std: 0.0898060748762972
True score mean: -0.01033579005128049 std: 1.2636003318870648


Training: 100%|███████████████████████████████████| 2/2 [00:01<00:00,  1.44it/s]


[Epoch 10] Loss: 358.114104
Epoch 11


Training:   0%|                                           | 0/2 [00:00<?, ?it/s]

Predicted score mean: -0.0242702326911974 std: 0.10088523949844505
True score mean: -0.01316627309288317 std: 1.395468233148887


Training:  50%|█████████████████▌                 | 1/2 [00:00<00:00,  1.42it/s]

Predicted score mean: -0.051985067935496065 std: 0.1111061503878579
True score mean: 0.010745556666344542 std: 1.1929770241430016


Training: 100%|███████████████████████████████████| 2/2 [00:01<00:00,  1.44it/s]


[Epoch 11] Loss: 261.116134
Epoch 12


Training:   0%|                                           | 0/2 [00:00<?, ?it/s]

Predicted score mean: -0.06400453755900451 std: 0.11252327698341515
True score mean: 0.013393036550715281 std: 1.358450922642078


Training:  50%|█████████████████▌                 | 1/2 [00:00<00:00,  1.42it/s]

Predicted score mean: -0.06314670491253684 std: 0.10859926084149837
True score mean: 0.01176527288882701 std: 1.3431472086481198


Training: 100%|███████████████████████████████████| 2/2 [00:01<00:00,  1.44it/s]


[Epoch 12] Loss: 369.847032
Epoch 13


Training:   0%|                                           | 0/2 [00:00<?, ?it/s]

Predicted score mean: -0.039005439968213707 std: 0.08803776228322367
True score mean: 0.013097578979864806 std: 1.2829170810573742


Training:  50%|█████████████████▌                 | 1/2 [00:00<00:00,  1.44it/s]

Predicted score mean: -0.0070300196244742095 std: 0.0646041314112025
True score mean: 0.001773937184372802 std: 2.575415426048133


Training: 100%|███████████████████████████████████| 2/2 [00:01<00:00,  1.42it/s]


[Epoch 13] Loss: 1176.019472
Epoch 14


Training:   0%|                                           | 0/2 [00:00<?, ?it/s]

Predicted score mean: 0.027440001183067293 std: 0.044665192871549804
True score mean: 0.0056116239880256 std: 1.49105406164377


Training:  50%|█████████████████▌                 | 1/2 [00:00<00:00,  1.43it/s]

Predicted score mean: 0.056990994800168164 std: 0.03402919399253826
True score mean: 0.018451464217947983 std: 1.235560454023992


Training: 100%|███████████████████████████████████| 2/2 [00:01<00:00,  1.45it/s]


[Epoch 14] Loss: 338.993675
Epoch 15


Training:   0%|                                           | 0/2 [00:00<?, ?it/s]

Predicted score mean: 0.07688335726046734 std: 0.03734665547157664
True score mean: -0.01228893283460383 std: 1.9567427469211534


Training:  50%|█████████████████▌                 | 1/2 [00:00<00:00,  1.42it/s]

Predicted score mean: 0.08277746519303933 std: 0.04727250290837862
True score mean: -0.014108702111262523 std: 1.9453731412381672


Training: 100%|███████████████████████████████████| 2/2 [00:01<00:00,  1.45it/s]


[Epoch 15] Loss: 1411.564545
Epoch 16


Training:   0%|                                           | 0/2 [00:00<?, ?it/s]

Predicted score mean: 0.06324132271223198 std: 0.04569589968697804
True score mean: -0.004779779028774322 std: 1.5205150073560674


Training:  50%|█████████████████▌                 | 1/2 [00:00<00:00,  1.43it/s]

Predicted score mean: 0.03968739783632242 std: 0.04132269275398327
True score mean: 0.0062564432114692505 std: 1.404240232989029


Training: 100%|███████████████████████████████████| 2/2 [00:01<00:00,  1.45it/s]


[Epoch 16] Loss: 511.389176
Epoch 17


Training:   0%|                                           | 0/2 [00:00<?, ?it/s]

Predicted score mean: 0.00970297716729779 std: 0.0400616293492476
True score mean: 0.011248724081917977 std: 1.4519595887070809


Training:  50%|█████████████████▌                 | 1/2 [00:00<00:00,  1.40it/s]

Predicted score mean: -0.016354810524703528 std: 0.032977127842814105
True score mean: -0.005610178217962537 std: 1.7447971185088273


Training: 100%|███████████████████████████████████| 2/2 [00:01<00:00,  1.35it/s]


[Epoch 17] Loss: 844.314626
Epoch 18


Training:   0%|                                           | 0/2 [00:00<?, ?it/s]

Predicted score mean: -0.03205353059520662 std: 0.016033727799973325
True score mean: 0.020844708279545026 std: 1.6432072299607408


Training:  50%|█████████████████▌                 | 1/2 [00:00<00:00,  1.43it/s]

Predicted score mean: -0.025969390547808495 std: 0.04071751758918636
True score mean: 0.016385587818052154 std: 1.5502846163458122


Training: 100%|███████████████████████████████████| 2/2 [00:01<00:00,  1.41it/s]


[Epoch 18] Loss: 782.561515
Epoch 19


Training:   0%|                                           | 0/2 [00:00<?, ?it/s]

Predicted score mean: -0.011874027002486355 std: 0.07184036128637
True score mean: -0.016651115090550866 std: 1.6122794104624516


Training:  50%|█████████████████▌                 | 1/2 [00:00<00:00,  1.43it/s]

Predicted score mean: -0.0031881669797764215 std: 0.10039324639095742
True score mean: -0.0073110376368051575 std: 1.1903323369030452


Training: 100%|███████████████████████████████████| 2/2 [00:01<00:00,  1.45it/s]


[Epoch 19] Loss: 372.368761
Epoch 20


Training:   0%|                                           | 0/2 [00:00<?, ?it/s]

Predicted score mean: 0.0036771964801546974 std: 0.11895152039997331
True score mean: 0.006533627439368614 std: 1.588205658503679


Training:  50%|█████████████████▌                 | 1/2 [00:00<00:00,  1.42it/s]

Predicted score mean: 0.007921174001942111 std: 0.12511844154601942
True score mean: 0.014211503771793766 std: 1.5447898696301559


Training: 100%|███████████████████████████████████| 2/2 [00:01<00:00,  1.43it/s]


[Epoch 20] Loss: 684.678952
Epoch 21


Training:   0%|                                           | 0/2 [00:00<?, ?it/s]

Predicted score mean: 0.016759329177503404 std: 0.11695533190921234
True score mean: 0.006025741133750509 std: 1.8177447067900987


Training:  50%|█████████████████▌                 | 1/2 [00:00<00:00,  1.41it/s]

Predicted score mean: 0.022204531726354636 std: 0.0998620305086969
True score mean: 0.005075928854775431 std: 1.6087232944056602


Training: 100%|███████████████████████████████████| 2/2 [00:01<00:00,  1.41it/s]


[Epoch 21] Loss: 914.404005
Epoch 22


Training:   0%|                                           | 0/2 [00:00<?, ?it/s]

Predicted score mean: 0.014897268038463555 std: 0.060931761879829203
True score mean: 0.01174804204446258 std: 2.686200948973995


Training:  50%|█████████████████▌                 | 1/2 [00:00<00:00,  1.41it/s]

Predicted score mean: 0.01109221497212276 std: 0.045752362642732726
True score mean: -0.0038348562645395944 std: 1.3260905397210574


Training: 100%|███████████████████████████████████| 2/2 [00:01<00:00,  1.44it/s]


[Epoch 22] Loss: 1311.141118
Epoch 23


Training:   0%|                                           | 0/2 [00:00<?, ?it/s]

Predicted score mean: -0.0006987628623269889 std: 0.06453326404209513
True score mean: -0.012987048567006887 std: 1.5945181296405746


Training:  50%|█████████████████▌                 | 1/2 [00:00<00:00,  1.42it/s]

Predicted score mean: -0.012206518496772638 std: 0.09193494723445138
True score mean: -0.0009299801304075097 std: 1.4705638123044644


Training: 100%|███████████████████████████████████| 2/2 [00:01<00:00,  1.45it/s]


[Epoch 23] Loss: 637.813577
Epoch 24


Training:   0%|                                           | 0/2 [00:00<?, ?it/s]

Predicted score mean: -0.0171219254660523 std: 0.10335475466882117
True score mean: 0.0035534939103481662 std: 1.4820105492555669


Training:  50%|█████████████████▌                 | 1/2 [00:00<00:00,  1.42it/s]

Predicted score mean: -0.020040374182554405 std: 0.11090564044406756
True score mean: 0.006030769888726616 std: 1.546854896751142


Training: 100%|███████████████████████████████████| 2/2 [00:01<00:00,  1.42it/s]


[Epoch 24] Loss: 630.074596
Epoch 25


Training:   0%|                                           | 0/2 [00:00<?, ?it/s]

Predicted score mean: -0.005815625761090925 std: 0.08778228648327906
True score mean: -0.019842156942765785 std: 1.2289645896827162


Training:  50%|█████████████████▌                 | 1/2 [00:00<00:00,  1.48it/s]

Predicted score mean: 0.009767457921577767 std: 0.055809312466867045
True score mean: 0.0038187989633550867 std: 1.9196722654810483


Training: 100%|███████████████████████████████████| 2/2 [00:01<00:00,  1.45it/s]


[Epoch 25] Loss: 601.744614
Epoch 26


Training:   0%|                                           | 0/2 [00:00<?, ?it/s]

Predicted score mean: 0.021004057545950118 std: 0.03967932363036328
True score mean: 0.019728161052741357 std: 1.5236834818592506


Training:  50%|█████████████████▌                 | 1/2 [00:00<00:00,  1.48it/s]

Predicted score mean: 0.04657359828604341 std: 0.053482411725423014
True score mean: -0.0063930555415639845 std: 1.4413295305115663


Training: 100%|███████████████████████████████████| 2/2 [00:01<00:00,  1.45it/s]


[Epoch 26] Loss: 576.224680
Epoch 27


Training:   0%|                                           | 0/2 [00:00<?, ?it/s]

Predicted score mean: 0.0616334715455281 std: 0.0833000363965369
True score mean: 1.7311639810907533e-05 std: 1.4366297954186875


Training:  50%|█████████████████▌                 | 1/2 [00:00<00:00,  1.38it/s]

Predicted score mean: 0.06856141950339725 std: 0.10274829805764871
True score mean: 0.014167355354062425 std: 1.326234304757523


Training: 100%|███████████████████████████████████| 2/2 [00:01<00:00,  1.39it/s]


[Epoch 27] Loss: 391.159553
Epoch 28


Training:   0%|                                           | 0/2 [00:00<?, ?it/s]

Predicted score mean: 0.06320278837844023 std: 0.09350417949569383
True score mean: -0.013465160772373952 std: 1.5884296702964626


Training:  50%|█████████████████▌                 | 1/2 [00:00<00:00,  1.48it/s]

Predicted score mean: 0.032406396251627854 std: 0.05509434504131918
True score mean: -0.00230761847303091 std: 1.3839444812925024


Training: 100%|███████████████████████████████████| 2/2 [00:01<00:00,  1.35it/s]


[Epoch 28] Loss: 616.471758
Epoch 29


Training:   0%|                                           | 0/2 [00:00<?, ?it/s]

Predicted score mean: 0.000224938504724235 std: 0.014877047324875017
True score mean: -2.939915440188331e-05 std: 1.2511003871277395


Training:  50%|█████████████████▌                 | 1/2 [00:00<00:00,  1.47it/s]

Predicted score mean: -0.0337345366230285 std: 0.030182741742205192
True score mean: 0.015209746695578695 std: 1.6912618634575236


Training: 100%|███████████████████████████████████| 2/2 [00:01<00:00,  1.46it/s]


[Epoch 29] Loss: 470.649033
Epoch 30


Training:   0%|                                           | 0/2 [00:00<?, ?it/s]

Predicted score mean: -0.05894754951410818 std: 0.06948775306228107
True score mean: 0.006307291047274832 std: 1.3172999853865996


Training:  50%|█████████████████▌                 | 1/2 [00:00<00:00,  1.48it/s]

Predicted score mean: -0.06909162151230795 std: 0.08874843920887898
True score mean: -0.011849762851564694 std: 1.6638441050747375


Training: 100%|███████████████████████████████████| 2/2 [00:01<00:00,  1.47it/s]


[Epoch 30] Loss: 500.506546
Epoch 31


Training:   0%|                                           | 0/2 [00:00<?, ?it/s]

Predicted score mean: -0.07193506656739361 std: 0.09430604550506237
True score mean: 0.02264825250962315 std: 1.560488673490871


Training:  50%|█████████████████▌                 | 1/2 [00:00<00:00,  1.11it/s]

Predicted score mean: -0.0611975430870189 std: 0.09702014541251511
True score mean: -0.016007741651061908 std: 1.4071759994315123


Training: 100%|███████████████████████████████████| 2/2 [00:01<00:00,  1.27it/s]


[Epoch 31] Loss: 521.788276
Epoch 32


Training:   0%|                                           | 0/2 [00:00<?, ?it/s]

Predicted score mean: -0.049468135883145276 std: 0.08255857726171992
True score mean: -0.0030542564396914725 std: 1.8281561782678595


Training:  50%|█████████████████▌                 | 1/2 [00:00<00:00,  1.43it/s]

Predicted score mean: -0.027755189538158496 std: 0.05919222722510308
True score mean: 0.013484097186207656 std: 1.7600729712811696


Training: 100%|███████████████████████████████████| 2/2 [00:01<00:00,  1.45it/s]


[Epoch 32] Loss: 1097.021368
Epoch 33


Training:   0%|                                           | 0/2 [00:00<?, ?it/s]

Predicted score mean: -0.003630088455519274 std: 0.036659339226349535
True score mean: 0.01195740554061323 std: 1.540361706511744


Training:  50%|█████████████████▌                 | 1/2 [00:00<00:00,  1.49it/s]

Predicted score mean: 0.01735876254565717 std: 0.026453125753674405
True score mean: 0.002281787158197567 std: 1.3760039765745946


Training: 100%|███████████████████████████████████| 2/2 [00:01<00:00,  1.46it/s]


[Epoch 33] Loss: 587.789681
Epoch 34


Training:   0%|                                           | 0/2 [00:00<?, ?it/s]

Predicted score mean: 0.036174130554828265 std: 0.01467259285559839
True score mean: 0.011013524309869847 std: 1.4237169655770612


Training:  50%|█████████████████▌                 | 1/2 [00:00<00:00,  1.49it/s]

Predicted score mean: 0.04332336483494231 std: 0.013165453600565117
True score mean: 0.018546903308416008 std: 1.632064545562436


Training: 100%|███████████████████████████████████| 2/2 [00:01<00:00,  1.46it/s]


[Epoch 34] Loss: 619.301565
Epoch 35


Training:   0%|                                           | 0/2 [00:00<?, ?it/s]

Predicted score mean: 0.05372204763776523 std: 0.020609059916703358
True score mean: -0.016952296081853262 std: 1.759441525691252


Training:  50%|█████████████████▌                 | 1/2 [00:00<00:00,  1.44it/s]

Predicted score mean: 0.05302937708988376 std: 0.02841334472760164
True score mean: -0.013420036428470207 std: 1.453546769757249


Training: 100%|███████████████████████████████████| 2/2 [00:01<00:00,  1.46it/s]


[Epoch 35] Loss: 708.641909
Epoch 36


Training:   0%|                                           | 0/2 [00:00<?, ?it/s]

Predicted score mean: 0.0362669181539902 std: 0.03034160505348715
True score mean: 0.008887475102210727 std: 2.061666035669097


Training:  50%|█████████████████▌                 | 1/2 [00:00<00:00,  1.44it/s]

Predicted score mean: 0.016874534280230523 std: 0.03404987592888162
True score mean: 0.0012878938862749265 std: 1.4467479628427486


Training: 100%|███████████████████████████████████| 2/2 [00:01<00:00,  1.46it/s]


[Epoch 36] Loss: 883.726790
Epoch 37


Training:   0%|                                           | 0/2 [00:00<?, ?it/s]

Predicted score mean: -0.004093485723037052 std: 0.04213192147728353
True score mean: 0.011142699691691446 std: 1.3497437725461705


Training:  50%|█████████████████▌                 | 1/2 [00:00<00:00,  1.50it/s]

Predicted score mean: -0.029795337786444113 std: 0.04228685989359737
True score mean: -0.007693204587395875 std: 1.43095796442236


Training: 100%|███████████████████████████████████| 2/2 [00:01<00:00,  1.47it/s]


[Epoch 37] Loss: 413.704999
Epoch 38


Training:   0%|                                           | 0/2 [00:00<?, ?it/s]

Predicted score mean: -0.04990769585390133 std: 0.039944799768113375
True score mean: -0.013799323241021972 std: 1.2781371552320642


Training:  50%|█████████████████▌                 | 1/2 [00:00<00:00,  1.49it/s]

Predicted score mean: -0.06302424368271464 std: 0.029383038416395623
True score mean: 0.012629903254916218 std: 1.6191361595024638


Training: 100%|███████████████████████████████████| 2/2 [00:01<00:00,  1.46it/s]


[Epoch 38] Loss: 447.475453
Epoch 39


Training:   0%|                                           | 0/2 [00:00<?, ?it/s]

Predicted score mean: -0.06785010401123033 std: 0.026821456578210556
True score mean: -0.0009365948069605412 std: 1.3581005656536482


Training:  50%|█████████████████▌                 | 1/2 [00:00<00:00,  1.49it/s]

Predicted score mean: -0.058043615011300856 std: 0.024036337171185106
True score mean: 1.4411084600173375e-05 std: 1.4943894644502427


Training: 100%|███████████████████████████████████| 2/2 [00:01<00:00,  1.46it/s]


[Epoch 39] Loss: 466.522318
Epoch 40


Training:   0%|                                           | 0/2 [00:00<?, ?it/s]

Predicted score mean: -0.041728580877798335 std: 0.03317687360081242
True score mean: -0.008538297067799343 std: 1.2722518278840007


Training:  50%|█████████████████▌                 | 1/2 [00:00<00:00,  1.49it/s]

Predicted score mean: -0.0251690128201964 std: 0.04630011910831358
True score mean: 0.0016320793203494074 std: 1.488049459066953


Training: 100%|███████████████████████████████████| 2/2 [00:01<00:00,  1.46it/s]


[Epoch 40] Loss: 363.939880
Epoch 41


Training:   0%|                                           | 0/2 [00:00<?, ?it/s]

Predicted score mean: -0.009472061558610662 std: 0.05566488085366331
True score mean: 0.0021010630231831024 std: 1.4720918489887473


Training:  50%|█████████████████▌                 | 1/2 [00:00<00:00,  1.45it/s]

Predicted score mean: 0.008216566230799008 std: 0.06109860041999703
True score mean: -0.0021401211111376555 std: 1.978457812122418


Training: 100%|███████████████████████████████████| 2/2 [00:01<00:00,  1.47it/s]


[Epoch 41] Loss: 1155.879626
Epoch 42


Training:   0%|                                           | 0/2 [00:00<?, ?it/s]

Predicted score mean: 0.03352665705563294 std: 0.057292696469873595
True score mean: -0.017300924784677367 std: 2.1087549225977766


Training:  50%|█████████████████▌                 | 1/2 [00:00<00:00,  1.45it/s]

Predicted score mean: 0.03543110023103029 std: 0.031449123834830846
True score mean: 0.0029939369827897038 std: 2.0295658849127607


Training: 100%|███████████████████████████████████| 2/2 [00:01<00:00,  1.43it/s]


[Epoch 42] Loss: 1705.078444
Epoch 43


Training:   0%|                                           | 0/2 [00:00<?, ?it/s]

Predicted score mean: 0.033451057656764836 std: 0.016735557935103108
True score mean: 0.006892176213353631 std: 1.3465696322052723


Training:  50%|█████████████████▌                 | 1/2 [00:00<00:00,  1.48it/s]

Predicted score mean: 0.0226343102423521 std: 0.042826658265222486
True score mean: 0.011259968439918413 std: 1.7670278257379726


Training: 100%|███████████████████████████████████| 2/2 [00:01<00:00,  1.46it/s]


[Epoch 43] Loss: 592.743479
Epoch 44


Training:   0%|                                           | 0/2 [00:00<?, ?it/s]

Predicted score mean: 0.014361074918576315 std: 0.0726275340154542
True score mean: 0.005161251659632757 std: 1.3525747093120395


Training:  50%|█████████████████▌                 | 1/2 [00:00<00:00,  1.44it/s]

Predicted score mean: 0.008114226332154004 std: 0.09776254132119207
True score mean: -0.03592545198272362 std: 1.737830555229714


Training: 100%|███████████████████████████████████| 2/2 [00:01<00:00,  1.45it/s]


[Epoch 44] Loss: 795.460424
Epoch 45


Training:   0%|                                           | 0/2 [00:00<?, ?it/s]

Predicted score mean: -0.01953338923396502 std: 0.11283589415898079
True score mean: 0.00014643764920656106 std: 1.953838062441812


Training:  50%|█████████████████▌                 | 1/2 [00:00<00:00,  1.41it/s]

Predicted score mean: -0.040781386572675896 std: 0.11281400672305204
True score mean: -0.0008058280345452842 std: 1.7560823253361073


Training: 100%|███████████████████████████████████| 2/2 [00:01<00:00,  1.45it/s]


[Epoch 45] Loss: 1168.766840
Epoch 46


Training:   0%|                                           | 0/2 [00:00<?, ?it/s]

Predicted score mean: -0.04277117068971464 std: 0.10502700717765834
True score mean: -0.0003450737232775134 std: 1.5048489570410755


Training:  50%|█████████████████▌                 | 1/2 [00:00<00:00,  1.44it/s]

Predicted score mean: -0.04264383964504469 std: 0.09781381634680773
True score mean: -0.008593468436589754 std: 1.4599521742091852


Training: 100%|███████████████████████████████████| 2/2 [00:01<00:00,  1.45it/s]


[Epoch 46] Loss: 563.995534
Epoch 47


Training:   0%|                                           | 0/2 [00:00<?, ?it/s]

Predicted score mean: -0.039035845147342575 std: 0.08957767301428607
True score mean: 0.009175761203411431 std: 1.5157276947269034


Training:  50%|█████████████████▌                 | 1/2 [00:00<00:00,  1.41it/s]

Predicted score mean: -0.025425528023942805 std: 0.07628951011734676
True score mean: -0.003510660256004077 std: 1.3844011670826317


Training: 100%|███████████████████████████████████| 2/2 [00:01<00:00,  1.44it/s]


[Epoch 47] Loss: 499.320394
Epoch 48


Training:   0%|                                           | 0/2 [00:00<?, ?it/s]

Predicted score mean: -0.0016132796866178724 std: 0.05658576246765358
True score mean: 0.010845019158862056 std: 1.3536247901292993


Training:  50%|█████████████████▌                 | 1/2 [00:00<00:00,  1.43it/s]

Predicted score mean: 0.01899039815260447 std: 0.03904188341907566
True score mean: -0.03359189865883047 std: 1.7681894498157074


Training: 100%|███████████████████████████████████| 2/2 [00:01<00:00,  1.44it/s]


[Epoch 48] Loss: 806.798944
Epoch 49


Training:   0%|                                           | 0/2 [00:00<?, ?it/s]

Predicted score mean: 0.016975272131329003 std: 0.03719217704118154
True score mean: -0.005111475586391307 std: 1.4644756788592999


Training:  50%|█████████████████▌                 | 1/2 [00:00<00:00,  1.39it/s]

Predicted score mean: 0.010438426826470856 std: 0.05043453070691543
True score mean: -0.00031414874576610224 std: 1.6350003718692343


Training: 100%|███████████████████████████████████| 2/2 [00:01<00:00,  1.43it/s]


[Epoch 49] Loss: 725.096508
Epoch 50


Training:   0%|                                           | 0/2 [00:00<?, ?it/s]

Predicted score mean: -0.007409937940954259 std: 0.04786650437195301
True score mean: 0.0009148915343452325 std: 1.419640902845067


Training:  50%|█████████████████▌                 | 1/2 [00:00<00:00,  1.47it/s]

Predicted score mean: -0.016609120532428998 std: 0.045100824236459135
True score mean: 0.007652515216434724 std: 1.8899543161764356


Training: 100%|███████████████████████████████████| 2/2 [00:01<00:00,  1.45it/s]

[Epoch 50] Loss: 760.397158





In [6]:
@torch.no_grad()
def sample_reverse_diffusion(model, L, n_steps=1000, device='cuda'):
    """
    Sample protein coordinates using reverse diffusion (constant β(t) = 1).

    Args:
        model: trained EGNNScoreModel
        L: number of residues
        n_steps: number of reverse diffusion steps
        device: torch device

    Returns:
        x: (L, 3) denoised coordinates
    """
    model.eval()
    x = torch.randn(1, L, 3, device=device) # Gaussian noise
    ts = torch.linspace(1.0, 1e-3, n_steps, device=device)  # discretize time; avoid t=0

    dt = 1.0 / n_steps  # Positive forward step size

    for t in ts:
        t_batch = torch.full((1, 1), t, device=device)  # (1, 1)
        
        # Node & edge encoding 
        node_feats = encode_node_features(x, t_batch, dim=h_dim)      # (1, L, 2*h_dim)
        edge_feats = encode_edge_features(L, dim=edge_dim, device=device)  # (L, L, 2*edge_dim)
        edge_feats = edge_feats.unsqueeze(0)  # (1, L, L, 2*edge_dim)

        # Predict the score ∇ log p_t(x_t)
        score = model(x, t_batch, node_feats, edge_feats)  # (1, L, 3)
        # score = model(x, x0[0,:,:],t_batch)

        # Reverse diffusion update (Euler-Maruyama)
        z = torch.randn_like(x) if t > 1e-3 else 0.0  # No noise at final step
        x = x + (-0.5 * x - score)  * (-1*dt) + torch.sqrt(torch.tensor(-1*dt, device=device)) * z

    return x.squeeze(0)  # (L, 3)


In [7]:
sample = sample_reverse_diffusion(model, L=23, device='cpu')  # shape (32, 3)
write_ca_to_pdb(sample, 'sampled.pdb')

# sample = sample_reverse_diffusion(score_R3, L=23, device='cpu')  # shape (32, 3)
# write_ca_to_pdb(sample, 'compare.pdb')

Saved Cα PDB to: sampled.pdb


In [8]:
from torchinfo import summary

# Example batch dimensions
N, L = 8, 100
# Create dummy inputs
x_t = torch.randn(N, L, 3)
t = torch.randn(N, 1)
node_feats = encode_node_features(x_t, t, dim=h_dim)  # (1, L, 64)
# edge_feats = encode_edge_features(L, dim=edge_dim, device=device)  # (L, L, 64)
# edge_feats = edge_feats.unsqueeze(0)  # (1, L, L, 64)
edge_feats = encode_edge_features(L, dim=edge_dim, device=device)  # (L, L, edge_dim)
edge_feats = edge_feats.unsqueeze(0).expand(N, -1, -1, -1)          # (N, L, L, edge_dim)
# Now use input_data instead of input_size
summary(model, input_data=(x_t, t, node_feats, edge_feats))

Layer (type:depth-idx)                   Output Shape              Param #
EGNNScoreModel                           [8, 100, 3]               --
├─EGNN: 1-1                              [8, 100, 3]               --
│    └─ModuleList: 2-1                   --                        --
│    │    └─EGNNLayer: 3-1               [8, 100, 3]               29,185
│    │    └─EGNNLayer: 3-2               [8, 100, 3]               29,185
│    │    └─EGNNLayer: 3-3               [8, 100, 3]               29,185
│    │    └─EGNNLayer: 3-4               [8, 100, 3]               29,185
├─Linear: 1-2                            [8, 100, 3]               195
Total params: 116,935
Trainable params: 116,935
Non-trainable params: 0
Total mult-adds (M): 0.94
Input size (MB): 2.78
Forward/backward pass size (MB): 670.35
Params size (MB): 0.94
Estimated Total Size (MB): 674.07

In [4]:
# omega_grid = torch.linspace(1e-4,torch.pi,1000)

# for batch in loader:
#     x0 = batch['translations']
#     r0 = batch['rotations']
#     t = batch['timesteps']
#     print(batch['protein_id'])

#     print(x0.shape)
#     print(r0.shape)
#     write_frames_to_pdb(x0, r0, 'og.pdb')

#     omega_grid = torch.linspace(1e-4,torch.pi,1000)

#     for t in [0.3, 1.5, 5.0]:
#         t_tensor = torch.Tensor([t]).unsqueeze(0)
#         xt = noise_translations(x0.unsqueeze(0), t_tensor).squeeze(0)
#         rt = noise_rotations(r0.unsqueeze(0), t_tensor, omega_grid).squeeze(0)
        # write_frames_to_pdb(x0, r0, 'og.pdb')
        # write_frames_to_pdb(xt, rt, f'noised_{t}.pdb')
    
    # xt = noise_translations(x0.unsqueeze(0), t.unsqueeze(0))
    # rt = noise_rotations(r0.unsqueeze(0), t.unsqueeze(0), omega_grid)

    # write_frames_to_pdb(x0.squeeze(0), r0.squeeze(0), 'og.pdb')
    # write_frames_to_pdb(xt.squeeze(0), rt.squeeze(0), 'noised.pdb')

# xt = noise_translations(x0, t)
# rt = noise_rotations(r0, t, omega_grid)


# write_frames_to_pdb(x0[0,:,:].squeeze(0), r0[0,:,:].squeeze(0), 'og.pdb')
# write_frames_to_pdb(xt[0,:,:].squeeze(0), rt[0,:,:].squeeze(0), 'noised.pdb')