In [1]:
!export CUDA_VISIBLE_DEVICES=0


In [2]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

# Now import your GPU-related libraries, such as PyTorch
import torch
torch.cuda.empty_cache()

torch.cuda.memory_summary(device=None, abbreviated=False)




In [3]:
# Import libraries and set up CUDA
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

import torch
torch.cuda.empty_cache()

import pandas as pd
import numpy as np
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset, random_split
from transformers import AutoTokenizer, AutoModel, AdamW, get_linear_schedule_with_warmup
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from tqdm import tqdm
import random


random_seed = 777  # You can change this value to any seed you prefer
random.seed(random_seed)
np.random.seed(random_seed)
torch.manual_seed(random_seed)
torch.cuda.manual_seed_all(random_seed)


# Load and preprocess your data for H1, H2, and H3
model1_data = pd.read_csv("model1_2636.csv")
texts_h1 = model1_data['text'].tolist()
labels_h1 = model1_data.iloc[:, 1].values.tolist()

model2_data = pd.read_csv("model2_2636.csv")
texts_h2 = model2_data['text'].tolist()
labels_h2 = model2_data.iloc[:, 1].values.tolist()

model3_data = pd.read_csv("model3_2636.csv")
texts_h3 = model3_data['text'].tolist()
labels_h3 = model3_data.iloc[:, 1:].values.tolist()

#print("Total Records in H1:", len(texts_h1),"texts_h1", "with", len(labels_h1), "labels")
#print("Total Records in H2:", len(texts_h2),"texts_h2","with", len(labels_h2), "labels")
#print("Total Records in H3:", len(texts_h3),"texts_h3", "with", len(labels_h3), "labels")

# Split the data for each hierarchy
def split_data(texts, labels, train_ratio, val_ratio, test_ratio):
    total_samples = len(texts)
    train_size = int(total_samples * train_ratio)
    val_size = int(total_samples * val_ratio)
    test_size = total_samples - train_size - val_size  # Remaining samples for test

    train_texts = texts[:train_size]
    val_texts = texts[train_size:train_size + val_size]
    test_texts = texts[train_size + val_size:]

    train_labels = labels[:train_size]
    val_labels = labels[train_size:train_size + val_size]
    test_labels = labels[train_size + val_size:]

    return train_texts, val_texts, test_texts, train_labels, val_labels, test_labels

# Split the data for each hierarchy
train_texts_h1, val_texts_h1, test_texts_h1, train_labels_h1, val_labels_h1, test_labels_h1 = split_data(texts_h1, labels_h1, 0.7, 0.1, 0.2)
train_texts_h2, val_texts_h2, test_texts_h2, train_labels_h2, val_labels_h2, test_labels_h2 = split_data(texts_h2, labels_h2, 0.7, 0.1, 0.2)
train_texts_h3, val_texts_h3, test_texts_h3, train_labels_h3, val_labels_h3, test_labels_h3 = split_data(texts_h3, labels_h3, 0.7, 0.1, 0.2)

#print("Train Val Test Splitting:")
#print("H1:", len(train_texts_h1), len(val_texts_h1), len(test_texts_h1))
#print("H2:", len(train_texts_h2), len(val_texts_h2), len(test_texts_h2))
#print("H3:", len(train_texts_h3), len(val_texts_h3), len(test_texts_h3))

# Tokenizer
# Tokenizer and Model setup for Clinical-BigBird
tokenizer = AutoTokenizer.from_pretrained("yikuan8/Clinical-BigBird")
bigbird_model = AutoModel.from_pretrained("yikuan8/Clinical-BigBird")

# Define the TextDataset class
class TextDataset(Dataset):
    def __init__(self, texts, labels_h1, labels_h2, labels_h3, tokenizer, max_length=512):
        self.texts = texts
        self.labels_h1 = labels_h1  # Now expecting a list of binary values for H1
        self.labels_h2 = np.array(labels_h2).reshape(-1, 1)  # Reshape labels_h2 to [batch_size, 1]
        self.labels_h3 = labels_h3  # Multi-label but binary values for H3
        self.tokenizer = tokenizer
        self.max_length = max_length

    # Rest of the code remains unchanged

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        inputs = self.tokenizer.encode_plus(
            text, None, add_special_tokens=True, max_length=self.max_length,
            padding='max_length', return_token_type_ids=True, truncation=True
        )
        return {
            'ids': torch.tensor(inputs['input_ids'], dtype=torch.long),
            'mask': torch.tensor(inputs['attention_mask'], dtype=torch.long),
            'labels_h1': torch.tensor(self.labels_h1[idx], dtype=torch.float).unsqueeze(-1),
            'labels_h2': torch.tensor(self.labels_h2[idx], dtype=torch.float),
            'labels_h3': torch.tensor(self.labels_h3[idx], dtype=torch.float)
        }

# Create datasets and dataloaders for H1, H2, and H3
dataset_h1 = TextDataset(train_texts_h1 + val_texts_h1 + test_texts_h1, train_labels_h1 + val_labels_h1 + test_labels_h1, train_labels_h2 + val_labels_h2 + test_labels_h2, train_labels_h3 + val_labels_h3 + test_labels_h3, tokenizer)
dataset_h2 = TextDataset(train_texts_h2 + val_texts_h2 + test_texts_h2, train_labels_h1 + val_labels_h1 + test_labels_h1, train_labels_h2 + val_labels_h2 + test_labels_h2, train_labels_h3 + val_labels_h3 + test_labels_h3, tokenizer)
dataset_h3 = TextDataset(train_texts_h3 + val_texts_h3 + test_texts_h3, train_labels_h1 + val_labels_h1 + test_labels_h1, train_labels_h2 + val_labels_h2 + test_labels_h2, train_labels_h3 + val_labels_h3 + test_labels_h3, tokenizer)

# Split the datasets into train, val, and test for H1, H2, and H3
train_size_h1 = len(train_texts_h1)
val_size_h1 = len(val_texts_h1)
train_size_h2 = len(train_texts_h2)
val_size_h2 = len(val_texts_h2)
train_size_h3 = len(train_texts_h3)
val_size_h3 = len(val_texts_h3)

train_dataset_h1, val_dataset_h1, test_dataset_h1 = random_split(dataset_h1, [train_size_h1, val_size_h1, len(test_texts_h1)])
train_dataset_h2, val_dataset_h2, test_dataset_h2 = random_split(dataset_h2, [train_size_h2, val_size_h2, len(test_texts_h2)])
train_dataset_h3, val_dataset_h3, test_dataset_h3 = random_split(dataset_h3, [train_size_h3, val_size_h3, len(test_texts_h3)])

# Create dataloaders for H1, H2, and H3
train_dataloader_h1 = DataLoader(train_dataset_h1, batch_size=8, shuffle=True)
val_dataloader_h1 = DataLoader(val_dataset_h1,batch_size=8, shuffle=False)  # You can set shuffle to True if you want to shuffle the validation data.
train_dataloader_h2 = DataLoader(train_dataset_h2, batch_size=8, shuffle=True)
val_dataloader_h2 = DataLoader(val_dataset_h2, batch_size=8, shuffle=False)

train_dataloader_h3 = DataLoader(train_dataset_h3, batch_size=8, shuffle=True)
val_dataloader_h3 = DataLoader(val_dataset_h3, batch_size=8, shuffle=False)

# Define the model architecture for H1, H2, and H3 using Clinical-BigBird
class HierarchicalClassifier(nn.Module):
    def __init__(self, num_labels_h3):
        super(HierarchicalClassifier, self).__init__()
        self.bigbird = bigbird_model
        self.dropout = nn.Dropout(0.1)
        self.fc_h1 = nn.Linear(self.bigbird.config.hidden_size, 1)
        self.fc_h2 = nn.Linear(self.bigbird.config.hidden_size, 1)
        self.fc_h3 = nn.Linear(self.bigbird.config.hidden_size, num_labels_h3)

    def forward(self, input_ids, attention_mask):
        outputs = self.bigbird(input_ids, attention_mask=attention_mask)
        pooled_output = outputs.last_hidden_state[:, 0, :]
        logits_h1 = self.fc_h1(self.dropout(pooled_output))
        logits_h2 = self.fc_h2(self.dropout(pooled_output))
        logits_h3 = self.fc_h3(self.dropout(pooled_output))
        return logits_h1, logits_h2, logits_h3
        
            
# Initialize and move the model to the appropriate device (CPU/GPU)
model = HierarchicalClassifier(num_labels_h3=len(train_labels_h3[0]))

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Define loss function and optimizer
criterion = nn.BCEWithLogitsLoss()
optimizer = AdamW(model.parameters(), lr=1e-5)  # You can adjust the learning rate as needed

# Training loop for H1, H2, and H3
def train(model, dataloader, optimizer, criterion, device):
    model.train()
    total_loss = 0.0

    for batch in tqdm(dataloader, total=len(dataloader), desc="Training"):
        input_ids = batch['ids'].to(device)
        attention_mask = batch['mask'].to(device)
        labels_h1 = batch['labels_h1'].to(device)
        labels_h2 = batch['labels_h2'].to(device)
        labels_h3 = batch['labels_h3'].to(device)

        optimizer.zero_grad()

        logits_h1, logits_h2, logits_h3 = model(input_ids, attention_mask)
        loss_h1 = criterion(logits_h1, labels_h1)
        loss_h2 = criterion(logits_h2, labels_h2)
        loss_h3 = criterion(logits_h3, labels_h3)

        loss = loss_h1 + loss_h2 + loss_h3
        total_loss += loss.item()

        loss.backward()
        optimizer.step()

    return total_loss / len(dataloader)

# Validation loop for H1, H2, and H3
def evaluate(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0.0
    all_preds_h1, all_preds_h2, all_preds_h3 = [], [], []
    all_labels_h1, all_labels_h2, all_labels_h3 = [], [], []

    with torch.no_grad():
        for batch in tqdm(dataloader, total=len(dataloader), desc="Validation"):
            input_ids = batch['ids'].to(device)
            attention_mask = batch['mask'].to(device)
            labels_h1 = batch['labels_h1'].to(device)
            labels_h2 = batch['labels_h2'].to(device)
            labels_h3 = batch['labels_h3'].to(device)

            logits_h1, logits_h2, logits_h3 = model(input_ids, attention_mask)
            loss_h1 = criterion(logits_h1, labels_h1)
            loss_h2 = criterion(logits_h2, labels_h2)
            loss_h3 = criterion(logits_h3, labels_h3)

            loss = loss_h1 + loss_h2 + loss_h3
            total_loss += loss.item()

            preds_h1 = torch.sigmoid(logits_h1)
            preds_h2 = torch.sigmoid(logits_h2)
            preds_h3 = torch.sigmoid(logits_h3)

            all_preds_h1.extend(preds_h1.cpu().numpy())
            all_preds_h2.extend(preds_h2.cpu().numpy())
            all_preds_h3.extend(preds_h3.cpu().numpy())

            all_labels_h1.extend(labels_h1.cpu().numpy())
            all_labels_h2.extend(labels_h2.cpu().numpy())
            all_labels_h3.extend(labels_h3.cpu().numpy())

    return total_loss / len(dataloader), all_preds_h1, all_preds_h2, all_preds_h3, all_labels_h1, all_labels_h2, all_labels_h3

# Training and evaluation for H1, H2, and H3
# Custom accuracy function
def custom_accuracy(y_true, y_pred):
    correct_labels = np.sum(np.equal(y_true, y_pred), axis=1)
    total_labels = y_true.shape[1]
    sample_accuracy = np.mean(correct_labels / total_labels)
    return sample_accuracy

def custom_precision_recall_f1(y_true, y_pred):
    # Calculate precision and recall for each sample
    sample_precisions = []
    sample_recalls = []
    for true, pred in zip(y_true, y_pred):
        true_positive = np.sum((true == 1) & (pred == 1))
        false_positive = np.sum((true == 0) & (pred == 1))
        false_negative = np.sum((true == 1) & (pred == 0))
        
        sample_precision = true_positive / (true_positive + false_positive) if (true_positive + false_positive) > 0 else 0
        sample_recall = true_positive / (true_positive + false_negative) if (true_positive + false_negative) > 0 else 0
        
        sample_precisions.append(sample_precision)
        sample_recalls.append(sample_recall)
    
    # Calculate average precision and recall across all samples
    avg_precision = np.mean(sample_precisions)
    avg_recall = np.mean(sample_recalls)
    
    # Calculate F1 score
    if avg_precision + avg_recall > 0:
        avg_f1 = 2 * (avg_precision * avg_recall) / (avg_precision + avg_recall)
    else:
        avg_f1 = 0
    
    return avg_precision, avg_recall, avg_f1


# Training and evaluation for H1, H2, and H3
num_epochs = 33 # You can adjust the number of epochs as needed

for epoch in range(num_epochs):
    # Training for H1
    train_loss_h1 = train(model, train_dataloader_h1, optimizer, criterion, device)
    # Training for H2
    train_loss_h2 = train(model, train_dataloader_h2, optimizer, criterion, device)
    # Training for H3
    train_loss_h3 = train(model, train_dataloader_h3, optimizer, criterion, device)

    # Validation for H1
    val_loss_h1, val_preds_h1, _, _, val_labels_h1, _, _ = evaluate(model, val_dataloader_h1, criterion, device)
    # Validation for H2
    val_loss_h2, _, val_preds_h2, _, _, val_labels_h2, _ = evaluate(model, val_dataloader_h2, criterion, device)
    # Validation for H3
    val_loss_h3, _, _, val_preds_h3, _, _, val_labels_h3 = evaluate(model, val_dataloader_h3, criterion, device)

    # Metrics calculation for H1
    threshold_h1 = 0.5
    val_preds_h1_binary = (np.array(val_preds_h1) > threshold_h1).astype(int)

    acc_h1 = accuracy_score(val_labels_h1, val_preds_h1_binary)
    precision_h1 = precision_score(val_labels_h1, val_preds_h1_binary, average='micro')
    recall_h1 = recall_score(val_labels_h1, val_preds_h1_binary, average='micro')
    f1_h1 = f1_score(val_labels_h1, val_preds_h1_binary, average='micro')


    # Metrics calculation for H2
    threshold_h2 = 0.5
    val_preds_h2_binary = (np.array(val_preds_h2) > threshold_h2).astype(int)

    acc_h2 = accuracy_score(val_labels_h2, val_preds_h2_binary)
    precision_h2 = precision_score(val_labels_h2, val_preds_h2_binary, average='micro')
    recall_h2 = recall_score(val_labels_h2, val_preds_h2_binary, average='micro')
    f1_h2 = f1_score(val_labels_h2, val_preds_h2_binary, average='micro')

# Convert predictions and labels to binary using the threshold

    # Metrics calculation for H3
    threshold_h3 = 0.5
    val_preds_h3_binary = (np.array(val_preds_h3) > threshold_h3).astype(int)


    # Custom accuracy calculation for H3
    acc_h3_custom = custom_accuracy(np.array(val_labels_h3), val_preds_h3_binary)
    # Calculate custom precision, recall, and F1 score for H3
    precision_h3, recall_h3, f1_score_h3 = custom_precision_recall_f1(np.array(val_labels_h3), val_preds_h3_binary)

    print(f"Epoch {epoch+1}/{num_epochs}")
    print(f"H1: Train Loss: {train_loss_h1:.4f}, Val Loss: {val_loss_h1:.4f}, Accuracy: {acc_h1:.4f}")
    print(f"Precision H1: {precision_h1:.4f}, Recall H1: {recall_h1:.4f}, F1 H1: {f1_h1:.4f}")

    print(f"H2: Train Loss: {train_loss_h2:.4f}, Val Loss: {val_loss_h2:.4f}, Accuracy: {acc_h2:.4f}")
    print(f"Precision H2: {precision_h2:.4f}, Recall H2: {recall_h2:.4f}, F1 H2: {f1_h2:.4f}")

    print(f"H3: Train Loss: {train_loss_h3:.4f}, Val Loss: {val_loss_h3:.4f}, Custom Accuracy: {acc_h3_custom:.4f}")
    print(f"Precision H3: {precision_h3},Recall H3:{recall_h3},F1 Score H3: {f1_score_h3}")

# Saving the model
torch.save(model.state_dict(), 'model_Clinical_Bird.pth')

# Loading the model
#loaded_model = YourModelClass(*args, **kwargs)  # Instantiate your model
#loaded_model.load_state_dict(torch.load('model.pth'))
#loaded_model.eval()  # Set the model to evaluation mode



Training:   0%|          | 0/231 [00:00<?, ?it/s]Attention type 'block_sparse' is not possible if sequence_length: 512 <= num global tokens: 2 * config.block_size + min. num sliding tokens: 3 * config.block_size + config.num_random_blocks * config.block_size + additional buffer: config.num_random_blocks * config.block_size = 704 with config.block_size = 64, config.num_random_blocks = 3. Changing attention type to 'original_full'...
Training: 100%|██████████| 231/231 [03:13<00:00,  1.19it/s]
Training: 100%|██████████| 231/231 [03:12<00:00,  1.20it/s]
Training: 100%|██████████| 231/231 [03:11<00:00,  1.21it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.43it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.43it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.43it/s]


Epoch 1/33
H1: Train Loss: 1.7047, Val Loss: 1.3918, Accuracy: 0.9316
Precision H1: 0.9316, Recall H1: 0.9316, F1 H1: 0.9316
H2: Train Loss: 1.5215, Val Loss: 1.4247, Accuracy: 0.6578
Precision H2: 0.6578, Recall H2: 0.6578, F1 H2: 0.6578
H3: Train Loss: 1.4641, Val Loss: 1.4124, Custom Accuracy: 0.6727
Precision H3: 0.6709487597320297,Recall H3:0.5243621714154034,F1 Score H3: 0.5886671649925109


Training: 100%|██████████| 231/231 [03:08<00:00,  1.22it/s]
Training: 100%|██████████| 231/231 [03:08<00:00,  1.23it/s]
Training: 100%|██████████| 231/231 [03:08<00:00,  1.23it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.42it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.42it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.43it/s]


Epoch 2/33
H1: Train Loss: 1.4280, Val Loss: 1.3310, Accuracy: 0.9354
Precision H1: 0.9354, Recall H1: 0.9354, F1 H1: 0.9354
H2: Train Loss: 1.4019, Val Loss: 1.3621, Accuracy: 0.6768
Precision H2: 0.6768, Recall H2: 0.6768, F1 H2: 0.6768
H3: Train Loss: 1.3332, Val Loss: 1.3717, Custom Accuracy: 0.6748
Precision H3: 0.6976552598225602,Recall H3:0.47852149962416124,F1 Score H3: 0.5676749493129188


Training: 100%|██████████| 231/231 [03:09<00:00,  1.22it/s]
Training: 100%|██████████| 231/231 [03:08<00:00,  1.22it/s]
Training: 100%|██████████| 231/231 [03:08<00:00,  1.23it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.42it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.42it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.44it/s]


Epoch 3/33
H1: Train Loss: 1.2755, Val Loss: 1.1246, Accuracy: 0.9506
Precision H1: 0.9506, Recall H1: 0.9506, F1 H1: 0.9506
H2: Train Loss: 1.2329, Val Loss: 1.1284, Accuracy: 0.7795
Precision H2: 0.7795, Recall H2: 0.7795, F1 H2: 0.7795
H3: Train Loss: 1.1327, Val Loss: 1.1770, Custom Accuracy: 0.6868
Precision H3: 0.6920876335325004,Recall H3:0.5183058175453613,F1 Score H3: 0.592721393843731


Training: 100%|██████████| 231/231 [03:08<00:00,  1.22it/s]
Training: 100%|██████████| 231/231 [03:08<00:00,  1.23it/s]
Training: 100%|██████████| 231/231 [03:08<00:00,  1.23it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.43it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.42it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.43it/s]


Epoch 4/33
H1: Train Loss: 1.0595, Val Loss: 0.9583, Accuracy: 0.9658
Precision H1: 0.9658, Recall H1: 0.9658, F1 H1: 0.9658
H2: Train Loss: 1.0220, Val Loss: 1.0047, Accuracy: 0.8441
Precision H2: 0.8441, Recall H2: 0.8441, F1 H2: 0.8441
H3: Train Loss: 0.9619, Val Loss: 1.0143, Custom Accuracy: 0.6961
Precision H3: 0.6760275212746697,Recall H3:0.560620188852128,F1 Score H3: 0.6129387917718001


Training: 100%|██████████| 231/231 [03:08<00:00,  1.23it/s]
Training: 100%|██████████| 231/231 [03:08<00:00,  1.23it/s]
Training: 100%|██████████| 231/231 [03:08<00:00,  1.23it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.42it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.42it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.42it/s]


Epoch 5/33
H1: Train Loss: 0.8618, Val Loss: 0.8592, Accuracy: 0.9658
Precision H1: 0.9658, Recall H1: 0.9658, F1 H1: 0.9658
H2: Train Loss: 0.8370, Val Loss: 0.8992, Accuracy: 0.8935
Precision H2: 0.8935, Recall H2: 0.8935, F1 H2: 0.8935
H3: Train Loss: 0.7818, Val Loss: 0.9633, Custom Accuracy: 0.7031
Precision H3: 0.6875475285171102,Recall H3:0.5650682545169237,F1 Score H3: 0.6203199530115331


Training: 100%|██████████| 231/231 [03:08<00:00,  1.23it/s]
Training: 100%|██████████| 231/231 [03:08<00:00,  1.23it/s]
Training: 100%|██████████| 231/231 [03:08<00:00,  1.23it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.42it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.42it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.42it/s]


Epoch 6/33
H1: Train Loss: 0.7322, Val Loss: 0.8714, Accuracy: 0.9772
Precision H1: 0.9772, Recall H1: 0.9772, F1 H1: 0.9772
H2: Train Loss: 0.7300, Val Loss: 0.7878, Accuracy: 0.9354
Precision H2: 0.9354, Recall H2: 0.9354, F1 H2: 0.9354
H3: Train Loss: 0.7084, Val Loss: 0.8421, Custom Accuracy: 0.7008
Precision H3: 0.6820115879051242,Recall H3:0.5752344191507689,F1 Score H3: 0.6240887422524247


Training: 100%|██████████| 231/231 [03:08<00:00,  1.23it/s]
Training: 100%|██████████| 231/231 [03:08<00:00,  1.23it/s]
Training: 100%|██████████| 231/231 [03:08<00:00,  1.23it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.43it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.42it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.43it/s]


Epoch 7/33
H1: Train Loss: 0.6809, Val Loss: 0.8597, Accuracy: 0.9772
Precision H1: 0.9772, Recall H1: 0.9772, F1 H1: 0.9772
H2: Train Loss: 0.6791, Val Loss: 0.8372, Accuracy: 0.9316
Precision H2: 0.9316, Recall H2: 0.9316, F1 H2: 0.9316
H3: Train Loss: 0.6384, Val Loss: 0.8562, Custom Accuracy: 0.7014
Precision H3: 0.7040467137425312,Recall H3:0.5394873504189094,F1 Score H3: 0.6108787963508304


Training: 100%|██████████| 231/231 [03:08<00:00,  1.23it/s]
Training: 100%|██████████| 231/231 [03:08<00:00,  1.23it/s]
Training: 100%|██████████| 231/231 [03:08<00:00,  1.23it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.43it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.42it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.43it/s]


Epoch 8/33
H1: Train Loss: 0.6361, Val Loss: 0.8240, Accuracy: 0.9772
Precision H1: 0.9772, Recall H1: 0.9772, F1 H1: 0.9772
H2: Train Loss: 0.6231, Val Loss: 0.8161, Accuracy: 0.9354
Precision H2: 0.9354, Recall H2: 0.9354, F1 H2: 0.9354
H3: Train Loss: 0.5990, Val Loss: 0.7477, Custom Accuracy: 0.7043
Precision H3: 0.6807004043696059,Recall H3:0.5788497413022128,F1 Score H3: 0.625657111513403


Training: 100%|██████████| 231/231 [03:08<00:00,  1.23it/s]
Training: 100%|██████████| 231/231 [03:08<00:00,  1.22it/s]
Training: 100%|██████████| 231/231 [03:08<00:00,  1.22it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.43it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.43it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.43it/s]


Epoch 9/33
H1: Train Loss: 0.6012, Val Loss: 0.8417, Accuracy: 0.9772
Precision H1: 0.9772, Recall H1: 0.9772, F1 H1: 0.9772
H2: Train Loss: 0.6115, Val Loss: 0.7937, Accuracy: 0.9430
Precision H2: 0.9430, Recall H2: 0.9430, F1 H2: 0.9430
H3: Train Loss: 0.5804, Val Loss: 0.7931, Custom Accuracy: 0.7043
Precision H3: 0.6811470215462611,Recall H3:0.5756692399277951,F1 Score H3: 0.6239820412774183


Training: 100%|██████████| 231/231 [03:08<00:00,  1.23it/s]
Training: 100%|██████████| 231/231 [03:09<00:00,  1.22it/s]
Training: 100%|██████████| 231/231 [03:09<00:00,  1.22it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.41it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.42it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.43it/s]


Epoch 10/33
H1: Train Loss: 0.5821, Val Loss: 0.7715, Accuracy: 0.9848
Precision H1: 0.9848, Recall H1: 0.9848, F1 H1: 0.9848
H2: Train Loss: 0.6224, Val Loss: 0.7219, Accuracy: 0.9620
Precision H2: 0.9620, Recall H2: 0.9620, F1 H2: 0.9620
H3: Train Loss: 0.5686, Val Loss: 0.7827, Custom Accuracy: 0.7075
Precision H3: 0.6801270444806568,Recall H3:0.5975691186717803,F1 Score H3: 0.6361808546914451


Training: 100%|██████████| 231/231 [03:08<00:00,  1.23it/s]
Training: 100%|██████████| 231/231 [03:09<00:00,  1.22it/s]
Training: 100%|██████████| 231/231 [03:12<00:00,  1.20it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.37it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.35it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.36it/s]


Epoch 11/33
H1: Train Loss: 0.5632, Val Loss: 0.7944, Accuracy: 0.9772
Precision H1: 0.9772, Recall H1: 0.9772, F1 H1: 0.9772
H2: Train Loss: 0.5716, Val Loss: 0.7718, Accuracy: 0.9544
Precision H2: 0.9544, Recall H2: 0.9544, F1 H2: 0.9544
H3: Train Loss: 0.5989, Val Loss: 0.8181, Custom Accuracy: 0.7093
Precision H3: 0.6806231791022666,Recall H3:0.6055765421735003,F1 Score H3: 0.6409104659345741


Training: 100%|██████████| 231/231 [03:11<00:00,  1.21it/s]
Training: 100%|██████████| 231/231 [03:08<00:00,  1.23it/s]
Training: 100%|██████████| 231/231 [03:08<00:00,  1.23it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.43it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.42it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.43it/s]


Epoch 12/33
H1: Train Loss: 0.5770, Val Loss: 0.7487, Accuracy: 0.9734
Precision H1: 0.9734, Recall H1: 0.9734, F1 H1: 0.9734
H2: Train Loss: 0.5525, Val Loss: 0.7695, Accuracy: 0.9468
Precision H2: 0.9468, Recall H2: 0.9468, F1 H2: 0.9468
H3: Train Loss: 0.5352, Val Loss: 0.7995, Custom Accuracy: 0.7204
Precision H3: 0.7085732391816042,Recall H3:0.587348087062916,F1 Score H3: 0.6422907442743367


Training: 100%|██████████| 231/231 [03:08<00:00,  1.23it/s]
Training: 100%|██████████| 231/231 [03:08<00:00,  1.23it/s]
Training: 100%|██████████| 231/231 [03:08<00:00,  1.23it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.43it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.43it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.44it/s]


Epoch 13/33
H1: Train Loss: 0.5284, Val Loss: 0.7383, Accuracy: 0.9772
Precision H1: 0.9772, Recall H1: 0.9772, F1 H1: 0.9772
H2: Train Loss: 0.5319, Val Loss: 0.8173, Accuracy: 0.9430
Precision H2: 0.9430, Recall H2: 0.9430, F1 H2: 0.9430
H3: Train Loss: 0.5182, Val Loss: 0.7880, Custom Accuracy: 0.7286
Precision H3: 0.7146608672274071,Recall H3:0.6054285385083864,F1 Score H3: 0.6555254250123451


Training: 100%|██████████| 231/231 [03:08<00:00,  1.23it/s]
Training: 100%|██████████| 231/231 [03:08<00:00,  1.23it/s]
Training: 100%|██████████| 231/231 [03:08<00:00,  1.22it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.42it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.43it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.43it/s]


Epoch 14/33
H1: Train Loss: 0.5093, Val Loss: 0.7396, Accuracy: 0.9810
Precision H1: 0.9810, Recall H1: 0.9810, F1 H1: 0.9810
H2: Train Loss: 0.5154, Val Loss: 0.8589, Accuracy: 0.9278
Precision H2: 0.9278, Recall H2: 0.9278, F1 H2: 0.9278
H3: Train Loss: 0.5234, Val Loss: 0.7787, Custom Accuracy: 0.7295
Precision H3: 0.7225371586588317,Recall H3:0.5896339824096478,F1 Score H3: 0.6493550253697048


Training: 100%|██████████| 231/231 [03:08<00:00,  1.23it/s]
Training: 100%|██████████| 231/231 [03:08<00:00,  1.23it/s]
Training: 100%|██████████| 231/231 [03:08<00:00,  1.23it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.43it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.41it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.43it/s]


Epoch 15/33
H1: Train Loss: 0.5074, Val Loss: 0.7434, Accuracy: 0.9772
Precision H1: 0.9772, Recall H1: 0.9772, F1 H1: 0.9772
H2: Train Loss: 0.5070, Val Loss: 0.8332, Accuracy: 0.9506
Precision H2: 0.9506, Recall H2: 0.9506, F1 H2: 0.9506
H3: Train Loss: 0.4815, Val Loss: 0.8057, Custom Accuracy: 0.7438
Precision H3: 0.7179062487997848,Recall H3:0.6252914808047888,F1 Score H3: 0.6684059263905713


Training: 100%|██████████| 231/231 [03:08<00:00,  1.23it/s]
Training: 100%|██████████| 231/231 [03:08<00:00,  1.23it/s]
Training: 100%|██████████| 231/231 [03:08<00:00,  1.22it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.43it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.43it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.44it/s]


Epoch 16/33
H1: Train Loss: 0.4869, Val Loss: 0.7068, Accuracy: 0.9848
Precision H1: 0.9848, Recall H1: 0.9848, F1 H1: 0.9848
H2: Train Loss: 0.4840, Val Loss: 0.7674, Accuracy: 0.9354
Precision H2: 0.9354, Recall H2: 0.9354, F1 H2: 0.9354
H3: Train Loss: 0.4733, Val Loss: 0.7715, Custom Accuracy: 0.7444
Precision H3: 0.7335445711871568,Recall H3:0.6078262253167195,F1 Score H3: 0.6647939987486863


Training: 100%|██████████| 231/231 [03:08<00:00,  1.22it/s]
Training: 100%|██████████| 231/231 [03:08<00:00,  1.23it/s]
Training: 100%|██████████| 231/231 [03:08<00:00,  1.23it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.43it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.42it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.43it/s]


Epoch 17/33
H1: Train Loss: 0.4749, Val Loss: 0.6942, Accuracy: 0.9848
Precision H1: 0.9848, Recall H1: 0.9848, F1 H1: 0.9848
H2: Train Loss: 0.4655, Val Loss: 0.8172, Accuracy: 0.9468
Precision H2: 0.9468, Recall H2: 0.9468, F1 H2: 0.9468
H3: Train Loss: 0.4571, Val Loss: 0.7698, Custom Accuracy: 0.7567
Precision H3: 0.7367411211517676,Recall H3:0.6462406520391312,F1 Score H3: 0.6885297720426848


Training: 100%|██████████| 231/231 [03:08<00:00,  1.23it/s]
Training: 100%|██████████| 231/231 [03:08<00:00,  1.22it/s]
Training: 100%|██████████| 231/231 [03:08<00:00,  1.23it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.43it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.42it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.43it/s]


Epoch 18/33
H1: Train Loss: 0.4643, Val Loss: 0.6853, Accuracy: 0.9848
Precision H1: 0.9848, Recall H1: 0.9848, F1 H1: 0.9848
H2: Train Loss: 0.4586, Val Loss: 0.6901, Accuracy: 0.9544
Precision H2: 0.9544, Recall H2: 0.9544, F1 H2: 0.9544
H3: Train Loss: 0.4359, Val Loss: 0.6892, Custom Accuracy: 0.7625
Precision H3: 0.7329786183398351,Recall H3:0.6806226304325163,F1 Score H3: 0.7058310619044885


Training: 100%|██████████| 231/231 [03:08<00:00,  1.23it/s]
Training: 100%|██████████| 231/231 [03:08<00:00,  1.23it/s]
Training: 100%|██████████| 231/231 [03:08<00:00,  1.23it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.43it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.42it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.43it/s]


Epoch 19/33
H1: Train Loss: 0.4260, Val Loss: 0.6823, Accuracy: 0.9848
Precision H1: 0.9848, Recall H1: 0.9848, F1 H1: 0.9848
H2: Train Loss: 0.4128, Val Loss: 0.7125, Accuracy: 0.9544
Precision H2: 0.9544, Recall H2: 0.9544, F1 H2: 0.9544
H3: Train Loss: 0.4118, Val Loss: 0.6904, Custom Accuracy: 0.7780
Precision H3: 0.758185741170532,Recall H3:0.6829293752297554,F1 Score H3: 0.7185925796393889


Training: 100%|██████████| 231/231 [03:08<00:00,  1.23it/s]
Training: 100%|██████████| 231/231 [03:08<00:00,  1.23it/s]
Training: 100%|██████████| 231/231 [03:08<00:00,  1.23it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.43it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.42it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.43it/s]


Epoch 20/33
H1: Train Loss: 0.4074, Val Loss: 0.6610, Accuracy: 0.9886
Precision H1: 0.9886, Recall H1: 0.9886, F1 H1: 0.9886
H2: Train Loss: 0.3978, Val Loss: 0.6942, Accuracy: 0.9544
Precision H2: 0.9544, Recall H2: 0.9544, F1 H2: 0.9544
H3: Train Loss: 0.4039, Val Loss: 0.6716, Custom Accuracy: 0.7879
Precision H3: 0.7821350385989169,Recall H3:0.6611127571203617,F1 Score H3: 0.7165497890829912


Training: 100%|██████████| 231/231 [03:08<00:00,  1.23it/s]
Training: 100%|██████████| 231/231 [03:10<00:00,  1.21it/s]
Training: 100%|██████████| 231/231 [03:12<00:00,  1.20it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.37it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.35it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.37it/s]


Epoch 21/33
H1: Train Loss: 0.3847, Val Loss: 0.5822, Accuracy: 0.9848
Precision H1: 0.9848, Recall H1: 0.9848, F1 H1: 0.9848
H2: Train Loss: 0.4142, Val Loss: 0.6501, Accuracy: 0.9468
Precision H2: 0.9468, Recall H2: 0.9468, F1 H2: 0.9468
H3: Train Loss: 0.3699, Val Loss: 0.6699, Custom Accuracy: 0.8122
Precision H3: 0.7953556477320736,Recall H3:0.7271979710192639,F1 Score H3: 0.7597512575534946


Training: 100%|██████████| 231/231 [03:11<00:00,  1.20it/s]
Training: 100%|██████████| 231/231 [03:08<00:00,  1.22it/s]
Training: 100%|██████████| 231/231 [03:08<00:00,  1.23it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.43it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.42it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.44it/s]


Epoch 22/33
H1: Train Loss: 0.3539, Val Loss: 0.5819, Accuracy: 0.9848
Precision H1: 0.9848, Recall H1: 0.9848, F1 H1: 0.9848
H2: Train Loss: 0.3461, Val Loss: 0.6653, Accuracy: 0.9430
Precision H2: 0.9430, Recall H2: 0.9430, F1 H2: 0.9430
H3: Train Loss: 0.3339, Val Loss: 0.6433, Custom Accuracy: 0.8228
Precision H3: 0.8161852089608743,Recall H3:0.7343120230002359,F1 Score H3: 0.7730869808479122


Training: 100%|██████████| 231/231 [03:08<00:00,  1.22it/s]
Training: 100%|██████████| 231/231 [03:08<00:00,  1.23it/s]
Training: 100%|██████████| 231/231 [03:08<00:00,  1.22it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.36it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.35it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.37it/s]


Epoch 23/33
H1: Train Loss: 0.3277, Val Loss: 0.5842, Accuracy: 0.9810
Precision H1: 0.9810, Recall H1: 0.9810, F1 H1: 0.9810
H2: Train Loss: 0.3180, Val Loss: 0.6314, Accuracy: 0.9468
Precision H2: 0.9468, Recall H2: 0.9468, F1 H2: 0.9468
H3: Train Loss: 0.3077, Val Loss: 0.6002, Custom Accuracy: 0.8467
Precision H3: 0.8490100203978532,Recall H3:0.7636530980637445,F1 Score H3: 0.8040726236518351


Training: 100%|██████████| 231/231 [03:12<00:00,  1.20it/s]
Training: 100%|██████████| 231/231 [03:11<00:00,  1.20it/s]
Training: 100%|██████████| 231/231 [03:11<00:00,  1.20it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.36it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.35it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.37it/s]


Epoch 24/33
H1: Train Loss: 0.3002, Val Loss: 0.6291, Accuracy: 0.9772
Precision H1: 0.9772, Recall H1: 0.9772, F1 H1: 0.9772
H2: Train Loss: 0.2911, Val Loss: 0.7100, Accuracy: 0.9544
Precision H2: 0.9544, Recall H2: 0.9544, F1 H2: 0.9544
H3: Train Loss: 0.2837, Val Loss: 0.6447, Custom Accuracy: 0.8502
Precision H3: 0.857752429235319,Recall H3:0.7750127565716919,F1 Score H3: 0.8142861942627234


Training: 100%|██████████| 231/231 [03:12<00:00,  1.20it/s]
Training: 100%|██████████| 231/231 [03:12<00:00,  1.20it/s]
Training: 100%|██████████| 231/231 [03:12<00:00,  1.20it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.36it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.35it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.37it/s]


Epoch 25/33
H1: Train Loss: 0.2938, Val Loss: 0.5511, Accuracy: 0.9848
Precision H1: 0.9848, Recall H1: 0.9848, F1 H1: 0.9848
H2: Train Loss: 0.2781, Val Loss: 0.5936, Accuracy: 0.9468
Precision H2: 0.9468, Recall H2: 0.9468, F1 H2: 0.9468
H3: Train Loss: 0.2741, Val Loss: 0.5592, Custom Accuracy: 0.8701
Precision H3: 0.8675735755583664,Recall H3:0.8040826516111688,F1 Score H3: 0.8346223939642645


Training: 100%|██████████| 231/231 [03:12<00:00,  1.20it/s]
Training: 100%|██████████| 231/231 [03:11<00:00,  1.20it/s]
Training: 100%|██████████| 231/231 [03:12<00:00,  1.20it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.36it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.35it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.37it/s]


Epoch 26/33
H1: Train Loss: 0.2529, Val Loss: 0.5272, Accuracy: 0.9772
Precision H1: 0.9772, Recall H1: 0.9772, F1 H1: 0.9772
H2: Train Loss: 0.2423, Val Loss: 0.5791, Accuracy: 0.9430
Precision H2: 0.9430, Recall H2: 0.9430, F1 H2: 0.9430
H3: Train Loss: 0.2295, Val Loss: 0.5456, Custom Accuracy: 0.8894
Precision H3: 0.8810892740550536,Recall H3:0.8514124131044283,F1 Score H3: 0.8659966689135891


Training: 100%|██████████| 231/231 [03:12<00:00,  1.20it/s]
Training: 100%|██████████| 231/231 [03:11<00:00,  1.20it/s]
Training: 100%|██████████| 231/231 [03:12<00:00,  1.20it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.37it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.35it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.37it/s]


Epoch 27/33
H1: Train Loss: 0.2276, Val Loss: 0.4883, Accuracy: 0.9810
Precision H1: 0.9810, Recall H1: 0.9810, F1 H1: 0.9810
H2: Train Loss: 0.2175, Val Loss: 0.5134, Accuracy: 0.9430
Precision H2: 0.9430, Recall H2: 0.9430, F1 H2: 0.9430
H3: Train Loss: 0.2079, Val Loss: 0.5144, Custom Accuracy: 0.9008
Precision H3: 0.8963982574248734,Recall H3:0.8519527156409286,F1 Score H3: 0.8736105524278786


Training: 100%|██████████| 231/231 [03:12<00:00,  1.20it/s]
Training: 100%|██████████| 231/231 [03:12<00:00,  1.20it/s]
Training: 100%|██████████| 231/231 [03:11<00:00,  1.20it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.37it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.36it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.37it/s]


Epoch 28/33
H1: Train Loss: 0.2076, Val Loss: 0.4694, Accuracy: 0.9848
Precision H1: 0.9848, Recall H1: 0.9848, F1 H1: 0.9848
H2: Train Loss: 0.2145, Val Loss: 0.5394, Accuracy: 0.9354
Precision H2: 0.9354, Recall H2: 0.9354, F1 H2: 0.9354
H3: Train Loss: 0.1936, Val Loss: 0.5003, Custom Accuracy: 0.9131
Precision H3: 0.9092640143970943,Recall H3:0.8683809853011375,F1 Score H3: 0.8883523772801146


Training: 100%|██████████| 231/231 [03:10<00:00,  1.21it/s]
Training: 100%|██████████| 231/231 [03:08<00:00,  1.23it/s]
Training: 100%|██████████| 231/231 [03:09<00:00,  1.22it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.36it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.35it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.37it/s]


Epoch 29/33
H1: Train Loss: 0.1818, Val Loss: 0.4542, Accuracy: 0.9810
Precision H1: 0.9810, Recall H1: 0.9810, F1 H1: 0.9810
H2: Train Loss: 0.1744, Val Loss: 0.4950, Accuracy: 0.9468
Precision H2: 0.9468, Recall H2: 0.9468, F1 H2: 0.9468
H3: Train Loss: 0.1744, Val Loss: 0.5838, Custom Accuracy: 0.9175
Precision H3: 0.9183909985240784,Recall H3:0.8762586484069373,F1 Score H3: 0.8968302603820033


Training: 100%|██████████| 231/231 [03:11<00:00,  1.20it/s]
Training: 100%|██████████| 231/231 [03:12<00:00,  1.20it/s]
Training: 100%|██████████| 231/231 [03:11<00:00,  1.21it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.43it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.42it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.44it/s]


Epoch 30/33
H1: Train Loss: 0.1683, Val Loss: 0.4371, Accuracy: 0.9772
Precision H1: 0.9772, Recall H1: 0.9772, F1 H1: 0.9772
H2: Train Loss: 0.1510, Val Loss: 0.4782, Accuracy: 0.9392
Precision H2: 0.9392, Recall H2: 0.9392, F1 H2: 0.9392
H3: Train Loss: 0.1470, Val Loss: 0.4821, Custom Accuracy: 0.9339
Precision H3: 0.9306256481161423,Recall H3:0.9090338474368893,F1 Score H3: 0.9197030379539373


Training: 100%|██████████| 231/231 [03:08<00:00,  1.23it/s]
Training: 100%|██████████| 231/231 [03:08<00:00,  1.23it/s]
Training: 100%|██████████| 231/231 [03:08<00:00,  1.23it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.43it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.42it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.43it/s]


Epoch 31/33
H1: Train Loss: 0.1630, Val Loss: 0.4214, Accuracy: 0.9848
Precision H1: 0.9848, Recall H1: 0.9848, F1 H1: 0.9848
H2: Train Loss: 0.1339, Val Loss: 0.4496, Accuracy: 0.9544
Precision H2: 0.9544, Recall H2: 0.9544, F1 H2: 0.9544
H3: Train Loss: 0.1284, Val Loss: 0.4507, Custom Accuracy: 0.9409
Precision H3: 0.9349930593276601,Recall H3:0.9187854920744655,F1 Score H3: 0.9268184244022202


Training: 100%|██████████| 231/231 [03:08<00:00,  1.23it/s]
Training: 100%|██████████| 231/231 [03:08<00:00,  1.23it/s]
Training: 100%|██████████| 231/231 [03:08<00:00,  1.22it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.35it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.34it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.35it/s]


Epoch 32/33
H1: Train Loss: 0.1264, Val Loss: 0.4325, Accuracy: 0.9848
Precision H1: 0.9848, Recall H1: 0.9848, F1 H1: 0.9848
H2: Train Loss: 0.1146, Val Loss: 0.4492, Accuracy: 0.9506
Precision H2: 0.9506, Recall H2: 0.9506, F1 H2: 0.9506
H3: Train Loss: 0.1142, Val Loss: 0.4266, Custom Accuracy: 0.9476
Precision H3: 0.9373181845615306,Recall H3:0.9348185823470995,F1 Score H3: 0.9360667147698095


Training: 100%|██████████| 231/231 [03:09<00:00,  1.22it/s]
Training: 100%|██████████| 231/231 [03:08<00:00,  1.23it/s]
Training: 100%|██████████| 231/231 [03:08<00:00,  1.23it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.43it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.41it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.42it/s]


Epoch 33/33
H1: Train Loss: 0.1122, Val Loss: 0.3879, Accuracy: 0.9848
Precision H1: 0.9848, Recall H1: 0.9848, F1 H1: 0.9848
H2: Train Loss: 0.1023, Val Loss: 0.4647, Accuracy: 0.9468
Precision H2: 0.9468, Recall H2: 0.9468, F1 H2: 0.9468
H3: Train Loss: 0.1012, Val Loss: 0.4288, Custom Accuracy: 0.9468
Precision H3: 0.9368437223950532,Recall H3:0.9341352141732369,F1 Score H3: 0.9354875078086302


In [4]:
# Create test datasets for H1, H2, and H3
test_dataset_h1 = TextDataset(test_texts_h1, test_labels_h1, test_labels_h2, test_labels_h3, tokenizer)
test_dataset_h2 = TextDataset(test_texts_h2, test_labels_h1, test_labels_h2, test_labels_h3, tokenizer)
test_dataset_h3 = TextDataset(test_texts_h3, test_labels_h1, test_labels_h2, test_labels_h3, tokenizer)

# Create test dataloaders for H1, H2, and H3
test_dataloader_h1 = DataLoader(test_dataset_h1, batch_size=8, shuffle=False)
test_dataloader_h2 = DataLoader(test_dataset_h2, batch_size=8, shuffle=False)
test_dataloader_h3 = DataLoader(test_dataset_h3, batch_size=8, shuffle=False)
# Testing loop for H1, H2, and H3
def test(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0.0
    all_preds_h1, all_preds_h2, all_preds_h3 = [], [], []
    all_labels_h1, all_labels_h2, all_labels_h3 = [], [], []

    with torch.no_grad():
        for batch in tqdm(dataloader, total=len(dataloader), desc="Testing"):
            input_ids = batch['ids'].to(device)
            attention_mask = batch['mask'].to(device)
            labels_h1 = batch['labels_h1'].to(device)  # No need to reshape labels_h1
            labels_h2 = batch['labels_h2'].to(device)  # No need to reshape labels_h2

            #labels_h1 = batch['labels_h1'].unsqueeze(1).to(device)  # Reshape labels_h1
            #labels_h2 = batch['labels_h2'].unsqueeze(1).to(device)  # Reshape labels_h2
            #labels_h3 = batch['labels_h3'].unsqueeze(1).to(device)  # Reshape labels_h3
            labels_h3 = batch['labels_h3'].to(device)  # No need to reshape labels_h3
            print("Labels H1 shape:", labels_h1.shape)
            print("Labels H2 shape:", labels_h2.shape)
            print("Labels H3 shape:", labels_h3.shape)

            logits_h1, logits_h2, logits_h3 = model(input_ids, attention_mask)
            
            print("Logits H1 shape:", logits_h1.shape)
            print("Logits H2 shape:", logits_h2.shape)
            print("Logits H3 shape:", logits_h3.shape)
            
            
            
            loss_h1 = criterion(logits_h1, labels_h1)
            loss_h2 = criterion(logits_h2, labels_h2)
            loss_h3 = criterion(logits_h3, labels_h3)

            loss = loss_h1 + loss_h2 + loss_h3
            total_loss += loss.item()

            preds_h1 = torch.sigmoid(logits_h1)
            preds_h2 = torch.sigmoid(logits_h2)
            preds_h3 = torch.sigmoid(logits_h3)

            all_preds_h1.extend(preds_h1.cpu().numpy())
            all_preds_h2.extend(preds_h2.cpu().numpy())
            all_preds_h3.extend(preds_h3.cpu().numpy())

            all_labels_h1.extend(labels_h1.cpu().numpy())
            all_labels_h2.extend(labels_h2.cpu().numpy())
            all_labels_h3.extend(labels_h3.cpu().numpy())

    return total_loss / len(dataloader), all_preds_h1, all_preds_h2, all_preds_h3, all_labels_h1, all_labels_h2, all_labels_h3

# Testing for H1
test_loss_h1, test_preds_h1, _, _, test_labels_h1, _, _ = test(model, test_dataloader_h1, criterion, device)
# Testing for H2
test_loss_h2, _, test_preds_h2, _, _, test_labels_h2, _ = test(model, test_dataloader_h2, criterion, device)
# Testing for H3
test_loss_h3, _, _, test_preds_h3, _, _, test_labels_h3 = test(model, test_dataloader_h3, criterion, device)

# Metrics calculation for H1
threshold_h1 = 0.5
test_preds_h1_binary = (np.array(test_preds_h1) > threshold_h1).astype(int)

acc_h1 = accuracy_score(test_labels_h1, test_preds_h1_binary)
precision_h1 = precision_score(test_labels_h1, test_preds_h1_binary, average='micro')
recall_h1 = recall_score(test_labels_h1, test_preds_h1_binary, average='micro')
f1_h1 = f1_score(test_labels_h1, test_preds_h1_binary, average='micro')

# Metrics calculation for H2
threshold_h2 = 0.5
test_preds_h2_binary = (np.array(test_preds_h2) > threshold_h2).astype(int)

acc_h2 = accuracy_score(test_labels_h2, test_preds_h2_binary)
precision_h2 = precision_score(test_labels_h2, test_preds_h2_binary, average='micro')
recall_h2 = recall_score(test_labels_h2, test_preds_h2_binary, average='micro')
f1_h2 = f1_score(test_labels_h2, test_preds_h2_binary, average='micro')

# Metrics calculation for H3
threshold_h3 = 0.5
test_preds_h3_binary = (np.array(test_preds_h3) > threshold_h3).astype(int)

# Custom accuracy calculation for H3
acc_h3_custom = custom_accuracy(np.array(test_labels_h3), test_preds_h3_binary)
# Calculate custom precision, recall, and F1 score for H3
precision_h3, recall_h3, f1_score_h3 = custom_precision_recall_f1(np.array(test_labels_h3), test_preds_h3_binary)

print("Testing Results:")
print(f"H1: Test Loss: {test_loss_h1:.4f}, Accuracy: {acc_h1:.4f}")
print(f"Precision H1: {precision_h1:.4f}, Recall H1: {recall_h1:.4f}, F1 H1: {f1_h1:.4f}")

print(f"H2: Test Loss: {test_loss_h2:.4f}, Accuracy: {acc_h2:.4f}")
print(f"Precision H2: {precision_h2:.4f}, Recall H2: {recall_h2:.4f}, F1 H2: {f1_h2:.4f}")

print(f"H3: Test Loss: {test_loss_h3:.4f}, Custom Accuracy: {acc_h3_custom:.4f}")
print(f"Precision H3: {precision_h3}, Recall H3:{recall_h3}, F1 Score H3: {f1_score_h3}")




Testing:   0%|          | 0/66 [00:00<?, ?it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:   2%|▏         | 1/66 [00:00<00:19,  3.40it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:   3%|▎         | 2/66 [00:00<00:18,  3.46it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:   5%|▍         | 3/66 [00:00<00:18,  3.43it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:   6%|▌         | 4/66 [00:01<00:17,  3.46it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:   8%|▊         | 5/66 [00:01<00:17,  3.46it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:   9%|▉         | 6/66 [00:01<00:17,  3.37it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  11%|█         | 7/66 [00:02<00:17,  3.39it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  12%|█▏        | 8/66 [00:02<00:17,  3.39it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  14%|█▎        | 9/66 [00:02<00:16,  3.41it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  15%|█▌        | 10/66 [00:02<00:16,  3.41it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  17%|█▋        | 11/66 [00:03<00:16,  3.42it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  18%|█▊        | 12/66 [00:03<00:15,  3.41it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  20%|█▉        | 13/66 [00:03<00:15,  3.40it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  21%|██        | 14/66 [00:04<00:15,  3.42it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  23%|██▎       | 15/66 [00:04<00:14,  3.40it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  24%|██▍       | 16/66 [00:04<00:14,  3.43it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  26%|██▌       | 17/66 [00:04<00:14,  3.44it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  27%|██▋       | 18/66 [00:05<00:14,  3.42it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  29%|██▉       | 19/66 [00:05<00:13,  3.44it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  30%|███       | 20/66 [00:05<00:13,  3.42it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  32%|███▏      | 21/66 [00:06<00:13,  3.39it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  33%|███▎      | 22/66 [00:06<00:12,  3.40it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  35%|███▍      | 23/66 [00:06<00:12,  3.44it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  36%|███▋      | 24/66 [00:07<00:12,  3.44it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  38%|███▊      | 25/66 [00:07<00:11,  3.43it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  39%|███▉      | 26/66 [00:07<00:11,  3.46it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  41%|████      | 27/66 [00:07<00:11,  3.43it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  42%|████▏     | 28/66 [00:08<00:11,  3.42it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  44%|████▍     | 29/66 [00:08<00:10,  3.38it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  45%|████▌     | 30/66 [00:08<00:10,  3.40it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  47%|████▋     | 31/66 [00:09<00:10,  3.40it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  48%|████▊     | 32/66 [00:09<00:10,  3.40it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  50%|█████     | 33/66 [00:09<00:09,  3.40it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  52%|█████▏    | 34/66 [00:09<00:09,  3.40it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  53%|█████▎    | 35/66 [00:10<00:09,  3.39it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  55%|█████▍    | 36/66 [00:10<00:08,  3.42it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  56%|█████▌    | 37/66 [00:10<00:08,  3.37it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  58%|█████▊    | 38/66 [00:11<00:08,  3.38it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  59%|█████▉    | 39/66 [00:11<00:07,  3.40it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  61%|██████    | 40/66 [00:11<00:07,  3.40it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  62%|██████▏   | 41/66 [00:12<00:07,  3.42it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  64%|██████▎   | 42/66 [00:12<00:07,  3.43it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  65%|██████▌   | 43/66 [00:12<00:06,  3.46it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  67%|██████▋   | 44/66 [00:12<00:06,  3.48it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  68%|██████▊   | 45/66 [00:13<00:06,  3.45it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  70%|██████▉   | 46/66 [00:13<00:05,  3.45it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  71%|███████   | 47/66 [00:13<00:05,  3.43it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  73%|███████▎  | 48/66 [00:14<00:05,  3.40it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  74%|███████▍  | 49/66 [00:14<00:05,  3.40it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  76%|███████▌  | 50/66 [00:14<00:04,  3.39it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  77%|███████▋  | 51/66 [00:14<00:04,  3.36it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  79%|███████▉  | 52/66 [00:15<00:04,  3.39it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  80%|████████  | 53/66 [00:15<00:03,  3.39it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  82%|████████▏ | 54/66 [00:15<00:03,  3.39it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  83%|████████▎ | 55/66 [00:16<00:03,  3.39it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  85%|████████▍ | 56/66 [00:16<00:02,  3.40it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  86%|████████▋ | 57/66 [00:16<00:02,  3.42it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  88%|████████▊ | 58/66 [00:16<00:02,  3.42it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  89%|████████▉ | 59/66 [00:17<00:02,  3.45it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  91%|█████████ | 60/66 [00:17<00:01,  3.41it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  92%|█████████▏| 61/66 [00:17<00:01,  3.43it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  94%|█████████▍| 62/66 [00:18<00:01,  3.45it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  95%|█████████▌| 63/66 [00:18<00:00,  3.41it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  97%|█████████▋| 64/66 [00:18<00:00,  3.43it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  98%|█████████▊| 65/66 [00:19<00:00,  3.42it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing: 100%|██████████| 66/66 [00:19<00:00,  3.42it/s]
Testing:   0%|          | 0/66 [00:00<?, ?it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:   2%|▏         | 1/66 [00:00<00:19,  3.41it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:   3%|▎         | 2/66 [00:00<00:18,  3.46it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:   5%|▍         | 3/66 [00:00<00:18,  3.45it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:   6%|▌         | 4/66 [00:01<00:17,  3.47it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:   8%|▊         | 5/66 [00:01<00:17,  3.47it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:   9%|▉         | 6/66 [00:01<00:17,  3.39it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  11%|█         | 7/66 [00:02<00:17,  3.39it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  12%|█▏        | 8/66 [00:02<00:16,  3.41it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  14%|█▎        | 9/66 [00:02<00:16,  3.43it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  15%|█▌        | 10/66 [00:02<00:16,  3.43it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  17%|█▋        | 11/66 [00:03<00:16,  3.44it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  18%|█▊        | 12/66 [00:03<00:15,  3.43it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  20%|█▉        | 13/66 [00:03<00:15,  3.43it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  21%|██        | 14/66 [00:04<00:15,  3.43it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  23%|██▎       | 15/66 [00:04<00:14,  3.44it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  24%|██▍       | 16/66 [00:04<00:14,  3.45it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  26%|██▌       | 17/66 [00:04<00:14,  3.45it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  27%|██▋       | 18/66 [00:05<00:13,  3.43it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  29%|██▉       | 19/66 [00:05<00:13,  3.43it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  30%|███       | 20/66 [00:05<00:13,  3.42it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  32%|███▏      | 21/66 [00:06<00:13,  3.38it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  33%|███▎      | 22/66 [00:06<00:12,  3.39it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  35%|███▍      | 23/66 [00:06<00:12,  3.43it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  36%|███▋      | 24/66 [00:07<00:12,  3.43it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  38%|███▊      | 25/66 [00:07<00:11,  3.43it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  39%|███▉      | 26/66 [00:07<00:11,  3.46it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  41%|████      | 27/66 [00:07<00:11,  3.44it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  42%|████▏     | 28/66 [00:08<00:11,  3.43it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  44%|████▍     | 29/66 [00:08<00:10,  3.38it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  45%|████▌     | 30/66 [00:08<00:10,  3.40it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  47%|████▋     | 31/66 [00:09<00:10,  3.39it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  48%|████▊     | 32/66 [00:09<00:09,  3.40it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  50%|█████     | 33/66 [00:09<00:09,  3.41it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  52%|█████▏    | 34/66 [00:09<00:09,  3.41it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  53%|█████▎    | 35/66 [00:10<00:09,  3.41it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  55%|█████▍    | 36/66 [00:10<00:08,  3.42it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  56%|█████▌    | 37/66 [00:10<00:08,  3.36it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  58%|█████▊    | 38/66 [00:11<00:08,  3.37it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  59%|█████▉    | 39/66 [00:11<00:07,  3.38it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  61%|██████    | 40/66 [00:11<00:07,  3.40it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  62%|██████▏   | 41/66 [00:11<00:07,  3.41it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  64%|██████▎   | 42/66 [00:12<00:07,  3.43it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  65%|██████▌   | 43/66 [00:12<00:06,  3.45it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  67%|██████▋   | 44/66 [00:12<00:06,  3.47it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  68%|██████▊   | 45/66 [00:13<00:06,  3.43it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  70%|██████▉   | 46/66 [00:13<00:05,  3.43it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  71%|███████   | 47/66 [00:13<00:05,  3.42it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  73%|███████▎  | 48/66 [00:14<00:05,  3.41it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  74%|███████▍  | 49/66 [00:14<00:04,  3.41it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  76%|███████▌  | 50/66 [00:14<00:04,  3.38it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  77%|███████▋  | 51/66 [00:14<00:04,  3.36it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  79%|███████▉  | 52/66 [00:15<00:04,  3.40it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  80%|████████  | 53/66 [00:15<00:03,  3.39it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  82%|████████▏ | 54/66 [00:15<00:03,  3.42it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  83%|████████▎ | 55/66 [00:16<00:03,  3.40it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  85%|████████▍ | 56/66 [00:16<00:02,  3.40it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  86%|████████▋ | 57/66 [00:16<00:02,  3.40it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  88%|████████▊ | 58/66 [00:16<00:02,  3.39it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  89%|████████▉ | 59/66 [00:17<00:02,  3.44it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  91%|█████████ | 60/66 [00:17<00:01,  3.39it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  92%|█████████▏| 61/66 [00:17<00:01,  3.43it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  94%|█████████▍| 62/66 [00:18<00:01,  3.43it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  95%|█████████▌| 63/66 [00:18<00:00,  3.39it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  97%|█████████▋| 64/66 [00:18<00:00,  3.42it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  98%|█████████▊| 65/66 [00:19<00:00,  3.41it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing: 100%|██████████| 66/66 [00:19<00:00,  3.42it/s]
Testing:   0%|          | 0/66 [00:00<?, ?it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:   2%|▏         | 1/66 [00:00<00:19,  3.41it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:   3%|▎         | 2/66 [00:00<00:18,  3.45it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:   5%|▍         | 3/66 [00:00<00:18,  3.42it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:   6%|▌         | 4/66 [00:01<00:17,  3.45it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:   8%|▊         | 5/66 [00:01<00:17,  3.46it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:   9%|▉         | 6/66 [00:01<00:17,  3.37it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  11%|█         | 7/66 [00:02<00:17,  3.37it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  12%|█▏        | 8/66 [00:02<00:17,  3.38it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  14%|█▎        | 9/66 [00:02<00:16,  3.41it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  15%|█▌        | 10/66 [00:02<00:16,  3.42it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  17%|█▋        | 11/66 [00:03<00:16,  3.43it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  18%|█▊        | 12/66 [00:03<00:15,  3.43it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  20%|█▉        | 13/66 [00:03<00:15,  3.43it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  21%|██        | 14/66 [00:04<00:15,  3.44it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  23%|██▎       | 15/66 [00:04<00:14,  3.41it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  24%|██▍       | 16/66 [00:04<00:14,  3.43it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  26%|██▌       | 17/66 [00:04<00:14,  3.43it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  27%|██▋       | 18/66 [00:05<00:14,  3.41it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  29%|██▉       | 19/66 [00:05<00:13,  3.42it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  30%|███       | 20/66 [00:05<00:13,  3.41it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  32%|███▏      | 21/66 [00:06<00:13,  3.37it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  33%|███▎      | 22/66 [00:06<00:12,  3.39it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  35%|███▍      | 23/66 [00:06<00:12,  3.43it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  36%|███▋      | 24/66 [00:07<00:12,  3.43it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  38%|███▊      | 25/66 [00:07<00:11,  3.42it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  39%|███▉      | 26/66 [00:07<00:11,  3.44it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  41%|████      | 27/66 [00:07<00:11,  3.42it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  42%|████▏     | 28/66 [00:08<00:11,  3.41it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  44%|████▍     | 29/66 [00:08<00:10,  3.37it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  45%|████▌     | 30/66 [00:08<00:10,  3.39it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  47%|████▋     | 31/66 [00:09<00:10,  3.37it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  48%|████▊     | 32/66 [00:09<00:10,  3.38it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  50%|█████     | 33/66 [00:09<00:09,  3.39it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  52%|█████▏    | 34/66 [00:09<00:09,  3.40it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  53%|█████▎    | 35/66 [00:10<00:09,  3.40it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  55%|█████▍    | 36/66 [00:10<00:08,  3.43it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  56%|█████▌    | 37/66 [00:10<00:08,  3.38it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  58%|█████▊    | 38/66 [00:11<00:08,  3.40it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  59%|█████▉    | 39/66 [00:11<00:07,  3.42it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  61%|██████    | 40/66 [00:11<00:07,  3.41it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  62%|██████▏   | 41/66 [00:12<00:07,  3.42it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  64%|██████▎   | 42/66 [00:12<00:07,  3.42it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  65%|██████▌   | 43/66 [00:12<00:06,  3.45it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  67%|██████▋   | 44/66 [00:12<00:06,  3.47it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  68%|██████▊   | 45/66 [00:13<00:06,  3.44it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  70%|██████▉   | 46/66 [00:13<00:05,  3.45it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  71%|███████   | 47/66 [00:13<00:05,  3.43it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  73%|███████▎  | 48/66 [00:14<00:05,  3.42it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  74%|███████▍  | 49/66 [00:14<00:04,  3.40it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  76%|███████▌  | 50/66 [00:14<00:04,  3.39it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  77%|███████▋  | 51/66 [00:14<00:04,  3.36it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  79%|███████▉  | 52/66 [00:15<00:04,  3.39it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  80%|████████  | 53/66 [00:15<00:03,  3.40it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  82%|████████▏ | 54/66 [00:15<00:03,  3.41it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  83%|████████▎ | 55/66 [00:16<00:03,  3.40it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  85%|████████▍ | 56/66 [00:16<00:02,  3.42it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  86%|████████▋ | 57/66 [00:16<00:02,  3.42it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  88%|████████▊ | 58/66 [00:16<00:02,  3.43it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  89%|████████▉ | 59/66 [00:17<00:02,  3.45it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  91%|█████████ | 60/66 [00:17<00:01,  3.39it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  92%|█████████▏| 61/66 [00:17<00:01,  3.43it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  94%|█████████▍| 62/66 [00:18<00:01,  3.45it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  95%|█████████▌| 63/66 [00:18<00:00,  3.43it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  97%|█████████▋| 64/66 [00:18<00:00,  3.44it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  98%|█████████▊| 65/66 [00:19<00:00,  3.43it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing: 100%|██████████| 66/66 [00:19<00:00,  3.41it/s]

Testing Results:
H1: Test Loss: 0.0985, Accuracy: 0.9981
Precision H1: 0.9981, Recall H1: 0.9981, F1 H1: 0.9981
H2: Test Loss: 0.0985, Accuracy: 0.9962
Precision H2: 0.9962, Recall H2: 0.9962, F1 H2: 0.9962
H3: Test Loss: 0.0985, Custom Accuracy: 0.9869
Precision H3: 0.9875943553500371, Recall H3:0.9856872294372295, F1 Score H3: 0.9866398707995548





In [None]:
#UNSEEN DATASETS

In [3]:
# Assuming the new unseen datasets are similar to 'model1_2636.csv', 'model2_2636.csv', 'model3_2636.csv'
import pandas as pd
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset, random_split
#from transformers import RobertaTokenizer, RobertaModel, AdamW, get_linear_schedule_with_warmup
from transformers import AutoTokenizer, AutoModel, AdamW, get_linear_schedule_with_warmup

from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from tqdm import tqdm
import numpy as np
import torch

import pandas as pd

# Loading and preprocessing new_data_h1
new_data_h1 = pd.read_csv("112_H_100_O_NEWMODEL1.CSV")
new_data_h1['text'].fillna("", inplace=True)  # Replace NaN values with empty strings
new_texts_h1 = new_data_h1['text'].tolist()
new_labels_h1 = new_data_h1.iloc[:, 1].values.tolist()

# Loading and preprocessing new_data_h2
new_data_h2 = pd.read_csv("112_H_100_O_NEWMODEL2.CSV")
new_data_h2['text'].fillna("", inplace=True)  # Replace NaN values with empty strings
new_texts_h2 = new_data_h2['text'].tolist()
new_labels_h2 = new_data_h2.iloc[:, 1].values.tolist()

# Loading and preprocessing new_data_h3 for multi-label classification
new_data_h3 = pd.read_csv("112_H_100_O_NEWMODEL3.CSV")
new_data_h3['text'].fillna("", inplace=True)  # Replace NaN values with empty strings

# Replace NaN values in label columns with 0 for new_data_h3
# Assuming all label columns are binary, fill NaN values with 0
for label_col in new_data_h3.columns[1:]:  # Skip the text column, only fill NaN in label columns
    new_data_h3[label_col].fillna(0, inplace=True)

new_texts_h3 = new_data_h3['text'].tolist()
new_labels_h3 = new_data_h3.iloc[:, 1:].values.tolist()  # Extract labels after filling NaN values




class TextDataset(Dataset):
    def __init__(self, texts, labels_h1, labels_h2, labels_h3, tokenizer, max_length=512):
        self.texts = texts
        self.labels_h1 = labels_h1  # Now expecting a list of binary values for H1
        self.labels_h2 = np.array(labels_h2).reshape(-1, 1)  # Reshape labels_h2 to [batch_size, 1]
        self.labels_h3 = labels_h3  # Multi-label but binary values for H3
        self.tokenizer = tokenizer
        self.max_length = max_length

    # Rest of the code remains unchanged

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        inputs = self.tokenizer.encode_plus(
            text, None, add_special_tokens=True, max_length=self.max_length,
            padding='max_length', return_token_type_ids=True, truncation=True
        )
        return {
            'ids': torch.tensor(inputs['input_ids'], dtype=torch.long),
            'mask': torch.tensor(inputs['attention_mask'], dtype=torch.long),
            'labels_h1': torch.tensor(self.labels_h1[idx], dtype=torch.float).unsqueeze(-1),
            'labels_h2': torch.tensor(self.labels_h2[idx], dtype=torch.float),
            'labels_h3': torch.tensor(self.labels_h3[idx], dtype=torch.float)
        }
        
class HierarchicalClassifier(nn.Module):
    def __init__(self, num_labels_h3):
        super(HierarchicalClassifier, self).__init__()
        self.bigbird = bigbird_model
        self.dropout = nn.Dropout(0.1)
        self.fc_h1 = nn.Linear(self.bigbird.config.hidden_size, 1)
        self.fc_h2 = nn.Linear(self.bigbird.config.hidden_size, 1)
        self.fc_h3 = nn.Linear(self.bigbird.config.hidden_size, num_labels_h3)

    def forward(self, input_ids, attention_mask):
        outputs = self.bigbird(input_ids, attention_mask=attention_mask)
        pooled_output = outputs.last_hidden_state[:, 0, :]
        logits_h1 = self.fc_h1(self.dropout(pooled_output))
        logits_h2 = self.fc_h2(self.dropout(pooled_output))
        logits_h3 = self.fc_h3(self.dropout(pooled_output))
        return logits_h1, logits_h2, logits_h3

# Initialize and move the model to the appropriate device (CPU/GPU)
#model = HierarchicalClassifier(num_labels_h3=len(train_labels_h3[0]))

#device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#model.to(device)

# Define loss function and optimizer
criterion = nn.BCEWithLogitsLoss()

# Tokenizer and Model setup for Clinical-BigBird
tokenizer = AutoTokenizer.from_pretrained("yikuan8/Clinical-BigBird")
bigbird_model = AutoModel.from_pretrained("yikuan8/Clinical-BigBird")


#from transformers import RobertaTokenizer

# Initialize the tokenizer
#tokenizer = RobertaTokenizer.from_pretrained('roberta-base')

# Preprocess the texts and labels in the same way as the original datasets
# Assuming the preprocessing steps are included in the TextDataset class
# Create TextDataset instances for the new data
new_dataset_h1 = TextDataset(new_texts_h1, new_labels_h1, new_labels_h2, new_labels_h3, tokenizer)
new_dataset_h2 = TextDataset(new_texts_h2, new_labels_h1, new_labels_h2, new_labels_h3, tokenizer)
new_dataset_h3 = TextDataset(new_texts_h3, new_labels_h1, new_labels_h2, new_labels_h3, tokenizer)

# Create DataLoader instances for the new datasets
#new_dataloader_h1 = DataLoader(new_dataset_h1, batch_size=8, shuffle=False)
#new_dataloader_h2 = DataLoader(new_dataset_h2, batch_size=8, shuffle=False)
#new_dataloader_h3 = DataLoader(new_dataset_h3, batch_size=8, shuffle=False)


new_dataloader_h1 = DataLoader(new_dataset_h1, batch_size=8, shuffle=False, drop_last=True)
new_dataloader_h2 = DataLoader(new_dataset_h2, batch_size=8, shuffle=False, drop_last=True)
new_dataloader_h3 = DataLoader(new_dataset_h3, batch_size=8, shuffle=False, drop_last=True)

# Load the saved model
model_path = 'model_Clinical_Bird.pth'  # Path to your saved model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = HierarchicalClassifier(num_labels_h3=len(new_labels_h3[0]))  # Make sure to provide the correct number of labels for H3
model.load_state_dict(torch.load(model_path))
model.eval()  # Set the model to evaluation mode
model.to(device)  # Move the model to the appropriate device
def test(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0.0
    all_preds_h1, all_preds_h2, all_preds_h3 = [], [], []
    all_labels_h1, all_labels_h2, all_labels_h3 = [], [], []

    with torch.no_grad():
        for batch in tqdm(dataloader, total=len(dataloader), desc="Testing"):
            input_ids = batch['ids'].to(device)
            attention_mask = batch['mask'].to(device)
            
            # Ensuring labels are correctly shaped and moved to the device
            labels_h1 = batch['labels_h1'].to(device)  # Shape: [batch_size, 1]
            labels_h2 = batch['labels_h2'].to(device)  # Shape: [batch_size, 1]
            labels_h3 = batch['labels_h3'].to(device)  # Shape: [batch_size, num_labels_h3]
            
            logits_h1, logits_h2, logits_h3 = model(input_ids, attention_mask)
            
            # Compute loss for each hierarchy
            loss_h1 = criterion(logits_h1, labels_h1)
            loss_h2 = criterion(logits_h2, labels_h2)
            loss_h3 = criterion(logits_h3, labels_h3)

            loss = loss_h1 + loss_h2 + loss_h3
            total_loss += loss.item()

            # Convert logits to probabilities for binary classification/multi-label classification
            preds_h1 = torch.sigmoid(logits_h1)
            preds_h2 = torch.sigmoid(logits_h2)
            preds_h3 = torch.sigmoid(logits_h3)

            # Store predictions and labels for metric calculation
            all_preds_h1.extend(preds_h1.cpu().numpy())
            all_preds_h2.extend(preds_h2.cpu().numpy())
            all_preds_h3.extend(preds_h3.cpu().numpy())

            all_labels_h1.extend(labels_h1.cpu().numpy())
            all_labels_h2.extend(labels_h2.cpu().numpy())
            all_labels_h3.extend(labels_h3.cpu().numpy())

    return total_loss / len(dataloader), all_preds_h1, all_preds_h2, all_preds_h3, all_labels_h1, all_labels_h2, all_labels_h3







# Testing the model on the new unseen data
test_loss_h1, test_preds_h1, _, _, test_labels_h1, _, _ = test(model, new_dataloader_h1, criterion, device)
test_loss_h2, _, test_preds_h2, _, _, test_labels_h2, _ = test(model, new_dataloader_h2, criterion, device)
test_loss_h3, _, _, test_preds_h3, _, _, test_labels_h3 = test(model, new_dataloader_h3, criterion, device)

# Calculate and print the metrics for H1, H2, and H3 using the same approach as before
# Use the appropriate threshold and custom functions for metrics calculation

# Metrics calculation for H1
threshold_h1 = 0.5
test_preds_h1_binary = (np.array(test_preds_h1) > threshold_h1).astype(int)

acc_h1 = accuracy_score(test_labels_h1, test_preds_h1_binary)
precision_h1 = precision_score(test_labels_h1, test_preds_h1_binary, average='micro')
recall_h1 = recall_score(test_labels_h1, test_preds_h1_binary, average='micro')
f1_h1 = f1_score(test_labels_h1, test_preds_h1_binary, average='micro')

# Metrics calculation for H2
threshold_h2 = 0.5
test_preds_h2_binary = (np.array(test_preds_h2) > threshold_h2).astype(int)

acc_h2 = accuracy_score(test_labels_h2, test_preds_h2_binary)
precision_h2 = precision_score(test_labels_h2, test_preds_h2_binary, average='micro')
recall_h2 = recall_score(test_labels_h2, test_preds_h2_binary, average='micro')
f1_h2 = f1_score(test_labels_h2, test_preds_h2_binary, average='micro')

# Metrics calculation for H3
threshold_h3 = 0.5
test_preds_h3_binary = (np.array(test_preds_h3) > threshold_h3).astype(int)

import numpy as np

def custom_accuracy(y_true, y_pred):
    correct_predictions = np.equal(y_true, y_pred)
    sample_accuracy = np.sum(correct_predictions, axis=1) / y_true.shape[1]
    return np.mean(sample_accuracy)

def custom_precision_recall_f1(y_true, y_pred):
    true_positives = np.sum((y_true == 1) & (y_pred == 1), axis=1)
    false_positives = np.sum((y_true == 0) & (y_pred == 1), axis=1)
    false_negatives = np.sum((y_true == 1) & (y_pred == 0), axis=1)

    precision = np.mean(true_positives / (true_positives + false_positives + 1e-8))
    recall = np.mean(true_positives / (true_positives + false_negatives + 1e-8))
    f1 = 2 * (precision * recall) / (precision + recall + 1e-8)

    return precision, recall, f1




# Custom accuracy calculation for H3
acc_h3_custom = custom_accuracy(np.array(test_labels_h3), test_preds_h3_binary)
# Calculate custom precision, recall, and F1 score for H3
precision_h3, recall_h3, f1_score_h3 = custom_precision_recall_f1(np.array(test_labels_h3), test_preds_h3_binary)

print("Testing Results:")
print(f"H1: Test Loss: {test_loss_h1:.4f}, Accuracy: {acc_h1:.4f}")
print(f"Precision H1: {precision_h1:.4f}, Recall H1: {recall_h1:.4f}, F1 H1: {f1_h1:.4f}")

print(f"H2: Test Loss: {test_loss_h2:.4f}, Accuracy: {acc_h2:.4f}")
print(f"Precision H2: {precision_h2:.4f}, Recall H2: {recall_h2:.4f}, F1 H2: {f1_h2:.4f}")

print(f"H3: Test Loss: {test_loss_h3:.4f}, Custom Accuracy: {acc_h3_custom:.4f}")
print(f"Precision H3: {precision_h3}, Recall H3:{recall_h3}, F1 Score H3: {f1_score_h3}")





    Found GPU%d %s which is of cuda capability %d.%d.
    PyTorch no longer supports this GPU because it is too old.
    The minimum cuda capability supported by this library is %d.%d.
    
Testing:   0%|          | 0/26 [00:00<?, ?it/s]Attention type 'block_sparse' is not possible if sequence_length: 512 <= num global tokens: 2 * config.block_size + min. num sliding tokens: 3 * config.block_size + config.num_random_blocks * config.block_size + additional buffer: config.num_random_blocks * config.block_size = 704 with config.block_size = 64, config.num_random_blocks = 3. Changing attention type to 'original_full'...
Testing: 100%|██████████| 26/26 [00:07<00:00,  3.31it/s]
Testing: 100%|██████████| 26/26 [00:07<00:00,  3.49it/s]
Testing: 100%|██████████| 26/26 [00:07<00:00,  3.48it/s]

Testing Results:
H1: Test Loss: 0.5334, Accuracy: 0.9760
Precision H1: 0.9760, Recall H1: 0.9760, F1 H1: 0.9760
H2: Test Loss: 0.5334, Accuracy: 0.9808
Precision H2: 0.9808, Recall H2: 0.9808, F1 H2: 0.9808
H3: Test Loss: 0.5334, Custom Accuracy: 0.8920
Precision H3: 0.7907528215355902, Recall H3:0.8971799006601666, F1 Score H3: 0.8406111503616124





In [None]:
#Testing_ChatgptEq_Text

In [1]:
import pandas as pd
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset, random_split
#from transformers import RobertaTokenizer, RobertaModel, AdamW, get_linear_schedule_with_warmup
from transformers import AutoTokenizer, AutoModel, AdamW, get_linear_schedule_with_warmup

from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from tqdm import tqdm
import numpy as np
import torch

import pandas as pd

# Loading and preprocessing new_data_h1
new_data_h1 = pd.read_csv("P3_Chatgpt_Eq_Model1_508.csv")
new_data_h1['text'].fillna("", inplace=True)  # Replace NaN values with empty strings
new_texts_h1 = new_data_h1['text'].tolist()
new_labels_h1 = new_data_h1.iloc[:, 1].values.tolist()

# Loading and preprocessing new_data_h2
new_data_h2 = pd.read_csv("P3_Chatgpt_Eq_Model2_508.csv")
new_data_h2['text'].fillna("", inplace=True)  # Replace NaN values with empty strings
new_texts_h2 = new_data_h2['text'].tolist()
new_labels_h2 = new_data_h2.iloc[:, 1].values.tolist()

# Loading and preprocessing new_data_h3 for multi-label classification
new_data_h3 = pd.read_csv("P3_Chatgpt_Eq_Model3_508.csv")
new_data_h3['text'].fillna("", inplace=True)  # Replace NaN values with empty strings

# Replace NaN values in label columns with 0 for new_data_h3
# Assuming all label columns are binary, fill NaN values with 0
for label_col in new_data_h3.columns[1:]:  # Skip the text column, only fill NaN in label columns
    new_data_h3[label_col].fillna(0, inplace=True)

new_texts_h3 = new_data_h3['text'].tolist()
new_labels_h3 = new_data_h3.iloc[:, 1:].values.tolist()  # Extract labels after filling NaN values




class TextDataset(Dataset):
    def __init__(self, texts, labels_h1, labels_h2, labels_h3, tokenizer, max_length=512):
        self.texts = texts
        self.labels_h1 = labels_h1  # Now expecting a list of binary values for H1
        self.labels_h2 = np.array(labels_h2).reshape(-1, 1)  # Reshape labels_h2 to [batch_size, 1]
        self.labels_h3 = labels_h3  # Multi-label but binary values for H3
        self.tokenizer = tokenizer
        self.max_length = max_length

    # Rest of the code remains unchanged

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        inputs = self.tokenizer.encode_plus(
            text, None, add_special_tokens=True, max_length=self.max_length,
            padding='max_length', return_token_type_ids=True, truncation=True
        )
        return {
            'ids': torch.tensor(inputs['input_ids'], dtype=torch.long),
            'mask': torch.tensor(inputs['attention_mask'], dtype=torch.long),
            'labels_h1': torch.tensor(self.labels_h1[idx], dtype=torch.float).unsqueeze(-1),
            'labels_h2': torch.tensor(self.labels_h2[idx], dtype=torch.float),
            'labels_h3': torch.tensor(self.labels_h3[idx], dtype=torch.float)
        }
        
class HierarchicalClassifier(nn.Module):
    def __init__(self, num_labels_h3):
        super(HierarchicalClassifier, self).__init__()
        self.bigbird = bigbird_model
        self.dropout = nn.Dropout(0.1)
        self.fc_h1 = nn.Linear(self.bigbird.config.hidden_size, 1)
        self.fc_h2 = nn.Linear(self.bigbird.config.hidden_size, 1)
        self.fc_h3 = nn.Linear(self.bigbird.config.hidden_size, num_labels_h3)

    def forward(self, input_ids, attention_mask):
        outputs = self.bigbird(input_ids, attention_mask=attention_mask)
        pooled_output = outputs.last_hidden_state[:, 0, :]
        logits_h1 = self.fc_h1(self.dropout(pooled_output))
        logits_h2 = self.fc_h2(self.dropout(pooled_output))
        logits_h3 = self.fc_h3(self.dropout(pooled_output))
        return logits_h1, logits_h2, logits_h3

# Initialize and move the model to the appropriate device (CPU/GPU)
#model = HierarchicalClassifier(num_labels_h3=len(train_labels_h3[0]))

#device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#model.to(device)

# Define loss function and optimizer
criterion = nn.BCEWithLogitsLoss()

# Tokenizer and Model setup for Clinical-BigBird
tokenizer = AutoTokenizer.from_pretrained("yikuan8/Clinical-BigBird")
bigbird_model = AutoModel.from_pretrained("yikuan8/Clinical-BigBird")


#from transformers import RobertaTokenizer

# Initialize the tokenizer
#tokenizer = RobertaTokenizer.from_pretrained('roberta-base')

# Preprocess the texts and labels in the same way as the original datasets
# Assuming the preprocessing steps are included in the TextDataset class
# Create TextDataset instances for the new data
new_dataset_h1 = TextDataset(new_texts_h1, new_labels_h1, new_labels_h2, new_labels_h3, tokenizer)
new_dataset_h2 = TextDataset(new_texts_h2, new_labels_h1, new_labels_h2, new_labels_h3, tokenizer)
new_dataset_h3 = TextDataset(new_texts_h3, new_labels_h1, new_labels_h2, new_labels_h3, tokenizer)

# Create DataLoader instances for the new datasets
#new_dataloader_h1 = DataLoader(new_dataset_h1, batch_size=8, shuffle=False)
#new_dataloader_h2 = DataLoader(new_dataset_h2, batch_size=8, shuffle=False)
#new_dataloader_h3 = DataLoader(new_dataset_h3, batch_size=8, shuffle=False)


new_dataloader_h1 = DataLoader(new_dataset_h1, batch_size=8, shuffle=False, drop_last=True)
new_dataloader_h2 = DataLoader(new_dataset_h2, batch_size=8, shuffle=False, drop_last=True)
new_dataloader_h3 = DataLoader(new_dataset_h3, batch_size=8, shuffle=False, drop_last=True)

# Load the saved model
model_path = 'model_Clinical_Bird.pth'  # Path to your saved model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = HierarchicalClassifier(num_labels_h3=len(new_labels_h3[0]))  # Make sure to provide the correct number of labels for H3
model.load_state_dict(torch.load(model_path))
model.eval()  # Set the model to evaluation mode
model.to(device)  # Move the model to the appropriate device
def test(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0.0
    all_preds_h1, all_preds_h2, all_preds_h3 = [], [], []
    all_labels_h1, all_labels_h2, all_labels_h3 = [], [], []

    with torch.no_grad():
        for batch in tqdm(dataloader, total=len(dataloader), desc="Testing"):
            input_ids = batch['ids'].to(device)
            attention_mask = batch['mask'].to(device)
            
            # Ensuring labels are correctly shaped and moved to the device
            labels_h1 = batch['labels_h1'].to(device)  # Shape: [batch_size, 1]
            labels_h2 = batch['labels_h2'].to(device)  # Shape: [batch_size, 1]
            labels_h3 = batch['labels_h3'].to(device)  # Shape: [batch_size, num_labels_h3]
            
            logits_h1, logits_h2, logits_h3 = model(input_ids, attention_mask)
            
            # Compute loss for each hierarchy
            loss_h1 = criterion(logits_h1, labels_h1)
            loss_h2 = criterion(logits_h2, labels_h2)
            loss_h3 = criterion(logits_h3, labels_h3)

            loss = loss_h1 + loss_h2 + loss_h3
            total_loss += loss.item()

            # Convert logits to probabilities for binary classification/multi-label classification
            preds_h1 = torch.sigmoid(logits_h1)
            preds_h2 = torch.sigmoid(logits_h2)
            preds_h3 = torch.sigmoid(logits_h3)

            # Store predictions and labels for metric calculation
            all_preds_h1.extend(preds_h1.cpu().numpy())
            all_preds_h2.extend(preds_h2.cpu().numpy())
            all_preds_h3.extend(preds_h3.cpu().numpy())

            all_labels_h1.extend(labels_h1.cpu().numpy())
            all_labels_h2.extend(labels_h2.cpu().numpy())
            all_labels_h3.extend(labels_h3.cpu().numpy())

    return total_loss / len(dataloader), all_preds_h1, all_preds_h2, all_preds_h3, all_labels_h1, all_labels_h2, all_labels_h3







# Testing the model on the new unseen data
test_loss_h1, test_preds_h1, _, _, test_labels_h1, _, _ = test(model, new_dataloader_h1, criterion, device)
test_loss_h2, _, test_preds_h2, _, _, test_labels_h2, _ = test(model, new_dataloader_h2, criterion, device)
test_loss_h3, _, _, test_preds_h3, _, _, test_labels_h3 = test(model, new_dataloader_h3, criterion, device)

# Calculate and print the metrics for H1, H2, and H3 using the same approach as before
# Use the appropriate threshold and custom functions for metrics calculation

# Metrics calculation for H1
threshold_h1 = 0.5
test_preds_h1_binary = (np.array(test_preds_h1) > threshold_h1).astype(int)

acc_h1 = accuracy_score(test_labels_h1, test_preds_h1_binary)
precision_h1 = precision_score(test_labels_h1, test_preds_h1_binary, average='micro')
recall_h1 = recall_score(test_labels_h1, test_preds_h1_binary, average='micro')
f1_h1 = f1_score(test_labels_h1, test_preds_h1_binary, average='micro')

# Metrics calculation for H2
threshold_h2 = 0.5
test_preds_h2_binary = (np.array(test_preds_h2) > threshold_h2).astype(int)

acc_h2 = accuracy_score(test_labels_h2, test_preds_h2_binary)
precision_h2 = precision_score(test_labels_h2, test_preds_h2_binary, average='micro')
recall_h2 = recall_score(test_labels_h2, test_preds_h2_binary, average='micro')
f1_h2 = f1_score(test_labels_h2, test_preds_h2_binary, average='micro')

# Metrics calculation for H3
threshold_h3 = 0.5
test_preds_h3_binary = (np.array(test_preds_h3) > threshold_h3).astype(int)

import numpy as np

def custom_accuracy(y_true, y_pred):
    correct_predictions = np.equal(y_true, y_pred)
    sample_accuracy = np.sum(correct_predictions, axis=1) / y_true.shape[1]
    return np.mean(sample_accuracy)

def custom_precision_recall_f1(y_true, y_pred):
    true_positives = np.sum((y_true == 1) & (y_pred == 1), axis=1)
    false_positives = np.sum((y_true == 0) & (y_pred == 1), axis=1)
    false_negatives = np.sum((y_true == 1) & (y_pred == 0), axis=1)

    precision = np.mean(true_positives / (true_positives + false_positives + 1e-8))
    recall = np.mean(true_positives / (true_positives + false_negatives + 1e-8))
    f1 = 2 * (precision * recall) / (precision + recall + 1e-8)

    return precision, recall, f1




# Custom accuracy calculation for H3
acc_h3_custom = custom_accuracy(np.array(test_labels_h3), test_preds_h3_binary)
# Calculate custom precision, recall, and F1 score for H3
precision_h3, recall_h3, f1_score_h3 = custom_precision_recall_f1(np.array(test_labels_h3), test_preds_h3_binary)

print("Testing Results:")
print(f"H1: Test Loss: {test_loss_h1:.4f}, Accuracy: {acc_h1:.4f}")
print(f"Precision H1: {precision_h1:.4f}, Recall H1: {recall_h1:.4f}, F1 H1: {f1_h1:.4f}")

print(f"H2: Test Loss: {test_loss_h2:.4f}, Accuracy: {acc_h2:.4f}")
print(f"Precision H2: {precision_h2:.4f}, Recall H2: {recall_h2:.4f}, F1 H2: {f1_h2:.4f}")

print(f"H3: Test Loss: {test_loss_h3:.4f}, Custom Accuracy: {acc_h3_custom:.4f}")
print(f"Precision H3: {precision_h3}, Recall H3:{recall_h3}, F1 Score H3: {f1_score_h3}")





    Found GPU%d %s which is of cuda capability %d.%d.
    PyTorch no longer supports this GPU because it is too old.
    The minimum cuda capability supported by this library is %d.%d.
    
Testing:   0%|          | 0/63 [00:00<?, ?it/s]Attention type 'block_sparse' is not possible if sequence_length: 512 <= num global tokens: 2 * config.block_size + min. num sliding tokens: 3 * config.block_size + config.num_random_blocks * config.block_size + additional buffer: config.num_random_blocks * config.block_size = 704 with config.block_size = 64, config.num_random_blocks = 3. Changing attention type to 'original_full'...
Testing: 100%|██████████| 63/63 [00:32<00:00,  1.91it/s]
Testing: 100%|██████████| 63/63 [00:32<00:00,  1.92it/s]
Testing: 100%|██████████| 63/63 [00:32<00:00,  1.94it/s]

Testing Results:
H1: Test Loss: 3.7800, Accuracy: 0.8472
Precision H1: 0.8472, Recall H1: 0.8472, F1 H1: 0.8472
H2: Test Loss: 3.7020, Accuracy: 0.5159
Precision H2: 0.5159, Recall H2: 0.5159, F1 H2: 0.5159
H3: Test Loss: 3.7020, Custom Accuracy: 0.6810
Precision H3: 0.5680531923039022, Recall H3:0.6326776822343488, F1 Score H3: 0.598626356197814



