In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

import numpy as np
from tqdm import tqdm

from CAP import CAPModel
from CAP.dataset import CAPDataSet
from CAP.utils import *
from typing import List, Tuple

seed_everything(1813) # Seed 고정

## Functions

In [3]:
def switch2dataset(text_list: List[str], class_numer: int) -> List[Tuple[str, int]]:
    """
    Prepare dataset : Switch words list to dataset tuple 
    :param text_list: List of words
    :param class_numer: Text class
    :return: dataset tuple 
    """
    res = []
    for text_ in text_list:
        res.append((text_, class_numer))
    return res

In [4]:
def split_list(l:list, test_size: float = 0.2) -> (List[Tuple[str, int]], List[Tuple[str, int]]):
    """
    Split list with test_size 
    :param l: Dataset tuple list
    :param test_size: size of the test dataset. Default = 0.2
    :return: (train_dataset, test_dataset)
    """
    return l[:int(len(l) * (1 - test_size))], l[int(len(l) * (1 - test_size)):]

## Prepare data

In [5]:
DATA_COUNT = 30000

In [6]:
normal_data = load_normal_word(DATA_COUNT, 3, 10, './data/normal/')
hash_data = create_hash_text(normal_data, max_length=10)
ipv4_data = create_ip_text(DATA_COUNT, 'ipv4', verbose=False)
ipv6_data = create_ip_text(DATA_COUNT, 'ipv6', text_length_range=[3,4,5,6,7,8,9,10], verbose=False)
mac_data = create_mac_text(DATA_COUNT, verbose=False)

In [7]:
normal_data = switch2dataset(normal_data, 0)
hash_data = switch2dataset(hash_data, 1)
ipv4_data = switch2dataset(ipv4_data, 2)
ipv6_data = switch2dataset(ipv6_data, 3)
mac_data = switch2dataset(mac_data, 4)

In [8]:
normal_train, normal_val = split_list(normal_data)
hash_train, hash_val = split_list(hash_data)
ipv4_train, ipv4_val = split_list(ipv4_data)
ipv6_train, ipv6_val = split_list(ipv6_data)
mac_train, mac_val = split_list(mac_data)

In [9]:
train_data = normal_train + hash_train + ipv4_train + ipv6_train+ mac_train
val_data = normal_val + hash_val + ipv4_val + ipv6_val +mac_val

## Prepre Dataset

In [10]:
train_dataset = CAPDataSet(train_data, pad_size=30)
val_dataset = CAPDataSet(val_data, pad_size=30)

In [11]:
train_loader = DataLoader(train_dataset, batch_size=512, num_workers=0, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=512, num_workers=0, shuffle=True)

## Define model

In [12]:
model = CAPModel(71, 300, 500, 5, num_layers=6, bidirectional=True)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
n_epochs = 100
SAVE_PATH = './weight/'

## Train model

In [None]:
model.to(device)
train_loss = torch.zeros(n_epochs)
valid_loss = torch.zeros(n_epochs)

train_acc = torch.zeros(n_epochs)
valid_acc = torch.zeros(n_epochs)

valid_loss_min = np.Inf
past_lr = 0.0001
low_epoch = 0

model.to(device)

for e in range(n_epochs):
    print(f'\n====================== [Epoch {e+1}] ======================')
    model.train()
    train_tq = tqdm(train_loader)
    
    count = 0
    cnt = 0
    for data, label in train_tq:
        cnt += 1
        data, label = data.to(device), label.to(device)
        optimizer.zero_grad()
        
        output = model(data)
        loss = criterion(output, label)
        
        train_acc[e] += sum(label.detach().cpu() == torch.argmax(output.softmax(dim=1), dim=1).detach().cpu())
        count += len(label)
        
        loss.backward()
        optimizer.step()
        train_loss[e] += loss.item()
        train_tq.set_description(f'Train_loss : {train_loss[e] / cnt}')
        
    train_loss[e] /= len(train_loader)
    train_acc[e] /= count

    model.eval()
    count = 0
    for data, label in tqdm(val_loader):
        with torch.no_grad():
            data, label = data.to(device), label.to(device)
            
            output = model(data)
            loss = criterion(output, label)
            
            valid_acc[e] += sum(label.detach().cpu() == torch.argmax(output.softmax(dim=1), dim=1).detach().cpu())
            count += len(label)
            
            valid_loss[e] += loss.item()
            
    valid_loss[e] /= len(val_loader)
    valid_acc[e] /= count
    
    print('Epoch: {} \tTraining Loss: {:.6f} \tValidation Loss: {:.6f}'.format(e+1, train_loss[e], valid_loss[e]))
    print('Epoch: {} \tTraining Acc: {:.6f} \tValidation Accuracy: {:.6f}'.format(e+1, train_acc[e], valid_acc[e]))
    
    if valid_loss_min > valid_loss[e]:
        valid_loss_min = valid_loss[e]
        print(f'Validation loss decreased at Epoch {e+1} - model saved')
        low_epoch = e
        torch.save(model.state_dict(), os.path.join(SAVE_PATH, 'cap-lstm.pth'))

print(f'\n============== Training Result ==============')
print(f'Highest score : {valid_loss_min}')
print(f'Highest epoch : Epoch {low_epoch+1}')