Restarted aspire (Python 3.8.18)

In [27]:
import aspire
import numpy as np
import pandas as pd
from aspire.operators import RadialCTFFilter
from aspire.source.simulation import Simulation
from aspire.volume import LegacyVolume, Volume
from utils import volsCovarEigenvec
import time
from covar_estimation import im_stack_backward
import torch
# Specify parameters
img_size = 28  # image size in square
num_imgs = 2000  # number of images
dtype = np.float32

rank = 4
c = rank + 1
vols = LegacyVolume(
    L=img_size,
    C=c,
    dtype=dtype,
).generate()
vols -= np.mean(vols,axis=0)
sim = Simulation(
    unique_filters=[RadialCTFFilter(defocus=d) for d in np.linspace(1.5e4, 2.5e4, 7)],
    n=num_imgs,
    vols=vols,
    dtype=dtype,
    amplitudes=1,
    offsets = 0
)

vectorsGD = torch.tensor(volsCovarEigenvec(vols),requires_grad = False)


2024-04-13 17:59:16,984 INFO [aspire.source.image] Creating Simulation with 2000 images.


In [28]:
%load_ext autoreload
%autoreload 2

from covar_sgd import CovarDataset,Covar,CovarTrainer
cds = CovarDataset(sim,vectorsGD = vectorsGD)
covar = Covar(resolution=img_size,rank=rank)#,vectors=vectorsGD.reshape((rank,img_size,img_size,img_size))/cds.im_norm_factor)
batch_size = 1
learning_rate = 1e-4 
momentum = 0.9
reg = 1e-5
gamma_lr = 0.8
gamma_reg = 1
kwargs_dict = {'max_epochs' : 10, 'lr' : learning_rate,'momentum' : momentum,'optim_type' : 'SGD','reg' : reg,'gamma_lr': gamma_lr,'gamma_reg' : gamma_reg}
#kwargs_dict = {'max_epochs' : 10, 'lr' : 1e-10,'momentum' : momentum,'optim_type' : 'Adam','reg' : reg,'gamma_lr': gamma_lr,'gamma_reg' : gamma_reg}
device = torch.device('cuda:0')
dataloader = torch.utils.data.DataLoader(cds,batch_size = batch_size,shuffle = False)#,collate_fn=dataset_collate)


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [None]:
torch.manual_seed(0)
from covar_distributed import trainParallel

trainParallel(covar,cds,num_gpus = 8,batch_size = batch_size,**kwargs_dict)



In [30]:
torch.manual_seed(0)
trainer = CovarTrainer(covar,dataloader,device)
trainer.train(**kwargs_dict)

Epoch 2 , cost value : 8.47e-01,  cosine sim : 0.10, frobenium norm error : 1.00e+00:  96%|█████████▌| 1924/2000 [03:11<00:07, 10.04it/s] 
Epoch 8 , cost value : 5.00e-01,  cosine sim : 0.52, frobenium norm error : 9.62e-01:  44%|████▍     | 879/2000 [01:25<01:48, 10.34it/s] 
Epoch 0 , cost value : 6.52e-01,  cosine sim : 0.82, frobenium norm error : 5.82e-01: 100%|██████████| 2000/2000 [00:12<00:00, 165.47it/s]
Epoch 1 , cost value : 6.46e-01,  cosine sim : 0.86, frobenium norm error : 5.70e-01: 100%|██████████| 2000/2000 [00:10<00:00, 188.48it/s]
Epoch 2 , cost value : 3.08e-01,  cosine sim : 0.86, frobenium norm error : 5.96e-01:   3%|▎         | 67/2000 [00:00<00:10, 190.79it/s]

KeyboardInterrupt: 

In [None]:
%load_ext autoreload
%autoreload 2
from covar_analyzer import CovarAnalyzer
import torch
c = CovarAnalyzer.load('data/tmp2/results.csv')

#c.plotCosineSim()
#c.plotWeightedCosineSim()
c.plotFroErr()




In [None]:
import pickle
from projection_funcs import centered_fft2,vol_forward
x = pickle.load(open('data/pts.bin','rb'))
projs = x['projs'][0].reshape((rank,img_size,img_size))
fft_projs = aspire.image.Image(np.abs(centered_fft2(torch.tensor(projs)).numpy()))
projs = aspire.image.Image(projs)


projs.show()
fft_projs.show()

In [None]:
from nufft_plan import NufftPlan
d = torch.device('cuda:1')
plan = NufftPlan((img_size,)*3,batch_size = rank,gpu_device_id = d.index,gpu_method=1,gpu_sort = 0)
plan.setpts(torch.tensor(x['pts']).to(d))
vols = torch.tensor(x['vols']).to(d)
proj_vols = vol_forward(vols,plan)
p = aspire.image.Image(proj_vols.cpu().numpy())
p.show()

In [None]:

p = (torch.remainder(cds.pts_rot + torch.pi , 2 * torch.pi) - torch.pi) - torch.tensor(x['pts'])
#torch.tensor(x['pts']).shape
ind = torch.argmin(torch.norm(p.reshape((2000,-1)),dim=1))
aspire.utils.Rotation.from_matrix((sim.rotations[ind])).as_rotvec()



In [None]:
b = fft_projs.asnumpy()[1].reshape(1,-1)
a = np.argmax(fft_projs.asnumpy()[1])
b[0,a]
x['pts'][:,a]

In [None]:
from utils import meanCTFPSD
aspire.image.Image(cds.unique_filters.numpy()).show()