In [1]:
import numpy as np
import pandas as pd
import torch
import torch_geometric.transforms as T
import os

from collections import defaultdict
from torch_geometric.data import HeteroData
import torch.nn as nn
import torch.optim as optim
from torch_geometric.nn import MessagePassing, GCNConv, HeteroConv, GraphConv, Linear
import torch.nn.functional as F
from torch import Tensor
import matplotlib.pyplot as plt

In [2]:
from sentence_transformers import SentenceTransformer

In [3]:
from neo4j import GraphDatabase

URI = "neo4j://localhost:7687"
AUTH = ("test", "666666")

with GraphDatabase.driver(URI, auth=AUTH) as driver: 

    driver.verify_connectivity() 


def fetch_data(query):
  with driver.session(database='mtg-wilcox') as session:
    result = session.run(query)
    return pd.DataFrame([r.values() for r in result], columns=result.keys())

  driver.verify_connectivity()


In [4]:

class SequenceEncoder(object):
    # The 'SequenceEncoder' encodes raw column strings into embeddings using a sentence transformer.
    def __init__(self, model_name='all-MiniLM-L6-v2', device=None):
        self.device = device
        self.model = SentenceTransformer(model_name, device=device)

    @torch.no_grad()
    def __call__(self, df):
        x = self.model.encode(df.values, show_progress_bar=True,
                              convert_to_tensor=True, device=self.device)
        return x.cpu()
    
class GenresEncoder(object):
    # The 'GenreEncoder' splits the raw column strings by 'sep' and converts
    # individual elements to categorical labels.
    def __init__(self, sep='|'):
        self.sep = sep

    def __call__(self, df):
        genres = set(g for col in df.values for g in col.split(self.sep))
        mapping = {genre: i for i, genre in enumerate(genres)}

        x = torch.zeros(len(df), len(mapping))
        for i, col in enumerate(df.values):
            for genre in col.split(self.sep):
                x[i, mapping[genre]] = 1
        return x
    
class IdentityEncoder(object):
    # The 'IdentityEncoder' takes the raw column values and converts them to
    # PyTorch tensors.
    def __init__(self, dtype=None, is_list=False):
        self.dtype = dtype
        self.is_list = is_list

    def __call__(self, df):
        if self.is_list:
            return torch.stack([torch.tensor(el) for el in df.values])
        return torch.from_numpy(df.values).to(self.dtype)


In [5]:
def load_node(cypher, index_col, encoders=None, category_col=None, **kwargs):
    # Execute the cypher query and retrieve data from Neo4j
    df = fetch_data(cypher)
    df.set_index(index_col, inplace=True)
    # Define node mapping
    mapping = {index: i for i, index in enumerate(df.index.unique())}
    # Define node features
    x = None
    if encoders is not None:
        xs = [encoder(df[col]) for col, encoder in encoders.items()]
        x = torch.cat(xs, dim=-1)

    y = None
    if category_col is not None:
        # Get unique categories and map to numerical labels
        categories = df[category_col].unique()
        category_to_idx = {cat: idx for idx, cat in enumerate(sorted(categories))}
        
        # Map category column to numerical labels
        y = df[category_col].map(category_to_idx).values
        y = torch.tensor(y, dtype=torch.long) # length: n_nodes
        return x, y, mapping, category_to_idx
    return x, mapping


In [6]:
def load_edge(cypher, src_index_col, src_mapping, dst_index_col, dst_mapping,
                  encoders=None, **kwargs):
    # Execute the cypher query and retrieve data from Neo4j
    df = fetch_data(cypher)
    # Define edge index
    src = [src_mapping[index] for index in df[src_index_col]]
    dst = [dst_mapping[index] for index in df[dst_index_col]]
    edge_index = torch.tensor([src, dst])
    
    # Define edge features
    edge_attr = None
    if encoders is not None:
        edge_attrs = [encoder(df[col]) for col, encoder in encoders.items()]
        edge_attr = torch.cat(edge_attrs, dim=-1)

    return edge_index, edge_attr

In [14]:
ct_query = """
MATCH (ct:CellType)
WHERE ct.species_of_origin = 'H.sapiens' OR ct.species_of_origin = 'P.troglodytes'
RETURN ct.id AS cell_type_name_species, ct.cell_type_name AS cell_type_name
"""

ct_x, ct_y, ct_mapping, y_mapping = load_node(ct_query, index_col='cell_type_name_species', category_col='cell_type_name')

# ct_x has no node features
# ct_mapping is just a dictionary mapping cell_type_name to index numbers

In [15]:
ct_x

In [16]:
ct_y

tensor([ 0, 16, 22, 14, 15,  2,  7,  8, 11,  5, 17, 23, 19, 13, 12, 21, 18, 20,
         1,  6,  4,  3, 10,  9,  0, 16, 22, 14, 15,  2,  7,  8, 11,  5, 17, 23,
        19, 13, 12, 21, 18, 20,  1,  6,  4,  3, 10,  9])

In [17]:
ct_mapping

{'Astro_H.sapiens': 0,
 'Oligo_H.sapiens': 1,
 'VLMC_H.sapiens': 2,
 'Micro-PVM_H.sapiens': 3,
 'OPC_H.sapiens': 4,
 'Endo_H.sapiens': 5,
 'L5-6 NP_H.sapiens': 6,
 'L6 CT_H.sapiens': 7,
 'L6b_H.sapiens': 8,
 'L5 ET_H.sapiens': 9,
 'Pax6_H.sapiens': 10,
 'Vip_H.sapiens': 11,
 'Sncg_H.sapiens': 12,
 'Lamp5_Lhx6_H.sapiens': 13,
 'Lamp5_H.sapiens': 14,
 'Sst Chodl_H.sapiens': 15,
 'Pvalb_H.sapiens': 16,
 'Sst_H.sapiens': 17,
 'Chandelier_H.sapiens': 18,
 'L5 IT_H.sapiens': 19,
 'L4 IT_H.sapiens': 20,
 'L2-3 IT_H.sapiens': 21,
 'L6 IT Car3_H.sapiens': 22,
 'L6 IT_H.sapiens': 23,
 'Astro_P.troglodytes': 24,
 'Oligo_P.troglodytes': 25,
 'VLMC_P.troglodytes': 26,
 'Micro-PVM_P.troglodytes': 27,
 'OPC_P.troglodytes': 28,
 'Endo_P.troglodytes': 29,
 'L5-6 NP_P.troglodytes': 30,
 'L6 CT_P.troglodytes': 31,
 'L6b_P.troglodytes': 32,
 'L5 ET_P.troglodytes': 33,
 'Pax6_P.troglodytes': 34,
 'Vip_P.troglodytes': 35,
 'Sncg_P.troglodytes': 36,
 'Lamp5_Lhx6_P.troglodytes': 37,
 'Lamp5_P.troglodytes': 38

In [18]:
y_mapping

{'Astro': 0,
 'Chandelier': 1,
 'Endo': 2,
 'L2-3 IT': 3,
 'L4 IT': 4,
 'L5 ET': 5,
 'L5 IT': 6,
 'L5-6 NP': 7,
 'L6 CT': 8,
 'L6 IT': 9,
 'L6 IT Car3': 10,
 'L6b': 11,
 'Lamp5': 12,
 'Lamp5_Lhx6': 13,
 'Micro-PVM': 14,
 'OPC': 15,
 'Oligo': 16,
 'Pax6': 17,
 'Pvalb': 18,
 'Sncg': 19,
 'Sst': 20,
 'Sst Chodl': 21,
 'VLMC': 22,
 'Vip': 23}

In [13]:
 # now we give gene nodes encoded features based on GOs, KEGGs, etc.

In [19]:
gene_query = """
MATCH (gene:Gene)
WHERE gene.species_of_origin = 'H.sapiens' OR gene.species_of_origin = 'P.troglodytes'
RETURN gene.id as gene_id, gene.gos as gos, gene.pfams as pfams, gene.description as description, gene.kegg_pathway as kegg_pathways
"""
gene_x,gene_mapping = load_node(
    gene_query, 
    index_col='gene_id', encoders={
        'description': SequenceEncoder(),
        'gos': GenresEncoder(sep=','),
        'pfams': GenresEncoder(sep=','),
        'kegg_pathways': GenresEncoder(sep=','),
    })

Batches:   0%|          | 0/1127 [00:00<?, ?it/s]

In [20]:
gene_x.shape

torch.Size([36039, 28743])

In [28]:
gene_mapping

{'ARRB2_P.troglodytes': 0,
 'CRY2_P.troglodytes': 1,
 'ARRB1_P.troglodytes': 2,
 'CRY1_P.troglodytes': 3,
 'IGF1R_P.troglodytes': 4,
 'CAMK2D_P.troglodytes': 5,
 'MAPK8_P.troglodytes': 6,
 'HSP90AA1_P.troglodytes': 7,
 'FYN_P.troglodytes': 8,
 'AGO2_P.troglodytes': 9,
 'PRKCH_P.troglodytes': 10,
 'AP2A1_P.troglodytes': 11,
 'AP2S1_P.troglodytes': 12,
 'MAPK9_P.troglodytes': 13,
 'TGFBR1_P.troglodytes': 14,
 'SPAST_P.troglodytes': 15,
 'RAB7A_P.troglodytes': 16,
 'PAFAH1B1_P.troglodytes': 17,
 'SIRT4_P.troglodytes': 18,
 'SNX1_P.troglodytes': 19,
 'TGFBR2_P.troglodytes': 20,
 'RPA2_P.troglodytes': 21,
 'TMEM30A_P.troglodytes': 22,
 'EIF3J_P.troglodytes': 23,
 'HGS_P.troglodytes': 24,
 'PSMB5_P.troglodytes': 25,
 'TRIM5_P.troglodytes': 26,
 'ATG5_P.troglodytes': 27,
 'EIF3H_P.troglodytes': 28,
 'EIF3M_P.troglodytes': 29,
 'EIF3K_P.troglodytes': 30,
 'CACNA1D_P.troglodytes': 31,
 'RAB11B_P.troglodytes': 32,
 'FEN1_P.troglodytes': 33,
 'SNX5_P.troglodytes': 34,
 'POFUT1_P.troglodytes': 35,

In [27]:
gene_mapping.keys()

dict_keys(['ARRB2_P.troglodytes', 'CRY2_P.troglodytes', 'ARRB1_P.troglodytes', 'CRY1_P.troglodytes', 'IGF1R_P.troglodytes', 'CAMK2D_P.troglodytes', 'MAPK8_P.troglodytes', 'HSP90AA1_P.troglodytes', 'FYN_P.troglodytes', 'AGO2_P.troglodytes', 'PRKCH_P.troglodytes', 'AP2A1_P.troglodytes', 'AP2S1_P.troglodytes', 'MAPK9_P.troglodytes', 'TGFBR1_P.troglodytes', 'SPAST_P.troglodytes', 'RAB7A_P.troglodytes', 'PAFAH1B1_P.troglodytes', 'SIRT4_P.troglodytes', 'SNX1_P.troglodytes', 'TGFBR2_P.troglodytes', 'RPA2_P.troglodytes', 'TMEM30A_P.troglodytes', 'EIF3J_P.troglodytes', 'HGS_P.troglodytes', 'PSMB5_P.troglodytes', 'TRIM5_P.troglodytes', 'ATG5_P.troglodytes', 'EIF3H_P.troglodytes', 'EIF3M_P.troglodytes', 'EIF3K_P.troglodytes', 'CACNA1D_P.troglodytes', 'RAB11B_P.troglodytes', 'FEN1_P.troglodytes', 'SNX5_P.troglodytes', 'POFUT1_P.troglodytes', 'PAN3_P.troglodytes', 'CACNA1E_P.troglodytes', 'HAS2_P.troglodytes', 'ASZ1_P.troglodytes', 'ZNF365_P.troglodytes', 'MIEF1_P.troglodytes', 'LONP1_P.troglodytes

In [33]:
marker_query = """
MATCH (g:Gene)-[r:GeneWilcoxMarkerInCellType]->(ct:CellType) 
WHERE g.species_of_origin = 'H.sapiens' OR g.species_of_origin = 'P.troglodytes' AND r.avg_log2fc >= 2
RETURN g.id as gene_id, ct.id as cell_type_name_species, r.avg_log2fc as avg_log2fc
"""

edge_index, edge_weights = load_edge(
    marker_query,
    src_index_col='gene_id',
    src_mapping=gene_mapping, # the two index mappings were used for this
    dst_index_col='cell_type_name_species',
    dst_mapping=ct_mapping,
    encoders={'avg_log2fc': IdentityEncoder(dtype=torch.float32)} , # remember to set the correct dtype for identity encoding the edge weight
)

In [34]:
edge_weights

tensor([2.0387, 3.3695, 3.2551,  ..., 1.7481, 1.3200, 1.0406])

In [35]:
edge_index

tensor([[    5,     5,     5,  ..., 36036, 36037, 36038],
        [   31,    30,    32,  ...,    11,    11,    11]])

In [36]:
marker_query_2 = """
MATCH (g:Gene)-[r:GeneWilcoxMarkerInCellType]->(ct:CellType) 
WHERE g.species_of_origin = 'H.sapiens' OR g.species_of_origin = 'P.troglodytes' AND r.avg_log2fc >= 2
RETURN g.id as gene_id, ct.id as cell_type_name_species, r.avg_log2fc as avg_log2fc
"""

edge_index_2, edge_weights_2 = load_edge(
    marker_query_2,
    src_index_col='cell_type_name_species',
    src_mapping=ct_mapping,# the two index mappings were used for this
    dst_index_col='gene_id',
    dst_mapping=gene_mapping,
    encoders={'avg_log2fc': IdentityEncoder(dtype=torch.float32)} , # remember to set the correct dtype for identity encoding the edge weight
)

In [None]:
edge_index_2
# should be two rows-switched edge_index

tensor([[   31,    30,    32,  ...,    11,    11,    11],
        [    5,     5,     5,  ..., 36036, 36037, 36038]])

In [None]:
edge_weights_2

# should be the same as edge_weights 1

tensor([2.0387, 3.3695, 3.2551,  ..., 1.7481, 1.3200, 1.0406])

In [39]:
og_query = """
MATCH (og:OrthologousGroup)
WHERE og.eggnog_dataset_name = 'Mammalia'
RETURN og.orthologous_group_id as og_id
"""

og_x, og_mapping = load_node(og_query, index_col='og_id')

In [40]:
og_mapping

{'3J40A': 0,
 '3J7G6': 1,
 '3J605': 2,
 '3J1XG': 3,
 '3JEDX': 4,
 '3J2JM': 5,
 '3J5E0': 6,
 '3JAF8': 7,
 '3J90S': 8,
 '3JDI6': 9,
 '3J6VY': 10,
 '3J2VQ': 11,
 '3JDG2': 12,
 '3J4QZ': 13,
 '3J9JY': 14,
 '3J88D': 15,
 '3JEHT': 16,
 '3J704': 17,
 '3J2G1': 18,
 '3JF4R': 19,
 '3JFKB': 20,
 '3JBWT': 21,
 '3JCP3': 22,
 '3JE5B': 23,
 '3J916': 24,
 '3J1Q8': 25,
 '3JBVQ': 26,
 '3J4RQ': 27,
 '3JANP': 28,
 '3JAIF': 29,
 '3J9W3': 30,
 '3JE2P': 31,
 '3JBA6': 32,
 '3J66F': 33,
 '3JB0Q': 34,
 '3J7NT': 35,
 '3J842': 36,
 '3JC84': 37,
 '3J7JC': 38,
 '3J33G': 39,
 '3JEAP': 40,
 '3J2AP': 41,
 '3J3YN': 42,
 '3JDRG': 43,
 '3J2QU': 44,
 '3JBPA': 45,
 '3J72P': 46,
 '3JF16': 47,
 '3J5EU': 48,
 '3J9YB': 49,
 '3J4J3': 50,
 '3J6GF': 51,
 '3JCIK': 52,
 '3JHDM': 53,
 '3JACD': 54,
 '3J5DX': 55,
 '3J2E7': 56,
 '3J6NR': 57,
 '3J5A2': 58,
 '3J5HX': 59,
 '3JAG2': 60,
 '3JFCU': 61,
 '3JEBB': 62,
 '3J3CJ': 63,
 '3J5PB': 64,
 '3J8S2': 65,
 '3JB9G': 66,
 '3J6V6': 67,
 '3J2MI': 68,
 '3J915': 69,
 '3J436': 70,
 '3J5TK': 71,
 '

In [42]:
gene_og_query = """
MATCH (g:Gene)-[r:GeneInOrthologousGroup]->(og:OrthologousGroup)
WHERE og.eggnog_dataset_name = 'Mammalia' AND g.species_of_origin = 'H.sapiens' OR g.species_of_origin = 'P.troglodytes'
RETURN g.id as gene_id, og.orthologous_group_id as og_id
"""

edge_index_og, edge_weights_og = load_edge(
    gene_og_query,
    src_index_col='gene_id',
    src_mapping=gene_mapping, # the two index mappings were used for this
    dst_index_col='og_id',
    dst_mapping=og_mapping, 
)

In [43]:
edge_index_og

tensor([[    0,     1,     2,  ..., 27578, 27579, 27580],
        [    0,     1,     2,  ...,  7751, 13212,  7615]])

In [44]:
edge_weights_og

In [45]:
gene_og_query = """
MATCH (g:Gene)-[r:GeneInOrthologousGroup]->(og:OrthologousGroup)
WHERE og.eggnog_dataset_name = 'Mammalia' AND g.species_of_origin = 'H.sapiens' OR g.species_of_origin = 'P.troglodytes'
RETURN g.id as gene_id, og.orthologous_group_id as og_id
"""

edge_index_og_2, edge_weights_og_2 = load_edge(
    gene_og_query,
    src_index_col='og_id',
    src_mapping=og_mapping, # the two index mappings were used for this
    dst_index_col='gene_id',
    dst_mapping=gene_mapping, 
)
# again add reverse edge

In [46]:
edge_index_og_2

tensor([[    0,     1,     2,  ...,  7751, 13212,  7615],
        [    0,     1,     2,  ..., 27578, 27579, 27580]])

In [47]:
data = HeteroData()


In [48]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cpu


In [49]:
#  The edge_label tensor holds the ground truth labels that you want the model to predict for specific edges.
#  Used for edge prediction tasks

In [50]:
# Add user node features for message passing:
data['cell_type'].x = torch.eye(len(ct_mapping), device=device)
data['cell_type'].y = ct_y
# Add movie node features
data['gene'].x = gene_x

data['orthologous_group'].x = torch.eye(len(og_mapping), device=device)
data

HeteroData(
  cell_type={
    x=[48, 48],
    y=[48],
  },
  gene={ x=[36039, 28743] },
  orthologous_group={ x=[14948, 14948] }
)

In [51]:

# Add ratings between users and movies
data['gene', 'is_wilcox_marker_of', 'cell_type'].edge_index = edge_index
data['gene', 'is_wilcox_marker_of', 'cell_type'].edge_weights = edge_weights

data['cell_type', 'is_wilcox_marker_of', 'gene'].edge_index = edge_index_2
data['cell_type', 'is_wilcox_marker_of', 'gene'].edge_weights = edge_weights_2

data['gene', 'is_in', 'orthologous_group'].edge_index = edge_index_og
data['orthologous_group', 'is_in', 'gene'].edge_index = edge_index_og_2

data.to(device, non_blocking=True)

HeteroData(
  cell_type={
    x=[48, 48],
    y=[48],
  },
  gene={ x=[36039, 28743] },
  orthologous_group={ x=[14948, 14948] },
  (gene, is_wilcox_marker_of, cell_type)={
    edge_index=[2, 117640],
    edge_weights=[117640],
  },
  (cell_type, is_wilcox_marker_of, gene)={
    edge_index=[2, 117640],
    edge_weights=[117640],
  },
  (gene, is_in, orthologous_group)={ edge_index=[2, 20656] },
  (orthologous_group, is_in, gene)={ edge_index=[2, 20656] }
)

In [52]:
# I also need a reverse edge from cell type to genes
# just for the HGT model

In [53]:
data['cell_type'].x # is just a diagonal matrix - no features yet

tensor([[1., 0., 0.,  ..., 0., 0., 0.],
        [0., 1., 0.,  ..., 0., 0., 0.],
        [0., 0., 1.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 1., 0., 0.],
        [0., 0., 0.,  ..., 0., 1., 0.],
        [0., 0., 0.,  ..., 0., 0., 1.]])

In [54]:
data.node_types

['cell_type', 'gene', 'orthologous_group']

In [58]:
torch.save(data, 'mtg_hs_pt_wilcox_data_with_og.pt')

In [59]:
loaded_data = torch.load('mtg_hs_pt_wilcox_data_with_og.pt')

In [60]:
loaded_data

HeteroData(
  cell_type={
    x=[48, 48],
    y=[48],
  },
  gene={ x=[36039, 28743] },
  orthologous_group={ x=[14948, 14948] },
  (gene, is_wilcox_marker_of, cell_type)={
    edge_index=[2, 117640],
    edge_weights=[117640],
  },
  (cell_type, is_wilcox_marker_of, gene)={
    edge_index=[2, 117640],
    edge_weights=[117640],
  },
  (gene, is_in, orthologous_group)={ edge_index=[2, 20656] },
  (orthologous_group, is_in, gene)={ edge_index=[2, 20656] }
)

In [61]:
data.node_types

['cell_type', 'gene', 'orthologous_group']

In [66]:
data['cell_type'].y.shape

torch.Size([48])

In [74]:
ct_mapping

{'Astro_H.sapiens': 0,
 'Oligo_H.sapiens': 1,
 'VLMC_H.sapiens': 2,
 'Micro-PVM_H.sapiens': 3,
 'OPC_H.sapiens': 4,
 'Endo_H.sapiens': 5,
 'L5-6 NP_H.sapiens': 6,
 'L6 CT_H.sapiens': 7,
 'L6b_H.sapiens': 8,
 'L5 ET_H.sapiens': 9,
 'Pax6_H.sapiens': 10,
 'Vip_H.sapiens': 11,
 'Sncg_H.sapiens': 12,
 'Lamp5_Lhx6_H.sapiens': 13,
 'Lamp5_H.sapiens': 14,
 'Sst Chodl_H.sapiens': 15,
 'Pvalb_H.sapiens': 16,
 'Sst_H.sapiens': 17,
 'Chandelier_H.sapiens': 18,
 'L5 IT_H.sapiens': 19,
 'L4 IT_H.sapiens': 20,
 'L2-3 IT_H.sapiens': 21,
 'L6 IT Car3_H.sapiens': 22,
 'L6 IT_H.sapiens': 23,
 'Astro_P.troglodytes': 24,
 'Oligo_P.troglodytes': 25,
 'VLMC_P.troglodytes': 26,
 'Micro-PVM_P.troglodytes': 27,
 'OPC_P.troglodytes': 28,
 'Endo_P.troglodytes': 29,
 'L5-6 NP_P.troglodytes': 30,
 'L6 CT_P.troglodytes': 31,
 'L6b_P.troglodytes': 32,
 'L5 ET_P.troglodytes': 33,
 'Pax6_P.troglodytes': 34,
 'Vip_P.troglodytes': 35,
 'Sncg_P.troglodytes': 36,
 'Lamp5_Lhx6_P.troglodytes': 37,
 'Lamp5_P.troglodytes': 38

In [87]:
split={"train_idx": np.array(range(0, 24, 1)), "test_idx": np.array(range(24, 48, 1))}

In [88]:
# Now since i want to train on human and test on pt, i need to create the train val test split
for name in ['train', 'test']:
    idx = split[f'{name}_idx']
    idx = torch.from_numpy(idx).to(torch.long)
    mask = torch.zeros(data['cell_type'].num_nodes, dtype=torch.bool)
    mask[idx] = True
    data['cell_type'][f'{name}_mask'] = mask

In [89]:
data['cell_type']

{'x': tensor([[1., 0., 0.,  ..., 0., 0., 0.],
        [0., 1., 0.,  ..., 0., 0., 0.],
        [0., 0., 1.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 1., 0., 0.],
        [0., 0., 0.,  ..., 0., 1., 0.],
        [0., 0., 0.,  ..., 0., 0., 1.]]), 'y': tensor([ 0, 16, 22, 14, 15,  2,  7,  8, 11,  5, 17, 23, 19, 13, 12, 21, 18, 20,
         1,  6,  4,  3, 10,  9,  0, 16, 22, 14, 15,  2,  7,  8, 11,  5, 17, 23,
        19, 13, 12, 21, 18, 20,  1,  6,  4,  3, 10,  9]), 'train_mask': tensor([ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False]), 'test_mask': tensor([False, False, False, False, False, False, False, False, False, False,
        False, False, False, Fals

In [90]:
torch.save(data, 'mtg_hs_train_pt_test_wilcox_data_with_og.pt')