In [None]:
from google.colab import drive
drive.mount('/content/drive')

ValueError: mount failed

In [None]:
pwd

#### Import MNIST

In [None]:
import numpy as np
import os
import struct

def read_idx(filename):
    with open(filename, 'rb') as f:
        zero, data_type, dims = struct.unpack('>HBB', f.read(4))
        shape = tuple(struct.unpack('>I', f.read(4))[0] for d in range(dims))
        return np.frombuffer(f.read(), dtype=np.uint8).reshape(shape)

def load_mnist(image_path, label_path):
    images = read_idx(image_path)
    labels = read_idx(label_path)
    return images, labels

train_image_path = './drive/MyDrive/MNIST/train-images-idx3-ubyte/train-images-idx3-ubyte'
train_label_path = './drive/MyDrive/MNIST/train-labels-idx1-ubyte/train-labels-idx1-ubyte'
test_image_path =  './drive/MyDrive/MNIST/t10k-images-idx3-ubyte/t10k-images-idx3-ubyte'
test_label_path =  './drive/MyDrive/MNIST/t10k-labels-idx1-ubyte/t10k-labels-idx1-ubyte'

In [None]:
train_images, train_labels = load_mnist(train_image_path, train_label_path)
test_images, test_labels = load_mnist(test_image_path, test_label_path)
val_images, val_labels = train_images[50000:], train_labels[50000:]
train_images, train_labels = train_images[:50000], train_labels[:50000]
print(f'Train images shape: {train_images.shape}')
print(f'Train labels shape: {train_labels.shape}')
print(f'Test images shape: {test_images.shape}')
print(f'Test labels shape: {test_labels.shape}')

#### Implementation of Resnet Backbone and AVG Aggregator for feature extraction

In [None]:
import torchvision
import torch.nn as nn
import torch

class ResNet(nn.Module):
    def __init__(self,
                 model_name: str = 'resnet18',
                 pretrained: bool = True,
                 num_concepts: int = 19
                 ):
        super().__init__()
        self.model_name = model_name.lower()

        if pretrained:
            # The new naming of pretrained weights, you can change to V2 if desired.
            weights = 'IMAGENET1K_V1'
        else:
            weights = None

        self.model = torchvision.models.resnet18(weights=weights)

        # Modify the first convolutional layer to accept 1-channel input and adjust kernel size and stride
        self.model.conv1 = nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1, bias=False)

        # Adjust max pooling layer to not downsample too much
        self.model.maxpool = nn.Identity()

        # Initialize the new conv1 layer properly
        nn.init.kaiming_normal_(self.model.conv1.weight, mode='fan_out', nonlinearity='relu')

        # num_ftrs = self.model.fc.in_features

        # Set the fully connected layer to identity
        self.model.fc = nn.Identity()

        # self.digit_classifer = nn.Linear(num_ftrs, num_concepts)

    def forward(self, x):
        x = self.model(x)
        # digit = self.digit_classifer(x)
        return x

In [None]:
rn = ResNet()
rn(torch.randn(1, 1, 28, 56)).shape

#### Create MNISTDataset and MNISTAdditionDataset

In [None]:
import numpy as np
import os
import struct
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms

class MNISTDataset(Dataset):
    def __init__(self, images, labels, transform=None):
        self.images = np.array(images, copy=True)
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]
        if self.transform:
            image = self.transform(image)
        return image, label

# MNIST transforms
transform = transforms.Compose([
    transforms.ToTensor()
])

# Create train, validation, and test datasets
train_dataset = MNISTDataset(train_images, train_labels, transform=transform)
val_dataset = MNISTDataset(val_images, val_labels, transform=transform)
test_dataset = MNISTDataset(test_images, test_labels,  transform=transform)

In [None]:
train_dataset[0][0].shape , train_dataset[0][1]

In [None]:
class MNISTAdditionDataset(Dataset):
    def __init__(self, dataset):
        self.dataset = dataset

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        image, label = self.dataset[idx]
        random_index = np.random.randint(len(self.dataset))

        image2, label2 = self.dataset[random_index]

        image = torch.cat((image, image2), dim=-1)#.permute(0, 3, 1, 2)

        c_label = torch.zeros((2,10))
        c_label[0][label] = 1
        c_label[1][label2] = 1

        y_label = label + label2

        return image, c_label, torch.tensor(y_label)


# Create the combined datasets
train_addition_dataset = MNISTAdditionDataset(train_dataset)
val_addition_dataset = MNISTAdditionDataset(val_dataset)
test_addition_dataset = MNISTAdditionDataset(test_dataset)

In [None]:
train_addition_dataset[0][0].shape , train_addition_dataset[0][1]

In [None]:

rn(train_addition_dataset[0][0].unsqueeze(1)).shape

In [None]:
def prepare_data(dataset, model):
  x = []
  c = []
  y = []
  model.eval()
  with torch.no_grad():
    for image, c_label, y_label in dataset:
      image = image.unsqueeze(0) # Add batch dimension
      features = model(image)
      x.append(features.squeeze(0).numpy()) # Remove batch dimension and convert to numpy
      c.append(c_label.flatten().numpy())
      y.append(y_label.numpy())
  return np.array(x), np.array(c), np.array(y)

# Prepare data for training and testing
x_train, c_train, y_train = prepare_data(train_addition_dataset, rn)
#x_val, c_val, y_val = prepare_data(val_addition_dataset, rn)
x_test, c_test, y_test = prepare_data(test_addition_dataset, rn)


KeyboardInterrupt: 

In [None]:
pwd

'/content'

In [None]:
%cd /content/

/content


In [None]:
import pickle

# Save the data
# with open('train_data.pkl', 'wb') as f:
#   pickle.dump((x_train, c_train, y_train), f)

# with open('test_data.pkl', 'wb') as f:
#   pickle.dump((x_test, c_test, y_test), f)

# Load the data
with open('train_data.pkl', 'rb') as f:
  x_train, c_train, y_train = pickle.load(f)

with open('test_data.pkl', 'rb') as f:
  x_test, c_test, y_test = pickle.load(f)

print(x_train.shape, c_train.shape, y_train.shape)


(50000, 512) (50000, 20) (50000,)


#### Let's start working

In [None]:
!git clone -b code_exploration https://github.com/alialhousseini/pytorch_explain

Cloning into 'pytorch_explain'...
remote: Enumerating objects: 1318, done.[K
remote: Counting objects: 100% (428/428), done.[K
remote: Compressing objects: 100% (135/135), done.[K
remote: Total 1318 (delta 298), reused 406 (delta 291), pack-reused 890[K
Receiving objects: 100% (1318/1318), 8.35 MiB | 27.31 MiB/s, done.
Resolving deltas: 100% (764/764), done.


In [None]:
%%capture
!pip install wandb

In [None]:
%cd /content/pytorch_explain

/content/pytorch_explain


In [None]:
import torch.nn.functional as F
from torch_explain.nn.concepts import IntpLinearLayer1, IntpLinearLayer2, IntpLinearLayer3, ConceptReasoningLayer, ConceptEmbedding
import torch
import torch_explain as te
from torch_explain import datasets
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
from torch.optim.lr_scheduler import ReduceLROnPlateau
import wandb

In [None]:
# Good practice: Keep all keys secret!
wandb.login(key='598e89dd4ece6531ed05c8fc9ea9bb6914798009')

In [None]:
dataset_names = ['XOR','XNOR', 'IsBinEven','Trigonometry', 'Dot']

datasets = [
    datasets.xor(5000),
    datasets.xnor(5000),
    datasets.is_bin_even(5000),
    datasets.trigonometry(5000),
    datasets.dot(5000)
]

models = ['DCRBase', 'LLR1', 'LLR2', 'LLR3']

Train on MNIST

In [None]:
# Function to select loss function based on the configuration
def get_loss_function(name):
    if name == 'bce':
        return torch.nn.BCELoss()
    elif name == 'mse':
        return torch.nn.MSELoss()
    elif name == 'huber':
        return torch.nn.HuberLoss()
    elif name == 'hinge':
        return torch.nn.HingeEmbeddingLoss()
    elif name == 'bceL':
        return torch.nn.BCEWithLogitsLoss()
    elif name == 'cross_entropy':
        return torch.nn.CrossEntropyLoss()

In [None]:
import warnings
warnings.filterwarnings('ignore')

In [None]:
y_train.shape

(50000,)

In [None]:
import time
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
import wandb

print("===========================================================")
print("Training Started!")

# Define target accuracy for convergence time
target_accuracy = 0.5
isBias = True

# Iterate over models
for model_name in models:
    if model_name == 'DCRBase':
        print(f"Training on {model_name} ... ")
        print(f"--------------------------------")

        dataset_name = "MNIST_Addition"
        print(f"The following dataset has been loaded successfully: {dataset_name}")

        x_train = torch.from_numpy(x_train).float()
        c_train = torch.from_numpy(c_train).float()
        y_train = torch.from_numpy(y_train).long()

        x_test = torch.from_numpy(x_test).float()
        c_test = torch.from_numpy(c_test).float()
        y_test = torch.from_numpy(y_test).long()

        y_train = F.one_hot(y_train.long().ravel()).float()
        y_test = F.one_hot(y_test.long().ravel()).float()

        embedding_size = 16
        concept_encoder = torch.nn.Sequential(
            torch.nn.Linear(x_train.shape[1], 32),
            torch.nn.LeakyReLU(),
            torch.nn.Linear(32, 32),
            torch.nn.LeakyReLU(),
            te.nn.ConceptEmbedding(32, c_train.shape[1], embedding_size),
        )

        task_predictor = ConceptReasoningLayer(embedding_size, y_train.shape[1])
        model = torch.nn.Sequential(concept_encoder, task_predictor)

        num_val_samples = int(len(x_train) * 0.2)
        num_train_samples = len(x_train) - num_val_samples
        train_dataset, val_dataset = random_split(
           list(zip(x_train, c_train, y_train)), [num_train_samples, num_val_samples])

        train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
        val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)

        wandb.init(project="pytorch_explain", entity="alih999954-politecnico-di-torino",
                   name=f"{model_name}_{dataset_name}")

        config = {
            'lr': 0.0005,
            'task_loss_weight': 0.5,
            'loss_function': 'bce',
            'loss_function2': 'bceL',
            'loss_function3': 'cross_entropy',
        }
        wandb.config.update(config)

        loss_form = get_loss_function(wandb.config.loss_function3)
        optimizer = torch.optim.AdamW(model.parameters(), lr=wandb.config.lr)
        scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.0001, patience=7)

        print(f'-------------------------- Training {dataset_name} using {model_name} ----------------------')

        # Initialize timers and convergence flag
        total_start_time = time.time()
        convergence_time = None

        for epoch in range(51):
            epoch_start_time = time.time()
            model.train()
            train_losses, train_correct = 0, 0
            all_y_true_train, all_y_pred_train = [], []
            all_c_true_train, all_c_pred_train = [], []

            for x_batch, c_batch, y_batch in train_loader:
                optimizer.zero_grad()
                c_emb, c_pred = concept_encoder(x_batch)
                y_pred = task_predictor(c_emb, c_pred)

                concept_loss = loss_form(c_pred, c_batch)
                task_loss = loss_form(y_pred, y_batch)
                loss = concept_loss + 0.5 * task_loss

                loss.backward()
                optimizer.step()

                train_losses += loss.item()
                train_correct += (y_pred.argmax(1) == y_batch.argmax(1)).sum().item()
                all_y_true_train.append(y_batch.cpu().numpy())
                all_y_pred_train.append(y_pred.detach().cpu().numpy())
                all_c_true_train.append(c_batch.cpu().numpy())
                all_c_pred_train.append(c_pred.detach().cpu().numpy())

                wandb.log({
                    'train_concept_loss': concept_loss.item(),
                    'train_task_loss': task_loss.item(),
                    'learning_rate': optimizer.param_groups[0]['lr']
                })

            all_y_true_train = np.concatenate(all_y_true_train, axis=0)
            all_y_pred_train = np.concatenate(all_y_pred_train, axis=0)
            all_c_true_train = np.concatenate(all_c_true_train, axis=0)
            all_c_pred_train = np.concatenate(all_c_pred_train, axis=0)

            train_precision = precision_score(all_y_true_train.argmax(1), all_y_pred_train.argmax(1), average='weighted')
            train_recall = recall_score(all_y_true_train.argmax(1), all_y_pred_train.argmax(1), average='weighted')
            train_f1 = f1_score(all_y_true_train.argmax(1), all_y_pred_train.argmax(1), average='weighted')
            train_concept_accuracy = accuracy_score(all_c_true_train.argmax(1), all_c_pred_train.argmax(1))
            train_task_accuracy = accuracy_score(all_y_true_train.argmax(1), all_y_pred_train.argmax(1))

            model.eval()
            val_losses, val_correct = 0, 0
            all_y_true_val, all_y_pred_val = [], []
            all_c_true_val, all_c_pred_val = [], []

            with torch.no_grad():
                for x_batch, c_batch, y_batch in val_loader:
                    c_emb, c_pred = concept_encoder(x_batch)
                    y_pred = task_predictor(c_emb, c_pred)

                    val_concept_loss = loss_form(c_pred, c_batch)
                    val_task_loss = loss_form(y_pred, y_batch)
                    val_loss = val_concept_loss + 0.5 * val_task_loss

                    val_losses += val_loss.item()
                    val_correct += (y_pred.argmax(1) == y_batch.argmax(1)).sum().item()
                    all_y_true_val.append(y_batch.cpu().numpy())
                    all_y_pred_val.append(y_pred.detach().cpu().numpy())
                    all_c_true_val.append(c_batch.cpu().numpy())
                    all_c_pred_val.append(c_pred.detach().cpu().numpy())

                    wandb.log({
                        'val_concept_loss': val_concept_loss.item(),
                        'val_task_loss': val_task_loss.item(),
                        'val_learning_rate': optimizer.param_groups[0]['lr']
                    })

            all_y_true_val = np.concatenate(all_y_true_val, axis=0)
            all_y_pred_val = np.concatenate(all_y_pred_val, axis=0)
            all_c_true_val = np.concatenate(all_c_true_val, axis=0)
            all_c_pred_val = np.concatenate(all_c_pred_val, axis=0)

            val_precision = precision_score(all_y_true_val.argmax(1), all_y_pred_val.argmax(1), average='weighted')
            val_recall = recall_score(all_y_true_val.argmax(1), all_y_pred_val.argmax(1), average='weighted')
            val_f1 = f1_score(all_y_true_val.argmax(1), all_y_pred_val.argmax(1), average='weighted')
            val_concept_accuracy = accuracy_score(all_c_true_val.argmax(1), all_c_pred_val.argmax(1))
            val_task_accuracy = accuracy_score(all_y_true_val.argmax(1), all_y_pred_val.argmax(1))

            scheduler.step(val_losses / len(val_loader))

            epoch_end_time = time.time()
            epoch_time = epoch_end_time - epoch_start_time

            print(f"Epoch {epoch+1}, Loss: {train_losses/len(train_loader)}, Train Accuracy: {train_correct/len(train_dataset)}, Val Loss: {val_losses/len(val_loader)}, Val Accuracy: {val_correct/len(val_dataset)}, Train Precision: {train_precision}, Train Recall: {train_recall}, Train F1: {train_f1}, Val Precision: {val_precision}, Val Recall: {val_recall}, Val F1: {val_f1}, Train Concept Accuracy: {train_concept_accuracy}, Train Task Accuracy: {train_task_accuracy}, Val Concept Accuracy: {val_concept_accuracy}, Val Task Accuracy: {val_task_accuracy}, Epoch Time: {epoch_time}")

            wandb.log({
                'epoch': epoch + 1,
                'loss': train_losses / len(train_loader),
                'train_accuracy': train_correct / len(train_dataset),
                'val_loss': val_losses / len(val_loader),
                'val_accuracy': val_correct / len(val_dataset),
                'train_precision': train_precision,
                'train_recall': train_recall,
                'train_f1': train_f1,
                'val_precision': val_precision,
                'val_recall': val_recall,
                'val_f1': val_f1,
                'train_concept_accuracy': train_concept_accuracy,
                'train_task_accuracy': train_task_accuracy,
                'val_concept_accuracy': val_concept_accuracy,
                'val_task_accuracy': val_task_accuracy,
                'epoch_time': epoch_time
            })

            # Check for convergence
            if val_correct / len(val_dataset) >= target_accuracy and convergence_time is None:
                convergence_time = time.time() - total_start_time
                wandb.log({'convergence_time': convergence_time})
                print(f"Convergence achieved at epoch {epoch+1} with validation accuracy {val_correct/len(val_dataset)}")

        total_training_time = time.time() - total_start_time
        wandb.log({'total_training_time': total_training_time})
        print(f"Total Training Time: {total_training_time}")

        print(f"\n Training on {dataset_name} using {model_name} has been completed!")
        torch.save(model, f'model_{model_name}_{dataset_name}.pth')
        torch.save(model.state_dict(), f'model_state_dict_{model_name}_{dataset_name}.pth')
        wandb.finish()

        print(f"===========================================================")

    # Repeat similar updates for 'LLR1', 'LLR2', and 'LLR3' models


    elif model_name == 'LLR1':
        print(f"Training on {model_name} ... ")
        print(f"--------------------------------")

        dataset_name= "MNIST_Addition"
        print(f"The following dataset has been loaded successfully: {dataset_name}")

        # y_train = F.one_hot(y_train.long().ravel()).float()
        # y_test = F.one_hot(y_test.long().ravel()).float()

        embedding_size = 16
        concept_encoder = torch.nn.Sequential(
            torch.nn.Linear(x_train.shape[1], 32),
            torch.nn.LeakyReLU(),
            torch.nn.Linear(32, 32),
            torch.nn.LeakyReLU(),
            te.nn.ConceptEmbedding(32, c_train.shape[1], embedding_size),
        )

        task_predictor = IntpLinearLayer1(embedding_size, y_train.shape[1], bias=isBias)
        model = torch.nn.Sequential(concept_encoder, task_predictor)

        num_val_samples = int(len(x_train) * 0.2)
        num_train_samples = len(x_train) - num_val_samples
        train_dataset, val_dataset = random_split(
            list(zip(x_train, c_train, y_train)), [num_train_samples, num_val_samples])

        train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
        val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)

        wandb.init(project="pytorch_explain", entity="alih999954-politecnico-di-torino",
                   name=f"{model_name}_{dataset_name}")

        config = {
            'lr': 0.0005,
            'task_loss_weight': 0.5,
            'loss_function': 'bce',
            'loss_function2': 'bceL',
            'loss_function3': 'cross_entropy',
        }
        wandb.config.update(config)

        c_loss = get_loss_function(wandb.config.loss_function)
        y_loss = get_loss_function(wandb.config.loss_function3)

        optimizer = torch.optim.AdamW(model.parameters(), lr=wandb.config.lr)
        scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.0001, patience=7)

        print(f'-------------------------- Training {dataset_name} using {model_name} ----------------------')

        # Initialize timers and convergence flag
        total_start_time = time.time()
        convergence_time = None

        for epoch in range(41):
            epoch_start_time = time.time()
            model.train()
            train_losses, train_correct = 0, 0
            all_y_true_train, all_y_pred_train = [], []
            all_c_true_train, all_c_pred_train = [], []

            for x_batch, c_batch, y_batch in train_loader:
                optimizer.zero_grad()
                c_emb, c_pred = concept_encoder(x_batch)
                y_pred = task_predictor(c_emb, c_pred)

                concept_loss = c_loss(c_pred, c_batch)
                task_loss = y_loss(y_pred, y_batch)
                loss = concept_loss + 0.5 * task_loss

                loss.backward()
                optimizer.step()

                train_losses += loss.item()
                train_correct += (y_pred.argmax(1) == y_batch.argmax(1)).sum().item()
                all_y_true_train.append(y_batch.cpu().numpy())
                all_y_pred_train.append(y_pred.detach().cpu().numpy())
                all_c_true_train.append(c_batch.cpu().numpy())
                all_c_pred_train.append(c_pred.detach().cpu().numpy())

                wandb.log({
                    'train_concept_loss': concept_loss.item(),
                    'train_task_loss': task_loss.item(),
                    'learning_rate': optimizer.param_groups[0]['lr']
                })

            all_y_true_train = np.concatenate(all_y_true_train, axis=0)
            all_y_pred_train = np.concatenate(all_y_pred_train, axis=0)
            all_c_true_train = np.concatenate(all_c_true_train, axis=0)
            all_c_pred_train = np.concatenate(all_c_pred_train, axis=0)

            train_precision = precision_score(all_y_true_train.argmax(1), all_y_pred_train.argmax(1), average='weighted')
            train_recall = recall_score(all_y_true_train.argmax(1), all_y_pred_train.argmax(1), average='weighted')
            train_f1 = f1_score(all_y_true_train.argmax(1), all_y_pred_train.argmax(1), average='weighted')
            train_concept_accuracy = accuracy_score(all_c_true_train.argmax(1), all_c_pred_train.argmax(1))
            train_task_accuracy = accuracy_score(all_y_true_train.argmax(1), all_y_pred_train.argmax(1))

            model.eval()
            val_losses, val_correct = 0, 0
            all_y_true_val, all_y_pred_val = [], []
            all_c_true_val, all_c_pred_val = [], []

            with torch.no_grad():
                for x_batch, c_batch, y_batch in val_loader:
                    c_emb, c_pred = concept_encoder(x_batch)
                    y_pred = task_predictor(c_emb, c_pred)

                    val_concept_loss = c_loss(c_pred, c_batch)
                    val_task_loss = y_loss(y_pred, y_batch)
                    val_loss = val_concept_loss + 0.5 * val_task_loss

                    val_losses += val_loss.item()
                    val_correct += (y_pred.argmax(1) == y_batch.argmax(1)).sum().item()
                    all_y_true_val.append(y_batch.cpu().numpy())
                    all_y_pred_val.append(y_pred.detach().cpu().numpy())
                    all_c_true_val.append(c_batch.cpu().numpy())
                    all_c_pred_val.append(c_pred.detach().cpu().numpy())

                    wandb.log({
                        'val_concept_loss': val_concept_loss.item(),
                        'val_task_loss': val_task_loss.item(),
                        'val_learning_rate': optimizer.param_groups[0]['lr']
                    })

            all_y_true_val = np.concatenate(all_y_true_val, axis=0)
            all_y_pred_val = np.concatenate(all_y_pred_val, axis=0)
            all_c_true_val = np.concatenate(all_c_true_val, axis=0)
            all_c_pred_val = np.concatenate(all_c_pred_val, axis=0)

            val_precision = precision_score(all_y_true_val.argmax(1), all_y_pred_val.argmax(1), average='weighted')
            val_recall = recall_score(all_y_true_val.argmax(1), all_y_pred_val.argmax(1), average='weighted')
            val_f1 = f1_score(all_y_true_val.argmax(1), all_y_pred_val.argmax(1), average='weighted')
            val_concept_accuracy = accuracy_score(all_c_true_val.argmax(1), all_c_pred_val.argmax(1))
            val_task_accuracy = accuracy_score(all_y_true_val.argmax(1), all_y_pred_val.argmax(1))

            scheduler.step(val_losses / len(val_loader))

            epoch_end_time = time.time()
            epoch_time = epoch_end_time - epoch_start_time

            print(f"Epoch {epoch+1}, Loss: {train_losses/len(train_loader)}, Train Accuracy: {train_correct/len(train_dataset)}, Val Loss: {val_losses/len(val_loader)}, Val Accuracy: {val_correct/len(val_dataset)}, Train Precision: {train_precision}, Train Recall: {train_recall}, Train F1: {train_f1}, Val Precision: {val_precision}, Val Recall: {val_recall}, Val F1: {val_f1}, Epoch Time: {epoch_time}")

            wandb.log({
                'epoch': epoch + 1,
                'loss': train_losses / len(train_loader),
                'train_accuracy': train_correct / len(train_dataset),
                'val_loss': val_losses / len(val_loader),
                'val_accuracy': val_correct / len(val_dataset),
                'train_precision': train_precision,
                'train_recall': train_recall,
                'train_f1': train_f1,
                'val_precision': val_precision,
                'val_recall': val_recall,
                'val_f1': val_f1,
                'train_concept_accuracy': train_concept_accuracy,
                'train_task_accuracy': train_task_accuracy,
                'val_concept_accuracy': val_concept_accuracy,
                'val_task_accuracy': val_task_accuracy,
                'epoch_time': epoch_time
            })

            # Check for convergence
            if val_correct / len(val_dataset) >= target_accuracy and convergence_time is None:
                convergence_time = time.time() - total_start_time
                wandb.log({'convergence_time': convergence_time})
                print(f"Convergence achieved at epoch {epoch+1} with validation accuracy {val_correct/len(val_dataset)}")

        total_training_time = time.time() - total_start_time
        wandb.log({'total_training_time': total_training_time})
        print(f"Total Training Time: {total_training_time}")

        print(f"\n Training on {dataset_name} using {model_name} has been completed!")
        torch.save(model, f'model_{model_name}_{dataset_name}.pth')
        torch.save(model.state_dict(), f'model_state_dict_{model_name}_{dataset_name}.pth')
        wandb.finish()

        print(f"===========================================================")

    elif model_name == 'LLR2':
        print(f"Training on {model_name} ... ")
        print(f"--------------------------------")
        print(f"The following dataset has been loaded successfully: {dataset_name}")

        # y_train = F.one_hot(y_train.long().ravel()).float()
        # y_test = F.one_hot(y_test.long().ravel()).float()

        embedding_size = 16
        concept_encoder = torch.nn.Sequential(
            torch.nn.Linear(x_train.shape[1], 32),
            torch.nn.LeakyReLU(),
            torch.nn.Linear(32, 32),
            torch.nn.LeakyReLU(),
            te.nn.ConceptEmbedding(32, c_train.shape[1], embedding_size),
        )

        task_predictor = IntpLinearLayer2(embedding_size, y_train.shape[1], bias=isBias)
        model = torch.nn.Sequential(concept_encoder, task_predictor)

        wandb.init(project="pytorch_explain", entity="alih999954-politecnico-di-torino",
                   name=f"{model_name}_{dataset_name}")

        config = {
            'lr': 0.0005,
            'task_loss_weight': 0.5,
            'loss_function': 'bce',
            'loss_function2': 'bceL',
            'loss_function3': 'cross_entropy',
        }
        wandb.config.update(config)

        c_loss = get_loss_function(wandb.config.loss_function)
        y_loss = get_loss_function(wandb.config.loss_function3)

        optimizer = torch.optim.AdamW(model.parameters(), lr=wandb.config.lr)
        scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.0001, patience=7)

        print(f'-------------------------- Training {dataset_name} using {model_name} ----------------------')

        # Initialize timers and convergence flag
        total_start_time = time.time()
        convergence_time = None

        for epoch in range(41):
            epoch_start_time = time.time()
            model.train()
            train_losses, train_correct = 0, 0
            all_y_true_train, all_y_pred_train = [], []
            all_c_true_train, all_c_pred_train = [], []

            for x_batch, c_batch, y_batch in train_loader:
                optimizer.zero_grad()
                c_emb, c_pred = concept_encoder(x_batch)
                y_pred = task_predictor(c_emb, c_pred)

                concept_loss = c_loss(c_pred, c_batch)
                task_loss = y_loss(y_pred, y_batch)
                loss = concept_loss + 0.5 * task_loss

                loss.backward()
                optimizer.step()

                train_losses += loss.item()
                train_correct += (y_pred.argmax(1) == y_batch.argmax(1)).sum().item()
                all_y_true_train.append(y_batch.cpu().numpy())
                all_y_pred_train.append(y_pred.detach().cpu().numpy())
                all_c_true_train.append(c_batch.cpu().numpy())
                all_c_pred_train.append(c_pred.detach().cpu().numpy())

                wandb.log({
                    'train_concept_loss': concept_loss.item(),
                    'train_task_loss': task_loss.item(),
                    'learning_rate': optimizer.param_groups[0]['lr']
                })

            all_y_true_train = np.concatenate(all_y_true_train, axis=0)
            all_y_pred_train = np.concatenate(all_y_pred_train, axis=0)
            all_c_true_train = np.concatenate(all_c_true_train, axis=0)
            all_c_pred_train = np.concatenate(all_c_pred_train, axis=0)

            train_precision = precision_score(all_y_true_train.argmax(1), all_y_pred_train.argmax(1), average='weighted')
            train_recall = recall_score(all_y_true_train.argmax(1), all_y_pred_train.argmax(1), average='weighted')
            train_f1 = f1_score(all_y_true_train.argmax(1), all_y_pred_train.argmax(1), average='weighted')
            train_concept_accuracy = accuracy_score(all_c_true_train.argmax(1), all_c_pred_train.argmax(1))
            train_task_accuracy = accuracy_score(all_y_true_train.argmax(1), all_y_pred_train.argmax(1))

            model.eval()
            val_losses, val_correct = 0, 0
            all_y_true_val, all_y_pred_val = [], []
            all_c_true_val, all_c_pred_val = [], []

            with torch.no_grad():
                for x_batch, c_batch, y_batch in val_loader:
                    c_emb, c_pred = concept_encoder(x_batch)
                    y_pred = task_predictor(c_emb, c_pred)

                    val_concept_loss = c_loss(c_pred, c_batch)
                    val_task_loss = y_loss(y_pred, y_batch)
                    val_loss = val_concept_loss + 0.5 * val_task_loss

                    val_losses += val_loss.item()
                    val_correct += (y_pred.argmax(1) == y_batch.argmax(1)).sum().item()
                    all_y_true_val.append(y_batch.cpu().numpy())
                    all_y_pred_val.append(y_pred.detach().cpu().numpy())
                    all_c_true_val.append(c_batch.cpu().numpy())
                    all_c_pred_val.append(c_pred.detach().cpu().numpy())

                    wandb.log({
                        'val_concept_loss': val_concept_loss.item(),
                        'val_task_loss': val_task_loss.item(),
                        'val_learning_rate': optimizer.param_groups[0]['lr']
                    })


            all_y_true_val = np.concatenate(all_y_true_val, axis=0)
            all_y_pred_val = np.concatenate(all_y_pred_val, axis=0)
            all_c_true_val = np.concatenate(all_c_true_val, axis=0)
            all_c_pred_val = np.concatenate(all_c_pred_val, axis=0)

            val_precision = precision_score(all_y_true_val.argmax(1), all_y_pred_val.argmax(1), average='weighted')
            val_recall = recall_score(all_y_true_val.argmax(1), all_y_pred_val.argmax(1), average='weighted')
            val_f1 = f1_score(all_y_true_val.argmax(1), all_y_pred_val.argmax(1), average='weighted')
            val_concept_accuracy = accuracy_score(all_c_true_val.argmax(1), all_c_pred_val.argmax(1))
            val_task_accuracy = accuracy_score(all_y_true_val.argmax(1), all_y_pred_val.argmax(1))

            scheduler.step(val_losses / len(val_loader))

            epoch_end_time = time.time()
            epoch_time = epoch_end_time - epoch_start_time

            print(f"Epoch {epoch+1}, Loss: {train_losses/len(train_loader)}, Train Accuracy: {train_correct/len(train_dataset)}, Val Loss: {val_losses/len(val_loader)}, Val Accuracy: {val_correct/len(val_dataset)}, Train Precision: {train_precision}, Train Recall: {train_recall}, Train F1: {train_f1}, Val Precision: {val_precision}, Val Recall: {val_recall}, Val F1: {val_f1}, Epoch Time: {epoch_time}")

            wandb.log({
                'epoch': epoch + 1,
                'loss': train_losses / len(train_loader),
                'train_accuracy': train_correct / len(train_dataset),
                'val_loss': val_losses / len(val_loader),
                'val_accuracy': val_correct / len(val_dataset),
                'train_precision': train_precision,
                'train_recall': train_recall,
                'train_f1': train_f1,
                'val_precision': val_precision,
                'val_recall': val_recall,
                'val_f1': val_f1,
                'train_concept_accuracy': train_concept_accuracy,
                'train_task_accuracy': train_task_accuracy,
                'val_concept_accuracy': val_concept_accuracy,
                'val_task_accuracy': val_task_accuracy,
                'epoch_time': epoch_time
            })

            # Check for convergence
            if val_correct / len(val_dataset) >= target_accuracy and convergence_time is None:
                convergence_time = time.time() - total_start_time
                wandb.log({'convergence_time': convergence_time})
                print(f"Convergence achieved at epoch {epoch+1} with validation accuracy {val_correct/len(val_dataset)}")

        total_training_time = time.time() - total_start_time
        wandb.log({'total_training_time': total_training_time})
        print(f"Total Training Time: {total_training_time}")

        print(f"\n Training on {dataset_name} using {model_name} has been completed!")
        torch.save(model, f'model_{model_name}_{dataset_name}.pth')
        torch.save(model.state_dict(), f'model_state_dict_{model_name}_{dataset_name}.pth')
        wandb.finish()

        print(f"===========================================================")

    if model_name == 'LLR3':
        print(f"Training on {model_name} ... ")
        print(f"--------------------------------")

        embedding_size = 16
        concept_encoder = torch.nn.Sequential(
            torch.nn.Linear(x_train.shape[1], 32),
            torch.nn.LeakyReLU(),
            torch.nn.Linear(32, 32),
            torch.nn.LeakyReLU(),
            te.nn.ConceptEmbedding(32, c_train.shape[1], embedding_size),
        )

        task_predictor = IntpLinearLayer3(embedding_size, y_train.shape[1], bias=isBias)
        model = torch.nn.Sequential(concept_encoder, task_predictor)

        # num_val_samples = int(len(x_train) * 0.2)
        # num_train_samples = len(x_train) - num_val_samples
        # train_dataset, val_dataset = random_split(
        #     list(zip(x_train, c_train, y_train)), [num_train_samples, num_val_samples])

        # train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
        # val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)

        wandb.init(project="pytorch_explain", entity="alih999954-politecnico-di-torino",
                   name=f"{model_name}_{dataset_name}")

        config = {
            'lr': 0.0005,
            'task_loss_weight': 0.5,
            'loss_function': 'bce',
            'loss_function2': 'bceL',
            'loss_function3': 'cross_entropy',
        }
        wandb.config.update(config)

        c_loss = get_loss_function(wandb.config.loss_function)
        y_loss = get_loss_function(wandb.config.loss_function3)

        optimizer = torch.optim.AdamW(model.parameters(), lr=wandb.config.lr)
        scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.0001, patience=7)

        print(f'-------------------------- Training {dataset_name} using {model_name} ----------------------')

        # Initialize timers and convergence flag
        total_start_time = time.time()
        convergence_time = None

        for epoch in range(101):
            epoch_start_time = time.time()
            model.train()
            train_losses, train_correct = 0, 0
            all_y_true_train, all_y_pred_train = [], []
            all_c_true_train, all_c_pred_train = [], []

            for x_batch, c_batch, y_batch in train_loader:
                optimizer.zero_grad()
                c_emb, c_pred = concept_encoder(x_batch)
                y_pred = task_predictor(c_emb, c_pred)

                concept_loss = c_loss(c_pred, c_batch)
                task_loss = y_loss(y_pred, y_batch)
                loss = concept_loss + 0.5 * task_loss

                loss.backward()
                optimizer.step()

                train_losses += loss.item()
                train_correct += (y_pred.argmax(1) == y_batch.argmax(1)).sum().item()
                all_y_true_train.append(y_batch.cpu().numpy())
                all_y_pred_train.append(y_pred.detach().cpu().numpy())
                all_c_true_train.append(c_batch.cpu().numpy())
                all_c_pred_train.append(c_pred.detach().cpu().numpy())

                wandb.log({
                    'train_concept_loss': concept_loss.item(),
                    'train_task_loss': task_loss.item(),
                    'learning_rate': optimizer.param_groups[0]['lr']
                })

            all_y_true_train = np.concatenate(all_y_true_train, axis=0)
            all_y_pred_train = np.concatenate(all_y_pred_train, axis=0)
            all_c_true_train = np.concatenate(all_c_true_train, axis=0)
            all_c_pred_train = np.concatenate(all_c_pred_train, axis=0)

            train_precision = precision_score(all_y_true_train.argmax(1), all_y_pred_train.argmax(1), average='weighted')
            train_recall = recall_score(all_y_true_train.argmax(1), all_y_pred_train.argmax(1), average='weighted')
            train_f1 = f1_score(all_y_true_train.argmax(1), all_y_pred_train.argmax(1), average='weighted')
            train_concept_accuracy = accuracy_score(all_c_true_train.argmax(1), all_c_pred_train.argmax(1))
            train_task_accuracy = accuracy_score(all_y_true_train.argmax(1), all_y_pred_train.argmax(1))

            model.eval()
            val_losses, val_correct = 0, 0
            all_y_true_val, all_y_pred_val = [], []
            all_c_true_val, all_c_pred_val = [], []

            with torch.no_grad():
                for x_batch, c_batch, y_batch in val_loader:
                    c_emb, c_pred = concept_encoder(x_batch)
                    y_pred = task_predictor(c_emb, c_pred)

                    val_concept_loss = c_loss(c_pred, c_batch)
                    val_task_loss = y_loss(y_pred, y_batch)
                    val_loss = val_concept_loss + 0.5 * val_task_loss

                    val_losses += val_loss.item()
                    val_correct += (y_pred.argmax(1) == y_batch.argmax(1)).sum().item()
                    all_y_true_val.append(y_batch.cpu().numpy())
                    all_y_pred_val.append(y_pred.detach().cpu().numpy())
                    all_c_true_val.append(c_batch.cpu().numpy())
                    all_c_pred_val.append(c_pred.detach().cpu().numpy())

                    wandb.log({
                        'val_concept_loss': val_concept_loss.item(),
                        'val_task_loss': val_task_loss.item(),
                        'val_learning_rate': optimizer.param_groups[0]['lr']
                    })

            all_y_true_val = np.concatenate(all_y_true_val, axis=0)
            all_y_pred_val = np.concatenate(all_y_pred_val, axis=0)
            all_c_true_val = np.concatenate(all_c_true_val, axis=0)
            all_c_pred_val = np.concatenate(all_c_pred_val, axis=0)
            val_precision = precision_score(all_y_true_val.argmax(1), all_y_pred_val.argmax(1), average='weighted')
            val_recall = recall_score(all_y_true_val.argmax(1), all_y_pred_val.argmax(1), average='weighted')
            val_f1 = f1_score(all_y_true_val.argmax(1), all_y_pred_val.argmax(1), average='weighted')
            val_concept_accuracy = accuracy_score(all_c_true_val.argmax(1), all_c_pred_val.argmax(1))
            val_task_accuracy = accuracy_score(all_y_true_val.argmax(1), all_y_pred_val.argmax(1))

            scheduler.step(val_losses / len(val_loader))

            epoch_end_time = time.time()
            epoch_time = epoch_end_time - epoch_start_time

            print(f"Epoch {epoch+1}, Loss: {train_losses/len(train_loader)}, Train Accuracy: {train_correct/len(train_dataset)}, Val Loss: {val_losses/len(val_loader)}, Val Accuracy: {val_correct/len(val_dataset)}, Train Precision: {train_precision}, Train Recall: {train_recall}, Train F1: {train_f1}, Val Precision: {val_precision}, Val Recall: {val_recall}, Val F1: {val_f1}, Epoch Time: {epoch_time}")

            wandb.log({
                'epoch': epoch + 1,
                'loss': train_losses / len(train_loader),
                'train_accuracy': train_correct / len(train_dataset),
                'val_loss': val_losses / len(val_loader),
                'val_accuracy': val_correct / len(val_dataset),
                'train_precision': train_precision,
                'train_recall': train_recall,
                'train_f1': train_f1,
                'val_precision': val_precision,
                'val_recall': val_recall,
                'val_f1': val_f1,
                'train_concept_accuracy': train_concept_accuracy,
                'train_task_accuracy': train_task_accuracy,
                'val_concept_accuracy': val_concept_accuracy,
                'val_task_accuracy': val_task_accuracy,
                'epoch_time': epoch_time
            })


            # Check for convergence
            if val_correct / len(val_dataset) >= target_accuracy and convergence_time is None:
                convergence_time = time.time() - total_start_time
                wandb.log({'convergence_time': convergence_time})
                print(f"Convergence achieved at epoch {epoch+1} with validation accuracy {val_correct/len(val_dataset)}")

        total_training_time = time.time() - total_start_time
        wandb.log({'total_training_time': total_training_time})
        print(f"Total Training Time: {total_training_time}")

        print(f"\n Training on {dataset_name} using {model_name} has been completed!")
        torch.save(model, f'model_{model_name}_{dataset_name}.pth')
        torch.save(model.state_dict(), f'model_state_dict_{model_name}_{dataset_name}.pth')
        wandb.finish()

        print(f"===========================================================")


print(f"*********** ALL TRAINING ARE DONE - Check WandB ***********")


Training Started!
Training on DCRBase ... 
--------------------------------
The following dataset has been loaded successfully: MNIST_Addition
Training on LLR3 ... 
--------------------------------


VBox(children=(Label(value='0.001 MB of 0.011 MB uploaded\r'), FloatProgress(value=0.10524133355448664, max=1.…

0,1
learning_rate,▁▁▁▁▁▁▁▁▁▁▁
train_concept_loss,██▇▆▆▅▄▄▃▂▁
train_task_loss,▄██▅▁▃▅▂▂▃▆

0,1
learning_rate,0.0005
train_concept_loss,0.6629
train_task_loss,3.19551


-------------------------- Training MNIST_Addition using LLR3 ----------------------
Epoch 1, Loss: 1.7462483808517455, Train Accuracy: 0.0989, Val Loss: 1.6840228852193067, Val Accuracy: 0.0892, Train Precision: 0.031598877364128954, Train Recall: 0.0989, Train F1: 0.036714184199069445, Val Precision: 0.0213018437637755, Val Recall: 0.0892, Val F1: 0.02751930180610301, Epoch Time: 16.853581428527832
Epoch 2, Loss: 1.5685187818527222, Train Accuracy: 0.157675, Val Loss: 1.4384291604825645, Val Accuracy: 0.2166, Train Precision: 0.18838971451170192, Train Recall: 0.157675, Train F1: 0.12245993239323015, Val Precision: 0.21813463613356512, Val Recall: 0.2166, Val F1: 0.18281724653663914, Epoch Time: 17.98530626296997
Epoch 3, Loss: 1.379292474937439, Train Accuracy: 0.2432, Val Loss: 1.3366748246417683, Val Accuracy: 0.2626, Train Precision: 0.24010235236504973, Train Recall: 0.2432, Train F1: 0.22153339361655336, Val Precision: 0.26843970959001107, Val Recall: 0.2626, Val F1: 0.23010103

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
convergence_time,▁
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
epoch_time,▁▁▁▁▂▂▁▁█▃▁▂▁▁▁▁▂▁▂▁▁▁▁▁▁▂▁▁▁▁▂▁▁▁▂▁▂▃▁▁
learning_rate,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
loss,█▆▅▅▅▅▄▄▄▄▄▄▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁
total_training_time,▁
train_accuracy,▁▂▃▃▄▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇███████████
train_concept_accuracy,▁▂▂▃▄▄▄▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇▇▇████████████████
train_concept_loss,█▇▇▇▆▆▆▆▆▅▅▆▅▅▅▅▄▅▄▄▄▄▃▃▃▃▃▃▂▂▂▂▂▁▁▁▂▁▁▁
train_f1,▁▃▃▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇███████████

0,1
convergence_time,408.80884
epoch,101.0
epoch_time,16.92723
learning_rate,0.0005
loss,0.34843
total_training_time,1751.10247
train_accuracy,0.83805
train_concept_accuracy,0.4362
train_concept_loss,0.08034
train_f1,0.83799


*********** ALL TRAINING ARE DONE - Check WandB ***********


In [None]:
import time
from sklearn.metrics import precision_score, recall_score, f1_score

print("===========================================================")
print("Training Started!")

# Define target accuracy for convergence time
target_accuracy = 0.9  # for example

# Iterate over models
for model_name in models:
    if model_name == 'DCRBase':
        print(f"Training on {model_name} ... ")
        print(f"--------------------------------")

        for dataset_name, dataset in zip(dataset_names, datasets):
            x, c, y = dataset
            x_train, x_test, c_train, c_test, y_train, y_test = train_test_split(
                x, c, y, test_size=0.3, random_state=42)
            print(f"The following dataset has been loaded successfully: {dataset_name}")

            y_train = F.one_hot(y_train.long().ravel()).float()
            y_test = F.one_hot(y_test.long().ravel()).float()

            embedding_size = 16
            concept_encoder = torch.nn.Sequential(
                torch.nn.Linear(x.shape[1], 16),
                torch.nn.LeakyReLU(),
                te.nn.ConceptEmbedding(16, c.shape[1], embedding_size),
            )

            task_predictor = ConceptReasoningLayer(embedding_size, y_train.shape[1])
            model = torch.nn.Sequential(concept_encoder, task_predictor)

            num_val_samples = int(len(x_train) * 0.2)
            num_train_samples = len(x_train) - num_val_samples
            train_dataset, val_dataset = random_split(
                list(zip(x_train, c_train, y_train)), [num_train_samples, num_val_samples])

            train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
            val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)

            wandb.init(project="pytorch_explain", entity="alih999954-politecnico-di-torino",
                       name=f"{model_name}_{dataset_name}")

            config = {
                'lr': 0.0005,
                'task_loss_weight': 0.5,
                'loss_function': 'bce',
                'loss_function2': 'bceL',
                'loss_function3': 'cross_entropy'
            }
            wandb.config.update(config)

            loss_form = get_loss_function(wandb.config.loss_function)
            optimizer = torch.optim.AdamW(model.parameters(), lr=wandb.config.lr)
            scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.0001, patience=7)

            print(f'-------------------------- Training {dataset_name} using {model_name} ----------------------')

            # Initialize timers and convergence flag
            total_start_time = time.time()
            convergence_time = None

            for epoch in range(41):
                epoch_start_time = time.time()
                model.train()
                train_losses, train_correct = 0, 0
                all_y_true_train, all_y_pred_train = [], []
                all_c_true_train, all_c_pred_train = [], []

                for x_batch, c_batch, y_batch in train_loader:
                    optimizer.zero_grad()
                    c_emb, c_pred = concept_encoder(x_batch)
                    y_pred = task_predictor(c_emb, c_pred)

                    concept_loss = loss_form(c_pred, c_batch)
                    task_loss = loss_form(y_pred, y_batch)
                    loss = concept_loss + 0.5 * task_loss

                    loss.backward()
                    optimizer.step()

                    train_losses += loss.item()
                    train_correct += (y_pred.argmax(1) == y_batch.argmax(1)).sum().item()
                    all_y_true_train.append(y_batch.cpu().numpy())
                    all_y_pred_train.append(y_pred.detach().cpu().numpy())
                    all_c_true_train.append(c_batch.cpu().numpy())
                    all_c_pred_train.append(c_pred.detach().cpu().numpy())

                    wandb.log({
                        'train_concept_loss': concept_loss.item(),
                        'train_task_loss': task_loss.item(),
                        'learning_rate': optimizer.param_groups[0]['lr']
                    })

                all_y_true_train = np.concatenate(all_y_true_train, axis=0)
                all_y_pred_train = np.concatenate(all_y_pred_train, axis=0)
                all_c_true_train = np.concatenate(all_c_true_train, axis=0)
                all_c_pred_train = np.concatenate(all_c_pred_train, axis=0)
                train_precision = precision_score(all_y_true_train.argmax(1), all_y_pred_train.argmax(1), average='weighted')
                train_recall = recall_score(all_y_true_train.argmax(1), all_y_pred_train.argmax(1), average='weighted')
                train_f1 = f1_score(all_y_true_train.argmax(1), all_y_pred_train.argmax(1), average='weighted')
                train_concept_accuracy = accuracy_score(all_c_true_train.argmax(1), all_c_pred_train.argmax(1))
                train_task_accuracy = accuracy_score(all_y_true_train.argmax(1), all_y_pred_train.argmax(1))

                model.eval()
                val_losses, val_correct = 0, 0
                all_y_true_val, all_y_pred_val = [], []
                all_c_true_val, all_c_pred_val = [], []

                with torch.no_grad():
                    for x_batch, c_batch, y_batch in val_loader:
                        c_emb, c_pred = concept_encoder(x_batch)
                        y_pred = task_predictor(c_emb, c_pred)

                        val_concept_loss = loss_form(c_pred, c_batch)
                        val_task_loss = loss_form(y_pred, y_batch)
                        val_loss = val_concept_loss + 0.5 * val_task_loss

                        val_losses += val_loss.item()
                        val_correct += (y_pred.argmax(1) == y_batch.argmax(1)).sum().item()
                        all_y_true_val.append(y_batch.cpu().numpy())
                        all_y_pred_val.append(y_pred.detach().cpu().numpy())
                        all_c_true_val.append(c_batch.cpu().numpy())
                        all_c_pred_val.append(c_pred.detach().cpu().numpy())

                        wandb.log({
                            'val_concept_loss': val_concept_loss.item(),
                            'val_task_loss': val_task_loss.item(),
                            'val_learning_rate': optimizer.param_groups[0]['lr']
                        })

                all_y_true_val = np.concatenate(all_y_true_val, axis=0)
                all_y_pred_val = np.concatenate(all_y_pred_val, axis=0)
                all_c_true_val = np.concatenate(all_c_true_val, axis=0)
                all_c_pred_val = np.concatenate(all_c_pred_val, axis=0)
                val_precision = precision_score(all_y_true_val.argmax(1), all_y_pred_val.argmax(1), average='weighted')
                val_recall = recall_score(all_y_true_val.argmax(1), all_y_pred_val.argmax(1), average='weighted')
                val_f1 = f1_score(all_y_true_val.argmax(1), all_y_pred_val.argmax(1), average='weighted')
                val_concept_accuracy = accuracy_score(all_c_true_val.argmax(1), all_c_pred_val.argmax(1))
                val_task_accuracy = accuracy_score(all_y_true_val.argmax(1), all_y_pred_val.argmax(1))

                scheduler.step(val_losses / len(val_loader))

                epoch_end_time = time.time()
                epoch_time = epoch_end_time - epoch_start_time

                print(f"Epoch {epoch+1}, Loss: {train_losses/len(train_loader)}, Train Accuracy: {train_correct/len(train_dataset)}, Val Loss: {val_losses/len(val_loader)}, Val Accuracy: {val_correct/len(val_dataset)}, Train Precision: {train_precision}, Train Recall: {train_recall}, Train F1: {train_f1}, Val Precision: {val_precision}, Val Recall: {val_recall}, Val F1: {val_f1}, Epoch Time: {epoch_time}")

                wandb.log({
                    'epoch': epoch + 1,
                    'loss': train_losses / len(train_loader),
                    'train_accuracy': train_correct / len(train_dataset),
                    'val_loss': val_losses / len(val_loader),
                    'val_accuracy': val_correct / len(val_dataset),
                    'train_precision': train_precision,
                    'train_recall': train_recall,
                    'train_f1': train_f1,
                    'val_precision': val_precision,
                    'val_recall': val_recall,
                    'val_f1': val_f1,
                    'train_concept_accuracy': train_concept_accuracy,
                    'train_task_accuracy': train_task_accuracy,
                    'val_concept_accuracy': val_concept_accuracy,
                    'val_task_accuracy': val_task_accuracy,
                    'epoch_time': epoch_time
                })

                # Check for convergence
                if val_correct / len(val_dataset) >= target_accuracy and convergence_time is None:
                    convergence_time = time.time() - total_start_time
                    wandb.log({'convergence_time': convergence_time})
                    print(f"Convergence achieved at epoch {epoch+1} with validation accuracy {val_correct/len(val_dataset)}")

            total_training_time = time.time() - total_start_time
            wandb.log({'total_training_time': total_training_time})
            print(f"Total Training Time: {total_training_time}")

            print(f"\n Training on {dataset_name} using {model_name} has been completed!")
            torch.save(model, f'model_{model_name}_{dataset_name}.pth')
            torch.save(model.state_dict(), f'model_state_dict_{model_name}_{dataset_name}.pth')
            wandb.finish()

        print(f"===========================================================")

    if model_name == 'LLR1':
        print(f"Training on {model_name} ... ")
        print(f"--------------------------------")

        for dataset_name, dataset in zip(dataset_names, datasets):
            x, c, y = dataset
            x_train, x_test, c_train, c_test, y_train, y_test = train_test_split(
                x, c, y, test_size=0.3, random_state=42)
            print(f"The following dataset has been loaded successfully: {dataset_name}")

            y_train = F.one_hot(y_train.long().ravel()).float()
            y_test = F.one_hot(y_test.long().ravel()).float()

            embedding_size = 16
            concept_encoder = torch.nn.Sequential(
                torch.nn.Linear(x.shape[1], 16),
                torch.nn.LeakyReLU(),
                te.nn.ConceptEmbedding(16, c.shape[1], embedding_size),
            )

            task_predictor = IntpLinearLayer1(embedding_size, y_train.shape[1], bias=isBias)
            model = torch.nn.Sequential(concept_encoder, task_predictor)

            num_val_samples = int(len(x_train) * 0.2)
            num_train_samples = len(x_train) - num_val_samples
            train_dataset, val_dataset = random_split(
                list(zip(x_train, c_train, y_train)), [num_train_samples, num_val_samples])

            train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
            val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)

            wandb.init(project="pytorch_explain", entity="alih999954-politecnico-di-torino",
                       name=f"{model_name}_{dataset_name}")

            config = {
                'lr': 0.0005,
                'task_loss_weight': 0.5,
                'loss_function': 'bce',
                'loss_function2': 'bceL',
                'loss_function3': 'cross_entropy',
            }
            wandb.config.update(config)

            c_loss = get_loss_function(wandb.config.loss_function2)
            y_loss = get_loss_function(wandb.config.loss_function2)

            optimizer = torch.optim.AdamW(model.parameters(), lr=wandb.config.lr)
            scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.0001, patience=7)

            print(f'-------------------------- Training {dataset_name} using {model_name} ----------------------')

            # Initialize timers and convergence flag
            total_start_time = time.time()
            convergence_time = None

            for epoch in range(41):
                epoch_start_time = time.time()
                model.train()
                train_losses, train_correct = 0, 0
                all_y_true_train, all_y_pred_train = [], []
                all_c_true_train, all_c_pred_train = [], []

                for x_batch, c_batch, y_batch in train_loader:
                    optimizer.zero_grad()
                    c_emb, c_pred = concept_encoder(x_batch)
                    y_pred = task_predictor(c_emb, c_pred)

                    concept_loss = c_loss(c_pred, c_batch)
                    task_loss = y_loss(y_pred, y_batch)
                    loss = concept_loss + 0.5 * task_loss

                    loss.backward()
                    optimizer.step()

                    train_losses += loss.item()
                    train_correct += (y_pred.argmax(1) == y_batch.argmax(1)).sum().item()
                    all_y_true_train.append(y_batch.cpu().numpy())
                    all_y_pred_train.append(y_pred.detach().cpu().numpy())
                    all_c_true_train.append(c_batch.cpu().numpy())
                    all_c_pred_train.append(c_pred.detach().cpu().numpy())

                    wandb.log({
                        'train_concept_loss': concept_loss.item(),
                        'train_task_loss': task_loss.item(),
                        'learning_rate': optimizer.param_groups[0]['lr']
                    })

                all_y_true_train = np.concatenate(all_y_true_train, axis=0)
                all_y_pred_train = np.concatenate(all_y_pred_train, axis=0)
                all_c_true_train = np.concatenate(all_c_true_train, axis=0)
                all_c_pred_train = np.concatenate(all_c_pred_train, axis=0)
                train_precision = precision_score(all_y_true_train.argmax(1), all_y_pred_train.argmax(1), average='weighted')
                train_recall = recall_score(all_y_true_train.argmax(1), all_y_pred_train.argmax(1), average='weighted')
                train_f1 = f1_score(all_y_true_train.argmax(1), all_y_pred_train.argmax(1), average='weighted')
                train_concept_accuracy = accuracy_score(all_c_true_train.argmax(1), all_c_pred_train.argmax(1))
                train_task_accuracy = accuracy_score(all_y_true_train.argmax(1), all_y_pred_train.argmax(1))

                model.eval()
                val_losses, val_correct = 0, 0
                all_y_true_val, all_y_pred_val = [], []
                all_c_true_val, all_c_pred_val = [], []

                with torch.no_grad():
                    for x_batch, c_batch, y_batch in val_loader:
                        c_emb, c_pred = concept_encoder(x_batch)
                        y_pred = task_predictor(c_emb, c_pred)

                        val_concept_loss = c_loss(c_pred, c_batch)
                        val_task_loss = y_loss(y_pred, y_batch)
                        val_loss = val_concept_loss + 0.5 * val_task_loss

                        val_losses += val_loss.item()
                        val_correct += (y_pred.argmax(1) == y_batch.argmax(1)).sum().item()
                        all_y_true_val.append(y_batch.cpu().numpy())
                        all_y_pred_val.append(y_pred.detach().cpu().numpy())
                        all_c_true_val.append(c_batch.cpu().numpy())
                        all_c_pred_val.append(c_pred.detach().cpu().numpy())

                        wandb.log({
                            'val_concept_loss': val_concept_loss.item(),
                            'val_task_loss': val_task_loss.item(),
                            'val_learning_rate': optimizer.param_groups[0]['lr']
                        })

                all_y_true_val = np.concatenate(all_y_true_val, axis=0)
                all_y_pred_val = np.concatenate(all_y_pred_val, axis=0)
                all_c_true_val = np.concatenate(all_c_true_val, axis=0)
                all_c_pred_val = np.concatenate(all_c_pred_val, axis=0)
                val_precision = precision_score(all_y_true_val.argmax(1), all_y_pred_val.argmax(1), average='weighted')
                val_recall = recall_score(all_y_true_val.argmax(1), all_y_pred_val.argmax(1), average='weighted')
                val_f1 = f1_score(all_y_true_val.argmax(1), all_y_pred_val.argmax(1), average='weighted')
                val_concept_accuracy = accuracy_score(all_c_true_val.argmax(1), all_c_pred_val.argmax(1))
                val_task_accuracy = accuracy_score(all_y_true_val.argmax(1), all_y_pred_val.argmax(1))

                scheduler.step(val_losses / len(val_loader))

                epoch_end_time = time.time()
                epoch_time = epoch_end_time - epoch_start_time

                print(f"Epoch {epoch+1}, Loss: {train_losses/len(train_loader)}, Train Accuracy: {train_correct/len(train_dataset)}, Val Loss: {val_losses/len(val_loader)}, Val Accuracy: {val_correct/len(val_dataset)}, Train Precision: {train_precision}, Train Recall: {train_recall}, Train F1: {train_f1}, Val Precision: {val_precision}, Val Recall: {val_recall}, Val F1: {val_f1}, Epoch Time: {epoch_time}")

                wandb.log({
                    'epoch': epoch + 1,
                    'loss': train_losses / len(train_loader),
                    'train_accuracy': train_correct / len(train_dataset),
                    'val_loss': val_losses / len(val_loader),
                    'val_accuracy': val_correct / len(val_dataset),
                    'train_precision': train_precision,
                    'train_recall': train_recall,
                    'train_f1': train_f1,
                    'val_precision': val_precision,
                    'val_recall': val_recall,
                    'val_f1': val_f1,
                    'train_concept_accuracy': train_concept_accuracy,
                    'train_task_accuracy': train_task_accuracy,
                    'val_concept_accuracy': val_concept_accuracy,
                    'val_task_accuracy': val_task_accuracy,
                    'epoch_time': epoch_time
                })

                # Check for convergence
                if val_correct / len(val_dataset) >= target_accuracy and convergence_time is None:
                    convergence_time = time.time() - total_start_time
                    wandb.log({'convergence_time': convergence_time})
                    print(f"Convergence achieved at epoch {epoch+1} with validation accuracy {val_correct/len(val_dataset)}")

            total_training_time = time.time() - total_start_time
            wandb.log({'total_training_time': total_training_time})
            print(f"Total Training Time: {total_training_time}")

            print(f"\n Training on {dataset_name} using {model_name} has been completed!")
            torch.save(model, f'model_{model_name}_{dataset_name}.pth')
            torch.save(model.state_dict(), f'model_state_dict_{model_name}_{dataset_name}.pth')
            wandb.finish()

        print(f"===========================================================")

    if model_name == 'LLR2':
        print(f"Training on {model_name} ... ")
        print(f"--------------------------------")

        for dataset_name, dataset in zip(dataset_names, datasets):
            x, c, y = dataset
            x_train, x_test, c_train, c_test, y_train, y_test = train_test_split(
                x, c, y, test_size=0.3, random_state=42)
            print(f"The following dataset has been loaded successfully: {dataset_name}")

            y_train = F.one_hot(y_train.long().ravel()).float()
            y_test = F.one_hot(y_test.long().ravel()).float()

            embedding_size = 16
            concept_encoder = torch.nn.Sequential(
                torch.nn.Linear(x.shape[1], 16),
                torch.nn.LeakyReLU(),
                te.nn.ConceptEmbedding(16, c.shape[1], embedding_size),
            )

            task_predictor = IntpLinearLayer2(embedding_size, y_train.shape[1], bias=isBias)
            model = torch.nn.Sequential(concept_encoder, task_predictor)

            num_val_samples = int(len(x_train) * 0.2)
            num_train_samples = len(x_train) - num_val_samples
            train_dataset, val_dataset = random_split(
                list(zip(x_train, c_train, y_train)), [num_train_samples, num_val_samples])

            train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
            val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)

            wandb.init(project="pytorch_explain", entity="alih999954-politecnico-di-torino",
                       name=f"{model_name}_{dataset_name}")

            config = {
                'lr': 0.0005,
                'task_loss_weight': 0.5,
                'loss_function': 'bce',
                'loss_function2': 'bceL',
            }
            wandb.config.update(config)

            c_loss = get_loss_function(wandb.config.loss_function2)
            y_loss = get_loss_function(wandb.config.loss_function2)

            optimizer = torch.optim.AdamW(model.parameters(), lr=wandb.config.lr)
            scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.0001, patience=7)

            print(f'-------------------------- Training {dataset_name} using {model_name} ----------------------')

            # Initialize timers and convergence flag
            total_start_time = time.time()
            convergence_time = None

            for epoch in range(41):
                epoch_start_time = time.time()
                model.train()
                train_losses, train_correct = 0, 0
                all_y_true_train, all_y_pred_train = [], []
                all_c_true_train, all_c_pred_train = [], []

                for x_batch, c_batch, y_batch in train_loader:
                    optimizer.zero_grad()
                    c_emb, c_pred = concept_encoder(x_batch)
                    y_pred = task_predictor(c_emb, c_pred)

                    concept_loss = c_loss(c_pred, c_batch)
                    task_loss = y_loss(y_pred, y_batch)
                    loss = concept_loss + 0.5 * task_loss

                    loss.backward()
                    optimizer.step()

                    train_losses += loss.item()
                    train_correct += (y_pred.argmax(1) == y_batch.argmax(1)).sum().item()
                    all_y_true_train.append(y_batch.cpu().numpy())
                    all_y_pred_train.append(y_pred.detach().cpu().numpy())
                    all_c_true_train.append(c_batch.cpu().numpy())
                    all_c_pred_train.append(c_pred.detach().cpu().numpy())

                    wandb.log({
                        'train_concept_loss': concept_loss.item(),
                        'train_task_loss': task_loss.item(),
                        'learning_rate': optimizer.param_groups[0]['lr']
                    })

                all_y_true_train = np.concatenate(all_y_true_train, axis=0)
                all_y_pred_train = np.concatenate(all_y_pred_train, axis=0)
                all_c_true_train = np.concatenate(all_c_true_train, axis=0)
                all_c_pred_train = np.concatenate(all_c_pred_train, axis=0)
                train_precision = precision_score(all_y_true_train.argmax(1), all_y_pred_train.argmax(1), average='weighted')
                train_recall = recall_score(all_y_true_train.argmax(1), all_y_pred_train.argmax(1), average='weighted')
                train_f1 = f1_score(all_y_true_train.argmax(1), all_y_pred_train.argmax(1), average='weighted')
                train_concept_accuracy = accuracy_score(all_c_true_train.argmax(1), all_c_pred_train.argmax(1))
                train_task_accuracy = accuracy_score(all_y_true_train.argmax(1), all_y_pred_train.argmax(1))

                model.eval()
                val_losses, val_correct = 0, 0
                all_y_true_val, all_y_pred_val = [], []
                all_c_true_val, all_c_pred_val = [], []

                with torch.no_grad():
                    for x_batch, c_batch, y_batch in val_loader:
                        c_emb, c_pred = concept_encoder(x_batch)
                        y_pred = task_predictor(c_emb, c_pred)

                        val_concept_loss = c_loss(c_pred, c_batch)
                        val_task_loss = y_loss(y_pred, y_batch)
                        val_loss = val_concept_loss + 0.5 * val_task_loss

                        val_losses += val_loss.item()
                        val_correct += (y_pred.argmax(1) == y_batch.argmax(1)).sum().item()
                        all_y_true_val.append(y_batch.cpu().numpy())
                        all_y_pred_val.append(y_pred.detach().cpu().numpy())
                        all_c_true_val.append(c_batch.cpu().numpy())
                        all_c_pred_val.append(c_pred.detach().cpu().numpy())

                        wandb.log({
                            'val_concept_loss': val_concept_loss.item(),
                            'val_task_loss': val_task_loss.item(),
                            'val_learning_rate': optimizer.param_groups[0]['lr']
                        })

                all_y_true_val = np.concatenate(all_y_true_val, axis=0)
                all_y_pred_val = np.concatenate(all_y_pred_val, axis=0)
                all_c_true_val = np.concatenate(all_c_true_val, axis=0)
                all_c_pred_val = np.concatenate(all_c_pred_val, axis=0)
                val_precision = precision_score(all_y_true_val.argmax(1), all_y_pred_val.argmax(1), average='weighted')
                val_recall = recall_score(all_y_true_val.argmax(1), all_y_pred_val.argmax(1), average='weighted')
                val_f1 = f1_score(all_y_true_val.argmax(1), all_y_pred_val.argmax(1), average='weighted')
                val_concept_accuracy = accuracy_score(all_c_true_val.argmax(1), all_c_pred_val.argmax(1))
                val_task_accuracy = accuracy_score(all_y_true_val.argmax(1), all_y_pred_val.argmax(1))

                scheduler.step(val_losses / len(val_loader))

                epoch_end_time = time.time()
                epoch_time = epoch_end_time - epoch_start_time

                print(f"Epoch {epoch+1}, Loss: {train_losses/len(train_loader)}, Train Accuracy: {train_correct/len(train_dataset)}, Val Loss: {val_losses/len(val_loader)}, Val Accuracy: {val_correct/len(val_dataset)}, Train Precision: {train_precision}, Train Recall: {train_recall}, Train F1: {train_f1}, Val Precision: {val_precision}, Val Recall: {val_recall}, Val F1: {val_f1}, Epoch Time: {epoch_time}")

                wandb.log({
                    'epoch': epoch + 1,
                    'loss': train_losses / len(train_loader),
                    'train_accuracy': train_correct / len(train_dataset),
                    'val_loss': val_losses / len(val_loader),
                    'val_accuracy': val_correct / len(val_dataset),
                    'train_precision': train_precision,
                    'train_recall': train_recall,
                    'train_f1': train_f1,
                    'val_precision': val_precision,
                    'val_recall': val_recall,
                    'val_f1': val_f1,
                    'train_concept_accuracy': train_concept_accuracy,
                    'train_task_accuracy': train_task_accuracy,
                    'val_concept_accuracy': val_concept_accuracy,
                    'val_task_accuracy': val_task_accuracy,
                    'epoch_time': epoch_time
                })

                # Check for convergence
                if val_correct / len(val_dataset) >= target_accuracy and convergence_time is None:
                    convergence_time = time.time() - total_start_time
                    wandb.log({'convergence_time': convergence_time})
                    print(f"Convergence achieved at epoch {epoch+1} with validation accuracy {val_correct/len(val_dataset)}")

            total_training_time = time.time() - total_start_time
            wandb.log({'total_training_time': total_training_time})
            print(f"Total Training Time: {total_training_time}")

            print(f"\n Training on {dataset_name} using {model_name} has been completed!")
            torch.save(model, f'model_{model_name}_{dataset_name}.pth')
            torch.save(model.state_dict(), f'model_state_dict_{model_name}_{dataset_name}.pth')
            wandb.finish()

        print(f"===========================================================")


    if model_name == 'LLR3':
        print(f"Training on {model_name} ... ")
        print(f"--------------------------------")

        for dataset_name, dataset in zip(dataset_names, datasets):
            x, c, y = dataset
            x_train, x_test, c_train, c_test, y_train, y_test = train_test_split(
                x, c, y, test_size=0.3, random_state=42)
            print(f"The following dataset has been loaded successfully: {dataset_name}")

            y_train = F.one_hot(y_train.long().ravel()).float()
            y_test = F.one_hot(y_test.long().ravel()).float()

            embedding_size = 16
            concept_encoder = torch.nn.Sequential(
                torch.nn.Linear(x.shape[1], 16),
                torch.nn.LeakyReLU(),
                te.nn.ConceptEmbedding(16, c.shape[1], embedding_size),
            )

            task_predictor = IntpLinearLayer3(embedding_size, y_train.shape[1], bias=isBias)
            model = torch.nn.Sequential(concept_encoder, task_predictor)

            num_val_samples = int(len(x_train) * 0.2)
            num_train_samples = len(x_train) - num_val_samples
            train_dataset, val_dataset = random_split(
                list(zip(x_train, c_train, y_train)), [num_train_samples, num_val_samples])

            train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
            val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)

            wandb.init(project="pytorch_explain", entity="alih999954-politecnico-di-torino",
                       name=f"{model_name}_{dataset_name}")

            config = {
                'lr': 0.0005,
                'task_loss_weight': 0.5,
                'loss_function': 'bce',
                'loss_function2': 'bceL',
            }
            wandb.config.update(config)

            c_loss = get_loss_function(wandb.config.loss_function2)
            y_loss = get_loss_function(wandb.config.loss_function2)

            optimizer = torch.optim.AdamW(model.parameters(), lr=wandb.config.lr)
            scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.0001, patience=7)

            print(f'-------------------------- Training {dataset_name} using {model_name} ----------------------')

            # Initialize timers and convergence flag
            total_start_time = time.time()
            convergence_time = None

            for epoch in range(41):
                epoch_start_time = time.time()
                model.train()
                train_losses, train_correct = 0, 0
                all_y_true_train, all_y_pred_train = [], []
                all_c_true_train, all_c_pred_train = [], []

                for x_batch, c_batch, y_batch in train_loader:
                    optimizer.zero_grad()
                    c_emb, c_pred = concept_encoder(x_batch)
                    y_pred = task_predictor(c_emb, c_pred)

                    concept_loss = c_loss(c_pred, c_batch)
                    task_loss = y_loss(y_pred, y_batch)
                    loss = concept_loss + 0.5 * task_loss

                    loss.backward()
                    optimizer.step()

                    train_losses += loss.item()
                    train_correct += (y_pred.argmax(1) == y_batch.argmax(1)).sum().item()
                    all_y_true_train.append(y_batch.cpu().numpy())
                    all_y_pred_train.append(y_pred.detach().cpu().numpy())
                    all_c_true_train.append(c_batch.cpu().numpy())
                    all_c_pred_train.append(c_pred.detach().cpu().numpy())

                    wandb.log({
                        'train_concept_loss': concept_loss.item(),
                        'train_task_loss': task_loss.item(),
                        'learning_rate': optimizer.param_groups[0]['lr']
                    })

                all_y_true_train = np.concatenate(all_y_true_train, axis=0)
                all_y_pred_train = np.concatenate(all_y_pred_train, axis=0)
                all_c_true_train = np.concatenate(all_c_true_train, axis=0)
                all_c_pred_train = np.concatenate(all_c_pred_train, axis=0)
                train_precision = precision_score(all_y_true_train.argmax(1), all_y_pred_train.argmax(1), average='weighted')
                train_recall = recall_score(all_y_true_train.argmax(1), all_y_pred_train.argmax(1), average='weighted')
                train_f1 = f1_score(all_y_true_train.argmax(1), all_y_pred_train.argmax(1), average='weighted')
                train_concept_accuracy = accuracy_score(all_c_true_train.argmax(1), all_c_pred_train.argmax(1))
                train_task_accuracy = accuracy_score(all_y_true_train.argmax(1), all_y_pred_train.argmax(1))

                model.eval()
                val_losses, val_correct = 0, 0
                all_y_true_val, all_y_pred_val = [], []
                all_c_true_val, all_c_pred_val = [], []

                with torch.no_grad():
                    for x_batch, c_batch, y_batch in val_loader:
                        c_emb, c_pred = concept_encoder(x_batch)
                        y_pred = task_predictor(c_emb, c_pred)

                        val_concept_loss = c_loss(c_pred, c_batch)
                        val_task_loss = y_loss(y_pred, y_batch)
                        val_loss = val_concept_loss + 0.5 * val_task_loss

                        val_losses += val_loss.item()
                        val_correct += (y_pred.argmax(1) == y_batch.argmax(1)).sum().item()
                        all_y_true_val.append(y_batch.cpu().numpy())
                        all_y_pred_val.append(y_pred.detach().cpu().numpy())
                        all_c_true_val.append(c_batch.cpu().numpy())
                        all_c_pred_val.append(c_pred.detach().cpu().numpy())

                        wandb.log({
                            'val_concept_loss': val_concept_loss.item(),
                            'val_task_loss': val_task_loss.item(),
                            'val_learning_rate': optimizer.param_groups[0]['lr']
                        })

                all_y_true_val = np.concatenate(all_y_true_val, axis=0)
                all_y_pred_val = np.concatenate(all_y_pred_val, axis=0)
                all_c_true_val = np.concatenate(all_c_true_val, axis=0)
                all_c_pred_val = np.concatenate(all_c_pred_val, axis=0)
                val_precision = precision_score(all_y_true_val.argmax(1), all_y_pred_val.argmax(1), average='weighted')
                val_recall = recall_score(all_y_true_val.argmax(1), all_y_pred_val.argmax(1), average='weighted')
                val_f1 = f1_score(all_y_true_val.argmax(1), all_y_pred_val.argmax(1), average='weighted')
                val_concept_accuracy = accuracy_score(all_c_true_val.argmax(1), all_c_pred_val.argmax(1))
                val_task_accuracy = accuracy_score(all_y_true_val.argmax(1), all_y_pred_val.argmax(1))

                scheduler.step(val_losses / len(val_loader))

                epoch_end_time = time.time()
                epoch_time = epoch_end_time - epoch_start_time

                print(f"Epoch {epoch+1}, Loss: {train_losses/len(train_loader)}, Train Accuracy: {train_correct/len(train_dataset)}, Val Loss: {val_losses/len(val_loader)}, Val Accuracy: {val_correct/len(val_dataset)}, Train Precision: {train_precision}, Train Recall: {train_recall}, Train F1: {train_f1}, Val Precision: {val_precision}, Val Recall: {val_recall}, Val F1: {val_f1}, Epoch Time: {epoch_time}")

                wandb.log({
                    'epoch': epoch + 1,
                    'loss': train_losses / len(train_loader),
                    'train_accuracy': train_correct / len(train_dataset),
                    'val_loss': val_losses / len(val_loader),
                    'val_accuracy': val_correct / len(val_dataset),
                    'train_precision': train_precision,
                    'train_recall': train_recall,
                    'train_f1': train_f1,
                    'val_precision': val_precision,
                    'val_recall': val_recall,
                    'val_f1': val_f1,
                    'train_concept_accuracy': train_concept_accuracy,
                    'train_task_accuracy': train_task_accuracy,
                    'val_concept_accuracy': val_concept_accuracy,
                    'val_task_accuracy': val_task_accuracy,
                    'epoch_time': epoch_time
                })

                # Check for convergence
                if val_correct / len(val_dataset) >= target_accuracy and convergence_time is None:
                    convergence_time = time.time() - total_start_time
                    wandb.log({'convergence_time': convergence_time})
                    print(f"Convergence achieved at epoch {epoch+1} with validation accuracy {val_correct/len(val_dataset)}")

            total_training_time = time.time() - total_start_time
            wandb.log({'total_training_time': total_training_time})
            print(f"Total Training Time: {total_training_time}")

            print(f"\n Training on {dataset_name} using {model_name} has been completed!")
            torch.save(model, f'model_{model_name}_{dataset_name}.pth')
            torch.save(model.state_dict(), f'model_state_dict_{model_name}_{dataset_name}.pth')
            wandb.finish()

        print(f"===========================================================")


print(f"*********** ALL TRAINING ARE DONE - Check WandB ***********")


Training Started!
Training on DCRBase ... 
--------------------------------
The following dataset has been loaded successfully: XOR


-------------------------- Training XOR using DCRBase ----------------------
Epoch 1, Loss: 1.0362716669386083, Train Accuracy: 0.5021428571428571, Val Loss: 1.025920737873424, Val Accuracy: 0.5314285714285715, Train Precision: 0.7508929851510496, Train Recall: 0.5021428571428571, Train F1: 0.33963210026551766, Val Precision: 0.7608187633262261, Val Recall: 0.5314285714285715, Val F1: 0.4093071070569951, Epoch Time: 0.3013465404510498
Epoch 2, Loss: 1.0126415274359963, Train Accuracy: 0.5853571428571429, Val Loss: 0.9984437010504983, Val Accuracy: 0.5957142857142858, Train Precision: 0.6883322809848206, Train Recall: 0.5853571428571429, Train F1: 0.5206922476032667, Val Precision: 0.6242992594998098, Val Recall: 0.5957142857142858, Val F1: 0.5766772485872117, Epoch Time: 0.31109023094177246
Epoch 3, Loss: 0.9782215858047659, Train Accuracy: 0.6067857142857143, Val Loss: 0.9562058502977545, Val Accuracy: 0.6314285714285715, Train Precision: 0.6089654533949249, Train Recall: 0.6067857142

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
convergence_time,▁
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇██
epoch_time,▂▂▂▁▂▇▆▅▆▆▄▆▅▇█▄▂▁▁▁▃▁▂▂▂▁▂▃▁▁▂▂▁▂▁▁▂▄▁▂
learning_rate,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
loss,███▇▇▆▅▅▄▄▃▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
total_training_time,▁
train_accuracy,▁▂▂▃▄▄▄▄▄▆▆▇▇▇██████████████████████████
train_concept_accuracy,▁▂▄▇▇▇▇█████████████████████████████████
train_concept_loss,███▇▆▅▄▃▃▃▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_f1,▁▃▄▄▅▅▅▅▅▆▇▇▇███████████████████████████

0,1
convergence_time,4.07829
epoch,41.0
epoch_time,0.31536
learning_rate,0.0005
loss,0.08548
total_training_time,13.95733
train_accuracy,0.9925
train_concept_accuracy,0.74286
train_concept_loss,0.05369
train_f1,0.9925


The following dataset has been loaded successfully: XNOR


-------------------------- Training XNOR using DCRBase ----------------------
Epoch 1, Loss: 1.0339123417030682, Train Accuracy: 0.49642857142857144, Val Loss: 1.023291371085427, Val Accuracy: 0.5314285714285715, Train Precision: 0.7502847438582154, Train Recall: 0.49642857142857144, Train F1: 0.33056105005115766, Val Precision: 0.7361263154031763, Val Recall: 0.5314285714285715, Val F1: 0.40041050284805774, Epoch Time: 0.28998541831970215
Epoch 2, Loss: 1.0096844475377689, Train Accuracy: 0.5603571428571429, Val Loss: 0.9937947175719521, Val Accuracy: 0.5842857142857143, Train Precision: 0.5840025510204082, Train Recall: 0.5603571428571429, Train F1: 0.5311294291902021, Val Precision: 0.5857245560561887, Val Recall: 0.5842857142857143, Val F1: 0.5827848423439831, Epoch Time: 0.2968778610229492
Epoch 3, Loss: 0.9720776053992185, Train Accuracy: 0.6678571428571428, Val Loss: 0.9485270001671531, Val Accuracy: 0.6928571428571428, Train Precision: 0.6856588956145333, Train Recall: 0.667857

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
convergence_time,▁
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇██
epoch_time,▁▁▁▁▅▁▁▂▂▁▂▂▂▂▂▂▁▂▂▂▂▄▅▇▆▆▆▄█▆▇▇▂▂▁▁▂▁▂▂
learning_rate,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
loss,███▇▇▆▅▅▄▄▃▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
total_training_time,▁
train_accuracy,▁▂▃▄▄▄▄▅▆▆▇▇▇███████████████████████████
train_concept_accuracy,▇▁▅▆█▇▇▅▅▅▇▆▇▇▆█▇▇▇▆▆▆▅▆▆▆▆▆▆▆▅▆▆▅▇▅▆▆▆▆
train_concept_loss,███▇▆▅▄▄▃▃▂▂▂▂▂▂▁▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_f1,▁▃▄▅▅▅▅▆▆▇▇▇████████████████████████████

0,1
convergence_time,3.37883
epoch,41.0
epoch_time,0.30413
learning_rate,0.0005
loss,0.0815
total_training_time,13.99469
train_accuracy,0.99179
train_concept_accuracy,0.7475
train_concept_loss,0.04733
train_f1,0.99179


The following dataset has been loaded successfully: IsBinEven


-------------------------- Training IsBinEven using DCRBase ----------------------
Epoch 1, Loss: 1.0319461280649358, Train Accuracy: 0.4932142857142857, Val Loss: 1.0240461392836138, Val Accuracy: 0.49857142857142855, Train Precision: 0.24326033163265307, Train Recall: 0.4932142857142857, Train F1: 0.3258210612635392, Val Precision: 0.2485734693877551, Val Recall: 0.49857142857142855, Val F1: 0.3317472422715511, Epoch Time: 0.5314414501190186
Epoch 2, Loss: 1.014671279625459, Train Accuracy: 0.49392857142857144, Val Loss: 1.0018833875656128, Val Accuracy: 0.5028571428571429, Train Precision: 0.7502199274992342, Train Recall: 0.49392857142857144, Train F1: 0.3274035545846925, Val Precision: 0.7510719409715105, Val Recall: 0.5028571428571429, Val F1: 0.34119750554400713, Epoch Time: 0.48102664947509766
Epoch 3, Loss: 0.9844159361991015, Train Accuracy: 0.5071428571428571, Val Loss: 0.9625050859017805, Val Accuracy: 0.5442857142857143, Train Precision: 0.7534821751953226, Train Recall: 0

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
convergence_time,▁
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇██
epoch_time,▇▆█▅▇▇█▅▁▃▂▂▁▂▂▂▃▂▁▂▂▁▃▁▂▂▁▂▂▁▂▂▁▂▁▇▆▇▅▆
learning_rate,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
loss,███▇▇▆▆▅▄▄▃▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
total_training_time,▁
train_accuracy,▁▁▁▂▃▄▅▇▇███████████████████████████████
train_concept_accuracy,▁▃▆▆▃▅▅▇▇█▇▇▇▇▇█████████████████████████
train_concept_loss,███▇▆▆▅▅▄▃▃▃▂▂▂▂▂▂▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_f1,▁▁▁▂▄▅▆▇████████████████████████████████

0,1
convergence_time,4.11986
epoch,41.0
epoch_time,0.49346
learning_rate,0.0005
loss,0.09683
total_training_time,16.9251
train_accuracy,0.99607
train_concept_accuracy,0.52143
train_concept_loss,0.05197
train_f1,0.99607


The following dataset has been loaded successfully: Trigonometry


-------------------------- Training Trigonometry using DCRBase ----------------------
Epoch 1, Loss: 0.9833151549100876, Train Accuracy: 0.6442857142857142, Val Loss: 0.9340455477887933, Val Accuracy: 0.7385714285714285, Train Precision: 0.6307482993197279, Train Recall: 0.6442857142857142, Train F1: 0.6235341757702274, Val Precision: 0.7464228456322965, Val Recall: 0.7385714285714285, Val F1: 0.7406574155328336, Epoch Time: 0.3372673988342285
Epoch 2, Loss: 0.8720612593672492, Train Accuracy: 0.7882142857142858, Val Loss: 0.8180320425467058, Val Accuracy: 0.7928571428571428, Train Precision: 0.7972142407480612, Train Recall: 0.7882142857142858, Train F1: 0.7901751529265294, Val Precision: 0.8046136259555843, Val Recall: 0.7928571428571428, Val F1: 0.7948619248104668, Epoch Time: 0.3375515937805176
Epoch 3, Loss: 0.7394181205467745, Train Accuracy: 0.8414285714285714, Val Loss: 0.6743861003355547, Val Accuracy: 0.8557142857142858, Train Precision: 0.8442657683613314, Train Recall: 0.84

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
convergence_time,▁
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇██
epoch_time,▂▂▁▂▃▁▂▂▁▂▂▁▂▇▆▅▇▆▅▆███▄▂▂▂▁▂▂▂▁▃▁▁▂▂▂▂▁
learning_rate,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
loss,█▇▆▅▄▃▃▃▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
total_training_time,▁
train_accuracy,▁▄▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇█▇█████████████████
train_concept_accuracy,▁▃███▇▆▇▇▇▇▇▇▇▇▆▇▆▆▆▆▆▆▆▆▆▆▅▅▅▅▅▅▅▅▅▅▅▅▅
train_concept_loss,█▇▆▅▄▃▃▃▂▂▂▂▂▂▂▂▁▂▂▁▂▁▂▁▁▁▂▁▁▁▂▁▁▁▁▁▁▁▁▁
train_f1,▁▅▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇█▇█████████████████

0,1
convergence_time,6.54324
epoch,41.0
epoch_time,0.32055
learning_rate,0.0005
loss,0.20436
total_training_time,14.9014
train_accuracy,0.94286
train_concept_accuracy,0.55286
train_concept_loss,0.08645
train_f1,0.94212


The following dataset has been loaded successfully: Dot


-------------------------- Training Dot using DCRBase ----------------------
Epoch 1, Loss: 1.0647607093507594, Train Accuracy: 0.5492857142857143, Val Loss: 1.0129291469400579, Val Accuracy: 0.5142857142857142, Train Precision: 0.5532212383883176, Train Recall: 0.5492857142857143, Train F1: 0.546675971764316, Val Precision: 0.5142697168453046, Val Recall: 0.5142857142857142, Val F1: 0.5096598639455782, Epoch Time: 0.3051440715789795
Epoch 2, Loss: 0.9714498817920685, Train Accuracy: 0.5128571428571429, Val Loss: 0.9213082736188715, Val Accuracy: 0.5085714285714286, Train Precision: 0.51267055438883, Train Recall: 0.5128571428571429, Train F1: 0.5127377021618414, Val Precision: 0.5106658595641645, Val Recall: 0.5085714285714286, Val F1: 0.4951992225461613, Epoch Time: 0.2969224452972412
Epoch 3, Loss: 0.8671358837322756, Train Accuracy: 0.6414285714285715, Val Loss: 0.8069654378024015, Val Accuracy: 0.6471428571428571, Train Precision: 0.6511439391417629, Train Recall: 0.64142857142857

VBox(children=(Label(value='0.001 MB of 0.012 MB uploaded\r'), FloatProgress(value=0.11634864165588615, max=1.…

0,1
convergence_time,▁
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇██
epoch_time,▂▂▂▁▃▂▂▃▂▁▃▂▁▂▃▂▁▃▂▁▃▂▁▂▃▂▂▃▂▁▃▆▆█▆▇█▆██
learning_rate,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
loss,█▇▇▆▅▄▄▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
total_training_time,▁
train_accuracy,▂▁▃▄▄▅▅▅▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███████████████
train_concept_accuracy,▁▆██▇▇███▇▇▇▇▇▇▇▇▆▇▆▆▆▆▆▆▆▆▆▆▆▆▆▆▇▆▇▆▇▇▇
train_concept_loss,█▇▆▅▄▄▃▂▂▂▂▂▂▁▂▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_f1,▂▁▃▄▄▅▅▅▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███████████████

0,1
convergence_time,6.6662
epoch,41.0
epoch_time,0.43595
learning_rate,0.0005
loss,0.096
total_training_time,13.5467
train_accuracy,0.96
train_concept_accuracy,0.74429
train_concept_loss,0.03281
train_f1,0.95999


Training on LLR1 ... 
--------------------------------
The following dataset has been loaded successfully: XOR


-------------------------- Training XOR using LLR1 ----------------------
Epoch 1, Loss: 1.0867823199792341, Train Accuracy: 0.44785714285714284, Val Loss: 1.0760117335753008, Val Accuracy: 0.4857142857142857, Train Precision: 0.4287976714317761, Train Recall: 0.44785714285714284, Train F1: 0.409320015715605, Val Precision: 0.4606690869118732, Val Recall: 0.4857142857142857, Val F1: 0.422344390114973, Epoch Time: 0.25629425048828125
Epoch 2, Loss: 1.0737389759583906, Train Accuracy: 0.5107142857142857, Val Loss: 1.06346641887318, Val Accuracy: 0.5128571428571429, Train Precision: 0.5561287477954144, Train Recall: 0.5107142857142857, Train F1: 0.38307105472218334, Val Precision: 0.2630224489795918, Val Recall: 0.5128571428571429, Val F1: 0.3477161742884123, Epoch Time: 0.2920675277709961
Epoch 3, Loss: 1.0610808337276632, Train Accuracy: 0.5010714285714286, Val Loss: 1.050393982367082, Val Accuracy: 0.5128571428571429, Train Precision: 0.25107257653061227, Train Recall: 0.50107142857142

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
convergence_time,▁
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇██
epoch_time,▁▂▂▂▂▃▂▂▁▁▁▂▂▃▆▆▅█▅▆▅▅█▇█▆▁▁▂▂▁▃▁▁▂▂▂▂▂▂
learning_rate,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
loss,███▇▇▇▆▆▆▅▄▄▃▃▃▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
total_training_time,▁
train_accuracy,▁▂▂▂▁▂▂▄▅▆▆▆▇▇▇▇▇▇██████████████████████
train_concept_accuracy,▁▂▅▆▇▇██████████████████████████████████
train_concept_loss,▇█▆█▆▆▅▄▅▄▅▅▃▂▃▁▂▃▂▂▃▂▃▁▂▂▂▁▂▂▁▂▂▃▄▂▁▁▂▂
train_f1,▂▂▁▁▂▃▂▄▆▆▇▇▇▇▇▇▇███████████████████████

0,1
convergence_time,3.617
epoch,41.0
epoch_time,0.27051
learning_rate,0.0005
loss,0.54281
total_training_time,12.82046
train_accuracy,0.99214
train_concept_accuracy,0.74786
train_concept_loss,0.52721
train_f1,0.99214


The following dataset has been loaded successfully: XNOR


-------------------------- Training XNOR using LLR1 ----------------------
Epoch 1, Loss: 1.0767670951106332, Train Accuracy: 0.48892857142857143, Val Loss: 1.0617602413350886, Val Accuracy: 0.5271428571428571, Train Precision: 0.2390511479591837, Train Recall: 0.48892857142857143, Train F1: 0.3211049240996471, Val Precision: 0.2778795918367347, Val Recall: 0.5271428571428571, Val F1: 0.36392088734464784, Epoch Time: 0.36129140853881836
Epoch 2, Loss: 1.0603158338503405, Train Accuracy: 0.5514285714285714, Val Loss: 1.0479580922560259, Val Accuracy: 0.42714285714285716, Train Precision: 0.6365702040128371, Train Recall: 0.5514285714285714, Train F1: 0.4798972915926795, Val Precision: 0.42977879665379665, Val Recall: 0.42714285714285716, Val F1: 0.41711970506912444, Epoch Time: 0.40888214111328125
Epoch 3, Loss: 1.0479966754263097, Train Accuracy: 0.5067857142857143, Val Loss: 1.039899630980058, Val Accuracy: 0.47285714285714286, Train Precision: 0.4689255428386829, Train Recall: 0.5067

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
convergence_time,▁
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇██
epoch_time,▄▆▇▇█▃▂▂▁▂▂▁▂▁▂▁▁▂▂▁▂▁▁▂▂▂▂▁▁▁▁▂▂▁▁▁▁▁▁▁
learning_rate,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
loss,███▇▇▇▇▆▆▅▅▄▄▃▃▃▃▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
total_training_time,▁
train_accuracy,▁▂▁▁▁▁▁▂▄▆▆▆▆▆▆▇▇▇▇▇████████████████████
train_concept_accuracy,▁▆▇▇▆▆▆▇▇███████████████████████████████
train_concept_loss,█▇▇▆▆▅▆▄▅▃▄▄▃▃▃▄▂▂▂▂▃▂▂▂▂▃▂▃▂▂▃▂▃▂▂▂▁▂▂▁
train_f1,▁▃▁▁▁▁▁▂▅▆▆▆▇▇▇▇▇▇▇▇████████████████████

0,1
convergence_time,5.52567
epoch,41.0
epoch_time,0.2637
learning_rate,0.0005
loss,0.54798
total_training_time,12.06918
train_accuracy,0.99071
train_concept_accuracy,0.76036
train_concept_loss,0.52727
train_f1,0.99071


The following dataset has been loaded successfully: IsBinEven


-------------------------- Training IsBinEven using LLR1 ----------------------
Epoch 1, Loss: 1.127017926086079, Train Accuracy: 0.5021428571428571, Val Loss: 1.1052432602102107, Val Accuracy: 0.52, Train Precision: 0.25214744897959185, Train Recall: 0.5021428571428571, Train F1: 0.33571700292099715, Val Precision: 0.27040000000000003, Val Recall: 0.52, Val F1: 0.35578947368421054, Epoch Time: 0.33426499366760254
Epoch 2, Loss: 1.089910704981197, Train Accuracy: 0.5021428571428571, Val Loss: 1.0665809024464001, Val Accuracy: 0.52, Train Precision: 0.25214744897959185, Train Recall: 0.5021428571428571, Train F1: 0.33571700292099715, Val Precision: 0.27040000000000003, Val Recall: 0.52, Val F1: 0.35578947368421054, Epoch Time: 0.35646891593933105
Epoch 3, Loss: 1.0540718084031886, Train Accuracy: 0.6171428571428571, Val Loss: 1.0394865599545566, Val Accuracy: 0.8771428571428571, Train Precision: 0.779769890826028, Train Recall: 0.6171428571428571, Train F1: 0.5509712641709765, Val Preci

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
convergence_time,▁
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇██
epoch_time,▂▃▁▂▄▁▂▂▁▂▃▁▄▇▆█▇▆▆██▇█▂▂▂▂▂▂▂▂▂▂▁▁▃▂▁▃▁
learning_rate,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
loss,██▇▇▇▆▆▅▄▄▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
total_training_time,▁
train_accuracy,▁▁▃▇▆▆▆▆▆▇▇▇████████████████████████████
train_concept_accuracy,▂▁▃▃▂▂▂▂▂▃▄▄▅▅▅▅▅▅▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇███████
train_concept_loss,██▇▇█▇▇▇▆▅▅▄▃▃▃▄▃▄▃▂▃▃▂▃▂▂▂▃▂▂▂▂▂▂▁▂▂▂▂▂
train_f1,▁▁▃▇▇▇▆▇▇▇▇█████████████████████████████

0,1
convergence_time,3.35886
epoch,41.0
epoch_time,0.31257
learning_rate,0.0005
loss,0.54317
total_training_time,15.0135
train_accuracy,0.99536
train_concept_accuracy,0.55714
train_concept_loss,0.56188
train_f1,0.99536


The following dataset has been loaded successfully: Trigonometry


-------------------------- Training Trigonometry using LLR1 ----------------------
Epoch 1, Loss: 1.1045652438293805, Train Accuracy: 0.6075, Val Loss: 1.0590358972549438, Val Accuracy: 0.5785714285714286, Train Precision: 0.4820896739130435, Train Recall: 0.6075, Train F1: 0.47086737513690125, Val Precision: 0.3347448979591837, Val Recall: 0.5785714285714286, Val F1: 0.42411118293471234, Epoch Time: 0.30114221572875977
Epoch 2, Loss: 1.0384912815960972, Train Accuracy: 0.6139285714285714, Val Loss: 1.0183833187276667, Val Accuracy: 0.5785714285714286, Train Precision: 0.3769082908163265, Train Recall: 0.6139285714285714, Train F1: 0.4670693579489773, Val Precision: 0.3347448979591837, Val Recall: 0.5785714285714286, Val F1: 0.42411118293471234, Epoch Time: 0.2981853485107422
Epoch 3, Loss: 0.9907256026159633, Train Accuracy: 0.6475, Val Loss: 0.9579563736915588, Val Accuracy: 0.7942857142857143, Train Precision: 0.7760726164079823, Train Recall: 0.6475, Train F1: 0.5387627118644068, V

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
convergence_time,▁
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇██
epoch_time,▁▁▁▁▂▃▂▁▁▂▁▁▁▂▁▁▂▁▂▂▁▁▁▂▁▁▂▁▁▂▆▅█▅▆▄▅▆▅▇
learning_rate,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
loss,█▇▇▆▄▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
total_training_time,▁
train_accuracy,▁▁▂▆▇▇▇▇▇█▇█████████████████████████████
train_concept_accuracy,▁▅▇██▇▇▇▆▆▆▆▆▆▆▆▆▆▆▆▆▆▅▆▆▅▆▆▅▆▅▆▅▆▅▅▆▆▆▅
train_concept_loss,█▇▇▆▄▄▄▃▃▃▃▃▃▃▂▂▂▃▁▂▃▂▁▂▂▂▂▂▂▂▂▂▂▂▂▁▂▂▂▂
train_f1,▁▁▂▆▇▇▇█████████████████████████████████

0,1
convergence_time,1.20499
epoch,41.0
epoch_time,0.44961
learning_rate,0.0005
loss,0.53958
total_training_time,13.90971
train_accuracy,0.98714
train_concept_accuracy,0.63679
train_concept_loss,0.52159
train_f1,0.98713


The following dataset has been loaded successfully: Dot


-------------------------- Training Dot using LLR1 ----------------------
Epoch 1, Loss: 1.0562793477015062, Train Accuracy: 0.5142857142857142, Val Loss: 1.038515643639998, Val Accuracy: 0.54, Train Precision: 0.7162239770279971, Train Recall: 0.5142857142857142, Train F1: 0.35459770941498525, Val Precision: 0.7356556717618665, Val Recall: 0.54, Val F1: 0.40884466884466886, Epoch Time: 0.27113771438598633
Epoch 2, Loss: 1.0221417302435094, Train Accuracy: 0.5678571428571428, Val Loss: 1.0039093060926958, Val Accuracy: 0.5885714285714285, Train Precision: 0.6628217261250907, Train Recall: 0.5678571428571428, Train F1: 0.4851494572128818, Val Precision: 0.6568742163876392, Val Recall: 0.5885714285714285, Val F1: 0.5301163985350466, Epoch Time: 0.27390503883361816
Epoch 3, Loss: 0.9858959737149152, Train Accuracy: 0.6096428571428572, Val Loss: 0.9649622331966053, Val Accuracy: 0.6242857142857143, Train Precision: 0.6456171154157448, Train Recall: 0.6096428571428572, Train F1: 0.578895655

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
convergence_time,▁
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇██
epoch_time,▂▂▂▂▂▇▅▆▆▅▆▄▇▆▆▆▅█▄▂▂▂▂▁▃▂▂▂▃▂▂▃▂▁▁▃▂▁▂▂
learning_rate,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
loss,██▇▆▆▅▅▄▄▄▃▃▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
total_training_time,▁
train_accuracy,▁▂▂▃▄▄▅▅▅▅▆▆▆▇▇▇▇███████████████████████
train_concept_accuracy,▇██▇▅▃▁▁▁▂▂▂▃▂▂▂▂▂▂▂▂▂▂▂▂▂▃▂▂▃▂▂▃▃▃▃▃▃▃▃
train_concept_loss,█▆▅▅▄▃▃▃▂▂▁▃▃▂▃▂▂▂▂▂▃▂▃▃▂▂▃▂▃▁▁▂▂▁▂▁▂▃▂▁
train_f1,▁▂▃▄▅▅▆▆▆▆▆▇▇▇▇▇████████████████████████

0,1
convergence_time,4.87004
epoch,41.0
epoch_time,0.26848
learning_rate,0.0005
loss,0.54714
total_training_time,12.89566
train_accuracy,0.98393
train_concept_accuracy,0.75393
train_concept_loss,0.53125
train_f1,0.98393


Training on LLR2 ... 
--------------------------------
The following dataset has been loaded successfully: XOR


-------------------------- Training XOR using LLR2 ----------------------
Epoch 1, Loss: 1.120077829469334, Train Accuracy: 0.5139285714285714, Val Loss: 1.1224551417610862, Val Accuracy: 0.4614285714285714, Train Precision: 0.2641225765306122, Train Recall: 0.5139285714285714, Train F1: 0.3489234320762983, Val Precision: 0.2129163265306122, Val Recall: 0.4614285714285714, Val F1: 0.29138109202625334, Epoch Time: 0.31279754638671875
Epoch 2, Loss: 1.0934011773629622, Train Accuracy: 0.5139285714285714, Val Loss: 1.089286576617848, Val Accuracy: 0.4614285714285714, Train Precision: 0.2641225765306122, Train Recall: 0.5139285714285714, Train F1: 0.3489234320762983, Val Precision: 0.2129163265306122, Val Recall: 0.4614285714285714, Val F1: 0.29138109202625334, Epoch Time: 0.2969703674316406
Epoch 3, Loss: 1.0576325411146337, Train Accuracy: 0.5139285714285714, Val Loss: 1.0546899383718318, Val Accuracy: 0.4614285714285714, Train Precision: 0.2641225765306122, Train Recall: 0.5139285714285

VBox(children=(Label(value='0.014 MB of 0.024 MB uploaded (0.009 MB deduped)\r'), FloatProgress(value=0.566802…

0,1
convergence_time,▁
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇██
epoch_time,▂▁▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▂▂▂▁▁▁▂▂▂▆▇▇▇█▅▇▇▇▇▇██▆
learning_rate,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
loss,██▇▇▇▆▆▆▅▄▄▃▃▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
total_training_time,▁
train_accuracy,▁▁▁▁▁▂▃▅▆▆▆▆▆▆▇▇▇▇██████████████████████
train_concept_accuracy,▄▅█▄▁▁▂▃▄▄▅▅▄▅▄▄▄▅▅▅▅▅▅▅▅▅▅▄▅▄▄▄▄▄▄▄▄▄▄▄
train_concept_loss,▇█▆▇▇▆▅▄▄▃▃▄▃▂▃▃▁▂▁▃▂▃▂▂▂▂▂▃▂▂▁▁▃▃▁▂▂▂▃▂
train_f1,▁▁▁▁▁▃▄▆▆▆▇▇▇▇▇▇▇███████████████████████

0,1
convergence_time,4.59233
epoch,41.0
epoch_time,0.45283
learning_rate,0.0005
loss,0.54476
total_training_time,15.56116
train_accuracy,0.99036
train_concept_accuracy,0.74464
train_concept_loss,0.50751
train_f1,0.99036


The following dataset has been loaded successfully: XNOR


-------------------------- Training XNOR using LLR2 ----------------------
Epoch 1, Loss: 1.1316345740448346, Train Accuracy: 0.4992857142857143, Val Loss: 1.1172700253399936, Val Accuracy: 0.4857142857142857, Train Precision: 0.2492862244897959, Train Recall: 0.4992857142857143, Train F1: 0.33253998502688353, Val Precision: 0.23591836734693877, Val Recall: 0.4857142857142857, Val F1: 0.3175824175824176, Epoch Time: 0.290330171585083
Epoch 2, Loss: 1.1057612110267987, Train Accuracy: 0.4992857142857143, Val Loss: 1.0832103057341143, Val Accuracy: 0.4857142857142857, Train Precision: 0.2492862244897959, Train Recall: 0.4992857142857143, Train F1: 0.33253998502688353, Val Precision: 0.23591836734693877, Val Recall: 0.4857142857142857, Val F1: 0.3175824175824176, Epoch Time: 0.3268916606903076
Epoch 3, Loss: 1.0650280497290872, Train Accuracy: 0.4992857142857143, Val Loss: 1.0429004322398792, Val Accuracy: 0.4857142857142857, Train Precision: 0.2492862244897959, Train Recall: 0.4992857142

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
convergence_time,▁
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇██
epoch_time,▁▂▂▁▅▅▅▆▅▆▆▆▆██▄▂▃▃▂▁▂▂▁▂▂▂▂▂▁▂▂▂▁▃▂▁▂▁▁
learning_rate,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
loss,██▇▇▇▆▆▅▅▄▄▄▃▃▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
total_training_time,▁
train_accuracy,▁▁▁▂▃▄▄▅▆▆▆▆▆▆▆▆▇▇▇▇▇███████████████████
train_concept_accuracy,▃██▄▃▃▄▃▃▃▂▃▃▂▂▂▂▂▁▁▁▁▁▁▁▁▁▂▂▂▂▂▂▂▂▂▂▂▂▂
train_concept_loss,█▇▇▆▆▅▅▄▄▄▃▃▄▃▂▃▂▁▂▂▂▃▃▂▃▃▂▂▂▂▂▂▂▁▁▂▂▂▁▃
train_f1,▁▁▁▂▄▅▄▆▆▆▆▇▇▇▇▇▇▇▇▇████████████████████

0,1
convergence_time,6.69263
epoch,41.0
epoch_time,0.29859
learning_rate,0.0005
loss,0.54955
total_training_time,14.2212
train_accuracy,0.98821
train_concept_accuracy,0.72571
train_concept_loss,0.52527
train_f1,0.98821


The following dataset has been loaded successfully: IsBinEven


-------------------------- Training IsBinEven using LLR2 ----------------------
Epoch 1, Loss: 1.3158762915567919, Train Accuracy: 0.485, Val Loss: 1.2837038473649458, Val Accuracy: 0.5314285714285715, Train Precision: 0.235225, Train Recall: 0.485, Train F1: 0.3168013468013468, Val Precision: 0.28241632653061227, Val Recall: 0.5314285714285715, Val F1: 0.36882729211087417, Epoch Time: 0.3786458969116211
Epoch 2, Loss: 1.249566601081328, Train Accuracy: 0.485, Val Loss: 1.1856052442030474, Val Accuracy: 0.5314285714285715, Train Precision: 0.235225, Train Recall: 0.485, Train F1: 0.3168013468013468, Val Precision: 0.28241632653061227, Val Recall: 0.5314285714285715, Val F1: 0.36882729211087417, Epoch Time: 0.35160207748413086
Epoch 3, Loss: 1.1206750517541713, Train Accuracy: 0.485, Val Loss: 1.0544093738902698, Val Accuracy: 0.5314285714285715, Train Precision: 0.235225, Train Recall: 0.485, Train F1: 0.3168013468013468, Val Precision: 0.28241632653061227, Val Recall: 0.53142857142857

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
convergence_time,▁
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇██
epoch_time,▂▁▂▄▁▂▁▁▂▂▂▂▁▂▅▅▆▆▅▅▆▆▇█▂▁▁▂▂▁▂▁▁▂▁▂▁▂▂▁
learning_rate,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
loss,█▇▆▆▆▅▅▅▄▄▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
total_training_time,▁
train_accuracy,▁▁▁▁▂▆▇▇▇███████████████████████████████
train_concept_accuracy,▁▁▃▆███▇▇▇▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆
train_concept_loss,██▇▇▆▆▆▆▅▄▄▄▄▃▃▃▃▂▂▂▂▂▁▂▂▃▂▂▂▂▁▂▂▁▁▁▁▁▂▂
train_f1,▁▁▁▁▂▇▇▇▇███████████████████████████████

0,1
convergence_time,2.98083
epoch,41.0
epoch_time,0.35212
learning_rate,0.0005
loss,0.54251
total_training_time,16.47069
train_accuracy,0.99607
train_concept_accuracy,0.51321
train_concept_loss,0.5642
train_f1,0.99607


The following dataset has been loaded successfully: Trigonometry


-------------------------- Training Trigonometry using LLR2 ----------------------
Epoch 1, Loss: 1.1432519880208103, Train Accuracy: 0.39357142857142857, Val Loss: 1.100565498525446, Val Accuracy: 0.44571428571428573, Train Precision: 0.1548984693877551, Train Recall: 0.39357142857142857, Train F1: 0.2223043128066193, Val Precision: 0.7705826499784203, Val Recall: 0.44571428571428573, Val F1: 0.3288495221253842, Epoch Time: 0.3200874328613281
Epoch 2, Loss: 1.0641399947079746, Train Accuracy: 0.65, Val Loss: 1.0395771156657825, Val Accuracy: 0.6085714285714285, Train Precision: 0.6391533857226792, Train Recall: 0.65, Train F1: 0.6201635119523683, Val Precision: 0.37035918367346937, Val Recall: 0.6085714285714285, Val F1: 0.46048211113930476, Epoch Time: 0.36556053161621094
Epoch 3, Loss: 1.009549082680182, Train Accuracy: 0.6175, Val Loss: 0.9796562357382341, Val Accuracy: 0.74, Train Precision: 0.7654442036836404, Train Recall: 0.6175, Train F1: 0.48256933902457383, Val Precision: 0.

VBox(children=(Label(value='0.001 MB of 0.026 MB uploaded\r'), FloatProgress(value=0.052955256858773704, max=1…

0,1
convergence_time,▁
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇██
epoch_time,▁▃▁▁▃▁▂▃▂▂▃▁▃▃▂▂▂▃▁▂▁▅▆▇▇▅▅▅▇███▂▂▂▁▂▂▂▁
learning_rate,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
loss,█▇▆▆▄▃▃▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
total_training_time,▁
train_accuracy,▁▄▄▆▆▇▇▇▇▇██████████████████████████████
train_concept_accuracy,▁▃▅█▆▄▃▃▃▅▄▄▄▄▄▄▄▄▄▄▄▄▄▄▃▃▃▃▃▄▃▄▃▃▃▃▃▃▃▄
train_concept_loss,███▆▄▄▄▃▂▃▃▃▂▂▂▂▂▂▂▁▂▂▂▁▁▁▃▂▂▂▂▁▂▂▂▁▂▁▂▂
train_f1,▁▅▃▆▇▇▇█████████████████████████████████

0,1
convergence_time,2.02285
epoch,41.0
epoch_time,0.31748
learning_rate,0.0005
loss,0.53388
total_training_time,15.17304
train_accuracy,0.98786
train_concept_accuracy,0.69893
train_concept_loss,0.48181
train_f1,0.98786


The following dataset has been loaded successfully: Dot


-------------------------- Training Dot using LLR2 ----------------------
Epoch 1, Loss: 1.1257635978135196, Train Accuracy: 0.4853571428571429, Val Loss: 1.1196441650390625, Val Accuracy: 0.5071428571428571, Train Precision: 0.235571556122449, Train Recall: 0.4853571428571429, Train F1: 0.31719180434857286, Val Precision: 0.2571938775510204, Val Recall: 0.5071428571428571, Val F1: 0.3412999322951929, Epoch Time: 0.3057565689086914
Epoch 2, Loss: 1.0855651064352556, Train Accuracy: 0.4853571428571429, Val Loss: 1.067568919875405, Val Accuracy: 0.5071428571428571, Train Precision: 0.235571556122449, Train Recall: 0.4853571428571429, Train F1: 0.31719180434857286, Val Precision: 0.2571938775510204, Val Recall: 0.5071428571428571, Val F1: 0.3412999322951929, Epoch Time: 0.2976851463317871
Epoch 3, Loss: 1.0285326960411938, Train Accuracy: 0.48714285714285716, Val Loss: 1.0057718428698452, Val Accuracy: 0.5571428571428572, Train Precision: 0.750635829286992, Train Recall: 0.487142857142857

VBox(children=(Label(value='0.001 MB of 0.011 MB uploaded\r'), FloatProgress(value=0.12272921108742005, max=1.…

0,1
convergence_time,▁
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇██
epoch_time,▂▂▂▁▂▂▁▂▂▁▂▁▁▁▂▁▁▂▂▁▁▂▁▁▂▅▄▅▆▆▅▅▅█▇▇▂▂▁▂
learning_rate,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
loss,██▇▆▆▅▄▄▃▃▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
total_training_time,▁
train_accuracy,▁▁▁▄▅▆▆▆▆▆▇▇▇▇▇▇████████████████████████
train_concept_accuracy,▁▄▅▇▇██████████████▇▇███▇███████████████
train_concept_loss,▇█▆▆▄▃▃▃▂▃▃▂▂▃▃▃▂▂▃▃▂▁▂▂▂▃▂▂▃▂▂▂▂▃▂▂▁▂▂▂
train_f1,▁▁▁▅▆▆▆▆▇▇▇▇▇▇▇█████████████████████████

0,1
convergence_time,4.15127
epoch,41.0
epoch_time,0.29189
learning_rate,0.0005
loss,0.54708
total_training_time,13.61725
train_accuracy,0.98214
train_concept_accuracy,0.72179
train_concept_loss,0.5164
train_f1,0.98214


Training on LLR3 ... 
--------------------------------
The following dataset has been loaded successfully: XOR


-------------------------- Training XOR using LLR3 ----------------------
Epoch 1, Loss: 1.0685490315610713, Train Accuracy: 0.4975, Val Loss: 1.0641127608039163, Val Accuracy: 0.46285714285714286, Train Precision: 0.24750625, Train Recall: 0.4975, Train F1: 0.33055926544240405, Val Precision: 0.23517778245318746, Val Recall: 0.46285714285714286, Val F1: 0.31188616071428577, Epoch Time: 0.30971598625183105
Epoch 2, Loss: 1.059847512028434, Train Accuracy: 0.5396428571428571, Val Loss: 1.0550726110284978, Val Accuracy: 0.6514285714285715, Train Precision: 0.5538103672050847, Train Recall: 0.5396428571428571, Train F1: 0.5036006859500068, Val Precision: 0.6802329605906909, Val Recall: 0.6514285714285715, Val F1: 0.6348508634222919, Epoch Time: 0.3052806854248047
Epoch 3, Loss: 1.0492621876976707, Train Accuracy: 0.5785714285714286, Val Loss: 1.0430228926918723, Val Accuracy: 0.6771428571428572, Train Precision: 0.580036605398973, Train Recall: 0.5785714285714286, Train F1: 0.577105953190

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
convergence_time,▁
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇██
epoch_time,▁▁▂▅█▆▅▅▄▅▆▆▆▃▂▂▁▁▂▁▂▂▁▁▂▁▂▂▁▁▃▁▂▂▁▁▁▂▁▂
learning_rate,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
loss,████▇▇▇▆▆▅▄▃▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
total_training_time,▁
train_accuracy,▁▂▂▄▄▄▄▅▅▆▇▇▇███████████████████████████
train_concept_accuracy,▁▁▃▅▆▆▆▇████████████████████████████████
train_concept_loss,█▇▇▇▇▆▅▅▅▄▄▄▃▃▂▄▂▂▃▂▃▃▂▃▃▃▃▃▂▃▄▃▃▂▃▁▂▂▂▂
train_f1,▁▃▄▅▅▅▅▅▆▇▇▇▇███████████████████████████

0,1
convergence_time,4.47716
epoch,41.0
epoch_time,0.32494
learning_rate,0.0005
loss,0.53773
total_training_time,14.33458
train_accuracy,0.99179
train_concept_accuracy,0.74429
train_concept_loss,0.51043
train_f1,0.99179


The following dataset has been loaded successfully: XNOR


-------------------------- Training XNOR using LLR3 ----------------------
Epoch 1, Loss: 1.0689470036463304, Train Accuracy: 0.4967857142857143, Val Loss: 1.0532141382044011, Val Accuracy: 0.4957142857142857, Train Precision: 0.24679604591836735, Train Recall: 0.4967857142857143, Train F1: 0.32976804035859153, Val Precision: 0.2457326530612245, Val Recall: 0.4957142857142857, Val F1: 0.32858234411243004, Epoch Time: 0.29227423667907715
Epoch 2, Loss: 1.0581723641265521, Train Accuracy: 0.5432142857142858, Val Loss: 1.0441479357806118, Val Accuracy: 0.5485714285714286, Train Precision: 0.6058710689570004, Train Recall: 0.5432142857142858, Train F1: 0.468019348132482, Val Precision: 0.5592344418947153, Val Recall: 0.5485714285714286, Val F1: 0.531170581179384, Epoch Time: 0.32696056365966797
Epoch 3, Loss: 1.046123512766578, Train Accuracy: 0.595, Val Loss: 1.0312260822816328, Val Accuracy: 0.5214285714285715, Train Precision: 0.6161915611814345, Train Recall: 0.595, Train F1: 0.5771972

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
convergence_time,▁
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇██
epoch_time,▁▂▂▂▄▂▁▁▂▂▁▂▂▂▅▆▆█▇▅▇▇▇█▇▂▂▁▃▁▂▂▂▁▃▂▁▂▂▁
learning_rate,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
loss,███▇▇▇▆▅▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
total_training_time,▁
train_accuracy,▁▂▂▂▃▄▅▆▆▆▆▆▆▆▆▇▇▇▇▇████████████████████
train_concept_accuracy,▇▃▁▂▅▇▇█████████▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇
train_concept_loss,███▇▇▆▅▄▄▄▄▄▄▃▃▂▃▃▃▂▂▃▂▃▂▃▂▂▂▂▂▂▁▁▂▃▁▁▂▃
train_f1,▁▂▄▃▄▅▆▆▆▆▆▇▇▇▇▇▇▇▇█████████████████████

0,1
convergence_time,4.79389
epoch,41.0
epoch_time,0.29582
learning_rate,0.0005
loss,0.55258
total_training_time,14.25981
train_accuracy,0.98321
train_concept_accuracy,0.76321
train_concept_loss,0.53019
train_f1,0.98321


The following dataset has been loaded successfully: IsBinEven


-------------------------- Training IsBinEven using LLR3 ----------------------
Epoch 1, Loss: 1.0666635578328914, Train Accuracy: 0.5082142857142857, Val Loss: 1.0678334561261265, Val Accuracy: 0.4957142857142857, Train Precision: 0.25828176020408167, Train Recall: 0.5082142857142857, Train F1: 0.3425000845708873, Val Precision: 0.2457326530612245, Val Recall: 0.4957142857142857, Val F1: 0.32858234411243004, Epoch Time: 0.3753206729888916
Epoch 2, Loss: 1.059000925584273, Train Accuracy: 0.5153571428571428, Val Loss: 1.0593418208035557, Val Accuracy: 0.5442857142857143, Train Precision: 0.7519256166495375, Train Recall: 0.5153571428571428, Train F1: 0.35821106761596405, Val Precision: 0.7625632775632776, Val Recall: 0.5442857142857143, Val F1: 0.42821911678267427, Epoch Time: 0.3666844367980957
Epoch 3, Loss: 1.0484899986873975, Train Accuracy: 0.6946428571428571, Val Loss: 1.0445225672288374, Val Accuracy: 0.8957142857142857, Train Precision: 0.7906442575705127, Train Recall: 0.69464

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
convergence_time,▁
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇██
epoch_time,▂▂▄▂▂▂▁▂▁▁▂▁▂▂▂▂▂▁▂▁▂▆▆▇▆▆▆███▂▁▂▁▂▂▂▂▂▂
learning_rate,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
loss,███▇▇▅▄▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
total_training_time,▁
train_accuracy,▁▁▄▆▆███████████████████████████████████
train_concept_accuracy,▂▁▃▆██▅▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂
train_concept_loss,██▇▇▇▇▆▅▄▄▄▃▄▃▃▃▃▂▃▃▃▃▁▂▂▂▂▂▂▁▂▂▂▁▂▂▂▂▁▂
train_f1,▁▁▄▇▇███████████████████████████████████

0,1
convergence_time,1.94254
epoch,41.0
epoch_time,0.37025
learning_rate,0.0005
loss,0.5373
total_training_time,16.68652
train_accuracy,0.99571
train_concept_accuracy,0.52607
train_concept_loss,0.54762
train_f1,0.99571


The following dataset has been loaded successfully: Trigonometry


-------------------------- Training Trigonometry using LLR3 ----------------------
Epoch 1, Loss: 1.0619423443620855, Train Accuracy: 0.6110714285714286, Val Loss: 1.0362199436534534, Val Accuracy: 0.59, Train Precision: 0.3734082908163266, Train Recall: 0.6110714285714286, Train F1: 0.4635527440858853, Val Precision: 0.34809999999999997, Val Recall: 0.59, Val F1: 0.43786163522012583, Epoch Time: 0.40245580673217773
Epoch 2, Loss: 1.0120272460308941, Train Accuracy: 0.6110714285714286, Val Loss: 0.9791820862076499, Val Accuracy: 0.6, Train Precision: 0.3734082908163266, Train Recall: 0.6110714285714286, Train F1: 0.4635527440858853, Val Precision: 0.7616161616161616, Val Recall: 0.6, Val F1: 0.4601567209162145, Epoch Time: 0.40648365020751953
Epoch 3, Loss: 0.9346235963431272, Train Accuracy: 0.7828571428571428, Val Loss: 0.8663977763869546, Val Accuracy: 0.9285714285714286, Train Precision: 0.8377691969465634, Train Recall: 0.7828571428571428, Train F1: 0.757627606370479, Val Precisio

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
convergence_time,▁
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇██
epoch_time,▂▃▁▃▂▁▁▁▂▁▁▁▁▁▃▁▁▂▁▁▂▂▂▂▂▂▅▆▅█▄▅▅▄▅▄▁▁▁▁
learning_rate,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
loss,█▇▆▅▄▃▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
total_training_time,▁
train_accuracy,▁▁▄▇▇▇▇▇▇▇▇▇▇███████████████████████████
train_concept_accuracy,▁▃▄▅▆▆▇▇████████████████████████████████
train_concept_loss,█▇▆▅▃▄▄▃▂▂▃▂▂▂▂▂▂▁▂▂▁▂▁▁▂▁▂▂▂▂▃▁▂▁▁▁▁▂▁▂
train_f1,▁▁▅▇▇▇█▇████████████████████████████████

0,1
convergence_time,1.15295
epoch,41.0
epoch_time,0.34941
learning_rate,0.0005
loss,0.53566
total_training_time,16.5289
train_accuracy,0.99071
train_concept_accuracy,0.63107
train_concept_loss,0.50298
train_f1,0.99071


The following dataset has been loaded successfully: Dot


-------------------------- Training Dot using LLR3 ----------------------
Epoch 1, Loss: 1.0569172203540802, Train Accuracy: 0.5489285714285714, Val Loss: 1.0513964999805798, Val Accuracy: 0.5371428571428571, Train Precision: 0.5516040222892902, Train Recall: 0.5489285714285714, Train F1: 0.5267915455626404, Val Precision: 0.5674869831222138, Val Recall: 0.5371428571428571, Val F1: 0.4945535714285714, Epoch Time: 0.4829239845275879
Epoch 2, Loss: 1.0260874046520754, Train Accuracy: 0.5942857142857143, Val Loss: 1.015939398245378, Val Accuracy: 0.5828571428571429, Train Precision: 0.6221100300199979, Train Recall: 0.5942857142857143, Train F1: 0.5601181953129132, Val Precision: 0.6088142857142859, Val Recall: 0.5828571428571429, Val F1: 0.5629373897031548, Epoch Time: 0.39142680168151855
Epoch 3, Loss: 0.98495975272222, Train Accuracy: 0.6467857142857143, Val Loss: 0.9698037721894004, Val Accuracy: 0.6328571428571429, Train Precision: 0.6692026027802425, Train Recall: 0.6467857142857143

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
convergence_time,▁
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇██
epoch_time,█▄▄▁▃▂▁▃▂▂▁▂▁▁▂▂▁▂▂▁▂▂▁▂▄▁▂▂▁▁▂▂▇▅▇▇▇██▇
learning_rate,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
loss,██▇▆▅▄▄▃▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
total_training_time,▁
train_accuracy,▁▂▃▃▄▅▆▇▇▇▇▇▇▇██████████████████████████
train_concept_accuracy,▁▆▇▆▅▅▅▆▇▇▇▇▇▇▇▇▇▇▇█▇█▇█▇▇▇██▇██████████
train_concept_loss,██▆▇▆▄▄▃▃▃▃▃▄▃▃▃▂▃▃▃▂▃▃▃▂▂▂▃▂▃▃▃▃▃▂▂▁▂▂▁
train_f1,▁▂▃▄▅▅▆▇▇▇▇▇▇▇██████████████████████████

0,1
convergence_time,3.24913
epoch,41.0
epoch_time,0.46754
learning_rate,0.0005
loss,0.5442
total_training_time,14.93943
train_accuracy,0.98464
train_concept_accuracy,0.71857
train_concept_loss,0.52849
train_f1,0.98464


*********** ALL TRAINING ARE DONE - Check WandB ***********
