In [1]:
import torch
import sys
import pandas as pd
import numpy as np
#%matplotlib widget
import matplotlib.pyplot as plt

from matplotlib.colors import LogNorm
import cooltools.lib.plotting # provides 'fall' colormap

sys.path.insert(1,'../data_utils_v2/')
from Sample import Sample

In [2]:
sample = pd.read_pickle('../diffusion/sampling_small/sample_0_1_1_38_X.pkl')
#sample = pd.read_pickle('../diffusion/results_large/generated_hic_map.pkl')['sample'];

In [3]:
hic_map = sample.contact_probabilities()

AttributeError: 'Sample' object has no attribute 'coords'

In [None]:
class DistLoss(torch.nn.Module): 

    def __init__(self,dists,self_interaction_included=True):
        super().__init__()

        # ignore diagonal if it's self-interaction
        i,j = torch.triu_indices(dists.shape[-2],dists.shape[-1],
                                               int(self_interaction_included)) 
        self.dists = dists[i,j] 

    def get_dists(self,coords):
        return torch.nn.functional.pdist(coords)
    
    def bond_strength(self):
        '''
        could add molecular interactions as well
        '''

    def forward(self,coords):
        return (self.dists - torch.nn.functional.pdist(coords) ).abs().sum()
    


def _adjust_coords_(
    coords,
    dists,
    n_it=1000,
    lr=0.01
    
):
    '''
    Given an initial guess of coordinate locations and the
    distance map predicted by the diffusion model, adjust
    the coordinates to best match the predicted map. 
    '''

    dist_loss = DistLoss(dists,False)

    coords.requires_grad_(True)

    optimizer = torch.optim.Adam([coords],lr=lr) 

    for _ in range(n_it):
        optimizer.zero_grad()
        loss = dist_loss(coords)
        loss.backward()
        optimizer.step()
    
    coords.requires_grad_(False) 
    
    return coords,loss.item()
    

    




In [None]:
def _dist_to_coord_neighbors_known_(sample,s):

    # Ensure we have distances in their unnormalized form
    sample.unnormalize_()

    # For covenience...
    dists = sample.batch[s,0,...]
    
    # Initialize the coordinates object 
    b,dt,dev = sample.batch_seg_len, sample.dtype, sample.device
    coords = torch.empty(b,3,dtype=dt,device=dev)
    coords[:] = torch.nan 
    
    # Place bead 0 at the origin 
    coords[0,:] = 0 
    
    # Place bead 1 on the x axis 
    coords[1,0] = dists[0,1]
    coords[1,1:] = 0 
    
    # Get other x coordinates
    #coords[2:,0] = (dists[0,1]**2 + dists[0,2:]**2 - dists[1,2:]**2) / (2*dists[0,1]**2)
    coords[2:,0] = ( 1 + (dists[0,2:]**2 - dists[1,2:]**2)/dists[0,1]**2 ) / 2
    
    # Place bead 2 in the xy plane with positive y 
    coords[2,1] = ( dists[0,2]**2 - coords[2,0]**2 ).sqrt()
    coords[2,2] = 0 
    
    # Get other y coordinates
    coords[3:,1] = 1
    coords[3:,1]+= ( dists[0,3:]**2 - coords[3:,0]**2 - dists[2,3:]**2 + (coords[3:,0] - coords[2,0])**2 ) / (dists[0,2]**2 - coords[2,0]**2)
    coords[3:,1]/= 2
    
    # Give bead 3 a positive z value 
    coords[3,2] = ( dists[0,3]**2 - coords[3,:2].square().sum() ).sqrt() 
    
    # Get other z coordinates
    coords[4:,2] = dists[0,4:]**2 - dists[3,4:]**2
    coords[4:,2]+= (coords[4:,:2] - coords[3,:2]).square().sum(1) - coords[4:,:2].square().sum(1)
    coords[4:,2]/= coords[3,2]**2
    coords[4:,2]+= 1
    coords[4:,2]/= 2
    
    return coords
    

In [None]:
def _dist_to_coord_(sample,s):
    '''
    Assumes self interactions are excluded from the map 
    '''
    # Ensure we have distances in their unnormalized form
    sample.unnormalize_()

    # For covenience...
    dists = sample.batch[s,0,...]
    
    # Initialize the coordinates object 
    b,dt,dev = dists.shape[-1]+1, sample.dtype, sample.device
    coords = torch.empty(b,3,dtype=dt,device=dev)
    coords[:] = torch.nan 
    
    # Place bead 0 at the origin 
    coords[0,:] = 0 
    
    # Place bead 1 on the x axis 
    coords[1,0] = dists[0,0]
    coords[1,1:] = 0 
    
    # Get other x coordinates
    #coords[2:,0] = (dists[0,1]**2 + dists[0,2:]**2 - dists[1,2:]**2) / (2*dists[0,1]**2)
    coords[2:,0] = ( 1 + (dists[0,1:]**2 - dists[1,1:]**2)/dists[0,0]**2 ) / 2
    
    # Place bead 2 in the xy plane with positive y 
    coords[2,1] = ( dists[0,1]**2 - coords[2,0]**2 ).sqrt()
    coords[2,2] = 0 
    
    # Get other y coordinates
    coords[3:,1] = 1
    coords[3:,1]+= ( dists[0,2:]**2 - coords[3:,0]**2 - dists[2,2:]**2 + (coords[3:,0] - coords[2,0])**2 ) / (dists[0,1]**2 - coords[2,0]**2)
    coords[3:,1]/= 2
    
    # Give bead 3 a positive z value 
    coords[3,2] = ( dists[0,2]**2 - coords[3,:2].square().sum() ).sqrt() 
    
    # Get other z coordinates
    coords[4:,2] = dists[0,3:]**2 - dists[3,3:]**2
    coords[4:,2]+= (coords[4:,:2] - coords[3,:2]).square().sum(1) - coords[4:,:2].square().sum(1)
    coords[4:,2]/= coords[3,2]**2
    coords[4:,2]+= 1
    coords[4:,2]/= 2
    
    return coords
    

In [None]:
class DistLoss(torch.nn.Module): 

    def __init__(self,dists,self_interaction_included=True):
        super().__init__()

        # ignore diagonal if it's self-interaction
        i,j = torch.triu_indices(dists.shape[-2],dists.shape[-1],
                                               int(self_interaction_included)) 
        self.dists = dists[...,i,j].squeeze()

        self.triu_indices = torch.triu_indices(
            dists.shape[-2]+1-int(self_interaction_included),
            dists.shape[-2]+1-int(self_interaction_included),
            1
        )

    def get_dists(self,coords):
        return torch.nn.functional.pdist(coords)
    
    def bond_strength(self):
        '''
        could add molecular interactions as well
        '''

    def forward(self,coords):
        i,j = self.triu_indices
        return (self.dists - torch.cdist(coords,coords)[...,i,j] ).abs().sum()
        #return (self.dists - torch.nn.functional.pdist(coords) ).abs().sum()
    


def _adjust_coords_(
    coords,
    dists,
    n_it=1000,
    lr=0.01
    
):
    '''
    Given an initial guess of coordinate locations and the
    distance map predicted by the diffusion model, adjust
    the coordinates to best match the predicted map. 
    '''

    dist_loss = DistLoss(dists,False)

    coords.requires_grad_(True)

    optimizer = torch.optim.Adam([coords],lr=lr) 

    for _ in range(n_it):
        optimizer.zero_grad()
        loss = dist_loss(coords)
        loss.backward()
        optimizer.step()
    
    coords.requires_grad_(False) 
    
    return coords

def get_coords(sample):

    # Get the coordinates for each sample
    coords = torch.empty(
        len(sample),sample.seg_len+1,3, # Shape 
        dtype=sample.dtype,
        device=sample.device
    )
    
    for i in range(len(sample)): 
        coords[i,...] = _dist_to_coord_(sample,i)
    
    # Refine the coordinates 
    
    coords = _adjust_coords_(coords,sample.batch)

    
    return coords.to(sample.batch_dtype) 


In [None]:
class DistLoss(torch.nn.Module): 

    def __init__(self,dists,self_interaction_included=True):
        super().__init__()

        # ignore diagonal if it's self-interaction
        i,j = torch.triu_indices(dists.shape[-2],dists.shape[-1],
                                               int(self_interaction_included)) 
        self.dists = dists[...,i,j].squeeze()

        self.triu_indices = torch.triu_indices(
            dists.shape[-2]+1-int(self_interaction_included),
            dists.shape[-2]+1-int(self_interaction_included),
            1
        )

    def get_dists(self,coords):
        return torch.nn.functional.pdist(coords)
    
    def bond_strength(self):
        '''
        could add molecular interactions as well
        '''

    def forward(self,coords,drop_percentile=None):
        i,j = self.triu_indices
        
        errs = (self.dists - torch.cdist(coords,coords)[...,i,j] ).abs()
        if drop_percentile is not None: 
            # Remove outliers 
            errs,_ = errs.flatten().sort()
            errs = errs[torch.arange(int((1-drop_percentile)*len(errs)))]

        return errs.sum()
        #return (self.dists - torch.cdist(coords,coords)[...,i,j] ).abs().sum()
        #return (self.dists - torch.nn.functional.pdist(coords) ).abs().sum()
    


def _adjust_coords_(
    coords,
    dists,
    n_it=1000,
    lr=0.01,
    drop_percentile=None
):
    '''
    Given an initial guess of coordinate locations and the
    distance map predicted by the diffusion model, adjust
    the coordinates to best match the predicted map. 
    '''

    dist_loss = DistLoss(dists,False)

    coords.requires_grad_(True)

    optimizer = torch.optim.Adam([coords],lr=lr) 

    for _ in range(n_it):
        optimizer.zero_grad()
        loss = dist_loss(coords,drop_percentile)
        loss.backward()
        optimizer.step()
    
    coords.requires_grad_(False) 
    
    return coords

def get_coords(sample,drop_percentile=None):

    # Get the coordinates for each sample
    coords = torch.empty(
        len(sample),sample.seg_len+1,3, # Shape 
        dtype=sample.dtype,
        device=sample.device
    )
    
    for i in range(len(sample)): 
        coords[i,...] = _dist_to_coord_(sample,i)
    
    # Refine the coordinates 
    
    coords = _adjust_coords_(coords,sample.batch,drop_percentile=drop_percentile)

    
    return coords.to(sample.batch_dtype) 


In [None]:
'''
s = 0 
coords = _dist_to_coord_neighbors_unknown_(sample,s) 
dists = sample.batch[s,0,...]
coords2, loss_final = _adjust_coords_(coords,dists)#,lr=.01)
''';

In [None]:
s = 0
coords = _dist_to_coord_(sample,s) 
sample.cuda()
coords2 = get_coords(sample,drop_percentile=1)

In [None]:
coords2.isnan().any(1).any(1).sum()

In [None]:
from mpl_toolkits.mplot3d import Axes3D
#fig = plt.figure()
#ax = Axes3D(fig)
ax = plt.axes(projection='3d')

'''
# Data for a three-dimensional line
zline = coords[:,2].cpu().numpy()#np.linspace(0, 15, 1000)
xline = coords[:,0].cpu().numpy()#np.sin(zline)
yline = coords[:,1].cpu().numpy()#np.cos(zline)
ax.plot3D(xline, yline, zline, 'gray')

# Data for three-dimensional scattered points
#zdata = 15 * np.random.random(100)
#xdata = np.sin(zdata) + 0.1 * np.random.randn(100)
#ydata = np.cos(zdata) + 0.1 * np.random.randn(100)
ax.scatter3D(xline, yline, zline, c=zline, cmap='Greens');
''';

# Data for a three-dimensional line
zline = coords[:,2].cpu().numpy()#np.linspace(0, 15, 1000)
xline = coords[:,0].cpu().numpy()#np.sin(zline)
yline = coords[:,1].cpu().numpy()#np.cos(zline)
ax.plot3D(xline, yline, zline, 'gray')

# Data for three-dimensional scattered points
#ax.scatter(xline, yline, zline, c=zline, cmap='Greens');
ax.scatter3D(xline, yline, zline, c=zline, cmap='Greens');

#plt.show()

In [None]:
from mpl_toolkits.mplot3d import Axes3D
#fig = plt.figure()
#ax = Axes3D(fig)
ax = plt.axes(projection='3d')

'''
# Data for a three-dimensional line
zline = coords[:,2].cpu().numpy()#np.linspace(0, 15, 1000)
xline = coords[:,0].cpu().numpy()#np.sin(zline)
yline = coords[:,1].cpu().numpy()#np.cos(zline)
ax.plot3D(xline, yline, zline, 'gray')

# Data for three-dimensional scattered points
#zdata = 15 * np.random.random(100)
#xdata = np.sin(zdata) + 0.1 * np.random.randn(100)
#ydata = np.cos(zdata) + 0.1 * np.random.randn(100)
ax.scatter3D(xline, yline, zline, c=zline, cmap='Greens');
''';

# Data for a three-dimensional line
zline = coords2[s,:,2].cpu().numpy()#np.linspace(0, 15, 1000)
xline = coords2[s,:,0].cpu().numpy()#np.sin(zline)
yline = coords2[s,:,1].cpu().numpy()#np.cos(zline)
ax.plot3D(xline, yline, zline, 'gray')

# Data for three-dimensional scattered points
#ax.scatter(xline, yline, zline, c=zline, cmap='Greens');
ax.scatter3D(xline, yline, zline, c=zline, cmap='Greens');

#plt.show()

In [None]:
coords2 = coords2[~coords2.isnan().any(1).any(1),...]
i,j = torch.triu_indices(coords2.shape[-2],coords2.shape[-2],1)

b = coords2.shape[0]
c = coords2.shape[-2]
d = torch.empty(b,c-1,c-1,dtype=coords2.dtype,device=coords2.device)
d[...,i,j-1] = torch.cdist(coords2,coords2)[...,i,j]
d[...,j-1,i] = d[...,i,j-1]

sample1 = Sample(data = d.unsqueeze(1))

In [None]:
hic_map = sample1.contact_probabilities()

In [None]:
norm = LogNorm(vmax=1)

fig = plt.figure()
ax = fig.add_subplot(111) 

im = ax.matshow(
    hic_map.cpu().numpy(),
    norm=norm,
    cmap='fall',
    #extent=(region[1], region[2], region[2], region[1])
);
ax.xaxis.set_visible(False)

cbar = fig.colorbar(im, label='Interaction Frequencies',location='right');

In [None]:
def get_hic2(sample,r_c=2,sigma=3):

    sample.unnormalize_()

    r = sample.batch.clone() # Distances

    mask = r < r_c 
    r[mask] = .5*( 1 + torch.tanh( sigma*( r_c - r[mask] ) ) )

    mask^= True 
    r[mask] = .5 * ( r_c / r[mask] )**4

    return r.mean(0).squeeze()

    
    

In [None]:
norm = LogNorm(vmax=1)

fig = plt.figure()
ax = fig.add_subplot(111) 

im = ax.matshow(
    get_hic2(sample).cpu().numpy(),
    norm=norm,
    cmap='fall',
    #extent=(region[1], region[2], region[2], region[1])
);
ax.xaxis.set_visible(False)

cbar = fig.colorbar(im, label='Interaction Frequencies',location='right');

#### Validate the above's implementation inside of the class

In [None]:
import torch
import sys
import pandas as pd
import numpy as np
#%matplotlib widget
import matplotlib.pyplot as plt

from matplotlib.colors import LogNorm
import cooltools.lib.plotting # provides 'fall' colormap

sys.path.insert(1,'../data_utils_v2/')
from Sample import Sample

In [None]:
sample = Sample(data=pd.read_pickle('../diffusion/sampling_small/sample_0_1_1_38_X.pkl').batch)
#sample = pd.read_pickle('../diffusion/results_large/generated_hic_map.pkl')['sample'];

In [None]:
sample.cuda()
hic_map = sample.contact_probabilities(r_c=1.76,sigma=3.72)

In [None]:
norm = LogNorm(vmax=1)

fig = plt.figure()
ax = fig.add_subplot(111) 

im = ax.matshow(
    hic_map.cpu().numpy(),
    norm=norm,
    cmap='fall',
    #extent=(region[1], region[2], region[2], region[1])
);
ax.xaxis.set_visible(False)

cbar = fig.colorbar(im, label='Interaction Frequencies',location='right');