In [3]:
# Importing stock ml libraries
import os
import time
import numpy as np
import pandas as pd
from sklearn import metrics
import transformers
import torch
from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler
from transformers import BertTokenizer, BertModel, BertConfig, BertForSequenceClassification
from transformers import AutoTokenizer, AutoModel, AutoModelForSequenceClassification

In [4]:
# Setting up the device for GPU usage
from torch import cuda
device = 'cuda' if cuda.is_available() else 'cpu'
device

'cuda'

In [5]:
# Root label (source = ASRS coding forms) : order = by descending frequency
anomaly_labels=['Deviation / Discrepancy - Procedural',
                    'Aircraft Equipment',
                    'Conflict',
                    'Inflight Event / Encounter',
                    'ATC Issue',
                    'Deviation - Altitude',
                    'Deviation - Track / Heading',
                    'Ground Event / Encounter',
                    'Flight Deck / Cabin / Aircraft Event',
                    'Ground Incursion',
                    'Airspace Violation',
                    'Deviation - Speed',
                    'Ground Excursion',
                    'No Specific Anomaly Occurred']

In [6]:
# Function to check prefixes and include 'Other' category
def check_prefixes(anomaly, prefixes):
    if pd.isna(anomaly):
        # Return a series of 0s if the anomaly is NaN
        return pd.Series({prefix: 0 for prefix in prefixes + ['Other']})
    
    split_anomalies = [item.strip() for item in anomaly.split(';')]
    prefix_matches = {prefix: any(item.startswith(prefix) for item in split_anomalies) for prefix in prefixes}
    prefix_matches['Other'] = not any(prefix_matches.values())  # If no prefix matches, this is 'Other'
    return pd.Series(prefix_matches)

In [7]:
# from google.colab import drive
# drive.mount('/content/drive')

drop the NaN values in Anomaly?

In [8]:
loaded_data = pd.read_pickle("./data/train_data_final.pkl")
pd.to_pickle(loaded_data, "./data/train_data_final.pkl")

In [9]:
loaded_data = pd.read_pickle("./data/train_data_final.pkl")

train_df = loaded_data[0]
print("\nA Dataframe with", len(train_df), "entries has been loaded")

# Apply this function to each row in the 'Anomaly' column
train_anomaly_encoding = train_df['Anomaly'].apply(lambda x: check_prefixes(x, anomaly_labels))
train_df['anomaly_encoding'] = train_anomaly_encoding.values.tolist()
train_df.head()



A Dataframe with 97417 entries has been loaded


Unnamed: 0_level_0,Date,Local Time Of Day,Locale Reference,State Reference,Relative Position.Angle.Radial,Relative Position.Distance.Nautical Miles,Altitude.AGL.Single Value,Altitude.MSL.Single Value,Flight Conditions,Weather Elements / Visibility,...,Result,Contributing Factors / Situations,Primary Problem,Narrative,Callback,Narrative.1,Callback.1,Synopsis,Year,anomaly_encoding
ACN,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
1163382,201404,1201-1800,ZZZ.Airport,US,,,0.0,,VMC,,...,Flight Crew Returned To Gate; Flight Crew Reje...,Aircraft; Environment - Non Weather Related; P...,Environment - Non Weather Related,I was the pilot flying performing the takeoff....,,At approximately 75 KTS I glanced at my airspe...,,A B767 Captain; the pilot not flying; rejected...,2014,"[1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]"
893734,201006,1801-2400,SFO.Airport,CA,,,0.0,,,,...,General None Reported / Taken,Human Factors,Human Factors,We had 6 shipments of dry ice for the flight; ...,,,,A B767-300 Pilot reported his Dangerous Goods ...,2010,"[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]"
991883,201201,0601-1200,EGLL.Airport,FO,,,,4000.0,,,...,General None Reported / Taken,Procedure; Company Policy; Human Factors; Manuals,Company Policy,I have seen a lot of mistakes on every flight ...,,,,The Captain of an international flight crew re...,2012,"[1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0]"
1590076,201810,0001-0600,EUG.Airport,OR,,,,3900.0,VMC,,...,Flight Crew Took Evasive Action,Human Factors; Procedure,Ambiguous,It was my first time flying into KEUG and I wa...,,Night VMC visual approach left base leg into E...,,Air carrier flight crew reported receiving a T...,2018,"[0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]"
1715282,202001,1801-2400,MDW.Airport,IL,,5.0,,2000.0,,,...,General None Reported / Taken,Procedure,Procedure,I am writing this report to bring attention to...,,,,Air Carrier First Officer reported that the us...,2020,"[0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]"


In [10]:
loaded_data = pd.read_pickle("./data/test_data_final.pkl")

test_df = loaded_data[0]
print("\nA Dataframe with", len(test_df), "entries has been loaded")

# Apply this function to each row in the 'Anomaly' column
test_anomaly_encoding = test_df['Anomaly'].apply(lambda x: check_prefixes(x, anomaly_labels))
test_df['anomaly_encoding'] = test_anomaly_encoding.values.tolist()
test_df.head()



A Dataframe with 10824 entries has been loaded


Unnamed: 0_level_0,Date,Local Time Of Day,Locale Reference,State Reference,Relative Position.Angle.Radial,Relative Position.Distance.Nautical Miles,Altitude.AGL.Single Value,Altitude.MSL.Single Value,Flight Conditions,Weather Elements / Visibility,...,Result,Contributing Factors / Situations,Primary Problem,Narrative,Callback,Narrative.1,Callback.1,Synopsis,Year,anomaly_encoding
ACN,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
1014798,201206,0601-1200,SLC.Airport,UT,,,,11300.0,VMC,,...,General None Reported / Taken,Aircraft; Human Factors,Aircraft,Flying into SLC on the DELTA THREE RNAV arriva...,The Reporter stated that his aircraft is equip...,,,A CE750 Captain noted that his aircraft's FMS ...,2012,"[1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0]"
1806744,202105,1201-1800,ORD.Airport,IL,,,,3900.0,,,...,Flight Crew FLC complied w / Automation / Advi...,Human Factors; Procedure; Airspace Structure; ...,Airspace Structure,ORD was on a very busy east flow arrival push....,,,,C90TRACON Controller reported they did not not...,2021,"[1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]"
1044902,201210,0001-0600,S46.TRACON,WA,,,,,,,...,General None Reported / Taken,ATC Equipment / Nav Facility / Buildings,ATC Equipment / Nav Facility / Buildings,B737-800 was vectored to an ILS Runway 16L app...,,,,S46 Controller expressed concern regarding the...,2012,"[1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]"
1764093,202009,0601-1200,ZZZ.Tower,US,,,400.0,,VMC,,...,Flight Crew Executed Go Around / Missed Approach,Human Factors; Procedure,Human Factors,We were on a 6 mile final when tower cleared a...,,While on about a six mile final tower cleared ...,,CRJ-200 flight crew reported failing to retrac...,2020,"[1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0]"
1786435,202102,1201-1800,ZZZ.ARTCC,US,,,,17000.0,,,...,Air Traffic Control Issued New Clearance; Flig...,Environment - Non Weather Related; Human Facto...,Environment - Non Weather Related,During Climb we Leveled at 17;000 departure sw...,,after copilot (pf) leved at 17000'; dfw depart...,,Air carrier First Officer reported an altitude...,2021,"[1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0]"


In [17]:
#
# MODEL_NAME = "model"
MODEL_NAME = None
MODEL_DIRECTORY = "model_save"


# Sections of configBertTokenizer
# Defining some key variables that will be used later on in the training
MAX_LEN = 512
TRAIN_BATCH_SIZE = 32 # 32 Size for NASA
VALID_BATCH_SIZE = 32
EPOCHS = 5 # 5 Epochs for NASA
LEARNING_RATE = 1e-05 * 2 # 0.00002 Rate for NASA
# tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
tokenizer = AutoTokenizer.from_pretrained("NASA-AIML/MIKA_SafeAeroBERT")

In [12]:
class CustomDataset(Dataset):

    def __init__(self, dataframe, tokenizer, max_len):
        self.tokenizer = tokenizer
        self.data = dataframe
        self.narrative = dataframe.Narrative
        self.targets = self.data.anomaly_encoding
        self.max_len = max_len

    def __len__(self):
        return len(self.narrative)

    def __getitem__(self, index):
        narrative = str(self.narrative.iloc[index])
        narrative = " ".join(narrative.split())

        inputs = self.tokenizer(
            narrative,
            None,
            add_special_tokens=True,
            max_length=self.max_len,
            padding='max_length',
            truncation=True,
            return_token_type_ids=True
        )
        ids = inputs['input_ids']
        mask = inputs['attention_mask']
        token_type_ids = inputs["token_type_ids"]


        return {
            'ids': torch.tensor(ids, dtype=torch.long),
            'mask': torch.tensor(mask, dtype=torch.long),
            'token_type_ids': torch.tensor(token_type_ids, dtype=torch.long),
            'targets': torch.tensor(self.targets.iloc[index], dtype=torch.float)
        }


In [13]:
# Creating the dataset and dataloader for the neural network
print("TRAIN Dataset: {}".format(train_df.shape))
print("TEST Dataset: {}".format(test_df.shape))

training_set = CustomDataset(train_df, tokenizer, MAX_LEN)
testing_set = CustomDataset(test_df, tokenizer, MAX_LEN)

TRAIN Dataset: (97417, 97)
TEST Dataset: (10824, 97)


In [14]:
train_params = {'batch_size': TRAIN_BATCH_SIZE,
                'shuffle': True,
                'num_workers': 2
                }

test_params = {'batch_size': VALID_BATCH_SIZE,
                'shuffle': True,
                'num_workers': 2
                }

training_loader = DataLoader(training_set, **train_params)
testing_loader = DataLoader(testing_set, **test_params)

In [15]:
torch.cuda.empty_cache()

In [18]:
class BERTClass(torch.nn.Module):
    def __init__(self, num_labels=15):
        super(BERTClass, self).__init__()
        self.l1 = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=num_labels,)

    def forward(self, ids, mask, token_type_ids):
        output = self.l1(ids, attention_mask=mask, token_type_ids=token_type_ids)
        return output.logits

model = BERTClass()

# Freeze all layers in the model
for param in model.parameters():
    param.requires_grad = False

# Unfreeze the classifier and pooler layers
for param in model.l1.classifier.parameters():
    param.requires_grad = True

for param in model.l1.bert.pooler.parameters():
    param.requires_grad = True

model.to(device)

Downloading config.json: 100%|██████████| 570/570 [00:00<00:00, 33.1kB/s]
Downloading model.safetensors: 100%|██████████| 440M/440M [00:28<00:00, 15.3MB/s] 
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


BERTClass(
  (l1): BertForSequenceClassification(
    (bert): BertModel(
      (embeddings): BertEmbeddings(
        (word_embeddings): Embedding(30522, 768, padding_idx=0)
        (position_embeddings): Embedding(512, 768)
        (token_type_embeddings): Embedding(2, 768)
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (encoder): BertEncoder(
        (layer): ModuleList(
          (0-11): 12 x BertLayer(
            (attention): BertAttention(
              (self): BertSelfAttention(
                (query): Linear(in_features=768, out_features=768, bias=True)
                (key): Linear(in_features=768, out_features=768, bias=True)
                (value): Linear(in_features=768, out_features=768, bias=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (output): BertSelfOutput(
                (dense): Linear(in_features=768, out_features=768, bias=Tr

In [19]:
class SafeAeroBERTClass(torch.nn.Module):
    def __init__(self, num_labels=15):
        super(SafeAeroBERTClass, self).__init__()
        self.l1 = AutoModelForSequenceClassification.from_pretrained("NASA-AIML/MIKA_SafeAeroBERT", num_labels=num_labels,)

    def forward(self, ids, mask, token_type_ids):
        output = self.l1(ids, attention_mask=mask, token_type_ids=token_type_ids)
        return output.logits

model = SafeAeroBERTClass()

# Freeze all layers in the model
for param in model.parameters():
    param.requires_grad = False

# Unfreeze the classifier and pooler layers
for param in model.l1.classifier.parameters():
    param.requires_grad = True

for param in model.l1.bert.pooler.parameters():
    param.requires_grad = True

model.to(device)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at NASA-AIML/MIKA_SafeAeroBERT and are newly initialized: ['classifier.bias', 'bert.pooler.dense.weight', 'classifier.weight', 'bert.pooler.dense.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


SafeAeroBERTClass(
  (l1): BertForSequenceClassification(
    (bert): BertModel(
      (embeddings): BertEmbeddings(
        (word_embeddings): Embedding(30522, 768, padding_idx=0)
        (position_embeddings): Embedding(512, 768)
        (token_type_embeddings): Embedding(2, 768)
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (encoder): BertEncoder(
        (layer): ModuleList(
          (0-11): 12 x BertLayer(
            (attention): BertAttention(
              (self): BertSelfAttention(
                (query): Linear(in_features=768, out_features=768, bias=True)
                (key): Linear(in_features=768, out_features=768, bias=True)
                (value): Linear(in_features=768, out_features=768, bias=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (output): BertSelfOutput(
                (dense): Linear(in_features=768, out_features=768,

In [20]:
loss_fn = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(params =  model.parameters(), lr=LEARNING_RATE)
metrics_dict = {
    "Accuracy": metrics.accuracy_score,
    "F1 Micro Score": lambda y_true, y_pred: metrics.f1_score(y_true, y_pred, average='micro', zero_division=0),
    "F1 Macro Score": lambda y_true, y_pred: metrics.f1_score(y_true, y_pred, average='macro', zero_division=0)
}

In [21]:
def save_model(model, epoch, directory='model_save', model_name=None):
    """
    Saves the model state.

    Args:
    model (torch.nn.Module): The model to save.
    epoch (int): The current epoch number.
    file_path (str): Base directory to save the models.
    """
    if model_name is None:
        model_name = model.__class__.__name__

    if not os.path.exists(directory):
        os.makedirs(directory)
    
    file_path = os.path.join(directory, f"{model_name}_epoch_{epoch}.pth")

    torch.save(model.state_dict(), file_path)
    print(f'Model saved at {file_path}')


In [22]:
def load_model(model, directory='model_save', model_name=None, epoch=None):
    """
    Loads the model state.

    Args:
    model (torch.nn.Module): The model to load state into.
    file_path (str): Path to the saved model file.
    """
    if model_name is None:
        model_name = model.__class__.__name__

    if epoch is None:
        epoch = find_last_saved_epoch(directory, model_name)
        if epoch == -1:
            print("No saved model found.")
            return
    
    file_path = os.path.join(directory, f"{model_name}_epoch_{epoch}.pth")
    if not os.path.exists(file_path):
        print(f"No model file found at {file_path}")
        return

    model.load_state_dict(torch.load(file_path))
    model.to(device)
    print(f'Model loaded from {file_path}')

In [23]:
def find_last_saved_epoch(directory='model_save', model_name=None):
    """
    Finds the last saved epoch number in the specified directory.

    Args:
    file_path (str): The directory where models are saved.

    Returns:
    int: The last saved epoch number. Returns -1 if no saved model is found.
    """
    if model_name is None:
        model_name = model.__class__.__name__

    # Check if the directory exists, and create it if it doesn't
    if not os.path.exists(directory):
        return -1

    saved_epochs = []
    for filename in os.listdir(directory):
        if model_name is None or filename.startswith(model_name):
            parts = filename.replace('.pth', '').split('_')
            if parts[-2] == 'epoch':
                try:
                    saved_epochs.append(int(parts[-1]))
                except ValueError:
                    pass
    
    return max(saved_epochs, default=-1)

In [24]:
def process_batch(model, batch_data, device, loss_fn, mode, optimizer=None):
    ids = batch_data['ids'].to(device, dtype=torch.long)
    mask = batch_data['mask'].to(device, dtype=torch.long)
    token_type_ids = batch_data['token_type_ids'].to(device, dtype=torch.long)
    targets = batch_data['targets'].to(device, dtype=torch.float)

    if mode == 'train':
        optimizer.zero_grad()
        outputs = model(ids, mask, token_type_ids)
        loss = loss_fn(outputs, targets)
        loss.backward()
        optimizer.step()
    else:
        with torch.no_grad():
            outputs = model(ids, mask, token_type_ids)
            loss = loss_fn(outputs, targets)

    return outputs, targets, loss


In [25]:
def calculate_metrics(metrics_dict, targets, outputs):
    return {metric_name: metric_fn(targets, outputs) for metric_name, metric_fn in metrics_dict.items()}


In [26]:
def print_batch_results(mode, epoch, batch, dataset_size, loss, batch_metrics_results, start_time, batch_start_time, batch_size):
    current_time = time.time()
    elapsed_time = current_time - start_time
    batch_time_ms = (current_time - batch_start_time) * 1000

    current = (batch + 1) * batch_size
    metric_str = ", ".join([f"{metric}: {value:.4f}" for metric, value in batch_metrics_results.items()])
    epoch_str = f"Epoch: {epoch+1}, " if epoch is not None else ""
    
    print(f"\r{mode.capitalize()} - {epoch_str}Batch: {batch+1} [{current:>5d}/{dataset_size:>5d}], "
          f"Time: {elapsed_time:.0f}s {batch_time_ms:.0f}ms/step, Loss: {loss:>7f}, {metric_str}", end="")


In [27]:
def process_batches(mode, model, loader, device, loss_fn, metrics_dict, optimizer=None, epoch=None):
    model.train() if mode == 'train' else model.eval()
    total_loss = 0.0
    all_targets = []
    all_outputs = []
    start_time = time.time()

    for batch, data in enumerate(loader, 0):
        batch_start_time = time.time()
        
        outputs, targets, loss = process_batch(model, data, device, loss_fn, mode, optimizer)
        total_loss += loss.item()

        outputs_binary = torch.sigmoid(outputs).cpu().detach().numpy() >= 0.5
        targets = targets.cpu().detach().numpy()
        all_outputs.extend(outputs_binary)
        all_targets.extend(targets)

        batch_metrics_results = calculate_metrics(metrics_dict, targets, outputs_binary)
        batch_size = targets.shape[0]
        print_batch_results(mode, epoch, batch, len(loader.dataset), loss.item(), batch_metrics_results, start_time, batch_start_time, batch_size)

    print()
    avg_loss = total_loss / len(loader)
    return avg_loss, all_outputs, all_targets

In [28]:
def evaluate(model, validation_loader, loss_fn, metrics_dict, device):
    avg_val_loss, val_outputs, val_targets = process_batches('evaluate', model, validation_loader, device, loss_fn, metrics_dict)
    val_outputs = np.array(val_outputs) >= 0.5

    metrics_results = {metric_name: metric_fn(val_targets, val_outputs) for metric_name, metric_fn in metrics_dict.items()}

    print(f"Average Validation Loss: {avg_val_loss:.4f}")
    for metric, value in metrics_results.items():
        print(f"Validation {metric}: {value:.4f}")

    return avg_val_loss, metrics_results


In [29]:
def train(model, epoch, training_loader, validation_loader, optimizer, loss_fn, metrics_dict, device):
    print(f"Training Epoch {epoch + 1}")

    # Training phase
    avg_train_loss, _, _ = process_batches('train', model, training_loader, device, loss_fn, metrics_dict, optimizer, epoch=epoch)
    print(f"Average Training Loss for Epoch {epoch + 1}: {avg_train_loss:.4f}")

    # Validation phase
    if validation_loader is not None:
        avg_val_loss, val_metrics_results = evaluate(model, validation_loader, loss_fn, metrics_dict, device)
    else:
        avg_val_loss = None
        val_metrics_results = {}

    return avg_train_loss, avg_val_loss, val_metrics_results


In [30]:
last_saved_epoch = find_last_saved_epoch(directory=MODEL_DIRECTORY, model_name=MODEL_NAME)

start_epoch = last_saved_epoch + 1 if last_saved_epoch != -1 else 0
if last_saved_epoch != -1:
    load_model(model, directory=MODEL_DIRECTORY, model_name=MODEL_NAME, epoch=last_saved_epoch)
    print(f"Resuming training from epoch {start_epoch + 1}")
else:
    print("No saved model found. Starting training from the beginning.")

for epoch in range(start_epoch, EPOCHS):
    train_loss, val_loss, val_metrics = train(model, epoch, training_loader, testing_loader, optimizer, loss_fn, metrics_dict, device)
    save_model(model, epoch, directory=MODEL_DIRECTORY, model_name=MODEL_NAME)
    # Additional epoch-level processing if needed

print("Finished Training")

# Testing phase
avg_test_loss, test_metrics_results = evaluate(model, testing_loader, loss_fn, metrics_dict, device)
print(f"Testing Loss: {avg_test_loss}")
for metric, value in test_metrics_results.items():
    print(f"Testing {metric}: {value}")


No saved model found. Starting training from the beginning.
Training Epoch 1
Train - Epoch: 1, Batch: 20 [  640/97417], Time: 112s 5555ms/step, Loss: 0.560862, Accuracy: 0.0625, F1 Micro Score: 0.4000, F1 Macro Score: 0.0490

KeyboardInterrupt: 

In [None]:
# from torch.nn.utils.rnn import pad_sequence
# def collate_fn(batch):
#     ids = [item['ids'].clone().detach() for item in batch]
#     masks = [item['mask'].clone().detach() for item in batch]
#     token_type_ids = [item['token_type_ids'].clone().detach() for item in batch]
#     targets = [item['targets'].clone().detach() for item in batch]

#     # Padding the sequences to the maximum length in this batch
#     ids = pad_sequence(ids, batch_first=True, padding_value=tokenizer.pad_token_id)
#     masks = pad_sequence(masks, batch_first=True, padding_value=0)
#     token_type_ids = pad_sequence(token_type_ids, batch_first=True, padding_value=0)

#     targets = torch.stack(targets)

#     return {
#         'ids': ids,
#         'mask': masks,
#         'token_type_ids': token_type_ids,
#         'targets': targets
#     }