In [8]:
import mdtraj as md
import torch
import pandas as pd
import copy
import os
import pickle
from tqdm.auto import tqdm
from NewSample import coords_to_dists, loss_fcn, coord_to_dcd, smooth_transition_loss_by_sample

In [2]:
def to_list(object):
    return object if type(object) == list else [object]

def parse_filename(filename):
    filename = filename.split('/')[-1]  # remove directory path
    components = filename.split('_')
    region_idx = int(components[1])
    cond_scale = float(components[2]) if '.' in components[2] else int(components[2])
    rescaled_phi = float(components[3]) if '.' in components[3] else int(components[3])
    milestone = int(components[4])
    chrom = components[5].split('.')[0]

    return region_idx, cond_scale, rescaled_phi, milestone, chrom

def is_valid(
    filename,
    region_idx,
    cond_scale,
    rescaled_phi,
    milestone,
    chrom
):

    try:
        ri, cs, rp, ms, ch = parse_filename(filename) 
    except:
        return False, (None, None, None, None, None)
    
    for desired,actual in [(region_idx,ri),(cond_scale,cs),
                         (rescaled_phi,rp),(milestone,ms),(chrom,ch)]:
        if desired is not None and actual not in to_list(desired):
            return False, (ri, cs, rp, ms, ch)
    
    return True, (ri, cs, rp, ms, ch)
    
def get_all_sample_types(
    sample_directory,
    *,
    region_idx=None,
    cond_scale=None,
    rescaled_phi=None,
    milestone=None,
    chrom = None
):
    samples = os.listdir(sample_directory)

    to_process = pd.DataFrame({
        'region_idx':[],
        'cond_scale':[],
        'rescaled_phi':[],
        'milestone':[],
        'chrom':[],
    })
    
    for f in samples:
        valid, properties = is_valid(f,region_idx,cond_scale,rescaled_phi,milestone,chrom) 
        if valid: 
            to_process.loc[len(to_process)] = properties

    to_process.sort_values(
        ['milestone','chrom','region_idx','cond_scale','rescaled_phi'], # sorts in this order
        inplace=True,
        ignore_index=True,
    )

    return [*to_process.itertuples(index=False,name=None)]

In [9]:
'''
def get_best_alignments(t_ref,t_sample):
    '#''
    Efficiency could be improved by simply finding the best distance alignments, which
    are independent of coordinate superimposition, then only aligning the relevant structures,
    rather than aligning all before comparing distance maps
    '#''
    ref_dists = coords_to_dists(torch.from_numpy(t_ref.xyz))
    while ref_dists.ndim < 3: # Shouldn't happen unless only one structure was available
        ref_dists = ref_dists.unsqueeze(0)
    #ref_dists = ref_dists.unsqueeze(1).expand(-1,len(t_sample),-1,-1)
        
    best_alignments = []
    for frame in range(len(t_ref)):

        # This is an in-place operation
        t_sample.superpose(t_ref,frame=frame)

        # Measure similarities
        ref_dist = ref_dists[frame,...]
        sample_coords = torch.from_numpy(t_sample.xyz)
        losses = torch.tensor([
            loss_fcn(sample_coords[i,:,:],ref_dist) for i in range(len(t_sample))
        ]) # Should already be flat
        losses/= ref_dist.numel()
        min_loss,min_loss_idx = losses.max(0)

        best_alignments.append(
            (min_loss,int(min_loss_idx),sample_coords[min_loss_idx,...])
        )

    best_info = [{'Loss':loss,'Index':idx} for loss,idx,_ in best_alignments]
    best_alignments = torch.stack([coords for _,_,coords in best_alignments],dim=0)

    return best_info,best_alignments
'''
def get_best_alignments(t_ref,t_sample,r_c=1.,long_scale=1/8,use_gpu=True,high_precision=False):

    ref_dists = coords_to_dists(torch.from_numpy(t_ref.xyz))
    ref_dists = ref_dists.unsqueeze(1).expand(-1,len(t_sample),-1,-1)
    sample_dists = coords_to_dists(torch.from_numpy(t_sample.xyz))
    sample_dists = sample_dists.unsqueeze(0).expand(len(t_ref),-1,-1,-1)

    losses = smooth_transition_loss_by_sample(sample_dists,ref_dists,r_c,long_scale,use_gpu,high_precision)

    best_loss,best_loss_idx = losses.min(1)
    
    coords = torch.empty(*t_ref.xyz.shape)
    for frame,idx in enumerate(best_loss_idx):
        ts1 = t_sample[idx]
        ts2 = copy.deepcopy(ts1)
        ts2.xyz[...,-1]*= -1 # Reflect, in case we computed the isomer with backward chirality
        ts1.superpose(t_ref,frame=frame)
        ts2.superpose(t_ref,frame=frame)

        diff1 = ((t_ref[frame].xyz - ts1.xyz)**2).sum()
        diff2 = ((t_ref[frame].xyz - ts2.xyz)**2).sum()
        if diff1 < diff2: 
            coords[frame,...] = torch.from_numpy(ts1.xyz).squeeze(0)
        else:
            coords[frame,...] = torch.from_numpy(ts2.xyz).squeeze(0)

    best_info = [
        {
            'Loss':best_loss[i],
            'Index':best_loss_idx[i]
        } for i in range(len(best_loss))
    ]
    
    return best_info,coords
    

def get_gen_filepaths(
    sample_dir,
    region_idx,
    cond_scale,
    rescaled_phi,
    milestone,
    chrom
):
    
    filepath = sample_dir
    if filepath != '' and filepath[-1] != '/':
        filepath+= '/'
    filepath+= f'sample_{region_idx}_{float(cond_scale)}_{float(rescaled_phi)}'
    filepath+= f'_{milestone}_{chrom}'
    return filepath+'.dcd', filepath+'.psf'

def get_tan_filepaths(
    sample_dir,
    chrom,
    region_idx,
    nbins
):
    filepath = sample_dir
    if filepath != '' and filepath[-1] != '/':
        filepath+= '/'
    filepath+= f'chrom_{chrom}_region_{region_idx}_nbins_{nbins}'
    return filepath+'.dcd',filepath+'.psf'

def align_many_samples(
    sample_dir,
    reference_dir,
    *,
    chroms=None,
    milestones=None,
    region_idxs=None,
    cond_scales=None,
    rescaled_phis=None,
    r_c=1.,
    long_scale=1/8,
    use_gpu=True,
    high_precision=False,
):
    sample_types = get_all_sample_types(
        sample_directory=sample_dir,
        region_idx=region_idxs,
        cond_scale=cond_scales,
        rescaled_phi=rescaled_phis,
        milestone=milestones,
        chrom = chroms
    )

    if sample_dir != '' and sample_dir[-1] != '/':
        sample_dir = sample_dir + '/'
    dest_dir = sample_dir + 'aligned/'

    #with tqdm(initial = 0, total = len(sample_types), disable = False) as pbar:
    for region_idx, cond_scale, rescaled_phi, milestone, chrom in tqdm(sample_types,desc='Sample Alignment Progress'):
        # Load generated samples
        gen_dcd,gen_psf = get_gen_filepaths(sample_dir,region_idx,cond_scale,rescaled_phi,milestone,chrom)
        t_gen = md.load(gen_dcd,top=gen_psf)

        # Load reference samples
        nbins = t_gen.xyz.shape[-2]
        ref_dcd,ref_psf = get_tan_filepaths(reference_dir,chrom,region_idx,nbins)
        t_ref = md.load(ref_dcd,top=ref_psf)

        # Align the data, place configurations into a 
        info,coords = get_best_alignments(t_ref,t_gen,r_c,long_scale,use_gpu,high_precision)

        # Save the information
        sample_name = '.'.join(gen_dcd.split('/')[-1].split('.')[:-1])
        coord_to_dcd(coords,dest_dir,sample_name)
        pickle.dump(info,open(dest_dir+sample_name+'_loss_info.pkl','wb'))
    
        
    

In [10]:
align_many_samples(
    '../../data/samples/origami_64_no_embed_reduction/dcd_files/',
    '../../data/samples/Tan/',
    chroms=None,#'1',
    milestones=120,
    region_idxs=None,#330,
    cond_scales=None,#3.0,
    rescaled_phis=None#.5
)

Sample Alignment Progress:   0%|          | 0/2670 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [1]:
from pairwise_alignment import align_many_samples
align_many_samples(
    '../../data/samples/origami_64_no_embed_reduction/dcd_files/',
    '../../data/samples/Tan/',
    chroms=None,#'1',
    milestones=120,
    region_idxs=None,#330,
    cond_scales=None,#3.0,
    rescaled_phis=None#.5
)

Sample Alignment Progress:   0%|          | 0/2670 [00:00<?, ?it/s]

In [6]:
import time
tt = -time.time()
align_many_samples(
    '../../data/samples/origami_64_no_embed_reduction/dcd_files/',
    '../../data/samples/Tan/',
    chroms='1',
    milestones=120,
    region_idxs=330,
    cond_scales=3.0,
    rescaled_phis=.5
)
tt+= time.time()
tt

10.19516110420227

In [None]:
asdf

In [40]:
align_many_samples(
    '../../data/samples/origami_64_no_embed_reduction/dcd_files/',
    '../../data/samples/Tan/',
    chroms=None,#'1',
    milestones=120,
    region_idxs=None,#330,
    cond_scales=None,#3.0,
    rescaled_phis=None#.5
)

KeyboardInterrupt: 

In [6]:
t = md.load(
    '../../data/samples/origami_64_no_embed_reduction/dcd_files/sample_200_0.5_0.0_36_1.dcd',
    top = '../../data/samples/origami_64_no_embed_reduction/dcd_files/sample_200_0.5_0.0_36_1.psf'
)

In [7]:
t1 = md.load(
    '../../data/samples/Tan/chrom_1_region_200_nbins_64.dcd',
    top = '../../data/samples/Tan/chrom_1_region_200_nbins_64.psf'
)

In [8]:
print(len(t))
print(len(t1))

1919
90


In [10]:
len(t1)

90

In [11]:
len(t)

1919

In [None]:
1919 

In [60]:
t2 = copy.deepcopy(t)
t2.xyz[...,2]*= -1

In [69]:
torch.rand(5,5,5).sum((-1,-2)).shape

torch.Size([5])

In [8]:
t1

<mdtraj.Trajectory with 90 frames, 64 atoms, 1 residues, without unitcells at 0x7f8f6e658310>

In [19]:
t1[0].xyz.shape

(1, 64, 3)

In [20]:
losses.shape

torch.Size([1919, 90])

In [21]:
a,b = losses.min(0)
a.shape

torch.Size([90])

In [27]:
import numpy as np
(np.random.rand(5)**2).sum()

3.114803162908151

In [15]:
2000 * 90 * 8 * 64**2 / 1024**3

5.4931640625

In [9]:
a=torch.rand(5,5,5)
mask=torch.where(a<.5)#,dtype=torch.short)

In [54]:
def smooth_transition_one_element(output,target,r_c,long_scale,m,b):
    diff = max(output-target,target-output) / r_c
    if diff < 1:
        return diff**2
    else:
        return m*diff**long_scale + b

def smooth_transition_loss_by_sample(
    output,
    target,
    r_c=1.0, # Transition distance from x**2 -> x**(long_scale)
    long_scale=1/8,
    use_gpu=True,
    high_precision=True
):
    '''
    Reduces to smooth L1 loss if  long_scale == 1.
    
    Rather than summing over ALL data, sum over the final two 
    dimensions (corresponding to individual distance maps). 
    '''
    # Scale to ensure the two functions have the same slope at r_c
    m = 2 / long_scale
    # Shift to ensure the two functions have the same value at r_c
    b = 1 - m
    
    return_device = output.device
    return_dtype = output.dtype
    if use_gpu and torch.cuda.is_available():
        output = output.cuda()
        target = target.cuda()
    if high_precision:
        output = output.double()
        target = target.double()

    losses = (output - target).abs_()
    del output, target
    losses/= r_c
    
    if losses.is_cuda:
        '''
        This is slower than using masking, but it avoid memory issues associated
        with mask indexing (torch turns bool masks into int64 indexing arrays) while
        remaining faster than some alternative low-memory options I tried. 
        '''
    
        #losses = torch.where(
        #    losses < 1,
        #    losses.square(),
        #    m*losses.pow(long_scale)+b,
        #    out=losses
        #)
        torch.where(
            losses < 1,
            losses.square(),
            m*losses.pow(long_scale)+b,
            out=losses
        )
        
    else:
        '''
        Assume that these memory issues don't arise on the CPU
        '''
        mask = losses < 1
        if mask.any():
            #losses[mask].square_()
            losses[mask] = losses[mask]**2
        mask^= True
        if mask.any():
            losses[mask] = m*losses[mask]**long_scale + b
        del mask
    

    return losses.sum((-1,-2)).to(dtype=return_dtype,device=return_device)

In [56]:
dist1 = coords_to_dists(torch.from_numpy(t1.xyz))
dist2 = coords_to_dists(torch.from_numpy(t.xyz))
dist1 = dist1.unsqueeze(1).expand(-1,dist2.shape[0],-1,-1)
dist2 = dist2.unsqueeze(0).expand(dist1.shape[0],-1,-1,-1)

In [57]:
tt = -time.time()
losses = smooth_transition_loss_by_sample(dist1,dist2,use_gpu=True)
tt+= time.time()
tt

1.5471458435058594

In [58]:
tt = -time.time()
losses1 = smooth_transition_loss_by_sample(dist1,dist2,use_gpu=False)
tt+= time.time()
tt

3.2595009803771973

In [1]:
import torch
a = torch.rand(5,10,3)
b = torch.linalg.norm(a,dim=-1)

In [2]:
b.shape

torch.Size([5, 10])

In [5]:
torch.linalg.vecdot(a[...,:1],a,dim=-1).shape

torch.Size([5, 10])

In [6]:
b[1,:] = torch.nan
b.nan_to_num_(torch.inf)
b

tensor([[0.4869, 1.0372, 1.2529, 0.8914, 0.3977, 1.1891, 1.3347, 0.4880, 0.6940,
         1.0607],
        [   inf,    inf,    inf,    inf,    inf,    inf,    inf,    inf,    inf,
            inf],
        [1.3215, 0.9294, 0.9438, 0.8269, 0.9954, 0.9360, 0.6900, 1.2909, 1.2902,
         1.1081],
        [1.1064, 0.8872, 1.1795, 0.8327, 0.7717, 1.0095, 1.0139, 1.1269, 0.4704,
         1.1709],
        [0.8528, 0.4068, 0.9754, 0.7854, 0.8601, 1.2958, 1.3150, 0.7143, 1.0410,
         0.8614]])

In [12]:
a = torch.zeros(5)
a.min(0)

torch.return_types.min(
values=tensor(0.),
indices=tensor(0))