In [2]:
import os
import sys
import copy, time
import progressbar
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import models, transforms

sys.path.append('../src/')
from multimodal_datasets import MultiModalDataset, MultiModalGridDataset

In [27]:
# Create multi-modal datasets

data_dir = '/Users/adaly/Documents/Splotch_projects/Maynard_DLPFC/data/'

train_tissues = ['151507', '151508', '151509', '151510', '151669', '151670', '151671', '151672', '151673', '151674']
val_tissues = ['151675', '151676']

countfiles_train = [os.path.join(data_dir, 'Countfiles_Visium_norm/%s_stdata_aligned_counts_IDs.txt.unified.tsv') % s for s in train_tissues]
annotfiles_train = [os.path.join(data_dir, 'Covariates_Visium/%s.tsv') % s for s in train_tissues]
imgfiles_train = [os.path.join(data_dir, 'maynard_patchdata_oddr/%s_full_image/') % s for s in train_tissues]

countfiles_val = [os.path.join(data_dir, 'Countfiles_Visium_norm/%s_stdata_aligned_counts_IDs.txt.unified.tsv') % s for s in val_tissues]
annotfiles_val = [os.path.join(data_dir, 'Covariates_Visium/%s.tsv') % s for s in val_tissues]
imgfiles_val = [os.path.join(data_dir, 'maynard_patchdata_oddr/%s_full_image/') % s for s in val_tissues]

# Joana's manually curated list of layer marker genes:
jp_markers = {
    'MBP': 'ENSG00000197971',    # WM
    'SNAP25': 'ENSG00000132639', # GM (Layers 1-6)
    'PCP4': 'ENSG00000183036',   # Layer 5
    'RORB': 'ENSG00000198963',   # Layer 4
    'SYNPR': 'ENSG00000163630',  # Layer 6
    'MFGE8': 'ENSG00000140545',
    'CBLN2': 'ENSG00000141668',
    'RPRM': 'ENSG00000177519',
    'NR4A2': 'ENSG00000153234',
    'CXCL14': 'ENSG00000145824',
    'C1QL2': 'ENSG00000144119',
    'CUX2': 'ENSG00000111249',
    'CARTPT': 'ENSG00000164326',
    'CCK': 'ENSG00000187094'
    }
# Note that CountDataset (extended by MultiModalDataset) orders selected genes based on row order in count file.
# This corresponds to alphanumberical ordering in our case, so we can sort by ENSEMBL ID to get same ordering.
select_genes = sorted([ensmbl for _, ensmbl in jp_markers.items()])

# Preprocessing transform for image data
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

mmp_train = MultiModalDataset(countfiles_train, imgfiles_train, annotfiles_train, 
                              select_genes=select_genes, img_transforms=preprocess)
mmp_val = MultiModalDataset(countfiles_val, imgfiles_val, annotfiles_val, 
                            select_genes=select_genes, img_transforms=preprocess)

/Users/adaly/Documents/Splotch_projects/Maynard_DLPFC/data/maynard_patchdata_oddr/151510_full_image/ 0_62 no image data
/Users/adaly/Documents/Splotch_projects/Maynard_DLPFC/data/maynard_patchdata_oddr/151510_full_image/ 2_64 no image data
/Users/adaly/Documents/Splotch_projects/Maynard_DLPFC/data/maynard_patchdata_oddr/151510_full_image/ 55_55 no image data
/Users/adaly/Documents/Splotch_projects/Maynard_DLPFC/data/maynard_patchdata_oddr/151510_full_image/ 0_26 no image data
/Users/adaly/Documents/Splotch_projects/Maynard_DLPFC/data/maynard_patchdata_oddr/151510_full_image/ 57_77 no image data
/Users/adaly/Documents/Splotch_projects/Maynard_DLPFC/data/maynard_patchdata_oddr/151510_full_image/ 33_77 no image data
/Users/adaly/Documents/Splotch_projects/Maynard_DLPFC/data/maynard_patchdata_oddr/151510_full_image/ 21_77 no image data
/Users/adaly/Documents/Splotch_projects/Maynard_DLPFC/data/maynard_patchdata_oddr/151510_full_image/ 0_38 no image data
/Users/adaly/Documents/Splotch_proje

/Users/adaly/Documents/Splotch_projects/Maynard_DLPFC/data/maynard_patchdata_oddr/151510_full_image/ 51_77 no image data
/Users/adaly/Documents/Splotch_projects/Maynard_DLPFC/data/maynard_patchdata_oddr/151510_full_image/ 0_16 no image data
/Users/adaly/Documents/Splotch_projects/Maynard_DLPFC/data/maynard_patchdata_oddr/151510_full_image/ 73_77 no image data
/Users/adaly/Documents/Splotch_projects/Maynard_DLPFC/data/maynard_patchdata_oddr/151510_full_image/ 0_48 no image data
/Users/adaly/Documents/Splotch_projects/Maynard_DLPFC/data/maynard_patchdata_oddr/151510_full_image/ 15_77 no image data
/Users/adaly/Documents/Splotch_projects/Maynard_DLPFC/data/maynard_patchdata_oddr/151510_full_image/ 119_77 no image data
/Users/adaly/Documents/Splotch_projects/Maynard_DLPFC/data/maynard_patchdata_oddr/151510_full_image/ 101_77 no image data
/Users/adaly/Documents/Splotch_projects/Maynard_DLPFC/data/maynard_patchdata_oddr/151510_full_image/ 77_77 no image data
/Users/adaly/Documents/Splotch_p

/Users/adaly/Documents/Splotch_projects/Maynard_DLPFC/data/maynard_patchdata_oddr/151672_full_image/ 126_14 no image data
/Users/adaly/Documents/Splotch_projects/Maynard_DLPFC/data/maynard_patchdata_oddr/151672_full_image/ 124_46 no image data
/Users/adaly/Documents/Splotch_projects/Maynard_DLPFC/data/maynard_patchdata_oddr/151672_full_image/ 116_22 no image data
/Users/adaly/Documents/Splotch_projects/Maynard_DLPFC/data/maynard_patchdata_oddr/151672_full_image/ 123_25 no image data
/Users/adaly/Documents/Splotch_projects/Maynard_DLPFC/data/maynard_patchdata_oddr/151672_full_image/ 118_0 no image data
/Users/adaly/Documents/Splotch_projects/Maynard_DLPFC/data/maynard_patchdata_oddr/151672_full_image/ 124_50 no image data
/Users/adaly/Documents/Splotch_projects/Maynard_DLPFC/data/maynard_patchdata_oddr/151672_full_image/ 125_13 no image data
/Users/adaly/Documents/Splotch_projects/Maynard_DLPFC/data/maynard_patchdata_oddr/151672_full_image/ 30_4 no image data
/Users/adaly/Documents/Splo

/Users/adaly/Documents/Splotch_projects/Maynard_DLPFC/data/maynard_patchdata_oddr/151672_full_image/ 110_66 no image data
/Users/adaly/Documents/Splotch_projects/Maynard_DLPFC/data/maynard_patchdata_oddr/151672_full_image/ 99_67 no image data
/Users/adaly/Documents/Splotch_projects/Maynard_DLPFC/data/maynard_patchdata_oddr/151672_full_image/ 32_6 no image data
/Users/adaly/Documents/Splotch_projects/Maynard_DLPFC/data/maynard_patchdata_oddr/151672_full_image/ 127_13 no image data
/Users/adaly/Documents/Splotch_projects/Maynard_DLPFC/data/maynard_patchdata_oddr/151672_full_image/ 90_68 no image data
/Users/adaly/Documents/Splotch_projects/Maynard_DLPFC/data/maynard_patchdata_oddr/151672_full_image/ 126_18 no image data
/Users/adaly/Documents/Splotch_projects/Maynard_DLPFC/data/maynard_patchdata_oddr/151672_full_image/ 124_56 no image data
/Users/adaly/Documents/Splotch_projects/Maynard_DLPFC/data/maynard_patchdata_oddr/151672_full_image/ 122_32 no image data
/Users/adaly/Documents/Splot

/Users/adaly/Documents/Splotch_projects/Maynard_DLPFC/data/maynard_patchdata_oddr/151673_full_image/ 40_70 no image data
/Users/adaly/Documents/Splotch_projects/Maynard_DLPFC/data/maynard_patchdata_oddr/151673_full_image/ 110_24 no image data
/Users/adaly/Documents/Splotch_projects/Maynard_DLPFC/data/maynard_patchdata_oddr/151673_full_image/ 119_51 no image data
/Users/adaly/Documents/Splotch_projects/Maynard_DLPFC/data/maynard_patchdata_oddr/151673_full_image/ 117_49 no image data
/Users/adaly/Documents/Splotch_projects/Maynard_DLPFC/data/maynard_patchdata_oddr/151673_full_image/ 111_61 no image data
/Users/adaly/Documents/Splotch_projects/Maynard_DLPFC/data/maynard_patchdata_oddr/151673_full_image/ 119_49 no image data
/Users/adaly/Documents/Splotch_projects/Maynard_DLPFC/data/maynard_patchdata_oddr/151673_full_image/ 112_20 no image data
/Users/adaly/Documents/Splotch_projects/Maynard_DLPFC/data/maynard_patchdata_oddr/151673_full_image/ 117_37 no image data
/Users/adaly/Documents/Sp

/Users/adaly/Documents/Splotch_projects/Maynard_DLPFC/data/maynard_patchdata_oddr/151673_full_image/ 115_55 no image data
/Users/adaly/Documents/Splotch_projects/Maynard_DLPFC/data/maynard_patchdata_oddr/151673_full_image/ 114_30 no image data
/Users/adaly/Documents/Splotch_projects/Maynard_DLPFC/data/maynard_patchdata_oddr/151673_full_image/ 111_13 no image data
/Users/adaly/Documents/Splotch_projects/Maynard_DLPFC/data/maynard_patchdata_oddr/151673_full_image/ 76_68 no image data
/Users/adaly/Documents/Splotch_projects/Maynard_DLPFC/data/maynard_patchdata_oddr/151673_full_image/ 44_72 no image data
/Users/adaly/Documents/Splotch_projects/Maynard_DLPFC/data/maynard_patchdata_oddr/151673_full_image/ 116_54 no image data
/Users/adaly/Documents/Splotch_projects/Maynard_DLPFC/data/maynard_patchdata_oddr/151673_full_image/ 112_24 no image data
/Users/adaly/Documents/Splotch_projects/Maynard_DLPFC/data/maynard_patchdata_oddr/151673_full_image/ 31_3 no image data
/Users/adaly/Documents/Splot

/Users/adaly/Documents/Splotch_projects/Maynard_DLPFC/data/maynard_patchdata_oddr/151674_full_image/ 117_55 no image data
/Users/adaly/Documents/Splotch_projects/Maynard_DLPFC/data/maynard_patchdata_oddr/151674_full_image/ 100_6 no image data
/Users/adaly/Documents/Splotch_projects/Maynard_DLPFC/data/maynard_patchdata_oddr/151674_full_image/ 13_11 no image data
/Users/adaly/Documents/Splotch_projects/Maynard_DLPFC/data/maynard_patchdata_oddr/151674_full_image/ 110_24 no image data
/Users/adaly/Documents/Splotch_projects/Maynard_DLPFC/data/maynard_patchdata_oddr/151674_full_image/ 66_72 no image data
/Users/adaly/Documents/Splotch_projects/Maynard_DLPFC/data/maynard_patchdata_oddr/151674_full_image/ 117_49 no image data
/Users/adaly/Documents/Splotch_projects/Maynard_DLPFC/data/maynard_patchdata_oddr/151674_full_image/ 111_61 no image data
/Users/adaly/Documents/Splotch_projects/Maynard_DLPFC/data/maynard_patchdata_oddr/151674_full_image/ 27_7 no image data
/Users/adaly/Documents/Splotc

/Users/adaly/Documents/Splotch_projects/Maynard_DLPFC/data/maynard_patchdata_oddr/151674_full_image/ 67_71 no image data
/Users/adaly/Documents/Splotch_projects/Maynard_DLPFC/data/maynard_patchdata_oddr/151674_full_image/ 94_0 no image data
/Users/adaly/Documents/Splotch_projects/Maynard_DLPFC/data/maynard_patchdata_oddr/151674_full_image/ 99_5 no image data
/Users/adaly/Documents/Splotch_projects/Maynard_DLPFC/data/maynard_patchdata_oddr/151674_full_image/ 116_54 no image data
/Users/adaly/Documents/Splotch_projects/Maynard_DLPFC/data/maynard_patchdata_oddr/151674_full_image/ 68_72 no image data
/Users/adaly/Documents/Splotch_projects/Maynard_DLPFC/data/maynard_patchdata_oddr/151674_full_image/ 112_24 no image data
/Users/adaly/Documents/Splotch_projects/Maynard_DLPFC/data/maynard_patchdata_oddr/151674_full_image/ 112_62 no image data
/Users/adaly/Documents/Splotch_projects/Maynard_DLPFC/data/maynard_patchdata_oddr/151674_full_image/ 109_21 no image data
/Users/adaly/Documents/Splotch

/Users/adaly/Documents/Splotch_projects/Maynard_DLPFC/data/maynard_patchdata_oddr/151676_full_image/ 47_69 no image data
/Users/adaly/Documents/Splotch_projects/Maynard_DLPFC/data/maynard_patchdata_oddr/151676_full_image/ 124_50 no image data
/Users/adaly/Documents/Splotch_projects/Maynard_DLPFC/data/maynard_patchdata_oddr/151676_full_image/ 49_1 no image data
/Users/adaly/Documents/Splotch_projects/Maynard_DLPFC/data/maynard_patchdata_oddr/151676_full_image/ 85_65 no image data
/Users/adaly/Documents/Splotch_projects/Maynard_DLPFC/data/maynard_patchdata_oddr/151676_full_image/ 49_69 no image data
/Users/adaly/Documents/Splotch_projects/Maynard_DLPFC/data/maynard_patchdata_oddr/151676_full_image/ 52_68 no image data
/Users/adaly/Documents/Splotch_projects/Maynard_DLPFC/data/maynard_patchdata_oddr/151676_full_image/ 121_45 no image data
/Users/adaly/Documents/Splotch_projects/Maynard_DLPFC/data/maynard_patchdata_oddr/151676_full_image/ 104_6 no image data
/Users/adaly/Documents/Splotch_

/Users/adaly/Documents/Splotch_projects/Maynard_DLPFC/data/maynard_patchdata_oddr/151676_full_image/ 74_66 no image data
/Users/adaly/Documents/Splotch_projects/Maynard_DLPFC/data/maynard_patchdata_oddr/151676_full_image/ 116_24 no image data
/Users/adaly/Documents/Splotch_projects/Maynard_DLPFC/data/maynard_patchdata_oddr/151676_full_image/ 78_66 no image data
/Users/adaly/Documents/Splotch_projects/Maynard_DLPFC/data/maynard_patchdata_oddr/151676_full_image/ 111_27 no image data
/Users/adaly/Documents/Splotch_projects/Maynard_DLPFC/data/maynard_patchdata_oddr/151676_full_image/ 107_9 no image data
/Users/adaly/Documents/Splotch_projects/Maynard_DLPFC/data/maynard_patchdata_oddr/151676_full_image/ 111_11 no image data
/Users/adaly/Documents/Splotch_projects/Maynard_DLPFC/data/maynard_patchdata_oddr/151676_full_image/ 122_42 no image data
/Users/adaly/Documents/Splotch_projects/Maynard_DLPFC/data/maynard_patchdata_oddr/151676_full_image/ 107_13 no image data
/Users/adaly/Documents/Splo

/Users/adaly/Documents/Splotch_projects/Maynard_DLPFC/data/maynard_patchdata_oddr/151676_full_image/ 118_40 no image data
/Users/adaly/Documents/Splotch_projects/Maynard_DLPFC/data/maynard_patchdata_oddr/151676_full_image/ 53_1 no image data
/Users/adaly/Documents/Splotch_projects/Maynard_DLPFC/data/maynard_patchdata_oddr/151676_full_image/ 117_25 no image data
/Users/adaly/Documents/Splotch_projects/Maynard_DLPFC/data/maynard_patchdata_oddr/151676_full_image/ 122_54 no image data
/Users/adaly/Documents/Splotch_projects/Maynard_DLPFC/data/maynard_patchdata_oddr/151676_full_image/ 124_48 no image data
/Users/adaly/Documents/Splotch_projects/Maynard_DLPFC/data/maynard_patchdata_oddr/151676_full_image/ 113_61 no image data
/Users/adaly/Documents/Splotch_projects/Maynard_DLPFC/data/maynard_patchdata_oddr/151676_full_image/ 24_68 no image data
/Users/adaly/Documents/Splotch_projects/Maynard_DLPFC/data/maynard_patchdata_oddr/151676_full_image/ 119_47 no image data
/Users/adaly/Documents/Splo

In [28]:
# Construct a network that takes an image patch and outputs an expression vector

x_count, x_img, y = mmp_train[0]
n_genes = x_count.shape[0]

# load a pre-trained DenseNet121
densenet = models.densenet121(True)

# Change final layer to predict gene dimension; Sigmoid activation ensures values fall in normalized 0,1 range.
densenet.classifier = nn.Sequential(
    nn.Linear(1024, n_genes),
    nn.Sigmoid()
)

pred = densenet(torch.unsqueeze(x_img,0))
print(pred)

tensor([[0.2759, 0.5301, 0.4844, 0.4730, 0.5630, 0.5016, 0.5452, 0.4289, 0.3434,
         0.5279, 0.6469, 0.4337, 0.6727, 0.5018]], grad_fn=<SigmoidBackward>)


In [42]:
def train_selfsupervised(model, dataloaders, criterion, optimizer, num_epochs, 
                         outfile=None, display=None):
    since = time.time()
    train_loss_hist, val_loss_hist = [],[]
    best_loss = np.inf
    
    # GPU support
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model.to(device)
    
    best_model_wts = copy.deepcopy(model.state_dict())

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1), flush=True)
        print('-' * 10, flush=True)

        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
            else:
                model.eval()

            if display:
                iterator = progressbar.progressbar(dataloaders[phase])
            else:
                iterator = dataloaders[phase]
            
            running_loss = 0

            for x_count, x_img, y in iterator:
                optimizer.zero_grad()

                with torch.set_grad_enabled(phase=='train'):
                    counts_pred = model(x_img)
                    loss = criterion(x_count, counts_pred)
                                        
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
                
                running_loss += loss.item()
            epoch_loss = running_loss / len(dataloaders[phase])
            
            if phase == 'val':
                if epoch_loss < best_loss:
                    best_loss = epoch_loss
                    best_model_wts = copy.deepcopy(model.state_dict())
                    
                    if outfile is not None:
                        torch.save(model.state_dict(), outfile)
                val_loss_hist.append(val_loss)
            else:
                train_loss_hist.append(train_loss)
    
    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60), flush=True)
    print('Best val loss: {:4f}'.format(best_loss), flush=True)

mmp_dataloaders = {
    'train': DataLoader(mmp_train, batch_size=32, shuffle=True),
    'val': DataLoader(mmp_val, batch_size=32, shuffle=False)
}

criterion = nn.MSELoss(reduction='mean')
optimizer = torch.optim.Adam(densenet.parameters(), lr=1e-4)
    
train_selfsupervised(densenet, mmp_dataloaders, criterion, optimizer, 10)

Epoch 0/9
----------
tensor(0.1294, grad_fn=<MseLossBackward>)
tensor(0.1283, grad_fn=<MseLossBackward>)
tensor(0.1258, grad_fn=<MseLossBackward>)
tensor(0.1157, grad_fn=<MseLossBackward>)


KeyboardInterrupt: 