In [1]:
from transformers import ElectraTokenizer, ElectraForTokenClassification
import torch
from torch.utils.data import DataLoader, TensorDataset
from torch.nn.utils.rnn import pad_sequence

# Define entity and relation types
entity_types = ["Corporation", "Locale", "Application", "Malware", "Vulnerability", "Assault_technique",
                "Infected_file", "APT_group", "Ransomware", "Campaign", "Algorithm", "System", "Organization", "Indicator"]

relation_types = ["originate_from", "utilizes", "is_susceptible", "contains_product", "contains_file",
                  "belongs_to", "operates_in", "launches", "is_a_type_of", "indicates", "implements"]

# Define label mappings
entity_tags = ["O"] + [f"B-{entity}" for entity in entity_types] + [f"I-{entity}" for entity in entity_types] + [f"E-{entity}" for entity in entity_types] + [f"S-{entity}" for entity in entity_types]
relation_tags = ["O"] + [f"B-{relation}" for relation in relation_types] + [f"I-{relation}" for relation in relation_types] + [f"E-{relation}" for relation in relation_types] + [f"S-{relation}" for relation in relation_types]

entity_label_map = {tag: idx for idx, tag in enumerate(entity_tags)}
relation_label_map = {tag: idx for idx, tag in enumerate(relation_tags)}

# Add start and end tokens to entity_label_map
entity_label_map['<start>'] = len(entity_label_map)
entity_label_map['<end>'] = len(entity_label_map)

# Define the preprocessing function
def preprocess_ner_data(file_path, encoding='utf-8'):
    with open(file_path, 'r', encoding=encoding) as file:
        lines = file.readlines()

    sentences = []
    labels = []
    sentence = []
    label = []

    for line in lines:
        if line.strip() == "":
            if sentence:
                sentences.append(['<start>'] + sentence + ['<end>'])
                labels.append(['<start>'] + label + ['<end>'])
                sentence = []
                label = []
        else:
            word, tag = line.strip().split('\t')
            sentence.append(word)
            label.append(tag)

    if sentence:
        sentences.append(['<start>'] + sentence + ['<end>'])
        labels.append(['<start>'] + label + ['<end>'])

    return sentences, labels

# Preprocess data
train_sentences, train_labels = preprocess_ner_data("C:\\Users\\Admin\\Desktop\\nitt\\data\\MITREtrain.txt")
valid_sentences, valid_labels = preprocess_ner_data("C:\\Users\\Admin\\Desktop\\nitt\\data\\MITREvalid.txt")
test_sentences, test_labels = preprocess_ner_data("C:\\Users\\Admin\\Desktop\\nitt\\data\\MITREtest.txt")

print("Entity Label Map:")
print(entity_label_map)
print("\nRelation Label Map:")
print(relation_label_map)

# Initialize the ELECTRA tokenizer
tokenizer = ElectraTokenizer.from_pretrained('google/electra-base-discriminator')

# Function to tokenize and encode sentences
def tokenize_and_encode(sentences, labels, tokenizer, entity_label_map, max_length=512):
    all_input_ids = []
    all_attention_masks = []
    all_label_ids = []

    for sentence, label in zip(sentences, labels):
        # Tokenize each word separately
        tokenized_words = [tokenizer.tokenize(word) for word in sentence]
        
        # Flatten the list of tokenized words
        tokens = [token for word in tokenized_words for token in word]
        
        # Convert tokens to ids
        input_id = tokenizer.convert_tokens_to_ids(tokens)
        
        # Create attention mask
        attention_mask = [1] * len(input_id)
        
        # Align labels
        label_id = []
        for word, word_tokens in zip(label, tokenized_words):
            label_id.extend([entity_label_map.get(word, entity_label_map["O"])] * len(word_tokens))
        
        # Pad or truncate
        if len(input_id) < max_length:
            padding_length = max_length - len(input_id)
            input_id = input_id + [tokenizer.pad_token_id] * padding_length
            attention_mask = attention_mask + [0] * padding_length
            label_id = label_id + [entity_label_map['<pad>']] * padding_length
        else:
            input_id = input_id[:max_length]
            attention_mask = attention_mask[:max_length]
            label_id = label_id[:max_length]
        
        all_input_ids.append(input_id)
        all_attention_masks.append(attention_mask)
        all_label_ids.append(label_id)

    # Convert to tensors
    all_input_ids = torch.tensor(all_input_ids)
    all_attention_masks = torch.tensor(all_attention_masks)
    all_label_ids = torch.tensor(all_label_ids)
    
    return all_input_ids, all_attention_masks, all_label_ids

# Add padding token to entity_label_map
entity_label_map['<pad>'] = len(entity_label_map)

# Tokenize with adjusted max_length
train_inputs, train_masks, train_labels = tokenize_and_encode(train_sentences, train_labels, tokenizer, entity_label_map, max_length=512)
valid_inputs, valid_masks, valid_labels = tokenize_and_encode(valid_sentences, valid_labels, tokenizer, entity_label_map, max_length=512)
test_inputs, test_masks, test_labels = tokenize_and_encode(test_sentences, test_labels, tokenizer, entity_label_map, max_length=512)

# Create TensorDataset instances
train_data = TensorDataset(train_inputs, train_masks, train_labels)
valid_data = TensorDataset(valid_inputs, valid_masks, valid_labels)
test_data = TensorDataset(test_inputs, test_masks, test_labels)

train_batch_size = 8
train_loader = DataLoader(train_data, batch_size=train_batch_size, shuffle=True)
valid_loader = DataLoader(valid_data, batch_size=train_batch_size)
test_loader = DataLoader(test_data, batch_size=train_batch_size)

# Print shapes to verify
print("\nTrain Data Shapes:")
print("Input IDs:", train_inputs.shape)
print("Attention Masks:", train_masks.shape)
print("Label IDs:", train_labels.shape)

print(f"\nTotal number of labels: {len(entity_label_map)}")

# Initialize ELECTRA model for token classification
model = ElectraForTokenClassification.from_pretrained('google/electra-base-discriminator', num_labels=len(entity_label_map))

  from .autonotebook import tqdm as notebook_tqdm


Entity Label Map:
{'O': 0, 'B-Corporation': 1, 'B-Locale': 2, 'B-Application': 3, 'B-Malware': 4, 'B-Vulnerability': 5, 'B-Assault_technique': 6, 'B-Infected_file': 7, 'B-APT_group': 8, 'B-Ransomware': 9, 'B-Campaign': 10, 'B-Algorithm': 11, 'B-System': 12, 'B-Organization': 13, 'B-Indicator': 14, 'I-Corporation': 15, 'I-Locale': 16, 'I-Application': 17, 'I-Malware': 18, 'I-Vulnerability': 19, 'I-Assault_technique': 20, 'I-Infected_file': 21, 'I-APT_group': 22, 'I-Ransomware': 23, 'I-Campaign': 24, 'I-Algorithm': 25, 'I-System': 26, 'I-Organization': 27, 'I-Indicator': 28, 'E-Corporation': 29, 'E-Locale': 30, 'E-Application': 31, 'E-Malware': 32, 'E-Vulnerability': 33, 'E-Assault_technique': 34, 'E-Infected_file': 35, 'E-APT_group': 36, 'E-Ransomware': 37, 'E-Campaign': 38, 'E-Algorithm': 39, 'E-System': 40, 'E-Organization': 41, 'E-Indicator': 42, 'S-Corporation': 43, 'S-Locale': 44, 'S-Application': 45, 'S-Malware': 46, 'S-Vulnerability': 47, 'S-Assault_technique': 48, 'S-Infected_fi

Some weights of ElectraForTokenClassification were not initialized from the model checkpoint at google/electra-base-discriminator and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [12]:
def calculate_class_weights(train_loader, entity_label_map):
    label_counts = {label: 0 for label in entity_label_map.keys()}
    total_labels = 0
    
    for batch in train_loader:
        labels = batch[2]  # Assuming labels are the third item in the batch
        for label_seq in labels:
            for label in label_seq:
                label_name = list(entity_label_map.keys())[list(entity_label_map.values()).index(label.item())]
                label_counts[label_name] += 1
                total_labels += 1
    
    # Handle labels not present in the training data
    min_count = 1  # To avoid division by zero
    for label in label_counts:
        if label_counts[label] == 0:
            label_counts[label] = min_count
            total_labels += min_count
    
    class_weights = {label: total_labels / (len(entity_label_map) * count) for label, count in label_counts.items()}
    
    # Normalize weights
    max_weight = max(class_weights.values())
    class_weights = {label: min(weight / max_weight, 10.0) for label, weight in class_weights.items()}
    
    print("Class weights:")
    for label, weight in class_weights.items():
        print(f"{label}: {weight:.4f}")
    
    return class_weights

In [6]:
# Before your training loop
class_weights = calculate_class_weights(train_loader, entity_label_map)



Class weights:
B-Indicator: 0.6497
B-Malware: 4.8730
B-Organization: 19.5034
B-System: 7.2310
B-Vulnerability: 50.6613
I-Indicator: 3.1188
I-Malware: 35.2053
I-Organization: 67.5484
I-System: 14.9433
I-Vulnerability: 95.4995
O: 0.1162
<start>: 0.0000
<end>: 0.0000
<pad>: 0.0000

Class distribution:
<end>: 8433 (0.59%)
<pad>: 1330973 (92.48%)
<start>: 8433 (0.59%)
B-Indicator: 12788 (0.89%)
B-Malware: 1705 (0.12%)
B-Organization: 426 (0.03%)
B-System: 1149 (0.08%)
B-Vulnerability: 164 (0.01%)
I-Indicator: 2664 (0.19%)
I-Malware: 236 (0.02%)
I-Organization: 123 (0.01%)
I-System: 556 (0.04%)
I-Vulnerability: 87 (0.01%)
O: 71495 (4.97%)


In [7]:
from TorchCRF import CRF
import torch
import torch.nn as nn
from transformers import ElectraModel, ElectraTokenizer, AdamW
from torch.utils.data import DataLoader
from sklearn.metrics import f1_score
import time
from seqeval.metrics import f1_score as seqeval_f1_score
from seqeval.scheme import IOB2

# Define device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class JointExtractionModel(nn.Module):
    def __init__(self, num_labels, hidden_size, num_layers, device):
        super(JointExtractionModel, self).__init__()
        self.device = device
        self.electra = ElectraModel.from_pretrained('google/electra-base-discriminator').to(device)
        self.gru = nn.GRU(768, hidden_size, num_layers=num_layers, bidirectional=True, batch_first=True).to(device)
        self.linear = nn.Linear(hidden_size * 2, num_labels).to(device)
        self.crf = CRF(num_labels).to(device)
        print(f"Model initialized with num_labels: {num_labels}")

    def forward(self, input_ids, attention_mask):
        electra_output = self.electra(input_ids=input_ids, attention_mask=attention_mask)
        hidden = electra_output.last_hidden_state
        gru_output, _ = self.gru(hidden)
        logits = self.linear(gru_output)
        return logits

    def loss(self, input_ids, attention_mask, tags):
        logits = self.forward(input_ids, attention_mask)
        log_likelihood = self.crf.forward(logits, tags, mask=attention_mask.bool())
        return -log_likelihood.mean()

    def decode(self, input_ids, attention_mask):
        logits = self.forward(input_ids, attention_mask)
        return self.crf.viterbi_decode(logits, mask=attention_mask.bool())

# Initialize the model
hidden_size = 256  # or any other value you prefer
num_layers = 2  # or any other value you prefer
num_labels = len(entity_label_map)
model = JointExtractionModel(num_labels, hidden_size, num_layers, device)

# Initialize the ELECTRA tokenizer
tokenizer = ElectraTokenizer.from_pretrained('google/electra-base-discriminator')

# The rest of your code (data loading, training loop, etc.) remains the same

Model initialized with num_labels: 60


In [13]:
import torch
import torch.nn as nn
from transformers import AdamW, get_linear_schedule_with_warmup
from tqdm import tqdm
from seqeval.metrics import f1_score, precision_score, recall_score
import numpy as np

def train(model, train_loader, val_loader, test_loader, entity_label_map, device, epochs=10, lr=2e-5, max_grad_norm=1.0, warmup_steps=0, patience=3):
    # Calculate class weights
    class_weights = calculate_class_weights(train_loader, entity_label_map)
    class_weight_tensor = torch.tensor([class_weights[label] for label in entity_label_map.keys()]).to(device)
    
    # Initialize optimizer and scheduler
    optimizer = AdamW(model.parameters(), lr=lr)
    total_steps = len(train_loader) * epochs
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps)
    
    # Initialize best model tracking
    best_val_f1 = 0
    best_model = None
    
    # Early stopping variables
    no_improve_epochs = 0
    
    for epoch in range(epochs):
        model.train()
        total_loss = 0
        
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}")
        for batch in progress_bar:
            input_ids, attention_mask, labels = [b.to(device) for b in batch]
            
            # Forward pass
            logits = model(input_ids, attention_mask)
            
            # Calculate weighted loss
            loss = model.crf(logits, labels, mask=attention_mask.bool(), reduction='none')
            weighted_loss = (loss * class_weight_tensor[labels]).mean()
            
            # Backward pass
            optimizer.zero_grad()
            weighted_loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
            optimizer.step()
            scheduler.step()
            
            total_loss += weighted_loss.item()
            progress_bar.set_postfix({'loss': f"{weighted_loss.item():.4f}"})
        
        avg_loss = total_loss / len(train_loader)
        print(f"Epoch {epoch+1}/{epochs}, Average Loss: {avg_loss:.4f}")
        
        # Validation
        val_f1, val_precision, val_recall = evaluate(model, val_loader, entity_label_map, device)
        print(f"Validation - F1: {val_f1:.4f}, Precision: {val_precision:.4f}, Recall: {val_recall:.4f}")
        
        # Check if this is the best model
        if val_f1 > best_val_f1:
            best_val_f1 = val_f1
            best_model = model.state_dict()
            torch.save(best_model, 'electra_new_cw.pt')
            print("New best model saved!")
            no_improve_epochs = 0
        else:
            no_improve_epochs += 1
        
        # Early stopping
        if no_improve_epochs >= patience:
            print(f"Early stopping triggered after {epoch+1} epochs")
            break
    
    # Load the best model and evaluate on test set
    model.load_state_dict(torch.load('electra_new_cw.pt'))
    test_f1, test_precision, test_recall = evaluate(model, test_loader, entity_label_map, device)
    print(f"Test - F1: {test_f1:.4f}, Precision: {test_precision:.4f}, Recall: {test_recall:.4f}")
    
    return model

In [14]:
def evaluate(model, data_loader, entity_label_map, device):
    model.eval()
    all_preds = []
    all_labels = []
    inv_label_map = {v: k for k, v in entity_label_map.items()}

    with torch.no_grad():
        for batch in data_loader:
            input_ids, attention_mask, tags = [b.to(device) for b in batch]
            predictions = model.decode(input_ids, attention_mask)
            
            for pred, true, mask in zip(predictions, tags, attention_mask):
                pred = [inv_label_map[p] for p, m in zip(pred, mask) if m.item() == 1]
                true = [inv_label_map[t.item()] for t, m in zip(true, mask) if m.item() == 1 and t.item() != -100]
                
                # Ensure pred and true have the same length
                min_len = min(len(pred), len(true))
                pred = pred[:min_len]
                true = true[:min_len]
                
                all_preds.append(pred)
                all_labels.append(true)

    # Print some debug information
    print(f"Number of predicted sequences: {len(all_preds)}")
    print(f"Number of true label sequences: {len(all_labels)}")
    print(f"Lengths of first few predicted sequences: {[len(p) for p in all_preds[:5]]}")
    print(f"Lengths of first few true label sequences: {[len(t) for t in all_labels[:5]]}")

    f1 = f1_score(all_labels, all_preds)
    precision = precision_score(all_labels, all_preds)
    recall = recall_score(all_labels, all_preds)

    return f1, precision, recall

In [15]:
#model = JointExtractionModel(num_labels, hidden_size, num_layers, device)
model = train(model, train_loader, valid_loader, test_loader, entity_label_map, device, epochs=10, lr=2e-5, patience=3)

Class weights:
O: 0.0000
B-Corporation: 1.0000
B-Locale: 1.0000
B-Application: 1.0000
B-Malware: 0.0006
B-Vulnerability: 0.0061
B-Assault_technique: 1.0000
B-Infected_file: 1.0000
B-APT_group: 1.0000
B-Ransomware: 1.0000
B-Campaign: 1.0000
B-Algorithm: 1.0000
B-System: 0.0009
B-Organization: 0.0023
B-Indicator: 0.0001
I-Corporation: 1.0000
I-Locale: 1.0000
I-Application: 1.0000
I-Malware: 0.0042
I-Vulnerability: 0.0115
I-Assault_technique: 1.0000
I-Infected_file: 1.0000
I-APT_group: 1.0000
I-Ransomware: 1.0000
I-Campaign: 1.0000
I-Algorithm: 1.0000
I-System: 0.0018
I-Organization: 0.0081
I-Indicator: 0.0004
E-Corporation: 1.0000
E-Locale: 1.0000
E-Application: 1.0000
E-Malware: 1.0000
E-Vulnerability: 1.0000
E-Assault_technique: 1.0000
E-Infected_file: 1.0000
E-APT_group: 1.0000
E-Ransomware: 1.0000
E-Campaign: 1.0000
E-Algorithm: 1.0000
E-System: 1.0000
E-Organization: 1.0000
E-Indicator: 1.0000
S-Corporation: 1.0000
S-Locale: 1.0000
S-Application: 1.0000
S-Malware: 1.0000
S-Vulnerabi

Epoch 1/10:   0%|                                                                              | 0/352 [00:00<?, ?it/s]


TypeError: forward() got an unexpected keyword argument 'reduction'