In [1]:
import os
import csv
from collections import defaultdict
import string
from sklearn.preprocessing import LabelEncoder
from gensim.utils import simple_preprocess

import pandas as pd
from sklearn.preprocessing import LabelEncoder
import numpy as np

from transformers import BertTokenizer, BertForSequenceClassification

import torch
from torch.utils.data import Dataset, TensorDataset, DataLoader, RandomSampler, SequentialSampler
import tensorflow as tf
import torchvision
import torchvision.transforms as transforms

from tabulate import tabulate
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score, accuracy_score, precision_score, recall_score, confusion_matrix
from tqdm import trange
import seaborn as sns
import matplotlib.pyplot as plt

In [2]:
def clean(text):
    cleaned_text = ' '.join([word for word in simple_preprocess(text) if word not in string.punctuation])
    return cleaned_text

def create_file_data(category_path, label):
    file_data = defaultdict(str)

    for filename in os.listdir(category_path):
        if filename.endswith('.txt'):
            with open(os.path.join(category_path, filename), 'r', encoding='utf-8-sig') as file:
                text = file.read()
                tokenized_text = clean(text)
                file_data[filename] = {'text': tokenized_text, 'label': label}

    return file_data

file_data_k = create_file_data('/Users/julia/PycharmProjects/judgeAI-main/orzeczenia/kradziez', 'kradziez')
file_data_o = create_file_data('/Users/julia/PycharmProjects/judgeAI-main/orzeczenia/oszustwo', 'oszustwo')
file_data_z = create_file_data('/Users/julia/PycharmProjects/judgeAI-main/orzeczenia/zdrowie', 'przestępstwo przeciwko zdrowiu')
file_data_ko = create_file_data('/Users/julia/PycharmProjects/judgeAI-main/orzeczenia/komunikacja', 'przestępstwo w komunikacji')

merged_dict = {**file_data_k, **file_data_o, **file_data_z, **file_data_ko}

csv_filename = 'data.csv'

with open(csv_filename, 'w', encoding='utf-8-sig', newline='') as csvfile:
    writer = csv.writer(csvfile)

    writer.writerow(['Key', 'Text', 'Label', 'Label_encoded'])

    label_encoder = LabelEncoder()
    for key, values in merged_dict.items():
        lem_text = values['text'].lower()
        label_encoded = label_encoder.fit_transform([values['label']])[0]
        writer.writerow([key, lem_text, values['label'], label_encoded])

In [None]:
df = pd.read_csv('data.csv')

X = df.Text.values
y = df.Label_encoded.values

In [None]:
tokenizer = BertTokenizer.from_pretrained('dkleczek/bert-base-polish-cased-v1', do_lower_case=True)

In [13]:
def calculate_chunksize(text, tokenizer, target_length=512):

    tokens = tokenizer.tokenize(text[1])
    num_tokens = len(tokens)
    return int((num_tokens + target_length) / target_length) * target_length

chunksize = calculate_chunksize(X, tokenizer)
chunks = []

for idx in range(len(X)):
    text = X[idx]
    label = y[idx]

    for i in range(0, len(text), chunksize):
        chunk = {'Chunk': text[i:i+chunksize], 'Label': label}
        chunks.append(chunk)

df = pd.DataFrame(chunks)
df.to_csv('new_data.csv', index=False, encoding='utf-8-sig')

df = pd.read_csv('new_data.csv')


In [None]:
token_id = []
attention_masks = []

def preprocessing(input_text, tokenizer):
    return tokenizer.encode_plus(
        input_text,
        add_special_tokens=True,
        max_length=512,
        pad_to_max_length=True,
        return_attention_mask=True,
        return_tensors='pt'
    )

for sample in df['Chunk']:
    encoding_dict = preprocessing(sample, tokenizer)
    token_id.append(encoding_dict['input_ids'])
    attention_masks.append(encoding_dict['attention_mask'])

token_id = torch.cat(token_id, dim=0)
attention_masks = torch.cat(attention_masks, dim=0)

y = torch.tensor(df['Label'].values)

In [15]:
train_idx, val_idx = train_test_split(
    np.arange(len(y)),
    test_size=0.3,
    shuffle=True,
    stratify=y
)

train_set = TensorDataset(token_id[train_idx], attention_masks[train_idx], y[train_idx])
val_set = TensorDataset(token_id[val_idx], attention_masks[val_idx], y[val_idx])

train_dataloader = DataLoader(
    train_set,
    sampler=RandomSampler(train_set),
    batch_size=16
)

validation_dataloader = DataLoader(
    val_set,
    sampler=SequentialSampler(val_set),
    batch_size=16
)

In [None]:
model = BertForSequenceClassification.from_pretrained(
    'dkleczek/bert-base-polish-cased-v1',
    num_labels = 4,
)

optimizer = torch.optim.AdamW(model.parameters(),
                              lr = 5e-5,
                              eps = 1e-08
                              )

In [None]:
model.cuda()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

epochs = 10

total_confusion_matrix = np.zeros((4, 4), dtype=float)

total_val_accuracy = []
total_val_precision = []
total_val_recall = []
total_val_f1 = []

for epoch in trange(epochs, desc='Epoch'):

    model.train()

    tr_loss = 0
    nb_tr_examples, nb_tr_steps = 0, 0

    for step, batch in enumerate(train_dataloader):
        batch = tuple(t.to(device) for t in batch)
        b_input_ids, b_input_mask, b_labels = batch
        optimizer.zero_grad()

        train_output = model(b_input_ids,
                             token_type_ids=None,
                             attention_mask=b_input_mask,
                             labels=b_labels)

        train_output.loss.backward()
        optimizer.step()

        tr_loss += train_output.loss.item()
        nb_tr_examples += b_input_ids.size(0)
        nb_tr_steps += 1

    model.eval()

    val_labels = []
    val_predictions_probs = []

    val_accuracy = []
    val_precision = []
    val_recall = []

    epoch_confusion_matrix = np.zeros((4, 4), dtype=float)

    for batch in validation_dataloader:
        batch = tuple(t.to(device) for t in batch)
        b_input_ids, b_input_mask, b_labels = batch
        with torch.no_grad():
            eval_output = model(b_input_ids,
                                token_type_ids=None,
                                attention_mask=b_input_mask)

        logits = eval_output.logits.detach().cpu().numpy()
        label_ids = b_labels.to('cpu').numpy()

        val_labels.extend(label_ids)
        val_predictions_probs.extend(logits)

        for target_label in range(4):
            preds = np.argmax(logits, axis = 1).flatten()
            labels = label_ids.flatten()

            tp = sum([pred == labels[i] == target_label for i, pred in enumerate(preds)])
            fp = sum([pred == target_label and pred != labels[i] for i, pred in enumerate(preds)])
            tn = sum([pred != target_label and labels[i] != target_label for i, pred in enumerate(preds)])
            fn = sum([pred != target_label and labels[i] == target_label for i, pred in enumerate(preds)])

            epoch_confusion_matrix[target_label, target_label] += tp
            for other_label in range(4):
                if other_label != target_label:
                    epoch_confusion_matrix[other_label, target_label] += fp
                    epoch_confusion_matrix[target_label, other_label] += fn
                    epoch_confusion_matrix[other_label, other_label] += tn

            b_accuracy = (tp + tn) / (tp + tn + fp + fn)
            b_precision = tp / (tp + fp) if (tp + fp) > 0 else 'nan'
            b_recall = tp / (tp + fn) if (tp + fn) > 0 else 'nan'

            val_accuracy.append(b_accuracy)
            if b_precision != 'nan':
                val_precision.append(b_precision)
            if b_recall != 'nan':
                val_recall.append(b_recall)

    epoch_val_accuracy = sum(val_accuracy) / len(val_accuracy)
    epoch_val_precision = sum(val_precision) / len(val_precision) if len(val_precision) > 0 else 'nan'
    epoch_val_recall = sum(val_recall) / len(val_recall) if len(val_recall) > 0 else 'nan'
    epoch_val_f1 = (2 * epoch_val_precision * epoch_val_recall) / (epoch_val_precision + epoch_val_recall)

    print('\n\t - Train loss: {:.4f}'.format(tr_loss / nb_tr_steps))
    print('\t - Validation Accuracy: {:.4f}'.format(epoch_val_accuracy))
    print('\t - Validation Precision: {:.4f}'.format(epoch_val_precision))
    print('\t - Validation Recall: {:.4f}'.format(epoch_val_recall))
    print('\t - Validation F1-score: {:.4f}'.format(epoch_val_f1))

    total_confusion_matrix = total_confusion_matrix.astype(float) + epoch_confusion_matrix
    total_val_accuracy.append(epoch_val_accuracy)
    total_val_precision.append(epoch_val_precision)
    total_val_recall.append(epoch_val_recall)
    total_val_f1.append(epoch_val_f1)

total_avg_val_accuracy = sum(total_val_accuracy) / len(total_val_accuracy)
total_avg_val_precision = sum(total_val_precision) / len(total_val_precision) if len(total_val_precision) > 0 else 'nan'
total_avg_val_recall = sum(total_val_recall) / len(total_val_recall) if len(total_val_recall) > 0 else 'nan'
total_avg_val_f1 = sum(total_val_f1) / len(total_val_f1) if len(total_val_f1) > 0 else 'nan'

normalized_conf_matrix = total_confusion_matrix.astype('float') / total_confusion_matrix.sum(axis=1)[:, np.newaxis]
plt.figure(figsize=(8, 6))
labels = ["kradzież", "oszustwo", "przestępstwo przeciwko zdrowiu", "przestępstwo w komunikacji"]
sns.heatmap(np.round(normalized_conf_matrix, 2) * 100, annot=True, fmt=".0f", cmap="Blues", cbar_kws={'label': 'Procenty'}, xticklabels=labels, yticklabels=labels)
plt.title('Confusion Matrix')
plt.xlabel('Przewidziana Etykieta')
plt.ylabel('Rzeczywista Etykieta')
plt.show()

print(f'Total Validation Accuracy: {total_avg_val_accuracy:.4f}')
print(f'Total Validation Precision: {total_avg_val_precision:.4f}')
print(f'Total Validation Recall: {total_avg_val_recall:.4f}')
print(f'Total Validation F1-score: {total_avg_val_f1:.4f}')