if __name__ == '__main__':
    num_classes = 10

    model = Unet(
        dim = 64,
        dim_mults = (1, 2, 4, 8),
        num_classes = num_classes,
        cond_drop_prob = 0.5
    )

    diffusion = GaussianDiffusion(
        model,
        image_size = 128,
        timesteps = 1000
    ).cuda()

    training_images = torch.randn(8, 3, 128, 128).cuda() # images are normalized from 0 to 1
    image_classes = torch.randint(0, num_classes, (8,)).cuda()    # say 10 classes

    loss = diffusion(training_images, classes = image_classes)
    loss.backward()

    # do above for many steps

    sampled_images = diffusion.sample(
        classes = image_classes,
        cond_scale = 6.                # condition scaling, anything greater than 1 strengthens the classifier free guidance. reportedly 3-8 is good empirically
    )

    sampled_images.shape # (8, 3, 128, 128)

    # interpolation

    interpolate_out = diffusion.interpolate(
        training_images[:1],
        training_images[:1],
        image_classes[:1]
    )

In [1]:
import torch
from classifier_free_guidance import Unet, GaussianDiffusion
import sys
sys.path.insert(1, '../data_utils/')
from DataLoader import DataLoader

In [2]:
nbeads = 65 # Image dimensions will be 1 less than this b/c self-interactions are ignored. 
two_channels = True
batch_size = 64

num_classes = nbeads * (nbeads-1) // 2

In [3]:
dataset_filepath = '../../data/processed_data.hdf5'

In [4]:
# Train on the first 16 cells (except 8, which doesn't have a structure)
# Test on cell 17 
train_cells = []
for k in range(1,17):
    if k == 8:
        continue
    train_cells.append(k)

In [5]:
try: 
    device = torch.empty(1).cuda().device
except: 
    device = torch.empty(1).device

In [6]:
dl = DataLoader(
    dataset_filepath,
    segment_length=nbeads,
    batch_size=batch_size,
    normalize_distances=True,
    geos=None,
    organisms=None,
    cell_types=None,
    cell_numbers=train_cells,
    chroms=None,
    replicates=None,
    shuffle=True,
    allow_overlap=False,
    two_channels=two_channels,
    try_GPU=True,
    mean_dist_fp='../../data/mean_dists.pt',
    mean_sq_dist_fp='../../data/squares.pt'
)

In [7]:
model = Unet(
    dim = 64,
    dim_mults = (1,2,4,8),
    num_classes = num_classes,
    cond_drop_prob = 0.5,
    channels = 1 + int(two_channels), 
)

In [8]:
diffusion = GaussianDiffusion(
    model,
    image_size = nbeads - 1,
    timesteps = 1000
).to(device)

In [9]:
dl.coord_info

Unnamed: 0,Accession,Organism,Cell_Type,Cell,Replicate,Chromosome,idx_min,idx_max
0,GSE117876,Human,GM12878,1,0,1,0,10760
1,GSE117876,Human,GM12878,1,0,2,10761,22431
2,GSE117876,Human,GM12878,1,0,3,22432,32136
3,GSE117876,Human,GM12878,1,0,4,32137,41468
4,GSE117876,Human,GM12878,1,0,5,41469,50199
...,...,...,...,...,...,...,...,...
1009,GSE117876,Human,GM12878,16,2,19,5962405,5965142
1010,GSE117876,Human,GM12878,16,2,20,5965143,5968101
1011,GSE117876,Human,GM12878,16,2,21,5968102,5969679
1012,GSE117876,Human,GM12878,16,2,22,5969680,5971274


In [18]:
a = torch.rand(10,2,32,32)

In [19]:
a.flatten(-3).shape

torch.Size([10, 2048])