In [1]:
from transformers import ElectraTokenizer, ElectraForTokenClassification
import torch
from torch.utils.data import DataLoader, TensorDataset
from torch.nn.utils.rnn import pad_sequence

# Define entity and relation types
entity_types = ["Corporation", "Locale", "Application", "Malware", "Vulnerability", "Assault_technique",
                "Infected_file", "APT_group", "Ransomware", "Campaign", "Algorithm", "System", "Organization", "Indicator"]

relation_types = ["originate_from", "utilizes", "is_susceptible", "contains_product", "contains_file",
                  "belongs_to", "operates_in", "launches", "is_a_type_of", "indicates", "implements"]

# Define label mappings
entity_tags = ["O"] + [f"B-{entity}" for entity in entity_types] + [f"I-{entity}" for entity in entity_types] + [f"E-{entity}" for entity in entity_types] + [f"S-{entity}" for entity in entity_types]
relation_tags = ["O"] + [f"B-{relation}" for relation in relation_types] + [f"I-{relation}" for relation in relation_types] + [f"E-{relation}" for relation in relation_types] + [f"S-{relation}" for relation in relation_types]

entity_label_map = {tag: idx for idx, tag in enumerate(entity_tags)}
relation_label_map = {tag: idx for idx, tag in enumerate(relation_tags)}

# Add start and end tokens to entity_label_map
entity_label_map['<start>'] = len(entity_label_map)
entity_label_map['<end>'] = len(entity_label_map)

# Define the preprocessing function
def preprocess_ner_data(file_path, encoding='utf-8'):
    with open(file_path, 'r', encoding=encoding) as file:
        lines = file.readlines()

    sentences = []
    labels = []
    sentence = []
    label = []

    for line in lines:
        if line.strip() == "":
            if sentence:
                sentences.append(['<start>'] + sentence + ['<end>'])
                labels.append(['<start>'] + label + ['<end>'])
                sentence = []
                label = []
        else:
            word, tag = line.strip().split('\t')
            sentence.append(word)
            label.append(tag)

    if sentence:
        sentences.append(['<start>'] + sentence + ['<end>'])
        labels.append(['<start>'] + label + ['<end>'])

    return sentences, labels

# Preprocess data
train_sentences, train_labels = preprocess_ner_data("C:\\Users\\Admin\\Desktop\\nitt\\data\\MITREtrain.txt")
valid_sentences, valid_labels = preprocess_ner_data("C:\\Users\\Admin\\Desktop\\nitt\\data\\MITREvalid.txt")
test_sentences, test_labels = preprocess_ner_data("C:\\Users\\Admin\\Desktop\\nitt\\data\\MITREtest.txt")

print("Entity Label Map:")
print(entity_label_map)
print("\nRelation Label Map:")
print(relation_label_map)

# Initialize the ELECTRA tokenizer
tokenizer = ElectraTokenizer.from_pretrained('google/electra-base-discriminator')

# Function to tokenize and encode sentences
def tokenize_and_encode(sentences, labels, tokenizer, entity_label_map, max_length=512):
    all_input_ids = []
    all_attention_masks = []
    all_label_ids = []

    for sentence, label in zip(sentences, labels):
        # Tokenize each word separately
        tokenized_words = [tokenizer.tokenize(word) for word in sentence]
        
        # Flatten the list of tokenized words
        tokens = [token for word in tokenized_words for token in word]
        
        # Convert tokens to ids
        input_id = tokenizer.convert_tokens_to_ids(tokens)
        
        # Create attention mask
        attention_mask = [1] * len(input_id)
        
        # Align labels
        label_id = []
        for word, word_tokens in zip(label, tokenized_words):
            label_id.extend([entity_label_map.get(word, entity_label_map["O"])] * len(word_tokens))
        
        # Pad or truncate
        if len(input_id) < max_length:
            padding_length = max_length - len(input_id)
            input_id = input_id + [tokenizer.pad_token_id] * padding_length
            attention_mask = attention_mask + [0] * padding_length
            label_id = label_id + [entity_label_map['<pad>']] * padding_length
        else:
            input_id = input_id[:max_length]
            attention_mask = attention_mask[:max_length]
            label_id = label_id[:max_length]
        
        all_input_ids.append(input_id)
        all_attention_masks.append(attention_mask)
        all_label_ids.append(label_id)

    # Convert to tensors
    all_input_ids = torch.tensor(all_input_ids)
    all_attention_masks = torch.tensor(all_attention_masks)
    all_label_ids = torch.tensor(all_label_ids)
    
    return all_input_ids, all_attention_masks, all_label_ids

# Add padding token to entity_label_map
entity_label_map['<pad>'] = len(entity_label_map)

# Tokenize with adjusted max_length
train_inputs, train_masks, train_labels = tokenize_and_encode(train_sentences, train_labels, tokenizer, entity_label_map, max_length=512)
valid_inputs, valid_masks, valid_labels = tokenize_and_encode(valid_sentences, valid_labels, tokenizer, entity_label_map, max_length=512)
test_inputs, test_masks, test_labels = tokenize_and_encode(test_sentences, test_labels, tokenizer, entity_label_map, max_length=512)

# Create TensorDataset instances
train_data = TensorDataset(train_inputs, train_masks, train_labels)
valid_data = TensorDataset(valid_inputs, valid_masks, valid_labels)
test_data = TensorDataset(test_inputs, test_masks, test_labels)

train_batch_size = 16
train_loader = DataLoader(train_data, batch_size=train_batch_size, shuffle=True)
valid_loader = DataLoader(valid_data, batch_size=train_batch_size)
test_loader = DataLoader(test_data, batch_size=train_batch_size)

# Print shapes to verify
print("\nTrain Data Shapes:")
print("Input IDs:", train_inputs.shape)
print("Attention Masks:", train_masks.shape)
print("Label IDs:", train_labels.shape)

print(f"\nTotal number of labels: {len(entity_label_map)}")

# Initialize ELECTRA model for token classification
model = ElectraForTokenClassification.from_pretrained('google/electra-base-discriminator', num_labels=len(entity_label_map))

  from .autonotebook import tqdm as notebook_tqdm


Entity Label Map:
{'O': 0, 'B-Corporation': 1, 'B-Locale': 2, 'B-Application': 3, 'B-Malware': 4, 'B-Vulnerability': 5, 'B-Assault_technique': 6, 'B-Infected_file': 7, 'B-APT_group': 8, 'B-Ransomware': 9, 'B-Campaign': 10, 'B-Algorithm': 11, 'B-System': 12, 'B-Organization': 13, 'B-Indicator': 14, 'I-Corporation': 15, 'I-Locale': 16, 'I-Application': 17, 'I-Malware': 18, 'I-Vulnerability': 19, 'I-Assault_technique': 20, 'I-Infected_file': 21, 'I-APT_group': 22, 'I-Ransomware': 23, 'I-Campaign': 24, 'I-Algorithm': 25, 'I-System': 26, 'I-Organization': 27, 'I-Indicator': 28, 'E-Corporation': 29, 'E-Locale': 30, 'E-Application': 31, 'E-Malware': 32, 'E-Vulnerability': 33, 'E-Assault_technique': 34, 'E-Infected_file': 35, 'E-APT_group': 36, 'E-Ransomware': 37, 'E-Campaign': 38, 'E-Algorithm': 39, 'E-System': 40, 'E-Organization': 41, 'E-Indicator': 42, 'S-Corporation': 43, 'S-Locale': 44, 'S-Application': 45, 'S-Malware': 46, 'S-Vulnerability': 47, 'S-Assault_technique': 48, 'S-Infected_fi

Some weights of ElectraForTokenClassification were not initialized from the model checkpoint at google/electra-base-discriminator and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [2]:
from TorchCRF import CRF
import torch
import torch.nn as nn
from transformers import ElectraModel, ElectraTokenizer, AdamW
from torch.utils.data import DataLoader
from sklearn.metrics import f1_score
import time
from seqeval.metrics import f1_score as seqeval_f1_score
from seqeval.scheme import IOB2

# Define device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class JointExtractionModel(nn.Module):
    def __init__(self, num_labels, hidden_size, num_layers, device):
        super(JointExtractionModel, self).__init__()
        self.device = device
        self.electra = ElectraModel.from_pretrained('google/electra-base-discriminator').to(device)
        self.gru = nn.GRU(768, hidden_size, num_layers=num_layers, bidirectional=True, batch_first=True).to(device)
        self.linear = nn.Linear(hidden_size * 2, num_labels).to(device)
        self.crf = CRF(num_labels).to(device)
        print(f"Model initialized with num_labels: {num_labels}")

    def forward(self, input_ids, attention_mask):
        electra_output = self.electra(input_ids=input_ids, attention_mask=attention_mask)
        hidden = electra_output.last_hidden_state
        gru_output, _ = self.gru(hidden)
        logits = self.linear(gru_output)
        return logits

    def loss(self, input_ids, attention_mask, tags):
        logits = self.forward(input_ids, attention_mask)
        log_likelihood = self.crf.forward(logits, tags, mask=attention_mask.bool())
        return -log_likelihood.mean()

    def decode(self, input_ids, attention_mask):
        logits = self.forward(input_ids, attention_mask)
        return self.crf.viterbi_decode(logits, mask=attention_mask.bool())

# Initialize the model
hidden_size = 256  # or any other value you prefer
num_layers = 2  # or any other value you prefer
num_labels = len(entity_label_map)
model = JointExtractionModel(num_labels, hidden_size, num_layers, device)

# Initialize the ELECTRA tokenizer
tokenizer = ElectraTokenizer.from_pretrained('google/electra-base-discriminator')

# The rest of your code (data loading, training loop, etc.) remains the same

Model initialized with num_labels: 60


In [3]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from transformers import AdamW, get_linear_schedule_with_warmup
from tqdm import tqdm
from seqeval.metrics import f1_score, precision_score, recall_score
from sklearn.metrics import f1_score as sklearn_f1_score
import numpy as np
from seqeval.metrics import f1_score as seqeval_f1_score
from seqeval.metrics import precision_score as seqeval_precision_score
from seqeval.metrics import recall_score as seqeval_recall_score
from collections import Counter

def train(model, train_loader, val_loader, test_loader, epochs, lr, entity_label_map, device, accumulation_steps=1):
    model = model.to(device)
    optimizer = AdamW(model.parameters(), lr=lr)
    total_steps = len(train_loader) * epochs // accumulation_steps
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=total_steps)
    best_val_f1 = 0.0
    
    for epoch in range(epochs):
        model.train()
        total_loss = 0
        optimizer.zero_grad()  # Zero gradients at the start of each epoch
        
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}")
        for step, batch in enumerate(progress_bar):
            input_ids, attention_mask, tags = [b.to(device) for b in batch]
            loss = model.loss(input_ids, attention_mask, tags)
            loss = loss / accumulation_steps  # Normalize loss for gradient accumulation
            loss.backward()
            
            if (step + 1) % accumulation_steps == 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)  # Gradient clipping
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()
            
            total_loss += loss.item() * accumulation_steps
            
            # Update progress bar
            progress_bar.set_postfix({'loss': f"{loss.item():.4f}"})
        
        avg_loss = total_loss / len(train_loader)
        val_f1, val_precision, val_recall = evaluate(model, val_loader, entity_label_map, device)
        
        print(f'Epoch {epoch+1}/{epochs}:')
        print(f'  Train Loss: {avg_loss:.4f}')
        print(f'  Val F1: {val_f1:.4f}')
        print(f'  Val Precision: {val_precision:.4f}')
        print(f'  Val Recall: {val_recall:.4f}')
        
        if val_f1 > best_val_f1:
            best_val_f1 = val_f1
            torch.save(model.state_dict(), 'best_electra_ner_model_16.pt')
    
    model.load_state_dict(torch.load('best_electra_ner_model_16.pt', map_location=device))
    test_f1, test_precision, test_recall = evaluate(model, test_loader, entity_label_map, device)
    print(f'Test F1-score: {test_f1:.4f}')
    print(f'Test Precision: {test_precision:.4f}')
    print(f'Test Recall: {test_recall:.4f}')
    
    return best_val_f1, test_f1, test_precision, test_recall

In [4]:
def evaluate(model, data_loader, entity_label_map, device):
    model.eval()
    all_preds = []
    all_labels = []
    inv_label_map = {v: k for k, v in entity_label_map.items()}

    with torch.no_grad():
        for batch in data_loader:
            input_ids, attention_mask, tags = [b.to(device) for b in batch]
            predictions = model.decode(input_ids, attention_mask)
            
            for pred, true, mask in zip(predictions, tags, attention_mask):
                pred = [inv_label_map[p] for p, m in zip(pred, mask) if m.item() == 1]
                true = [inv_label_map[t.item()] for t, m in zip(true, mask) if m.item() == 1 and t.item() != -100]
                
                # Ensure pred and true have the same length
                min_len = min(len(pred), len(true))
                pred = pred[:min_len]
                true = true[:min_len]
                
                all_preds.append(pred)
                all_labels.append(true)

    # Print some debug information
    print(f"Number of predicted sequences: {len(all_preds)}")
    print(f"Number of true label sequences: {len(all_labels)}")
    print(f"Lengths of first few predicted sequences: {[len(p) for p in all_preds[:5]]}")
    print(f"Lengths of first few true label sequences: {[len(t) for t in all_labels[:5]]}")

    f1 = f1_score(all_labels, all_preds)
    precision = precision_score(all_labels, all_preds)
    recall = recall_score(all_labels, all_preds)

    return f1, precision, recall

In [5]:
epochs = 100
learning_rate = 2e-5
accumulation_steps = 4  # Adjust as needed

best_val_f1, test_f1, test_precision, test_recall = train(
    model, 
    train_loader, 
    valid_loader, 
    test_loader, 
    epochs, 
    learning_rate, 
    entity_label_map,
    device,
    accumulation_steps
)

Epoch 1/100: 100%|█████████████████████████████████████████████████████| 176/176 [03:34<00:00,  1.22s/it, loss=10.3091]


Number of predicted sequences: 813
Number of true label sequences: 813
Lengths of first few predicted sequences: [46, 56, 20, 18, 24]
Lengths of first few true label sequences: [46, 56, 20, 18, 24]




Epoch 1/100:
  Train Loss: 91.6144
  Val F1: 0.5097
  Val Precision: 0.8399
  Val Recall: 0.3659


Epoch 2/100: 100%|██████████████████████████████████████████████████████| 176/176 [03:33<00:00,  1.21s/it, loss=4.2172]


Number of predicted sequences: 813
Number of true label sequences: 813
Lengths of first few predicted sequences: [46, 56, 20, 18, 24]
Lengths of first few true label sequences: [46, 56, 20, 18, 24]
Epoch 2/100:
  Train Loss: 32.3981
  Val F1: 0.7673
  Val Precision: 0.8301
  Val Recall: 0.7133


Epoch 3/100: 100%|██████████████████████████████████████████████████████| 176/176 [03:27<00:00,  1.18s/it, loss=3.7686]


Number of predicted sequences: 813
Number of true label sequences: 813
Lengths of first few predicted sequences: [46, 56, 20, 18, 24]
Lengths of first few true label sequences: [46, 56, 20, 18, 24]
Epoch 3/100:
  Train Loss: 17.3750
  Val F1: 0.7839
  Val Precision: 0.8575
  Val Recall: 0.7219


Epoch 4/100: 100%|██████████████████████████████████████████████████████| 176/176 [03:28<00:00,  1.19s/it, loss=5.2968]


Number of predicted sequences: 813
Number of true label sequences: 813
Lengths of first few predicted sequences: [46, 56, 20, 18, 24]
Lengths of first few true label sequences: [46, 56, 20, 18, 24]
Epoch 4/100:
  Train Loss: 11.1603
  Val F1: 0.7854
  Val Precision: 0.8178
  Val Recall: 0.7554


Epoch 5/100: 100%|██████████████████████████████████████████████████████| 176/176 [03:28<00:00,  1.19s/it, loss=1.9497]


Number of predicted sequences: 813
Number of true label sequences: 813
Lengths of first few predicted sequences: [46, 56, 20, 18, 24]
Lengths of first few true label sequences: [46, 56, 20, 18, 24]
Epoch 5/100:
  Train Loss: 8.1052
  Val F1: 0.8137
  Val Precision: 0.8107
  Val Recall: 0.8168


Epoch 6/100: 100%|██████████████████████████████████████████████████████| 176/176 [03:28<00:00,  1.18s/it, loss=0.7818]


Number of predicted sequences: 813
Number of true label sequences: 813
Lengths of first few predicted sequences: [46, 56, 20, 18, 24]
Lengths of first few true label sequences: [46, 56, 20, 18, 24]
Epoch 6/100:
  Train Loss: 6.5135
  Val F1: 0.8273
  Val Precision: 0.8380
  Val Recall: 0.8170


Epoch 7/100: 100%|██████████████████████████████████████████████████████| 176/176 [03:27<00:00,  1.18s/it, loss=1.9797]


Number of predicted sequences: 813
Number of true label sequences: 813
Lengths of first few predicted sequences: [46, 56, 20, 18, 24]
Lengths of first few true label sequences: [46, 56, 20, 18, 24]
Epoch 7/100:
  Train Loss: 5.5602
  Val F1: 0.8395
  Val Precision: 0.8263
  Val Recall: 0.8530


Epoch 8/100: 100%|██████████████████████████████████████████████████████| 176/176 [03:27<00:00,  1.18s/it, loss=0.6014]


Number of predicted sequences: 813
Number of true label sequences: 813
Lengths of first few predicted sequences: [46, 56, 20, 18, 24]
Lengths of first few true label sequences: [46, 56, 20, 18, 24]
Epoch 8/100:
  Train Loss: 4.7416
  Val F1: 0.8505
  Val Precision: 0.8344
  Val Recall: 0.8671


Epoch 9/100: 100%|██████████████████████████████████████████████████████| 176/176 [03:28<00:00,  1.19s/it, loss=0.6839]


Number of predicted sequences: 813
Number of true label sequences: 813
Lengths of first few predicted sequences: [46, 56, 20, 18, 24]
Lengths of first few true label sequences: [46, 56, 20, 18, 24]
Epoch 9/100:
  Train Loss: 4.0533
  Val F1: 0.8480
  Val Precision: 0.8337
  Val Recall: 0.8628


Epoch 10/100: 100%|█████████████████████████████████████████████████████| 176/176 [03:27<00:00,  1.18s/it, loss=1.3772]


Number of predicted sequences: 813
Number of true label sequences: 813
Lengths of first few predicted sequences: [46, 56, 20, 18, 24]
Lengths of first few true label sequences: [46, 56, 20, 18, 24]
Epoch 10/100:
  Train Loss: 3.4717
  Val F1: 0.8504
  Val Precision: 0.8442
  Val Recall: 0.8568


Epoch 11/100: 100%|█████████████████████████████████████████████████████| 176/176 [03:27<00:00,  1.18s/it, loss=0.5569]


Number of predicted sequences: 813
Number of true label sequences: 813
Lengths of first few predicted sequences: [46, 56, 20, 18, 24]
Lengths of first few true label sequences: [46, 56, 20, 18, 24]
Epoch 11/100:
  Train Loss: 3.1219
  Val F1: 0.8585
  Val Precision: 0.8579
  Val Recall: 0.8591


Epoch 12/100: 100%|█████████████████████████████████████████████████████| 176/176 [03:28<00:00,  1.19s/it, loss=0.3974]


Number of predicted sequences: 813
Number of true label sequences: 813
Lengths of first few predicted sequences: [46, 56, 20, 18, 24]
Lengths of first few true label sequences: [46, 56, 20, 18, 24]
Epoch 12/100:
  Train Loss: 2.6871
  Val F1: 0.8688
  Val Precision: 0.8722
  Val Recall: 0.8654


Epoch 13/100: 100%|█████████████████████████████████████████████████████| 176/176 [03:28<00:00,  1.18s/it, loss=0.7661]


Number of predicted sequences: 813
Number of true label sequences: 813
Lengths of first few predicted sequences: [46, 56, 20, 18, 24]
Lengths of first few true label sequences: [46, 56, 20, 18, 24]
Epoch 13/100:
  Train Loss: 2.2488
  Val F1: 0.8730
  Val Precision: 0.8874
  Val Recall: 0.8591


Epoch 14/100: 100%|█████████████████████████████████████████████████████| 176/176 [03:28<00:00,  1.19s/it, loss=0.2991]


Number of predicted sequences: 813
Number of true label sequences: 813
Lengths of first few predicted sequences: [46, 56, 20, 18, 24]
Lengths of first few true label sequences: [46, 56, 20, 18, 24]
Epoch 14/100:
  Train Loss: 2.0206
  Val F1: 0.8779
  Val Precision: 0.8756
  Val Recall: 0.8803


Epoch 15/100: 100%|█████████████████████████████████████████████████████| 176/176 [03:27<00:00,  1.18s/it, loss=0.0718]


Number of predicted sequences: 813
Number of true label sequences: 813
Lengths of first few predicted sequences: [46, 56, 20, 18, 24]
Lengths of first few true label sequences: [46, 56, 20, 18, 24]
Epoch 15/100:
  Train Loss: 1.7367
  Val F1: 0.8769
  Val Precision: 0.8663
  Val Recall: 0.8877


Epoch 16/100: 100%|█████████████████████████████████████████████████████| 176/176 [03:28<00:00,  1.19s/it, loss=0.2243]


Number of predicted sequences: 813
Number of true label sequences: 813
Lengths of first few predicted sequences: [46, 56, 20, 18, 24]
Lengths of first few true label sequences: [46, 56, 20, 18, 24]
Epoch 16/100:
  Train Loss: 1.4289
  Val F1: 0.8848
  Val Precision: 0.8702
  Val Recall: 0.8999


Epoch 17/100: 100%|█████████████████████████████████████████████████████| 176/176 [03:27<00:00,  1.18s/it, loss=0.8221]


Number of predicted sequences: 813
Number of true label sequences: 813
Lengths of first few predicted sequences: [46, 56, 20, 18, 24]
Lengths of first few true label sequences: [46, 56, 20, 18, 24]
Epoch 17/100:
  Train Loss: 1.3129
  Val F1: 0.8736
  Val Precision: 0.8874
  Val Recall: 0.8603


Epoch 18/100: 100%|█████████████████████████████████████████████████████| 176/176 [03:29<00:00,  1.19s/it, loss=0.0262]


Number of predicted sequences: 813
Number of true label sequences: 813
Lengths of first few predicted sequences: [46, 56, 20, 18, 24]
Lengths of first few true label sequences: [46, 56, 20, 18, 24]
Epoch 18/100:
  Train Loss: 1.1785
  Val F1: 0.8864
  Val Precision: 0.8760
  Val Recall: 0.8971


Epoch 19/100: 100%|█████████████████████████████████████████████████████| 176/176 [03:28<00:00,  1.18s/it, loss=0.1222]


Number of predicted sequences: 813
Number of true label sequences: 813
Lengths of first few predicted sequences: [46, 56, 20, 18, 24]
Lengths of first few true label sequences: [46, 56, 20, 18, 24]
Epoch 19/100:
  Train Loss: 1.0021
  Val F1: 0.8934
  Val Precision: 0.8861
  Val Recall: 0.9008


Epoch 20/100: 100%|█████████████████████████████████████████████████████| 176/176 [03:29<00:00,  1.19s/it, loss=0.1749]


Number of predicted sequences: 813
Number of true label sequences: 813
Lengths of first few predicted sequences: [46, 56, 20, 18, 24]
Lengths of first few true label sequences: [46, 56, 20, 18, 24]
Epoch 20/100:
  Train Loss: 0.8661
  Val F1: 0.8854
  Val Precision: 0.8890
  Val Recall: 0.8818


Epoch 21/100: 100%|█████████████████████████████████████████████████████| 176/176 [03:29<00:00,  1.19s/it, loss=0.1563]


Number of predicted sequences: 813
Number of true label sequences: 813
Lengths of first few predicted sequences: [46, 56, 20, 18, 24]
Lengths of first few true label sequences: [46, 56, 20, 18, 24]
Epoch 21/100:
  Train Loss: 0.7833
  Val F1: 0.8872
  Val Precision: 0.8742
  Val Recall: 0.9005


Epoch 22/100: 100%|█████████████████████████████████████████████████████| 176/176 [03:27<00:00,  1.18s/it, loss=0.2329]


Number of predicted sequences: 813
Number of true label sequences: 813
Lengths of first few predicted sequences: [46, 56, 20, 18, 24]
Lengths of first few true label sequences: [46, 56, 20, 18, 24]
Epoch 22/100:
  Train Loss: 0.6967
  Val F1: 0.8735
  Val Precision: 0.8671
  Val Recall: 0.8799


Epoch 23/100: 100%|█████████████████████████████████████████████████████| 176/176 [03:27<00:00,  1.18s/it, loss=0.0662]


Number of predicted sequences: 813
Number of true label sequences: 813
Lengths of first few predicted sequences: [46, 56, 20, 18, 24]
Lengths of first few true label sequences: [46, 56, 20, 18, 24]
Epoch 23/100:
  Train Loss: 0.6874
  Val F1: 0.8857
  Val Precision: 0.8741
  Val Recall: 0.8977


Epoch 24/100: 100%|█████████████████████████████████████████████████████| 176/176 [03:28<00:00,  1.19s/it, loss=0.0288]


Number of predicted sequences: 813
Number of true label sequences: 813
Lengths of first few predicted sequences: [46, 56, 20, 18, 24]
Lengths of first few true label sequences: [46, 56, 20, 18, 24]
Epoch 24/100:
  Train Loss: 0.5442
  Val F1: 0.8848
  Val Precision: 0.8871
  Val Recall: 0.8824


Epoch 25/100: 100%|█████████████████████████████████████████████████████| 176/176 [03:27<00:00,  1.18s/it, loss=0.1173]


Number of predicted sequences: 813
Number of true label sequences: 813
Lengths of first few predicted sequences: [46, 56, 20, 18, 24]
Lengths of first few true label sequences: [46, 56, 20, 18, 24]
Epoch 25/100:
  Train Loss: 0.5340
  Val F1: 0.8855
  Val Precision: 0.8815
  Val Recall: 0.8895


Epoch 26/100: 100%|█████████████████████████████████████████████████████| 176/176 [03:28<00:00,  1.19s/it, loss=0.1898]


Number of predicted sequences: 813
Number of true label sequences: 813
Lengths of first few predicted sequences: [46, 56, 20, 18, 24]
Lengths of first few true label sequences: [46, 56, 20, 18, 24]
Epoch 26/100:
  Train Loss: 0.4493
  Val F1: 0.8892
  Val Precision: 0.8819
  Val Recall: 0.8967


Epoch 27/100: 100%|█████████████████████████████████████████████████████| 176/176 [03:29<00:00,  1.19s/it, loss=0.1163]


Number of predicted sequences: 813
Number of true label sequences: 813
Lengths of first few predicted sequences: [46, 56, 20, 18, 24]
Lengths of first few true label sequences: [46, 56, 20, 18, 24]
Epoch 27/100:
  Train Loss: 0.4056
  Val F1: 0.8868
  Val Precision: 0.8782
  Val Recall: 0.8956


Epoch 28/100: 100%|█████████████████████████████████████████████████████| 176/176 [03:31<00:00,  1.20s/it, loss=0.2056]


Number of predicted sequences: 813
Number of true label sequences: 813
Lengths of first few predicted sequences: [46, 56, 20, 18, 24]
Lengths of first few true label sequences: [46, 56, 20, 18, 24]
Epoch 28/100:
  Train Loss: 0.3294
  Val F1: 0.8899
  Val Precision: 0.8808
  Val Recall: 0.8993


Epoch 29/100: 100%|█████████████████████████████████████████████████████| 176/176 [03:27<00:00,  1.18s/it, loss=0.0334]


Number of predicted sequences: 813
Number of true label sequences: 813
Lengths of first few predicted sequences: [46, 56, 20, 18, 24]
Lengths of first few true label sequences: [46, 56, 20, 18, 24]
Epoch 29/100:
  Train Loss: 0.3097
  Val F1: 0.8822
  Val Precision: 0.8630
  Val Recall: 0.9022


Epoch 30/100: 100%|█████████████████████████████████████████████████████| 176/176 [03:29<00:00,  1.19s/it, loss=0.0152]


Number of predicted sequences: 813
Number of true label sequences: 813
Lengths of first few predicted sequences: [46, 56, 20, 18, 24]
Lengths of first few true label sequences: [46, 56, 20, 18, 24]
Epoch 30/100:
  Train Loss: 0.2881
  Val F1: 0.8822
  Val Precision: 0.8785
  Val Recall: 0.8859


Epoch 31/100: 100%|█████████████████████████████████████████████████████| 176/176 [03:29<00:00,  1.19s/it, loss=0.0158]


Number of predicted sequences: 813
Number of true label sequences: 813
Lengths of first few predicted sequences: [46, 56, 20, 18, 24]
Lengths of first few true label sequences: [46, 56, 20, 18, 24]
Epoch 31/100:
  Train Loss: 0.2484
  Val F1: 0.8783
  Val Precision: 0.8863
  Val Recall: 0.8705


Epoch 32/100: 100%|█████████████████████████████████████████████████████| 176/176 [03:27<00:00,  1.18s/it, loss=0.0131]


Number of predicted sequences: 813
Number of true label sequences: 813
Lengths of first few predicted sequences: [46, 56, 20, 18, 24]
Lengths of first few true label sequences: [46, 56, 20, 18, 24]
Epoch 32/100:
  Train Loss: 0.2253
  Val F1: 0.8891
  Val Precision: 0.8772
  Val Recall: 0.9014


Epoch 33/100: 100%|█████████████████████████████████████████████████████| 176/176 [03:29<00:00,  1.19s/it, loss=0.0495]


Number of predicted sequences: 813
Number of true label sequences: 813
Lengths of first few predicted sequences: [46, 56, 20, 18, 24]
Lengths of first few true label sequences: [46, 56, 20, 18, 24]
Epoch 33/100:
  Train Loss: 0.2326
  Val F1: 0.8858
  Val Precision: 0.8781
  Val Recall: 0.8936


Epoch 34/100: 100%|█████████████████████████████████████████████████████| 176/176 [03:30<00:00,  1.19s/it, loss=0.0177]


Number of predicted sequences: 813
Number of true label sequences: 813
Lengths of first few predicted sequences: [46, 56, 20, 18, 24]
Lengths of first few true label sequences: [46, 56, 20, 18, 24]
Epoch 34/100:
  Train Loss: 0.1874
  Val F1: 0.8952
  Val Precision: 0.8840
  Val Recall: 0.9067


Epoch 35/100: 100%|█████████████████████████████████████████████████████| 176/176 [03:29<00:00,  1.19s/it, loss=0.0254]


Number of predicted sequences: 813
Number of true label sequences: 813
Lengths of first few predicted sequences: [46, 56, 20, 18, 24]
Lengths of first few true label sequences: [46, 56, 20, 18, 24]
Epoch 35/100:
  Train Loss: 0.1947
  Val F1: 0.8857
  Val Precision: 0.8655
  Val Recall: 0.9069


Epoch 36/100: 100%|█████████████████████████████████████████████████████| 176/176 [03:28<00:00,  1.18s/it, loss=0.0153]


Number of predicted sequences: 813
Number of true label sequences: 813
Lengths of first few predicted sequences: [46, 56, 20, 18, 24]
Lengths of first few true label sequences: [46, 56, 20, 18, 24]
Epoch 36/100:
  Train Loss: 0.1788
  Val F1: 0.8884
  Val Precision: 0.8781
  Val Recall: 0.8989


Epoch 37/100: 100%|█████████████████████████████████████████████████████| 176/176 [03:28<00:00,  1.19s/it, loss=0.0082]


Number of predicted sequences: 813
Number of true label sequences: 813
Lengths of first few predicted sequences: [46, 56, 20, 18, 24]
Lengths of first few true label sequences: [46, 56, 20, 18, 24]
Epoch 37/100:
  Train Loss: 0.1382
  Val F1: 0.8893
  Val Precision: 0.8789
  Val Recall: 0.9001


Epoch 38/100: 100%|█████████████████████████████████████████████████████| 176/176 [03:29<00:00,  1.19s/it, loss=0.0062]


Number of predicted sequences: 813
Number of true label sequences: 813
Lengths of first few predicted sequences: [46, 56, 20, 18, 24]
Lengths of first few true label sequences: [46, 56, 20, 18, 24]
Epoch 38/100:
  Train Loss: 0.1603
  Val F1: 0.8923
  Val Precision: 0.8903
  Val Recall: 0.8942


Epoch 39/100: 100%|█████████████████████████████████████████████████████| 176/176 [03:29<00:00,  1.19s/it, loss=0.0307]


Number of predicted sequences: 813
Number of true label sequences: 813
Lengths of first few predicted sequences: [46, 56, 20, 18, 24]
Lengths of first few true label sequences: [46, 56, 20, 18, 24]
Epoch 39/100:
  Train Loss: 0.1597
  Val F1: 0.8865
  Val Precision: 0.8783
  Val Recall: 0.8950


Epoch 40/100: 100%|█████████████████████████████████████████████████████| 176/176 [03:28<00:00,  1.19s/it, loss=0.0098]


Number of predicted sequences: 813
Number of true label sequences: 813
Lengths of first few predicted sequences: [46, 56, 20, 18, 24]
Lengths of first few true label sequences: [46, 56, 20, 18, 24]
Epoch 40/100:
  Train Loss: 0.1236
  Val F1: 0.8873
  Val Precision: 0.8763
  Val Recall: 0.8985


Epoch 41/100: 100%|█████████████████████████████████████████████████████| 176/176 [03:28<00:00,  1.18s/it, loss=0.0066]


Number of predicted sequences: 813
Number of true label sequences: 813
Lengths of first few predicted sequences: [46, 56, 20, 18, 24]
Lengths of first few true label sequences: [46, 56, 20, 18, 24]
Epoch 41/100:
  Train Loss: 0.1230
  Val F1: 0.8887
  Val Precision: 0.8756
  Val Recall: 0.9022


Epoch 42/100: 100%|█████████████████████████████████████████████████████| 176/176 [03:29<00:00,  1.19s/it, loss=0.0033]


Number of predicted sequences: 813
Number of true label sequences: 813
Lengths of first few predicted sequences: [46, 56, 20, 18, 24]
Lengths of first few true label sequences: [46, 56, 20, 18, 24]
Epoch 42/100:
  Train Loss: 0.1192
  Val F1: 0.8880
  Val Precision: 0.8834
  Val Recall: 0.8926


Epoch 43/100: 100%|█████████████████████████████████████████████████████| 176/176 [03:29<00:00,  1.19s/it, loss=0.0037]


Number of predicted sequences: 813
Number of true label sequences: 813
Lengths of first few predicted sequences: [46, 56, 20, 18, 24]
Lengths of first few true label sequences: [46, 56, 20, 18, 24]
Epoch 43/100:
  Train Loss: 0.1172
  Val F1: 0.8895
  Val Precision: 0.8816
  Val Recall: 0.8975


Epoch 44/100: 100%|█████████████████████████████████████████████████████| 176/176 [03:28<00:00,  1.18s/it, loss=0.0042]


Number of predicted sequences: 813
Number of true label sequences: 813
Lengths of first few predicted sequences: [46, 56, 20, 18, 24]
Lengths of first few true label sequences: [46, 56, 20, 18, 24]
Epoch 44/100:
  Train Loss: 0.1086
  Val F1: 0.8862
  Val Precision: 0.8766
  Val Recall: 0.8961


Epoch 45/100: 100%|█████████████████████████████████████████████████████| 176/176 [03:28<00:00,  1.19s/it, loss=0.0108]


Number of predicted sequences: 813
Number of true label sequences: 813
Lengths of first few predicted sequences: [46, 56, 20, 18, 24]
Lengths of first few true label sequences: [46, 56, 20, 18, 24]
Epoch 45/100:
  Train Loss: 0.1112
  Val F1: 0.8874
  Val Precision: 0.8779
  Val Recall: 0.8971


Epoch 46/100: 100%|█████████████████████████████████████████████████████| 176/176 [03:27<00:00,  1.18s/it, loss=0.0054]


Number of predicted sequences: 813
Number of true label sequences: 813
Lengths of first few predicted sequences: [46, 56, 20, 18, 24]
Lengths of first few true label sequences: [46, 56, 20, 18, 24]
Epoch 46/100:
  Train Loss: 0.1122
  Val F1: 0.8783
  Val Precision: 0.8756
  Val Recall: 0.8811


Epoch 47/100: 100%|█████████████████████████████████████████████████████| 176/176 [03:28<00:00,  1.19s/it, loss=0.0070]


Number of predicted sequences: 813
Number of true label sequences: 813
Lengths of first few predicted sequences: [46, 56, 20, 18, 24]
Lengths of first few true label sequences: [46, 56, 20, 18, 24]
Epoch 47/100:
  Train Loss: 0.0861
  Val F1: 0.8827
  Val Precision: 0.8794
  Val Recall: 0.8859


Epoch 48/100: 100%|█████████████████████████████████████████████████████| 176/176 [03:29<00:00,  1.19s/it, loss=0.0153]


Number of predicted sequences: 813
Number of true label sequences: 813
Lengths of first few predicted sequences: [46, 56, 20, 18, 24]
Lengths of first few true label sequences: [46, 56, 20, 18, 24]
Epoch 48/100:
  Train Loss: 0.1007
  Val F1: 0.8776
  Val Precision: 0.8680
  Val Recall: 0.8875


Epoch 49/100: 100%|█████████████████████████████████████████████████████| 176/176 [03:29<00:00,  1.19s/it, loss=0.0031]


Number of predicted sequences: 813
Number of true label sequences: 813
Lengths of first few predicted sequences: [46, 56, 20, 18, 24]
Lengths of first few true label sequences: [46, 56, 20, 18, 24]
Epoch 49/100:
  Train Loss: 0.0862
  Val F1: 0.8894
  Val Precision: 0.8802
  Val Recall: 0.8987


Epoch 50/100: 100%|█████████████████████████████████████████████████████| 176/176 [03:28<00:00,  1.19s/it, loss=0.0049]


Number of predicted sequences: 813
Number of true label sequences: 813
Lengths of first few predicted sequences: [46, 56, 20, 18, 24]
Lengths of first few true label sequences: [46, 56, 20, 18, 24]
Epoch 50/100:
  Train Loss: 0.0607
  Val F1: 0.8824
  Val Precision: 0.8838
  Val Recall: 0.8811


Epoch 51/100: 100%|█████████████████████████████████████████████████████| 176/176 [03:29<00:00,  1.19s/it, loss=0.0043]


Number of predicted sequences: 813
Number of true label sequences: 813
Lengths of first few predicted sequences: [46, 56, 20, 18, 24]
Lengths of first few true label sequences: [46, 56, 20, 18, 24]
Epoch 51/100:
  Train Loss: 0.0757
  Val F1: 0.8867
  Val Precision: 0.8833
  Val Recall: 0.8901


Epoch 52/100: 100%|█████████████████████████████████████████████████████| 176/176 [03:28<00:00,  1.18s/it, loss=0.0030]


Number of predicted sequences: 813
Number of true label sequences: 813
Lengths of first few predicted sequences: [46, 56, 20, 18, 24]
Lengths of first few true label sequences: [46, 56, 20, 18, 24]
Epoch 52/100:
  Train Loss: 0.0699
  Val F1: 0.8809
  Val Precision: 0.8862
  Val Recall: 0.8758


Epoch 53/100: 100%|█████████████████████████████████████████████████████| 176/176 [03:27<00:00,  1.18s/it, loss=0.0027]


Number of predicted sequences: 813
Number of true label sequences: 813
Lengths of first few predicted sequences: [46, 56, 20, 18, 24]
Lengths of first few true label sequences: [46, 56, 20, 18, 24]
Epoch 53/100:
  Train Loss: 0.0599
  Val F1: 0.8865
  Val Precision: 0.8827
  Val Recall: 0.8905


Epoch 54/100: 100%|█████████████████████████████████████████████████████| 176/176 [03:27<00:00,  1.18s/it, loss=0.0034]


Number of predicted sequences: 813
Number of true label sequences: 813
Lengths of first few predicted sequences: [46, 56, 20, 18, 24]
Lengths of first few true label sequences: [46, 56, 20, 18, 24]
Epoch 54/100:
  Train Loss: 0.0747
  Val F1: 0.8864
  Val Precision: 0.8790
  Val Recall: 0.8940


Epoch 55/100: 100%|█████████████████████████████████████████████████████| 176/176 [03:28<00:00,  1.18s/it, loss=0.2463]


Number of predicted sequences: 813
Number of true label sequences: 813
Lengths of first few predicted sequences: [46, 56, 20, 18, 24]
Lengths of first few true label sequences: [46, 56, 20, 18, 24]
Epoch 55/100:
  Train Loss: 0.0593
  Val F1: 0.8855
  Val Precision: 0.8797
  Val Recall: 0.8914


Epoch 56/100: 100%|█████████████████████████████████████████████████████| 176/176 [03:28<00:00,  1.18s/it, loss=0.1458]


Number of predicted sequences: 813
Number of true label sequences: 813
Lengths of first few predicted sequences: [46, 56, 20, 18, 24]
Lengths of first few true label sequences: [46, 56, 20, 18, 24]
Epoch 56/100:
  Train Loss: 0.0646
  Val F1: 0.8866
  Val Precision: 0.8876
  Val Recall: 0.8856


Epoch 57/100: 100%|█████████████████████████████████████████████████████| 176/176 [03:28<00:00,  1.19s/it, loss=0.0031]


Number of predicted sequences: 813
Number of true label sequences: 813
Lengths of first few predicted sequences: [46, 56, 20, 18, 24]
Lengths of first few true label sequences: [46, 56, 20, 18, 24]
Epoch 57/100:
  Train Loss: 0.0471
  Val F1: 0.8805
  Val Precision: 0.8779
  Val Recall: 0.8832


Epoch 58/100: 100%|█████████████████████████████████████████████████████| 176/176 [03:27<00:00,  1.18s/it, loss=0.0022]


Number of predicted sequences: 813
Number of true label sequences: 813
Lengths of first few predicted sequences: [46, 56, 20, 18, 24]
Lengths of first few true label sequences: [46, 56, 20, 18, 24]
Epoch 58/100:
  Train Loss: 0.0473
  Val F1: 0.8824
  Val Precision: 0.8796
  Val Recall: 0.8852


Epoch 59/100: 100%|█████████████████████████████████████████████████████| 176/176 [03:27<00:00,  1.18s/it, loss=0.0011]


Number of predicted sequences: 813
Number of true label sequences: 813
Lengths of first few predicted sequences: [46, 56, 20, 18, 24]
Lengths of first few true label sequences: [46, 56, 20, 18, 24]
Epoch 59/100:
  Train Loss: 0.0367
  Val F1: 0.8817
  Val Precision: 0.8745
  Val Recall: 0.8891


Epoch 60/100: 100%|█████████████████████████████████████████████████████| 176/176 [03:29<00:00,  1.19s/it, loss=0.0024]


Number of predicted sequences: 813
Number of true label sequences: 813
Lengths of first few predicted sequences: [46, 56, 20, 18, 24]
Lengths of first few true label sequences: [46, 56, 20, 18, 24]
Epoch 60/100:
  Train Loss: 0.0365
  Val F1: 0.8833
  Val Precision: 0.8789
  Val Recall: 0.8877


Epoch 61/100: 100%|█████████████████████████████████████████████████████| 176/176 [03:30<00:00,  1.19s/it, loss=0.0013]


Number of predicted sequences: 813
Number of true label sequences: 813
Lengths of first few predicted sequences: [46, 56, 20, 18, 24]
Lengths of first few true label sequences: [46, 56, 20, 18, 24]
Epoch 61/100:
  Train Loss: 0.0334
  Val F1: 0.8842
  Val Precision: 0.8854
  Val Recall: 0.8830


Epoch 62/100: 100%|█████████████████████████████████████████████████████| 176/176 [03:29<00:00,  1.19s/it, loss=0.0029]


Number of predicted sequences: 813
Number of true label sequences: 813
Lengths of first few predicted sequences: [46, 56, 20, 18, 24]
Lengths of first few true label sequences: [46, 56, 20, 18, 24]
Epoch 62/100:
  Train Loss: 0.0363
  Val F1: 0.8868
  Val Precision: 0.8810
  Val Recall: 0.8926


Epoch 63/100: 100%|█████████████████████████████████████████████████████| 176/176 [03:28<00:00,  1.18s/it, loss=0.0036]


Number of predicted sequences: 813
Number of true label sequences: 813
Lengths of first few predicted sequences: [46, 56, 20, 18, 24]
Lengths of first few true label sequences: [46, 56, 20, 18, 24]
Epoch 63/100:
  Train Loss: 0.0442
  Val F1: 0.8825
  Val Precision: 0.8757
  Val Recall: 0.8893


Epoch 64/100: 100%|█████████████████████████████████████████████████████| 176/176 [03:28<00:00,  1.18s/it, loss=0.2023]


Number of predicted sequences: 813
Number of true label sequences: 813
Lengths of first few predicted sequences: [46, 56, 20, 18, 24]
Lengths of first few true label sequences: [46, 56, 20, 18, 24]
Epoch 64/100:
  Train Loss: 0.0341
  Val F1: 0.8836
  Val Precision: 0.8732
  Val Recall: 0.8944


Epoch 65/100: 100%|█████████████████████████████████████████████████████| 176/176 [03:29<00:00,  1.19s/it, loss=0.0009]


Number of predicted sequences: 813
Number of true label sequences: 813
Lengths of first few predicted sequences: [46, 56, 20, 18, 24]
Lengths of first few true label sequences: [46, 56, 20, 18, 24]
Epoch 65/100:
  Train Loss: 0.0304
  Val F1: 0.8854
  Val Precision: 0.8820
  Val Recall: 0.8889


Epoch 66/100: 100%|█████████████████████████████████████████████████████| 176/176 [03:28<00:00,  1.18s/it, loss=0.0035]


Number of predicted sequences: 813
Number of true label sequences: 813
Lengths of first few predicted sequences: [46, 56, 20, 18, 24]
Lengths of first few true label sequences: [46, 56, 20, 18, 24]
Epoch 66/100:
  Train Loss: 0.0325
  Val F1: 0.8849
  Val Precision: 0.8778
  Val Recall: 0.8922


Epoch 67/100: 100%|█████████████████████████████████████████████████████| 176/176 [03:29<00:00,  1.19s/it, loss=0.0008]


Number of predicted sequences: 813
Number of true label sequences: 813
Lengths of first few predicted sequences: [46, 56, 20, 18, 24]
Lengths of first few true label sequences: [46, 56, 20, 18, 24]
Epoch 67/100:
  Train Loss: 0.0219
  Val F1: 0.8810
  Val Precision: 0.8902
  Val Recall: 0.8720


Epoch 68/100: 100%|█████████████████████████████████████████████████████| 176/176 [03:29<00:00,  1.19s/it, loss=0.0022]


Number of predicted sequences: 813
Number of true label sequences: 813
Lengths of first few predicted sequences: [46, 56, 20, 18, 24]
Lengths of first few true label sequences: [46, 56, 20, 18, 24]
Epoch 68/100:
  Train Loss: 0.0241
  Val F1: 0.8874
  Val Precision: 0.8836
  Val Recall: 0.8912


Epoch 69/100: 100%|█████████████████████████████████████████████████████| 176/176 [03:29<00:00,  1.19s/it, loss=0.0007]


Number of predicted sequences: 813
Number of true label sequences: 813
Lengths of first few predicted sequences: [46, 56, 20, 18, 24]
Lengths of first few true label sequences: [46, 56, 20, 18, 24]
Epoch 69/100:
  Train Loss: 0.0225
  Val F1: 0.8857
  Val Precision: 0.8842
  Val Recall: 0.8871


Epoch 70/100: 100%|█████████████████████████████████████████████████████| 176/176 [03:28<00:00,  1.18s/it, loss=0.0009]


Number of predicted sequences: 813
Number of true label sequences: 813
Lengths of first few predicted sequences: [46, 56, 20, 18, 24]
Lengths of first few true label sequences: [46, 56, 20, 18, 24]
Epoch 70/100:
  Train Loss: 0.0218
  Val F1: 0.8870
  Val Precision: 0.8842
  Val Recall: 0.8899


Epoch 71/100: 100%|█████████████████████████████████████████████████████| 176/176 [03:28<00:00,  1.18s/it, loss=0.0041]


Number of predicted sequences: 813
Number of true label sequences: 813
Lengths of first few predicted sequences: [46, 56, 20, 18, 24]
Lengths of first few true label sequences: [46, 56, 20, 18, 24]
Epoch 71/100:
  Train Loss: 0.0197
  Val F1: 0.8860
  Val Precision: 0.8791
  Val Recall: 0.8930


Epoch 72/100: 100%|█████████████████████████████████████████████████████| 176/176 [03:26<00:00,  1.17s/it, loss=0.0026]


Number of predicted sequences: 813
Number of true label sequences: 813
Lengths of first few predicted sequences: [46, 56, 20, 18, 24]
Lengths of first few true label sequences: [46, 56, 20, 18, 24]
Epoch 72/100:
  Train Loss: 0.0345
  Val F1: 0.8877
  Val Precision: 0.8815
  Val Recall: 0.8940


Epoch 73/100: 100%|█████████████████████████████████████████████████████| 176/176 [03:26<00:00,  1.18s/it, loss=0.0574]


Number of predicted sequences: 813
Number of true label sequences: 813
Lengths of first few predicted sequences: [46, 56, 20, 18, 24]
Lengths of first few true label sequences: [46, 56, 20, 18, 24]
Epoch 73/100:
  Train Loss: 0.0306
  Val F1: 0.8883
  Val Precision: 0.8797
  Val Recall: 0.8971


Epoch 74/100: 100%|█████████████████████████████████████████████████████| 176/176 [03:26<00:00,  1.17s/it, loss=0.0010]


Number of predicted sequences: 813
Number of true label sequences: 813
Lengths of first few predicted sequences: [46, 56, 20, 18, 24]
Lengths of first few true label sequences: [46, 56, 20, 18, 24]
Epoch 74/100:
  Train Loss: 0.0222
  Val F1: 0.8865
  Val Precision: 0.8842
  Val Recall: 0.8889


Epoch 75/100: 100%|█████████████████████████████████████████████████████| 176/176 [03:27<00:00,  1.18s/it, loss=0.0031]


Number of predicted sequences: 813
Number of true label sequences: 813
Lengths of first few predicted sequences: [46, 56, 20, 18, 24]
Lengths of first few true label sequences: [46, 56, 20, 18, 24]
Epoch 75/100:
  Train Loss: 0.0182
  Val F1: 0.8839
  Val Precision: 0.8795
  Val Recall: 0.8883


Epoch 76/100: 100%|█████████████████████████████████████████████████████| 176/176 [03:27<00:00,  1.18s/it, loss=0.0011]


Number of predicted sequences: 813
Number of true label sequences: 813
Lengths of first few predicted sequences: [46, 56, 20, 18, 24]
Lengths of first few true label sequences: [46, 56, 20, 18, 24]
Epoch 76/100:
  Train Loss: 0.0301
  Val F1: 0.8844
  Val Precision: 0.8835
  Val Recall: 0.8854


Epoch 77/100: 100%|█████████████████████████████████████████████████████| 176/176 [03:27<00:00,  1.18s/it, loss=0.0023]


Number of predicted sequences: 813
Number of true label sequences: 813
Lengths of first few predicted sequences: [46, 56, 20, 18, 24]
Lengths of first few true label sequences: [46, 56, 20, 18, 24]
Epoch 77/100:
  Train Loss: 0.0305
  Val F1: 0.8845
  Val Precision: 0.8748
  Val Recall: 0.8944


Epoch 78/100: 100%|█████████████████████████████████████████████████████| 176/176 [03:27<00:00,  1.18s/it, loss=0.0017]


Number of predicted sequences: 813
Number of true label sequences: 813
Lengths of first few predicted sequences: [46, 56, 20, 18, 24]
Lengths of first few true label sequences: [46, 56, 20, 18, 24]
Epoch 78/100:
  Train Loss: 0.0223
  Val F1: 0.8821
  Val Precision: 0.8782
  Val Recall: 0.8859


Epoch 79/100: 100%|█████████████████████████████████████████████████████| 176/176 [03:39<00:00,  1.25s/it, loss=0.0016]


Number of predicted sequences: 813
Number of true label sequences: 813
Lengths of first few predicted sequences: [46, 56, 20, 18, 24]
Lengths of first few true label sequences: [46, 56, 20, 18, 24]
Epoch 79/100:
  Train Loss: 0.0212
  Val F1: 0.8836
  Val Precision: 0.8804
  Val Recall: 0.8869


Epoch 80/100: 100%|█████████████████████████████████████████████████████| 176/176 [03:45<00:00,  1.28s/it, loss=0.0018]


Number of predicted sequences: 813
Number of true label sequences: 813
Lengths of first few predicted sequences: [46, 56, 20, 18, 24]
Lengths of first few true label sequences: [46, 56, 20, 18, 24]
Epoch 80/100:
  Train Loss: 0.0275
  Val F1: 0.8854
  Val Precision: 0.8761
  Val Recall: 0.8950


Epoch 81/100: 100%|█████████████████████████████████████████████████████| 176/176 [03:50<00:00,  1.31s/it, loss=0.0008]


Number of predicted sequences: 813
Number of true label sequences: 813
Lengths of first few predicted sequences: [46, 56, 20, 18, 24]
Lengths of first few true label sequences: [46, 56, 20, 18, 24]
Epoch 81/100:
  Train Loss: 0.0164
  Val F1: 0.8828
  Val Precision: 0.8842
  Val Recall: 0.8814


Epoch 82/100: 100%|█████████████████████████████████████████████████████| 176/176 [03:40<00:00,  1.25s/it, loss=0.0018]


Number of predicted sequences: 813
Number of true label sequences: 813
Lengths of first few predicted sequences: [46, 56, 20, 18, 24]
Lengths of first few true label sequences: [46, 56, 20, 18, 24]
Epoch 82/100:
  Train Loss: 0.0107
  Val F1: 0.8827
  Val Precision: 0.8811
  Val Recall: 0.8844


Epoch 83/100: 100%|█████████████████████████████████████████████████████| 176/176 [03:27<00:00,  1.18s/it, loss=0.0009]


Number of predicted sequences: 813
Number of true label sequences: 813
Lengths of first few predicted sequences: [46, 56, 20, 18, 24]
Lengths of first few true label sequences: [46, 56, 20, 18, 24]
Epoch 83/100:
  Train Loss: 0.0259
  Val F1: 0.8836
  Val Precision: 0.8776
  Val Recall: 0.8897


Epoch 84/100: 100%|█████████████████████████████████████████████████████| 176/176 [03:39<00:00,  1.25s/it, loss=0.0011]


Number of predicted sequences: 813
Number of true label sequences: 813
Lengths of first few predicted sequences: [46, 56, 20, 18, 24]
Lengths of first few true label sequences: [46, 56, 20, 18, 24]
Epoch 84/100:
  Train Loss: 0.0209
  Val F1: 0.8872
  Val Precision: 0.8784
  Val Recall: 0.8961


Epoch 85/100: 100%|█████████████████████████████████████████████████████| 176/176 [03:42<00:00,  1.27s/it, loss=0.0008]


Number of predicted sequences: 813
Number of true label sequences: 813
Lengths of first few predicted sequences: [46, 56, 20, 18, 24]
Lengths of first few true label sequences: [46, 56, 20, 18, 24]
Epoch 85/100:
  Train Loss: 0.0257
  Val F1: 0.8867
  Val Precision: 0.8853
  Val Recall: 0.8881


Epoch 86/100: 100%|█████████████████████████████████████████████████████| 176/176 [03:59<00:00,  1.36s/it, loss=0.0017]


Number of predicted sequences: 813
Number of true label sequences: 813
Lengths of first few predicted sequences: [46, 56, 20, 18, 24]
Lengths of first few true label sequences: [46, 56, 20, 18, 24]
Epoch 86/100:
  Train Loss: 0.0173
  Val F1: 0.8853
  Val Precision: 0.8810
  Val Recall: 0.8897


Epoch 87/100: 100%|█████████████████████████████████████████████████████| 176/176 [03:43<00:00,  1.27s/it, loss=0.0022]


Number of predicted sequences: 813
Number of true label sequences: 813
Lengths of first few predicted sequences: [46, 56, 20, 18, 24]
Lengths of first few true label sequences: [46, 56, 20, 18, 24]
Epoch 87/100:
  Train Loss: 0.0171
  Val F1: 0.8866
  Val Precision: 0.8809
  Val Recall: 0.8924


Epoch 88/100: 100%|█████████████████████████████████████████████████████| 176/176 [03:45<00:00,  1.28s/it, loss=0.0024]


Number of predicted sequences: 813
Number of true label sequences: 813
Lengths of first few predicted sequences: [46, 56, 20, 18, 24]
Lengths of first few true label sequences: [46, 56, 20, 18, 24]
Epoch 88/100:
  Train Loss: 0.0140
  Val F1: 0.8860
  Val Precision: 0.8833
  Val Recall: 0.8887


Epoch 89/100: 100%|█████████████████████████████████████████████████████| 176/176 [04:05<00:00,  1.39s/it, loss=0.0010]


Number of predicted sequences: 813
Number of true label sequences: 813
Lengths of first few predicted sequences: [46, 56, 20, 18, 24]
Lengths of first few true label sequences: [46, 56, 20, 18, 24]
Epoch 89/100:
  Train Loss: 0.0218
  Val F1: 0.8853
  Val Precision: 0.8806
  Val Recall: 0.8901


Epoch 90/100: 100%|█████████████████████████████████████████████████████| 176/176 [03:51<00:00,  1.32s/it, loss=0.0007]


Number of predicted sequences: 813
Number of true label sequences: 813
Lengths of first few predicted sequences: [46, 56, 20, 18, 24]
Lengths of first few true label sequences: [46, 56, 20, 18, 24]
Epoch 90/100:
  Train Loss: 0.0176
  Val F1: 0.8844
  Val Precision: 0.8836
  Val Recall: 0.8852


Epoch 91/100: 100%|█████████████████████████████████████████████████████| 176/176 [03:45<00:00,  1.28s/it, loss=0.0010]


Number of predicted sequences: 813
Number of true label sequences: 813
Lengths of first few predicted sequences: [46, 56, 20, 18, 24]
Lengths of first few true label sequences: [46, 56, 20, 18, 24]
Epoch 91/100:
  Train Loss: 0.0160
  Val F1: 0.8869
  Val Precision: 0.8791
  Val Recall: 0.8948


Epoch 92/100: 100%|█████████████████████████████████████████████████████| 176/176 [03:34<00:00,  1.22s/it, loss=0.0017]


Number of predicted sequences: 813
Number of true label sequences: 813
Lengths of first few predicted sequences: [46, 56, 20, 18, 24]
Lengths of first few true label sequences: [46, 56, 20, 18, 24]
Epoch 92/100:
  Train Loss: 0.0180
  Val F1: 0.8872
  Val Precision: 0.8789
  Val Recall: 0.8957


Epoch 93/100: 100%|█████████████████████████████████████████████████████| 176/176 [03:26<00:00,  1.17s/it, loss=0.0008]


Number of predicted sequences: 813
Number of true label sequences: 813
Lengths of first few predicted sequences: [46, 56, 20, 18, 24]
Lengths of first few true label sequences: [46, 56, 20, 18, 24]
Epoch 93/100:
  Train Loss: 0.0153
  Val F1: 0.8871
  Val Precision: 0.8799
  Val Recall: 0.8944


Epoch 94/100: 100%|█████████████████████████████████████████████████████| 176/176 [03:25<00:00,  1.17s/it, loss=0.0019]


Number of predicted sequences: 813
Number of true label sequences: 813
Lengths of first few predicted sequences: [46, 56, 20, 18, 24]
Lengths of first few true label sequences: [46, 56, 20, 18, 24]
Epoch 94/100:
  Train Loss: 0.0148
  Val F1: 0.8865
  Val Precision: 0.8787
  Val Recall: 0.8944


Epoch 95/100: 100%|█████████████████████████████████████████████████████| 176/176 [03:25<00:00,  1.17s/it, loss=0.0012]


Number of predicted sequences: 813
Number of true label sequences: 813
Lengths of first few predicted sequences: [46, 56, 20, 18, 24]
Lengths of first few true label sequences: [46, 56, 20, 18, 24]
Epoch 95/100:
  Train Loss: 0.0125
  Val F1: 0.8862
  Val Precision: 0.8791
  Val Recall: 0.8934


Epoch 96/100: 100%|█████████████████████████████████████████████████████| 176/176 [03:23<00:00,  1.16s/it, loss=0.0014]


Number of predicted sequences: 813
Number of true label sequences: 813
Lengths of first few predicted sequences: [46, 56, 20, 18, 24]
Lengths of first few true label sequences: [46, 56, 20, 18, 24]
Epoch 96/100:
  Train Loss: 0.0153
  Val F1: 0.8867
  Val Precision: 0.8827
  Val Recall: 0.8907


Epoch 97/100: 100%|█████████████████████████████████████████████████████| 176/176 [03:24<00:00,  1.16s/it, loss=0.0005]


Number of predicted sequences: 813
Number of true label sequences: 813
Lengths of first few predicted sequences: [46, 56, 20, 18, 24]
Lengths of first few true label sequences: [46, 56, 20, 18, 24]
Epoch 97/100:
  Train Loss: 0.0158
  Val F1: 0.8863
  Val Precision: 0.8814
  Val Recall: 0.8912


Epoch 98/100: 100%|█████████████████████████████████████████████████████| 176/176 [03:25<00:00,  1.17s/it, loss=0.0008]


Number of predicted sequences: 813
Number of true label sequences: 813
Lengths of first few predicted sequences: [46, 56, 20, 18, 24]
Lengths of first few true label sequences: [46, 56, 20, 18, 24]
Epoch 98/100:
  Train Loss: 0.0145
  Val F1: 0.8865
  Val Precision: 0.8825
  Val Recall: 0.8905


Epoch 99/100: 100%|█████████████████████████████████████████████████████| 176/176 [03:25<00:00,  1.17s/it, loss=0.0011]


Number of predicted sequences: 813
Number of true label sequences: 813
Lengths of first few predicted sequences: [46, 56, 20, 18, 24]
Lengths of first few true label sequences: [46, 56, 20, 18, 24]
Epoch 99/100:
  Train Loss: 0.0130
  Val F1: 0.8865
  Val Precision: 0.8819
  Val Recall: 0.8912


Epoch 100/100: 100%|████████████████████████████████████████████████████| 176/176 [03:24<00:00,  1.16s/it, loss=0.0010]


Number of predicted sequences: 813
Number of true label sequences: 813
Lengths of first few predicted sequences: [46, 56, 20, 18, 24]
Lengths of first few true label sequences: [46, 56, 20, 18, 24]
Epoch 100/100:
  Train Loss: 0.0117
  Val F1: 0.8865
  Val Precision: 0.8819
  Val Recall: 0.8912
Number of predicted sequences: 748
Number of true label sequences: 748
Lengths of first few predicted sequences: [17, 8, 59, 48, 23]
Lengths of first few true label sequences: [17, 8, 59, 48, 23]
Test F1-score: 0.8684
Test Precision: 0.8604
Test Recall: 0.8765


In [38]:
import torch
import torch.nn as nn
from transformers import ElectraModel, ElectraTokenizer
from TorchCRF import CRF

# Define entity types and label mappings (as before)
entity_types = ["Corporation", "Locale", "Application", "Malware", "Vulnerability", "Assault_technique",
                "Infected_file", "APT_group", "Ransomware", "Campaign", "Algorithm", "System", "Organization", "Indicator"]

entity_tags = (["O"] + 
               [f"B-{entity}" for entity in entity_types] + 
               [f"I-{entity}" for entity in entity_types] + 
               [f"E-{entity}" for entity in entity_types] + 
               [f"S-{entity}" for entity in entity_types])

entity_label_map = {tag: idx for idx, tag in enumerate(entity_tags)}
entity_label_map['<start>'] = len(entity_label_map)
entity_label_map['<end>'] = len(entity_label_map)
entity_label_map['<pad>'] = len(entity_label_map)

# Reverse mapping for decoding
id2label = {v: k for k, v in entity_label_map.items()}

# Define the JointExtractionModel
class JointExtractionModel(nn.Module):
    def __init__(self, num_labels, hidden_size, num_layers, device):
        super(JointExtractionModel, self).__init__()
        self.device = device
        self.electra = ElectraModel.from_pretrained('google/electra-base-discriminator')
        self.gru = nn.GRU(768, hidden_size, num_layers=num_layers, bidirectional=True, batch_first=True)
        self.linear = nn.Linear(hidden_size * 2, num_labels)
        self.crf = CRF(num_labels)

    def forward(self, input_ids, attention_mask):
        electra_output = self.electra(input_ids=input_ids, attention_mask=attention_mask)
        hidden = electra_output.last_hidden_state
        gru_output, _ = self.gru(hidden)
        logits = self.linear(gru_output)
        return logits

    def decode(self, input_ids, attention_mask):
        logits = self.forward(input_ids, attention_mask)
        return self.crf.viterbi_decode(logits, mask=attention_mask.bool())

# Initialize model and load weights
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
hidden_size = 256
num_layers = 2
num_labels = len(entity_label_map)

model = JointExtractionModel(num_labels, hidden_size, num_layers, device)
model.load_state_dict(torch.load('best_electra_ner_model_16.pt', map_location=device))
model.to(device)
model.eval()

# Initialize tokenizer
tokenizer = ElectraTokenizer.from_pretrained('google/electra-base-discriminator')

def preprocess_text(text, max_length=512):
    words = ['<start>'] + text.split() + ['<end>']
    tokenized_words = [tokenizer.tokenize(word) for word in words]
    tokens = [token for word in tokenized_words for token in word]
    
    input_ids = tokenizer.convert_tokens_to_ids(tokens)
    attention_mask = [1] * len(input_ids)
    
    if len(input_ids) < max_length:
        padding_length = max_length - len(input_ids)
        input_ids = input_ids + [tokenizer.pad_token_id] * padding_length
        attention_mask = attention_mask + [0] * padding_length
    else:
        input_ids = input_ids[:max_length]
        attention_mask = attention_mask[:max_length]
    
    return tokens, torch.tensor([input_ids]).to(device), torch.tensor([attention_mask]).to(device)

In [42]:
def predict(text):
    tokens, input_ids, attention_mask = preprocess_text(text)
    
    with torch.no_grad():
        predicted_labels = model.decode(input_ids, attention_mask)[0]
    
    predicted_tags = [id2label[label] for label in predicted_labels]
    
    # Remove special tokens and their predictions
    words = text.split()
    predicted_tags = predicted_tags[1:-1]  # Remove <start> and <end> predictions
    
    # Align predictions with original words
    aligned_labels = []
    token_index = 0
    for word in words:
        word_tokens = tokenizer.tokenize(word)
        if token_index < len(predicted_tags):
            word_label = predicted_tags[token_index]
            aligned_labels.append(word_label)
            token_index += len(word_tokens)
        else:
            aligned_labels.append('O')  # Default to 'O' if we run out of predictions
    
    return list(zip(words, aligned_labels))

def process_file(input_file, output_file):
    with open(input_file, 'r', encoding='utf-8') as f_in, open(output_file, 'w', encoding='utf-8') as f_out:
        for line in f_in:
            line = line.strip()
            if line:
                results = predict(line)
                for word, label in results:
                    f_out.write(f"{word}\t{label}\n")
                f_out.write("\n")  # Empty line between sentences
            else:
                f_out.write("\n")  # Preserve empty lines from input

In [43]:
if __name__ == "__main__":
    input_file = r"C:\Users\Admin\Desktop\nitt\sample_test_dataset_NITT_project.txt" 
    output_file = "predictions_sample.txt"
    process_file(input_file, output_file)
    print(f"Predictions saved to {output_file}")

Predictions saved to predictions_sample.txt


In [33]:
# def save_results(results, output_file):
#     with open(output_file, 'w', encoding='utf-8') as f:
#         for sentence_num, sentence_result in enumerate(results, 1):
#             f.write(f"# Sentence {sentence_num}\n")
#             for word, label in sentence_result:
#                 f.write(f"{word}\t{label}\n")
#             f.write("\n")  # Empty line between sentences

In [34]:
def preprocess_text(text, max_length=512):
    words = ['<start>'] + text.split() + ['<end>']
    tokenized_words = [tokenizer.tokenize(word) for word in words]
    tokens = [token for word in tokenized_words for token in word]
    
    input_ids = tokenizer.convert_tokens_to_ids(tokens)
    attention_mask = [1] * len(input_ids)
    
    if len(input_ids) < max_length:
        padding_length = max_length - len(input_ids)
        input_ids = input_ids + [tokenizer.pad_token_id] * padding_length
        attention_mask = attention_mask + [0] * padding_length
    else:
        input_ids = input_ids[:max_length]
        attention_mask = attention_mask[:max_length]
    
    return words, tokens, torch.tensor([input_ids]).to(device), torch.tensor([attention_mask]).to(device)

In [36]:
def process_file(input_file, output_file):
    with open(input_file, 'r', encoding='utf-8') as f:
        sentences = f.readlines()
    
    all_predictions = []
    for sentence in sentences:
        sentence = sentence.strip()
        if sentence:
            result = predict(sentence)
            all_predictions.extend([label for _, label in result])
            print(f"Processed: {sentence}")
            print(f"Predictions: {[label for _, label in result]}")
            print()
    
    # Analyze predictions
    label_counts = Counter(all_predictions)
    total_predictions = len(all_predictions)
    
    print("\nPrediction Analysis:")
    print(f"Total predictions: {total_predictions}")
    for label, count in label_counts.items():
        percentage = (count / total_predictions) * 100
        print(f"{label}: {count} ({percentage:.2f}%)")
    
    # Save detailed results to file
    with open(output_file, 'w', encoding='utf-8') as f:
        for sentence, result in zip(sentences, all_predictions):
            f.write(f"Sentence: {sentence.strip()}\n")
            f.write(f"Predictions: {' '.join(result)}\n\n")
    
    print(f"\nDetailed results saved to {output_file}")

def predict(text):
    words, tokens, input_ids, attention_mask = preprocess_text(text)
    
    print("Debug - Tokens:", tokens)
    
    with torch.no_grad():
        predicted_labels = model.decode(input_ids, attention_mask)[0]
    
    predicted_tags = [id2label[label] for label in predicted_labels]
    
    print("Debug - Raw predicted tags:", predicted_tags)
    
    # Remove special tokens and their predictions
    tokens = tokens[1:-1]  # Remove <start> and <end>
    predicted_tags = predicted_tags[1:-1]  # Remove corresponding predictions
    
    aligned_labels = []
    word_index = 0
    token_index = 0
    
    while word_index < len(words) and token_index < len(tokens):
        word = words[word_index]
        word_tokens = tokenizer.tokenize(word)
        word_label = predicted_tags[token_index]
        
        # For words split into subwords, use the label of the first subword
        aligned_labels.append(word_label)
        
        word_index += 1
        token_index += len(word_tokens)
    
    # Ensure we have a label for each word
    while len(aligned_labels) < len(words):
        aligned_labels.append('O')
    
    result = list(zip(words, aligned_labels))
    print("Debug - Aligned result:", result)
    
    return result

In [37]:
if __name__ == "__main__":
    input_file = r"C:\Users\Admin\Desktop\nitt\sample_test_dataset_NITT_project.txt"  # Replace with your input file path
    output_file = "prediction_results.txt"
    process_file(input_file, output_file)

Debug - Tokens: ['<', 'start', '>', 'open', '##pg', '##p', 'is', 'the', 'protocol', 'in', 'pg', '##p', '(', 'pretty', 'good', 'privacy', ')', '.', '<', 'end', '>']
Debug - Raw predicted tags: ['<start>', '<start>', '<start>', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', '<end>', '<end>', '<end>']
Debug - Aligned result: [('<start>', '<start>'), ('OpenPGP', 'O'), ('is', 'O'), ('the', 'O'), ('protocol', 'O'), ('in', 'O'), ('PGP', 'O'), ('(', 'O'), ('Pretty', 'O'), ('Good', 'O'), ('Privacy', 'O'), (')', 'O'), ('.', '<end>'), ('<end>', '<end>')]
Processed: OpenPGP is the protocol in PGP ( Pretty Good Privacy ) .
Predictions: ['<start>', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', '<end>', '<end>']

Debug - Tokens: ['<', 'start', '>', 'they', 'use', 'implant', '##s', 'and', 'c2', '(', 'command', 'and', 'control', ')', 'code', 'that', 'shared', 'across', 'multiple', 'chinese', '-', 'speaking', 'apt', '##s', '.', '<', 'end', '>']
Debug - Raw predicted 

### Checking if the model weights have been loaded properly

In [35]:
# def test_single_sentence(sentence):
#     print(f"Testing sentence: {sentence}")
#     result = predict(sentence)
#     print("Prediction result:")
#     for word, label in result:
#         print(f"{word}\t{label}")
#     print("\n")
#     return result

In [17]:
# Initialize model and load weights
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
hidden_size = 256
num_layers = 2
num_labels = len(entity_label_map)

model = JointExtractionModel(num_labels, hidden_size, num_layers, device)
state_dict = torch.load('best_electra_ner_model_16.pt', map_location=device)

# Print the keys in the state dict
print("Keys in state dict:", state_dict.keys())

# Try to load state dict and catch any errors
try:
    model.load_state_dict(state_dict)
    print("Model loaded successfully")
except RuntimeError as e:
    print("Error loading model:", str(e))
    # If there's an error, try to load with strict=False
    model.load_state_dict(state_dict, strict=False)
    print("Model loaded with strict=False")

model.to(device)
model.eval()

Keys in state dict: odict_keys(['electra.embeddings.word_embeddings.weight', 'electra.embeddings.position_embeddings.weight', 'electra.embeddings.token_type_embeddings.weight', 'electra.embeddings.LayerNorm.weight', 'electra.embeddings.LayerNorm.bias', 'electra.encoder.layer.0.attention.self.query.weight', 'electra.encoder.layer.0.attention.self.query.bias', 'electra.encoder.layer.0.attention.self.key.weight', 'electra.encoder.layer.0.attention.self.key.bias', 'electra.encoder.layer.0.attention.self.value.weight', 'electra.encoder.layer.0.attention.self.value.bias', 'electra.encoder.layer.0.attention.output.dense.weight', 'electra.encoder.layer.0.attention.output.dense.bias', 'electra.encoder.layer.0.attention.output.LayerNorm.weight', 'electra.encoder.layer.0.attention.output.LayerNorm.bias', 'electra.encoder.layer.0.intermediate.dense.weight', 'electra.encoder.layer.0.intermediate.dense.bias', 'electra.encoder.layer.0.output.dense.weight', 'electra.encoder.layer.0.output.dense.bias',

JointExtractionModel(
  (electra): ElectraModel(
    (embeddings): ElectraEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): ElectraEncoder(
      (layer): ModuleList(
        (0-11): 12 x ElectraLayer(
          (attention): ElectraAttention(
            (self): ElectraSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): ElectraSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((76