In [210]:
from client.dataset import get_train_test_datasets
from client.net import Net

import torch
from torch.optim import AdamW

from sklearn.metrics import roc_auc_score

# Dataset

In [310]:
dataset_path = 'client/anti_fraud_dataset/client_2/client_anti_fraud_dataset.csv'

In [311]:
train_set, test_set = get_train_test_datasets(dataset_path)

In [312]:
batch_size = 16

In [313]:
train_dataloader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True)
test_dataloader = torch.utils.data.DataLoader(test_set, batch_size=batch_size)

# Model Training

In [315]:
model = Net(n_features=train_set[0]['transaction'].shape[0])
optimizer = AdamW(params=model.parameters(), lr=0.0001)
loss_fn = torch.nn.BCELoss()

In [316]:
epochs = 15

In [318]:
for epoch in range(epochs):
    train_epoch_loss = 0.0
    model.train()
    
    for i, data in enumerate(train_dataloader):
        transactions, labels = data['transaction'], data['label']
        transactions = transactions.reshape(transactions.shape[0], 1, transactions.shape[1])
        
        optimizer.zero_grad()
        
        output = model(transactions)
        
        loss = loss_fn(output, labels)
        loss.backward()
        optimizer.step()
        
        train_epoch_loss += loss.item()
    
    train_epoch_loss /= len(train_dataloader)
    
    test_epoch_loss = 0.0
    model.eval()
    
    outputs = np.array([])
    labels = np.array([])
    
    for i, data in enumerate(test_dataloader):
        transactions, label = data['transaction'], data['label']
        transactions = transactions.reshape(transactions.shape[0], 1, transactions.shape[1])
        output = model(transactions)
        
        loss = loss_fn(output, label)
        
        test_epoch_loss += loss.item()
        outputs = np.hstack([outputs, output.detach().numpy().reshape(-1)])
        labels = np.hstack([labels, label.reshape(-1)])
    
    test_epoch_loss /= len(test_dataloader)
    test_roc_auc_score = roc_auc_score(labels, outputs)
    print(test_roc_auc_score)
    
    
    print(f"{epoch + 1}. Train loss - {train_epoch_loss}, test loss - {test_epoch_loss}")
        

0.9550151975683892
1. Train loss - 0.5729241016365233, test loss - 0.5583413441975912
0.9562310030395137
2. Train loss - 0.44225669900576275, test loss - 0.39403529465198517
0.9592705167173252
3. Train loss - 0.3522117797817503, test loss - 0.3245007743438085
0.9610942249240121
4. Train loss - 0.3366995134523937, test loss - 0.2868928238749504
0.9617021276595744
5. Train loss - 0.3156351276806423, test loss - 0.267974612613519
0.9617021276595744
6. Train loss - 0.3025435053166889, test loss - 0.2595323200027148
0.9623100303951369
7. Train loss - 0.27698197393190294, test loss - 0.24961544076601663
0.9610942249240121
8. Train loss - 0.2579490230196998, test loss - 0.23085092504819235
0.9617021276595744
9. Train loss - 0.2414939829281398, test loss - 0.22434659178058305
0.9617021276595744
10. Train loss - 0.2410434546569983, test loss - 0.21819024781386057
0.9629179331306992
11. Train loss - 0.23216892033815384, test loss - 0.2070761633416017
0.9641337386018236
12. Train loss - 0.2320052