In [1]:
!export CUDA_VISIBLE_DEVICES=0


In [2]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

# Now import your GPU-related libraries, such as PyTorch
import torch
torch.cuda.empty_cache()

torch.cuda.memory_summary(device=None, abbreviated=False)




In [3]:
# Import libraries and set up CUDA
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

import torch
torch.cuda.empty_cache()

import pandas as pd
import numpy as np
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset, random_split
from transformers import AutoTokenizer, AutoModel, AdamW, get_linear_schedule_with_warmup
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from tqdm import tqdm
import random



random_seed = 777  # You can change this value to any seed you prefer
random.seed(random_seed)
np.random.seed(random_seed)
torch.manual_seed(random_seed)
torch.cuda.manual_seed_all(random_seed)


# Load and preprocess your data for H1, H2, and H3
model1_data = pd.read_csv("model1_2636.csv")
texts_h1 = model1_data['text'].tolist()
labels_h1 = model1_data.iloc[:, 1].values.tolist()

model2_data = pd.read_csv("model2_2636.csv")
texts_h2 = model2_data['text'].tolist()
labels_h2 = model2_data.iloc[:, 1].values.tolist()

model3_data = pd.read_csv("model3_2636.csv")
texts_h3 = model3_data['text'].tolist()
labels_h3 = model3_data.iloc[:, 1:].values.tolist()

#print("Total Records in H1:", len(texts_h1),"texts_h1", "with", len(labels_h1), "labels")
#print("Total Records in H2:", len(texts_h2),"texts_h2","with", len(labels_h2), "labels")
#print("Total Records in H3:", len(texts_h3),"texts_h3", "with", len(labels_h3), "labels")

# Split the data for each hierarchy
def split_data(texts, labels, train_ratio, val_ratio, test_ratio):
    total_samples = len(texts)
    train_size = int(total_samples * train_ratio)
    val_size = int(total_samples * val_ratio)
    test_size = total_samples - train_size - val_size  # Remaining samples for test

    train_texts = texts[:train_size]
    val_texts = texts[train_size:train_size + val_size]
    test_texts = texts[train_size + val_size:]

    train_labels = labels[:train_size]
    val_labels = labels[train_size:train_size + val_size]
    test_labels = labels[train_size + val_size:]

    return train_texts, val_texts, test_texts, train_labels, val_labels, test_labels

# Split the data for each hierarchy
train_texts_h1, val_texts_h1, test_texts_h1, train_labels_h1, val_labels_h1, test_labels_h1 = split_data(texts_h1, labels_h1, 0.7, 0.1, 0.2)
train_texts_h2, val_texts_h2, test_texts_h2, train_labels_h2, val_labels_h2, test_labels_h2 = split_data(texts_h2, labels_h2, 0.7, 0.1, 0.2)
train_texts_h3, val_texts_h3, test_texts_h3, train_labels_h3, val_labels_h3, test_labels_h3 = split_data(texts_h3, labels_h3, 0.7, 0.1, 0.2)

#print("Train Val Test Splitting:")
#print("H1:", len(train_texts_h1), len(val_texts_h1), len(test_texts_h1))
#print("H2:", len(train_texts_h2), len(val_texts_h2), len(test_texts_h2))
#print("H3:", len(train_texts_h3), len(val_texts_h3), len(test_texts_h3))

# Tokenizer
#tokenizer = RobertaTokenizer.from_pretrained('roberta-base')

# Import and initialize the Bio_ClinicalBERT tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
bio_clinical_bert_model = AutoModel.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")


# Define the TextDataset class
class TextDataset(Dataset):
    def __init__(self, texts, labels_h1, labels_h2, labels_h3, tokenizer, max_length=512):
        self.texts = texts
        self.labels_h1 = labels_h1  # Now expecting a list of binary values for H1
        self.labels_h2 = np.array(labels_h2).reshape(-1, 1)  # Reshape labels_h2 to [batch_size, 1]
        self.labels_h3 = labels_h3  # Multi-label but binary values for H3
        self.tokenizer = tokenizer
        self.max_length = max_length

    # Rest of the code remains unchanged

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        inputs = self.tokenizer.encode_plus(
            text, None, add_special_tokens=True, max_length=self.max_length,
            padding='max_length', return_token_type_ids=True, truncation=True
        )
        return {
            'ids': torch.tensor(inputs['input_ids'], dtype=torch.long),
            'mask': torch.tensor(inputs['attention_mask'], dtype=torch.long),
            'labels_h1': torch.tensor(self.labels_h1[idx], dtype=torch.float).unsqueeze(-1),
            'labels_h2': torch.tensor(self.labels_h2[idx], dtype=torch.float),
            'labels_h3': torch.tensor(self.labels_h3[idx], dtype=torch.float)
        }

# Create datasets and dataloaders for H1, H2, and H3
dataset_h1 = TextDataset(train_texts_h1 + val_texts_h1 + test_texts_h1, train_labels_h1 + val_labels_h1 + test_labels_h1, train_labels_h2 + val_labels_h2 + test_labels_h2, train_labels_h3 + val_labels_h3 + test_labels_h3, tokenizer)
dataset_h2 = TextDataset(train_texts_h2 + val_texts_h2 + test_texts_h2, train_labels_h1 + val_labels_h1 + test_labels_h1, train_labels_h2 + val_labels_h2 + test_labels_h2, train_labels_h3 + val_labels_h3 + test_labels_h3, tokenizer)
dataset_h3 = TextDataset(train_texts_h3 + val_texts_h3 + test_texts_h3, train_labels_h1 + val_labels_h1 + test_labels_h1, train_labels_h2 + val_labels_h2 + test_labels_h2, train_labels_h3 + val_labels_h3 + test_labels_h3, tokenizer)

# Split the datasets into train, val, and test for H1, H2, and H3
train_size_h1 = len(train_texts_h1)
val_size_h1 = len(val_texts_h1)
train_size_h2 = len(train_texts_h2)
val_size_h2 = len(val_texts_h2)
train_size_h3 = len(train_texts_h3)
val_size_h3 = len(val_texts_h3)

train_dataset_h1, val_dataset_h1, test_dataset_h1 = random_split(dataset_h1, [train_size_h1, val_size_h1, len(test_texts_h1)])
train_dataset_h2, val_dataset_h2, test_dataset_h2 = random_split(dataset_h2, [train_size_h2, val_size_h2, len(test_texts_h2)])
train_dataset_h3, val_dataset_h3, test_dataset_h3 = random_split(dataset_h3, [train_size_h3, val_size_h3, len(test_texts_h3)])

# Create dataloaders for H1, H2, and H3
train_dataloader_h1 = DataLoader(train_dataset_h1, batch_size=8, shuffle=True)
val_dataloader_h1 = DataLoader(val_dataset_h1,batch_size=8, shuffle=False)  # You can set shuffle to True if you want to shuffle the validation data.
train_dataloader_h2 = DataLoader(train_dataset_h2, batch_size=8, shuffle=True)
val_dataloader_h2 = DataLoader(val_dataset_h2, batch_size=8, shuffle=False)

train_dataloader_h3 = DataLoader(train_dataset_h3, batch_size=8, shuffle=True)
val_dataloader_h3 = DataLoader(val_dataset_h3, batch_size=8, shuffle=False)

# Define the model architecture for H1, H2, and H3
class HierarchicalClassifier(nn.Module):
    def __init__(self, num_labels_h3):
        super(HierarchicalClassifier, self).__init__()
        self.bert = bio_clinical_bert_model  # Correctly assigned Bio_ClinicalBERT model
        self.dropout = nn.Dropout(0.1)
        self.fc_h1 = nn.Linear(self.bert.config.hidden_size, 1)
        self.fc_h2 = nn.Linear(self.bert.config.hidden_size, 1)
        self.fc_h3 = nn.Linear(self.bert.config.hidden_size, num_labels_h3)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids, attention_mask)  # Correctly refer to self.bert here
        pooled_output = outputs.last_hidden_state[:, 0, :]
        logits_h1 = self.fc_h1(self.dropout(pooled_output))
        logits_h2 = self.fc_h2(self.dropout(pooled_output))
        logits_h3 = self.fc_h3(self.dropout(pooled_output))
        return logits_h1, logits_h2, logits_h3

# Initialize and move the model to the appropriate device (CPU/GPU)
model = HierarchicalClassifier(num_labels_h3=len(train_labels_h3[0]))

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Define loss function and optimizer
criterion = nn.BCEWithLogitsLoss()
optimizer = AdamW(model.parameters(), lr=1e-5)  # You can adjust the learning rate as needed

# Training loop for H1, H2, and H3
def train(model, dataloader, optimizer, criterion, device):
    model.train()
    total_loss = 0.0

    for batch in tqdm(dataloader, total=len(dataloader), desc="Training"):
        input_ids = batch['ids'].to(device)
        attention_mask = batch['mask'].to(device)
        labels_h1 = batch['labels_h1'].to(device)
        labels_h2 = batch['labels_h2'].to(device)
        labels_h3 = batch['labels_h3'].to(device)

        optimizer.zero_grad()

        logits_h1, logits_h2, logits_h3 = model(input_ids, attention_mask)
        loss_h1 = criterion(logits_h1, labels_h1)
        loss_h2 = criterion(logits_h2, labels_h2)
        loss_h3 = criterion(logits_h3, labels_h3)

        loss = loss_h1 + loss_h2 + loss_h3
        total_loss += loss.item()

        loss.backward()
        optimizer.step()

    return total_loss / len(dataloader)

# Validation loop for H1, H2, and H3
def evaluate(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0.0
    all_preds_h1, all_preds_h2, all_preds_h3 = [], [], []
    all_labels_h1, all_labels_h2, all_labels_h3 = [], [], []

    with torch.no_grad():
        for batch in tqdm(dataloader, total=len(dataloader), desc="Validation"):
            input_ids = batch['ids'].to(device)
            attention_mask = batch['mask'].to(device)
            labels_h1 = batch['labels_h1'].to(device)
            labels_h2 = batch['labels_h2'].to(device)
            labels_h3 = batch['labels_h3'].to(device)

            logits_h1, logits_h2, logits_h3 = model(input_ids, attention_mask)
            loss_h1 = criterion(logits_h1, labels_h1)
            loss_h2 = criterion(logits_h2, labels_h2)
            loss_h3 = criterion(logits_h3, labels_h3)

            loss = loss_h1 + loss_h2 + loss_h3
            total_loss += loss.item()

            preds_h1 = torch.sigmoid(logits_h1)
            preds_h2 = torch.sigmoid(logits_h2)
            preds_h3 = torch.sigmoid(logits_h3)

            all_preds_h1.extend(preds_h1.cpu().numpy())
            all_preds_h2.extend(preds_h2.cpu().numpy())
            all_preds_h3.extend(preds_h3.cpu().numpy())

            all_labels_h1.extend(labels_h1.cpu().numpy())
            all_labels_h2.extend(labels_h2.cpu().numpy())
            all_labels_h3.extend(labels_h3.cpu().numpy())

    return total_loss / len(dataloader), all_preds_h1, all_preds_h2, all_preds_h3, all_labels_h1, all_labels_h2, all_labels_h3

# Training and evaluation for H1, H2, and H3
# Custom accuracy function
def custom_accuracy(y_true, y_pred):
    correct_labels = np.sum(np.equal(y_true, y_pred), axis=1)
    total_labels = y_true.shape[1]
    sample_accuracy = np.mean(correct_labels / total_labels)
    return sample_accuracy

def custom_precision_recall_f1(y_true, y_pred):
    # Calculate precision and recall for each sample
    sample_precisions = []
    sample_recalls = []
    for true, pred in zip(y_true, y_pred):
        true_positive = np.sum((true == 1) & (pred == 1))
        false_positive = np.sum((true == 0) & (pred == 1))
        false_negative = np.sum((true == 1) & (pred == 0))
        
        sample_precision = true_positive / (true_positive + false_positive) if (true_positive + false_positive) > 0 else 0
        sample_recall = true_positive / (true_positive + false_negative) if (true_positive + false_negative) > 0 else 0
        
        sample_precisions.append(sample_precision)
        sample_recalls.append(sample_recall)
    
    # Calculate average precision and recall across all samples
    avg_precision = np.mean(sample_precisions)
    avg_recall = np.mean(sample_recalls)
    
    # Calculate F1 score
    if avg_precision + avg_recall > 0:
        avg_f1 = 2 * (avg_precision * avg_recall) / (avg_precision + avg_recall)
    else:
        avg_f1 = 0
    
    return avg_precision, avg_recall, avg_f1


# Training and evaluation for H1, H2, and H3
num_epochs = 33 # You can adjust the number of epochs as needed

for epoch in range(num_epochs):
    # Training for H1
    train_loss_h1 = train(model, train_dataloader_h1, optimizer, criterion, device)
    # Training for H2
    train_loss_h2 = train(model, train_dataloader_h2, optimizer, criterion, device)
    # Training for H3
    train_loss_h3 = train(model, train_dataloader_h3, optimizer, criterion, device)

    # Validation for H1
    val_loss_h1, val_preds_h1, _, _, val_labels_h1, _, _ = evaluate(model, val_dataloader_h1, criterion, device)
    # Validation for H2
    val_loss_h2, _, val_preds_h2, _, _, val_labels_h2, _ = evaluate(model, val_dataloader_h2, criterion, device)
    # Validation for H3
    val_loss_h3, _, _, val_preds_h3, _, _, val_labels_h3 = evaluate(model, val_dataloader_h3, criterion, device)

    # Metrics calculation for H1
    threshold_h1 = 0.5
    val_preds_h1_binary = (np.array(val_preds_h1) > threshold_h1).astype(int)

    acc_h1 = accuracy_score(val_labels_h1, val_preds_h1_binary)
    precision_h1 = precision_score(val_labels_h1, val_preds_h1_binary, average='micro')
    recall_h1 = recall_score(val_labels_h1, val_preds_h1_binary, average='micro')
    f1_h1 = f1_score(val_labels_h1, val_preds_h1_binary, average='micro')


    # Metrics calculation for H2
    threshold_h2 = 0.5
    val_preds_h2_binary = (np.array(val_preds_h2) > threshold_h2).astype(int)

    acc_h2 = accuracy_score(val_labels_h2, val_preds_h2_binary)
    precision_h2 = precision_score(val_labels_h2, val_preds_h2_binary, average='micro')
    recall_h2 = recall_score(val_labels_h2, val_preds_h2_binary, average='micro')
    f1_h2 = f1_score(val_labels_h2, val_preds_h2_binary, average='micro')

# Convert predictions and labels to binary using the threshold

    # Metrics calculation for H3
    threshold_h3 = 0.5
    val_preds_h3_binary = (np.array(val_preds_h3) > threshold_h3).astype(int)


    # Custom accuracy calculation for H3
    acc_h3_custom = custom_accuracy(np.array(val_labels_h3), val_preds_h3_binary)
    # Calculate custom precision, recall, and F1 score for H3
    precision_h3, recall_h3, f1_score_h3 = custom_precision_recall_f1(np.array(val_labels_h3), val_preds_h3_binary)

    print(f"Epoch {epoch+1}/{num_epochs}")
    print(f"H1: Train Loss: {train_loss_h1:.4f}, Val Loss: {val_loss_h1:.4f}, Accuracy: {acc_h1:.4f}")
    print(f"Precision H1: {precision_h1:.4f}, Recall H1: {recall_h1:.4f}, F1 H1: {f1_h1:.4f}")

    print(f"H2: Train Loss: {train_loss_h2:.4f}, Val Loss: {val_loss_h2:.4f}, Accuracy: {acc_h2:.4f}")
    print(f"Precision H2: {precision_h2:.4f}, Recall H2: {recall_h2:.4f}, F1 H2: {f1_h2:.4f}")

    print(f"H3: Train Loss: {train_loss_h3:.4f}, Val Loss: {val_loss_h3:.4f}, Custom Accuracy: {acc_h3_custom:.4f}")
    print(f"Precision H3: {precision_h3},Recall H3:{recall_h3},F1 Score H3: {f1_score_h3}")

# Saving the model
torch.save(model.state_dict(), 'model.pth')

# Loading the model
#loaded_model = YourModelClass(*args, **kwargs)  # Instantiate your model
#loaded_model.load_state_dict(torch.load('model.pth'))
#loaded_model.eval()  # Set the model to evaluation mode



Training: 100%|██████████| 231/231 [02:55<00:00,  1.32it/s]
Training: 100%|██████████| 231/231 [02:55<00:00,  1.32it/s]
Training: 100%|██████████| 231/231 [02:52<00:00,  1.34it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.63it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.62it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.64it/s]


Epoch 1/33
H1: Train Loss: 1.7278, Val Loss: 1.4918, Accuracy: 0.9049
Precision H1: 0.9049, Recall H1: 0.9049, F1 H1: 0.9049
H2: Train Loss: 1.5552, Val Loss: 1.4855, Accuracy: 0.6502
Precision H2: 0.6502, Recall H2: 0.6502, F1 H2: 0.6502
H3: Train Loss: 1.4929, Val Loss: 1.4087, Custom Accuracy: 0.6710
Precision H3: 0.6918975194640594,Recall H3:0.47677563247905447,F1 Score H3: 0.5645374447161476


Training: 100%|██████████| 231/231 [02:52<00:00,  1.34it/s]
Training: 100%|██████████| 231/231 [02:52<00:00,  1.34it/s]
Training: 100%|██████████| 231/231 [02:51<00:00,  1.34it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.64it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.62it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.65it/s]


Epoch 2/33
H1: Train Loss: 1.4458, Val Loss: 1.3054, Accuracy: 0.9544
Precision H1: 0.9544, Recall H1: 0.9544, F1 H1: 0.9544
H2: Train Loss: 1.4025, Val Loss: 1.2854, Accuracy: 0.7110
Precision H2: 0.7110, Recall H2: 0.7110, F1 H2: 0.7110
H3: Train Loss: 1.3446, Val Loss: 1.3020, Custom Accuracy: 0.6783
Precision H3: 0.670903494477639,Recall H3:0.5278809276908136,F1 Score H3: 0.5908604624928947


Training: 100%|██████████| 231/231 [02:52<00:00,  1.34it/s]
Training: 100%|██████████| 231/231 [02:52<00:00,  1.34it/s]
Training: 100%|██████████| 231/231 [02:52<00:00,  1.34it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.65it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.63it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.65it/s]


Epoch 3/33
H1: Train Loss: 1.2689, Val Loss: 1.0838, Accuracy: 0.9544
Precision H1: 0.9544, Recall H1: 0.9544, F1 H1: 0.9544
H2: Train Loss: 1.2064, Val Loss: 1.0860, Accuracy: 0.8023
Precision H2: 0.8023, Recall H2: 0.8023, F1 H2: 0.8023
H3: Train Loss: 1.1050, Val Loss: 1.0887, Custom Accuracy: 0.6841
Precision H3: 0.6736013036393264,Recall H3:0.5459284589512725,F1 Score H3: 0.6030818319057899


Training: 100%|██████████| 231/231 [02:51<00:00,  1.34it/s]
Training: 100%|██████████| 231/231 [02:51<00:00,  1.35it/s]
Training: 100%|██████████| 231/231 [02:52<00:00,  1.34it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.64it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.62it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.63it/s]


Epoch 4/33
H1: Train Loss: 0.9985, Val Loss: 0.9008, Accuracy: 0.9658
Precision H1: 0.9658, Recall H1: 0.9658, F1 H1: 0.9658
H2: Train Loss: 0.9169, Val Loss: 0.8951, Accuracy: 0.9163
Precision H2: 0.9163, Recall H2: 0.9163, F1 H2: 0.9163
H3: Train Loss: 0.8380, Val Loss: 0.9310, Custom Accuracy: 0.6964
Precision H3: 0.6688786287645603,Recall H3:0.5814686791873103,F1 Score H3: 0.6221183031799981


Training: 100%|██████████| 231/231 [02:51<00:00,  1.34it/s]
Training: 100%|██████████| 231/231 [02:52<00:00,  1.34it/s]
Training: 100%|██████████| 231/231 [02:52<00:00,  1.34it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.63it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.61it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.66it/s]


Epoch 5/33
H1: Train Loss: 0.7772, Val Loss: 0.8705, Accuracy: 0.9734
Precision H1: 0.9734, Recall H1: 0.9734, F1 H1: 0.9734
H2: Train Loss: 0.7267, Val Loss: 0.8347, Accuracy: 0.9202
Precision H2: 0.9202, Recall H2: 0.9202, F1 H2: 0.9202
H3: Train Loss: 0.6948, Val Loss: 0.8305, Custom Accuracy: 0.6973
Precision H3: 0.6829757378236465,Recall H3:0.5723776329289637,F1 Score H3: 0.622804773972203


Training: 100%|██████████| 231/231 [02:52<00:00,  1.34it/s]
Training: 100%|██████████| 231/231 [02:52<00:00,  1.34it/s]
Training: 100%|██████████| 231/231 [02:51<00:00,  1.34it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.64it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.62it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.64it/s]


Epoch 6/33
H1: Train Loss: 0.6759, Val Loss: 0.8431, Accuracy: 0.9696
Precision H1: 0.9696, Recall H1: 0.9696, F1 H1: 0.9696
H2: Train Loss: 0.6569, Val Loss: 0.8306, Accuracy: 0.9506
Precision H2: 0.9506, Recall H2: 0.9506, F1 H2: 0.9506
H3: Train Loss: 0.6271, Val Loss: 0.8038, Custom Accuracy: 0.7023
Precision H3: 0.6785940611986239,Recall H3:0.5951582637894425,F1 Score H3: 0.6341434757099792


Training: 100%|██████████| 231/231 [02:52<00:00,  1.34it/s]
Training: 100%|██████████| 231/231 [02:52<00:00,  1.34it/s]
Training: 100%|██████████| 231/231 [02:53<00:00,  1.34it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.50it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.49it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.60it/s]


Epoch 7/33
H1: Train Loss: 0.6085, Val Loss: 0.8375, Accuracy: 0.9772
Precision H1: 0.9772, Recall H1: 0.9772, F1 H1: 0.9772
H2: Train Loss: 0.6131, Val Loss: 0.8272, Accuracy: 0.9506
Precision H2: 0.9506, Recall H2: 0.9506, F1 H2: 0.9506
H3: Train Loss: 0.5802, Val Loss: 0.7626, Custom Accuracy: 0.7081
Precision H3: 0.6858832759973444,Recall H3:0.5858946608946608,F1 Score H3: 0.6319583596265945


Training: 100%|██████████| 231/231 [02:55<00:00,  1.32it/s]
Training: 100%|██████████| 231/231 [02:52<00:00,  1.34it/s]
Training: 100%|██████████| 231/231 [02:53<00:00,  1.34it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.50it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.46it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.51it/s]


Epoch 8/33
H1: Train Loss: 0.5895, Val Loss: 0.8473, Accuracy: 0.9848
Precision H1: 0.9848, Recall H1: 0.9848, F1 H1: 0.9848
H2: Train Loss: 0.5843, Val Loss: 0.7742, Accuracy: 0.9582
Precision H2: 0.9582, Recall H2: 0.9582, F1 H2: 0.9582
H3: Train Loss: 0.5665, Val Loss: 0.7439, Custom Accuracy: 0.7093
Precision H3: 0.6805329229283602,Recall H3:0.5976888658447594,F1 Score H3: 0.6364262516061162


Training: 100%|██████████| 231/231 [02:52<00:00,  1.34it/s]
Training: 100%|██████████| 231/231 [02:51<00:00,  1.34it/s]
Training: 100%|██████████| 231/231 [02:52<00:00,  1.34it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.64it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.62it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.65it/s]


Epoch 9/33
H1: Train Loss: 0.5650, Val Loss: 0.8449, Accuracy: 0.9772
Precision H1: 0.9772, Recall H1: 0.9772, F1 H1: 0.9772
H2: Train Loss: 0.5955, Val Loss: 0.7852, Accuracy: 0.9544
Precision H2: 0.9544, Recall H2: 0.9544, F1 H2: 0.9544
H3: Train Loss: 0.5606, Val Loss: 0.7192, Custom Accuracy: 0.7119
Precision H3: 0.6790678375279137,Recall H3:0.6127704256031252,F1 Score H3: 0.6442179329892858


Training: 100%|██████████| 231/231 [02:52<00:00,  1.34it/s]
Training: 100%|██████████| 231/231 [02:52<00:00,  1.34it/s]
Training: 100%|██████████| 231/231 [02:52<00:00,  1.34it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.63it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.62it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.65it/s]


Epoch 10/33
H1: Train Loss: 0.5547, Val Loss: 0.8737, Accuracy: 0.9772
Precision H1: 0.9772, Recall H1: 0.9772, F1 H1: 0.9772
H2: Train Loss: 0.5452, Val Loss: 0.8276, Accuracy: 0.9506
Precision H2: 0.9506, Recall H2: 0.9506, F1 H2: 0.9506
H3: Train Loss: 0.5347, Val Loss: 0.7304, Custom Accuracy: 0.7262
Precision H3: 0.6959940249864205,Recall H3:0.6238691916448571,F1 Score H3: 0.6579609527511079


Training: 100%|██████████| 231/231 [02:52<00:00,  1.34it/s]
Training: 100%|██████████| 231/231 [02:52<00:00,  1.34it/s]
Training: 100%|██████████| 231/231 [02:52<00:00,  1.34it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.63it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.62it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.65it/s]


Epoch 11/33
H1: Train Loss: 0.5264, Val Loss: 0.8259, Accuracy: 0.9848
Precision H1: 0.9848, Recall H1: 0.9848, F1 H1: 0.9848
H2: Train Loss: 0.5277, Val Loss: 0.8089, Accuracy: 0.9392
Precision H2: 0.9392, Recall H2: 0.9392, F1 H2: 0.9392
H3: Train Loss: 0.5119, Val Loss: 0.7131, Custom Accuracy: 0.7330
Precision H3: 0.7139429328592827,Recall H3:0.6143729527759946,F1 Score H3: 0.6604260816538143


Training: 100%|██████████| 231/231 [02:52<00:00,  1.34it/s]
Training: 100%|██████████| 231/231 [02:52<00:00,  1.34it/s]
Training: 100%|██████████| 231/231 [02:52<00:00,  1.34it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.65it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.63it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.64it/s]


Epoch 12/33
H1: Train Loss: 0.5445, Val Loss: 0.7648, Accuracy: 0.9848
Precision H1: 0.9848, Recall H1: 0.9848, F1 H1: 0.9848
H2: Train Loss: 0.5179, Val Loss: 0.7763, Accuracy: 0.9620
Precision H2: 0.9620, Recall H2: 0.9620, F1 H2: 0.9620
H3: Train Loss: 0.5062, Val Loss: 0.7132, Custom Accuracy: 0.7420
Precision H3: 0.6995788959667286,Recall H3:0.6692912284167037,F1 Score H3: 0.6840999892036802


Training: 100%|██████████| 231/231 [02:51<00:00,  1.34it/s]
Training: 100%|██████████| 231/231 [02:51<00:00,  1.34it/s]
Training: 100%|██████████| 231/231 [02:52<00:00,  1.34it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.64it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.62it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.64it/s]


Epoch 13/33
H1: Train Loss: 0.5001, Val Loss: 0.7557, Accuracy: 0.9810
Precision H1: 0.9810, Recall H1: 0.9810, F1 H1: 0.9810
H2: Train Loss: 0.4926, Val Loss: 0.8086, Accuracy: 0.9544
Precision H2: 0.9544, Recall H2: 0.9544, F1 H2: 0.9544
H3: Train Loss: 0.4730, Val Loss: 0.7116, Custom Accuracy: 0.7517
Precision H3: 0.7231046203479664,Recall H3:0.653031811872116,F1 Score H3: 0.686284163899545


Training: 100%|██████████| 231/231 [02:52<00:00,  1.34it/s]
Training: 100%|██████████| 231/231 [02:51<00:00,  1.34it/s]
Training: 100%|██████████| 231/231 [02:52<00:00,  1.34it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.64it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.63it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.64it/s]


Epoch 14/33
H1: Train Loss: 0.4631, Val Loss: 0.7064, Accuracy: 0.9848
Precision H1: 0.9848, Recall H1: 0.9848, F1 H1: 0.9848
H2: Train Loss: 0.4680, Val Loss: 0.7360, Accuracy: 0.9582
Precision H2: 0.9582, Recall H2: 0.9582, F1 H2: 0.9582
H3: Train Loss: 0.4641, Val Loss: 0.6832, Custom Accuracy: 0.7666
Precision H3: 0.7321469721659835,Recall H3:0.6993432423090217,F1 Score H3: 0.7153692455370009


Training: 100%|██████████| 231/231 [02:52<00:00,  1.34it/s]
Training: 100%|██████████| 231/231 [02:52<00:00,  1.34it/s]
Training: 100%|██████████| 231/231 [02:52<00:00,  1.34it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.52it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.49it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.51it/s]


Epoch 15/33
H1: Train Loss: 0.4630, Val Loss: 0.6784, Accuracy: 0.9848
Precision H1: 0.9848, Recall H1: 0.9848, F1 H1: 0.9848
H2: Train Loss: 0.4456, Val Loss: 0.7307, Accuracy: 0.9506
Precision H2: 0.9506, Recall H2: 0.9506, F1 H2: 0.9506
H3: Train Loss: 0.4215, Val Loss: 0.6487, Custom Accuracy: 0.7809
Precision H3: 0.7532799477666398,Recall H3:0.7077111418366171,F1 Score H3: 0.7297848915715489


Training: 100%|██████████| 231/231 [02:53<00:00,  1.33it/s]
Training: 100%|██████████| 231/231 [02:51<00:00,  1.34it/s]
Training: 100%|██████████| 231/231 [02:52<00:00,  1.34it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.63it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.62it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.64it/s]


Epoch 16/33
H1: Train Loss: 0.4137, Val Loss: 0.6584, Accuracy: 0.9848
Precision H1: 0.9848, Recall H1: 0.9848, F1 H1: 0.9848
H2: Train Loss: 0.4083, Val Loss: 0.6862, Accuracy: 0.9544
Precision H2: 0.9544, Recall H2: 0.9544, F1 H2: 0.9544
H3: Train Loss: 0.3924, Val Loss: 0.6396, Custom Accuracy: 0.7970
Precision H3: 0.7897817117398866,Recall H3:0.7024648988527316,F1 Score H3: 0.7435686920847091


Training: 100%|██████████| 231/231 [02:52<00:00,  1.34it/s]
Training: 100%|██████████| 231/231 [02:52<00:00,  1.34it/s]
Training: 100%|██████████| 231/231 [02:52<00:00,  1.34it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.64it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.61it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.64it/s]


Epoch 17/33
H1: Train Loss: 0.3870, Val Loss: 0.6216, Accuracy: 0.9848
Precision H1: 0.9848, Recall H1: 0.9848, F1 H1: 0.9848
H2: Train Loss: 0.3767, Val Loss: 0.6666, Accuracy: 0.9544
Precision H2: 0.9544, Recall H2: 0.9544, F1 H2: 0.9544
H3: Train Loss: 0.3593, Val Loss: 0.5928, Custom Accuracy: 0.8116
Precision H3: 0.780191513176304,Recall H3:0.746378368146429,F1 Score H3: 0.7629104642647077


Training: 100%|██████████| 231/231 [02:52<00:00,  1.34it/s]
Training: 100%|██████████| 231/231 [02:53<00:00,  1.33it/s]
Training: 100%|██████████| 231/231 [02:52<00:00,  1.34it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.64it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.62it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.65it/s]


Epoch 18/33
H1: Train Loss: 0.3561, Val Loss: 0.6134, Accuracy: 0.9810
Precision H1: 0.9810, Recall H1: 0.9810, F1 H1: 0.9810
H2: Train Loss: 0.3550, Val Loss: 0.6512, Accuracy: 0.9392
Precision H2: 0.9392, Recall H2: 0.9392, F1 H2: 0.9392
H3: Train Loss: 0.3330, Val Loss: 0.5705, Custom Accuracy: 0.8371
Precision H3: 0.8080114013574089,Recall H3:0.7932325701337107,F1 Score H3: 0.8005537844420023


Training: 100%|██████████| 231/231 [02:52<00:00,  1.34it/s]
Training: 100%|██████████| 231/231 [02:51<00:00,  1.34it/s]
Training: 100%|██████████| 231/231 [02:52<00:00,  1.34it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.64it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.61it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.64it/s]


Epoch 19/33
H1: Train Loss: 0.3347, Val Loss: 0.5878, Accuracy: 0.9848
Precision H1: 0.9848, Recall H1: 0.9848, F1 H1: 0.9848
H2: Train Loss: 0.3111, Val Loss: 0.6359, Accuracy: 0.9430
Precision H2: 0.9430, Recall H2: 0.9430, F1 H2: 0.9430
H3: Train Loss: 0.2981, Val Loss: 0.5492, Custom Accuracy: 0.8602
Precision H3: 0.8557256431781146,Recall H3:0.7969536483795038,F1 Score H3: 0.82529463136184


Training: 100%|██████████| 231/231 [02:51<00:00,  1.34it/s]
Training: 100%|██████████| 231/231 [02:52<00:00,  1.34it/s]
Training: 100%|██████████| 231/231 [02:54<00:00,  1.32it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.63it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.61it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.64it/s]


Epoch 20/33
H1: Train Loss: 0.2978, Val Loss: 0.5630, Accuracy: 0.9848
Precision H1: 0.9848, Recall H1: 0.9848, F1 H1: 0.9848
H2: Train Loss: 0.2803, Val Loss: 0.6125, Accuracy: 0.9430
Precision H2: 0.9430, Recall H2: 0.9430, F1 H2: 0.9430
H3: Train Loss: 0.2626, Val Loss: 0.5096, Custom Accuracy: 0.8736
Precision H3: 0.8551720354001724,Recall H3:0.8440374686572405,F1 Score H3: 0.8495682707777904


Training: 100%|██████████| 231/231 [02:52<00:00,  1.34it/s]
Training: 100%|██████████| 231/231 [02:52<00:00,  1.34it/s]
Training: 100%|██████████| 231/231 [02:52<00:00,  1.34it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.64it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.62it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.64it/s]


Epoch 21/33
H1: Train Loss: 0.2529, Val Loss: 0.5678, Accuracy: 0.9886
Precision H1: 0.9886, Recall H1: 0.9886, F1 H1: 0.9886
H2: Train Loss: 0.2515, Val Loss: 0.5590, Accuracy: 0.9620
Precision H2: 0.9620, Recall H2: 0.9620, F1 H2: 0.9620
H3: Train Loss: 0.2390, Val Loss: 0.4543, Custom Accuracy: 0.8906
Precision H3: 0.8751914857428166,Recall H3:0.8703053072824936,F1 Score H3: 0.8727415575599737


Training: 100%|██████████| 231/231 [02:52<00:00,  1.34it/s]
Training: 100%|██████████| 231/231 [02:52<00:00,  1.34it/s]
Training: 100%|██████████| 231/231 [02:52<00:00,  1.34it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.64it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.62it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.65it/s]


Epoch 22/33
H1: Train Loss: 0.2479, Val Loss: 0.5034, Accuracy: 0.9848
Precision H1: 0.9848, Recall H1: 0.9848, F1 H1: 0.9848
H2: Train Loss: 0.2279, Val Loss: 0.5317, Accuracy: 0.9544
Precision H2: 0.9544, Recall H2: 0.9544, F1 H2: 0.9544
H3: Train Loss: 0.2070, Val Loss: 0.4453, Custom Accuracy: 0.9038
Precision H3: 0.8888537740248768,Recall H3:0.8885862975216587,F1 Score H3: 0.8887200156477848


Training: 100%|██████████| 231/231 [02:52<00:00,  1.34it/s]
Training: 100%|██████████| 231/231 [02:52<00:00,  1.34it/s]
Training: 100%|██████████| 231/231 [02:52<00:00,  1.34it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.64it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.62it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.65it/s]


Epoch 23/33
H1: Train Loss: 0.2027, Val Loss: 0.5149, Accuracy: 0.9848
Precision H1: 0.9848, Recall H1: 0.9848, F1 H1: 0.9848
H2: Train Loss: 0.1927, Val Loss: 0.5490, Accuracy: 0.9506
Precision H2: 0.9506, Recall H2: 0.9506, F1 H2: 0.9506
H3: Train Loss: 0.1804, Val Loss: 0.4412, Custom Accuracy: 0.9172
Precision H3: 0.8978659490066334,Recall H3:0.9148748484299816,F1 Score H3: 0.90629060168941


Training: 100%|██████████| 231/231 [02:52<00:00,  1.34it/s]
Training: 100%|██████████| 231/231 [02:52<00:00,  1.34it/s]
Training: 100%|██████████| 231/231 [02:52<00:00,  1.34it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.63it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.61it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.65it/s]


Epoch 24/33
H1: Train Loss: 0.1716, Val Loss: 0.5036, Accuracy: 0.9810
Precision H1: 0.9810, Recall H1: 0.9810, F1 H1: 0.9810
H2: Train Loss: 0.1668, Val Loss: 0.5337, Accuracy: 0.9506
Precision H2: 0.9506, Recall H2: 0.9506, F1 H2: 0.9506
H3: Train Loss: 0.1549, Val Loss: 0.4295, Custom Accuracy: 0.9342
Precision H3: 0.9281863995742323,Recall H3:0.9127141869537306,F1 Score H3: 0.9203852736303505


Training: 100%|██████████| 231/231 [02:52<00:00,  1.34it/s]
Training: 100%|██████████| 231/231 [02:52<00:00,  1.34it/s]
Training: 100%|██████████| 231/231 [02:52<00:00,  1.34it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.64it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.63it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.65it/s]


Epoch 25/33
H1: Train Loss: 0.1523, Val Loss: 0.5745, Accuracy: 0.9810
Precision H1: 0.9810, Recall H1: 0.9810, F1 H1: 0.9810
H2: Train Loss: 0.1702, Val Loss: 0.5857, Accuracy: 0.9468
Precision H2: 0.9468, Recall H2: 0.9468, F1 H2: 0.9468
H3: Train Loss: 0.1554, Val Loss: 0.5072, Custom Accuracy: 0.9365
Precision H3: 0.9333967046894804,Recall H3:0.9135046828963178,F1 Score H3: 0.9233435704419505


Training: 100%|██████████| 231/231 [02:51<00:00,  1.34it/s]
Training: 100%|██████████| 231/231 [02:51<00:00,  1.34it/s]
Training: 100%|██████████| 231/231 [02:52<00:00,  1.34it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.64it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.61it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.64it/s]


Epoch 26/33
H1: Train Loss: 0.1524, Val Loss: 0.5273, Accuracy: 0.9848
Precision H1: 0.9848, Recall H1: 0.9848, F1 H1: 0.9848
H2: Train Loss: 0.1405, Val Loss: 0.5068, Accuracy: 0.9506
Precision H2: 0.9506, Recall H2: 0.9506, F1 H2: 0.9506
H3: Train Loss: 0.1247, Val Loss: 0.4180, Custom Accuracy: 0.9471
Precision H3: 0.9426652456120136,Recall H3:0.925546749406065,F1 Score H3: 0.93402756879936


Training: 100%|██████████| 231/231 [02:52<00:00,  1.34it/s]
Training: 100%|██████████| 231/231 [02:52<00:00,  1.34it/s]
Training: 100%|██████████| 231/231 [02:52<00:00,  1.34it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.52it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.47it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.49it/s]


Epoch 27/33
H1: Train Loss: 0.1179, Val Loss: 0.5301, Accuracy: 0.9810
Precision H1: 0.9810, Recall H1: 0.9810, F1 H1: 0.9810
H2: Train Loss: 0.1195, Val Loss: 0.5133, Accuracy: 0.9506
Precision H2: 0.9506, Recall H2: 0.9506, F1 H2: 0.9506
H3: Train Loss: 0.1078, Val Loss: 0.3981, Custom Accuracy: 0.9497
Precision H3: 0.9357214184210383,Recall H3:0.9471370412435052,F1 Score H3: 0.9413946238203976


Training: 100%|██████████| 231/231 [02:55<00:00,  1.31it/s]
Training: 100%|██████████| 231/231 [02:52<00:00,  1.34it/s]
Training: 100%|██████████| 231/231 [02:52<00:00,  1.34it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.62it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.61it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.65it/s]


Epoch 28/33
H1: Train Loss: 0.0999, Val Loss: 0.5072, Accuracy: 0.9848
Precision H1: 0.9848, Recall H1: 0.9848, F1 H1: 0.9848
H2: Train Loss: 0.1004, Val Loss: 0.5111, Accuracy: 0.9506
Precision H2: 0.9506, Recall H2: 0.9506, F1 H2: 0.9506
H3: Train Loss: 0.0928, Val Loss: 0.4071, Custom Accuracy: 0.9564
Precision H3: 0.9523221075502445,Recall H3:0.9429109124926617,F1 Score H3: 0.9475931433453042


Training: 100%|██████████| 231/231 [02:52<00:00,  1.34it/s]
Training: 100%|██████████| 231/231 [02:52<00:00,  1.34it/s]
Training: 100%|██████████| 231/231 [02:52<00:00,  1.34it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.63it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.62it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.64it/s]


Epoch 29/33
H1: Train Loss: 0.0900, Val Loss: 0.5128, Accuracy: 0.9810
Precision H1: 0.9810, Recall H1: 0.9810, F1 H1: 0.9810
H2: Train Loss: 0.0888, Val Loss: 0.5101, Accuracy: 0.9506
Precision H2: 0.9506, Recall H2: 0.9506, F1 H2: 0.9506
H3: Train Loss: 0.0812, Val Loss: 0.4231, Custom Accuracy: 0.9573
Precision H3: 0.9524367795280344,Recall H3:0.9446536247867047,F1 Score H3: 0.9485292362609659


Training: 100%|██████████| 231/231 [02:52<00:00,  1.34it/s]
Training: 100%|██████████| 231/231 [02:52<00:00,  1.34it/s]
Training: 100%|██████████| 231/231 [02:52<00:00,  1.34it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.64it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.62it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.65it/s]


Epoch 30/33
H1: Train Loss: 0.0825, Val Loss: 0.5081, Accuracy: 0.9848
Precision H1: 0.9848, Recall H1: 0.9848, F1 H1: 0.9848
H2: Train Loss: 0.0795, Val Loss: 0.4715, Accuracy: 0.9506
Precision H2: 0.9506, Recall H2: 0.9506, F1 H2: 0.9506
H3: Train Loss: 0.0750, Val Loss: 0.3719, Custom Accuracy: 0.9591
Precision H3: 0.9526585792745489,Recall H3:0.9470993201981796,F1 Score H3: 0.949870815710053


Training: 100%|██████████| 231/231 [02:53<00:00,  1.33it/s]
Training: 100%|██████████| 231/231 [02:53<00:00,  1.33it/s]
Training: 100%|██████████| 231/231 [02:52<00:00,  1.34it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.63it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.62it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.64it/s]


Epoch 31/33
H1: Train Loss: 0.0817, Val Loss: 0.5233, Accuracy: 0.9772
Precision H1: 0.9772, Recall H1: 0.9772, F1 H1: 0.9772
H2: Train Loss: 0.0753, Val Loss: 0.4994, Accuracy: 0.9544
Precision H2: 0.9544, Recall H2: 0.9544, F1 H2: 0.9544
H3: Train Loss: 0.0687, Val Loss: 0.4210, Custom Accuracy: 0.9596
Precision H3: 0.9537298569617961,Recall H3:0.9475656894858417,F1 Score H3: 0.9506377808371164


Training: 100%|██████████| 231/231 [02:52<00:00,  1.34it/s]
Training: 100%|██████████| 231/231 [02:52<00:00,  1.34it/s]
Training: 100%|██████████| 231/231 [02:52<00:00,  1.34it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.64it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.61it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.63it/s]


Epoch 32/33
H1: Train Loss: 0.0610, Val Loss: 0.5096, Accuracy: 0.9810
Precision H1: 0.9810, Recall H1: 0.9810, F1 H1: 0.9810
H2: Train Loss: 0.0601, Val Loss: 0.5250, Accuracy: 0.9620
Precision H2: 0.9620, Recall H2: 0.9620, F1 H2: 0.9620
H3: Train Loss: 0.0607, Val Loss: 0.4810, Custom Accuracy: 0.9608
Precision H3: 0.957880680789426,Recall H3:0.9454711427144888,F1 Score H3: 0.9516354576873616


Training: 100%|██████████| 231/231 [02:52<00:00,  1.34it/s]
Training: 100%|██████████| 231/231 [02:52<00:00,  1.34it/s]
Training: 100%|██████████| 231/231 [02:52<00:00,  1.34it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.62it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.61it/s]
Validation: 100%|██████████| 33/33 [00:09<00:00,  3.63it/s]


Epoch 33/33
H1: Train Loss: 0.0562, Val Loss: 0.5343, Accuracy: 0.9810
Precision H1: 0.9810, Recall H1: 0.9810, F1 H1: 0.9810
H2: Train Loss: 0.0671, Val Loss: 0.5253, Accuracy: 0.9506
Precision H2: 0.9506, Recall H2: 0.9506, F1 H2: 0.9506
H3: Train Loss: 0.0580, Val Loss: 0.4513, Custom Accuracy: 0.9617
Precision H3: 0.957876154263987,Recall H3:0.9512155778315474,F1 Score H3: 0.9545342470989061


In [4]:
# Create test datasets for H1, H2, and H3
test_dataset_h1 = TextDataset(test_texts_h1, test_labels_h1, test_labels_h2, test_labels_h3, tokenizer)
test_dataset_h2 = TextDataset(test_texts_h2, test_labels_h1, test_labels_h2, test_labels_h3, tokenizer)
test_dataset_h3 = TextDataset(test_texts_h3, test_labels_h1, test_labels_h2, test_labels_h3, tokenizer)

# Create test dataloaders for H1, H2, and H3
test_dataloader_h1 = DataLoader(test_dataset_h1, batch_size=8, shuffle=False)
test_dataloader_h2 = DataLoader(test_dataset_h2, batch_size=8, shuffle=False)
test_dataloader_h3 = DataLoader(test_dataset_h3, batch_size=8, shuffle=False)
# Testing loop for H1, H2, and H3
def test(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0.0
    all_preds_h1, all_preds_h2, all_preds_h3 = [], [], []
    all_labels_h1, all_labels_h2, all_labels_h3 = [], [], []

    with torch.no_grad():
        for batch in tqdm(dataloader, total=len(dataloader), desc="Testing"):
            input_ids = batch['ids'].to(device)
            attention_mask = batch['mask'].to(device)
            labels_h1 = batch['labels_h1'].to(device)  # No need to reshape labels_h1
            labels_h2 = batch['labels_h2'].to(device)  # No need to reshape labels_h2

            #labels_h1 = batch['labels_h1'].unsqueeze(1).to(device)  # Reshape labels_h1
            #labels_h2 = batch['labels_h2'].unsqueeze(1).to(device)  # Reshape labels_h2
            #labels_h3 = batch['labels_h3'].unsqueeze(1).to(device)  # Reshape labels_h3
            labels_h3 = batch['labels_h3'].to(device)  # No need to reshape labels_h3
            print("Labels H1 shape:", labels_h1.shape)
            print("Labels H2 shape:", labels_h2.shape)
            print("Labels H3 shape:", labels_h3.shape)

            logits_h1, logits_h2, logits_h3 = model(input_ids, attention_mask)
            
            print("Logits H1 shape:", logits_h1.shape)
            print("Logits H2 shape:", logits_h2.shape)
            print("Logits H3 shape:", logits_h3.shape)
            
            
            
            loss_h1 = criterion(logits_h1, labels_h1)
            loss_h2 = criterion(logits_h2, labels_h2)
            loss_h3 = criterion(logits_h3, labels_h3)

            loss = loss_h1 + loss_h2 + loss_h3
            total_loss += loss.item()

            preds_h1 = torch.sigmoid(logits_h1)
            preds_h2 = torch.sigmoid(logits_h2)
            preds_h3 = torch.sigmoid(logits_h3)

            all_preds_h1.extend(preds_h1.cpu().numpy())
            all_preds_h2.extend(preds_h2.cpu().numpy())
            all_preds_h3.extend(preds_h3.cpu().numpy())

            all_labels_h1.extend(labels_h1.cpu().numpy())
            all_labels_h2.extend(labels_h2.cpu().numpy())
            all_labels_h3.extend(labels_h3.cpu().numpy())

    return total_loss / len(dataloader), all_preds_h1, all_preds_h2, all_preds_h3, all_labels_h1, all_labels_h2, all_labels_h3

# Testing for H1
test_loss_h1, test_preds_h1, _, _, test_labels_h1, _, _ = test(model, test_dataloader_h1, criterion, device)
# Testing for H2
test_loss_h2, _, test_preds_h2, _, _, test_labels_h2, _ = test(model, test_dataloader_h2, criterion, device)
# Testing for H3
test_loss_h3, _, _, test_preds_h3, _, _, test_labels_h3 = test(model, test_dataloader_h3, criterion, device)

# Metrics calculation for H1
threshold_h1 = 0.5
test_preds_h1_binary = (np.array(test_preds_h1) > threshold_h1).astype(int)

acc_h1 = accuracy_score(test_labels_h1, test_preds_h1_binary)
precision_h1 = precision_score(test_labels_h1, test_preds_h1_binary, average='micro')
recall_h1 = recall_score(test_labels_h1, test_preds_h1_binary, average='micro')
f1_h1 = f1_score(test_labels_h1, test_preds_h1_binary, average='micro')

# Metrics calculation for H2
threshold_h2 = 0.5
test_preds_h2_binary = (np.array(test_preds_h2) > threshold_h2).astype(int)

acc_h2 = accuracy_score(test_labels_h2, test_preds_h2_binary)
precision_h2 = precision_score(test_labels_h2, test_preds_h2_binary, average='micro')
recall_h2 = recall_score(test_labels_h2, test_preds_h2_binary, average='micro')
f1_h2 = f1_score(test_labels_h2, test_preds_h2_binary, average='micro')

# Metrics calculation for H3
threshold_h3 = 0.5
test_preds_h3_binary = (np.array(test_preds_h3) > threshold_h3).astype(int)

# Custom accuracy calculation for H3
acc_h3_custom = custom_accuracy(np.array(test_labels_h3), test_preds_h3_binary)
# Calculate custom precision, recall, and F1 score for H3
precision_h3, recall_h3, f1_score_h3 = custom_precision_recall_f1(np.array(test_labels_h3), test_preds_h3_binary)

print("Testing Results:")
print(f"H1: Test Loss: {test_loss_h1:.4f}, Accuracy: {acc_h1:.4f}")
print(f"Precision H1: {precision_h1:.4f}, Recall H1: {recall_h1:.4f}, F1 H1: {f1_h1:.4f}")

print(f"H2: Test Loss: {test_loss_h2:.4f}, Accuracy: {acc_h2:.4f}")
print(f"Precision H2: {precision_h2:.4f}, Recall H2: {recall_h2:.4f}, F1 H2: {f1_h2:.4f}")

print(f"H3: Test Loss: {test_loss_h3:.4f}, Custom Accuracy: {acc_h3_custom:.4f}")
print(f"Precision H3: {precision_h3}, Recall H3:{recall_h3}, F1 Score H3: {f1_score_h3}")




Testing:   0%|          | 0/66 [00:00<?, ?it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:   2%|▏         | 1/66 [00:00<00:20,  3.17it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:   3%|▎         | 2/66 [00:00<00:18,  3.47it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:   5%|▍         | 3/66 [00:00<00:17,  3.53it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:   6%|▌         | 4/66 [00:01<00:17,  3.61it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:   8%|▊         | 5/66 [00:01<00:16,  3.64it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:   9%|▉         | 6/66 [00:01<00:17,  3.49it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  11%|█         | 7/66 [00:01<00:16,  3.53it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  12%|█▏        | 8/66 [00:02<00:16,  3.57it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  14%|█▎        | 9/66 [00:02<00:15,  3.61it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  15%|█▌        | 10/66 [00:02<00:15,  3.63it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  17%|█▋        | 11/66 [00:03<00:15,  3.62it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  18%|█▊        | 12/66 [00:03<00:14,  3.60it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  20%|█▉        | 13/66 [00:03<00:14,  3.59it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  21%|██        | 14/66 [00:03<00:14,  3.62it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  23%|██▎       | 15/66 [00:04<00:14,  3.60it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  24%|██▍       | 16/66 [00:04<00:13,  3.65it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  26%|██▌       | 17/66 [00:04<00:13,  3.65it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  27%|██▋       | 18/66 [00:05<00:13,  3.61it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  29%|██▉       | 19/66 [00:05<00:12,  3.63it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  30%|███       | 20/66 [00:05<00:12,  3.57it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  32%|███▏      | 21/66 [00:05<00:12,  3.49it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  33%|███▎      | 22/66 [00:06<00:12,  3.51it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  35%|███▍      | 23/66 [00:06<00:11,  3.59it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  36%|███▋      | 24/66 [00:06<00:11,  3.61it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  38%|███▊      | 25/66 [00:06<00:11,  3.56it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  39%|███▉      | 26/66 [00:07<00:11,  3.63it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  41%|████      | 27/66 [00:07<00:10,  3.57it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  42%|████▏     | 28/66 [00:07<00:10,  3.56it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  44%|████▍     | 29/66 [00:08<00:10,  3.47it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  45%|████▌     | 30/66 [00:08<00:10,  3.52it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  47%|████▋     | 31/66 [00:08<00:09,  3.51it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  48%|████▊     | 32/66 [00:08<00:09,  3.53it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  50%|█████     | 33/66 [00:09<00:09,  3.55it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  52%|█████▏    | 34/66 [00:09<00:09,  3.54it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  53%|█████▎    | 35/66 [00:09<00:08,  3.54it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  55%|█████▍    | 36/66 [00:10<00:08,  3.59it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  56%|█████▌    | 37/66 [00:10<00:08,  3.49it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  58%|█████▊    | 38/66 [00:10<00:07,  3.51it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  59%|█████▉    | 39/66 [00:10<00:07,  3.52it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  61%|██████    | 40/66 [00:11<00:07,  3.54it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  62%|██████▏   | 41/66 [00:11<00:07,  3.57it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  64%|██████▎   | 42/66 [00:11<00:06,  3.59it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  65%|██████▌   | 43/66 [00:12<00:06,  3.65it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  67%|██████▋   | 44/66 [00:12<00:05,  3.70it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  68%|██████▊   | 45/66 [00:12<00:05,  3.63it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  70%|██████▉   | 46/66 [00:12<00:05,  3.64it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  71%|███████   | 47/66 [00:13<00:05,  3.60it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  73%|███████▎  | 48/66 [00:13<00:05,  3.58it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  74%|███████▍  | 49/66 [00:13<00:04,  3.56it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  76%|███████▌  | 50/66 [00:14<00:04,  3.52it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  77%|███████▋  | 51/66 [00:14<00:04,  3.45it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  79%|███████▉  | 52/66 [00:14<00:03,  3.52it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  80%|████████  | 53/66 [00:14<00:03,  3.51it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  82%|████████▏ | 54/66 [00:15<00:03,  3.55it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  83%|████████▎ | 55/66 [00:15<00:03,  3.54it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  85%|████████▍ | 56/66 [00:15<00:02,  3.55it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  86%|████████▋ | 57/66 [00:15<00:02,  3.57it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  88%|████████▊ | 58/66 [00:16<00:02,  3.58it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  89%|████████▉ | 59/66 [00:16<00:01,  3.65it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  91%|█████████ | 60/66 [00:16<00:01,  3.56it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  92%|█████████▏| 61/66 [00:17<00:01,  3.62it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  94%|█████████▍| 62/66 [00:17<00:01,  3.64it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  95%|█████████▌| 63/66 [00:17<00:00,  3.59it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  97%|█████████▋| 64/66 [00:17<00:00,  3.62it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  98%|█████████▊| 65/66 [00:18<00:00,  3.59it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing: 100%|██████████| 66/66 [00:18<00:00,  3.57it/s]
Testing:   0%|          | 0/66 [00:00<?, ?it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:   2%|▏         | 1/66 [00:00<00:18,  3.56it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:   3%|▎         | 2/66 [00:00<00:17,  3.66it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:   5%|▍         | 3/66 [00:00<00:17,  3.61it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:   6%|▌         | 4/66 [00:01<00:17,  3.61it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:   8%|▊         | 5/66 [00:01<00:16,  3.63it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:   9%|▉         | 6/66 [00:01<00:17,  3.45it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  11%|█         | 7/66 [00:01<00:16,  3.49it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  12%|█▏        | 8/66 [00:02<00:16,  3.53it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  14%|█▎        | 9/66 [00:02<00:15,  3.58it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  15%|█▌        | 10/66 [00:02<00:15,  3.60it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  17%|█▋        | 11/66 [00:03<00:15,  3.59it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  18%|█▊        | 12/66 [00:03<00:15,  3.56it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  20%|█▉        | 13/66 [00:03<00:14,  3.56it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  21%|██        | 14/66 [00:03<00:14,  3.59it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  23%|██▎       | 15/66 [00:04<00:14,  3.58it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  24%|██▍       | 16/66 [00:04<00:13,  3.62it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  26%|██▌       | 17/66 [00:04<00:13,  3.64it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  27%|██▋       | 18/66 [00:05<00:13,  3.60it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  29%|██▉       | 19/66 [00:05<00:12,  3.62it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  30%|███       | 20/66 [00:05<00:12,  3.56it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  32%|███▏      | 21/66 [00:05<00:12,  3.48it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  33%|███▎      | 22/66 [00:06<00:12,  3.52it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  35%|███▍      | 23/66 [00:06<00:11,  3.61it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  36%|███▋      | 24/66 [00:06<00:11,  3.61it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  38%|███▊      | 25/66 [00:06<00:11,  3.59it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  39%|███▉      | 26/66 [00:07<00:10,  3.64it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  41%|████      | 27/66 [00:07<00:10,  3.57it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  42%|████▏     | 28/66 [00:07<00:10,  3.57it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  44%|████▍     | 29/66 [00:08<00:10,  3.47it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  45%|████▌     | 30/66 [00:08<00:10,  3.52it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  47%|████▋     | 31/66 [00:08<00:10,  3.46it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  48%|████▊     | 32/66 [00:08<00:09,  3.49it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  50%|█████     | 33/66 [00:09<00:09,  3.50it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  52%|█████▏    | 34/66 [00:09<00:09,  3.51it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  53%|█████▎    | 35/66 [00:09<00:08,  3.51it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  55%|█████▍    | 36/66 [00:10<00:08,  3.56it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  56%|█████▌    | 37/66 [00:10<00:08,  3.45it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  58%|█████▊    | 38/66 [00:10<00:08,  3.46it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  59%|█████▉    | 39/66 [00:10<00:07,  3.49it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  61%|██████    | 40/66 [00:11<00:07,  3.51it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  62%|██████▏   | 41/66 [00:11<00:07,  3.53it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  64%|██████▎   | 42/66 [00:11<00:06,  3.57it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  65%|██████▌   | 43/66 [00:12<00:06,  3.63it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  67%|██████▋   | 44/66 [00:12<00:05,  3.69it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  68%|██████▊   | 45/66 [00:12<00:05,  3.62it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  70%|██████▉   | 46/66 [00:12<00:05,  3.64it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  71%|███████   | 47/66 [00:13<00:05,  3.59it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  73%|███████▎  | 48/66 [00:13<00:05,  3.56it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  74%|███████▍  | 49/66 [00:13<00:04,  3.53it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  76%|███████▌  | 50/66 [00:14<00:04,  3.48it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  77%|███████▋  | 51/66 [00:14<00:04,  3.39it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  79%|███████▉  | 52/66 [00:14<00:04,  3.48it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  80%|████████  | 53/66 [00:14<00:03,  3.48it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  82%|████████▏ | 54/66 [00:15<00:03,  3.52it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  83%|████████▎ | 55/66 [00:15<00:03,  3.51it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  85%|████████▍ | 56/66 [00:15<00:02,  3.52it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  86%|████████▋ | 57/66 [00:16<00:02,  3.53it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  88%|████████▊ | 58/66 [00:16<00:02,  3.54it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  89%|████████▉ | 59/66 [00:16<00:01,  3.62it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  91%|█████████ | 60/66 [00:16<00:01,  3.51it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  92%|█████████▏| 61/66 [00:17<00:01,  3.58it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  94%|█████████▍| 62/66 [00:17<00:01,  3.61it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  95%|█████████▌| 63/66 [00:17<00:00,  3.56it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  97%|█████████▋| 64/66 [00:18<00:00,  3.60it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  98%|█████████▊| 65/66 [00:18<00:00,  3.56it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing: 100%|██████████| 66/66 [00:18<00:00,  3.55it/s]
Testing:   0%|          | 0/66 [00:00<?, ?it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:   2%|▏         | 1/66 [00:00<00:18,  3.57it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:   3%|▎         | 2/66 [00:00<00:17,  3.66it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:   5%|▍         | 3/66 [00:00<00:17,  3.61it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:   6%|▌         | 4/66 [00:01<00:16,  3.65it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:   8%|▊         | 5/66 [00:01<00:16,  3.66it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:   9%|▉         | 6/66 [00:01<00:17,  3.48it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  11%|█         | 7/66 [00:01<00:16,  3.50it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  12%|█▏        | 8/66 [00:02<00:16,  3.54it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  14%|█▎        | 9/66 [00:02<00:15,  3.59it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  15%|█▌        | 10/66 [00:02<00:15,  3.58it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  17%|█▋        | 11/66 [00:03<00:15,  3.58it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  18%|█▊        | 12/66 [00:03<00:15,  3.56it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  20%|█▉        | 13/66 [00:03<00:14,  3.56it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  21%|██        | 14/66 [00:03<00:14,  3.59it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  23%|██▎       | 15/66 [00:04<00:14,  3.56it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  24%|██▍       | 16/66 [00:04<00:13,  3.61it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  26%|██▌       | 17/66 [00:04<00:13,  3.62it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  27%|██▋       | 18/66 [00:05<00:13,  3.56it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  29%|██▉       | 19/66 [00:05<00:13,  3.59it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  30%|███       | 20/66 [00:05<00:13,  3.51it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  32%|███▏      | 21/66 [00:05<00:13,  3.44it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  33%|███▎      | 22/66 [00:06<00:12,  3.46it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  35%|███▍      | 23/66 [00:06<00:12,  3.56it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  36%|███▋      | 24/66 [00:06<00:11,  3.57it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  38%|███▊      | 25/66 [00:07<00:11,  3.56it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  39%|███▉      | 26/66 [00:07<00:11,  3.61it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  41%|████      | 27/66 [00:07<00:10,  3.55it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  42%|████▏     | 28/66 [00:07<00:10,  3.55it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  44%|████▍     | 29/66 [00:08<00:10,  3.47it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  45%|████▌     | 30/66 [00:08<00:10,  3.52it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  47%|████▋     | 31/66 [00:08<00:10,  3.47it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  48%|████▊     | 32/66 [00:09<00:09,  3.50it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  50%|█████     | 33/66 [00:09<00:09,  3.51it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  52%|█████▏    | 34/66 [00:09<00:09,  3.52it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  53%|█████▎    | 35/66 [00:09<00:08,  3.51it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  55%|█████▍    | 36/66 [00:10<00:08,  3.55it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  56%|█████▌    | 37/66 [00:10<00:08,  3.46it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  58%|█████▊    | 38/66 [00:10<00:08,  3.48it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  59%|█████▉    | 39/66 [00:11<00:07,  3.52it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  61%|██████    | 40/66 [00:11<00:07,  3.53it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  62%|██████▏   | 41/66 [00:11<00:07,  3.54it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  64%|██████▎   | 42/66 [00:11<00:06,  3.56it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  65%|██████▌   | 43/66 [00:12<00:06,  3.63it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  67%|██████▋   | 44/66 [00:12<00:06,  3.64it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  68%|██████▊   | 45/66 [00:12<00:05,  3.57it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  70%|██████▉   | 46/66 [00:12<00:05,  3.60it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  71%|███████   | 47/66 [00:13<00:05,  3.55it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  73%|███████▎  | 48/66 [00:13<00:05,  3.54it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  74%|███████▍  | 49/66 [00:13<00:04,  3.52it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  76%|███████▌  | 50/66 [00:14<00:04,  3.46it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  77%|███████▋  | 51/66 [00:14<00:04,  3.40it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  79%|███████▉  | 52/66 [00:14<00:04,  3.49it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  80%|████████  | 53/66 [00:14<00:03,  3.48it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  82%|████████▏ | 54/66 [00:15<00:03,  3.53it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  83%|████████▎ | 55/66 [00:15<00:03,  3.53it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  85%|████████▍ | 56/66 [00:15<00:02,  3.56it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  86%|████████▋ | 57/66 [00:16<00:02,  3.56it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  88%|████████▊ | 58/66 [00:16<00:02,  3.55it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  89%|████████▉ | 59/66 [00:16<00:01,  3.61it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  91%|█████████ | 60/66 [00:16<00:01,  3.50it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  92%|█████████▏| 61/66 [00:17<00:01,  3.57it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  94%|█████████▍| 62/66 [00:17<00:01,  3.62it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  95%|█████████▌| 63/66 [00:17<00:00,  3.56it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  97%|█████████▋| 64/66 [00:18<00:00,  3.59it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing:  98%|█████████▊| 65/66 [00:18<00:00,  3.56it/s]

Labels H1 shape: torch.Size([8, 1])
Labels H2 shape: torch.Size([8, 1])
Labels H3 shape: torch.Size([8, 13])
Logits H1 shape: torch.Size([8, 1])
Logits H2 shape: torch.Size([8, 1])
Logits H3 shape: torch.Size([8, 13])


Testing: 100%|██████████| 66/66 [00:18<00:00,  3.55it/s]

Testing Results:
H1: Test Loss: 0.0697, Accuracy: 0.9981
Precision H1: 0.9981, Recall H1: 0.9981, F1 H1: 0.9981
H2: Test Loss: 0.0697, Accuracy: 0.9962
Precision H2: 0.9962, Recall H2: 0.9962, F1 H2: 0.9962
H3: Test Loss: 0.0697, Custom Accuracy: 0.9914
Precision H3: 0.9921589405964406, Recall H3:0.9905498436748438, F1 Score H3: 0.9913537391923244





In [None]:
#unseen data

In [2]:
# Assuming the new unseen datasets are similar to 'model1_2636.csv', 'model2_2636.csv', 'model3_2636.csv'
import pandas as pd
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset, random_split
from transformers import RobertaTokenizer, RobertaModel, AdamW, get_linear_schedule_with_warmup
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from tqdm import tqdm
import numpy as np
import torch

import pandas as pd

# Loading and preprocessing new_data_h1
new_data_h1 = pd.read_csv("112_H_100_O_NEWMODEL1.CSV")
new_data_h1['text'].fillna("", inplace=True)  # Replace NaN values with empty strings
new_texts_h1 = new_data_h1['text'].tolist()
new_labels_h1 = new_data_h1.iloc[:, 1].values.tolist()

# Loading and preprocessing new_data_h2
new_data_h2 = pd.read_csv("112_H_100_O_NEWMODEL2.CSV")
new_data_h2['text'].fillna("", inplace=True)  # Replace NaN values with empty strings
new_texts_h2 = new_data_h2['text'].tolist()
new_labels_h2 = new_data_h2.iloc[:, 1].values.tolist()

# Loading and preprocessing new_data_h3 for multi-label classification
new_data_h3 = pd.read_csv("112_H_100_O_NEWMODEL3.CSV")
new_data_h3['text'].fillna("", inplace=True)  # Replace NaN values with empty strings

# Replace NaN values in label columns with 0 for new_data_h3
# Assuming all label columns are binary, fill NaN values with 0
for label_col in new_data_h3.columns[1:]:  # Skip the text column, only fill NaN in label columns
    new_data_h3[label_col].fillna(0, inplace=True)

new_texts_h3 = new_data_h3['text'].tolist()
new_labels_h3 = new_data_h3.iloc[:, 1:].values.tolist()  # Extract labels after filling NaN values



from transformers import AutoTokenizer, AutoModel, AdamW, get_linear_schedule_with_warmup

# Import and initialize the Bio_ClinicalBERT tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
bio_clinical_bert_model = AutoModel.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")


# Define the TextDataset class
class TextDataset(Dataset):
    def __init__(self, texts, labels_h1, labels_h2, labels_h3, tokenizer, max_length=512):
        self.texts = texts
        self.labels_h1 = labels_h1  # Now expecting a list of binary values for H1
        self.labels_h2 = np.array(labels_h2).reshape(-1, 1)  # Reshape labels_h2 to [batch_size, 1]
        self.labels_h3 = labels_h3  # Multi-label but binary values for H3
        self.tokenizer = tokenizer
        self.max_length = max_length

    # Rest of the code remains unchanged

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        inputs = self.tokenizer.encode_plus(
            text, None, add_special_tokens=True, max_length=self.max_length,
            padding='max_length', return_token_type_ids=True, truncation=True
        )
        return {
            'ids': torch.tensor(inputs['input_ids'], dtype=torch.long),
            'mask': torch.tensor(inputs['attention_mask'], dtype=torch.long),
            'labels_h1': torch.tensor(self.labels_h1[idx], dtype=torch.float).unsqueeze(-1),
            'labels_h2': torch.tensor(self.labels_h2[idx], dtype=torch.float),
            'labels_h3': torch.tensor(self.labels_h3[idx], dtype=torch.float)
        }


# Define the model architecture for H1, H2, and H3
class HierarchicalClassifier(nn.Module):
    def __init__(self, num_labels_h3):
        super(HierarchicalClassifier, self).__init__()
        self.bert = bio_clinical_bert_model  # Correctly assigned Bio_ClinicalBERT model
        self.dropout = nn.Dropout(0.1)
        self.fc_h1 = nn.Linear(self.bert.config.hidden_size, 1)
        self.fc_h2 = nn.Linear(self.bert.config.hidden_size, 1)
        self.fc_h3 = nn.Linear(self.bert.config.hidden_size, num_labels_h3)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids, attention_mask)  # Correctly refer to self.bert here
        pooled_output = outputs.last_hidden_state[:, 0, :]
        logits_h1 = self.fc_h1(self.dropout(pooled_output))
        logits_h2 = self.fc_h2(self.dropout(pooled_output))
        logits_h3 = self.fc_h3(self.dropout(pooled_output))
        return logits_h1, logits_h2, logits_h3




# Initialize and move the model to the appropriate device (CPU/GPU)
#model = HierarchicalClassifier(num_labels_h3=len(train_labels_h3[0]))

#device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#model.to(device)

# Define loss function and optimizer
criterion = nn.BCEWithLogitsLoss()

#from transformers import RobertaTokenizer

# Initialize the tokenizer
#tokenizer = RobertaTokenizer.from_pretrained('roberta-base')

# Preprocess the texts and labels in the same way as the original datasets
# Assuming the preprocessing steps are included in the TextDataset class
# Create TextDataset instances for the new data
new_dataset_h1 = TextDataset(new_texts_h1, new_labels_h1, new_labels_h2, new_labels_h3, tokenizer)
new_dataset_h2 = TextDataset(new_texts_h2, new_labels_h1, new_labels_h2, new_labels_h3, tokenizer)
new_dataset_h3 = TextDataset(new_texts_h3, new_labels_h1, new_labels_h2, new_labels_h3, tokenizer)

# Create DataLoader instances for the new datasets
#new_dataloader_h1 = DataLoader(new_dataset_h1, batch_size=8, shuffle=False)
#new_dataloader_h2 = DataLoader(new_dataset_h2, batch_size=8, shuffle=False)
#new_dataloader_h3 = DataLoader(new_dataset_h3, batch_size=8, shuffle=False)


new_dataloader_h1 = DataLoader(new_dataset_h1, batch_size=8, shuffle=False, drop_last=True)
new_dataloader_h2 = DataLoader(new_dataset_h2, batch_size=8, shuffle=False, drop_last=True)
new_dataloader_h3 = DataLoader(new_dataset_h3, batch_size=8, shuffle=False, drop_last=True)

# Load the saved model
model_path = 'model_bio.pth'  # Path to your saved model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = HierarchicalClassifier(num_labels_h3=len(new_labels_h3[0]))  # Make sure to provide the correct number of labels for H3
model.load_state_dict(torch.load(model_path))
model.eval()  # Set the model to evaluation mode
model.to(device)  # Move the model to the appropriate device
def test(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0.0
    all_preds_h1, all_preds_h2, all_preds_h3 = [], [], []
    all_labels_h1, all_labels_h2, all_labels_h3 = [], [], []

    with torch.no_grad():
        for batch in tqdm(dataloader, total=len(dataloader), desc="Testing"):
            input_ids = batch['ids'].to(device)
            attention_mask = batch['mask'].to(device)
            
            # Ensuring labels are correctly shaped and moved to the device
            labels_h1 = batch['labels_h1'].to(device)  # Shape: [batch_size, 1]
            labels_h2 = batch['labels_h2'].to(device)  # Shape: [batch_size, 1]
            labels_h3 = batch['labels_h3'].to(device)  # Shape: [batch_size, num_labels_h3]
            
            logits_h1, logits_h2, logits_h3 = model(input_ids, attention_mask)
            
            # Compute loss for each hierarchy
            loss_h1 = criterion(logits_h1, labels_h1)
            loss_h2 = criterion(logits_h2, labels_h2)
            loss_h3 = criterion(logits_h3, labels_h3)

            loss = loss_h1 + loss_h2 + loss_h3
            total_loss += loss.item()

            # Convert logits to probabilities for binary classification/multi-label classification
            preds_h1 = torch.sigmoid(logits_h1)
            preds_h2 = torch.sigmoid(logits_h2)
            preds_h3 = torch.sigmoid(logits_h3)

            # Store predictions and labels for metric calculation
            all_preds_h1.extend(preds_h1.cpu().numpy())
            all_preds_h2.extend(preds_h2.cpu().numpy())
            all_preds_h3.extend(preds_h3.cpu().numpy())

            all_labels_h1.extend(labels_h1.cpu().numpy())
            all_labels_h2.extend(labels_h2.cpu().numpy())
            all_labels_h3.extend(labels_h3.cpu().numpy())

    return total_loss / len(dataloader), all_preds_h1, all_preds_h2, all_preds_h3, all_labels_h1, all_labels_h2, all_labels_h3







# Testing the model on the new unseen data
test_loss_h1, test_preds_h1, _, _, test_labels_h1, _, _ = test(model, new_dataloader_h1, criterion, device)
test_loss_h2, _, test_preds_h2, _, _, test_labels_h2, _ = test(model, new_dataloader_h2, criterion, device)
test_loss_h3, _, _, test_preds_h3, _, _, test_labels_h3 = test(model, new_dataloader_h3, criterion, device)

# Calculate and print the metrics for H1, H2, and H3 using the same approach as before
# Use the appropriate threshold and custom functions for metrics calculation

# Metrics calculation for H1
threshold_h1 = 0.5
test_preds_h1_binary = (np.array(test_preds_h1) > threshold_h1).astype(int)

acc_h1 = accuracy_score(test_labels_h1, test_preds_h1_binary)
precision_h1 = precision_score(test_labels_h1, test_preds_h1_binary, average='micro')
recall_h1 = recall_score(test_labels_h1, test_preds_h1_binary, average='micro')
f1_h1 = f1_score(test_labels_h1, test_preds_h1_binary, average='micro')

# Metrics calculation for H2
threshold_h2 = 0.5
test_preds_h2_binary = (np.array(test_preds_h2) > threshold_h2).astype(int)

acc_h2 = accuracy_score(test_labels_h2, test_preds_h2_binary)
precision_h2 = precision_score(test_labels_h2, test_preds_h2_binary, average='micro')
recall_h2 = recall_score(test_labels_h2, test_preds_h2_binary, average='micro')
f1_h2 = f1_score(test_labels_h2, test_preds_h2_binary, average='micro')

# Metrics calculation for H3
threshold_h3 = 0.5
test_preds_h3_binary = (np.array(test_preds_h3) > threshold_h3).astype(int)

import numpy as np

def custom_accuracy(y_true, y_pred):
    correct_predictions = np.equal(y_true, y_pred)
    sample_accuracy = np.sum(correct_predictions, axis=1) / y_true.shape[1]
    return np.mean(sample_accuracy)

def custom_precision_recall_f1(y_true, y_pred):
    true_positives = np.sum((y_true == 1) & (y_pred == 1), axis=1)
    false_positives = np.sum((y_true == 0) & (y_pred == 1), axis=1)
    false_negatives = np.sum((y_true == 1) & (y_pred == 0), axis=1)

    precision = np.mean(true_positives / (true_positives + false_positives + 1e-8))
    recall = np.mean(true_positives / (true_positives + false_negatives + 1e-8))
    f1 = 2 * (precision * recall) / (precision + recall + 1e-8)

    return precision, recall, f1




# Custom accuracy calculation for H3
acc_h3_custom = custom_accuracy(np.array(test_labels_h3), test_preds_h3_binary)
# Calculate custom precision, recall, and F1 score for H3
precision_h3, recall_h3, f1_score_h3 = custom_precision_recall_f1(np.array(test_labels_h3), test_preds_h3_binary)

print("Testing Results:")
print(f"H1: Test Loss: {test_loss_h1:.4f}, Accuracy: {acc_h1:.4f}")
print(f"Precision H1: {precision_h1:.4f}, Recall H1: {recall_h1:.4f}, F1 H1: {f1_h1:.4f}")

print(f"H2: Test Loss: {test_loss_h2:.4f}, Accuracy: {acc_h2:.4f}")
print(f"Precision H2: {precision_h2:.4f}, Recall H2: {recall_h2:.4f}, F1 H2: {f1_h2:.4f}")

print(f"H3: Test Loss: {test_loss_h3:.4f}, Custom Accuracy: {acc_h3_custom:.4f}")
print(f"Precision H3: {precision_h3}, Recall H3:{recall_h3}, F1 Score H3: {f1_score_h3}")





Testing: 100%|██████████| 26/26 [00:08<00:00,  3.16it/s]
Testing: 100%|██████████| 26/26 [00:07<00:00,  3.42it/s]
Testing: 100%|██████████| 26/26 [00:07<00:00,  3.39it/s]

Testing Results:
H1: Test Loss: 0.6649, Accuracy: 0.9760
Precision H1: 0.9760, Recall H1: 0.9760, F1 H1: 0.9760
H2: Test Loss: 0.6649, Accuracy: 0.9856
Precision H2: 0.9856, Recall H2: 0.9856, F1 H2: 0.9856
H3: Test Loss: 0.6649, Custom Accuracy: 0.8964
Precision H3: 0.7931394973712151, Recall H3:0.9034455100413692, F1 Score H3: 0.8447066435011861





In [None]:
#Testing_Chatgpt_Eq_TEXT

In [1]:
# Assuming the new unseen datasets are similar to 'model1_2636.csv', 'model2_2636.csv', 'model3_2636.csv'
import pandas as pd
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset, random_split
from transformers import RobertaTokenizer, RobertaModel, AdamW, get_linear_schedule_with_warmup
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from tqdm import tqdm
import numpy as np
import torch

import pandas as pd

# Loading and preprocessing new_data_h1
new_data_h1 = pd.read_csv("P3_Chatgpt_Eq_Model1_508.csv")
new_data_h1['text'].fillna("", inplace=True)  # Replace NaN values with empty strings
new_texts_h1 = new_data_h1['text'].tolist()
new_labels_h1 = new_data_h1.iloc[:, 1].values.tolist()

# Loading and preprocessing new_data_h2
new_data_h2 = pd.read_csv("P3_Chatgpt_Eq_Model2_508.csv")
new_data_h2['text'].fillna("", inplace=True)  # Replace NaN values with empty strings
new_texts_h2 = new_data_h2['text'].tolist()
new_labels_h2 = new_data_h2.iloc[:, 1].values.tolist()

# Loading and preprocessing new_data_h3 for multi-label classification
new_data_h3 = pd.read_csv("P3_Chatgpt_Eq_Model3_508.csv")
new_data_h3['text'].fillna("", inplace=True)  # Replace NaN values with empty strings

# Replace NaN values in label columns with 0 for new_data_h3
# Assuming all label columns are binary, fill NaN values with 0
for label_col in new_data_h3.columns[1:]:  # Skip the text column, only fill NaN in label columns
    new_data_h3[label_col].fillna(0, inplace=True)

new_texts_h3 = new_data_h3['text'].tolist()
new_labels_h3 = new_data_h3.iloc[:, 1:].values.tolist()  # Extract labels after filling NaN values



from transformers import AutoTokenizer, AutoModel, AdamW, get_linear_schedule_with_warmup

# Import and initialize the Bio_ClinicalBERT tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
bio_clinical_bert_model = AutoModel.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")


# Define the TextDataset class
class TextDataset(Dataset):
    def __init__(self, texts, labels_h1, labels_h2, labels_h3, tokenizer, max_length=512):
        self.texts = texts
        self.labels_h1 = labels_h1  # Now expecting a list of binary values for H1
        self.labels_h2 = np.array(labels_h2).reshape(-1, 1)  # Reshape labels_h2 to [batch_size, 1]
        self.labels_h3 = labels_h3  # Multi-label but binary values for H3
        self.tokenizer = tokenizer
        self.max_length = max_length

    # Rest of the code remains unchanged

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        inputs = self.tokenizer.encode_plus(
            text, None, add_special_tokens=True, max_length=self.max_length,
            padding='max_length', return_token_type_ids=True, truncation=True
        )
        return {
            'ids': torch.tensor(inputs['input_ids'], dtype=torch.long),
            'mask': torch.tensor(inputs['attention_mask'], dtype=torch.long),
            'labels_h1': torch.tensor(self.labels_h1[idx], dtype=torch.float).unsqueeze(-1),
            'labels_h2': torch.tensor(self.labels_h2[idx], dtype=torch.float),
            'labels_h3': torch.tensor(self.labels_h3[idx], dtype=torch.float)
        }


# Define the model architecture for H1, H2, and H3
class HierarchicalClassifier(nn.Module):
    def __init__(self, num_labels_h3):
        super(HierarchicalClassifier, self).__init__()
        self.bert = bio_clinical_bert_model  # Correctly assigned Bio_ClinicalBERT model
        self.dropout = nn.Dropout(0.1)
        self.fc_h1 = nn.Linear(self.bert.config.hidden_size, 1)
        self.fc_h2 = nn.Linear(self.bert.config.hidden_size, 1)
        self.fc_h3 = nn.Linear(self.bert.config.hidden_size, num_labels_h3)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids, attention_mask)  # Correctly refer to self.bert here
        pooled_output = outputs.last_hidden_state[:, 0, :]
        logits_h1 = self.fc_h1(self.dropout(pooled_output))
        logits_h2 = self.fc_h2(self.dropout(pooled_output))
        logits_h3 = self.fc_h3(self.dropout(pooled_output))
        return logits_h1, logits_h2, logits_h3




# Initialize and move the model to the appropriate device (CPU/GPU)
#model = HierarchicalClassifier(num_labels_h3=len(train_labels_h3[0]))

#device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#model.to(device)

# Define loss function and optimizer
criterion = nn.BCEWithLogitsLoss()

#from transformers import RobertaTokenizer

# Initialize the tokenizer
#tokenizer = RobertaTokenizer.from_pretrained('roberta-base')

# Preprocess the texts and labels in the same way as the original datasets
# Assuming the preprocessing steps are included in the TextDataset class
# Create TextDataset instances for the new data
new_dataset_h1 = TextDataset(new_texts_h1, new_labels_h1, new_labels_h2, new_labels_h3, tokenizer)
new_dataset_h2 = TextDataset(new_texts_h2, new_labels_h1, new_labels_h2, new_labels_h3, tokenizer)
new_dataset_h3 = TextDataset(new_texts_h3, new_labels_h1, new_labels_h2, new_labels_h3, tokenizer)

# Create DataLoader instances for the new datasets
#new_dataloader_h1 = DataLoader(new_dataset_h1, batch_size=8, shuffle=False)
#new_dataloader_h2 = DataLoader(new_dataset_h2, batch_size=8, shuffle=False)
#new_dataloader_h3 = DataLoader(new_dataset_h3, batch_size=8, shuffle=False)


new_dataloader_h1 = DataLoader(new_dataset_h1, batch_size=8, shuffle=False, drop_last=True)
new_dataloader_h2 = DataLoader(new_dataset_h2, batch_size=8, shuffle=False, drop_last=True)
new_dataloader_h3 = DataLoader(new_dataset_h3, batch_size=8, shuffle=False, drop_last=True)

# Load the saved model
model_path = 'model_bio.pth'  # Path to your saved model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = HierarchicalClassifier(num_labels_h3=len(new_labels_h3[0]))  # Make sure to provide the correct number of labels for H3
model.load_state_dict(torch.load(model_path))
model.eval()  # Set the model to evaluation mode
model.to(device)  # Move the model to the appropriate device
def test(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0.0
    all_preds_h1, all_preds_h2, all_preds_h3 = [], [], []
    all_labels_h1, all_labels_h2, all_labels_h3 = [], [], []

    with torch.no_grad():
        for batch in tqdm(dataloader, total=len(dataloader), desc="Testing"):
            input_ids = batch['ids'].to(device)
            attention_mask = batch['mask'].to(device)
            
            # Ensuring labels are correctly shaped and moved to the device
            labels_h1 = batch['labels_h1'].to(device)  # Shape: [batch_size, 1]
            labels_h2 = batch['labels_h2'].to(device)  # Shape: [batch_size, 1]
            labels_h3 = batch['labels_h3'].to(device)  # Shape: [batch_size, num_labels_h3]
            
            logits_h1, logits_h2, logits_h3 = model(input_ids, attention_mask)
            
            # Compute loss for each hierarchy
            loss_h1 = criterion(logits_h1, labels_h1)
            loss_h2 = criterion(logits_h2, labels_h2)
            loss_h3 = criterion(logits_h3, labels_h3)

            loss = loss_h1 + loss_h2 + loss_h3
            total_loss += loss.item()

            # Convert logits to probabilities for binary classification/multi-label classification
            preds_h1 = torch.sigmoid(logits_h1)
            preds_h2 = torch.sigmoid(logits_h2)
            preds_h3 = torch.sigmoid(logits_h3)

            # Store predictions and labels for metric calculation
            all_preds_h1.extend(preds_h1.cpu().numpy())
            all_preds_h2.extend(preds_h2.cpu().numpy())
            all_preds_h3.extend(preds_h3.cpu().numpy())

            all_labels_h1.extend(labels_h1.cpu().numpy())
            all_labels_h2.extend(labels_h2.cpu().numpy())
            all_labels_h3.extend(labels_h3.cpu().numpy())

    return total_loss / len(dataloader), all_preds_h1, all_preds_h2, all_preds_h3, all_labels_h1, all_labels_h2, all_labels_h3







# Testing the model on the new unseen data
test_loss_h1, test_preds_h1, _, _, test_labels_h1, _, _ = test(model, new_dataloader_h1, criterion, device)
test_loss_h2, _, test_preds_h2, _, _, test_labels_h2, _ = test(model, new_dataloader_h2, criterion, device)
test_loss_h3, _, _, test_preds_h3, _, _, test_labels_h3 = test(model, new_dataloader_h3, criterion, device)

# Calculate and print the metrics for H1, H2, and H3 using the same approach as before
# Use the appropriate threshold and custom functions for metrics calculation

# Metrics calculation for H1
threshold_h1 = 0.5
test_preds_h1_binary = (np.array(test_preds_h1) > threshold_h1).astype(int)

acc_h1 = accuracy_score(test_labels_h1, test_preds_h1_binary)
precision_h1 = precision_score(test_labels_h1, test_preds_h1_binary, average='micro')
recall_h1 = recall_score(test_labels_h1, test_preds_h1_binary, average='micro')
f1_h1 = f1_score(test_labels_h1, test_preds_h1_binary, average='micro')

# Metrics calculation for H2
threshold_h2 = 0.5
test_preds_h2_binary = (np.array(test_preds_h2) > threshold_h2).astype(int)

acc_h2 = accuracy_score(test_labels_h2, test_preds_h2_binary)
precision_h2 = precision_score(test_labels_h2, test_preds_h2_binary, average='micro')
recall_h2 = recall_score(test_labels_h2, test_preds_h2_binary, average='micro')
f1_h2 = f1_score(test_labels_h2, test_preds_h2_binary, average='micro')

# Metrics calculation for H3
threshold_h3 = 0.5
test_preds_h3_binary = (np.array(test_preds_h3) > threshold_h3).astype(int)

import numpy as np

def custom_accuracy(y_true, y_pred):
    correct_predictions = np.equal(y_true, y_pred)
    sample_accuracy = np.sum(correct_predictions, axis=1) / y_true.shape[1]
    return np.mean(sample_accuracy)

def custom_precision_recall_f1(y_true, y_pred):
    true_positives = np.sum((y_true == 1) & (y_pred == 1), axis=1)
    false_positives = np.sum((y_true == 0) & (y_pred == 1), axis=1)
    false_negatives = np.sum((y_true == 1) & (y_pred == 0), axis=1)

    precision = np.mean(true_positives / (true_positives + false_positives + 1e-8))
    recall = np.mean(true_positives / (true_positives + false_negatives + 1e-8))
    f1 = 2 * (precision * recall) / (precision + recall + 1e-8)

    return precision, recall, f1




# Custom accuracy calculation for H3
acc_h3_custom = custom_accuracy(np.array(test_labels_h3), test_preds_h3_binary)
# Calculate custom precision, recall, and F1 score for H3
precision_h3, recall_h3, f1_score_h3 = custom_precision_recall_f1(np.array(test_labels_h3), test_preds_h3_binary)

print("Testing Results:")
print(f"H1: Test Loss: {test_loss_h1:.4f}, Accuracy: {acc_h1:.4f}")
print(f"Precision H1: {precision_h1:.4f}, Recall H1: {recall_h1:.4f}, F1 H1: {f1_h1:.4f}")

print(f"H2: Test Loss: {test_loss_h2:.4f}, Accuracy: {acc_h2:.4f}")
print(f"Precision H2: {precision_h2:.4f}, Recall H2: {recall_h2:.4f}, F1 H2: {f1_h2:.4f}")

print(f"H3: Test Loss: {test_loss_h3:.4f}, Custom Accuracy: {acc_h3_custom:.4f}")
print(f"Precision H3: {precision_h3}, Recall H3:{recall_h3}, F1 Score H3: {f1_score_h3}")





  return self.fget.__get__(instance, owner)()
    Found GPU1 Tesla K40c which is of cuda capability 3.5.
    PyTorch no longer supports this GPU because it is too old.
    The minimum cuda capability supported by this library is 3.7.
    
    Found GPU2 Tesla K40c which is of cuda capability 3.5.
    PyTorch no longer supports this GPU because it is too old.
    The minimum cuda capability supported by this library is 3.7.
    
Testing: 100%|██████████| 63/63 [00:35<00:00,  1.78it/s]
Testing: 100%|██████████| 63/63 [00:34<00:00,  1.82it/s]
Testing: 100%|██████████| 63/63 [00:34<00:00,  1.83it/s]

Testing Results:
H1: Test Loss: 2.7799, Accuracy: 0.8135
Precision H1: 0.8135, Recall H1: 0.8135, F1 H1: 0.8135
H2: Test Loss: 2.7933, Accuracy: 0.5397
Precision H2: 0.5397, Recall H2: 0.5397, F1 H2: 0.5397
H3: Test Loss: 2.7933, Custom Accuracy: 0.6606
Precision H3: 0.5327876972830001, Recall H3:0.5802125549192283, F1 Score H3: 0.5554897362889103



