In [None]:
import thulac
import pickle as pk
from tqdm import tqdm
import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
import json
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import multiprocessing
import random
import torch.optim as optim
import math
import re


torch.manual_seed(42)
np.random.seed(42)
random.seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)

torch.autograd.set_detect_anomaly(True)


print("CUDA available:", torch.cuda.is_available())
print("Number of GPUs:", torch.cuda.device_count())
for i in range(torch.cuda.device_count()):
    print(f"GPU {i}: {torch.cuda.get_device_name(i)}")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(0)

device = "cuda"

Cutter = thulac.thulac(seg_only=True)


hidden_size = 256
max_one_sentence_word_num = 128
max_one_document_sentence_num = 32
learning_rate = 1e-3
weight_decay = 1e-3
dropout_rate = 0.5
batch_size = 512
epoch_num = 16
filters = 64
min_gram = 2
max_gram = 5
window_sizes = [2, 3, 4, 5]
embedding_dim = 200


with open("law.txt", 'r', encoding='utf-8') as f:
    article_num = len(f.readlines())
with open("accu.txt", 'r', encoding='utf-8') as f:
    accusation_num = len(f.readlines())
penalty_num = 11


def load_mapping(file_path):
    id_to_name = {}
    name_to_id = {}
    with open(file_path, 'r', encoding='utf-8') as f:
        for idx, line in enumerate(f):
            name = line.strip()
            id_to_name[idx] = [name]
            name_to_id[name] = idx
    return id_to_name, name_to_id

id_to_article, article_to_id = load_mapping("law.txt")
article_num = len(id_to_article)
print(f"article_num: {article_num}")
print(f"Sample id_to_article: {dict(list(id_to_article.items())[:5])}")
print(f"Sample article_to_id: {dict(list(article_to_id.items())[:5])}")

id_to_crime, crime_to_id = load_mapping("accu.txt")
accusation_num = len(id_to_crime)
print(f"accusation_num: {accusation_num}")


def load_w2v_matrix(numpy_path: str, w2id_path: str):
    with open(w2id_path, 'rb') as f:
        word2id_dict = pk.load(f)
    array = np.load(numpy_path)
    if not isinstance(array, np.ndarray):
        raise TypeError(f"Expected np.ndarray, got {type(array)}")
    word_embedding = torch.from_numpy(array.astype(np.float32))
    return word_embedding, word2id_dict

w, d = load_w2v_matrix("cail_thulac.npy", "w2id_thulac.pkl")
print("Model loaded succeed")

def transform_word2id(word):
    return d.get(word, d["BLANK"])


def convert_imprisonment_to_term(tempterm):
    if tempterm["death_penalty"] == True or tempterm["life_imprisonment"] == True or tempterm["imprisonment"] == -1 or tempterm["imprisonment"] == -2:
        return 0
    else:
        imprisonment = tempterm["imprisonment"]
        if imprisonment > 10 * 12:
            return 1
        elif imprisonment > 7 * 12:
            return 2
        elif imprisonment > 5 * 12:
            return 3
        elif imprisonment > 3 * 12:
            return 4
        elif imprisonment > 2 * 12:
            return 5
        elif imprisonment > 1 * 12:
            return 6
        elif imprisonment > 9:
            return 7
        elif imprisonment > 6:
            return 8
        elif imprisonment > 0:
            return 9
        else:
            return 10


def map_term_to_months(term_idx):
    term_to_months = {
        0: 216,   # death_penalty or life_imprisonment -> 18 years (216 months)
        1: 168,   # (120, 216) -> 168 months
        2: 102,   # (84, 120] -> 102 months
        3: 72,    # (60, 84] -> 72 months
        4: 48,    # (36, 60] -> 48 months
        5: 30,    # (24, 36] -> 30 months
        6: 18,    # (12, 24] -> 18 months
        7: 10.5,  # (9, 12] -> 10.5 months
        8: 7.5,   # (6, 9] -> 7.5 months
        9: 3,     # (0, 6] -> 3 months
        10: 0     # 0
    }
    return term_to_months.get(term_idx, -3)  # -3 indicates an invalid prediction



def load_jsonlines(file_path):
    data = []
    skipped_lines = []
    with open(file_path, 'r', encoding='utf-8') as f:
        for line_num, line in enumerate(f, 1):
            cleaned_line = line.strip().strip('"')
            if not cleaned_line:
                continue
            try:
                item = json.loads(cleaned_line)
                data.append(item)
            except json.JSONDecodeError as e:
                print(f"Warning - Parsing error on line {line_num}: {e}, Line content: {cleaned_line}")
                skipped_lines.append((line_num, cleaned_line))
                default_item = {
                    "meta": {
                        "relevant_articles": [],
                        "accusation": [],
                        "term_of_imprisonment": {"imprisonment": -3, "death_penalty": False, "life_imprisonment": False}
                    }
                }
                data.append(default_item)
    if skipped_lines:
        print(f"Skipped {len(skipped_lines)} lines of invalid data, but default values were filled to maintain the same number of lines")
    return data


def calculate_imprisonment_score(true_imprisonments, pred_imprisonments):
    score_list = []
    abstentions = 0
    max_imprisonment = 216
    
    for true, pred in zip(true_imprisonments, pred_imprisonments):
        true_term = true[0]
        pred_term = pred[0]
        
        true_term = max_imprisonment if true_term in [-2, -1] else true_term
        pred_term = max_imprisonment if pred_term in [-2, -1] else pred_term
        
        if pred_term == -3:
            abstentions += 1
            score_list.append(math.log(max_imprisonment))
            continue
        
        if true_term < 0 or pred_term < 0:
            continue
        
        score_list.append(abs(math.log(true_term + 1) - math.log(pred_term + 1)))
    
    if not score_list:
        return {"score": 0, "abstention_rate": 1.0}
    
    log_distance = sum(score_list) / len(score_list)
    normalized_score = (math.log(max_imprisonment) - log_distance) / math.log(max_imprisonment)
    
    return {
        "score": normalized_score,
        "abstention_rate": abstentions / len(true_imprisonments)
    }


def preprocess_and_cache(file_path, cache_path):
    global d
    if os.path.exists(cache_path):
        print(f"Loading cached data from {cache_path}")
        with open(cache_path, 'rb') as f:
            data = pk.load(f)
        for i, (fact, article, accusation, penalty) in enumerate(data[:5]):
            print(f"Cached sample {i}: article={article}, accusation={accusation}")
        return data
    
    print(f"Generating cache for {file_path}")
    data = []
    with open(file_path, 'r', encoding='utf-8') as file:
        for line in tqdm(file, desc=f"Preprocessing {file_path}"):
            d_local = json.loads(line.strip())
            fact = [transform_word2id(w[0]) if isinstance(w, list) and w else d["BLANK"]
                    for w in Cutter.cut(d_local['fact'])
                    if isinstance(w, list) and w and w[0] not in [",", ".", "?", "\"", "”", "。", "？", "", "，", ",", "、", "”"]]
            min_length = min(window_sizes)
            if len(fact) < min_length:
                fact.extend([d["BLANK"]] * (min_length - len(fact)))
            
            
            article_raw = d_local['meta']['relevant_articles']
            article_labels = []
            for art in article_raw:
                art_str = str(art)
                if art_str in article_to_id:
                    article_labels.append(article_to_id[art_str])
                else:
                    print(f"Warning: Article {art_str} not in article_to_id")
            if not article_labels:
                continue
            
            
            accusation_raw = d_local['meta']['accusation']
            accusation_labels = []
            for acc in accusation_raw:
                acc_clean = re.sub(r"[\[\]]", "", acc).strip()
                if acc_clean in crime_to_id:
                    accusation_labels.append(crime_to_id[acc_clean])
                else:
                    print(f"Warning: Accusation {acc_clean} not in crime_to_id")
            if not accusation_labels:
                continue
            
            penalty = convert_imprisonment_to_term(d_local['meta']['term_of_imprisonment'])
            data.append((fact, article_labels, accusation_labels, penalty))
    
    with open(cache_path, 'wb') as f:
        pk.dump(data, f)
    print(f"Data cached to {cache_path}")
    return data


class DataAdapterDataset(Dataset):
    def __init__(self, mode='train'):
        file_path = 'train_data94835.json' if mode == 'train' else 'test_data7050.json'

        # cache
        if mode == 'train':
            cache_path = f"{mode}94835_cache.pkl"
        if mode == 'test':
            cache_path = f"{mode}7050_cache.pkl"
        
        self.data = preprocess_and_cache(file_path, cache_path)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return len(self.data)


def collate_fn(batch):
    facts = [item[0] for item in batch]
    max_len = min(max(len(f) for f in facts), max_one_sentence_word_num * max_one_document_sentence_num)
    fact_tensor = torch.zeros(len(batch), max_one_document_sentence_num, max_one_sentence_word_num, dtype=torch.long)
    for i, f in enumerate(facts):
        sent_idx = 0
        word_idx = 0
        for word in f[:max_len]:
            if word_idx >= max_one_sentence_word_num:
                word_idx = 0
                sent_idx += 1
            if sent_idx >= max_one_document_sentence_num:
                break
            fact_tensor[i, sent_idx, word_idx] = word
            word_idx += 1

    article_labels = torch.zeros(len(batch), article_num, dtype=torch.float)
    for i, labels in enumerate([item[1] for item in batch]):
        for lbl in labels:
            if lbl < article_num:
                article_labels[i, lbl] = 1.0

    accusation_labels = torch.zeros(len(batch), accusation_num, dtype=torch.float)
    for i, labels in enumerate([item[2] for item in batch]):
        for lbl in labels:
            if lbl < accusation_num:
                accusation_labels[i, lbl] = 1.0

    penalty_labels = torch.tensor([item[3] for item in batch], dtype=torch.long)
    return fact_tensor, article_labels, accusation_labels, penalty_labels

dataloader = DataLoader(
    DataAdapterDataset('train'), 
    batch_size=batch_size, 
    shuffle=True, 
    num_workers=min(multiprocessing.cpu_count(), 4),
    pin_memory=True,
    drop_last=False, 
    collate_fn=collate_fn
)
test_dataloader = DataLoader(
    DataAdapterDataset('test'), 
    batch_size=batch_size, 
    shuffle=False, 
    num_workers=min(multiprocessing.cpu_count(), 4), 
    pin_memory=True,
    drop_last=False, 
    collate_fn=collate_fn
)

print("Train samples:", len(dataloader.dataset))
print("Test samples:", len(test_dataloader.dataset))
print("Train batches:", len(dataloader))
print("Test batches:", len(test_dataloader))


class CNNEncoder(nn.Module):
    def __init__(self):
        super(CNNEncoder, self).__init__()
        self.convs = nn.ModuleList([nn.Conv2d(1, filters, (k, embedding_dim)) for k in range(min_gram, max_gram + 1)])

    def forward(self, x):
        sample_num = x.size(0)
        sentence_num = x.size(1)
        sentence_len = x.size(2)
        x = x.view(sample_num, 1, -1, embedding_dim)
        conv_out = []
        gram=min_gram
        for conv in self.convs:
            y = F.relu(conv(x)).view(sample_num, filters, -1)
            y = F.max_pool1d(y, kernel_size=sentence_num * sentence_len - gram + 1).view(sample_num, -1)
            conv_out.append(y)
            gram += 1
        conv_out = torch.cat(conv_out, dim=1)
        return conv_out.view(-1, (max_gram - min_gram + 1) * filters)

class LSTMDecoder(nn.Module):
    def __init__(self):
        super(LSTMDecoder, self).__init__()
        self.feature_len = hidden_size
        self.hidden_dim = hidden_size
        self.outfc = nn.ModuleList([nn.Linear(hidden_size, article_num), nn.Linear(hidden_size, accusation_num), nn.Linear(hidden_size, penalty_num)])
        self.midfc = nn.ModuleList([nn.Linear(hidden_size, hidden_size) for _ in range(3)])
        self.cell_list = nn.ModuleList([None] + [nn.LSTMCell(hidden_size, hidden_size) for _ in range(3)])
        self.hidden_state_fc_list = nn.ModuleList([nn.ModuleList([nn.Linear(hidden_size, hidden_size) for _ in range(4)]) for _ in range(4)])
        self.cell_state_fc_list = nn.ModuleList([nn.ModuleList([nn.Linear(hidden_size, hidden_size) for _ in range(4)]) for _ in range(4)])
        self.sigmoid = nn.Sigmoid()

    def init_hidden(self, batch_size):
        self.hidden_list = [(torch.zeros(batch_size, self.hidden_dim).to(device), torch.zeros(batch_size, self.hidden_dim).to(device)) for _ in range(4)]

    def forward(self, x):
        sample_num = x.size(0)
        self.init_hidden(sample_num)
        outputs = []
        graph = generate_graph()
        first = [True] * 4
        for a in range(1, 4):
            hx, cx = self.hidden_list[a]
            h, c = self.cell_list[a](x, (hx, cx))
            for b in range(1, 4):
                if graph[a][b]:
                    hp, cp = self.hidden_list[b]
                    if first[b]:
                        first[b] = False
                        hp, cp = h, c
                    else:
                        hp = hp + self.hidden_state_fc_list[a][b](h)
                        cp = cp + self.cell_state_fc_list[a][b](c)
                    self.hidden_list[b] = (hp, cp)
            
            output = self.outfc[a - 1](h).view(sample_num, -1)
            outputs.append(output)
        return outputs

def generate_graph():
    s = "[(1 2),(2 3),(1 3)]"
    arr = s.replace("[", "").replace("]", "").split(",")
    graph = []
    n = 0
    if s == "[]":
        arr = []
        n = 3
    for a in range(len(arr)):
        arr[a] = arr[a].replace("(", "").replace(")", "").split()
        arr[a][0] = int(arr[a][0])
        arr[a][1] = int(arr[a][1])
        n = max(n, max(arr[a][0], arr[a][1]))
    n += 1
    for a in range(n):
        graph.append([False] * n)
    for a in range(len(arr)):
        graph[arr[a][0]][arr[a][1]] = True
    return graph

class TopJudge(nn.Module):
    def __init__(self, embedding: np.array, dropout_rate: float = dropout_rate):
        super(TopJudge, self).__init__()
        self.embs = nn.Embedding(164673, 200)
        self.embs.weight.data.copy_(embedding)
        self.embs.weight.requires_grad = False
        self.encoder = CNNEncoder()
        self.decoder = LSTMDecoder()
        self.dropout = nn.Dropout(dropout_rate)    

    def forward(self, x):
        x = self.embs(x)
        x = self.encoder(x)
        x = self.dropout(x)
        x = self.decoder(x)
        return x

model = TopJudge(w).to(device)

# Loss Function and Optimizer
criterion_multi = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([50.0]).to(device))
criterion_single = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)


for epoch in range(epoch_num):
    for step, batch in enumerate(tqdm(dataloader, desc=f"Epoch {epoch+1}/{epoch_num}")):
        model.train()
        optimizer.zero_grad()
        facts_embedding = batch[0].to(device)
        article_labels = batch[1].to(device)
        accusation_labels = batch[2].to(device)
        penalty_labels = batch[3].to(device)

        o = model(facts_embedding)
        
        if torch.isnan(o[2]).any() or torch.isinf(o[2]).any():
            print(f"Step {step+1}: o[2] contains nan or inf: {o[2]}")
            break
        
        loss1 = criterion_multi(o[0], article_labels)
        loss2 = criterion_multi(o[1], accusation_labels)
        loss3 = criterion_single(o[2], penalty_labels)
        loss = loss1 + loss2 + loss3
        loss.backward()
        print(f"Epoch {epoch+1}, Step {step+1}, Loss: {loss.item()}, article Loss: {loss1.item()}, accusation Loss: {loss2.item()}, penalty Loss: {loss3.item()}")
        for name, param in model.named_parameters():
            if param.grad is not None:
                print(f"{name} grad norm: {param.grad.norm()}")
                break
        optimizer.step()


penalty_strings = {
    0: "death_penalty or life_imprisonment",
    1: "(10 years - ∞)",
    2: "(7 years - 10 years]",
    3: "(5 years - 7 years]",
    4: "(3 years - 5 years]",
    5: "(2 years - 3 years]",
    6: "(1 year - 2 years]",
    7: "(9 months - 12 months]",
    8: "(6 months - 9 months]",
    9: "(0 - 6 months]",
    10: "0"
}


article_labels_all = []
article_preds_all = []
accusation_labels_all = []
accusation_preds_all = []
penalty_labels_all = []
penalty_preds_all = []


test_data = load_jsonlines('test_data7050.json')
true_imprisonments = []
for item in test_data:
    term = item['meta']['term_of_imprisonment']
    if term['death_penalty']:
        imprisonment = 216  # death_penalty
    elif term['life_imprisonment']:
        imprisonment = 216  # life_imprisonment
    else:
        imprisonment = term['imprisonment']
    true_imprisonments.append([imprisonment])


output_file = "prediction_results.txt"
comparison_file = "imprisonment_comparison.txt"

with open(output_file, 'w', encoding='utf-8') as f, open(comparison_file, 'w', encoding='utf-8') as cf:
    line_num = 0
    pred_imprisonments = []
    cf.write("Line\tPred_Term_Index\tPred_Term_Str\tTrue_Term_Index\tTrue_Term_Str\tPred_Months\tTrue_Months\tMatch\n")
    
    with torch.no_grad():
        for step, batch in enumerate(tqdm(test_dataloader, desc="Testing")):
            model.eval()
            facts_embedding = batch[0].to(device)
            article_labels = batch[1].to(device)
            accusation_labels = batch[2].to(device)
            penalty_labels = batch[3].to(device)
            o = model(facts_embedding)
            
            article_preds_raw = torch.sigmoid(o[0])
            accusation_preds_raw = torch.sigmoid(o[1])
            k = 1
            _, article_topk_indices = torch.topk(article_preds_raw, k, dim=1)
            _, accusation_topk_indices = torch.topk(accusation_preds_raw, k, dim=1)
            article_preds = torch.zeros_like(article_preds_raw)
            accusation_preds = torch.zeros_like(accusation_preds_raw)
            for i in range(article_preds.size(0)):
                article_preds[i, article_topk_indices[i]] = 1.0
                accusation_preds[i, accusation_topk_indices[i]] = 1.0
            
            penalty_preds = torch.argmax(o[2], dim=1)
            
            batch_pred_imprisonments = [[map_term_to_months(pred.item())] for pred in penalty_preds]
            pred_imprisonments.extend(batch_pred_imprisonments)
            
            batch_true_imprisonments = true_imprisonments[line_num:line_num + len(penalty_labels)]
            for i in range(len(penalty_labels)):
                true_penalty = penalty_labels[i].item()
                pred_penalty = penalty_preds[i].item()
                true_penalty_str = penalty_strings.get(true_penalty, "Unknown")
                pred_penalty_str = penalty_strings.get(pred_penalty, "Unknown")
                pred_months = map_term_to_months(pred_penalty)
                true_months = batch_true_imprisonments[i][0]
                penalty_match = 1 if true_penalty == pred_penalty else 0
                cf.write(f"{line_num + 1}\t{pred_penalty}\t{pred_penalty_str}\t{true_penalty}\t{true_penalty_str}\t{pred_months}\t{true_months}\t{penalty_match}\n")
                line_num += 1

                true_article = [int(idx) for idx, val in enumerate(article_labels[i]) if val == 1]
                pred_article = [int(idx) for idx, val in enumerate(article_preds[i]) if val == 1]
                true_article_str = [id_to_article.get(idx, ["Unknown article"])[0] for idx in true_article]
                pred_article_str = [id_to_article.get(idx, ["Unknown article"])[0] for idx in pred_article]
                article_match = 1 if set(true_article) == set(pred_article) else 0
                f.write(f"{line_num} (predict: article, ans: article)\t{pred_article_str}\t{true_article_str}\t{article_match}\n")

                true_accusation = [int(idx) for idx, val in enumerate(accusation_labels[i]) if val == 1]
                pred_accusation = [int(idx) for idx, val in enumerate(accusation_preds[i]) if val == 1]
                true_accusation_str = [id_to_crime.get(idx, ["Unknown accusation"])[0] for idx in true_accusation]
                pred_accusation_str = [id_to_crime.get(idx, ["Unknown accusation"])[0] for idx in pred_accusation]
                accusation_match = 1 if set(true_accusation) == set(pred_accusation) else 0
                f.write(f"{line_num} (predict: accusation, ans: accusation)\t{pred_accusation_str}\t{true_accusation_str}\t{accusation_match}\n")

                f.write(f"{line_num} (predict: penalty, ans: penalty)\t[{pred_penalty_str}]\t[{true_penalty_str}]\t{penalty_match}\n")

            article_labels_all.extend(article_labels.cpu().numpy())
            article_preds_all.extend(article_preds.cpu().numpy())
            accusation_labels_all.extend(accusation_labels.cpu().numpy())
            accusation_preds_all.extend(accusation_preds.cpu().numpy())
            penalty_labels_all.extend(penalty_labels.cpu().numpy())
            penalty_preds_all.extend(penalty_preds.cpu().numpy())


imprisonment_metrics = calculate_imprisonment_score(true_imprisonments, pred_imprisonments)
with open(output_file, 'a', encoding='utf-8') as f:
    f.write(f"\nImprisonment Normalized Score: {imprisonment_metrics['score']:.3f}\n")
    f.write(f"Imprisonment Abstention Rate: {imprisonment_metrics['abstention_rate']:.3f}\n")


true_articles_bin = np.array(article_labels_all)
pred_articles_bin = np.array(article_preds_all)
true_accusations_bin = np.array(accusation_labels_all)
pred_accusations_bin = np.array(accusation_preds_all)
true_imprisonments_bin = np.array(penalty_labels_all)
pred_imprisonments_bin = np.array(penalty_preds_all)


def compute_confusion_matrix(true_bin, pred_bin, num_labels):
    res = [{"TP": 0, "FP": 0, "FN": 0, "TN": 0} for _ in range(num_labels)]
    for i in range(num_labels):
        y_true = true_bin[:, i]
        y_pred = pred_bin[:, i]
        res[i]["TP"] = int((y_true * y_pred).sum())
        res[i]["FN"] = int((y_true * (1 - y_pred)).sum())
        res[i]["FP"] = int(((1 - y_true) * y_pred).sum())
        res[i]["TN"] = int(((1 - y_true) * (1 - y_pred)).sum())
    return res


def get_value(res):
    if res["TP"] == 0:
        if res["FP"] == 0 and res["FN"] == 0:
            precision = 1.0
            recall = 1.0
            f1 = 1.0
        else:
            precision = 0.0
            recall = 0.0
            f1 = 0.0
    else:
        precision = 1.0 * res["TP"] / (res["TP"] + res["FP"]) if (res["TP"] + res["FP"]) > 0 else 0.0
        recall = 1.0 * res["TP"] / (res["TP"] + res["FN"]) if (res["TP"] + res["FN"]) > 0 else 0.0
        f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
    return precision, recall, f1

def gen_result(res):
    precision = []
    recall = []
    f1 = []
    total = {"TP": 0, "FP": 0, "FN": 0, "TN": 0}
    for a in range(len(res)):
        total["TP"] += res[a]["TP"]
        total["FP"] += res[a]["FP"]
        total["FN"] += res[a]["FN"]
        total["TN"] += res[a]["TN"]
        p, r, f = get_value(res[a])
        precision.append(p)
        recall.append(r)
        f1.append(f)

    macro_precision = sum(precision) / len(precision) if precision else 0.0
    macro_recall = sum(recall) / len(recall) if recall else 0.0
    macro_f1 = sum(f1) / len(f1) if f1 else 0.0

    return macro_precision, macro_recall, macro_f1


def imprisonment_multi_label_accuracy(true_labels, pred_labels):
    correct = np.sum(np.all(true_labels == pred_labels, axis=1))
    total = len(true_labels)
    return correct / total if total > 0 else 0

def multi_label_accuracy(true_labels, pred_labels):
    correct = 0
    for true, pred in zip(true_labels, pred_labels):
        true_set = set(np.where(true == 1)[0])
        pred_set = set(np.where(pred == 1)[0])
        if true_set == pred_set or true_set.issubset(pred_set) or pred_set.issubset(true_set):
            correct += 1
    total = len(true_labels)
    return correct / total if total > 0 else 0


articles_acc = multi_label_accuracy(true_articles_bin, pred_articles_bin)
accusations_acc = multi_label_accuracy(true_accusations_bin, pred_accusations_bin)
imprisonments_acc = imprisonment_multi_label_accuracy(np.eye(penalty_num)[true_imprisonments_bin], np.eye(penalty_num)[pred_imprisonments_bin])

res_articles = compute_confusion_matrix(true_articles_bin, pred_articles_bin, article_num)
macro_p_a, macro_r_a, macro_f_a = gen_result(res_articles)
res_accusations = compute_confusion_matrix(true_accusations_bin, pred_accusations_bin, accusation_num)
macro_p_c, macro_r_c, macro_f_c = gen_result(res_accusations)
res_imprisonments = compute_confusion_matrix(np.eye(penalty_num)[true_imprisonments_bin], np.eye(penalty_num)[pred_imprisonments_bin], penalty_num)
macro_p_i, macro_r_i, macro_f_i = gen_result(res_imprisonments)


# print("relevant_articles:")
print("Law_Articles:")
print(f"  Accuracy: {articles_acc:.3f}")
print(f"  Macro Precision: {macro_p_a:.3f}")
print(f"  Macro Recall: {macro_r_a:.3f}")
print(f"  Macro F1: {macro_f_a:.3f}")

# print("accusation:")
print("Charges:")
print(f"  Accuracy: {accusations_acc:.3f}")
print(f"  Macro Precision: {macro_p_c:.3f}")
print(f"  Macro Recall: {macro_r_c:.3f}")
print(f"  Macro F1: {macro_f_c:.3f}")

# print("imprisonment:")
print(" Terms of Penalty:")
print(f"  Accuracy: {imprisonments_acc:.3f}")
print(f"  Macro Precision: {macro_p_i:.3f}")
print(f"  Macro Recall: {macro_r_i:.3f}")
print(f"  Macro F1: {macro_f_i:.3f}")
print(f"  Normalized Score: {imprisonment_metrics['score']:.3f}")
print(f"  Abstention Rate: {imprisonment_metrics['abstention_rate']:.3f}")

