In [1]:
from pathlib import Path
from anndata import read_h5ad
import sys
import scanpy as sc
import os
import matplotlib.pyplot as plt

import numpy as np
import pandas as pd

In [2]:
DATA_PATH = Path("../data/dataset_breast_cancer_9visium")

In [3]:
(DATA_PATH / "all_adata.h5ad").exists()

True

In [4]:
adata_all = read_h5ad(DATA_PATH / "all_adata.h5ad")

In [5]:
adata_all

AnnData object with n_obs × n_vars = 24578 × 14664
    obs: 'in_tissue', 'array_row', 'array_col', 'imagecol', 'imagerow', 'tile_tissue_mask_path', 'tissue_area', 'tile_path', 'library_id'
    var: 'gene_ids-FFPE', 'feature_types-FFPE', 'genome-FFPE', 'gene_ids-block1', 'feature_types-block1', 'genome-block1', 'gene_ids-block2', 'feature_types-block2', 'genome-block2'
    uns: 'spatial'
    obsm: 'spatial'

In [6]:
# adata_all.uns['spatial']['1160920F']

In [23]:
adata_all_train_valid = adata_all[adata_all.obs["library_id"].isin(
    adata_all.obs.library_id.cat.remove_categories(["FFPE", "1160920F"]).unique())]

In [8]:
adata_all_train_valid

View of AnnData object with n_obs × n_vars = 17457 × 14664
    obs: 'in_tissue', 'array_row', 'array_col', 'imagecol', 'imagerow', 'tile_tissue_mask_path', 'tissue_area', 'tile_path', 'library_id'
    var: 'gene_ids-FFPE', 'feature_types-FFPE', 'genome-FFPE', 'gene_ids-block1', 'feature_types-block1', 'genome-block1', 'gene_ids-block2', 'feature_types-block2', 'genome-block2'
    uns: 'spatial'
    obsm: 'spatial'

In [26]:
training_index = adata_all_train_valid.obs.sample(frac=0.7, random_state=1).index
training_dataset = adata_all_train_valid[training_index,].copy()


In [27]:
training_dataset

AnnData object with n_obs × n_vars = 12220 × 14664
    obs: 'in_tissue', 'array_row', 'array_col', 'imagecol', 'imagerow', 'tile_tissue_mask_path', 'tissue_area', 'tile_path', 'library_id'
    var: 'gene_ids-FFPE', 'feature_types-FFPE', 'genome-FFPE', 'gene_ids-block1', 'feature_types-block1', 'genome-block1', 'gene_ids-block2', 'feature_types-block2', 'genome-block2'
    uns: 'spatial'
    obsm: 'spatial'

In [7]:
gene_list=["COX6C","TTLL12", "PABPC1", "GNAS", "HSP90AB1", "TFF3", "ATP1A1", "B2M", "FASN", "SPARC", "CD74", "CD63", "CD24", "CD81"]

In [46]:
test = DataGenerator(adata=training_dataset, genes=gene_list, aug=False)

In [53]:
test._load_label(test.adata.obs_names[0])

(array([6.0063534], dtype=float32),
 array([1.3862944], dtype=float32),
 array([3.218876], dtype=float32),
 array([2.6390574], dtype=float32),
 array([3.6888795], dtype=float32),
 array([3.0910425], dtype=float32),
 array([1.609438], dtype=float32),
 array([4.8598123], dtype=float32),
 array([3.3322046], dtype=float32),
 array([2.3025851], dtype=float32),
 array([4.1108737], dtype=float32),
 array([2.8332133], dtype=float32),
 array([2.5649493], dtype=float32),
 array([2.4849067], dtype=float32))

In [54]:
test.adata.obs_names[0]

'GTATTCTTACCGTGCT-1-block2'

In [60]:
adata_all.obs

Unnamed: 0,in_tissue,array_row,array_col,imagecol,imagerow,tile_tissue_mask_path,tissue_area,tile_path,library_id
GATAAGGGACGATTAG-1-1142243F,1,1,3,12601,4511,/tmp/1142243F_tissue_mask/1142243F-12601-4511-...,0.733437,/clusterdata/uqxtan9/Xiao/breast_cancer_9visiu...,1142243F
TGTTGGCTGGCGGAAG-1-1142243F,1,1,5,12872,4512,/tmp/1142243F_tissue_mask/1142243F-12872-4512-...,0.878391,/clusterdata/uqxtan9/Xiao/breast_cancer_9visiu...,1142243F
GCGAGGGACTGCTAGA-1-1142243F,1,1,7,13144,4513,/tmp/1142243F_tissue_mask/1142243F-13144-4513-...,0.884632,/clusterdata/uqxtan9/Xiao/breast_cancer_9visiu...,1142243F
GCGCGTTTAAATCGTA-1-1142243F,1,1,9,13416,4514,/tmp/1142243F_tissue_mask/1142243F-13416-4514-...,0.813425,/clusterdata/uqxtan9/Xiao/breast_cancer_9visiu...,1142243F
ATCTATCGATGATCAA-1-1142243F,1,3,3,12599,4984,/tmp/1142243F_tissue_mask/1142243F-12599-4984-...,0.879218,/clusterdata/uqxtan9/Xiao/breast_cancer_9visiu...,1142243F
...,...,...,...,...,...,...,...,...,...
TTGTTCAGTGTGCTAC-1-FFPE,1,24,64,14428,9698,/tmp/FFPE_tissue_mask/FFPE-14428-9698-299.jpeg,0.967338,/clusterdata/uqxtan9/Xiao/breast_cancer_9visiu...,FFPE
TTGTTGTGTGTCAAGA-1-FFPE,1,31,77,16312,11467,/tmp/FFPE_tissue_mask/FFPE-16312-11467-299.jpeg,0.975951,/clusterdata/uqxtan9/Xiao/breast_cancer_9visiu...,FFPE
TTGTTTCACATCCAGG-1-FFPE,1,58,42,11229,18277,/tmp/FFPE_tissue_mask/FFPE-11229-18277-299.jpeg,0.766479,/clusterdata/uqxtan9/Xiao/breast_cancer_9visiu...,FFPE
TTGTTTCATTAGTCTA-1-FFPE,1,60,30,9488,18780,/tmp/FFPE_tissue_mask/FFPE-9488-18780-299.jpeg,0.973043,/clusterdata/uqxtan9/Xiao/breast_cancer_9visiu...,FFPE


In [73]:
idxs = adata_all.obs.loc[adata_all.obs['library_id'] == '1142243F'].index
print(len(idxs))
adata_all[idxs, gene_list].X.todense()

4704


matrix([[2.0794415, 0.       , 2.5649493, ..., 1.3862944, 2.4849067,
         0.6931472],
        [2.4849067, 0.       , 2.3025851, ..., 1.7917595, 1.0986123,
         1.609438 ],
        [2.7080503, 1.0986123, 2.944439 , ..., 1.7917595, 2.1972246,
         1.3862944],
        ...,
        [2.3025851, 0.       , 2.944439 , ..., 1.609438 , 1.609438 ,
         2.3978953],
        [2.6390574, 0.6931472, 2.5649493, ..., 1.7917595, 2.3025851,
         2.3978953],
        [2.3025851, 0.       , 3.218876 , ..., 2.7080503, 1.9459102,
         2.0794415]], dtype=float32)

In [83]:
adata_all.raw

In [84]:
type(adata_all.raw)

NoneType

In [3]:
import torch

class DataGenerator(torch.utils.data.Dataset):
    """
    data generator for multiple branches gene prediction model
    """

    def __init__(self, adata, dim=(299, 299), n_channels=3, genes=None, aug=False, tile_path="tile_path"):
        'Initialization'
        self.dim = dim
        self.adata = adata
        self.n_channels = n_channels
        self.genes = genes
        self.num_genes = len(genes)
        self.aug = aug
        self.tile_path = tile_path
        self.indexes = np.arange(self.adata.n_obs)
        
    def __len__(self):
        'Denotes the number of samples'
        return int(self.adata.n_obs)

    def __getitem__(self, index):
        'Generates one sample of data'
        # Find obs name
        obs_temp = self.adata.obs_names[index]

        # Generate data
        X_img, y = self._load_data(obs_temp)

        return torch.Tensor(X_img), torch.Tensor(y)

    def _load_data(self, obs):
        img_path = self.adata.obs.loc[obs, 'tile_path']
        X_img = Image.open(img_path).convert('RGB')
        X_img = transforms.Resize(self.dim)(X_img)
        X_img = np.array(X_img).astype('uint8')
        #         X_img = np.expand_dims(X_img, axis=0)
        #         n_rotate = np.random.randint(0, 4)
        #         X_img = np.rot90(X_img, k=n_rotate, axes=(1, 2))
        if self.aug:
            X_img = seq_aug(image=X_img)
        y = self._load_label(obs)
        return X_img, y

    def _load_label(self, obs):
        batch_adata = self.adata[obs, self.genes].copy()

        return tuple([batch_adata.to_df()[i].values for i in self.genes])

    def get_classes(self):
        return self.adata.to_df().loc[:, self.genes]


In [2]:
import os
import glob
import torch
import torchvision
import numpy as np
import scanpy as sc
import pandas as pd 
import scprep as scp
import anndata as ad
import seaborn as sns
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import torch.nn.functional as F
import torchvision.transforms as transforms
from PIL import ImageFile, Image
sys.path.append("../models/Hist2ST/")
from utils import read_tiff, get_data
from graph_construction import calcADJ
from collections import defaultdict as dfd
ImageFile.LOAD_TRUNCATED_IMAGES = True
Image.MAX_IMAGE_PIXELS = None


In [3]:
class ViT_SKIN(torch.utils.data.Dataset):
    """Some Information about ViT_SKIN"""
    def __init__(self,train=True,r=4,norm=False,fold=0,flatten=True,ori=False,adj=False,prune='NA',neighs=4):
        super(ViT_SKIN, self).__init__()

        self.dir = '../data/GSE144240_RAW/'
        self.r = 224//r

        patients = ['P2', 'P5', 'P9', 'P10']
        reps = ['rep1', 'rep2', 'rep3']
        names = []
        for i in patients:
            for j in reps:
                names.append(i+'_ST_'+j)
        gene_list = list(np.load('../models/Hist2ST/data/skin_hvg_cut_1000.npy',allow_pickle=True))

        self.ori = ori
        self.adj = adj
        self.norm = norm
        self.train = train
        self.flatten = flatten
        self.gene_list = gene_list
        samples = names
        te_names = [samples[fold]]
        tr_names = list(set(samples)-set(te_names))

        if train:
            self.names = tr_names
        else:
            self.names = te_names

        print(te_names)
        print('Loading imgs...')
        self.img_dict = {i:torch.Tensor(np.array(self.get_img(i))) for i in self.names}
        print('Loading metadata...')
        self.meta_dict = {i:self.get_meta(i) for i in self.names}

        self.gene_set = list(gene_list)
        if self.norm:
            self.exp_dict = {
                i:sc.pp.scale(scp.transform.log(scp.normalize.library_size_normalize(m[self.gene_set].values)))
                for i,m in self.meta_dict.items()
            }
        else:
            self.exp_dict = {
                i:scp.transform.log(scp.normalize.library_size_normalize(m[self.gene_set].values)) 
                for i,m in self.meta_dict.items()
            }
        if self.ori:
            self.ori_dict = {i:m[self.gene_set].values for i,m in self.meta_dict.items()}
            self.counts_dict={}
            for i,m in self.ori_dict.items():
                n_counts=m.sum(1)
                sf = n_counts / np.median(n_counts)
                self.counts_dict[i]=sf
        self.center_dict = {
            i:np.floor(m[['pixel_x','pixel_y']].values).astype(int)
            for i,m in self.meta_dict.items()
        }
        self.loc_dict = {i:m[['x','y']].values for i,m in self.meta_dict.items()}
        self.adj_dict = {
            i:calcADJ(m,neighs,pruneTag=prune)
            for i,m in self.loc_dict.items()
        }
        self.patch_dict=dfd(lambda :None)
        self.lengths = [len(i) for i in self.meta_dict.values()]
        self.cumlen = np.cumsum(self.lengths)
        self.id2name = dict(enumerate(self.names))


    def filter_helper(self):
        a = np.zeros(len(self.gene_list))
        n = 0
        for i,exp in self.exp_dict.items():
            n += exp.shape[0]
            exp[exp>0] = 1
            for j in range((len(self.gene_list))):
                a[j] += np.sum(exp[:,j])


    def __getitem__(self, index):
        print("Index: ", index)
        ID=self.id2name[index]
        im = self.img_dict[ID].permute(1,0,2)

        exps = self.exp_dict[ID]
        if self.ori:
            oris = self.ori_dict[ID]
            sfs = self.counts_dict[ID]
        adj=self.adj_dict[ID]
        centers = self.center_dict[ID]
        loc = self.loc_dict[ID]
        patches = self.patch_dict[ID]
        positions = torch.LongTensor(loc)
        patch_dim = 3 * self.r * self.r * 4
        exps = torch.Tensor(exps)
        if patches is None:
            n_patches = len(centers)
            if self.flatten:
                patches = torch.zeros((n_patches,patch_dim))
            else:
                patches = torch.zeros((n_patches,3,2*self.r,2*self.r))

            for i in range(n_patches):
                center = centers[i]
                x, y = center
                patch = im[(x-self.r):(x+self.r),(y-self.r):(y+self.r),:]
                if self.flatten:
                    patches[i] = patch.flatten()
                else:
                    patches[i]=patch.permute(2,0,1)
            self.patch_dict[ID]=patches
        data=[patches, positions, exps]
        if self.adj:
            data.append(adj)
        if self.ori:
            data+=[torch.Tensor(oris),torch.Tensor(sfs)]
        data.append(torch.Tensor(centers))
        return data
        
    def __len__(self):
        return len(self.exp_dict)

    def get_img(self,name):
        path = glob.glob(self.dir+'*'+name+'.jpg')[0]
        im = Image.open(path)
        return im

    def get_cnt(self,name):
        path = glob.glob(self.dir+'*'+name+'_stdata.tsv')[0]
        df = pd.read_csv(path,sep='\t',index_col=0)
        return df

    def get_pos(self,name):
        path = glob.glob(self.dir+'*spot*'+name+'.tsv')[0]
        df = pd.read_csv(path,sep='\t')

        x = df['x'].values
        y = df['y'].values
        x = np.around(x).astype(int)
        y = np.around(y).astype(int)
        id = []
        for i in range(len(x)):
            id.append(str(x[i])+'x'+str(y[i])) 
        df['id'] = id

        return df

    def get_meta(self,name,gene_list=None):
        cnt = self.get_cnt(name)
        pos = self.get_pos(name)
        meta = cnt.join(pos.set_index('id'),how='inner')

        return meta

    def get_overlap(self,meta_dict,gene_list):
        gene_set = set(gene_list)
        for i in meta_dict.values():
            gene_set = gene_set&set(i.columns)
        return list(gene_set)

In [4]:
dataset = ViT_SKIN(
            train='train',fold=0,flatten=False,adj=True,ori=True,prune='NA'
        )

['P2_ST_rep1']
Loading imgs...
Loading metadata...


In [7]:
dataset.id2name

{0: 'P10_ST_rep1',
 1: 'P5_ST_rep2',
 2: 'P10_ST_rep2',
 3: 'P9_ST_rep2',
 4: 'P10_ST_rep3',
 5: 'P2_ST_rep2',
 6: 'P9_ST_rep3',
 7: 'P9_ST_rep1',
 8: 'P5_ST_rep3',
 9: 'P5_ST_rep1',
 10: 'P2_ST_rep3'}

In [5]:
train_loader = DataLoader(dataset, batch_size=1, num_workers=0, shuffle=True)

iterator=iter(train_loader)

# [print(iterator._next_index()) for i in range(12)]

NameError: name 'DataLoader' is not defined

In [6]:
from torch.utils.data import DataLoader
train_loader = DataLoader(dataset, batch_size=1, num_workers=0, shuffle=True)
iterator=iter(train_loader)

# for i in range(20):
#     next(iterator)
#     print(dataset.patch_dict.keys())

In [8]:
test = next(iterator)

Index:  7


In [9]:
[i.shape for i in test]

[torch.Size([1, 521, 3, 112, 112]),
 torch.Size([1, 521, 2]),
 torch.Size([1, 521, 171]),
 torch.Size([1, 521, 521]),
 torch.Size([1, 521, 171]),
 torch.Size([1, 521]),
 torch.Size([1, 521, 2])]

In [15]:
dataset.counts_dict['P5_ST_rep3'].dtype

dtype('float64')

In [18]:
dataset.patch_dict['P10_ST_rep3'].dtype

torch.float32

In [19]:
dataset.img_dict['P10_ST_rep3'].dtype

torch.float32

In [24]:
dataset.center_dict['P10_ST_rep3'].shape

(462, 2)

In [18]:
# dataset.patch_dict

In [13]:
len(dataset.names)

11

In [19]:
iterator._next_index()

StopIteration: 

In [12]:
dataset.names

['P9_ST_rep3',
 'P5_ST_rep1',
 'P9_ST_rep1',
 'P9_ST_rep2',
 'P10_ST_rep1',
 'P2_ST_rep2',
 'P5_ST_rep2',
 'P2_ST_rep3',
 'P10_ST_rep2',
 'P5_ST_rep3',
 'P10_ST_rep3']

In [None]:
from torch.utils.data import DataLoader
train_loader = DataLoader(dataset, batch_size=1, num_workers=0, shuffle=True)

In [7]:
iterator=iter(train_loader)
test = next(iterator)

Index:  4


In [21]:
dataset.img_dict.keys()

dict_keys(['P10_ST_rep1', 'P9_ST_rep2', 'P9_ST_rep3', 'P5_ST_rep3', 'P2_ST_rep1', 'P9_ST_rep1', 'P10_ST_rep3', 'P5_ST_rep1', 'P5_ST_rep2', 'P10_ST_rep2', 'P2_ST_rep3'])

In [19]:
dataset.img_dict['P10_ST_rep1'].shape

torch.Size([15872, 15872, 3])

In [27]:
dataset.exp_dict['P10_ST_rep3'].shape


(462, 171)

In [49]:
dataset.meta_dict['P10_ST_rep3']
# dataset.meta_dict['P10_ST_rep3'].loc['15x19']


Unnamed: 0,FO538757.1,RP4-669L17.10,RP11-206L10.9,SAMD11,NOC2L,KLHL17,PLEKHN1,PERM1,HES4,ISG15,...,BPY2C,AC006386.1,AC006328.1,TTTY3,x,y,new_x,new_y,pixel_x,pixel_y
15x19,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,15,19,15.07,18.93,4885.3,5725.4
16x18,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,...,0.0,4.0,0.0,0.0,16,18,16.04,17.97,5060.4,5534.0
16x20,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,...,0.0,2.0,0.0,0.0,16,20,16.03,19.95,5057.6,5931.4
16x22,1.0,0.0,0.0,0.0,0.0,0.0,3.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,16,22,16.00,21.93,5052.5,6328.3
17x17,0.0,0.0,0.0,0.0,1.0,1.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,17,17,17.05,16.94,5241.8,5326.6
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
54x32,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,...,0.0,0.0,0.0,0.0,54,32,53.97,32.09,11904.5,8363.0
54x34,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,1.0,0.0,0.0,54,34,53.99,34.07,11907.9,8761.1
55x27,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,55,27,55.00,27.04,12090.5,7350.8
55x29,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,55,29,54.93,29.09,12078.8,7763.0


In [54]:
# dataset.loc_dict['P10_ST_rep3']
3 * dataset.r * dataset.r * 4

37632

In [58]:
dfd

collections.defaultdict

In [67]:
dataset.patch_dict['P2_ST_rep3'].shape

torch.Size([638, 3, 112, 112])

In [81]:
dataset.loc_dict['P2_ST_rep3']

array([[10, 24],
       [10, 26],
       [10, 28],
       ...,
       [ 9, 39],
       [ 9, 41],
       [ 9, 43]])

In [34]:
dataset.adj_dict['P2_ST_rep3'].count_nonzero()

tensor(2552)

In [39]:
dataset.adj_dict['P2_ST_rep3'].count_nonzero() / (dataset.adj_dict['P2_ST_rep3'].shape[0]) 

tensor(4.)

In [26]:
dataset.exp_dict['P2_ST_rep3'].shape

(638, 171)

In [73]:
len(test)

7

In [46]:
test[1].shape

torch.Size([1, 521, 2])

In [43]:
torch.LongTensor(dataset.loc_dict['P2_ST_rep3'])

tensor([[10, 24],
        [10, 26],
        [10, 28],
        ...,
        [ 9, 39],
        [ 9, 41],
        [ 9, 43]])

In [48]:
for i,m in dataset.exp_dict.items():
    print(m.shape)

(521, 171)
(1145, 171)
(638, 171)
(1071, 171)
(462, 171)
(590, 171)
(1182, 171)
(521, 171)
(621, 171)
(608, 171)
(646, 171)


In [47]:
for i in range(10):
    print(i)
    next(iterator)
    print([i.shape for i in test])

0
Index:  1
[torch.Size([1, 521, 3, 112, 112]), torch.Size([1, 521, 2]), torch.Size([1, 521, 171]), torch.Size([1, 521, 521]), torch.Size([1, 521, 171]), torch.Size([1, 521]), torch.Size([1, 521, 2])]
1
Index:  2
[torch.Size([1, 521, 3, 112, 112]), torch.Size([1, 521, 2]), torch.Size([1, 521, 171]), torch.Size([1, 521, 521]), torch.Size([1, 521, 171]), torch.Size([1, 521]), torch.Size([1, 521, 2])]
2
Index:  10
[torch.Size([1, 521, 3, 112, 112]), torch.Size([1, 521, 2]), torch.Size([1, 521, 171]), torch.Size([1, 521, 521]), torch.Size([1, 521, 171]), torch.Size([1, 521]), torch.Size([1, 521, 2])]
3
Index:  9
[torch.Size([1, 521, 3, 112, 112]), torch.Size([1, 521, 2]), torch.Size([1, 521, 171]), torch.Size([1, 521, 521]), torch.Size([1, 521, 171]), torch.Size([1, 521]), torch.Size([1, 521, 2])]
4
Index:  6
[torch.Size([1, 521, 3, 112, 112]), torch.Size([1, 521, 2]), torch.Size([1, 521, 171]), torch.Size([1, 521, 521]), torch.Size([1, 521, 171]), torch.Size([1, 521]), torch.Size([1, 521,

StopIteration: 

In [1]:
dataset.counts_dict['P2_ST_rep3'].shape

NameError: name 'dataset' is not defined

In [3]:
class ViT_HER2ST(torch.utils.data.Dataset):
    """Some Information about HER2ST"""
    def __init__(self,train=True,fold=0,r=4,flatten=True,ori=False,adj=False,prune='Grid',neighs=4):
        super(ViT_HER2ST, self).__init__()
        
        self.cnt_dir = '../models/Hist2ST/data/her2st/data/ST-cnts'
        self.img_dir = '../models/Hist2ST/data/her2st/data/ST-imgs'
        self.pos_dir = '../models/Hist2ST/data/her2st/data/ST-spotfiles'
        self.lbl_dir = '../models/Hist2ST/data/her2st/data/ST-pat/lbl'
        self.r = 224//r

        # gene_list = list(np.load('data/her_hvg.npy',allow_pickle=True))
        gene_list = list(np.load('../models/Hist2ST/data/her_hvg_cut_1000.npy',allow_pickle=True))
        self.gene_list = gene_list
        names = os.listdir(self.cnt_dir)
        names.sort()
        names = [i[:2] for i in names]
        self.train = train
        self.ori = ori
        self.adj = adj
        # samples = ['A1','B1','C1','D1','E1','F1','G2','H1']
        samples = names[1:33]

        te_names = [samples[fold]]
        print(te_names)
        tr_names = list(set(samples)-set(te_names))

        if train:
            self.names = tr_names
        else:
            self.names = te_names

        print('Loading imgs...')
        self.img_dict = {i:torch.Tensor(np.array(self.get_img(i))) for i in self.names}
        print('Loading metadata...')
        self.meta_dict = {i:self.get_meta(i) for i in self.names}
        self.label={i:None for i in self.names}
        self.lbl2id={
            'invasive cancer':0, 'breast glands':1, 'immune infiltrate':2, 
            'cancer in situ':3, 'connective tissue':4, 'adipose tissue':5, 'undetermined':-1
        }
        if not train and self.names[0] in ['A1','B1','C1','D1','E1','F1','G2','H1','J1']:
            self.lbl_dict={i:self.get_lbl(i) for i in self.names}
            # self.label={i:m['label'].values for i,m in self.lbl_dict.items()}
            idx=self.meta_dict[self.names[0]].index
            lbl=self.lbl_dict[self.names[0]]
            lbl=lbl.loc[idx,:]['label'].values
            # lbl=torch.Tensor(list(map(lambda i:self.lbl2id[i],lbl)))
            self.label[self.names[0]]=lbl
        elif train:
            for i in self.names:
                idx=self.meta_dict[i].index
                if i in ['A1','B1','C1','D1','E1','F1','G2','H1','J1']:
                    lbl=self.get_lbl(i)
                    lbl=lbl.loc[idx,:]['label'].values
                    lbl=torch.Tensor(list(map(lambda i:self.lbl2id[i],lbl)))
                    self.label[i]=lbl
                else:
                    self.label[i]=torch.full((len(idx),),-1)
        self.gene_set = list(gene_list)
        self.exp_dict = {
            i:scp.transform.log(scp.normalize.library_size_normalize(m[self.gene_set].values)) 
            for i,m in self.meta_dict.items()
        }
        if self.ori:
            self.ori_dict = {i:m[self.gene_set].values for i,m in self.meta_dict.items()}
            self.counts_dict={}
            for i,m in self.ori_dict.items():
                n_counts=m.sum(1)
                sf = n_counts / np.median(n_counts)
                self.counts_dict[i]=sf
        self.center_dict = {
            i:np.floor(m[['pixel_x','pixel_y']].values).astype(int) 
            for i,m in self.meta_dict.items()
        }
        self.loc_dict = {i:m[['x','y']].values for i,m in self.meta_dict.items()}
        self.adj_dict = {
            i:calcADJ(m,neighs,pruneTag=prune)
            for i,m in self.loc_dict.items()
        }
        self.patch_dict=dfd(lambda :None)
        self.lengths = [len(i) for i in self.meta_dict.values()]
        self.cumlen = np.cumsum(self.lengths)
        self.id2name = dict(enumerate(self.names))
        self.flatten=flatten
    def __getitem__(self, index):
        ID=self.id2name[index]
        im = self.img_dict[ID]
        im = im.permute(1,0,2)
        # im = torch.Tensor(np.array(self.im))
        exps = self.exp_dict[ID]
        if self.ori:
            oris = self.ori_dict[ID]
            sfs = self.counts_dict[ID]
        centers = self.center_dict[ID]
        loc = self.loc_dict[ID]
        adj = self.adj_dict[ID]
        patches = self.patch_dict[ID]
        positions = torch.LongTensor(loc)
        patch_dim = 3 * self.r * self.r * 4
        label=self.label[ID]
        exps = torch.Tensor(exps)
        if patches is None:
            n_patches = len(centers)
            if self.flatten:
                patches = torch.zeros((n_patches,patch_dim))
            else:
                patches = torch.zeros((n_patches,3,2*self.r,2*self.r))
            for i in range(n_patches):
                center = centers[i]
                x, y = center
                patch = im[(x-self.r):(x+self.r),(y-self.r):(y+self.r),:]
                if self.flatten:
                    patches[i] = patch.flatten()
                else:
                    patches[i]=patch.permute(2,0,1)
            self.patch_dict[ID]=patches
        data=[patches, positions, exps]
        if self.adj:
            data.append(adj)
        if self.ori:
            data+=[torch.Tensor(oris),torch.Tensor(sfs)]
        data.append(torch.Tensor(centers))
        return data
        
    def __len__(self):
        return len(self.exp_dict)

    def get_img(self,name):
        pre = self.img_dir+'/'+name[0]+'/'+name
        fig_name = os.listdir(pre)[0]
        path = pre+'/'+fig_name
        im = Image.open(path)
        return im

    def get_cnt(self,name):
        path = self.cnt_dir+'/'+name+'.tsv'
        df = pd.read_csv(path,sep='\t',index_col=0)

        return df

    def get_pos(self,name):
        path = self.pos_dir+'/'+name+'_selection.tsv'
        # path = self.pos_dir+'/'+name+'_labeled_coordinates.tsv'
        df = pd.read_csv(path,sep='\t')

        x = df['x'].values
        y = df['y'].values
        x = np.around(x).astype(int)
        y = np.around(y).astype(int)
        id = []
        for i in range(len(x)):
            id.append(str(x[i])+'x'+str(y[i])) 
        df['id'] = id

        return df

    def get_meta(self,name,gene_list=None):
        cnt = self.get_cnt(name)
        pos = self.get_pos(name)
        meta = cnt.join((pos.set_index('id')))

        return meta

    def get_lbl(self,name):
        # path = self.pos_dir+'/'+name+'_selection.tsv'
        path = self.lbl_dir+'/'+name+'_labeled_coordinates.tsv'
        df = pd.read_csv(path,sep='\t')

        x = df['x'].values
        y = df['y'].values
        x = np.around(x).astype(int)
        y = np.around(y).astype(int)
        id = []
        for i in range(len(x)):
            id.append(str(x[i])+'x'+str(y[i])) 
        df['id'] = id
        df.drop('pixel_x', inplace=True, axis=1)
        df.drop('pixel_y', inplace=True, axis=1)
        df.drop('x', inplace=True, axis=1)
        df.drop('y', inplace=True, axis=1)
        df.set_index('id',inplace=True)
        return df


In [8]:
ds = ViT_HER2ST(
    train='train',fold=5,flatten=False,
    ori=True,neighs=4,adj=True,prune='Grid'
)


['B1']
Loading imgs...
Loading metadata...


In [12]:
ds.meta_dict['G1']

Unnamed: 0,FO538757.1,SAMD11,NOC2L,KLHL17,PLEKHN1,PERM1,HES4,ISG15,AGRN,RNF223,...,USP9Y,TMSB4Y,NLGN4Y,x,y,new_x,new_y,pixel_x,pixel_y,selected
10x10,0,1,0,0,0,0,0,3,0,0,...,0,0,0,10,10,9.803,9.952,2561.12,2609.24,1
10x11,0,0,0,0,0,0,0,2,0,0,...,0,0,0,10,11,9.815,10.943,2564.61,2898.09,1
10x12,0,0,0,0,0,0,1,1,0,0,...,0,0,0,10,12,9.816,11.946,2564.91,3190.44,1
10x13,0,0,0,0,0,0,1,4,0,0,...,0,0,0,10,13,9.850,12.948,2574.80,3482.49,1
10x14,0,0,0,0,0,0,1,1,0,0,...,0,0,0,10,14,9.833,13.934,2569.85,3769.88,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
9x23,0,0,0,0,0,0,2,1,1,0,...,0,0,0,9,23,8.907,22.958,2300.44,6400.11,1
9x24,0,0,0,0,0,0,1,0,0,0,...,0,0,0,9,24,8.922,23.958,2304.81,6691.58,1
9x25,0,0,0,0,0,0,1,3,1,0,...,0,0,0,9,25,8.836,24.973,2279.79,6987.42,1
9x26,0,0,0,0,0,0,0,0,0,0,...,0,0,0,9,26,8.919,25.957,2303.93,7274.23,1


In [13]:
# ds.center_dict['G1']

In [14]:
# ds.loc_dict['G1']

In [26]:
from torch.utils.data import DataLoader
train_loader = DataLoader(ds, batch_size=1, num_workers=0, shuffle=True)


In [27]:
iterator=iter(train_loader)
test = next(iterator)

In [28]:
for i in range(40):
    next(iterator)
    print(i)


0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29


StopIteration: 

In [32]:
ds.id2name

{0: 'B2',
 1: 'D1',
 2: 'C3',
 3: 'D6',
 4: 'F3',
 5: 'B3',
 6: 'D5',
 7: 'F2',
 8: 'E3',
 9: 'C1',
 10: 'F1',
 11: 'A6',
 12: 'D3',
 13: 'G1',
 14: 'C5',
 15: 'B5',
 16: 'A2',
 17: 'A5',
 18: 'E2',
 19: 'G3',
 20: 'A3',
 21: 'G2',
 22: 'A4',
 23: 'B4',
 24: 'B6',
 25: 'C6',
 26: 'D2',
 27: 'C4',
 28: 'E1',
 29: 'C2',
 30: 'D4'}

In [28]:
from PIL import Image

def read_tiff(path):
    Image.MAX_IMAGE_PIXELS = 933120000
    im = Image.open(path)
    imarray = np.array(im)
    # I = plt.imread(path)
    return im


class STDataset(torch.utils.data.Dataset):
    """Some Information about STDataset"""
    def __init__(self, adata, img_path, diameter=177.5, train=True):
        super(STDataset, self).__init__()

        self.exp = adata.X.toarray()
        self.im = read_tiff(img_path)
        self.r = np.ceil(diameter/2).astype(int)
        self.train = train
        # self.d_spot = self.d_spot if self.d_spot%2==0 else self.d_spot+1
        self.transforms = transforms.Compose([
            transforms.ColorJitter(0.5,0.5,0.5),
            transforms.RandomHorizontalFlip(),
            transforms.RandomRotation(degrees=180),
            transforms.ToTensor()
        ])
        self.centers = adata.obsm['spatial']
        self.pos = adata.obsm['position_norm']
    def __getitem__(self, index):
        exp = self.exp[index]
        center = self.centers[index]
        x, y = center
        patch = self.im.crop((x-self.r, y-self.r, x+self.r, y+self.r))
        exp = torch.Tensor(exp)
        mask = exp!=0
        mask = mask.float()
        if self.train:
            patch = self.transforms(patch)
        pos = torch.Tensor(self.pos[index])
        return patch, pos, exp, mask

    def __len__(self):
        return len(self.centers)


In [58]:
list(ds.meta_dict.values())[0]

Unnamed: 0,FO538757.1,SAMD11,NOC2L,KLHL17,PLEKHN1,PERM1,HES4,ISG15,AGRN,RNF223,...,USP9Y,DDX3Y,UTY,x,y,new_x,new_y,pixel_x,pixel_y,selected
10x10,0,0,0,0,1,0,0,2,0,0,...,0,0,0,10,10,9.937,10.017,2599.27,2620.50,1
10x11,0,0,0,0,0,0,0,2,0,0,...,0,0,0,10,11,9.945,11.103,2601.60,2936.11,1
10x12,0,0,0,0,0,0,0,6,0,0,...,0,0,0,10,12,9.941,12.028,2600.43,3204.93,1
10x13,1,0,2,0,0,0,9,14,2,0,...,0,0,0,10,13,9.939,13.029,2599.85,3495.84,1
10x14,0,0,1,0,0,0,3,4,2,0,...,0,0,0,10,14,9.987,14.089,2613.81,3803.89,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
9x23,2,1,4,1,0,0,12,9,0,0,...,0,0,0,9,23,9.064,22.968,2345.36,6384.29,1
9x24,2,1,1,0,0,0,9,16,2,0,...,0,0,0,9,24,9.124,23.955,2362.81,6671.13,1
9x25,0,0,1,1,0,0,5,18,2,0,...,0,0,0,9,25,9.114,24.978,2359.91,6968.43,1
9x26,1,1,0,0,0,0,1,9,1,0,...,0,0,0,9,26,9.104,25.987,2357.00,7261.66,1


In [45]:
# ds.center_dict['D4']

In [46]:
# ds.img_dict['D4']

In [44]:
ds.get_img('D4').size

(9307, 9881)