In [1]:
!export CUDA_VISIBLE_DEVICES=0


In [2]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

# Now import your GPU-related libraries, such as PyTorch
import torch
torch.cuda.empty_cache()

torch.cuda.memory_summary(device=None, abbreviated=False)




In [3]:
# Import libraries and set up CUDA
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

import torch
torch.cuda.empty_cache()

import pandas as pd
import numpy as np
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset, random_split
from transformers import AutoTokenizer, AutoModel, AdamW, get_linear_schedule_with_warmup
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from tqdm import tqdm
import random



random_seed = 777  # You can change this value to any seed you prefer
random.seed(random_seed)
np.random.seed(random_seed)
torch.manual_seed(random_seed)
torch.cuda.manual_seed_all(random_seed)


# Load and preprocess your data for H1, H2, and H3
model1_data = pd.read_csv("model1_2636.csv")
texts_h1 = model1_data['text'].tolist()
labels_h1 = model1_data.iloc[:, 1].values.tolist()

model2_data = pd.read_csv("model2_2636.csv")
texts_h2 = model2_data['text'].tolist()
labels_h2 = model2_data.iloc[:, 1].values.tolist()

model3_data = pd.read_csv("model3_2636.csv")
texts_h3 = model3_data['text'].tolist()
labels_h3 = model3_data.iloc[:, 1:].values.tolist()

#print("Total Records in H1:", len(texts_h1),"texts_h1", "with", len(labels_h1), "labels")
#print("Total Records in H2:", len(texts_h2),"texts_h2","with", len(labels_h2), "labels")
#print("Total Records in H3:", len(texts_h3),"texts_h3", "with", len(labels_h3), "labels")

# Split the data for each hierarchy
def split_data(texts, labels, train_ratio, val_ratio, test_ratio):
    total_samples = len(texts)
    train_size = int(total_samples * train_ratio)
    val_size = int(total_samples * val_ratio)
    test_size = total_samples - train_size - val_size  # Remaining samples for test

    train_texts = texts[:train_size]
    val_texts = texts[train_size:train_size + val_size]
    test_texts = texts[train_size + val_size:]

    train_labels = labels[:train_size]
    val_labels = labels[train_size:train_size + val_size]
    test_labels = labels[train_size + val_size:]

    return train_texts, val_texts, test_texts, train_labels, val_labels, test_labels

# Split the data for each hierarchy
train_texts_h1, val_texts_h1, test_texts_h1, train_labels_h1, val_labels_h1, test_labels_h1 = split_data(texts_h1, labels_h1, 0.7, 0.1, 0.2)
train_texts_h2, val_texts_h2, test_texts_h2, train_labels_h2, val_labels_h2, test_labels_h2 = split_data(texts_h2, labels_h2, 0.7, 0.1, 0.2)
train_texts_h3, val_texts_h3, test_texts_h3, train_labels_h3, val_labels_h3, test_labels_h3 = split_data(texts_h3, labels_h3, 0.7, 0.1, 0.2)

#print("Train Val Test Splitting:")
#print("H1:", len(train_texts_h1), len(val_texts_h1), len(test_texts_h1))
#print("H2:", len(train_texts_h2), len(val_texts_h2), len(test_texts_h2))
#print("H3:", len(train_texts_h3), len(val_texts_h3), len(test_texts_h3))

# Tokenizer

# Import and initialize the ClinicalBERT tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("medicalai/ClinicalBERT")
clinical_bert_model = AutoModel.from_pretrained("medicalai/ClinicalBERT")


# Define the TextDataset class
class TextDataset(Dataset):
    def __init__(self, texts, labels_h1, labels_h2, labels_h3, tokenizer, max_length=512):
        self.texts = texts
        self.labels_h1 = labels_h1  # Now expecting a list of binary values for H1
        self.labels_h2 = np.array(labels_h2).reshape(-1, 1)  # Reshape labels_h2 to [batch_size, 1]
        self.labels_h3 = labels_h3  # Multi-label but binary values for H3
        self.tokenizer = tokenizer
        self.max_length = max_length

    # Rest of the code remains unchanged

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        inputs = self.tokenizer.encode_plus(
            text, None, add_special_tokens=True, max_length=self.max_length,
            padding='max_length', return_token_type_ids=True, truncation=True
        )
        return {
            'ids': torch.tensor(inputs['input_ids'], dtype=torch.long),
            'mask': torch.tensor(inputs['attention_mask'], dtype=torch.long),
            'labels_h1': torch.tensor(self.labels_h1[idx], dtype=torch.float).unsqueeze(-1),
            'labels_h2': torch.tensor(self.labels_h2[idx], dtype=torch.float),
            'labels_h3': torch.tensor(self.labels_h3[idx], dtype=torch.float)
        }

# Create datasets and dataloaders for H1, H2, and H3
dataset_h1 = TextDataset(train_texts_h1 + val_texts_h1 + test_texts_h1, train_labels_h1 + val_labels_h1 + test_labels_h1, train_labels_h2 + val_labels_h2 + test_labels_h2, train_labels_h3 + val_labels_h3 + test_labels_h3, tokenizer)
dataset_h2 = TextDataset(train_texts_h2 + val_texts_h2 + test_texts_h2, train_labels_h1 + val_labels_h1 + test_labels_h1, train_labels_h2 + val_labels_h2 + test_labels_h2, train_labels_h3 + val_labels_h3 + test_labels_h3, tokenizer)
dataset_h3 = TextDataset(train_texts_h3 + val_texts_h3 + test_texts_h3, train_labels_h1 + val_labels_h1 + test_labels_h1, train_labels_h2 + val_labels_h2 + test_labels_h2, train_labels_h3 + val_labels_h3 + test_labels_h3, tokenizer)

# Split the datasets into train, val, and test for H1, H2, and H3
train_size_h1 = len(train_texts_h1)
val_size_h1 = len(val_texts_h1)
train_size_h2 = len(train_texts_h2)
val_size_h2 = len(val_texts_h2)
train_size_h3 = len(train_texts_h3)
val_size_h3 = len(val_texts_h3)

train_dataset_h1, val_dataset_h1, test_dataset_h1 = random_split(dataset_h1, [train_size_h1, val_size_h1, len(test_texts_h1)])
train_dataset_h2, val_dataset_h2, test_dataset_h2 = random_split(dataset_h2, [train_size_h2, val_size_h2, len(test_texts_h2)])
train_dataset_h3, val_dataset_h3, test_dataset_h3 = random_split(dataset_h3, [train_size_h3, val_size_h3, len(test_texts_h3)])

# Create dataloaders for H1, H2, and H3
train_dataloader_h1 = DataLoader(train_dataset_h1, batch_size=8, shuffle=True)
val_dataloader_h1 = DataLoader(val_dataset_h1,batch_size=8, shuffle=False)  # You can set shuffle to True if you want to shuffle the validation data.
train_dataloader_h2 = DataLoader(train_dataset_h2, batch_size=8, shuffle=True)
val_dataloader_h2 = DataLoader(val_dataset_h2, batch_size=8, shuffle=False)

train_dataloader_h3 = DataLoader(train_dataset_h3, batch_size=8, shuffle=True)
val_dataloader_h3 = DataLoader(val_dataset_h3, batch_size=8, shuffle=False)

# Define the model architecture for H1, H2, and H3
class HierarchicalClassifier(nn.Module):
    def __init__(self, num_labels_h3):
        super(HierarchicalClassifier, self).__init__()
        self.bert = clinical_bert_model  # Use the ClinicalBERT model
        self.dropout = nn.Dropout(0.1)
        self.fc_h1 = nn.Linear(self.bert.config.hidden_size, 1)
        self.fc_h2 = nn.Linear(self.bert.config.hidden_size, 1)
        self.fc_h3 = nn.Linear(self.bert.config.hidden_size, num_labels_h3)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids, attention_mask)
        pooled_output = outputs.last_hidden_state[:, 0, :]
        logits_h1 = self.fc_h1(self.dropout(pooled_output))
        logits_h2 = self.fc_h2(self.dropout(pooled_output))
        logits_h3 = self.fc_h3(self.dropout(pooled_output))
        return logits_h1, logits_h2, logits_h3

# Initialize and move the model to the appropriate device (CPU/GPU)
model = HierarchicalClassifier(num_labels_h3=len(train_labels_h3[0]))

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Define loss function and optimizer
criterion = nn.BCEWithLogitsLoss()
optimizer = AdamW(model.parameters(), lr=1e-5)  # You can adjust the learning rate as needed

# Training loop for H1, H2, and H3
def train(model, dataloader, optimizer, criterion, device):
    model.train()
    total_loss = 0.0

    for batch in tqdm(dataloader, total=len(dataloader), desc="Training"):
        input_ids = batch['ids'].to(device)
        attention_mask = batch['mask'].to(device)
        labels_h1 = batch['labels_h1'].to(device)
        labels_h2 = batch['labels_h2'].to(device)
        labels_h3 = batch['labels_h3'].to(device)

        optimizer.zero_grad()

        logits_h1, logits_h2, logits_h3 = model(input_ids, attention_mask)
        loss_h1 = criterion(logits_h1, labels_h1)
        loss_h2 = criterion(logits_h2, labels_h2)
        loss_h3 = criterion(logits_h3, labels_h3)

        loss = loss_h1 + loss_h2 + loss_h3
        total_loss += loss.item()

        loss.backward()
        optimizer.step()

    return total_loss / len(dataloader)

# Validation loop for H1, H2, and H3
def evaluate(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0.0
    all_preds_h1, all_preds_h2, all_preds_h3 = [], [], []
    all_labels_h1, all_labels_h2, all_labels_h3 = [], [], []

    with torch.no_grad():
        for batch in tqdm(dataloader, total=len(dataloader), desc="Validation"):
            input_ids = batch['ids'].to(device)
            attention_mask = batch['mask'].to(device)
            labels_h1 = batch['labels_h1'].to(device)
            labels_h2 = batch['labels_h2'].to(device)
            labels_h3 = batch['labels_h3'].to(device)

            logits_h1, logits_h2, logits_h3 = model(input_ids, attention_mask)
            loss_h1 = criterion(logits_h1, labels_h1)
            loss_h2 = criterion(logits_h2, labels_h2)
            loss_h3 = criterion(logits_h3, labels_h3)

            loss = loss_h1 + loss_h2 + loss_h3
            total_loss += loss.item()

            preds_h1 = torch.sigmoid(logits_h1)
            preds_h2 = torch.sigmoid(logits_h2)
            preds_h3 = torch.sigmoid(logits_h3)

            all_preds_h1.extend(preds_h1.cpu().numpy())
            all_preds_h2.extend(preds_h2.cpu().numpy())
            all_preds_h3.extend(preds_h3.cpu().numpy())

            all_labels_h1.extend(labels_h1.cpu().numpy())
            all_labels_h2.extend(labels_h2.cpu().numpy())
            all_labels_h3.extend(labels_h3.cpu().numpy())

    return total_loss / len(dataloader), all_preds_h1, all_preds_h2, all_preds_h3, all_labels_h1, all_labels_h2, all_labels_h3

# Training and evaluation for H1, H2, and H3
# Custom accuracy function
def custom_accuracy(y_true, y_pred):
    correct_labels = np.sum(np.equal(y_true, y_pred), axis=1)
    total_labels = y_true.shape[1]
    sample_accuracy = np.mean(correct_labels / total_labels)
    return sample_accuracy

def custom_precision_recall_f1(y_true, y_pred):
    # Calculate precision and recall for each sample
    sample_precisions = []
    sample_recalls = []
    for true, pred in zip(y_true, y_pred):
        true_positive = np.sum((true == 1) & (pred == 1))
        false_positive = np.sum((true == 0) & (pred == 1))
        false_negative = np.sum((true == 1) & (pred == 0))
        
        sample_precision = true_positive / (true_positive + false_positive) if (true_positive + false_positive) > 0 else 0
        sample_recall = true_positive / (true_positive + false_negative) if (true_positive + false_negative) > 0 else 0
        
        sample_precisions.append(sample_precision)
        sample_recalls.append(sample_recall)
    
    # Calculate average precision and recall across all samples
    avg_precision = np.mean(sample_precisions)
    avg_recall = np.mean(sample_recalls)
    
    # Calculate F1 score
    if avg_precision + avg_recall > 0:
        avg_f1 = 2 * (avg_precision * avg_recall) / (avg_precision + avg_recall)
    else:
        avg_f1 = 0
    
    return avg_precision, avg_recall, avg_f1


# Training and evaluation for H1, H2, and H3
num_epochs = 33 # You can adjust the number of epochs as needed

for epoch in range(num_epochs):
    # Training for H1
    train_loss_h1 = train(model, train_dataloader_h1, optimizer, criterion, device)
    # Training for H2
    train_loss_h2 = train(model, train_dataloader_h2, optimizer, criterion, device)
    # Training for H3
    train_loss_h3 = train(model, train_dataloader_h3, optimizer, criterion, device)

    # Validation for H1
    val_loss_h1, val_preds_h1, _, _, val_labels_h1, _, _ = evaluate(model, val_dataloader_h1, criterion, device)
    # Validation for H2
    val_loss_h2, _, val_preds_h2, _, _, val_labels_h2, _ = evaluate(model, val_dataloader_h2, criterion, device)
    # Validation for H3
    val_loss_h3, _, _, val_preds_h3, _, _, val_labels_h3 = evaluate(model, val_dataloader_h3, criterion, device)

    # Metrics calculation for H1
    threshold_h1 = 0.5
    val_preds_h1_binary = (np.array(val_preds_h1) > threshold_h1).astype(int)

    acc_h1 = accuracy_score(val_labels_h1, val_preds_h1_binary)
    precision_h1 = precision_score(val_labels_h1, val_preds_h1_binary, average='micro')
    recall_h1 = recall_score(val_labels_h1, val_preds_h1_binary, average='micro')
    f1_h1 = f1_score(val_labels_h1, val_preds_h1_binary, average='micro')


    # Metrics calculation for H2
    threshold_h2 = 0.5
    val_preds_h2_binary = (np.array(val_preds_h2) > threshold_h2).astype(int)

    acc_h2 = accuracy_score(val_labels_h2, val_preds_h2_binary)
    precision_h2 = precision_score(val_labels_h2, val_preds_h2_binary, average='micro')
    recall_h2 = recall_score(val_labels_h2, val_preds_h2_binary, average='micro')
    f1_h2 = f1_score(val_labels_h2, val_preds_h2_binary, average='micro')

# Convert predictions and labels to binary using the threshold

    # Metrics calculation for H3
    threshold_h3 = 0.5
    val_preds_h3_binary = (np.array(val_preds_h3) > threshold_h3).astype(int)


    # Custom accuracy calculation for H3
    acc_h3_custom = custom_accuracy(np.array(val_labels_h3), val_preds_h3_binary)
    # Calculate custom precision, recall, and F1 score for H3
    precision_h3, recall_h3, f1_score_h3 = custom_precision_recall_f1(np.array(val_labels_h3), val_preds_h3_binary)

    print(f"Epoch {epoch+1}/{num_epochs}")
    print(f"H1: Train Loss: {train_loss_h1:.4f}, Val Loss: {val_loss_h1:.4f}, Accuracy: {acc_h1:.4f}")
    print(f"Precision H1: {precision_h1:.4f}, Recall H1: {recall_h1:.4f}, F1 H1: {f1_h1:.4f}")

    print(f"H2: Train Loss: {train_loss_h2:.4f}, Val Loss: {val_loss_h2:.4f}, Accuracy: {acc_h2:.4f}")
    print(f"Precision H2: {precision_h2:.4f}, Recall H2: {recall_h2:.4f}, F1 H2: {f1_h2:.4f}")

    print(f"H3: Train Loss: {train_loss_h3:.4f}, Val Loss: {val_loss_h3:.4f}, Custom Accuracy: {acc_h3_custom:.4f}")
    print(f"Precision H3: {precision_h3},Recall H3:{recall_h3},F1 Score H3: {f1_score_h3}")

# Saving the model
torch.save(model.state_dict(), 'model_clinicalBert.pth')

# Loading the model
#loaded_model = YourModelClass(*args, **kwargs)  # Instantiate your model
#loaded_model.load_state_dict(torch.load('model.pth'))
#loaded_model.eval()  # Set the model to evaluation mode



Training: 100%|██████████| 231/231 [01:32<00:00,  2.49it/s]
Training: 100%|██████████| 231/231 [01:35<00:00,  2.43it/s]
Training: 100%|██████████| 231/231 [01:34<00:00,  2.44it/s]
Validation: 100%|██████████| 33/33 [00:05<00:00,  6.13it/s]
Validation: 100%|██████████| 33/33 [00:05<00:00,  6.14it/s]
Validation: 100%|██████████| 33/33 [00:05<00:00,  6.24it/s]


Epoch 1/33
H1: Train Loss: 1.7098, Val Loss: 1.4100, Accuracy: 0.9125
Precision H1: 0.9125, Recall H1: 0.9125, F1 H1: 0.9125
H2: Train Loss: 1.5122, Val Loss: 1.4480, Accuracy: 0.6578
Precision H2: 0.6578, Recall H2: 0.6578, F1 H2: 0.6578
H3: Train Loss: 1.4465, Val Loss: 1.3906, Custom Accuracy: 0.6739
Precision H3: 0.6899420604743798,Recall H3:0.48436510679856687,F1 Score H3: 0.5691591929606629


Training: 100%|██████████| 231/231 [01:34<00:00,  2.46it/s]
Training: 100%|██████████| 231/231 [01:36<00:00,  2.40it/s]
Training: 100%|██████████| 231/231 [01:37<00:00,  2.38it/s]
Validation: 100%|██████████| 33/33 [00:05<00:00,  5.82it/s]
Validation: 100%|██████████| 33/33 [00:05<00:00,  5.81it/s]
Validation: 100%|██████████| 33/33 [00:05<00:00,  5.88it/s]


Epoch 2/33
H1: Train Loss: 1.3957, Val Loss: 1.2153, Accuracy: 0.9544
Precision H1: 0.9544, Recall H1: 0.9544, F1 H1: 0.9544
H2: Train Loss: 1.3274, Val Loss: 1.2787, Accuracy: 0.7490
Precision H2: 0.7490, Recall H2: 0.7490, F1 H2: 0.7490
H3: Train Loss: 1.2335, Val Loss: 1.2514, Custom Accuracy: 0.6844
Precision H3: 0.6844287524895889,Recall H3:0.5280422365973697,F1 Score H3: 0.5961499986540789


Training: 100%|██████████| 231/231 [01:36<00:00,  2.38it/s]
Training: 100%|██████████| 231/231 [01:37<00:00,  2.38it/s]
Training: 100%|██████████| 231/231 [01:36<00:00,  2.38it/s]
Validation: 100%|██████████| 33/33 [00:05<00:00,  5.82it/s]
Validation: 100%|██████████| 33/33 [00:05<00:00,  5.78it/s]
Validation: 100%|██████████| 33/33 [00:05<00:00,  5.88it/s]


Epoch 3/33
H1: Train Loss: 1.1388, Val Loss: 0.9718, Accuracy: 0.9620
Precision H1: 0.9620, Recall H1: 0.9620, F1 H1: 0.9620
H2: Train Loss: 1.0500, Val Loss: 1.0602, Accuracy: 0.8441
Precision H2: 0.8441, Recall H2: 0.8441, F1 H2: 0.8441
H3: Train Loss: 0.9434, Val Loss: 1.0137, Custom Accuracy: 0.6976
Precision H3: 0.6745156617780192,Recall H3:0.5900221662579077,F1 Score H3: 0.6294460839582667


Training: 100%|██████████| 231/231 [01:36<00:00,  2.38it/s]
Training: 100%|██████████| 231/231 [01:36<00:00,  2.38it/s]
Training: 100%|██████████| 231/231 [01:36<00:00,  2.38it/s]
Validation: 100%|██████████| 33/33 [00:05<00:00,  5.82it/s]
Validation: 100%|██████████| 33/33 [00:05<00:00,  5.77it/s]
Validation: 100%|██████████| 33/33 [00:05<00:00,  5.89it/s]


Epoch 4/33
H1: Train Loss: 0.8536, Val Loss: 0.9379, Accuracy: 0.9696
Precision H1: 0.9696, Recall H1: 0.9696, F1 H1: 0.9696
H2: Train Loss: 0.7989, Val Loss: 0.9991, Accuracy: 0.8669
Precision H2: 0.8669, Recall H2: 0.8669, F1 H2: 0.8669
H3: Train Loss: 0.7480, Val Loss: 0.9886, Custom Accuracy: 0.6996
Precision H3: 0.701548071700163,Recall H3:0.548928036475565,F1 Score H3: 0.6159244514529689


Training: 100%|██████████| 231/231 [01:37<00:00,  2.38it/s]
Training: 100%|██████████| 231/231 [01:36<00:00,  2.38it/s]
Training: 100%|██████████| 231/231 [01:36<00:00,  2.38it/s]
Validation: 100%|██████████| 33/33 [00:05<00:00,  5.83it/s]
Validation: 100%|██████████| 33/33 [00:05<00:00,  5.80it/s]
Validation: 100%|██████████| 33/33 [00:05<00:00,  5.89it/s]


Epoch 5/33
H1: Train Loss: 0.7147, Val Loss: 0.8472, Accuracy: 0.9734
Precision H1: 0.9734, Recall H1: 0.9734, F1 H1: 0.9734
H2: Train Loss: 0.6790, Val Loss: 0.8545, Accuracy: 0.9240
Precision H2: 0.9240, Recall H2: 0.9240, F1 H2: 0.9240
H3: Train Loss: 0.6715, Val Loss: 0.7981, Custom Accuracy: 0.7055
Precision H3: 0.6748415716096325,Recall H3:0.5985596047383119,F1 Score H3: 0.6344157864249877


Training: 100%|██████████| 231/231 [01:36<00:00,  2.38it/s]
Training: 100%|██████████| 231/231 [01:37<00:00,  2.38it/s]
Training: 100%|██████████| 231/231 [01:34<00:00,  2.44it/s]
Validation: 100%|██████████| 33/33 [00:05<00:00,  6.06it/s]
Validation: 100%|██████████| 33/33 [00:05<00:00,  6.15it/s]
Validation: 100%|██████████| 33/33 [00:05<00:00,  6.21it/s]


Epoch 6/33
H1: Train Loss: 0.6459, Val Loss: 0.8339, Accuracy: 0.9734
Precision H1: 0.9734, Recall H1: 0.9734, F1 H1: 0.9734
H2: Train Loss: 0.6221, Val Loss: 0.8297, Accuracy: 0.9430
Precision H2: 0.9430, Recall H2: 0.9430, F1 H2: 0.9430
H3: Train Loss: 0.6145, Val Loss: 0.8111, Custom Accuracy: 0.7025
Precision H3: 0.6815951475647292,Recall H3:0.5880096456142083,F1 Score H3: 0.6313531948290527


Training: 100%|██████████| 231/231 [01:33<00:00,  2.46it/s]
Training: 100%|██████████| 231/231 [01:33<00:00,  2.46it/s]
Training: 100%|██████████| 231/231 [01:33<00:00,  2.46it/s]
Validation: 100%|██████████| 33/33 [00:05<00:00,  6.10it/s]
Validation: 100%|██████████| 33/33 [00:05<00:00,  6.13it/s]
Validation: 100%|██████████| 33/33 [00:05<00:00,  6.21it/s]


Epoch 7/33
H1: Train Loss: 0.5990, Val Loss: 0.8551, Accuracy: 0.9772
Precision H1: 0.9772, Recall H1: 0.9772, F1 H1: 0.9772
H2: Train Loss: 0.5980, Val Loss: 0.9273, Accuracy: 0.8973
Precision H2: 0.8973, Recall H2: 0.8973, F1 H2: 0.8973
H3: Train Loss: 0.5739, Val Loss: 0.8421, Custom Accuracy: 0.7101
Precision H3: 0.7033733587916097,Recall H3:0.567860434875644,F1 Score H3: 0.6283940898882266


Training: 100%|██████████| 231/231 [01:33<00:00,  2.46it/s]
Training: 100%|██████████| 231/231 [01:34<00:00,  2.46it/s]
Training: 100%|██████████| 231/231 [01:33<00:00,  2.46it/s]
Validation: 100%|██████████| 33/33 [00:05<00:00,  6.16it/s]
Validation: 100%|██████████| 33/33 [00:05<00:00,  6.14it/s]
Validation: 100%|██████████| 33/33 [00:05<00:00,  6.21it/s]


Epoch 8/33
H1: Train Loss: 0.5837, Val Loss: 0.7239, Accuracy: 0.9810
Precision H1: 0.9810, Recall H1: 0.9810, F1 H1: 0.9810
H2: Train Loss: 0.5667, Val Loss: 0.7746, Accuracy: 0.9544
Precision H2: 0.9544, Recall H2: 0.9544, F1 H2: 0.9544
H3: Train Loss: 0.5602, Val Loss: 0.7176, Custom Accuracy: 0.7178
Precision H3: 0.6890684410646387,Recall H3:0.6086121947338676,F1 Score H3: 0.6463461728087492


Training: 100%|██████████| 231/231 [01:33<00:00,  2.47it/s]
Training: 100%|██████████| 231/231 [01:33<00:00,  2.46it/s]
Training: 100%|██████████| 231/231 [01:33<00:00,  2.47it/s]
Validation: 100%|██████████| 33/33 [00:05<00:00,  6.17it/s]
Validation: 100%|██████████| 33/33 [00:05<00:00,  6.11it/s]
Validation: 100%|██████████| 33/33 [00:05<00:00,  6.23it/s]


Epoch 9/33
H1: Train Loss: 0.5416, Val Loss: 0.7290, Accuracy: 0.9848
Precision H1: 0.9848, Recall H1: 0.9848, F1 H1: 0.9848
H2: Train Loss: 0.5450, Val Loss: 0.7429, Accuracy: 0.9658
Precision H2: 0.9658, Recall H2: 0.9658, F1 H2: 0.9658
H3: Train Loss: 0.5532, Val Loss: 0.7346, Custom Accuracy: 0.7224
Precision H3: 0.6814800640846268,Recall H3:0.6508419337316675,F1 Score H3: 0.6658087210680067


Training: 100%|██████████| 231/231 [01:33<00:00,  2.46it/s]
Training: 100%|██████████| 231/231 [01:33<00:00,  2.46it/s]
Training: 100%|██████████| 231/231 [01:33<00:00,  2.46it/s]
Validation: 100%|██████████| 33/33 [00:05<00:00,  6.17it/s]
Validation: 100%|██████████| 33/33 [00:05<00:00,  6.15it/s]
Validation: 100%|██████████| 33/33 [00:05<00:00,  6.25it/s]


Epoch 10/33
H1: Train Loss: 0.5450, Val Loss: 0.7120, Accuracy: 0.9886
Precision H1: 0.9886, Recall H1: 0.9886, F1 H1: 0.9886
H2: Train Loss: 0.5343, Val Loss: 0.7428, Accuracy: 0.9620
Precision H2: 0.9620, Recall H2: 0.9620, F1 H2: 0.9620
H3: Train Loss: 0.5185, Val Loss: 0.6710, Custom Accuracy: 0.7330
Precision H3: 0.6950460608255286,Recall H3:0.6428088599191261,F1 Score H3: 0.6679076468199665


Training: 100%|██████████| 231/231 [01:33<00:00,  2.47it/s]
Training: 100%|██████████| 231/231 [01:33<00:00,  2.46it/s]
Training: 100%|██████████| 231/231 [01:35<00:00,  2.41it/s]
Validation: 100%|██████████| 33/33 [00:05<00:00,  5.79it/s]
Validation: 100%|██████████| 33/33 [00:05<00:00,  5.72it/s]
Validation: 100%|██████████| 33/33 [00:05<00:00,  5.85it/s]


Epoch 11/33
H1: Train Loss: 0.5009, Val Loss: 0.6970, Accuracy: 0.9886
Precision H1: 0.9886, Recall H1: 0.9886, F1 H1: 0.9886
H2: Train Loss: 0.4975, Val Loss: 0.6735, Accuracy: 0.9582
Precision H2: 0.9582, Recall H2: 0.9582, F1 H2: 0.9582
H3: Train Loss: 0.4916, Val Loss: 0.6370, Custom Accuracy: 0.7450
Precision H3: 0.713765986864846,Recall H3:0.6493813748566599,F1 Score H3: 0.6800531635711881


Training: 100%|██████████| 231/231 [01:37<00:00,  2.37it/s]
Training: 100%|██████████| 231/231 [01:37<00:00,  2.38it/s]
Training: 100%|██████████| 231/231 [01:36<00:00,  2.38it/s]
Validation: 100%|██████████| 33/33 [00:05<00:00,  5.81it/s]
Validation: 100%|██████████| 33/33 [00:05<00:00,  5.76it/s]
Validation: 100%|██████████| 33/33 [00:05<00:00,  5.85it/s]


Epoch 12/33
H1: Train Loss: 0.4866, Val Loss: 0.6644, Accuracy: 0.9886
Precision H1: 0.9886, Recall H1: 0.9886, F1 H1: 0.9886
H2: Train Loss: 0.4858, Val Loss: 0.6791, Accuracy: 0.9582
Precision H2: 0.9582, Recall H2: 0.9582, F1 H2: 0.9582
H3: Train Loss: 0.4688, Val Loss: 0.6382, Custom Accuracy: 0.7572
Precision H3: 0.7239733017299557,Recall H3:0.6774064655243361,F1 Score H3: 0.699916192481865


Training: 100%|██████████| 231/231 [01:36<00:00,  2.38it/s]
Training: 100%|██████████| 231/231 [01:37<00:00,  2.38it/s]
Training: 100%|██████████| 231/231 [01:37<00:00,  2.38it/s]
Validation: 100%|██████████| 33/33 [00:05<00:00,  5.79it/s]
Validation: 100%|██████████| 33/33 [00:05<00:00,  5.78it/s]
Validation: 100%|██████████| 33/33 [00:05<00:00,  5.88it/s]


Epoch 13/33
H1: Train Loss: 0.4598, Val Loss: 0.6291, Accuracy: 0.9810
Precision H1: 0.9810, Recall H1: 0.9810, F1 H1: 0.9810
H2: Train Loss: 0.4471, Val Loss: 0.7097, Accuracy: 0.9582
Precision H2: 0.9582, Recall H2: 0.9582, F1 H2: 0.9582
H3: Train Loss: 0.4334, Val Loss: 0.6651, Custom Accuracy: 0.7754
Precision H3: 0.756277742114244,Recall H3:0.6752639101498417,F1 Score H3: 0.713478457985026


Training: 100%|██████████| 231/231 [01:36<00:00,  2.40it/s]
Training: 100%|██████████| 231/231 [01:33<00:00,  2.46it/s]
Training: 100%|██████████| 231/231 [01:33<00:00,  2.46it/s]
Validation: 100%|██████████| 33/33 [00:05<00:00,  6.16it/s]
Validation: 100%|██████████| 33/33 [00:05<00:00,  6.18it/s]
Validation: 100%|██████████| 33/33 [00:05<00:00,  6.26it/s]


Epoch 14/33
H1: Train Loss: 0.4311, Val Loss: 0.6086, Accuracy: 0.9848
Precision H1: 0.9848, Recall H1: 0.9848, F1 H1: 0.9848
H2: Train Loss: 0.4230, Val Loss: 0.6621, Accuracy: 0.9582
Precision H2: 0.9582, Recall H2: 0.9582, F1 H2: 0.9582
H3: Train Loss: 0.4049, Val Loss: 0.6373, Custom Accuracy: 0.7994
Precision H3: 0.7882549009925436,Recall H3:0.715995780729621,F1 Score H3: 0.7503897988651187


Training: 100%|██████████| 231/231 [01:33<00:00,  2.46it/s]
Training: 100%|██████████| 231/231 [01:33<00:00,  2.46it/s]
Training: 100%|██████████| 231/231 [01:33<00:00,  2.46it/s]
Validation: 100%|██████████| 33/33 [00:05<00:00,  6.17it/s]
Validation: 100%|██████████| 33/33 [00:05<00:00,  6.14it/s]
Validation: 100%|██████████| 33/33 [00:05<00:00,  6.23it/s]


Epoch 15/33
H1: Train Loss: 0.3945, Val Loss: 0.5826, Accuracy: 0.9848
Precision H1: 0.9848, Recall H1: 0.9848, F1 H1: 0.9848
H2: Train Loss: 0.3823, Val Loss: 0.5702, Accuracy: 0.9734
Precision H2: 0.9734, Recall H2: 0.9734, F1 H2: 0.9734
H3: Train Loss: 0.3731, Val Loss: 0.5486, Custom Accuracy: 0.8105
Precision H3: 0.7880011412330804,Recall H3:0.752443500732474,F1 Score H3: 0.769811937589721


Training: 100%|██████████| 231/231 [01:33<00:00,  2.47it/s]
Training: 100%|██████████| 231/231 [01:33<00:00,  2.46it/s]
Training: 100%|██████████| 231/231 [01:33<00:00,  2.47it/s]
Validation: 100%|██████████| 33/33 [00:05<00:00,  6.16it/s]
Validation: 100%|██████████| 33/33 [00:05<00:00,  6.13it/s]
Validation: 100%|██████████| 33/33 [00:05<00:00,  6.22it/s]


Epoch 16/33
H1: Train Loss: 0.3682, Val Loss: 0.6002, Accuracy: 0.9810
Precision H1: 0.9810, Recall H1: 0.9810, F1 H1: 0.9810
H2: Train Loss: 0.3488, Val Loss: 0.6548, Accuracy: 0.9506
Precision H2: 0.9506, Recall H2: 0.9506, F1 H2: 0.9506
H3: Train Loss: 0.3367, Val Loss: 0.5915, Custom Accuracy: 0.8274
Precision H3: 0.8499894381073088,Recall H3:0.722141019099194,F1 Score H3: 0.7808667992464668


Training: 100%|██████████| 231/231 [01:33<00:00,  2.47it/s]
Training: 100%|██████████| 231/231 [01:33<00:00,  2.46it/s]
Training: 100%|██████████| 231/231 [01:33<00:00,  2.46it/s]
Validation: 100%|██████████| 33/33 [00:05<00:00,  6.15it/s]
Validation: 100%|██████████| 33/33 [00:05<00:00,  6.14it/s]
Validation: 100%|██████████| 33/33 [00:05<00:00,  6.23it/s]


Epoch 17/33
H1: Train Loss: 0.3456, Val Loss: 0.5746, Accuracy: 0.9810
Precision H1: 0.9810, Recall H1: 0.9810, F1 H1: 0.9810
H2: Train Loss: 0.3575, Val Loss: 0.5814, Accuracy: 0.9582
Precision H2: 0.9582, Recall H2: 0.9582, F1 H2: 0.9582
H3: Train Loss: 0.3135, Val Loss: 0.5332, Custom Accuracy: 0.8435
Precision H3: 0.8358867051832831,Recall H3:0.7887500754420906,F1 Score H3: 0.811634587665272


Training: 100%|██████████| 231/231 [01:33<00:00,  2.46it/s]
Training: 100%|██████████| 231/231 [01:34<00:00,  2.46it/s]
Training: 100%|██████████| 231/231 [01:33<00:00,  2.47it/s]
Validation: 100%|██████████| 33/33 [00:05<00:00,  6.15it/s]
Validation: 100%|██████████| 33/33 [00:05<00:00,  6.15it/s]
Validation: 100%|██████████| 33/33 [00:05<00:00,  6.23it/s]


Epoch 18/33
H1: Train Loss: 0.3048, Val Loss: 0.5470, Accuracy: 0.9848
Precision H1: 0.9848, Recall H1: 0.9848, F1 H1: 0.9848
H2: Train Loss: 0.2829, Val Loss: 0.5313, Accuracy: 0.9620
Precision H2: 0.9620, Recall H2: 0.9620, F1 H2: 0.9620
H3: Train Loss: 0.2712, Val Loss: 0.4644, Custom Accuracy: 0.8681
Precision H3: 0.8749740753543035,Recall H3:0.7975903796246001,F1 Score H3: 0.8344920912866403


Training: 100%|██████████| 231/231 [01:33<00:00,  2.46it/s]
Training: 100%|██████████| 231/231 [01:33<00:00,  2.47it/s]
Training: 100%|██████████| 231/231 [01:33<00:00,  2.46it/s]
Validation: 100%|██████████| 33/33 [00:05<00:00,  6.15it/s]
Validation: 100%|██████████| 33/33 [00:05<00:00,  6.13it/s]
Validation: 100%|██████████| 33/33 [00:05<00:00,  6.22it/s]


Epoch 19/33
H1: Train Loss: 0.2639, Val Loss: 0.5041, Accuracy: 0.9886
Precision H1: 0.9886, Recall H1: 0.9886, F1 H1: 0.9886
H2: Train Loss: 0.2531, Val Loss: 0.5102, Accuracy: 0.9620
Precision H2: 0.9620, Recall H2: 0.9620, F1 H2: 0.9620
H3: Train Loss: 0.2401, Val Loss: 0.4605, Custom Accuracy: 0.8842
Precision H3: 0.8900020026445882,Recall H3:0.8229177982980265,F1 Score H3: 0.8551462690711873


Training: 100%|██████████| 231/231 [01:33<00:00,  2.46it/s]
Training: 100%|██████████| 231/231 [01:33<00:00,  2.46it/s]
Training: 100%|██████████| 231/231 [01:34<00:00,  2.45it/s]
Validation: 100%|██████████| 33/33 [00:05<00:00,  6.15it/s]
Validation: 100%|██████████| 33/33 [00:05<00:00,  6.13it/s]
Validation: 100%|██████████| 33/33 [00:05<00:00,  6.24it/s]


Epoch 20/33
H1: Train Loss: 0.2363, Val Loss: 0.4948, Accuracy: 0.9810
Precision H1: 0.9810, Recall H1: 0.9810, F1 H1: 0.9810
H2: Train Loss: 0.2244, Val Loss: 0.4442, Accuracy: 0.9658
Precision H2: 0.9658, Recall H2: 0.9658, F1 H2: 0.9658
H3: Train Loss: 0.2075, Val Loss: 0.4094, Custom Accuracy: 0.9047
Precision H3: 0.8989966201943388,Recall H3:0.8682587691142825,F1 Score H3: 0.8833603831228999


Training: 100%|██████████| 231/231 [01:33<00:00,  2.47it/s]
Training: 100%|██████████| 231/231 [01:33<00:00,  2.46it/s]
Training: 100%|██████████| 231/231 [01:33<00:00,  2.46it/s]
Validation: 100%|██████████| 33/33 [00:05<00:00,  6.17it/s]
Validation: 100%|██████████| 33/33 [00:05<00:00,  6.16it/s]
Validation: 100%|██████████| 33/33 [00:05<00:00,  6.23it/s]


Epoch 21/33
H1: Train Loss: 0.2023, Val Loss: 0.4849, Accuracy: 0.9886
Precision H1: 0.9886, Recall H1: 0.9886, F1 H1: 0.9886
H2: Train Loss: 0.2009, Val Loss: 0.4694, Accuracy: 0.9620
Precision H2: 0.9620, Recall H2: 0.9620, F1 H2: 0.9620
H3: Train Loss: 0.1955, Val Loss: 0.4055, Custom Accuracy: 0.9184
Precision H3: 0.9201324763111836,Recall H3:0.8816498773723108,F1 Score H3: 0.900480219764145


Training: 100%|██████████| 231/231 [01:33<00:00,  2.46it/s]
Training: 100%|██████████| 231/231 [01:33<00:00,  2.46it/s]
Training: 100%|██████████| 231/231 [01:33<00:00,  2.47it/s]
Validation: 100%|██████████| 33/33 [00:05<00:00,  6.16it/s]
Validation: 100%|██████████| 33/33 [00:05<00:00,  6.16it/s]
Validation: 100%|██████████| 33/33 [00:05<00:00,  6.21it/s]


Epoch 22/33
H1: Train Loss: 0.1775, Val Loss: 0.4901, Accuracy: 0.9848
Precision H1: 0.9848, Recall H1: 0.9848, F1 H1: 0.9848
H2: Train Loss: 0.1682, Val Loss: 0.4681, Accuracy: 0.9544
Precision H2: 0.9544, Recall H2: 0.9544, F1 H2: 0.9544
H3: Train Loss: 0.1697, Val Loss: 0.3952, Custom Accuracy: 0.9254
Precision H3: 0.9211237853823405,Recall H3:0.9018746673689638,F1 Score H3: 0.9113976002487255


Training: 100%|██████████| 231/231 [01:33<00:00,  2.46it/s]
Training: 100%|██████████| 231/231 [01:33<00:00,  2.46it/s]
Training: 100%|██████████| 231/231 [01:33<00:00,  2.46it/s]
Validation: 100%|██████████| 33/33 [00:05<00:00,  6.12it/s]
Validation: 100%|██████████| 33/33 [00:05<00:00,  6.14it/s]
Validation: 100%|██████████| 33/33 [00:05<00:00,  6.24it/s]


Epoch 23/33
H1: Train Loss: 0.1545, Val Loss: 0.4438, Accuracy: 0.9886
Precision H1: 0.9886, Recall H1: 0.9886, F1 H1: 0.9886
H2: Train Loss: 0.1456, Val Loss: 0.3811, Accuracy: 0.9658
Precision H2: 0.9658, Recall H2: 0.9658, F1 H2: 0.9658
H3: Train Loss: 0.1373, Val Loss: 0.3420, Custom Accuracy: 0.9427
Precision H3: 0.941969806703647,Recall H3:0.9209932019817951,F1 Score H3: 0.9313634080778984


Training: 100%|██████████| 231/231 [01:33<00:00,  2.47it/s]
Training: 100%|██████████| 231/231 [01:33<00:00,  2.47it/s]
Training: 100%|██████████| 231/231 [01:33<00:00,  2.46it/s]
Validation: 100%|██████████| 33/33 [00:05<00:00,  6.18it/s]
Validation: 100%|██████████| 33/33 [00:05<00:00,  6.13it/s]
Validation: 100%|██████████| 33/33 [00:05<00:00,  6.24it/s]


Epoch 24/33
H1: Train Loss: 0.1317, Val Loss: 0.4453, Accuracy: 0.9772
Precision H1: 0.9772, Recall H1: 0.9772, F1 H1: 0.9772
H2: Train Loss: 0.1249, Val Loss: 0.3933, Accuracy: 0.9582
Precision H2: 0.9582, Recall H2: 0.9582, F1 H2: 0.9582
H3: Train Loss: 0.1204, Val Loss: 0.3485, Custom Accuracy: 0.9468
Precision H3: 0.9455715492787736,Recall H3:0.9223346995210113,F1 Score H3: 0.9338085905971357


Training: 100%|██████████| 231/231 [01:33<00:00,  2.46it/s]
Training: 100%|██████████| 231/231 [01:33<00:00,  2.46it/s]
Training: 100%|██████████| 231/231 [01:33<00:00,  2.46it/s]
Validation: 100%|██████████| 33/33 [00:05<00:00,  6.10it/s]
Validation: 100%|██████████| 33/33 [00:05<00:00,  6.14it/s]
Validation: 100%|██████████| 33/33 [00:05<00:00,  6.23it/s]


Epoch 25/33
H1: Train Loss: 0.1143, Val Loss: 0.4983, Accuracy: 0.9772
Precision H1: 0.9772, Recall H1: 0.9772, F1 H1: 0.9772
H2: Train Loss: 0.1127, Val Loss: 0.4123, Accuracy: 0.9582
Precision H2: 0.9582, Recall H2: 0.9582, F1 H2: 0.9582
H3: Train Loss: 0.1082, Val Loss: 0.3537, Custom Accuracy: 0.9476
Precision H3: 0.9437669744703966,Recall H3:0.9315323248783325,F1 Score H3: 0.9376097396048912


Training: 100%|██████████| 231/231 [01:33<00:00,  2.46it/s]
Training: 100%|██████████| 231/231 [01:33<00:00,  2.46it/s]
Training: 100%|██████████| 231/231 [01:33<00:00,  2.47it/s]
Validation: 100%|██████████| 33/33 [00:05<00:00,  6.15it/s]
Validation: 100%|██████████| 33/33 [00:05<00:00,  6.16it/s]
Validation: 100%|██████████| 33/33 [00:05<00:00,  6.25it/s]


Epoch 26/33
H1: Train Loss: 0.1161, Val Loss: 0.4621, Accuracy: 0.9886
Precision H1: 0.9886, Recall H1: 0.9886, F1 H1: 0.9886
H2: Train Loss: 0.1148, Val Loss: 0.3948, Accuracy: 0.9582
Precision H2: 0.9582, Recall H2: 0.9582, F1 H2: 0.9582
H3: Train Loss: 0.0989, Val Loss: 0.3304, Custom Accuracy: 0.9526
Precision H3: 0.9488532802221015,Recall H3:0.937917880598489,F1 Score H3: 0.9433538905669768


Training: 100%|██████████| 231/231 [01:33<00:00,  2.46it/s]
Training: 100%|██████████| 231/231 [01:33<00:00,  2.46it/s]
Training: 100%|██████████| 231/231 [01:33<00:00,  2.46it/s]
Validation: 100%|██████████| 33/33 [00:05<00:00,  6.13it/s]
Validation: 100%|██████████| 33/33 [00:05<00:00,  6.14it/s]
Validation: 100%|██████████| 33/33 [00:05<00:00,  6.23it/s]


Epoch 27/33
H1: Train Loss: 0.0951, Val Loss: 0.4282, Accuracy: 0.9810
Precision H1: 0.9810, Recall H1: 0.9810, F1 H1: 0.9810
H2: Train Loss: 0.0883, Val Loss: 0.3759, Accuracy: 0.9696
Precision H2: 0.9696, Recall H2: 0.9696, F1 H2: 0.9696
H3: Train Loss: 0.0862, Val Loss: 0.3224, Custom Accuracy: 0.9564
Precision H3: 0.9490901683867463,Recall H3:0.9438505094398632,F1 Score H3: 0.9464630872264979


Training: 100%|██████████| 231/231 [01:33<00:00,  2.47it/s]
Training: 100%|██████████| 231/231 [01:33<00:00,  2.46it/s]
Training: 100%|██████████| 231/231 [01:33<00:00,  2.47it/s]
Validation: 100%|██████████| 33/33 [00:05<00:00,  6.18it/s]
Validation: 100%|██████████| 33/33 [00:05<00:00,  6.16it/s]
Validation: 100%|██████████| 33/33 [00:05<00:00,  6.20it/s]


Epoch 28/33
H1: Train Loss: 0.0909, Val Loss: 0.4380, Accuracy: 0.9886
Precision H1: 0.9886, Recall H1: 0.9886, F1 H1: 0.9886
H2: Train Loss: 0.0820, Val Loss: 0.3499, Accuracy: 0.9696
Precision H2: 0.9696, Recall H2: 0.9696, F1 H2: 0.9696
H3: Train Loss: 0.0741, Val Loss: 0.2997, Custom Accuracy: 0.9585
Precision H3: 0.9558693946526646,Recall H3:0.9439245798561388,F1 Score H3: 0.9498594361824335


Training: 100%|██████████| 231/231 [01:33<00:00,  2.47it/s]
Training: 100%|██████████| 231/231 [01:33<00:00,  2.47it/s]
Training: 100%|██████████| 231/231 [01:33<00:00,  2.46it/s]
Validation: 100%|██████████| 33/33 [00:05<00:00,  6.17it/s]
Validation: 100%|██████████| 33/33 [00:05<00:00,  6.14it/s]
Validation: 100%|██████████| 33/33 [00:05<00:00,  6.22it/s]


Epoch 29/33
H1: Train Loss: 0.0707, Val Loss: 0.4640, Accuracy: 0.9772
Precision H1: 0.9772, Recall H1: 0.9772, F1 H1: 0.9772
H2: Train Loss: 0.0650, Val Loss: 0.4057, Accuracy: 0.9620
Precision H2: 0.9620, Recall H2: 0.9620, F1 H2: 0.9620
H3: Train Loss: 0.0636, Val Loss: 0.3675, Custom Accuracy: 0.9608
Precision H3: 0.9561470215462611,Recall H3:0.9476259059909249,F1 Score H3: 0.9518673938970613


Training: 100%|██████████| 231/231 [01:34<00:00,  2.46it/s]
Training: 100%|██████████| 231/231 [01:34<00:00,  2.46it/s]
Training: 100%|██████████| 231/231 [01:33<00:00,  2.46it/s]
Validation: 100%|██████████| 33/33 [00:05<00:00,  6.16it/s]
Validation: 100%|██████████| 33/33 [00:05<00:00,  6.15it/s]
Validation: 100%|██████████| 33/33 [00:05<00:00,  6.25it/s]


Epoch 30/33
H1: Train Loss: 0.0618, Val Loss: 0.4703, Accuracy: 0.9772
Precision H1: 0.9772, Recall H1: 0.9772, F1 H1: 0.9772
H2: Train Loss: 0.0592, Val Loss: 0.4050, Accuracy: 0.9658
Precision H2: 0.9658, Recall H2: 0.9658, F1 H2: 0.9658
H3: Train Loss: 0.0571, Val Loss: 0.3194, Custom Accuracy: 0.9629
Precision H3: 0.96103566902046,Recall H3:0.9473316818373853,F1 Score H3: 0.9541344712585019


Training: 100%|██████████| 231/231 [01:33<00:00,  2.46it/s]
Training: 100%|██████████| 231/231 [01:33<00:00,  2.46it/s]
Training: 100%|██████████| 231/231 [01:33<00:00,  2.46it/s]
Validation: 100%|██████████| 33/33 [00:05<00:00,  6.15it/s]
Validation: 100%|██████████| 33/33 [00:05<00:00,  6.12it/s]
Validation: 100%|██████████| 33/33 [00:05<00:00,  6.22it/s]


Epoch 31/33
H1: Train Loss: 0.0525, Val Loss: 0.4832, Accuracy: 0.9772
Precision H1: 0.9772, Recall H1: 0.9772, F1 H1: 0.9772
H2: Train Loss: 0.0513, Val Loss: 0.3806, Accuracy: 0.9620
Precision H2: 0.9620, Recall H2: 0.9620, F1 H2: 0.9620
H3: Train Loss: 0.0499, Val Loss: 0.3131, Custom Accuracy: 0.9608
Precision H3: 0.9587180879956546,Recall H3:0.9450032645850135,F1 Score H3: 0.9518112740023023


Training: 100%|██████████| 231/231 [01:33<00:00,  2.46it/s]
Training: 100%|██████████| 231/231 [01:34<00:00,  2.46it/s]
Training: 100%|██████████| 231/231 [01:34<00:00,  2.46it/s]
Validation: 100%|██████████| 33/33 [00:05<00:00,  6.15it/s]
Validation: 100%|██████████| 33/33 [00:05<00:00,  6.14it/s]
Validation: 100%|██████████| 33/33 [00:05<00:00,  6.23it/s]


Epoch 32/33
H1: Train Loss: 0.0480, Val Loss: 0.5237, Accuracy: 0.9848
Precision H1: 0.9848, Recall H1: 0.9848, F1 H1: 0.9848
H2: Train Loss: 0.0468, Val Loss: 0.4249, Accuracy: 0.9582
Precision H2: 0.9582, Recall H2: 0.9582, F1 H2: 0.9582
H3: Train Loss: 0.0478, Val Loss: 0.3338, Custom Accuracy: 0.9643
Precision H3: 0.9630635524171647,Recall H3:0.9500640571933348,F1 Score H3: 0.9565196397231022


Training: 100%|██████████| 231/231 [01:33<00:00,  2.46it/s]
Training: 100%|██████████| 231/231 [01:33<00:00,  2.46it/s]
Training: 100%|██████████| 231/231 [01:33<00:00,  2.46it/s]
Validation: 100%|██████████| 33/33 [00:05<00:00,  6.14it/s]
Validation: 100%|██████████| 33/33 [00:05<00:00,  6.14it/s]
Validation: 100%|██████████| 33/33 [00:05<00:00,  6.22it/s]


Epoch 33/33
H1: Train Loss: 0.0441, Val Loss: 0.5424, Accuracy: 0.9810
Precision H1: 0.9810, Recall H1: 0.9810, F1 H1: 0.9810
H2: Train Loss: 0.0926, Val Loss: 0.4741, Accuracy: 0.9582
Precision H2: 0.9582, Recall H2: 0.9582, F1 H2: 0.9582
H3: Train Loss: 0.0681, Val Loss: 0.3710, Custom Accuracy: 0.9585
Precision H3: 0.9559116422234293,Recall H3:0.9409640950515475,F1 Score H3: 0.9483789746461571


In [4]:
# Create test datasets for H1, H2, and H3
test_dataset_h1 = TextDataset(test_texts_h1, test_labels_h1, test_labels_h2, test_labels_h3, tokenizer)
test_dataset_h2 = TextDataset(test_texts_h2, test_labels_h1, test_labels_h2, test_labels_h3, tokenizer)
test_dataset_h3 = TextDataset(test_texts_h3, test_labels_h1, test_labels_h2, test_labels_h3, tokenizer)

# Create test dataloaders for H1, H2, and H3
test_dataloader_h1 = DataLoader(test_dataset_h1, batch_size=8, shuffle=False)
test_dataloader_h2 = DataLoader(test_dataset_h2, batch_size=8, shuffle=False)
test_dataloader_h3 = DataLoader(test_dataset_h3, batch_size=8, shuffle=False)
# Testing loop for H1, H2, and H3
def test(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0.0
    all_preds_h1, all_preds_h2, all_preds_h3 = [], [], []
    all_labels_h1, all_labels_h2, all_labels_h3 = [], [], []

    with torch.no_grad():
        for batch in tqdm(dataloader, total=len(dataloader), desc="Testing"):
            input_ids = batch['ids'].to(device)
            attention_mask = batch['mask'].to(device)
            labels_h1 = batch['labels_h1'].to(device)  # No need to reshape labels_h1
            labels_h2 = batch['labels_h2'].to(device)  # No need to reshape labels_h2

            #labels_h1 = batch['labels_h1'].unsqueeze(1).to(device)  # Reshape labels_h1
            #labels_h2 = batch['labels_h2'].unsqueeze(1).to(device)  # Reshape labels_h2
            #labels_h3 = batch['labels_h3'].unsqueeze(1).to(device)  # Reshape labels_h3
            labels_h3 = batch['labels_h3'].to(device)  # No need to reshape labels_h3
            print("Labels H1 shape:", labels_h1.shape)
            print("Labels H2 shape:", labels_h2.shape)
            print("Labels H3 shape:", labels_h3.shape)

            logits_h1, logits_h2, logits_h3 = model(input_ids, attention_mask)
            
            print("Logits H1 shape:", logits_h1.shape)
            print("Logits H2 shape:", logits_h2.shape)
            print("Logits H3 shape:", logits_h3.shape)
            
            
            
            loss_h1 = criterion(logits_h1, labels_h1)
            loss_h2 = criterion(logits_h2, labels_h2)
            loss_h3 = criterion(logits_h3, labels_h3)

            loss = loss_h1 + loss_h2 + loss_h3
            total_loss += loss.item()

            preds_h1 = torch.sigmoid(logits_h1)
            preds_h2 = torch.sigmoid(logits_h2)
            preds_h3 = torch.sigmoid(logits_h3)

            all_preds_h1.extend(preds_h1.cpu().numpy())
            all_preds_h2.extend(preds_h2.cpu().numpy())
            all_preds_h3.extend(preds_h3.cpu().numpy())

            all_labels_h1.extend(labels_h1.cpu().numpy())
            all_labels_h2.extend(labels_h2.cpu().numpy())
            all_labels_h3.extend(labels_h3.cpu().numpy())

    return total_loss / len(dataloader), all_preds_h1, all_preds_h2, all_preds_h3, all_labels_h1, all_labels_h2, all_labels_h3

# Testing for H1
test_loss_h1, test_preds_h1, _, _, test_labels_h1, _, _ = test(model, test_dataloader_h1, criterion, device)
# Testing for H2
test_loss_h2, _, test_preds_h2, _, _, test_labels_h2, _ = test(model, test_dataloader_h2, criterion, device)
# Testing for H3
test_loss_h3, _, _, test_preds_h3, _, _, test_labels_h3 = test(model, test_dataloader_h3, criterion, device)

# Metrics calculation for H1
threshold_h1 = 0.5
test_preds_h1_binary = (np.array(test_preds_h1) > threshold_h1).astype(int)

acc_h1 = accuracy_score(test_labels_h1, test_preds_h1_binary)
precision_h1 = precision_score(test_labels_h1, test_preds_h1_binary, average='micro')
recall_h1 = recall_score(test_labels_h1, test_preds_h1_binary, average='micro')
f1_h1 = f1_score(test_labels_h1, test_preds_h1_binary, average='micro')

# Metrics calculation for H2
threshold_h2 = 0.5
test_preds_h2_binary = (np.array(test_preds_h2) > threshold_h2).astype(int)

acc_h2 = accuracy_score(test_labels_h2, test_preds_h2_binary)
precision_h2 = precision_score(test_labels_h2, test_preds_h2_binary, average='micro')
recall_h2 = recall_score(test_labels_h2, test_preds_h2_binary, average='micro')
f1_h2 = f1_score(test_labels_h2, test_preds_h2_binary, average='micro')

# Metrics calculation for H3
threshold_h3 = 0.5
test_preds_h3_binary = (np.array(test_preds_h3) > threshold_h3).astype(int)

# Custom accuracy calculation for H3
acc_h3_custom = custom_accuracy(np.array(test_labels_h3), test_preds_h3_binary)
# Calculate custom precision, recall, and F1 score for H3
precision_h3, recall_h3, f1_score_h3 = custom_precision_recall_f1(np.array(test_labels_h3), test_preds_h3_binary)

print("Testing Results:")
print(f"H1: Test Loss: {test_loss_h1:.4f}, Accuracy: {acc_h1:.4f}")
print(f"Precision H1: {precision_h1:.4f}, Recall H1: {recall_h1:.4f}, F1 H1: {f1_h1:.4f}")

print(f"H2: Test Loss: {test_loss_h2:.4f}, Accuracy: {acc_h2:.4f}")
print(f"Precision H2: {precision_h2:.4f}, Recall H2: {recall_h2:.4f}, F1 H2: {f1_h2:.4f}")

print(f"H3: Test Loss: {test_loss_h3:.4f}, Custom Accuracy: {acc_h3_custom:.4f}")
print(f"Precision H3: {precision_h3}, Recall H3:{recall_h3}, F1 Score H3: {f1_score_h3}")




Testing:   2%|▏         | 1/66 [00:00<00:11,  5.88it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])
Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:   5%|▍         | 3/66 [00:00<00:10,  6.25it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])
Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:   8%|▊         | 5/66 [00:00<00:09,  6.41it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])
Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  11%|█         | 7/66 [00:01<00:09,  6.08it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])
Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  14%|█▎        | 9/66 [00:01<00:09,  6.31it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])
Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  17%|█▋        | 11/66 [00:01<00:08,  6.32it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])
Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  20%|█▉        | 13/66 [00:02<00:08,  6.29it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])
Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  23%|██▎       | 15/66 [00:02<00:08,  6.32it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])
Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  26%|██▌       | 17/66 [00:02<00:07,  6.47it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])
Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  29%|██▉       | 19/66 [00:03<00:07,  6.35it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])
Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  32%|███▏      | 21/66 [00:03<00:07,  6.06it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])
Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  35%|███▍      | 23/66 [00:03<00:06,  6.39it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])
Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  38%|███▊      | 25/66 [00:03<00:06,  6.29it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])
Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  41%|████      | 27/66 [00:04<00:06,  6.24it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])
Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  44%|████▍     | 29/66 [00:04<00:06,  6.02it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])
Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  47%|████▋     | 31/66 [00:04<00:05,  6.02it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])
Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  50%|█████     | 33/66 [00:05<00:05,  6.06it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])
Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  53%|█████▎    | 35/66 [00:05<00:05,  6.10it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])
Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  56%|█████▌    | 37/66 [00:05<00:04,  5.98it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])
Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  59%|█████▉    | 39/66 [00:06<00:04,  6.13it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])
Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  62%|██████▏   | 41/66 [00:06<00:04,  6.18it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])
Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  65%|██████▌   | 43/66 [00:06<00:03,  6.43it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])
Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  68%|██████▊   | 45/66 [00:07<00:03,  6.38it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])
Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  71%|███████   | 47/66 [00:07<00:03,  6.26it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])
Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  74%|███████▍  | 49/66 [00:07<00:02,  6.17it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])
Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  77%|███████▋  | 51/66 [00:08<00:02,  5.92it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])
Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  80%|████████  | 53/66 [00:08<00:02,  6.10it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])
Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  83%|████████▎ | 55/66 [00:08<00:01,  6.17it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])
Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  86%|████████▋ | 57/66 [00:09<00:01,  6.26it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])
Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  89%|████████▉ | 59/66 [00:09<00:01,  6.43it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])
Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  92%|█████████▏| 61/66 [00:09<00:00,  6.26it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])
Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  95%|█████████▌| 63/66 [00:10<00:00,  6.17it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])
Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  98%|█████████▊| 65/66 [00:10<00:00,  6.25it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])
Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing: 100%|██████████| 66/66 [00:10<00:00,  6.23it/s]
Testing:   2%|▏         | 1/66 [00:00<00:10,  6.15it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])
Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:   5%|▍         | 3/66 [00:00<00:09,  6.36it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])
Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:   8%|▊         | 5/66 [00:00<00:09,  6.52it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])
Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  11%|█         | 7/66 [00:01<00:09,  6.10it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])
Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  14%|█▎        | 9/66 [00:01<00:09,  6.30it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])
Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  17%|█▋        | 11/66 [00:01<00:08,  6.32it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])
Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  20%|█▉        | 13/66 [00:02<00:08,  6.28it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])
Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  23%|██▎       | 15/66 [00:02<00:08,  6.31it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])
Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  26%|██▌       | 17/66 [00:02<00:07,  6.43it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])
Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  29%|██▉       | 19/66 [00:03<00:07,  6.38it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])
Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  32%|███▏      | 21/66 [00:03<00:07,  6.00it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])
Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  35%|███▍      | 23/66 [00:03<00:06,  6.32it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])
Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  38%|███▊      | 25/66 [00:03<00:06,  6.25it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])
Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  41%|████      | 27/66 [00:04<00:06,  6.24it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])
Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  44%|████▍     | 29/66 [00:04<00:06,  6.02it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])
Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  47%|████▋     | 31/66 [00:04<00:05,  5.99it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])
Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  50%|█████     | 33/66 [00:05<00:05,  6.13it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])
Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  53%|█████▎    | 35/66 [00:05<00:05,  6.12it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])
Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  56%|█████▌    | 37/66 [00:05<00:04,  5.98it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])
Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  59%|█████▉    | 39/66 [00:06<00:04,  6.13it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])
Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  62%|██████▏   | 41/66 [00:06<00:04,  6.23it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])
Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  65%|██████▌   | 43/66 [00:06<00:03,  6.50it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])
Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  68%|██████▊   | 45/66 [00:07<00:03,  6.42it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])
Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  71%|███████   | 47/66 [00:07<00:03,  6.16it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])
Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  74%|███████▍  | 49/66 [00:07<00:02,  6.10it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])
Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  77%|███████▋  | 51/66 [00:08<00:02,  5.86it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])
Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  80%|████████  | 53/66 [00:08<00:02,  5.99it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])
Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  83%|████████▎ | 55/66 [00:08<00:01,  6.09it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])
Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  86%|████████▋ | 57/66 [00:09<00:01,  6.24it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])
Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  89%|████████▉ | 59/66 [00:09<00:01,  6.44it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])
Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  92%|█████████▏| 61/66 [00:09<00:00,  6.35it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])
Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  95%|█████████▌| 63/66 [00:10<00:00,  6.26it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])
Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  98%|█████████▊| 65/66 [00:10<00:00,  6.29it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])
Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing: 100%|██████████| 66/66 [00:10<00:00,  6.23it/s]
Testing:   2%|▏         | 1/66 [00:00<00:10,  6.25it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])
Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:   5%|▍         | 3/66 [00:00<00:09,  6.34it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])
Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:   8%|▊         | 5/66 [00:00<00:09,  6.53it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])
Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])


Testing:   9%|▉         | 6/66 [00:00<00:09,  6.02it/s]

Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])
Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  12%|█▏        | 8/66 [00:01<00:09,  6.20it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])
Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  15%|█▌        | 10/66 [00:01<00:08,  6.32it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])
Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  18%|█▊        | 12/66 [00:01<00:08,  6.30it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])
Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  21%|██        | 14/66 [00:02<00:08,  6.35it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])
Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  24%|██▍       | 16/66 [00:02<00:07,  6.35it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])
Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  27%|██▋       | 18/66 [00:02<00:07,  6.24it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])
Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  30%|███       | 20/66 [00:03<00:07,  6.23it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])
Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  33%|███▎      | 22/66 [00:03<00:07,  6.13it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])
Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  36%|███▋      | 24/66 [00:03<00:06,  6.39it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])
Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  39%|███▉      | 26/66 [00:04<00:06,  6.49it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])
Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  42%|████▏     | 28/66 [00:04<00:06,  6.23it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])
Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  45%|████▌     | 30/66 [00:04<00:05,  6.14it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])
Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  48%|████▊     | 32/66 [00:05<00:05,  6.04it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])
Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  52%|█████▏    | 34/66 [00:05<00:05,  6.13it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])
Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  55%|█████▍    | 36/66 [00:05<00:04,  6.26it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])
Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  58%|█████▊    | 38/66 [00:06<00:04,  6.07it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])
Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  61%|██████    | 40/66 [00:06<00:04,  6.20it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])
Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  64%|██████▎   | 42/66 [00:06<00:03,  6.27it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])
Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  67%|██████▋   | 44/66 [00:07<00:03,  6.57it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])
Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  70%|██████▉   | 46/66 [00:07<00:03,  6.39it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])
Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  73%|███████▎  | 48/66 [00:07<00:02,  6.22it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])
Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  76%|███████▌  | 50/66 [00:08<00:02,  5.99it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])
Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  79%|███████▉  | 52/66 [00:08<00:02,  6.03it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])
Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  82%|████████▏ | 54/66 [00:08<00:01,  6.13it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])
Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  85%|████████▍ | 56/66 [00:09<00:01,  6.22it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])
Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  88%|████████▊ | 58/66 [00:09<00:01,  6.20it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])
Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  91%|█████████ | 60/66 [00:09<00:00,  6.07it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])
Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  94%|█████████▍| 62/66 [00:09<00:00,  6.37it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])
Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  97%|█████████▋| 64/66 [00:10<00:00,  6.33it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])
Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing: 100%|██████████| 66/66 [00:10<00:00,  6.23it/s]


Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])
Testing Results:
H1: Test Loss: 0.0991, Accuracy: 0.9981
Precision H1: 0.9981, Recall H1: 0.9981, F1 H1: 0.9981
H2: Test Loss: 0.0991, Accuracy: 0.9905
Precision H2: 0.9905, Recall H2: 0.9905, F1 H2: 0.9905
H3: Test Loss: 0.0991, Custom Accuracy: 0.9904
Precision H3: 0.9919890873015872, Recall H3:0.9869055134680135, F1 Score H3: 0.9894407707993158


In [None]:
#unseen data

In [2]:
# Assuming the new unseen datasets are similar to 'model1_2636.csv', 'model2_2636.csv', 'model3_2636.csv'
import pandas as pd
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset, random_split
from transformers import AutoTokenizer, AutoModel, AdamW, get_linear_schedule_with_warmup

#from transformers import RobertaTokenizer, RobertaModel, AdamW, get_linear_schedule_with_warmup
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from tqdm import tqdm
import numpy as np
import torch

import pandas as pd

# Loading and preprocessing new_data_h1
new_data_h1 = pd.read_csv("112_H_100_O_NEWMODEL1.CSV")
new_data_h1['text'].fillna("", inplace=True)  # Replace NaN values with empty strings
new_texts_h1 = new_data_h1['text'].tolist()
new_labels_h1 = new_data_h1.iloc[:, 1].values.tolist()

# Loading and preprocessing new_data_h2
new_data_h2 = pd.read_csv("112_H_100_O_NEWMODEL2.CSV")
new_data_h2['text'].fillna("", inplace=True)  # Replace NaN values with empty strings
new_texts_h2 = new_data_h2['text'].tolist()
new_labels_h2 = new_data_h2.iloc[:, 1].values.tolist()

# Loading and preprocessing new_data_h3 for multi-label classification
new_data_h3 = pd.read_csv("112_H_100_O_NEWMODEL3.CSV")
new_data_h3['text'].fillna("", inplace=True)  # Replace NaN values with empty strings

# Replace NaN values in label columns with 0 for new_data_h3
# Assuming all label columns are binary, fill NaN values with 0
for label_col in new_data_h3.columns[1:]:  # Skip the text column, only fill NaN in label columns
    new_data_h3[label_col].fillna(0, inplace=True)

new_texts_h3 = new_data_h3['text'].tolist()
new_labels_h3 = new_data_h3.iloc[:, 1:].values.tolist()  # Extract labels after filling NaN values




# Tokenizer

# Import and initialize the ClinicalBERT tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("medicalai/ClinicalBERT")
clinical_bert_model = AutoModel.from_pretrained("medicalai/ClinicalBERT")


# Define the TextDataset class
class TextDataset(Dataset):
    def __init__(self, texts, labels_h1, labels_h2, labels_h3, tokenizer, max_length=512):
        self.texts = texts
        self.labels_h1 = labels_h1  # Now expecting a list of binary values for H1
        self.labels_h2 = np.array(labels_h2).reshape(-1, 1)  # Reshape labels_h2 to [batch_size, 1]
        self.labels_h3 = labels_h3  # Multi-label but binary values for H3
        self.tokenizer = tokenizer
        self.max_length = max_length

    # Rest of the code remains unchanged

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        inputs = self.tokenizer.encode_plus(
            text, None, add_special_tokens=True, max_length=self.max_length,
            padding='max_length', return_token_type_ids=True, truncation=True
        )
        return {
            'ids': torch.tensor(inputs['input_ids'], dtype=torch.long),
            'mask': torch.tensor(inputs['attention_mask'], dtype=torch.long),
            'labels_h1': torch.tensor(self.labels_h1[idx], dtype=torch.float).unsqueeze(-1),
            'labels_h2': torch.tensor(self.labels_h2[idx], dtype=torch.float),
            'labels_h3': torch.tensor(self.labels_h3[idx], dtype=torch.float)
        }
        
# Define the model architecture for H1, H2, and H3
class HierarchicalClassifier(nn.Module):
    def __init__(self, num_labels_h3):
        super(HierarchicalClassifier, self).__init__()
        self.bert = clinical_bert_model  # Use the ClinicalBERT model
        self.dropout = nn.Dropout(0.1)
        self.fc_h1 = nn.Linear(self.bert.config.hidden_size, 1)
        self.fc_h2 = nn.Linear(self.bert.config.hidden_size, 1)
        self.fc_h3 = nn.Linear(self.bert.config.hidden_size, num_labels_h3)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids, attention_mask)
        pooled_output = outputs.last_hidden_state[:, 0, :]
        logits_h1 = self.fc_h1(self.dropout(pooled_output))
        logits_h2 = self.fc_h2(self.dropout(pooled_output))
        logits_h3 = self.fc_h3(self.dropout(pooled_output))
        return logits_h1, logits_h2, logits_h3


# Initialize and move the model to the appropriate device (CPU/GPU)
#model = HierarchicalClassifier(num_labels_h3=len(train_labels_h3[0]))

#device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#model.to(device)

# Define loss function and optimizer
criterion = nn.BCEWithLogitsLoss()

#from transformers import RobertaTokenizer

# Initialize the tokenizer
#tokenizer = RobertaTokenizer.from_pretrained('roberta-base')

# Preprocess the texts and labels in the same way as the original datasets
# Assuming the preprocessing steps are included in the TextDataset class
# Create TextDataset instances for the new data
new_dataset_h1 = TextDataset(new_texts_h1, new_labels_h1, new_labels_h2, new_labels_h3, tokenizer)
new_dataset_h2 = TextDataset(new_texts_h2, new_labels_h1, new_labels_h2, new_labels_h3, tokenizer)
new_dataset_h3 = TextDataset(new_texts_h3, new_labels_h1, new_labels_h2, new_labels_h3, tokenizer)

# Create DataLoader instances for the new datasets
#new_dataloader_h1 = DataLoader(new_dataset_h1, batch_size=8, shuffle=False)
#new_dataloader_h2 = DataLoader(new_dataset_h2, batch_size=8, shuffle=False)
#new_dataloader_h3 = DataLoader(new_dataset_h3, batch_size=8, shuffle=False)


new_dataloader_h1 = DataLoader(new_dataset_h1, batch_size=8, shuffle=False, drop_last=True)
new_dataloader_h2 = DataLoader(new_dataset_h2, batch_size=8, shuffle=False, drop_last=True)
new_dataloader_h3 = DataLoader(new_dataset_h3, batch_size=8, shuffle=False, drop_last=True)

# Load the saved model
model_path = 'model_clinicalBert.pth'  # Path to your saved model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = HierarchicalClassifier(num_labels_h3=len(new_labels_h3[0]))  # Make sure to provide the correct number of labels for H3
model.load_state_dict(torch.load(model_path))
model.eval()  # Set the model to evaluation mode
model.to(device)  # Move the model to the appropriate device
def test(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0.0
    all_preds_h1, all_preds_h2, all_preds_h3 = [], [], []
    all_labels_h1, all_labels_h2, all_labels_h3 = [], [], []

    with torch.no_grad():
        for batch in tqdm(dataloader, total=len(dataloader), desc="Testing"):
            input_ids = batch['ids'].to(device)
            attention_mask = batch['mask'].to(device)
            
            # Ensuring labels are correctly shaped and moved to the device
            labels_h1 = batch['labels_h1'].to(device)  # Shape: [batch_size, 1]
            labels_h2 = batch['labels_h2'].to(device)  # Shape: [batch_size, 1]
            labels_h3 = batch['labels_h3'].to(device)  # Shape: [batch_size, num_labels_h3]
            
            logits_h1, logits_h2, logits_h3 = model(input_ids, attention_mask)
            
            # Compute loss for each hierarchy
            loss_h1 = criterion(logits_h1, labels_h1)
            loss_h2 = criterion(logits_h2, labels_h2)
            loss_h3 = criterion(logits_h3, labels_h3)

            loss = loss_h1 + loss_h2 + loss_h3
            total_loss += loss.item()

            # Convert logits to probabilities for binary classification/multi-label classification
            preds_h1 = torch.sigmoid(logits_h1)
            preds_h2 = torch.sigmoid(logits_h2)
            preds_h3 = torch.sigmoid(logits_h3)

            # Store predictions and labels for metric calculation
            all_preds_h1.extend(preds_h1.cpu().numpy())
            all_preds_h2.extend(preds_h2.cpu().numpy())
            all_preds_h3.extend(preds_h3.cpu().numpy())

            all_labels_h1.extend(labels_h1.cpu().numpy())
            all_labels_h2.extend(labels_h2.cpu().numpy())
            all_labels_h3.extend(labels_h3.cpu().numpy())

    return total_loss / len(dataloader), all_preds_h1, all_preds_h2, all_preds_h3, all_labels_h1, all_labels_h2, all_labels_h3







# Testing the model on the new unseen data
test_loss_h1, test_preds_h1, _, _, test_labels_h1, _, _ = test(model, new_dataloader_h1, criterion, device)
test_loss_h2, _, test_preds_h2, _, _, test_labels_h2, _ = test(model, new_dataloader_h2, criterion, device)
test_loss_h3, _, _, test_preds_h3, _, _, test_labels_h3 = test(model, new_dataloader_h3, criterion, device)

# Calculate and print the metrics for H1, H2, and H3 using the same approach as before
# Use the appropriate threshold and custom functions for metrics calculation

# Metrics calculation for H1
threshold_h1 = 0.5
test_preds_h1_binary = (np.array(test_preds_h1) > threshold_h1).astype(int)

acc_h1 = accuracy_score(test_labels_h1, test_preds_h1_binary)
precision_h1 = precision_score(test_labels_h1, test_preds_h1_binary, average='micro')
recall_h1 = recall_score(test_labels_h1, test_preds_h1_binary, average='micro')
f1_h1 = f1_score(test_labels_h1, test_preds_h1_binary, average='micro')

# Metrics calculation for H2
threshold_h2 = 0.5
test_preds_h2_binary = (np.array(test_preds_h2) > threshold_h2).astype(int)

acc_h2 = accuracy_score(test_labels_h2, test_preds_h2_binary)
precision_h2 = precision_score(test_labels_h2, test_preds_h2_binary, average='micro')
recall_h2 = recall_score(test_labels_h2, test_preds_h2_binary, average='micro')
f1_h2 = f1_score(test_labels_h2, test_preds_h2_binary, average='micro')

# Metrics calculation for H3
threshold_h3 = 0.5
test_preds_h3_binary = (np.array(test_preds_h3) > threshold_h3).astype(int)

import numpy as np

def custom_accuracy(y_true, y_pred):
    correct_predictions = np.equal(y_true, y_pred)
    sample_accuracy = np.sum(correct_predictions, axis=1) / y_true.shape[1]
    return np.mean(sample_accuracy)

def custom_precision_recall_f1(y_true, y_pred):
    true_positives = np.sum((y_true == 1) & (y_pred == 1), axis=1)
    false_positives = np.sum((y_true == 0) & (y_pred == 1), axis=1)
    false_negatives = np.sum((y_true == 1) & (y_pred == 0), axis=1)

    precision = np.mean(true_positives / (true_positives + false_positives + 1e-8))
    recall = np.mean(true_positives / (true_positives + false_negatives + 1e-8))
    f1 = 2 * (precision * recall) / (precision + recall + 1e-8)

    return precision, recall, f1




# Custom accuracy calculation for H3
acc_h3_custom = custom_accuracy(np.array(test_labels_h3), test_preds_h3_binary)
# Calculate custom precision, recall, and F1 score for H3
precision_h3, recall_h3, f1_score_h3 = custom_precision_recall_f1(np.array(test_labels_h3), test_preds_h3_binary)

print("Testing Results:")
print(f"H1: Test Loss: {test_loss_h1:.4f}, Accuracy: {acc_h1:.4f}")
print(f"Precision H1: {precision_h1:.4f}, Recall H1: {recall_h1:.4f}, F1 H1: {f1_h1:.4f}")

print(f"H2: Test Loss: {test_loss_h2:.4f}, Accuracy: {acc_h2:.4f}")
print(f"Precision H2: {precision_h2:.4f}, Recall H2: {recall_h2:.4f}, F1 H2: {f1_h2:.4f}")

print(f"H3: Test Loss: {test_loss_h3:.4f}, Custom Accuracy: {acc_h3_custom:.4f}")
print(f"Precision H3: {precision_h3}, Recall H3:{recall_h3}, F1 Score H3: {f1_score_h3}")





Testing: 100%|██████████| 26/26 [00:04<00:00,  5.22it/s]
Testing: 100%|██████████| 26/26 [00:04<00:00,  5.96it/s]
Testing: 100%|██████████| 26/26 [00:04<00:00,  5.92it/s]

Testing Results:
H1: Test Loss: 0.5658, Accuracy: 0.9712
Precision H1: 0.9712, Recall H1: 0.9712, F1 H1: 0.9712
H2: Test Loss: 0.5658, Accuracy: 0.9856
Precision H2: 0.9856, Recall H2: 0.9856, F1 H2: 0.9856
H3: Test Loss: 0.5658, Custom Accuracy: 0.8961
Precision H3: 0.7919204039683364, Recall H3:0.9029647408127635, F1 Score H3: 0.8438049035906062





In [None]:
#TESTING_Chatgpt_Eq_text

In [2]:
# Assuming the new unseen datasets are similar to 'model1_2636.csv', 'model2_2636.csv', 'model3_2636.csv'
import pandas as pd
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset, random_split
from transformers import AutoTokenizer, AutoModel, AdamW, get_linear_schedule_with_warmup

#from transformers import RobertaTokenizer, RobertaModel, AdamW, get_linear_schedule_with_warmup
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from tqdm import tqdm
import numpy as np
import torch

import pandas as pd

# Loading and preprocessing new_data_h1
new_data_h1 = pd.read_csv("P3_Chatgpt_Eq_Model1_508.csv")
new_data_h1['text'].fillna("", inplace=True)  # Replace NaN values with empty strings
new_texts_h1 = new_data_h1['text'].tolist()
new_labels_h1 = new_data_h1.iloc[:, 1].values.tolist()

# Loading and preprocessing new_data_h2
new_data_h2 = pd.read_csv("P3_Chatgpt_Eq_Model2_508.csv")
new_data_h2['text'].fillna("", inplace=True)  # Replace NaN values with empty strings
new_texts_h2 = new_data_h2['text'].tolist()
new_labels_h2 = new_data_h2.iloc[:, 1].values.tolist()

# Loading and preprocessing new_data_h3 for multi-label classification
new_data_h3 = pd.read_csv("P3_Chatgpt_Eq_Model3_508.csv")
new_data_h3['text'].fillna("", inplace=True)  # Replace NaN values with empty strings

# Replace NaN values in label columns with 0 for new_data_h3
# Assuming all label columns are binary, fill NaN values with 0
for label_col in new_data_h3.columns[1:]:  # Skip the text column, only fill NaN in label columns
    new_data_h3[label_col].fillna(0, inplace=True)

new_texts_h3 = new_data_h3['text'].tolist()
new_labels_h3 = new_data_h3.iloc[:, 1:].values.tolist()  # Extract labels after filling NaN values




# Tokenizer

# Import and initialize the ClinicalBERT tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("medicalai/ClinicalBERT")
clinical_bert_model = AutoModel.from_pretrained("medicalai/ClinicalBERT")


# Define the TextDataset class
class TextDataset(Dataset):
    def __init__(self, texts, labels_h1, labels_h2, labels_h3, tokenizer, max_length=512):
        self.texts = texts
        self.labels_h1 = labels_h1  # Now expecting a list of binary values for H1
        self.labels_h2 = np.array(labels_h2).reshape(-1, 1)  # Reshape labels_h2 to [batch_size, 1]
        self.labels_h3 = labels_h3  # Multi-label but binary values for H3
        self.tokenizer = tokenizer
        self.max_length = max_length

    # Rest of the code remains unchanged

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        inputs = self.tokenizer.encode_plus(
            text, None, add_special_tokens=True, max_length=self.max_length,
            padding='max_length', return_token_type_ids=True, truncation=True
        )
        return {
            'ids': torch.tensor(inputs['input_ids'], dtype=torch.long),
            'mask': torch.tensor(inputs['attention_mask'], dtype=torch.long),
            'labels_h1': torch.tensor(self.labels_h1[idx], dtype=torch.float).unsqueeze(-1),
            'labels_h2': torch.tensor(self.labels_h2[idx], dtype=torch.float),
            'labels_h3': torch.tensor(self.labels_h3[idx], dtype=torch.float)
        }
        
# Define the model architecture for H1, H2, and H3
class HierarchicalClassifier(nn.Module):
    def __init__(self, num_labels_h3):
        super(HierarchicalClassifier, self).__init__()
        self.bert = clinical_bert_model  # Use the ClinicalBERT model
        self.dropout = nn.Dropout(0.1)
        self.fc_h1 = nn.Linear(self.bert.config.hidden_size, 1)
        self.fc_h2 = nn.Linear(self.bert.config.hidden_size, 1)
        self.fc_h3 = nn.Linear(self.bert.config.hidden_size, num_labels_h3)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids, attention_mask)
        pooled_output = outputs.last_hidden_state[:, 0, :]
        logits_h1 = self.fc_h1(self.dropout(pooled_output))
        logits_h2 = self.fc_h2(self.dropout(pooled_output))
        logits_h3 = self.fc_h3(self.dropout(pooled_output))
        return logits_h1, logits_h2, logits_h3


# Initialize and move the model to the appropriate device (CPU/GPU)
#model = HierarchicalClassifier(num_labels_h3=len(train_labels_h3[0]))

#device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#model.to(device)

# Define loss function and optimizer
criterion = nn.BCEWithLogitsLoss()

#from transformers import RobertaTokenizer

# Initialize the tokenizer
#tokenizer = RobertaTokenizer.from_pretrained('roberta-base')

# Preprocess the texts and labels in the same way as the original datasets
# Assuming the preprocessing steps are included in the TextDataset class
# Create TextDataset instances for the new data
new_dataset_h1 = TextDataset(new_texts_h1, new_labels_h1, new_labels_h2, new_labels_h3, tokenizer)
new_dataset_h2 = TextDataset(new_texts_h2, new_labels_h1, new_labels_h2, new_labels_h3, tokenizer)
new_dataset_h3 = TextDataset(new_texts_h3, new_labels_h1, new_labels_h2, new_labels_h3, tokenizer)

# Create DataLoader instances for the new datasets
#new_dataloader_h1 = DataLoader(new_dataset_h1, batch_size=8, shuffle=False)
#new_dataloader_h2 = DataLoader(new_dataset_h2, batch_size=8, shuffle=False)
#new_dataloader_h3 = DataLoader(new_dataset_h3, batch_size=8, shuffle=False)


new_dataloader_h1 = DataLoader(new_dataset_h1, batch_size=8, shuffle=False, drop_last=True)
new_dataloader_h2 = DataLoader(new_dataset_h2, batch_size=8, shuffle=False, drop_last=True)
new_dataloader_h3 = DataLoader(new_dataset_h3, batch_size=8, shuffle=False, drop_last=True)

# Load the saved model
model_path = 'model_clinicalBert.pth'  # Path to your saved model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = HierarchicalClassifier(num_labels_h3=len(new_labels_h3[0]))  # Make sure to provide the correct number of labels for H3
model.load_state_dict(torch.load(model_path))
model.eval()  # Set the model to evaluation mode
model.to(device)  # Move the model to the appropriate device
def test(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0.0
    all_preds_h1, all_preds_h2, all_preds_h3 = [], [], []
    all_labels_h1, all_labels_h2, all_labels_h3 = [], [], []

    with torch.no_grad():
        for batch in tqdm(dataloader, total=len(dataloader), desc="Testing"):
            input_ids = batch['ids'].to(device)
            attention_mask = batch['mask'].to(device)
            
            # Ensuring labels are correctly shaped and moved to the device
            labels_h1 = batch['labels_h1'].to(device)  # Shape: [batch_size, 1]
            labels_h2 = batch['labels_h2'].to(device)  # Shape: [batch_size, 1]
            labels_h3 = batch['labels_h3'].to(device)  # Shape: [batch_size, num_labels_h3]
            
            logits_h1, logits_h2, logits_h3 = model(input_ids, attention_mask)
            
            # Compute loss for each hierarchy
            loss_h1 = criterion(logits_h1, labels_h1)
            loss_h2 = criterion(logits_h2, labels_h2)
            loss_h3 = criterion(logits_h3, labels_h3)

            loss = loss_h1 + loss_h2 + loss_h3
            total_loss += loss.item()

            # Convert logits to probabilities for binary classification/multi-label classification
            preds_h1 = torch.sigmoid(logits_h1)
            preds_h2 = torch.sigmoid(logits_h2)
            preds_h3 = torch.sigmoid(logits_h3)

            # Store predictions and labels for metric calculation
            all_preds_h1.extend(preds_h1.cpu().numpy())
            all_preds_h2.extend(preds_h2.cpu().numpy())
            all_preds_h3.extend(preds_h3.cpu().numpy())

            all_labels_h1.extend(labels_h1.cpu().numpy())
            all_labels_h2.extend(labels_h2.cpu().numpy())
            all_labels_h3.extend(labels_h3.cpu().numpy())

    return total_loss / len(dataloader), all_preds_h1, all_preds_h2, all_preds_h3, all_labels_h1, all_labels_h2, all_labels_h3







# Testing the model on the new unseen data
test_loss_h1, test_preds_h1, _, _, test_labels_h1, _, _ = test(model, new_dataloader_h1, criterion, device)
test_loss_h2, _, test_preds_h2, _, _, test_labels_h2, _ = test(model, new_dataloader_h2, criterion, device)
test_loss_h3, _, _, test_preds_h3, _, _, test_labels_h3 = test(model, new_dataloader_h3, criterion, device)

# Calculate and print the metrics for H1, H2, and H3 using the same approach as before
# Use the appropriate threshold and custom functions for metrics calculation

# Metrics calculation for H1
threshold_h1 = 0.5
test_preds_h1_binary = (np.array(test_preds_h1) > threshold_h1).astype(int)

acc_h1 = accuracy_score(test_labels_h1, test_preds_h1_binary)
precision_h1 = precision_score(test_labels_h1, test_preds_h1_binary, average='micro')
recall_h1 = recall_score(test_labels_h1, test_preds_h1_binary, average='micro')
f1_h1 = f1_score(test_labels_h1, test_preds_h1_binary, average='micro')

# Metrics calculation for H2
threshold_h2 = 0.5
test_preds_h2_binary = (np.array(test_preds_h2) > threshold_h2).astype(int)

acc_h2 = accuracy_score(test_labels_h2, test_preds_h2_binary)
precision_h2 = precision_score(test_labels_h2, test_preds_h2_binary, average='micro')
recall_h2 = recall_score(test_labels_h2, test_preds_h2_binary, average='micro')
f1_h2 = f1_score(test_labels_h2, test_preds_h2_binary, average='micro')

# Metrics calculation for H3
threshold_h3 = 0.5
test_preds_h3_binary = (np.array(test_preds_h3) > threshold_h3).astype(int)

import numpy as np

def custom_accuracy(y_true, y_pred):
    correct_predictions = np.equal(y_true, y_pred)
    sample_accuracy = np.sum(correct_predictions, axis=1) / y_true.shape[1]
    return np.mean(sample_accuracy)

def custom_precision_recall_f1(y_true, y_pred):
    true_positives = np.sum((y_true == 1) & (y_pred == 1), axis=1)
    false_positives = np.sum((y_true == 0) & (y_pred == 1), axis=1)
    false_negatives = np.sum((y_true == 1) & (y_pred == 0), axis=1)

    precision = np.mean(true_positives / (true_positives + false_positives + 1e-8))
    recall = np.mean(true_positives / (true_positives + false_negatives + 1e-8))
    f1 = 2 * (precision * recall) / (precision + recall + 1e-8)

    return precision, recall, f1




# Custom accuracy calculation for H3
acc_h3_custom = custom_accuracy(np.array(test_labels_h3), test_preds_h3_binary)
# Calculate custom precision, recall, and F1 score for H3
precision_h3, recall_h3, f1_score_h3 = custom_precision_recall_f1(np.array(test_labels_h3), test_preds_h3_binary)

print("Testing Results:")
print(f"H1: Test Loss: {test_loss_h1:.4f}, Accuracy: {acc_h1:.4f}")
print(f"Precision H1: {precision_h1:.4f}, Recall H1: {recall_h1:.4f}, F1 H1: {f1_h1:.4f}")

print(f"H2: Test Loss: {test_loss_h2:.4f}, Accuracy: {acc_h2:.4f}")
print(f"Precision H2: {precision_h2:.4f}, Recall H2: {recall_h2:.4f}, F1 H2: {f1_h2:.4f}")

print(f"H3: Test Loss: {test_loss_h3:.4f}, Custom Accuracy: {acc_h3_custom:.4f}")
print(f"Precision H3: {precision_h3}, Recall H3:{recall_h3}, F1 Score H3: {f1_score_h3}")





Testing: 100%|██████████| 63/63 [00:17<00:00,  3.65it/s]
Testing: 100%|██████████| 63/63 [00:16<00:00,  3.79it/s]
Testing: 100%|██████████| 63/63 [00:16<00:00,  3.74it/s]

Testing Results:
H1: Test Loss: 4.4273, Accuracy: 0.8472
Precision H1: 0.8472, Recall H1: 0.8472, F1 H1: 0.8472
H2: Test Loss: 4.3395, Accuracy: 0.5258
Precision H2: 0.5258, Recall H2: 0.5258, F1 H2: 0.5258
H3: Test Loss: 4.3395, Custom Accuracy: 0.7060
Precision H3: 0.6222686741955771, Recall H3:0.5491439322256687, F1 Score H3: 0.583423913757354



