Restarted aspire (Python 3.8.18)

In [53]:
import aspire
import numpy as np
import pandas as pd
from aspire.operators import RadialCTFFilter
from aspire.source.simulation import Simulation
from aspire.volume import LegacyVolume, Volume
from utils import volsCovarEigenvec
import time
from covar_estimation import im_stack_backward
import torch
# Specify parameters
img_size = 15  # image size in square
num_imgs = 2000  # number of images
dtype = np.float32


rank = 4
c = rank + 1
vols = LegacyVolume(
    L=img_size,
    C=c,
    dtype=dtype,
).generate()
vols -= np.mean(vols,axis=0)
sim = Simulation(
    #unique_filters=[RadialCTFFilter(defocus=d) for d in np.linspace(1.5e4, 2.5e4, 7)],
    n=num_imgs,
    vols=vols,
    dtype=dtype,
    amplitudes=1,
    offsets = 0
)

vectorsGD = torch.tensor(volsCovarEigenvec(vols).asnumpy(),requires_grad = False)

2024-03-23 14:25:02,978 INFO [aspire.source.image] Creating Simulation with 2000 images.


In [56]:
%load_ext autoreload
%autoreload 2

from covar_sgd import CovarDataset,Covar,CovarTrainer,dataset_collate
cds = CovarDataset(sim,rank = rank)
covar = Covar(resolution=img_size,rank=rank,norm_factor=cds.im_norm_factor)#,vectors=vectorsGD)
batch_size = 10
learning_rate = 1e-4
momentum = 0.9
device = torch.device('cuda:0')
dataloader = torch.utils.data.DataLoader(cds,batch_size = batch_size,shuffle = False,collate_fn=dataset_collate)
optimizer = torch.optim.SGD(covar.parameters(),lr = learning_rate,momentum = momentum)
trainer = CovarTrainer(covar,dataloader,device,vectorsGD = vectorsGD/cds.im_norm_factor)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [57]:
trainer.train(10)

Epoch 0 , :   0%|          | 0/200 [00:00<?, ?it/s]

cost value : 9.82e+04,  cosine sim : 0.59: 100%|██████████| 200/200 [00:11<00:00, 17.56it/s]
cost value : 3.04e+05,  cosine sim : 0.66: 100%|██████████| 200/200 [00:11<00:00, 17.81it/s]
cost value : 8.61e+04,  cosine sim : 0.70: 100%|██████████| 200/200 [00:11<00:00, 17.59it/s]
cost value : 4.59e+04,  cosine sim : 0.71: 100%|██████████| 200/200 [00:11<00:00, 17.08it/s]
cost value : 2.33e+05,  cosine sim : 0.75: 100%|██████████| 200/200 [00:11<00:00, 17.84it/s]
cost value : 3.05e+05,  cosine sim : 0.76: 100%|██████████| 200/200 [00:11<00:00, 17.45it/s]
cost value : 7.78e+04,  cosine sim : 0.77: 100%|██████████| 200/200 [00:11<00:00, 17.10it/s]
cost value : 3.91e+04,  cosine sim : 0.76: 100%|██████████| 200/200 [00:10<00:00, 18.60it/s]
cost value : 6.53e+04,  cosine sim : 0.77: 100%|██████████| 200/200 [00:11<00:00, 17.47it/s]
cost value : 3.47e+05,  cosine sim : 0.77: 100%|██████████| 200/200 [00:11<00:00, 18.13it/s]


In [52]:
images,plans = cds[:]
images = images.to(device)
#covar = Covar(resolution=img_size,rank=rank,norm_factor=cds.im_norm_factor)
covar = covar.to(device)
a = covar.cost(images,plans)

covarGD = Covar(resolution=img_size,rank=rank,norm_factor=cds.im_norm_factor,vectors=vectorsGD)
covarGD = covarGD.to(device)
b = covarGD.cost(images,plans)

print(a,b)


cost value : 6.02e+02,  cosine sim : 1.00:  40%|████      | 80/200 [03:05<04:38,  2.32s/it]


tensor(3086.4863, device='cuda:0', grad_fn=<MeanBackward1>) tensor(-0.0007, device='cuda:0', grad_fn=<MeanBackward1>)


In [48]:
trainer.log_cosine_sim

[array([[-0.02840121]], dtype=float32),
 array([[-0.03713643]], dtype=float32),
 array([[-0.07990742]], dtype=float32),
 array([[-0.19121245]], dtype=float32),
 array([[-0.42498663]], dtype=float32),
 array([[-0.6909434]], dtype=float32),
 array([[-0.7764266]], dtype=float32),
 array([[-0.7366661]], dtype=float32),
 array([[-0.67463344]], dtype=float32),
 array([[-0.6347184]], dtype=float32),
 array([[-0.6316152]], dtype=float32),
 array([[-0.68954724]], dtype=float32),
 array([[-0.7649933]], dtype=float32),
 array([[-0.80808586]], dtype=float32),
 array([[-0.8180496]], dtype=float32),
 array([[-0.8170496]], dtype=float32),
 array([[-0.8171366]], dtype=float32),
 array([[-0.82627994]], dtype=float32),
 array([[-0.85104334]], dtype=float32),
 array([[-0.8747826]], dtype=float32),
 array([[-0.8915245]], dtype=float32),
 array([[-0.90418965]], dtype=float32),
 array([[-0.9114733]], dtype=float32),
 array([[-0.92441475]], dtype=float32),
 array([[-0.9385556]], dtype=float32),
 array([[-0.9