In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
import torchvision
import torchvision.transforms as transforms
import numpy as np
from scipy.cluster.hierarchy import linkage, fcluster
from torchvision.models import resnet50, ResNet50_Weights
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import copy



In [2]:
def conv_block(in_f, out_f, activation='relu', *args, **kwargs):
    activations = nn.ModuleDict([
                ['lrelu', nn.LeakyReLU()],
                ['relu', nn.ReLU()]
    ])
    
    return nn.Sequential(
        nn.Conv2d(in_f, out_f, *args, **kwargs),
        nn.MaxPool2d(2, 2),
        activations[activation]
    )

class Encoder(nn.Module):
    def __init__(self,encoder_sizes,*args,**kwargs):
        super().__init__()
        self.encoced_blocks = nn.Sequential(*[conv_block(in_f, out_f, kernel_size=3, padding=1, *args, **kwargs) 
                       for in_f, out_f in zip(encoder_sizes, encoder_sizes[1:])])
        
    def forward(self, x):
        return self.encoced_blocks(x)

class Predictor(nn.Module):
    def __init__(self,num_classes,connector_size):
        super().__init__();
        self.predictor=nn.Sequential(nn.Linear(connector_size * 4 * 4, 256),nn.ReLU(),nn.Dropout(0.5),nn.Linear(256, num_classes))

    def forward(self, x):
        return self.predictor(x)
        
class CNN(nn.Module):
    def __init__(self, in_c,enc_sizes, num_classes=10,activation='relu'):
        super().__init__()

        self.enc_sizes = [in_c, *enc_sizes]

        self.encoder = Encoder(self.enc_sizes, activation=activation)
        self.predictor=Predictor(num_classes,self.enc_sizes[-1])

    def forward(self, x):
        x = self.encoder(x)
        x = x.view(x.size(0), -1)  # Flatten
        x = self.predictor(x);
        return x


In [3]:
def conv_block(in_f, out_f, activation='relu', *args, **kwargs):
    activations = nn.ModuleDict([
                ['lrelu', nn.LeakyReLU()],
                ['relu', nn.ReLU()]
    ])
    
    return nn.Sequential(
        nn.Conv2d(in_f, out_f, *args, **kwargs),
        nn.MaxPool2d(2, 2),
        activations[activation]
    )

class Encoder(nn.Module):
    def __init__(self,encoder_sizes,*args,**kwargs):
        super().__init__()
        self.encoced_blocks = nn.Sequential(*[conv_block(in_f, out_f, kernel_size=3, padding=1, *args, **kwargs) 
                       for in_f, out_f in zip(encoder_sizes, encoder_sizes[1:])])
        
    def forward(self, x):
        return self.encoced_blocks(x)

class Predictor2(nn.Module):
    def __init__(self,num_classes,connector_size):
        super().__init__();
        self.predictor=nn.Sequential(nn.Linear(connector_size * 4 * 4, num_classes))

    def forward(self, x):
        return self.predictor(x)
        
class CNN2(nn.Module):
    def __init__(self, in_c,enc_sizes, num_classes=10,activation='relu'):
        super().__init__()

        self.enc_sizes = [in_c, *enc_sizes]

        self.encoder = Encoder(self.enc_sizes, activation=activation)
        self.predictor=Predictor2(num_classes,self.enc_sizes[-1])

    def forward(self, x):
        x = self.encoder(x)
        x = x.view(x.size(0), -1)  # Flatten
        x = self.predictor(x);
        return x


In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
from copy import deepcopy
from sklearn.metrics import precision_recall_fscore_support

def train_model_extra_metric(model, train_loader, val_loader, num_epochs=10, lr=0.001,lambda_reg=0.01, device='cpu', early_stopping_patience=5):
    """
    Trains a PyTorch model with validation and calculates precision, recall, and F1-score.
    
    Args:
        model (nn.Module): The PyTorch model to train.
        train_loader (DataLoader): The training data loader.
        val_loader (DataLoader): The validation data loader.
        num_epochs (int, optional): The number of epochs to train for. Defaults to 10.
        lr (float, optional): The learning rate. Defaults to 0.001.
        device (str, optional): The device to use for training ('cpu' or 'cuda'). Defaults to 'cpu'.
        early_stopping_patience (int, optional): Number of epochs to wait before early stopping. Defaults to 5.
    
    Returns:
        tuple: (best_model, dict_metrics) where dict_metrics contains training history
    """
    # Initialize model and optimization
    model = model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()
    
    # Initialize tracking variables
    best_val_f1 = 0
    best_val=0
    best_model = None
    patience_counter = 0
    history = {
        'train_loss': [],
        'val_accuracy': [],
        'val_precision': [],
        'val_recall': [],
        'val_f1': []
    }
    
    for epoch in range(num_epochs):
        # Training phase
        model.train()
        total_loss = 0
        num_batches = 0
        
        for inputs, targets in train_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            # l2_norm = sum(p.pow(2).sum() for p in model.parameters())
            # loss += lambda_reg * l2_norm
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            num_batches += 1
        
        avg_train_loss = total_loss / num_batches
        history['train_loss'].append(avg_train_loss)
        
        # Validation phase
        if val_loader:
            model.eval()
            correct = 0
            total = 0
            all_targets = []
            all_predictions = []
            
            with torch.no_grad():
                for inputs, targets in val_loader:
                    inputs, targets = inputs.to(device), targets.to(device)
                    outputs = model(inputs)
                    predictions = outputs.argmax(dim=1)
                    
                    correct += (predictions == targets).sum().item()
                    total += targets.size(0)
                    
                    all_targets.extend(targets.cpu().numpy())
                    all_predictions.extend(predictions.cpu().numpy())
            
            # Calculate metrics
            val_accuracy = correct / total
            precision, recall, f1, _ = precision_recall_fscore_support(
                all_targets, 
                all_predictions, 
                average='weighted', 
                zero_division=0
            )
            
            # Update history
            history['val_accuracy'].append(val_accuracy)
            history['val_precision'].append(precision)
            history['val_recall'].append(recall)
            history['val_f1'].append(f1)
            
            # # Print progress
            print(f"Epoch {epoch+1}/{num_epochs}")
            print(f"Train Loss: {avg_train_loss:.4f}")
            print(f"Val Accuracy: {val_accuracy:.4f}")
            print(f"Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}")
            print("-" * 50)
            
            # Model checkpoint and early stopping
            if val_accuracy>best_val:
                best_model=copy.deepcopy(model)
                best_val=val_accuracy
                patience_counter = 0
            else:
                patience_counter += 1
            
            # Early stopping check
            if patience_counter >= early_stopping_patience:
                # print(f"Early stopping triggered after epoch {epoch+1}")
                break
    
    return best_model

In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
import gc
from torch.cuda.amp import autocast, GradScaler
from sklearn.metrics import precision_recall_fscore_support

def train_model_memory_optimized(model, train_loader, val_loader, num_epochs=10, lr=0.001, 
                                lambda_reg=0.01, device='cpu', early_stopping_patience=5,
                                accumulation_steps=1, use_mixed_precision=True):
    """
    Trains a PyTorch model with validation and memory optimization techniques.
    
    Args:
        model (nn.Module): The PyTorch model to train.
        train_loader (DataLoader): The training data loader.
        val_loader (DataLoader): The validation data loader.
        num_epochs (int, optional): The number of epochs to train for. Defaults to 10.
        lr (float, optional): The learning rate. Defaults to 0.001.
        lambda_reg (float, optional): L2 regularization strength. Defaults to 0.01.
        device (str, optional): The device to use for training ('cpu' or 'cuda'). Defaults to 'cpu'.
        early_stopping_patience (int, optional): Number of epochs to wait before early stopping. Defaults to 5.
        accumulation_steps (int, optional): Number of steps to accumulate gradients. Defaults to 1.
        use_mixed_precision (bool, optional): Whether to use mixed precision training. Defaults to True.
    
    Returns:
        tuple: (best_model_state_dict, dict_metrics) where dict_metrics contains training history
    """
    # Initialize model and optimization
    model = model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()
    
    # Initialize gradient scaler for mixed precision
    scaler = GradScaler() if use_mixed_precision and device.startswith('cuda') else None
    
    # Initialize tracking variables
    best_val = 0
    best_model_state_dict = None
    patience_counter = 0
    history = {
        'train_loss': [],
        'val_accuracy': [],
        'val_precision': [],
        'val_recall': [],
        'val_f1': []
    }
    
    for epoch in range(num_epochs):
        # Training phase
        model.train()
        total_loss = 0
        num_batches = 0
        
        for batch_idx, (inputs, targets) in enumerate(train_loader):
            inputs, targets = inputs.to(device), targets.to(device)
            
            # Mixed precision training
            if use_mixed_precision and device.startswith('cuda'):
                with autocast():
                    outputs = model(inputs)
                    loss = criterion(outputs, targets)
                    # Scale loss for gradient accumulation
                    loss = loss / accumulation_steps
                
                # Scales loss and calls backward()
                scaler.scale(loss).backward()
                
                # Step with gradient accumulation
                if (batch_idx + 1) % accumulation_steps == 0:
                    scaler.step(optimizer)
                    scaler.update()
                    # More efficient than zero_grad()
                    for param in model.parameters():
                        param.grad = None
            else:
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                # Scale loss for gradient accumulation
                loss = loss / accumulation_steps
                loss.backward()
                
                # Step with gradient accumulation
                if (batch_idx + 1) % accumulation_steps == 0:
                    optimizer.step()
                    # More efficient than zero_grad()
                    for param in model.parameters():
                        param.grad = None
            
            total_loss += loss.item() * accumulation_steps
            num_batches += 1
            
            # Explicitly delete intermediate tensors
            del inputs, outputs, targets, loss
            
        avg_train_loss = total_loss / num_batches
        history['train_loss'].append(avg_train_loss)
        
        # Validation phase
        if val_loader:
            model.eval()
            correct = 0
            total = 0
            
            # More memory-efficient tracking of predictions
            all_targets = torch.tensor([], dtype=torch.long, device='cpu')
            all_predictions = torch.tensor([], dtype=torch.long, device='cpu')
            
            with torch.no_grad():
                for inputs, targets in val_loader:
                    inputs, targets = inputs.to(device), targets.to(device)
                    
                    # For validation, always use full precision
                    outputs = model(inputs)
                    predictions = outputs.argmax(dim=1)
                    
                    correct += (predictions == targets).sum().item()
                    total += targets.size(0)
                    
                    # Move to CPU before concatenating
                    all_targets = torch.cat([all_targets, targets.cpu()])
                    all_predictions = torch.cat([all_predictions, predictions.cpu()])
                    
                    # Explicitly delete tensors
                    del inputs, outputs, targets, predictions
            
            # Calculate metrics (convert to numpy only once)
            val_accuracy = correct / total
            all_targets_np = all_targets.numpy()
            all_predictions_np = all_predictions.numpy()
            precision, recall, f1, _ = precision_recall_fscore_support(
                all_targets_np, 
                all_predictions_np, 
                average='weighted', 
                zero_division=0
            )
            
            # Update history
            history['val_accuracy'].append(val_accuracy)
            history['val_precision'].append(precision)
            history['val_recall'].append(recall)
            history['val_f1'].append(f1)
            
            # Print progress
            print(f"Epoch {epoch+1}/{num_epochs}")
            print(f"Train Loss: {avg_train_loss:.4f}")
            print(f"Val Accuracy: {val_accuracy:.4f}")
            print(f"Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}")
            print("-" * 50)
            
            # Clear unnecessary variables
            del all_targets, all_predictions, all_targets_np, all_predictions_np
            
            # Model checkpoint (save state_dict instead of whole model)
            if val_accuracy > best_val:
                best_model_state_dict = model.state_dict().copy()
                best_val = val_accuracy
                patience_counter = 0
            else:
                patience_counter += 1
            
            # Early stopping check
            if patience_counter >= early_stopping_patience:
                print(f"Early stopping triggered after epoch {epoch+1}")
                break
    
    # Free memory
    torch.cuda.empty_cache()
    gc.collect()
    
    return best_model_state_dict, history

# Usage example:
def use_optimized_training(model, train_loader, val_loader):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Train with memory optimizations
    best_model_state_dict, history = train_model_memory_optimized(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        num_epochs=10,
        lr=0.001,
        device=device,
        early_stopping_patience=5,
        accumulation_steps=4,  # Accumulate gradients over 4 batches
        use_mixed_precision=True  # Use mixed precision if on CUDA
    )
    
    # Load the best model state dict
    model.load_state_dict(best_model_state_dict)
    
    # Clean up memory before inference or further use
    torch.cuda.empty_cache()
    gc.collect()
    
    return model, history

In [6]:
def train_model(model, train_loader, val_loader, num_epochs=10, lr=0.001, device='cpu'):
    model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()
    best_val=0;
    best_model=model;
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        for inputs, targets in train_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            # l2_norm = sum(p.pow(2).sum() for p in model.parameters())
            # loss += 0.01 * l2_norm
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        
        
        # Validation (optional in minimal version)
        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for inputs, targets in val_loader:
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = model(inputs)
                predictions = outputs.argmax(dim=1)
                correct += (predictions == targets).sum().item()
                total += targets.size(0)
        val_accuracy = correct / total
        if val_accuracy>best_val:
            best_model=copy.deepcopy(model)
            best_val=val_accuracy
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss:.4f}, Val Accuracy: {val_accuracy:.4f}")
    return best_model;


In [7]:
def test_model(model,test_loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, targets in test_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            predictions = outputs.argmax(dim=1)
            correct += (predictions == targets).sum().item()
            total += targets.size(0)
    test_accuracy = correct / total
    print('test accuracy',test_accuracy)

In [8]:
def remove_keys(original_dict, keys_to_remove):
    # Create a new dictionary by excluding the specified keys
    return {key: value for key, value in original_dict.items() if key not in keys_to_remove}


from collections import OrderedDict

def state_dict_to_vector(state_dict, remove_keys=[]):

    shared_state_dict = copy.deepcopy(state_dict)

    for key in remove_keys:

        if key in shared_state_dict:

            del shared_state_dict[key]

    sorted_shared_state_dict = OrderedDict(sorted(shared_state_dict.items()))

    return torch.nn.utils.parameters_to_vector([value.reshape(-1) for key, value in sorted_shared_state_dict.items()])





def vector_to_state_dict(vector, state_dict, remove_keys=[]):

    # create a reference dict to define the order of the vector

    reference_dict = copy.deepcopy(state_dict)

    for key in remove_keys:

        if key in reference_dict:

            del reference_dict[key]

    sorted_reference_dict = OrderedDict(sorted(reference_dict.items()))



    # create a shared state dict using the refence dict

    torch.nn.utils.vector_to_parameters(vector, sorted_reference_dict.values())



    # add back the encoder and decoder embedding weights.

    if "transformer.shared.weight" in sorted_reference_dict:

        for key in remove_keys:

            sorted_reference_dict[key] = sorted_reference_dict["transformer.shared.weight"]

    return sorted_reference_dict

import torch

def count_sign_conflicts(tensor: torch.Tensor) -> int:
    # Get the sign of each element in the tensor
    # -1 for negative, 1 for positive, and 0 for zero
    signs = torch.sign(tensor).to(device)

    # Create a mask for non-zero entries
    non_zero_mask = (signs != 0)

    # Check for sign conflicts, ignoring zeros
    negative_mask = (signs == -1) & non_zero_mask
    positive_mask = (signs == 1) & non_zero_mask

    # A conflict occurs if a column contains both -1 and 1
    conflict_mask = negative_mask.any(dim=0) & positive_mask.any(dim=0)

    # Count the number of columns with sign conflicts
    num_conflicts = conflict_mask.sum().item()

    return num_conflicts

def calc_sign_conflicts(model_list,dummy_model,exclude_keys,model_class,number_of_tasks,is_layer=False, is_transformer=False):
    device='cuda' if torch.cuda.is_available() else 'cpu'

    # dummy_model=model_class(3,enc_sizes,num_classes=2).to(device)  
    dm_new=remove_keys(dummy_model.state_dict(),exclude_keys)
    flat_ptm= state_dict_to_vector(dm_new)

    if is_layer==True:
        flat_ft = torch.vstack([state_dict_to_vector(model_list[x] )for x in range(number_of_tasks)]);
    else:
        flat_ft = torch.vstack([state_dict_to_vector(remove_keys(model_list[x],exclude_keys) )for x in range(number_of_tasks)]);
    flat_ptm=flat_ptm.to('cpu')
    flat_ft=flat_ft.to('cpu')
    tv_flat_checks = flat_ft - flat_ptm;

    # torch.save(tv_flat_checks,model_name+"_"+task_name+"_"+str(number_of_tasks)+"_"+version+".pt");
    # plot_mds_torch(tv_flat_checks);
    # plot_tsne_torch(tv_flat_checks)
    # flipped=flip(tv_flat_checks)
    # plot_mds_torch(flipped)
    # plot_tsne_torch(flipped)

    res,count1= count_sign_conflicts(tv_flat_checks)/tv_flat_checks.shape[1], tv_flat_checks.shape[1];
    gc.collect()
    torch.cuda.empty_cache()
    
    return res, count1;


In [9]:

def distance_of_task_vectors(task_vectors):
 
    res= ((task_vectors[0] - task_vectors[1]))**2

    res=torch.sum(res)
  
    return res
def calc_distance(model_list,dummy_model,exclude_keys,model_class,number_of_tasks,is_layer=False, is_transformer=False):
    device='cuda' if torch.cuda.is_available() else 'cpu'

    # dummy_model=model_class(3,enc_sizes,num_classes=2).to(device)  
    dm_new=remove_keys(dummy_model.state_dict(),exclude_keys)
    flat_ptm= state_dict_to_vector(dm_new)

    if is_layer==True:
        flat_ft = torch.vstack([state_dict_to_vector(model_list[x] )for x in range(number_of_tasks)]);
    else:
        flat_ft = torch.vstack([state_dict_to_vector(remove_keys(model_list[x],exclude_keys) )for x in range(number_of_tasks)]);
    flat_ptm=flat_ptm.to('cpu')
    flat_ft=flat_ft.to('cpu')
    tv_flat_checks = flat_ft - flat_ptm;

    # torch.save(tv_flat_checks,model_name+"_"+task_name+"_"+str(number_of_tasks)+"_"+version+".pt");
    # plot_mds_torch(tv_flat_checks);
    # plot_tsne_torch(tv_flat_checks)
    # flipped=flip(tv_flat_checks)
    # plot_mds_torch(flipped)
    # plot_tsne_torch(flipped)
    
    res= distance_of_task_vectors(tv_flat_checks)
    gc.collect()
    torch.cuda.empty_cache()

    return res,3;


In [10]:
import torch
import torch.nn.functional as F

def cosine_similarity_of_task_vectors(task_vectors):
    # Normalize the vectors
    task_vectors_0 = task_vectors[0] / torch.norm(task_vectors[0], dim=-1, keepdim=True)
    task_vectors_1 = task_vectors[1] / torch.norm(task_vectors[1], dim=-1, keepdim=True)
    
    # Compute cosine similarity
    cosine_sim = F.cosine_similarity(task_vectors_0, task_vectors_1, dim=-1)
    
    return cosine_sim

def calc_similarity(model_list, dummy_model, exclude_keys, model_class, number_of_tasks,is_layer=False, is_transformer=False):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    # dummy_model=model_class(3,enc_sizes,num_classes=2).to(device)  
    dm_new = remove_keys(dummy_model.state_dict(), exclude_keys)
    flat_ptm = state_dict_to_vector(dm_new)
    if is_layer==True:
        flat_ft = torch.vstack([state_dict_to_vector(model_list[x] )for x in range(number_of_tasks)]);
    else:
        flat_ft = torch.vstack([state_dict_to_vector(remove_keys(model_list[x],exclude_keys) )for x in range(number_of_tasks)]);
    flat_ft=flat_ft.to('cpu')
    tv_flat_checks = flat_ft - flat_ptm
    
    # Use cosine similarity instead of Euclidean distance
    res = cosine_similarity_of_task_vectors(tv_flat_checks)

    return res,3


In [11]:

def distance_of_task_vectors_euc(task_vectors):
 
    res= ((task_vectors[0] - task_vectors[1]))**2

    res=torch.sqrt(torch.sum(res))
  
    return res
def calc_distance_euc(model_list,dummy_model,exclude_keys,model_class,number_of_tasks,is_layer=False, is_transformer=False):
    device='cuda' if torch.cuda.is_available() else 'cpu'

    # dummy_model=model_class(3,enc_sizes,num_classes=2).to(device)  
    dm_new=remove_keys(dummy_model.state_dict(),exclude_keys)
    flat_ptm= state_dict_to_vector(dm_new)

    if is_layer==True:
        flat_ft = torch.vstack([state_dict_to_vector(model_list[x] )for x in range(number_of_tasks)]);
    else:
        flat_ft = torch.vstack([state_dict_to_vector(remove_keys(model_list[x],exclude_keys) )for x in range(number_of_tasks)]);
    flat_ptm=flat_ptm.to('cpu')
    flat_ft=flat_ft.to('cpu')
    tv_flat_checks = flat_ft - flat_ptm;

    # torch.save(tv_flat_checks,model_name+"_"+task_name+"_"+str(number_of_tasks)+"_"+version+".pt");
    # plot_mds_torch(tv_flat_checks);
    # plot_tsne_torch(tv_flat_checks)
    # flipped=flip(tv_flat_checks)
    # plot_mds_torch(flipped)
    # plot_tsne_torch(flipped)
    
    res= distance_of_task_vectors_euc(tv_flat_checks)
    # print(res)

    return res,3;


In [12]:




def calc_simdis(model_list,dummy_model,exclude_keys,model_class,number_of_tasks,is_layer=False, is_transformer=False):
    device='cuda' if torch.cuda.is_available() else 'cpu'

    # dummy_model=model_class(3,enc_sizes,num_classes=2).to(device)  
    dm_new=remove_keys(dummy_model.state_dict(),exclude_keys)
    flat_ptm= state_dict_to_vector(dm_new)
    if is_layer==True:
        flat_ft = torch.vstack([state_dict_to_vector(model_list[x] )for x in range(number_of_tasks)]);
    else:
        flat_ft = torch.vstack([state_dict_to_vector(remove_keys(model_list[x].state_dict(),exclude_keys) )for x in range(number_of_tasks)]);
    flat_ft=flat_ft.to('cpu')

    tv_flat_checks = flat_ft - flat_ptm;
    # torch.save(tv_flat_checks,model_name+"_"+task_name+"_"+str(number_of_tasks)+"_"+version+".pt");
    # plot_mds_torch(tv_flat_checks);
    # plot_tsne_torch(tv_flat_checks)
    # flipped=flip(tv_flat_checks)
    # plot_mds_torch(flipped)
    # plot_tsne_torch(flipped)
    return (count_sign_conflicts(tv_flat_checks)/tv_flat_checks.shape[1])+distance_of_task_vectors_euc(tv_flat_checks), tv_flat_checks.shape[1];

# Datasets


## Cifar 100

In [13]:


# 1. Load CIFAR-10 Dataset
def load_cifar100(batch_size=64):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))  # Normalize with CIFAR-10 stats
    ])
    
    train_dataset = torchvision.datasets.CIFAR100(
        root='./data', train=True, download=True, transform=transform
    )
    test_dataset = torchvision.datasets.CIFAR100(
        root='./data', train=False, download=True, transform=transform
    )
    
    # Split train dataset into training and validation
    train_size = int(0.8 * len(train_dataset))
    val_size = len(train_dataset) - train_size
    train_subset, val_subset = torch.utils.data.random_split(train_dataset, [train_size, val_size])
    
    # DataLoaders
    train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_subset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    
    return train_loader, val_loader, test_loader, len(train_dataset.classes),train_dataset.classes

import torch
from torch.utils.data import DataLoader, Dataset
import torchvision
from torchvision import transforms
import numpy as np
from collections import defaultdict

class BinaryCIFAR100(Dataset):
    def __init__(self, dataset, target_label, undersample_ratio=1.0):
        """
        Args:
            dataset: Original CIFAR100 dataset
            target_label: The positive class label
            undersample_ratio: Ratio of negative to positive samples (e.g., 1.0 means balanced)
        """
        self.dataset = dataset
        self.target_label = target_label
        
        # Separate positive and negative indices
        self.positive_indices = []
        self.negative_indices = []
        
        for idx in range(len(dataset)):
            _, label = dataset[idx]
            if label == target_label:
                self.positive_indices.append(idx)
            else:
                self.negative_indices.append(idx)
        
        # Undersample negative class
        num_positive = len(self.positive_indices)
        num_negative_keep = int(num_positive * undersample_ratio)
        
        # Randomly select negative samples
        if len(self.negative_indices) > num_negative_keep:
            self.negative_indices = np.random.choice(
                self.negative_indices, 
                size=num_negative_keep, 
                replace=False
            ).tolist()
        
        # Combine indices
        self.used_indices = self.positive_indices + self.negative_indices
        
        print(f"Task {target_label}: Positive samples: {len(self.positive_indices)}, "
              f"Negative samples: {len(self.negative_indices)}")
    
    def __len__(self):
        return len(self.used_indices)
    
    def __getitem__(self, idx):
        original_idx = self.used_indices[idx]
        img, label = self.dataset[original_idx]
        binary_label = 1 if label == self.target_label else 0
        return img, binary_label

def load_cifar100_binary(batch_size=64, undersample_ratio=1.0):
    """
    Load CIFAR100 as binary classification tasks with undersampling.
    
    Args:
        batch_size: Batch size for DataLoader
        undersample_ratio: Ratio of negative to positive samples (1.0 means balanced)
    
    Returns:
        list of dictionaries containing train and val loaders for each binary task
    """
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
    ])
    
    # Load datasets
    train_dataset = torchvision.datasets.CIFAR100(
        root='./data', train=True, download=True, transform=transform
    )
    
    # Split train dataset into training and validation
    train_size = int(0.8 * len(train_dataset))
    val_size = len(train_dataset) - train_size
    train_subset, val_subset = torch.utils.data.random_split(
        train_dataset, 
        [train_size, val_size],
        generator=torch.Generator().manual_seed(42)  # For reproducibility
    )
    
    task_loaders = []
    
    for label in range(100):
        # Create balanced datasets for this task
        train_task_dataset = BinaryCIFAR100(
            train_subset, 
            label, 
            undersample_ratio=undersample_ratio
        )
        val_task_dataset = BinaryCIFAR100(
            val_subset, 
            label,
            undersample_ratio=undersample_ratio
        )
        
        # Create DataLoader for this task
        train_loader = DataLoader(
            train_task_dataset, 
            batch_size=batch_size, 
            shuffle=True,
            num_workers=1,
            pin_memory=False
        )
        val_loader = DataLoader(
            val_task_dataset, 
            batch_size=batch_size, 
            shuffle=False,
            num_workers=1,
            pin_memory=False
        )
        
        task_loaders.append({"train": train_loader, "val": val_loader})
    
    return task_loaders

## Mammals

In [14]:
class AddGaussianNoise(object):
    def __init__(self, mean=0.0, std=1.0):
        self.mean = mean
        self.std = std

    def __call__(self, tensor):
        noise = torch.randn(tensor.size()) * self.std + self.mean
        return tensor + noise

In [15]:
import torch
from torch.utils.data import DataLoader, Dataset
import torchvision
from torchvision import transforms, datasets
import numpy as np
from collections import defaultdict
import os
from sklearn.model_selection import train_test_split

class BinaryAnimals(Dataset):
    def __init__(self, dataset, target_label, undersample_ratio=1.0):
        """
        Args:
            dataset: Original Animals dataset
            target_label: The positive class label
            undersample_ratio: Ratio of negative to positive samples (e.g., 1.0 means balanced)
        """
        self.dataset = dataset
        self.target_label = target_label
        
        # Get all labels at once for faster processing
        if isinstance(dataset, torch.utils.data.Subset):
            labels = torch.tensor([dataset.dataset.targets[i] for i in dataset.indices])
        else:
            labels = torch.tensor(dataset.targets)
        
        # Use tensor operations instead of loops for speed
        self.positive_indices = torch.where(labels == target_label)[0].tolist()
        self.negative_indices = torch.where(labels != target_label)[0].tolist()
        
        # Undersample negative class
        num_positive = len(self.positive_indices)
        num_negative_keep = int(num_positive * undersample_ratio)
        
        # Use numpy's faster random selection
        if len(self.negative_indices) > num_negative_keep:
            rng = np.random.default_rng()
            self.negative_indices = rng.choice(
                self.negative_indices, 
                size=num_negative_keep, 
                replace=False
            ).tolist()
        
        # Pre-compute indices and binary labels - now using Long (int64) type
        self.used_indices = self.positive_indices + self.negative_indices
        self.binary_labels = torch.zeros(len(self.used_indices), dtype=torch.long)  # Changed to long
        self.binary_labels[:len(self.positive_indices)] = 1
        
        print(f"Task {target_label}: Positive samples: {len(self.positive_indices)}, "
              f"Negative samples: {len(self.negative_indices)}")
    
    def __len__(self):
        return len(self.used_indices)
    
    def __getitem__(self, idx):
        original_idx = self.used_indices[idx]
        if isinstance(self.dataset, torch.utils.data.Subset):
            img, _ = self.dataset[original_idx]
        else:
            img, _ = self.dataset[original_idx]
        return img, self.binary_labels[idx]

def load_animals_binary(data_path, batch_size=64, undersample_ratio=1.0, image_size=32, valid=False):
    """
    Load Animals dataset as binary classification tasks with undersampling.
    
    Args:
        data_path: Path to the animals dataset
        batch_size: Batch size for DataLoader
        undersample_ratio: Ratio of negative to positive samples (1.0 means balanced)
        image_size: Size to resize images to
    
    Returns:
        list of dictionaries containing train and val loaders for each binary task
    """
    # Optimize transform pipeline
    transform = transforms.Compose([
        transforms.Resize((image_size, image_size), antialias=True),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        AddGaussianNoise(mean=0.0, std=0.1),
    ])
    
    # Load full dataset
    if valid==False:
        full_dataset = datasets.ImageFolder(root=data_path, transform=transform)
        
        # Get number of classes and class names
        num_classes = len(full_dataset.classes)
        class_names = full_dataset.classes
        
        # Create train/val split indices
        indices = list(range(len(full_dataset)))
        train_indices, val_indices = train_test_split(
            indices,
            train_size=0.8,
            stratify=[full_dataset.targets[i] for i in indices],
            random_state=42
        )
        
        # Create train and validation datasets
        
        train_subset = torch.utils.data.Subset(full_dataset, train_indices)
        val_subset = torch.utils.data.Subset(full_dataset, val_indices)
    else:
        train_subset = datasets.ImageFolder(root=data_path+'/train', transform=transform)
        val_subset = datasets.ImageFolder(root=data_path+'/valid', transform=transform)

        # Get number of classes and class names
        num_classes = len(train_subset.classes)
        class_names = train_subset.classes
    
    # Loader settings for better performance
    loader_kwargs = {
        'batch_size': batch_size,
        'num_workers': 4,
        'pin_memory': True,
        'persistent_workers': True,
        'prefetch_factor': 2
    }
    
    task_loaders = []
    
    for label in range(num_classes):
        # Create balanced datasets for this task
        train_task_dataset = BinaryAnimals(
            train_subset, 
            label, 
            undersample_ratio=undersample_ratio
        )
        val_task_dataset = BinaryAnimals(
            val_subset, 
            label,
            undersample_ratio=undersample_ratio
        )
        
        # Create DataLoader for this task
        train_loader = DataLoader(
            train_task_dataset, 
            shuffle=True,
            **loader_kwargs
        )
        val_loader = DataLoader(
            val_task_dataset, 
            shuffle=False,
            **loader_kwargs
        )
        task_loaders.append({"train": train_loader, "val": val_loader})
        # task_loaders.append({
        #     "train": train_loader, 
        #     "val": val_loader,
        #     "class_name": class_names[label]
        # })
    
    return task_loaders

In [16]:
cifar100_tasks=load_cifar100_binary()

Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to ./data/cifar-100-python.tar.gz


100%|██████████| 169M/169M [00:14<00:00, 11.7MB/s]


Extracting ./data/cifar-100-python.tar.gz to ./data
Task 0: Positive samples: 383, Negative samples: 383
Task 0: Positive samples: 117, Negative samples: 117
Task 1: Positive samples: 390, Negative samples: 390
Task 1: Positive samples: 110, Negative samples: 110
Task 2: Positive samples: 391, Negative samples: 391
Task 2: Positive samples: 109, Negative samples: 109
Task 3: Positive samples: 407, Negative samples: 407
Task 3: Positive samples: 93, Negative samples: 93
Task 4: Positive samples: 404, Negative samples: 404
Task 4: Positive samples: 96, Negative samples: 96
Task 5: Positive samples: 402, Negative samples: 402
Task 5: Positive samples: 98, Negative samples: 98
Task 6: Positive samples: 390, Negative samples: 390
Task 6: Positive samples: 110, Negative samples: 110
Task 7: Positive samples: 409, Negative samples: 409
Task 7: Positive samples: 91, Negative samples: 91
Task 8: Positive samples: 420, Negative samples: 420
Task 8: Positive samples: 80, Negative samples: 80
Task

In [17]:
mammal_tasks=load_animals_binary('/kaggle/input/cards-image-datasetclassification/train')


Task 0: Positive samples: 96, Negative samples: 96
Task 0: Positive samples: 24, Negative samples: 24
Task 1: Positive samples: 103, Negative samples: 103
Task 1: Positive samples: 26, Negative samples: 26
Task 2: Positive samples: 137, Negative samples: 137
Task 2: Positive samples: 34, Negative samples: 34
Task 3: Positive samples: 145, Negative samples: 145
Task 3: Positive samples: 36, Negative samples: 36
Task 4: Positive samples: 110, Negative samples: 110
Task 4: Positive samples: 28, Negative samples: 28
Task 5: Positive samples: 127, Negative samples: 127
Task 5: Positive samples: 32, Negative samples: 32
Task 6: Positive samples: 122, Negative samples: 122
Task 6: Positive samples: 30, Negative samples: 30
Task 7: Positive samples: 108, Negative samples: 108
Task 7: Positive samples: 27, Negative samples: 27
Task 8: Positive samples: 120, Negative samples: 120
Task 8: Positive samples: 30, Negative samples: 30
Task 9: Positive samples: 110, Negative samples: 110
Task 9: Posit

In [18]:
cards_tasks=load_animals_binary('/kaggle/input/cards-image-datasetclassification/train',valid=False)

Task 0: Positive samples: 96, Negative samples: 96
Task 0: Positive samples: 24, Negative samples: 24
Task 1: Positive samples: 103, Negative samples: 103
Task 1: Positive samples: 26, Negative samples: 26
Task 2: Positive samples: 137, Negative samples: 137
Task 2: Positive samples: 34, Negative samples: 34
Task 3: Positive samples: 145, Negative samples: 145
Task 3: Positive samples: 36, Negative samples: 36
Task 4: Positive samples: 110, Negative samples: 110
Task 4: Positive samples: 28, Negative samples: 28
Task 5: Positive samples: 127, Negative samples: 127
Task 5: Positive samples: 32, Negative samples: 32
Task 6: Positive samples: 122, Negative samples: 122
Task 6: Positive samples: 30, Negative samples: 30
Task 7: Positive samples: 108, Negative samples: 108
Task 7: Positive samples: 27, Negative samples: 27
Task 8: Positive samples: 120, Negative samples: 120
Task 8: Positive samples: 30, Negative samples: 30
Task 9: Positive samples: 110, Negative samples: 110
Task 9: Posit

In [19]:
device='cuda' if torch.cuda.is_available() else 'cpu';
enc_sizes=[32,64,64]
conflict_model=CNN(3,enc_sizes,num_classes=2)
models=[]
flags=['layer_wise']

In [20]:

task_dict={
    # 'cards': cards_tasks,
    # 'mammals': mammal_tasks
    'cifar100': cifar100_tasks,
    # 'imagenet':tiny_imagenet_tasks
    # 'flowers': flower_tasks
}
num_classes_dict={
    'cards':53,
    'mammals': 45,
    'cifar100': 100,
    'imagenet':200,
    'flowers':102

}
metrics={
    "ci": calc_sign_conflicts,
    # 'dis': calc_distance, 
    # 'sim': calc_similarity,
    # 'disci': calc_simdis,
    'diseuc': calc_distance_euc
}
types=['full','layer_wise']




In [21]:
for key in task_dict.keys():
  
    tasks=task_dict[key];
    print(len(tasks))
    num_classes=num_classes_dict[key]
    models=[]
    # exclude_keys=list(conflict_model.state_dict().keys())
    # exclude_keys.remove('encoder.encoced_blocks.2.0.weight')
    exclude_keys=[]
    for task_id in range(num_classes):
        temp_model=copy.deepcopy(conflict_model);
        # model_state_dict=conflict_model.state_dict()
        history={}
        model_state_dict=train_model_extra_metric(temp_model,tasks[task_id]['train'],tasks[task_id]['val'],num_epochs=5,device=device)
        model_state_dict=model_state_dict.state_dict()
        # print(model_state_dict)
        # test_model(temp_model,tasks[task_id]['val'])
        # break;
        model_layer=remove_keys(model_state_dict,exclude_keys)
        # torch.save(temp_model.state_dict(), f'model_{task_id}.pt')
        # print(model_layer)
        models.append(model_layer);
        
        del history
        del model_state_dict
        del temp_model;
        del model_layer
        # break;
        # torch.cuda.empty_cache()  # If using CUDA


    
    conflict2d=[[0 for y in range(num_classes)]for x in range(num_classes)]
    for tp in types:
        if tp in flags:
            exclude_keys=list(conflict_model.state_dict().keys())
            exclude_keys.remove('encoder.encoced_blocks.2.0.weight')
        else:
            exclude_keys=[];
        # num_classes=1;
        for metric in metrics.keys():

            for i in range(num_classes):
                for j in range(num_classes):
                    
                    conflict2d[i][j]=metrics[metric]([models[i],models[j]],conflict_model,exclude_keys,CNN,2)[0]

                    # conflict2d[i][j]=metrics[metric]([models[i],models[j]],conflict_model,exclude_keys,CNN,2)[0]
            
                
            np_conflict=np.array(conflict2d)
            np.savetxt(f"{key}_{tp}_{metric}.csv", np_conflict, delimiter=",")

100
Epoch 1/5
Train Loss: 0.5814
Val Accuracy: 0.7863
Precision: 0.8287, Recall: 0.7863, F1: 0.7792
--------------------------------------------------
Epoch 2/5
Train Loss: 0.4180
Val Accuracy: 0.8205
Precision: 0.8229, Recall: 0.8205, F1: 0.8202
--------------------------------------------------
Epoch 3/5
Train Loss: 0.3450
Val Accuracy: 0.8376
Precision: 0.8687, Recall: 0.8376, F1: 0.8341
--------------------------------------------------
Epoch 4/5
Train Loss: 0.3142
Val Accuracy: 0.8504
Precision: 0.8536, Recall: 0.8504, F1: 0.8501
--------------------------------------------------
Epoch 5/5
Train Loss: 0.2817
Val Accuracy: 0.8547
Precision: 0.8615, Recall: 0.8547, F1: 0.8540
--------------------------------------------------
Epoch 1/5
Train Loss: 0.5883
Val Accuracy: 0.7500
Precision: 0.7510, Recall: 0.7500, F1: 0.7497
--------------------------------------------------
Epoch 2/5
Train Loss: 0.5189
Val Accuracy: 0.7591
Precision: 0.7628, Recall: 0.7591, F1: 0.7582
------------------

In [22]:
# for key in task_dict.keys():
  
#     tasks=task_dict[key];
#     # print(len(tasks))
#     num_classes=num_classes_dict[key]
#     models=[]
#     exclude_keys=[]
#     for task_id in range(num_classes):
#         temp_model=copy.deepcopy(conflict_model);
#         train_model_extra_metric(temp_model,tasks[task_id]['train'],tasks[task_id]['val'],num_epochs=5,device=device)
#         # model_layer=remove_keys(temp_model.state_dict(),exclude_keys)
#         models.append(temp_model);
#         # del temp_model;
    
#     conflict2d=[[0 for y in range(num_classes)]for x in range(num_classes)]
#     for tp in types:
#         if tp in flags:
#             exclude_keys=list(conflict_model.state_dict().keys())
#             exclude_keys.remove('encoder.encoced_blocks.2.0.weight')
#         else:
#             exclude_keys=[];
#         for metric in metrics.keys():
            
#             for i in range(num_classes):
#                 for j in range(num_classes):
#                     conflict2d[i][j]=metrics[metric]([models[i],models[j]],conflict_model,exclude_keys,CNN,2)[0]
            
                
#             np_conflict=np.array(conflict2d)
#             np.savetxt(f"{key}_{tp}_{metric}.csv", np_conflict, delimiter=",")