In [None]:
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset

import pandas as pd
from sklearn.model_selection import train_test_split

from transformers import BertTokenizer, BertForSequenceClassification
from transformers import AdamW
from tqdm.auto import tqdm

from CAP.utils import *

seed_everything(14759) # Fix seed

## Function

In [None]:
def preprocess_data(dataframe, tokenizer, max_len=64):
    input_ids = []
    attention_masks = []
    labels = []

    for _, row in dataframe.iterrows():
        encoded = tokenizer.encode_plus(
            str(row['label'].upper()),
            add_special_tokens=True,
            max_length=max_len,
            pad_to_max_length=True,
            return_attention_mask=True,
            return_tensors='pt',
            truncation=True
        )
        input_ids.append(encoded['input_ids'])
        attention_masks.append(encoded['attention_mask'])
        labels.append(int(row['class']))

    input_ids = torch.cat(input_ids, dim=0)
    attention_masks = torch.cat(attention_masks, dim=0)
    labels = torch.tensor(labels)

    return TensorDataset(input_ids, attention_masks, labels)

## Prepare data

In [None]:
DATA_COUNT = 30000

In [None]:
data = load_normal_word(DATA_COUNT, 3, 10, './data/normal/')
data += create_hash_text(data, max_length=10)
data += create_ip_text(DATA_COUNT, 'ipv4', verbose=False)
data += create_ip_text(DATA_COUNT, 'ipv6', text_length_range=[3,4,5,6,7,8,9,10], verbose=False)
data += create_mac_text(DATA_COUNT, verbose=False)

In [None]:
data_dict = {'label':data, 'class':[]}

for i in range(5):
    data_dict['class'] += [i for _ in range(DATA_COUNT)]
    
df = pd.DataFrame(data_dict)

In [None]:
tokenizer = BertTokenizer.from_pretrained('bert-base-cased')

train_df, val_df = train_test_split(df, test_size=0.1, shuffle=True)
train_dataset = preprocess_data(train_df, tokenizer)
val_dataset = preprocess_data(val_df, tokenizer)
batch_size = 128

train_dataloader = DataLoader(
        train_dataset,
        sampler=RandomSampler(train_dataset),
        batch_size=batch_size
    )
    
validation_dataloader = DataLoader(
    val_dataset,
    sampler=SequentialSampler(val_dataset),
    batch_size=batch_size
)

## Define model

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = BertForSequenceClassification.from_pretrained('bert-base-cased', num_labels=5).to(device)

## Train model

In [None]:
optimizer = AdamW(model.parameters(), lr=1e-4)
num_epochs = 30
stop=3
early=0

In [None]:
best_val_loss = float('inf')

for epoch in range(num_epochs):
    model.train()
    total_train_loss = 0
    train_progress_bar = tqdm(train_dataloader, desc=f"Training Epoch {epoch + 1}")
    for batch in train_progress_bar:
        batch = tuple(t.to(device) for t in batch)
        b_input_ids, b_input_mask, b_labels = batch
        
        optimizer.zero_grad()
        
        outputs = model(b_input_ids, 
                        token_type_ids=None, 
                        attention_mask=b_input_mask, 
                        labels=b_labels)
        loss = outputs.loss
        total_train_loss += loss.item()
        
        loss.backward()
        optimizer.step()
        
        train_progress_bar.set_postfix({'avg_train_loss': total_train_loss / (train_progress_bar.last_print_n + 1)})
    
    avg_train_loss = total_train_loss / len(train_dataloader)
    print(f"\nAverage training loss for epoch {epoch + 1}: {avg_train_loss}")

    # 검증 단계
    model.eval()
    total_eval_loss = 0
    total_eval_accuracy = 0
    eval_progress_bar = tqdm(validation_dataloader, desc=f"Validation Epoch {epoch + 1}")
    for batch in eval_progress_bar:
        batch = tuple(t.to(device) for t in batch)
        b_input_ids, b_input_mask, b_labels = batch
        
        with torch.no_grad():
            outputs = model(b_input_ids,
                            token_type_ids=None,
                            attention_mask=b_input_mask,
                            labels=b_labels)
            loss = outputs.loss
            logits = outputs.logits
            total_eval_loss += loss.item()
        
        preds = torch.argmax(logits, dim=1)
        correct_predictions = torch.sum(preds == b_labels)
        total_eval_accuracy += correct_predictions.item()
        
        eval_progress_bar.set_postfix({'avg_val_loss': total_eval_loss / (eval_progress_bar.last_print_n + 1),
                                        'val_accuracy': total_eval_accuracy / ((eval_progress_bar.last_print_n + 1) * batch[0].size(0))})
    
    avg_eval_loss = total_eval_loss / len(validation_dataloader)
    avg_eval_accuracy = total_eval_accuracy / len(validation_dataloader.dataset)
    print(f"\nAverage validation loss for epoch {epoch + 1}: {avg_eval_loss}")
    print(f"Validation accuracy for epoch {epoch + 1}: {avg_eval_accuracy}")

    # Save model at best validation loss
    if avg_eval_loss < best_val_loss:
        best_val_loss = avg_eval_loss
        torch.save(model.state_dict(), f'./CAP_result/MAERec-S_{avg_eval_loss}_{avg_eval_accuracy}.pth')
        print(f"Model saved: epoch {epoch+1}, val_loss {avg_eval_loss:.4f}")
    
    else:
        if stop==early:
            break
        early+=1
        
print("Training and validation complete!")