In [4]:
import time
import pickle
import torch
import copy
import torch.nn          as nn
import numpy             as np
import pandas            as pd
import matplotlib.pyplot as plt 
import seaborn           as sns

from typing           import List
from torch.utils.data import Dataset, DataLoader
from torch_geometric.loader import DataLoader as PyG_Dataloader

from config import (
    PATH_TO_FEATURES,
    PATH_TO_SAVED_DRUG_FEATURES,
    PATH_SUMMARY_DATASETS
)

torch.manual_seed(42)
sns.set_theme(style="white")

---

# Experiments on the `GraphTab` approach

In this notebook we are going to expirment the approach of 
- replacing the cell-line branch by a GNN and 
- having the drug branch using tabular input.

In [5]:
WITHOUT_MISSING_FOLDER = '/without_missing/'

## Pre-processing

In [19]:
# Reading the cell-line gene graphs.
with open(f'{PATH_SUMMARY_DATASETS}{WITHOUT_MISSING_FOLDER}cell_line_graphs_dict.pkl', 'rb') as f:
    cl_graphs = pickle.load(f)

# Reading the drug response matrix.
with open(f'{PATH_SUMMARY_DATASETS}{WITHOUT_MISSING_FOLDER}drug_response_matrix__gdsc2.pkl', 'rb') as f: 
    drug_cl = pickle.load(f)  

In [20]:
print(f"Number of cell-lines/graphs: {len(list(cl_graphs.keys()))}")
print(cl_graphs['22RV1'])

Number of cell-lines/graphs: 732
Data(x=[858, 4], edge_index=[2, 83126])


In [11]:
print(f"Shape: {drug_cl.shape}")
print(f"Number of different drug id's   : {len(np.unique(drug_cl.DRUG_ID.values))}")
print(f"Number of different drug name's : {len(np.unique(drug_cl.DRUG_NAME.values))}")
drug_cl.head(3)

Shape: (91991, 5)
Number of different drug id's   : 152
Number of different drug name's : 152


Unnamed: 0,CELL_LINE_NAME,DRUG_ID,DRUG_NAME,DATASET,LN_IC50
3441054,22RV1,1003,Camptothecin,GDSC2,-3.142631
3459252,22RV1,1004,Vinblastine,GDSC2,-4.459259
3508920,22RV1,1006,Cytarabine,GDSC2,3.826935


In [13]:
with open(f'{PATH_SUMMARY_DATASETS}drug_smiles_fingerprints_matrix.pkl', 'rb') as f:
    drug_name_smiles = pickle.load(f)
# drug_name_smiles.set_index(['drug_name'], inplace=True)
print(drug_name_smiles.shape)
# TODO: Note that yet there are some DRUG_NAME's which have >1 DRUG_ID
drug_name_smiles.head(5)

(152, 257)


Unnamed: 0,DRUG_ID,0,1,2,3,4,5,6,7,8,...,246,247,248,249,250,251,252,253,254,255
1,1073,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,1
810,1910,1,1,0,0,0,0,0,0,0,...,1,0,0,0,0,1,0,0,0,1
1562,1913,0,1,1,0,1,0,0,0,0,...,0,0,0,0,1,1,0,0,0,0
2314,1634,0,0,0,0,0,0,0,0,0,...,1,0,0,0,0,1,0,0,0,1
3044,2045,0,1,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,1,0,0


---

## Dataset summary

We got three datasets now: 
- Drug response matrix
- fingerprint matrix
- cell-line graph dictionary

In [15]:
# Drug response matrix holding the ln(IC50) values for each cell-line drug tuple.
drug_response_matrix = copy.deepcopy(drug_cl)
print(drug_response_matrix.shape)
drug_response_matrix.sort_values(['DRUG_ID']).head(5)

(91991, 5)


Unnamed: 0,CELL_LINE_NAME,DRUG_ID,DRUG_NAME,DATASET,LN_IC50
3441054,22RV1,1003,Camptothecin,GDSC2,-3.142631
3445814,QGP-1,1003,Camptothecin,GDSC2,-0.534857
3446843,RC-K8,1003,Camptothecin,GDSC2,-3.324229
3447648,RCC-JW,1003,Camptothecin,GDSC2,-2.573593
3440830,CHP-134,1003,Camptothecin,GDSC2,-3.817771


In [16]:
# The drug matrix holding the fingerprints for the drugs in the drug response matrix.
fingerprints = copy.deepcopy(drug_name_smiles)
print(fingerprints.shape)
fingerprints.sort_values(['DRUG_ID']).head(5)

(152, 257)


Unnamed: 0,DRUG_ID,0,1,2,3,4,5,6,7,8,...,246,247,248,249,250,251,252,253,254,255
21273,1003,0,0,0,0,0,0,0,0,0,...,0,0,1,0,1,0,0,0,0,0
94088,1004,1,0,0,0,0,0,0,0,1,...,0,1,0,1,0,1,0,0,0,0
25052,1006,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,1,0,0,0,0
39836,1010,1,0,0,0,0,0,0,0,0,...,0,0,0,1,0,0,0,0,0,1
61065,1011,0,1,0,0,0,0,0,0,0,...,1,0,0,0,1,1,0,0,0,1


In [17]:
# Drugs as dictionary.
fingerprints_dict = fingerprints.set_index('DRUG_ID').T.to_dict('list')
print(len(fingerprints_dict.keys()))
print(len(fingerprints_dict[1003]))

152
256


In [21]:
# The cell-line graphs holding for each cell-line in the drug-response-matrix the corresponding 
# graph with the cell-line level gene features.
cell_line_graphs = copy.deepcopy(cl_graphs)
print(f"Number of cell-lines/graphs: {len(list(cell_line_graphs.keys()))}")
print(cell_line_graphs['22RV1'])

Number of cell-lines/graphs: 732
Data(x=[858, 4], edge_index=[2, 83126])


In [22]:
c = 0
sum_per = []
for k, v in cell_line_graphs.items():
    if v.x.isnan().any(): 
        sum_per.append(v.x.isnan().sum())
        c += 1
print(c)

0


---

## Create PyTorch dataset

dataset should be 

| Tuple Datapoint | Cell Branch Input | Drug Branch Input | Target | Example Dataset |
| --------------- | ----------------- | ----------------- | ------ | --------------- |
| `(cl_1, drug_1)`| $\text{graph}_{\text{cl}_1}$ | $\text{smiles}_{\text{drug}_1}$ | $ln(IC50)_{\text{cl}_1\text{drug}_1}$ | Train |
| `(cl_1, ...   )`| $\text{graph}_{\text{cl}_1}$ | $\text{smiles}_{\text{drug}_{...}}$ | $ln(IC50)_{\text{cl}_1\text{drug}_{...}}$ | Train | 
| `(cl_1, drug_m)`| $\text{graph}_{\text{cl}_1}$ | $\text{smiles}_{\text{drug}_m}$ | $ln(IC50)_{\text{cl}_1\text{drug}_m}$ | Test | 
| `(cl_i, drug_1)`| $\text{graph}_{\text{cl}_i}$ | $\text{smiles}_{\text{drug}_1}$ | $ln(IC50)_{\text{cl}_i\text{drug}_1}$ | Test |
| `(cl_i, ...   )`| $\text{graph}_{\text{cl}_i}$ | $\text{smiles}_{\text{drug}_{...}}$ | $ln(IC50)_{\text{cl}_i\text{drug}_{...}}$ | Train | 
| `(cl_i, drug_m)`| $\text{graph}_{\text{cl}_i}$ | $\text{smiles}_{\text{drug}_m}$ | $ln(IC50)_{\text{cl}_i\text{drug}_m}$ | Train |
| `(cl_n, drug_1)`| $\text{graph}_{\text{cl}_n}$ | $\text{smiles}_{\text{drug}_1}$ | $ln(IC50)_{\text{cl}_n\text{drug}_1}$ | Test | 
| `(cl_n, ...   )`| $\text{graph}_{\text{cl}_n}$ | $\text{smiles}_{\text{drug}_{...}}$ | $ln(IC50)_{\text{cl}_n\text{drug}_{...}}$ | Test |
| `(cl_n, drug_m)`| $\text{graph}_{\text{cl}_n}$ | $\text{smiles}_{\text{drug}_m}$ | $ln(IC50)_{\text{cl}_n\text{drug}_m}$ | Train | 

In [26]:
drug_response_matrix.head(5)

Unnamed: 0,CELL_LINE_NAME,DRUG_ID,DRUG_NAME,DATASET,LN_IC50
3441054,22RV1,1003,Camptothecin,GDSC2,-3.142631
3459252,22RV1,1004,Vinblastine,GDSC2,-4.459259
3508920,22RV1,1006,Cytarabine,GDSC2,3.826935
3551362,22RV1,1010,Gefitinib,GDSC2,4.032555
3571270,22RV1,1011,Navitoclax,GDSC2,3.963435


In [23]:
from typing import Tuple
from torch_geometric.data import Data, Dataset

class GraphTabDataset(Dataset): 
    def __init__(self, cl_graphs, drugs, drug_response_matrix):
        super().__init__()

        # SMILES fingerprints of the drugs and cell-line graphs.
        self.drugs = drugs
        self.cell_line_graphs = cl_graphs

        # Lookup datasets for the response values.
        drug_response_matrix.reset_index(drop=True, inplace=True)
        self.cell_lines = drug_response_matrix['CELL_LINE_NAME']
        self.drug_ids = drug_response_matrix['DRUG_ID']
        self.drug_names = drug_response_matrix['DRUG_NAME']
        self.ic50s = drug_response_matrix['LN_IC50']

    def __len__(self):
        return len(self.ic50s)

    def __getitem__(self, idx: int):
        """
        Returns a tuple of cell-line, drug and the corresponding ln(IC50)
        value for a given index.

        Args:
            idx (`int`): Index to specify the row in the drug response matrix.  
        Returns:
            `Tuple[torch_geometric.data.data.Data, np.ndarray, np.float64]`:
            Tuple of a cell-line graph, drug SMILES fingerprint and the 
            corresponding ln(IC50) value.
        """
        return (self.cell_line_graphs[self.cell_lines.iloc[idx]], 
                self.drugs[self.drug_ids.iloc[idx]],
                self.ic50s.iloc[idx])

    def print_dataset_summary(self):
        print(f"GraphTabDataset Summary")
        print(f"{23*'='}")
        print(f"# observations : {len(self.ic50s)}")
        print(f"# cell-lines   : {len(np.unique(self.cell_lines))}")
        print(f"# drugs        : {len(np.unique(self.drug_names))}")
        print(f"# genes        : {self.cell_line_graphs[next(iter(self.cell_line_graphs))].x.shape[0]}")

In [24]:
graph_tab_dataset = GraphTabDataset(cl_graphs=cell_line_graphs,
                                    drugs=fingerprints_dict,
                                    drug_response_matrix=drug_response_matrix)

graph_tab_dataset.print_dataset_summary()                                    

GraphTabDataset Summary
# observations : 91991
# cell-lines   : 732
# drugs        : 152
# genes        : 858


In [25]:
random_seed = 12345

BATCH_SIZE = 10 # TODO: tune batch_size
LR = 0.001 # TODO: tune this or implement lr decay
TRAIN_RATIO = 0.8 
TEST_VAL_RATIO = 1-TRAIN_RATIO # How much of all data is for the test and validation set.
VAL_RATIO = 0.5 # How much of the of the test and validation set is only for validation.
NUM_EPOCHS = 100

In [28]:
from sklearn.model_selection import train_test_split
from torch_geometric.data import Batch
from torch.utils.data import DataLoader
# from torch_geometric.loader import DataLoader

def _collate_graph_tab(samples):
    """
    Collates a list of cell-line graphs, SMILES drug fingerprints and 
    ln(IC50) values to pytorch handable formats.

    Args: 
        samples (`List[Tuple[torch.Data, List[int], float]]`): List of tuples of 
            cell-line graphs, SMILES drug fingerprints and ln(IC50) values.
            Example:
            >>> [(Data(x=[858, 4], edge_index=[2, 83126]), [0, 0, 1, ..., 0], 4.792643),
                 (Data(x=[858, 4], edge_index=[2, 83126]), [1, 0, 1, ..., 1], 5.857639),
                 ...]
            
    Returns:
        `DataBatch`: All graphs in the batch as one big (disconnected) graph.
            Example: 
            >>> DataBatch(x=[858000, 4], edge_index=[2, 83126000], batch=[858000], ptr=[1001])
                # x=[858*BATCH_SIZE, 4], edge_index=[2, 83126*BATCH_SIZE], batch=[858*BATCH_SIZE]
        `torch.tensor(List[List[int]])`: tensor of list of drug fingerprints.
            Example:
            >>> torch.tensor([[0,0,1,...,0], 
                              [1,0,1,...,1],
                              ...])
        `torch.tensor(float)`: tensor of ln(IC50)'s.
            Example: 
            >>> torch.tensor(4.792643, 5.857639, ...)
    """
    cells, drugs, targets = map(list, zip(*samples))
    drugs = [torch.tensor(drug_fp, dtype=torch.float64) for drug_fp in drugs] # list of fingerprint tensors

    return Batch.from_data_list(cells), torch.stack(drugs, 0), torch.tensor(targets)
    # return "hallo", torch.stack(drugs, 0), torch.tensor(targets)

def create_datasets(drm, cl_dict, drug_dict):
    print(f"Full     shape: {drm.shape}")
    train_set, test_val_set = train_test_split(drm, 
                                               test_size=TEST_VAL_RATIO, 
                                               random_state=random_seed,
                                               stratify=drm['CELL_LINE_NAME'])
    test_set, val_set = train_test_split(test_val_set,
                                         test_size=VAL_RATIO,
                                         random_state=random_seed,
                                         stratify=test_val_set['CELL_LINE_NAME'])
    print(f"train    shape: {train_set.shape}")
    print(f"test_val shape: {test_val_set.shape}")
    print(f"test     shape: {test_set.shape}")
    print(f"val      shape: {val_set.shape}")

    train_dataset = GraphTabDataset(cl_graphs=cl_dict,
                                    drugs=drug_dict,
                                    drug_response_matrix=train_set)
    test_dataset = GraphTabDataset(cl_graphs=cl_dict,
                                   drugs=drug_dict,
                                   drug_response_matrix=test_set)
    val_dataset = GraphTabDataset(cl_graphs=cl_dict,
                                  drugs=drug_dict,
                                  drug_response_matrix=val_set)

    print("\ntrain_dataset:")
    train_dataset.print_dataset_summary()
    print("\n\ntest_dataset:")
    test_dataset.print_dataset_summary()
    print("\n\nval_dataset:")
    val_dataset.print_dataset_summary()

    # TODO: try out different `num_workers`.
    train_loader = DataLoader(dataset=train_dataset, 
                              batch_size=BATCH_SIZE,
                              shuffle=True,
                              collate_fn=_collate_graph_tab)
    test_loader = DataLoader(dataset=test_dataset,
                             batch_size=BATCH_SIZE,
                             shuffle=True,
                             collate_fn=_collate_graph_tab)
    val_loader = DataLoader(dataset=val_dataset,
                            batch_size=BATCH_SIZE,
                            shuffle=True,
                            collate_fn=_collate_graph_tab)

    return train_loader, test_loader, val_loader

train_loader, test_loader, val_loader = create_datasets(drug_response_matrix, 
                                                        cl_graphs,
                                                        fingerprints_dict)

Full     shape: (91991, 5)
train    shape: (73592, 5)
test_val shape: (18399, 5)
test     shape: (9199, 5)
val      shape: (9200, 5)

train_dataset:
GraphTabDataset Summary
# observations : 73592
# cell-lines   : 732
# drugs        : 152
# genes        : 858


test_dataset:
GraphTabDataset Summary
# observations : 9199
# cell-lines   : 732
# drugs        : 152
# genes        : 858


val_dataset:
GraphTabDataset Summary
# observations : 9200
# cell-lines   : 732
# drugs        : 152
# genes        : 858


In [29]:
for step, data in enumerate(train_loader):
    cell, drugs, targets = data
    print(f'Step {step + 1}:')
    print(f'=======')
    print(f'Number of graphs in the current batch: {cell.num_graphs}')
    print(cell)
    # print(drugs)
    # print(targets)
    print()

    if step == 3:
        break

Step 1:
Number of graphs in the current batch: 10
DataBatch(x=[8580, 4], edge_index=[2, 831260], batch=[8580], ptr=[11])

Step 2:
Number of graphs in the current batch: 10
DataBatch(x=[8580, 4], edge_index=[2, 831260], batch=[8580], ptr=[11])

Step 3:
Number of graphs in the current batch: 10
DataBatch(x=[8580, 4], edge_index=[2, 831260], batch=[8580], ptr=[11])

Step 4:
Number of graphs in the current batch: 10
DataBatch(x=[8580, 4], edge_index=[2, 831260], batch=[8580], ptr=[11])



In [30]:
# for batch_cell_graph, batch_drugs, batch_ic50s in train_loader
print("Number of batches per dataset:")
print(f"  train : {len(train_loader)}")
print(f"  test  : {len(test_loader)}")
print(f"  val   : {len(val_loader)}")

Number of batches per dataset:
  train : 7360
  test  : 920
  val   : 920


By using `Batch.from_data_list(cells)` we batched all graphs in a batch into a single giant graph.

__TODO__: 
- [ ] We can also use the class `torch_geometric.data.DataLoader` directly to do this.
- example notebook: https://colab.research.google.com/drive/1I8a0DfQ3fI7Njc62__mVXUlcAleUclnb?usp=sharing#scrollTo=G-DwBYkquRUN

- PyTorch.data doc: https://pytorch.org/docs/stable/data.html

In [74]:
for data in train_loader: 
    cell, drug, ic = data
    print(drug)
    print(drug[0][0])
    break

tensor([[0., 0., 0.,  ..., 0., 0., 1.],
        [1., 1., 0.,  ..., 0., 0., 1.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 1., 0.,  ..., 0., 0., 0.],
        [1., 1., 0.,  ..., 0., 0., 1.],
        [0., 0., 0.,  ..., 1., 0., 1.]], dtype=torch.float64)
tensor(0., dtype=torch.float64)


In [78]:
Batch.to_data_list(cell)[:10]

[Data(x=[858, 4], edge_index=[2, 83126]),
 Data(x=[858, 4], edge_index=[2, 83126]),
 Data(x=[858, 4], edge_index=[2, 83126]),
 Data(x=[858, 4], edge_index=[2, 83126]),
 Data(x=[858, 4], edge_index=[2, 83126]),
 Data(x=[858, 4], edge_index=[2, 83126]),
 Data(x=[858, 4], edge_index=[2, 83126]),
 Data(x=[858, 4], edge_index=[2, 83126]),
 Data(x=[858, 4], edge_index=[2, 83126]),
 Data(x=[858, 4], edge_index=[2, 83126])]

## Model development

In [32]:
from tqdm import tqdm
import torch.nn.functional as F

class BuildModel():
    def __init__(self, model, criterion, optimizer, num_epochs, train_loader, test_loader, val_loader):
        self.train_losses = []
        self.test_losses = []
        self.val_losses = []

        self.train_loader = train_loader
        self.test_loader = test_loader
        self.val_loader = val_loader

        self.num_epochs = num_epochs

        self.model = model
        self.criterion = criterion
        self.optimizer = optimizer

    def train(self): 
        train_losses, test_losses = [], []

        for epoch in range(self.num_epochs):
            model.train()
            print("=====Epoch {}".format(epoch))
            print("Training...")
            for i, data in enumerate(tqdm(self.train_loader, desc='Iteration')):
                cell, drug, targets = data
                cell, drug, targets = cell.to(device), drug.to(device), targets.to(device)

                print(f"targets.size : {targets.size}")
                print(f"targets      : {targets[:10]}")

                # Models predictions of the ic50s for a batch of cell-lines and drugs
                preds = self.model(cell, drug.float())
                print(100*"=")
                print(targets)
                print(targets.view(-1, 1))
                losses = self.criterion(preds, targets.view(-1, 1).float()) # =train_loss

                # TODO: maybe put this on top of the loop.
                self.optimizer.zero_grad()
                losses.backward()
                self.optimizer.step()

            train_losses.add(losses)

            # TODO: Testing
            print("Testing...")


    def validate(self, model, loader, device): 
        model.eval()

        y_true, y_pred = [], []
        total_loss = 0
        with torch.no_grad():
            for data in tqdm(loader, desc='Iteration'):
                cell, drug, targets = data
                cell, drug, targets = cell.to(device), drug.to(device), targets.to(device)

                preds = model(cell, drug)
                total_loss += F.mse_loss(preds, targets.view(-1, 1).float(), reduction='sum')
                y_true.append(targets.view(-1, 1))
                y_pred.append(preds)

In [33]:
%load_ext autoreload
%autoreload
from v3_GCN import GraphTab_v1
from my_utils.model_helpers import train_and_test_model

torch.manual_seed(random_seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"device: {device}")

model = GraphTab_v1().to(device)
loss_func = nn.MSELoss()
optimizer = torch.optim.Adam(params=model.parameters(), lr=LR) # TODO: include weight_decay of lr

build_model = BuildModel(model=model,
                         criterion=loss_func,
                         optimizer=optimizer,
                         num_epochs=10,
                         train_loader=train_loader,
                         test_loader=test_loader,
                         val_loader=val_loader)

build_model.train()

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
device: cpu
self.drug_nn: Sequential(
  (0): Linear(in_features=256, out_features=128, bias=True)
  (1): ReLU()
  (2): Linear(in_features=128, out_features=128, bias=True)
  (3): ReLU()
)
self.cell_emb: Sequential(
  (0): GCNConv(4, 256)
  (1): ReLU(inplace=True)
  (2): GCNConv(256, 256)
  (3): ReLU(inplace=True)
  (4): <function global_mean_pool at 0x12e8a1480>
  (5): Linear(in_features=256, out_features=128, bias=True)
  (6): ReLU()
  (7): Linear(in_features=128, out_features=128, bias=True)
  (8): ReLU()
)
=====Epoch 0
Training...


Iteration:   0%|          | 0/7360 [00:00<?, ?it/s]


targets.size : <built-in method size of Tensor object at 0x13de702c0>
targets      : tensor([ 4.1808,  1.0550,  6.2007,  5.4340,  4.9870,  2.0163,  0.2725,  3.8410,
        -0.7047,  3.2447], dtype=torch.float64)
tensor([[0., 1., 0.,  ..., 1., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 1.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 1., 0.,  ..., 0., 0., 1.],
        [0., 1., 0.,  ..., 0., 0., 1.],
        [0., 0., 1.,  ..., 0., 0., 1.]])
drug_emb.shape: torch.Size([10, 128])
drug_emb: tensor([[0.0149, 0.0569, 0.0000,  ..., 0.1290, 0.1048, 0.0743],
        [0.0020, 0.0000, 0.0000,  ..., 0.0303, 0.0006, 0.0318],
        [0.0488, 0.0000, 0.0000,  ..., 0.0342, 0.0000, 0.1340],
        ...,
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0132],
        [0.0000, 0.1214, 0.0000,  ..., 0.1818, 0.0483, 0.0621],
        [0.0580, 0.0000, 0.0000,  ..., 0.1392, 0.0000, 0.0631]],
       grad_fn=<ReluBackward0>)


RuntimeError: index 8580 is out of bounds for dimension 0 with size 8580

In [34]:
for step, data in enumerate(train_loader):
    cell, drug, ic50 = data
    print(f'Step {step + 1}:')
    print(f'=======')
    print(f'Number of graphs in the current batch: {cell.num_graphs}')
    print(cell)
    print(ic50)
    print()

    if step == 1:
        break

Step 1:
Number of graphs in the current batch: 10
DataBatch(x=[8580, 4], edge_index=[2, 831260], batch=[8580], ptr=[11])
tensor([ 4.5438,  4.4568, -0.9637, -2.6823, -3.9976,  3.0180, -2.8586,  4.7727,
        -2.9860, -1.4190], dtype=torch.float64)

Step 2:
Number of graphs in the current batch: 10
DataBatch(x=[8580, 4], edge_index=[2, 831260], batch=[8580], ptr=[11])
tensor([ 3.1815,  2.9466,  0.6865,  4.0391,  0.9185,  4.0514,  2.4978,  6.8512,
        -3.0862,  4.6164], dtype=torch.float64)



In [304]:
# 858000
cell.x[858000]

IndexError: index 858000 is out of bounds for dimension 0 with size 858000

In [305]:
cell.x[858000-1]

tensor([ 4.2288, -1.0000,  2.0000,  0.0000], dtype=torch.float64)