# Electrocardiogram Analysis using ECG-FM

The electrocardiogram (ECG) is a low-cost, non-invasive diagnostic test that has been ubiquitous in the assessment and management of cardiovascular disease for decades. ECG-FM is a pretrained, open foundation model for ECG analysis.

In this tutorial, we will introduce how to perform inference for multi-label classification using a finetuned ECG-FM model. Specifically, we will take a model finetuned on the [PhysioNet 2021 v1.0.3 dataset](https://physionet.org/content/challenge-2021/1.0.3/) and perform inference on a sample of the [CODE-15% v1.0.0 dataset](https://zenodo.org/records/4916206/) to show how to adapt the predictions to a new set of labels.

## Overview
0. Installation
1. Prepare checkpoints
2. Prepare data
3. Run inference
4. Interpret results

## 0. Installation

ECG-FM was developed in collaboration with the [fairseq_signals](https://github.com/Jwoo5/fairseq-signals) framework, which implements a collection of deep learning methods for ECG analysis.

Clone [fairseq_signals](https://github.com/Jwoo5/fairseq-signals) and refer to the requirements and installation section in the top-level README. After following those steps, install `pandas` and make the environment accessible within this notebook by running:
```
python3 -m pip install --user pandas
python3 -m pip install --user --upgrade jupyterlab ipywidgets ipykernel
python3 -m ipykernel install --user --name ecg_fm
```

In [1]:
import os
import pandas as pd
import torch

from fairseq_signals.utils.store import MemmapReader



In [9]:
root = '/home/aa2650/playground/ECG-FM'
root

'/home/aa2650/playground/ECG-FM'

In [3]:
fairseq_signals_root = '/home/aa2650/playground/fairseq-signals'
fairseq_signals_root = fairseq_signals_root.rstrip('/')
fairseq_signals_root

'/home/aa2650/playground/fairseq-signals'

## 1. Prepare checkpoints

### Download checkpoints

The checkpoints are available on [HuggingFace](https://huggingface.co/wanglab/ecg-fm-preprint). Alternatively, they can be downloaded using the below commands.

**Disclaimer: These models are different from those reported in our arXiv paper.** These BERT-Base sized models were trained purely on public data sources due to privacy concerns surrounding UHN-ECG data and patient identification. Validation for the final models will be available upon full publication.

In [6]:
from huggingface_hub import hf_hub_download

_ = hf_hub_download(
    repo_id='wanglab/ecg-fm-preprint',
    filename='physionet_finetuned.pt',
    local_dir=os.path.join(root, 'notebooks/ckpts'),
)
_ = hf_hub_download(
    repo_id='wanglab/ecg-fm-preprint',
    filename='physionet_finetuned.yaml',
    local_dir=os.path.join(root, 'notebooks/ckpts'),
)

physionet_finetuned.yaml:   0%|          | 0.00/3.56k [00:00<?, ?B/s]

In [7]:
assert os.path.isfile(os.path.join(root, 'notebooks/ckpts/physionet_finetuned.pt'))
assert os.path.isfile(os.path.join(root, 'notebooks/ckpts/physionet_finetuned.yaml'))

## 2. Prepare data

The model being used was finetuned on the [PhysioNet 2021 v1.0.3 dataset](https://physionet.org/content/challenge-2021/1.0.3/). To simplify this tutorial, we have processed a sample of 10 ECGs (14 5s segments) from the [CODE-15% v1.0.0 dataset](https://zenodo.org/records/4916206/) so that we may demonstrate how to adapt the predictions to a new set of labels.

If looking to perform inference on a full dataset (or using your own dataset), refer to the flexible, end-to-end, multi-source data preprocessing pipeline described [here](https://github.com/Jwoo5/fairseq-signals/tree/master/scripts/preprocess/ecg). Its README is useful for understanding how the data is organized. There are preprocessing scripts implemented for several datasets.

### Update manifest

The segmented split must be saved with absolute file paths, so we will update the current relative file paths accordingly.

In [4]:
segmented_split = pd.read_csv(
    os.path.join(root, 'data/code_15/segmented_split_incomplete.csv'),
    index_col='idx',
)
# segmented_split['path'] = (root + '/data/code_15/segmented/') + segmented_split['path']
# segmented_split.to_csv(os.path.join(root, 'data/code_15/segmented_split.csv'))

In [5]:
assert os.path.isfile(os.path.join(root, 'data/code_15/segmented_split.csv'))

Run the follow commands togenerate the `test.tsv` file used for inference.

In [10]:
cmd = f"""
cd {fairseq_signals_root}/scripts/preprocess && \
python manifests.py \
    --split_file_paths "{root}/data/code_15/segmented_split.csv" \
    --save_dir "{root}/data/manifests/code_15_subset10/"
"""
os.system(cmd)


0

In [11]:
assert os.path.isfile(os.path.join(root, 'data/manifests/code_15_subset10/test.tsv'))

## 3. Run inference

Inside our environment, we can run the following command using hydra's command line interface to extract the logits for each segment. There must be an available GPU.

In [None]:
inference_cmd = f"""fairseq-hydra-inference \\
    task.data="{root}/data/manifests/code_15_subset10/" \\
    common_eval.path="{root}/notebooks/ckpts/physionet_finetuned.pt" \\
    common_eval.results_path="{root}/outputs" \\
    model.num_labels=26 \\
    dataset.valid_subset="test" \\
    dataset.batch_size=10 \\
    dataset.num_workers=3 \\
    dataset.disable_validation=false \\
    distributed_training.distributed_world_size=1 \\
    distributed_training.find_unused_parameters=True \\
    --config-dir "{root}/notebooks/ckpts" \\
    --config-name physionet_finetuned
"""

os.system(inference_cmd)

In [29]:
assert os.path.isfile(os.path.join(root, 'outputs/outputs_test.npy'))
assert os.path.isfile(os.path.join(root, 'outputs/outputs_test_header.pkl'))

## 4. Interpret results

The logits are ordered same as the samples in the manifest and labels in the label definition.

### Get predictions on PhysioNet 2021 labels

In [6]:
physionet2021_label_def = pd.read_csv(
    os.path.join(root, 'data/physionet2021/labels/label_def.csv'),
     index_col='name',
)
physionet2021_label_names = physionet2021_label_def.index
physionet2021_label_def

Unnamed: 0_level_0,pos_count_all,pos_percent_all
name,Unnamed: 1_level_1,Unnamed: 2_level_1
AF,5230,0.060793
AFL,8271,0.096142
BBB,490,0.005696
Brady,283,0.00329
CLBBB|LBBB,1487,0.017285
CRBBB|RBBB,4794,0.055725
IAVB,3516,0.04087
IRBBB,1854,0.021551
LAD,7614,0.088505
LAnFB,2179,0.025329


In [7]:
# Load the array of computed logits
logits = MemmapReader.from_header(
    os.path.join(root, 'outputs/outputs_test.npy')
)[:]
logits.shape

(14, 26)

In [8]:
# Construct predictions from logits
pred = pd.DataFrame(
    torch.sigmoid(torch.tensor(logits)).numpy(),
    columns=physionet2021_label_names,
)

# Join in sample information
pred = segmented_split.reset_index().join(pred, how='left').set_index('idx')
pred

Unnamed: 0_level_0,save_file,split,path,sample_size,AF,AFL,BBB,Brady,CLBBB|LBBB,CRBBB|RBBB,...,PR,PRWP,PVC|VPB,QAb,RAD,SA,SB,STach,TAb,TInv
idx,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
0,code_15_438277.mat,test,code_15_438277_0.mat,2500,0.00115,0.001436,1.465139e-15,0.3206733,1.055977e-06,0.004426,...,1.05666e-07,1.591192e-11,0.08102255,8.600521e-08,4.068469e-15,0.002309897,0.0004729232,4.5e-05,0.006356,1.8e-05
1,code_15_358121.mat,test,code_15_358121_0.mat,2500,0.002214,3.9e-05,8.9982e-09,1.075806e-09,4.526197e-06,0.999969,...,1.393881e-08,9.013309e-14,6.347319e-07,1.916207e-06,1.114556e-07,0.004842394,1.612027e-07,0.000795,0.066468,0.006737
2,code_15_1594286.mat,test,code_15_1594286_0.mat,2500,0.000137,8e-06,1.446803e-11,3.488379e-07,0.08700616,0.997258,...,1.374848e-07,2.920221e-11,3.470497e-10,0.004947504,3.023014e-06,0.0005003812,0.01905788,0.000454,0.000527,0.000309
2,code_15_1594286.mat,test,code_15_1594286_1.mat,2500,7.6e-05,0.01058,3.114224e-09,6.418872e-05,0.003409652,0.999826,...,7.799306e-05,6.360084e-09,7.854737e-08,1.693255e-05,6.996347e-09,9.390274e-06,0.0001750806,0.000296,0.000292,0.000458
3,code_15_975093.mat,test,code_15_975093_0.mat,2500,8.1e-05,0.005579,7.320367e-10,7.329348e-05,1.005381e-08,0.000283,...,3.378041e-06,5.209452e-11,0.0009393553,4.880567e-05,1.006852e-06,0.01523182,0.9244674,0.000849,0.000944,0.001305
3,code_15_975093.mat,test,code_15_975093_1.mat,2500,2.5e-05,0.000403,5.982543e-10,0.002227188,1.790568e-08,0.001132,...,3.719527e-08,8.07316e-11,0.0001271174,1.160941e-05,9.034796e-09,6.188131e-05,0.0670777,0.00118,0.000137,0.000324
4,code_15_795649.mat,test,code_15_795649_0.mat,2500,0.00298,5.6e-05,1.536473e-11,2.363584e-06,2.680685e-08,0.2555,...,6.478713e-09,2.62673e-12,1.008909e-05,4.761355e-07,1.761901e-11,1.501665e-06,8.729744e-07,0.480006,0.00746,0.005422
5,code_15_1238362.mat,test,code_15_1238362_0.mat,2500,0.00559,2.5e-05,2.469575e-09,0.0644764,0.3009445,0.136349,...,6.150279e-07,2.775399e-10,0.4705916,1.189895e-07,1.026724e-11,0.0005486983,0.8215199,8e-06,0.087187,0.004861
6,code_15_2969044.mat,test,code_15_2969044_0.mat,2500,0.005372,0.000162,1.247481e-08,1.449286e-05,1.508343e-07,7.9e-05,...,5.95428e-09,5.4438340000000006e-17,0.003062968,1.733397e-06,6.188186e-10,1.011417e-08,9.30058e-06,0.998813,0.038875,0.000885
6,code_15_2969044.mat,test,code_15_2969044_1.mat,2500,0.224791,0.000241,2.770369e-10,1.00198e-06,2.593149e-08,0.000822,...,9.283297e-10,9.714495e-18,0.0002026472,1.320838e-07,8.966632e-11,1.494184e-07,2.114854e-06,0.996955,0.071183,0.000123


In [13]:
# Perform a (crude) thresholding of 0.5 for all labels
pred_thresh = pred.copy()
pred_thresh[physionet2021_label_names] = pred_thresh[physionet2021_label_names] > 0.5

# Construct a readable column of predicted labels for each sample
pred_thresh['labels'] = pred_thresh[physionet2021_label_names].apply(
    lambda row: ', '.join(row.index[row]),
    axis=1,
)
pred_thresh['labels']

idx
0                    IAVB
1              CRBBB|RBBB
2    CRBBB|RBBB, LAD, NSR
2              CRBBB|RBBB
3                IAVB, SB
3                    IAVB
4                PAC|SVPB
5                      SB
6                   STach
6                   STach
7                    IAVB
8                   Brady
9    CRBBB|RBBB, PAC|SVPB
9    CRBBB|RBBB, PAC|SVPB
Name: labels, dtype: object

### Map predictions to CODE-15 labels

In [9]:
code_15_label_def = pd.read_csv(
    os.path.join(root, 'data/code_15/labels/label_def.csv'),
     index_col='name',
)
code_15_label_names = code_15_label_def.index
code_15_label_def

Unnamed: 0_level_0,pos_count_all,pos_percent_all
name,Unnamed: 1_level_1,Unnamed: 2_level_1
is_male,138528,0.402691
1dAVb,5699,0.016567
RBBB,9652,0.028058
LBBB,6011,0.017474
SB,5588,0.016244
ST,7571,0.022008
AF,7008,0.020372
normal_ecg,134497,0.390973


In [10]:
label_mapping = {
    'CRBBB|RBBB': 'RBBB',
    'CLBBB|LBBB': 'LBBB',
    'SB': 'SB',
    'STach': 'ST',
    'AF': 'AF',
}

physionet2021_label_def['name_mapped'] = physionet2021_label_def.index.map(label_mapping)
physionet2021_label_def

Unnamed: 0_level_0,pos_count_all,pos_percent_all,name_mapped
name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
AF,5230,0.060793,AF
AFL,8271,0.096142,
BBB,490,0.005696,
Brady,283,0.00329,
CLBBB|LBBB,1487,0.017285,LBBB
CRBBB|RBBB,4794,0.055725,RBBB
IAVB,3516,0.04087,
IRBBB,1854,0.021551,
LAD,7614,0.088505,
LAnFB,2179,0.025329,


In [11]:
pred_mapped = pred.copy()
pred_mapped.drop(set(physionet2021_label_names) - set(label_mapping.keys()), axis=1, inplace=True)
pred_mapped.rename(label_mapping, axis=1, inplace=True)
pred_mapped

Unnamed: 0_level_0,save_file,split,path,sample_size,AF,LBBB,RBBB,SB,ST
idx,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
0,code_15_438277.mat,test,code_15_438277_0.mat,2500,0.00115,1.055977e-06,0.004426,0.0004729232,4.5e-05
1,code_15_358121.mat,test,code_15_358121_0.mat,2500,0.002214,4.526197e-06,0.999969,1.612027e-07,0.000795
2,code_15_1594286.mat,test,code_15_1594286_0.mat,2500,0.000137,0.08700616,0.997258,0.01905788,0.000454
2,code_15_1594286.mat,test,code_15_1594286_1.mat,2500,7.6e-05,0.003409652,0.999826,0.0001750806,0.000296
3,code_15_975093.mat,test,code_15_975093_0.mat,2500,8.1e-05,1.005381e-08,0.000283,0.9244674,0.000849
3,code_15_975093.mat,test,code_15_975093_1.mat,2500,2.5e-05,1.790568e-08,0.001132,0.0670777,0.00118
4,code_15_795649.mat,test,code_15_795649_0.mat,2500,0.00298,2.680685e-08,0.2555,8.729744e-07,0.480006
5,code_15_1238362.mat,test,code_15_1238362_0.mat,2500,0.00559,0.3009445,0.136349,0.8215199,8e-06
6,code_15_2969044.mat,test,code_15_2969044_0.mat,2500,0.005372,1.508343e-07,7.9e-05,9.30058e-06,0.998813
6,code_15_2969044.mat,test,code_15_2969044_1.mat,2500,0.224791,2.593149e-08,0.000822,2.114854e-06,0.996955


In [15]:
pred_thresh_mapped = pred_thresh.copy()
pred_thresh_mapped.drop(set(physionet2021_label_names) - set(label_mapping.keys()), axis=1, inplace=True)
pred_thresh_mapped.rename(label_mapping, axis=1, inplace=True)
pred_thresh_mapped['predicted'] = pred_thresh_mapped[label_mapping.values()].apply(
    lambda row: ', '.join(row.index[row]),
    axis=1,
)
pred_thresh_mapped

Unnamed: 0_level_0,save_file,split,path,sample_size,AF,LBBB,RBBB,SB,ST,labels,predicted
idx,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1
0,code_15_438277.mat,test,code_15_438277_0.mat,2500,False,False,False,False,False,IAVB,
1,code_15_358121.mat,test,code_15_358121_0.mat,2500,False,False,True,False,False,CRBBB|RBBB,RBBB
2,code_15_1594286.mat,test,code_15_1594286_0.mat,2500,False,False,True,False,False,"CRBBB|RBBB, LAD, NSR",RBBB
2,code_15_1594286.mat,test,code_15_1594286_1.mat,2500,False,False,True,False,False,CRBBB|RBBB,RBBB
3,code_15_975093.mat,test,code_15_975093_0.mat,2500,False,False,False,True,False,"IAVB, SB",SB
3,code_15_975093.mat,test,code_15_975093_1.mat,2500,False,False,False,False,False,IAVB,
4,code_15_795649.mat,test,code_15_795649_0.mat,2500,False,False,False,False,False,PAC|SVPB,
5,code_15_1238362.mat,test,code_15_1238362_0.mat,2500,False,False,False,True,False,SB,SB
6,code_15_2969044.mat,test,code_15_2969044_0.mat,2500,False,False,False,False,True,STach,ST
6,code_15_2969044.mat,test,code_15_2969044_1.mat,2500,False,False,False,False,True,STach,ST


### Compare predicted CODE-15 to actual

In [16]:
code_15_labels = pd.read_csv(os.path.join(root, 'data/code_15/labels/labels.csv'), index_col='idx')
code_15_labels['actual'] = code_15_labels[label_mapping.values()].apply(
    lambda row: ', '.join(row.index[row]),
    axis=1,
)
code_15_labels

Unnamed: 0_level_0,is_male,1dAVb,RBBB,LBBB,SB,ST,AF,normal_ecg,actual
idx,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
0,True,True,False,False,False,False,False,False,
1,True,False,True,False,False,False,False,False,RBBB
2,False,False,True,False,False,False,False,False,RBBB
3,True,True,False,False,False,False,False,False,
4,True,False,False,False,False,True,False,False,ST
5,True,False,False,False,True,False,False,False,SB
6,False,False,False,False,False,True,False,False,ST
7,False,True,False,False,False,False,False,False,
8,True,False,False,False,True,False,False,False,SB
9,True,False,True,False,False,False,False,False,RBBB


In [17]:
# Visualize predicted and actual labels side-by-side
pred_thresh_mapped[['predicted']].join(code_15_labels[['actual']], how='left')

# Calculate accuracy
# Compare predicted and actual labels
comparison = pred_thresh_mapped[['predicted']].join(code_15_labels[['actual']], how='left')

# Calculate overall accuracy (exact match)
accuracy = (comparison['predicted'] == comparison['actual']).mean()
print(f"Overall accuracy: {accuracy:.2%}")


Overall accuracy: 78.57%


In [None]:
from sklearn.metrics import confusion_matrix, precision_score, recall_score, f1_score, classification_report
from sklearn.preprocessing import MultiLabelBinarizer

# Extract the actual and predicted labels
y_true_str = comparison['actual']
y_pred_str = comparison['predicted']

# Split into list-of-labels
y_true_list = [labels.split(", ") for labels in y_true_str]
y_pred_list = [labels.split(", ") for labels in y_pred_str]

# Binarize
mlb = MultiLabelBinarizer()
y_true_bin = mlb.fit_transform(y_true_list)
y_pred_bin = mlb.transform(y_pred_list)

# Classification report
report = classification_report(
    y_true_bin,
    y_pred_bin,
    target_names=mlb.classes_,
    zero_division=0
)
print(report)

              precision    recall  f1-score   support

                   0.60      0.75      0.67         4
        RBBB       1.00      1.00      1.00         5
          SB       0.50      0.50      0.50         2
          ST       1.00      0.67      0.80         3

   micro avg       0.79      0.79      0.79        14
   macro avg       0.78      0.73      0.74        14
weighted avg       0.81      0.79      0.79        14
 samples avg       0.79      0.79      0.79        14



In [24]:
from sklearn.metrics import classification_report

report = classification_report(
    y_true_str,
    y_pred_str,
    target_names=mlb.classes_,
    zero_division=0 
)
print(report)


              precision    recall  f1-score   support

                   0.60      0.75      0.67         4
        RBBB       1.00      1.00      1.00         5
          SB       0.50      0.50      0.50         2
          ST       1.00      0.67      0.80         3

    accuracy                           0.79        14
   macro avg       0.78      0.73      0.74        14
weighted avg       0.81      0.79      0.79        14



# 5. Extra - Load models

Outside of the scripts/hydra client, models can be easily loaded as shown below:

In [2]:
from fairseq_signals.models import build_model_from_checkpoint



In [12]:
model_finetuned = build_model_from_checkpoint(
    checkpoint_path=os.path.join(root, 'ckpts/physionet_finetuned.pt')
)
model_finetuned

ECGTransformerClassificationModel(
  (encoder): ECGTransformerModel(
    (dropout_input): Dropout(p=0.0, inplace=False)
    (dropout_features): Dropout(p=0.0, inplace=False)
    (encoder): TransformerEncoder(
      (layers): ModuleList(
        (0-11): 12 x TransformerEncoderLayer(
          (self_attn): MultiHeadAttention(
            (dropout): Dropout()
            (k_proj): Linear(in_features=768, out_features=768, bias=True)
            (v_proj): Linear(in_features=768, out_features=768, bias=True)
            (q_proj): Linear(in_features=768, out_features=768, bias=True)
            (out_proj): Linear(in_features=768, out_features=768, bias=True)
          )
          (dropout1): Dropout(p=0.0, inplace=False)
          (dropout2): Dropout(p=0.1, inplace=False)
          (dropout3): Dropout(p=0.0, inplace=False)
          (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (fc1): Linear(in_features=768, out_features=3072, bias=True)
          (f

In [None]:
# Run if the pretrained model hasn't already been downloaded
from huggingface_hub import hf_hub_download

_ = hf_hub_download(
    repo_id='wanglab/ecg-fm-preprint',
    filename='mimic_iv_ecg_physionet_pretrained.pt',
    local_dir=os.path.join(root, 'notebooks/ckpts'),
)
_ = hf_hub_download(
    repo_id='wanglab/ecg-fm-preprint',
    filename='mimic_iv_ecg_physionet_pretrained.yaml',
    local_dir=os.path.join(root, 'notebooks/ckpts'),
)

mimic_iv_ecg_physionet_pretrained.yaml:   0%|          | 0.00/3.53k [00:00<?, ?B/s]

In [13]:
model_pretrained = build_model_from_checkpoint(
    checkpoint_path=os.path.join(root, 'ckpts/mimic_iv_ecg_physionet_pretrained.pt')
)
model_pretrained

Wav2Vec2CMSCModel(
  (dropout_input): Dropout(p=0.1, inplace=False)
  (dropout_features): Dropout(p=0.1, inplace=False)
  (encoder): TransformerEncoder(
    (layers): ModuleList(
      (0-11): 12 x TransformerEncoderLayer(
        (self_attn): MultiHeadAttention(
          (dropout): Dropout()
          (k_proj): Linear(in_features=768, out_features=768, bias=True)
          (v_proj): Linear(in_features=768, out_features=768, bias=True)
          (q_proj): Linear(in_features=768, out_features=768, bias=True)
          (out_proj): Linear(in_features=768, out_features=768, bias=True)
        )
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.0, inplace=False)
        (dropout3): Dropout(p=0.1, inplace=False)
        (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (fc2): Linear(in_features=3072, out_features=768, bias=True)
        (final_layer_norm):