# Tahoe Inference - Chemical Perturbation Prediction
This notebook runs State model inference to predict chemical perturbation effects on single cells. We'll test the reproducibility of Tahoe results using our prepared dataset.
## input h5ad 

In [31]:
import h5py
import os
os.environ['MPLBACKEND'] = 'Agg' # solve the docker jupyter backend issue

In [32]:
# Check the structure of our gene perturbation sample H5 files
with h5py.File('virtual_cell/predicted_only_10.h5ad', 'r') as f:
    print(type(f))
    print("Keys in root:", list(f.keys()))
    if 'obsm' in f:
        print("Keys in obsm:", list(f['obsm'].keys()))
        print("Keys in X:", list(f['X'].keys()))
    else:
        print("No 'obsm' group found")

<class 'h5py._hl.files.File'>
Keys in root: ['X', 'layers', 'obs', 'obsm', 'obsp', 'uns', 'var', 'varm', 'varp']
Keys in obsm: []


AttributeError: 'Dataset' object has no attribute 'keys'

In [8]:

# Check the structure of one of COLAB H5 files
with h5py.File('/workspace/training_dataset/competition_support_set/competition_train.h5', 'r') as f:
    print(type(f))
    print("Keys in root:", list(f.keys()))
    if 'obsm' in f:
        print("Keys in obsm:", list(f['obsm'].keys()))
        print("Keys in X:", list(f['X'].keys()))
    else:
        print("No 'obsm' group found")

<class 'h5py._hl.files.File'>
Keys in root: ['X', 'layers', 'obs', 'obsm', 'obsp', 'uns', 'var', 'varm', 'varp']
Keys in obsm: []
Keys in X: ['data', 'indices', 'indptr']


In [5]:
# what parameter to use for pert_col in ST-Tahoe
! cat virtual_cell/ST-Tahoe/config.yaml | grep -i pert

  name: PerturbationDataModule
    pert_rep: onehot
    pert_col: drugname_drugconc
    control_pert: DMSO_TF
    perturbation_features_file: null
  name: PertSets
    freeze_pert: false


In [7]:
# our input downsample is gene perturbation 
import anndata as ad
adata = ad.read_h5ad("virtual_cell/predicted_only_10.h5ad")
print("Available columns in adata.obs:")
print(adata.obs.columns.tolist())

Available columns in adata.obs:
['target_gene']


##  Hugging Face Tahoe-100M Dataset 
This dataset contains over 100 million transcriptomic profiles from 50 cancer cell lines exposed to 1,100 small-molecule perturbations arcinstitute/ST-Parse (429G).
Test the model using downsampling.

In [9]:
! pip install datasets

Collecting datasets
  Downloading datasets-4.0.0-py3-none-any.whl.metadata (19 kB)
Collecting pyarrow>=15.0.0 (from datasets)
  Downloading pyarrow-21.0.0-cp311-cp311-manylinux_2_28_aarch64.whl.metadata (3.3 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting tqdm>=4.66.3 (from datasets)
  Downloading tqdm-4.67.1-py3-none-any.whl.metadata (57 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Collecting huggingface-hub>=0.24.0 (from datasets)
  Downloading huggingface_hub-0.34.4-py3-none-any.whl.metadata (14 kB)
Collecting hf-xet<2.0.0,>=1.1.3 (from huggingface-hub>=0.24.0->datasets)
  Downloading hf_xet-1.1.9-cp37-abi3-manylinux_2_28_aarch64.whl.metadata (4.7 kB)
Downloading datasets-4.0.0-py3-none-any.whl (49

### The Slow issue is that Hugging Face is downloading all the underlying Parquet files (3,388 files) even when you only want 10 samples. 
The Tahoe-100M dataset is split across 3,388 Parquet files, and load_dataset() is trying to download metadata from all of them to determine which files contain your first 10 samples.

In [11]:
import time
from datasets import load_dataset

print("Starting download test...")
start_time = time.time()

try:
    # Test with even smaller sample first
    dataset = load_dataset("tahoebio/Tahoe-100M", split="train[:10]", streaming=False)
    
    elapsed = time.time() - start_time
    print(f"✅ Downloaded 10 samples in {elapsed:.1f} seconds")
    print(f"Dataset shape: {len(dataset)}")
    
    # If that works, try 100
    if elapsed < 30:  # If 10 samples took less than 30 seconds
        dataset = load_dataset("tahoebio/Tahoe-100M", split="train[:100]", streaming=False)
        total_elapsed = time.time() - start_time
        print(f"✅ Downloaded 100 samples in {total_elapsed:.1f} seconds total")
        
except Exception as e:
    print(f"❌ Error: {e}")
    elapsed = time.time() - start_time
    print(f"Failed after {elapsed:.1f} seconds")

Starting download test...


Resolving data files:   0%|          | 0/3388 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/3388 [00:00<?, ?it/s]

Downloading data:   0%|          | 0/3388 [00:00<?, ?files/s]

train-00321-of-03388.parquet:   0%|          | 0.00/100M [00:00<?, ?B/s]

train-00322-of-03388.parquet:   0%|          | 0.00/107M [00:00<?, ?B/s]

train-00323-of-03388.parquet:   0%|          | 0.00/98.4M [00:00<?, ?B/s]

train-00324-of-03388.parquet:   0%|          | 0.00/101M [00:00<?, ?B/s]

train-00325-of-03388.parquet:   0%|          | 0.00/97.6M [00:00<?, ?B/s]

train-00326-of-03388.parquet:   0%|          | 0.00/103M [00:00<?, ?B/s]

train-00327-of-03388.parquet:   0%|          | 0.00/108M [00:00<?, ?B/s]

train-00328-of-03388.parquet:   0%|          | 0.00/94.5M [00:00<?, ?B/s]

train-00329-of-03388.parquet:   0%|          | 0.00/91.3M [00:00<?, ?B/s]

train-00330-of-03388.parquet:   0%|          | 0.00/90.7M [00:00<?, ?B/s]

train-00331-of-03388.parquet:   0%|          | 0.00/96.5M [00:00<?, ?B/s]

train-00332-of-03388.parquet:   0%|          | 0.00/98.8M [00:00<?, ?B/s]

train-00333-of-03388.parquet:   0%|          | 0.00/97.5M [00:00<?, ?B/s]

train-00334-of-03388.parquet:   0%|          | 0.00/88.0M [00:00<?, ?B/s]

train-00335-of-03388.parquet:   0%|          | 0.00/88.4M [00:00<?, ?B/s]

train-00336-of-03388.parquet:   0%|          | 0.00/92.0M [00:00<?, ?B/s]

train-00337-of-03388.parquet:   0%|          | 0.00/90.7M [00:00<?, ?B/s]

train-00338-of-03388.parquet:   0%|          | 0.00/84.9M [00:00<?, ?B/s]

train-00339-of-03388.parquet:   0%|          | 0.00/92.1M [00:00<?, ?B/s]

train-00340-of-03388.parquet:   0%|          | 0.00/97.8M [00:00<?, ?B/s]

train-00341-of-03388.parquet:   0%|          | 0.00/91.4M [00:00<?, ?B/s]

train-00342-of-03388.parquet:   0%|          | 0.00/95.2M [00:00<?, ?B/s]

train-00343-of-03388.parquet:   0%|          | 0.00/93.2M [00:00<?, ?B/s]

train-00344-of-03388.parquet:   0%|          | 0.00/84.9M [00:00<?, ?B/s]

train-00345-of-03388.parquet:   0%|          | 0.00/96.4M [00:00<?, ?B/s]

train-00346-of-03388.parquet:   0%|          | 0.00/103M [00:00<?, ?B/s]

train-00347-of-03388.parquet:   0%|          | 0.00/86.9M [00:00<?, ?B/s]

train-00348-of-03388.parquet:   0%|          | 0.00/95.1M [00:00<?, ?B/s]

train-00349-of-03388.parquet:   0%|          | 0.00/101M [00:00<?, ?B/s]

train-00350-of-03388.parquet:   0%|          | 0.00/93.7M [00:00<?, ?B/s]

train-00351-of-03388.parquet:   0%|          | 0.00/99.2M [00:00<?, ?B/s]

train-00352-of-03388.parquet:   0%|          | 0.00/95.4M [00:00<?, ?B/s]

train-00353-of-03388.parquet:   0%|          | 0.00/99.2M [00:00<?, ?B/s]

train-00354-of-03388.parquet:   0%|          | 0.00/108M [00:00<?, ?B/s]

train-00355-of-03388.parquet:   0%|          | 0.00/106M [00:00<?, ?B/s]

train-00356-of-03388.parquet:   0%|          | 0.00/78.0M [00:00<?, ?B/s]

train-00357-of-03388.parquet:   0%|          | 0.00/94.5M [00:00<?, ?B/s]

train-00358-of-03388.parquet:   0%|          | 0.00/107M [00:00<?, ?B/s]

train-00359-of-03388.parquet:   0%|          | 0.00/97.3M [00:00<?, ?B/s]

train-00360-of-03388.parquet:   0%|          | 0.00/90.1M [00:00<?, ?B/s]

train-00361-of-03388.parquet:   0%|          | 0.00/88.6M [00:00<?, ?B/s]

train-00362-of-03388.parquet:   0%|          | 0.00/91.9M [00:00<?, ?B/s]

train-00363-of-03388.parquet:   0%|          | 0.00/92.1M [00:00<?, ?B/s]

train-00364-of-03388.parquet:   0%|          | 0.00/98.6M [00:00<?, ?B/s]

train-00365-of-03388.parquet:   0%|          | 0.00/101M [00:00<?, ?B/s]

train-00366-of-03388.parquet:   0%|          | 0.00/95.5M [00:00<?, ?B/s]

train-00367-of-03388.parquet:   0%|          | 0.00/94.8M [00:00<?, ?B/s]

train-00368-of-03388.parquet:   0%|          | 0.00/103M [00:00<?, ?B/s]

train-00369-of-03388.parquet:   0%|          | 0.00/92.7M [00:00<?, ?B/s]

train-00370-of-03388.parquet:   0%|          | 0.00/83.9M [00:00<?, ?B/s]

train-00371-of-03388.parquet:   0%|          | 0.00/94.4M [00:00<?, ?B/s]

train-00372-of-03388.parquet:   0%|          | 0.00/95.8M [00:00<?, ?B/s]

train-00373-of-03388.parquet:   0%|          | 0.00/97.6M [00:00<?, ?B/s]

train-00374-of-03388.parquet:   0%|          | 0.00/84.8M [00:00<?, ?B/s]

train-00375-of-03388.parquet:   0%|          | 0.00/95.5M [00:00<?, ?B/s]

train-00376-of-03388.parquet:   0%|          | 0.00/102M [00:00<?, ?B/s]

train-00377-of-03388.parquet:   0%|          | 0.00/99.0M [00:00<?, ?B/s]

train-00378-of-03388.parquet:   0%|          | 0.00/83.1M [00:00<?, ?B/s]

train-00379-of-03388.parquet:   0%|          | 0.00/83.4M [00:00<?, ?B/s]

train-00380-of-03388.parquet:   0%|          | 0.00/91.7M [00:00<?, ?B/s]

train-00381-of-03388.parquet:   0%|          | 0.00/98.0M [00:00<?, ?B/s]

train-00382-of-03388.parquet:   0%|          | 0.00/94.0M [00:00<?, ?B/s]

train-00383-of-03388.parquet:   0%|          | 0.00/84.5M [00:00<?, ?B/s]

train-00384-of-03388.parquet:   0%|          | 0.00/97.6M [00:00<?, ?B/s]

train-00385-of-03388.parquet:   0%|          | 0.00/91.7M [00:00<?, ?B/s]

train-00386-of-03388.parquet:   0%|          | 0.00/88.2M [00:00<?, ?B/s]

train-00387-of-03388.parquet:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

train-00388-of-03388.parquet:   0%|          | 0.00/84.4M [00:00<?, ?B/s]

train-00389-of-03388.parquet:   0%|          | 0.00/85.3M [00:00<?, ?B/s]

train-00390-of-03388.parquet:   0%|          | 0.00/85.8M [00:00<?, ?B/s]

train-00391-of-03388.parquet:   0%|          | 0.00/84.3M [00:00<?, ?B/s]

train-00392-of-03388.parquet:   0%|          | 0.00/88.0M [00:00<?, ?B/s]

train-00393-of-03388.parquet:   0%|          | 0.00/86.1M [00:00<?, ?B/s]

train-00394-of-03388.parquet:   0%|          | 0.00/83.8M [00:00<?, ?B/s]

train-00395-of-03388.parquet:   0%|          | 0.00/91.5M [00:00<?, ?B/s]

train-00396-of-03388.parquet:   0%|          | 0.00/95.2M [00:00<?, ?B/s]

train-00397-of-03388.parquet:   0%|          | 0.00/78.8M [00:00<?, ?B/s]

train-00398-of-03388.parquet:   0%|          | 0.00/82.1M [00:00<?, ?B/s]

train-00399-of-03388.parquet:   0%|          | 0.00/86.3M [00:00<?, ?B/s]

train-00400-of-03388.parquet:   0%|          | 0.00/85.8M [00:00<?, ?B/s]

train-00401-of-03388.parquet:   0%|          | 0.00/81.5M [00:00<?, ?B/s]

train-00402-of-03388.parquet:   0%|          | 0.00/82.6M [00:00<?, ?B/s]

train-00403-of-03388.parquet:   0%|          | 0.00/83.0M [00:00<?, ?B/s]

train-00404-of-03388.parquet:   0%|          | 0.00/77.1M [00:00<?, ?B/s]

KeyboardInterrupt: 

## Model download

In [5]:
! pip install huggingface_hub # install huggingface_hub

Collecting huggingface_hub
  Downloading huggingface_hub-0.34.4-py3-none-any.whl.metadata (14 kB)
Collecting hf-xet<2.0.0,>=1.1.3 (from huggingface_hub)
  Downloading hf_xet-1.1.9-cp37-abi3-manylinux_2_28_aarch64.whl.metadata (4.7 kB)
Downloading huggingface_hub-0.34.4-py3-none-any.whl (561 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m561.5/561.5 kB[0m [31m14.3 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading hf_xet-1.1.9-cp37-abi3-manylinux_2_28_aarch64.whl (3.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.1/3.1 MB[0m [31m24.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: hf-xet, huggingface_hub
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2/2[0m [huggingface_hub] [huggingface_hub]
[1A[2KSuccessfully installed hf-xet-1.1.9 huggingface_hub-0.34.4
[0m

In [6]:
import os
from huggingface_hub import snapshot_download

# Define the repository and local directory
repo_id = "arcinstitute/ST-Tahoe"
local_dir = "virtual_cell/ST-Tahoe"

# Download all files from the repository
print(f"Downloading all files from {repo_id}...")
local_path = snapshot_download(
    repo_id=repo_id,
    local_dir=local_dir,
    local_dir_use_symlinks=False  # This ensures actual files are downloaded, not symlinks
)

Downloading all files from arcinstitute/ST-Tahoe...


For more details, check out https://huggingface.co/docs/huggingface_hub/main/en/guides/download#download-files-to-local-folder.


Fetching 14 files:   0%|          | 0/14 [00:00<?, ?it/s]

README.md: 0.00B [00:00, ?B/s]

cell_type_onehot_map.pkl:   0%|          | 0.00/518k [00:00<?, ?B/s]

MODEL_ACCEPTABLE_USE_POLICY.md: 0.00B [00:00, ?B/s]

batch_onehot_map.pkl:   0%|          | 0.00/16.0k [00:00<?, ?B/s]

MODEL_LICENSE.md: 0.00B [00:00, ?B/s]

config.yaml: 0.00B [00:00, ?B/s]

.gitattributes: 0.00B [00:00, ?B/s]

LICENSE.md: 0.00B [00:00, ?B/s]

data_module.torch:   0%|          | 0.00/1.90k [00:00<?, ?B/s]

pert_onehot_map.pt:   0%|          | 0.00/5.50M [00:00<?, ?B/s]

final.ckpt:   0%|          | 0.00/3.01G [00:00<?, ?B/s]

final_from_preprint.ckpt:   0%|          | 0.00/3.07G [00:00<?, ?B/s]

var_dims.pkl:   0%|          | 0.00/206k [00:00<?, ?B/s]

wandb_path.txt:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

In [7]:
! ls virtual_cell/ST-Tahoe/

LICENSE.md			cell_type_onehot_map.pkl  pert_onehot_map.pt
MODEL_ACCEPTABLE_USE_POLICY.md	config.yaml		  var_dims.pkl
MODEL_LICENSE.md		data_module.torch	  wandb_path.txt
README.md			final.ckpt
batch_onehot_map.pkl		final_from_preprint.ckpt


## Create a Test Dataset

In [18]:
import anndata as ad
import pandas as pd
import numpy as np

# Create a small synthetic dataset matching Tahoe format
n_cells = 1000
n_genes = 2000

# Generate expression data
X = np.random.negative_binomial(5, 0.3, size=(n_cells, n_genes))

# Create metadata with drug perturbations
drugs = ["8-Hydroxyquinoline", "Doxorubicin", "Paclitaxel", "DMSO"] * 250
obs = pd.DataFrame({
    'drug': drugs,
    'cell_line_id': np.random.choice(['A549', 'HeLa', 'MCF7'], n_cells),
    'sample': [f'smp_{i}' for i in range(n_cells)]
})

# Create gene names
var = pd.DataFrame({
    'gene_name': [f'gene_{i}' for i in range(n_genes)]
})

# Create AnnData object
adata = ad.AnnData(X=X, obs=obs, var=var)
adata.write_h5ad('test_tahoe_data.h5ad')



In [20]:
! ls -lh | grep test

-rw-r--r--  1 root root  16M Sep  9 20:31 test_tahoe_data.h5ad


## Inferecne synthetic Tahoe sample

In [29]:
! state tx infer \
  --output "virtual_cell/prediction_test_tahoe_data_250909.h5ad" \
  --model_dir virtual_cell/ST-Tahoe \
  --checkpoint virtual_cell/ST-Tahoe/final_from_preprint.ckpt \
  --adata "test_tahoe_data.h5ad" \
  --pert_col "drug"

INFO:state._cli._tx._infer:Loaded config from virtual_cell/ST-Tahoe/config.yaml
INFO:state._cli._tx._infer:Loading model from checkpoint: virtual_cell/ST-Tahoe/final_from_preprint.ckpt
PertSetsPerturbationModel(
  (loss_fn): SamplesLoss()
  (pert_encoder): Sequential(
    (0): Linear(in_features=1138, out_features=1488, bias=True)
    (1): GELU(approximate='none')
    (2): Dropout(p=0.1, inplace=False)
    (3): Linear(in_features=1488, out_features=1488, bias=True)
    (4): GELU(approximate='none')
    (5): Dropout(p=0.1, inplace=False)
    (6): Linear(in_features=1488, out_features=1488, bias=True)
    (7): GELU(approximate='none')
    (8): Dropout(p=0.1, inplace=False)
    (9): Linear(in_features=1488, out_features=1488, bias=True)
  )
  (basal_encoder): Sequential(
    (0): Linear(in_features=2000, out_features=1488, bias=True)
    (1): GELU(approximate='none')
    (2): Dropout(p=0.1, inplace=False)
    (3): Linear(in_features=1488, out_features=1488, bias=True)
    (4): GELU(approx

In [30]:
! ls -lh virtual_cell

total 93M
-rw-r--r--  1 root root 2.2K Jul 17 16:01 -r
drwxr-xr-x 15 root root  480 Jul 18 16:48 SE-600M
drwxr-xr-x 17 root root  544 Sep  2 18:18 ST-Tahoe
drwxr-xr-x 16 root root  512 Jul 18 16:42 ST_Parse
-rw-r--r--  1 root root  70M Jul 17 14:50 predicted_only_10.h5ad
-rw-r--r--  1 root root  24M Sep  9 20:37 prediction_test_tahoe_data_250909.h5ad


## Lessons learned when Inferecne our sample
### 1. Unexpected key(s) in state_dict: "basal_encoder.weight", "basal_encoder.bias"

 a mismatch between the model architecture that was saved in the checkpoint and the model architecture that's being loaded.

In [8]:
! state tx infer \
  --output "virtual_cell/prediction_ST-Tahoe_250902/prediction_only_10.h5ad" \
  --model_dir virtual_cell/ST-Tahoe \
  --checkpoint virtual_cell/ST-Tahoe/final.ckpt \
  --adata "virtual_cell/predicted_only_10.h5ad" \
  --pert_col "chemical"

INFO:state._cli._tx._infer:Loaded config from virtual_cell/ST-Tahoe/config.yaml
INFO:state._cli._tx._infer:Loading model from checkpoint: virtual_cell/ST-Tahoe/final.ckpt
PertSetsPerturbationModel(
  (loss_fn): SamplesLoss()
  (gene_decoder): LatentToGeneDecoder(
    (decoder): Sequential(
      (0): Linear(in_features=2000, out_features=1024, bias=True)
      (1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      (2): GELU(approximate='none')
      (3): Dropout(p=0.1, inplace=False)
      (4): Linear(in_features=1024, out_features=1024, bias=True)
      (5): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      (6): GELU(approximate='none')
      (7): Dropout(p=0.1, inplace=False)
      (8): Linear(in_features=1024, out_features=512, bias=True)
      (9): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (10): GELU(approximate='none')
      (11): Dropout(p=0.1, inplace=False)
      (12): Linear(in_features=512, out_features=2000, bias=True)
      (13): ReLU

In [33]:
# Tahoe synthetic data embedding first, but encounter the killed error. The moment memory used is 14G+, less than total 16 G. 
! state emb transform \
  --model-folder virtual_cell/SE-600M \
  --input virtual_cell/test_tahoe_data.h5ad \
  --output virtual_cell/test_tahoe_data_emb.h5ad

INFO:state._cli._emb._transform:Using model checkpoint: virtual_cell/SE-600M/se600m_epoch4.ckpt
INFO:state._cli._emb._transform:Creating inference object
INFO:state._cli._emb._transform:Loading model from checkpoint: virtual_cell/SE-600M/se600m_epoch4.ckpt
Killed


### shape '[1, -1, 2000]' is invalid for input of size 9256960 
ST-Parse logs for genetic and chemical perturbation dataset

In [39]:
! state tx infer \
  --output "virtual_cell/prediction_250617/prediction_only_10.h5ad" \
  --model_dir virtual_cell/ST_Parse \
  --checkpoint virtual_cell/ST_Parse/final.ckpt \
  --adata "virtual_cell/predicted_only_10.h5ad" \
  --pert_col "target_gene"

INFO:state._cli._tx._infer:Loaded config from virtual_cell/ST_Parse/config.yaml
INFO:state._cli._tx._infer:Loading model from checkpoint: virtual_cell/ST_Parse/final.ckpt
PertSetsPerturbationModel(
  (loss_fn): SamplesLoss()
  (gene_decoder): LatentToGeneDecoder(
    (decoder): Sequential(
      (0): Linear(in_features=2000, out_features=1024, bias=True)
      (1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      (2): GELU(approximate='none')
      (3): Dropout(p=0.1, inplace=False)
      (4): Linear(in_features=1024, out_features=1024, bias=True)
      (5): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      (6): GELU(approximate='none')
      (7): Dropout(p=0.1, inplace=False)
      (8): Linear(in_features=1024, out_features=512, bias=True)
      (9): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (10): GELU(approximate='none')
      (11): Dropout(p=0.1, inplace=False)
      (12): Linear(in_features=512, out_features=2000, bias=True)
      (13): ReLU