In [1]:
"""
Libraries

"""

import csv
from datetime import datetime
from typing import Callable

import logging

import numpy as np

import torch
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader

from scipy.optimize import linear_sum_assignment
from sklearn.metrics import accuracy_score, balanced_accuracy_score, confusion_matrix
from sklearn.metrics.cluster import normalized_mutual_info_score

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
"""
Setting generic hyperparameters

"""

num_epochs: int = 40
batch_size: int = 256 # Should be set to a power of 2.
# Learning rate
lr:         float = 1e-4 # Learning rate used in the IIC paper: lr=1e-4.

"""
GPU utilization

"""
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Specifications
if torch.cuda.is_available():
    print(f"Number of available devices: {torch.cuda.device_count()}\n",
          f"Device name: {torch.cuda.get_device_name(torch.cuda.current_device())}\n",
          f"Total GPU memory device 0: {torch.cuda.get_device_properties(0).total_memory/(1024**3):.2f} GB\n")

Number of available devices: 1
 Device name: NVIDIA A100 80GB PCIe
 Total GPU memory device 0: 79.20 GB



In [4]:
'''
Store data to .csv file

'''

# open the file for writing
f = open(f'logs/IIC_ten_classes_balanced/{datetime.now().strftime("%Y-%m-%d-%H-%M")}.csv', 'w')
# create a CSV writer object
writer = csv.writer(f)
# write the header row to the CSV file
writer.writerow(['epoch', 'loss', 'running_acc', 'acc', 'running_nmi', 'nmi'])

44

In [5]:
"""
The ten classes considered in case 2

"""

subset_classes = ['acantharia_protist',
                  'chordate_type1',
                  'copepod_calanoid_eucalanus',
                  'copepod_cyclopoid_copilia',
                  'ctenophore_cestid',
                  'ctenophore_lobate',
                  'diatom_chain_string',
                  'echinoderm_larva_seastar_brachiolaria',
                  'hydromedusae_haliscera',
                  'radiolarian_chain']

mapping_dict = {0: 0, 13: 1, 16: 2, 23: 3, 28: 4, 31: 5, 36: 6, 43: 7, 60: 8, 90: 9}

In [6]:
"""
Unsupervised Machine Learning Framework

"""

def train(model, data_loader: DataLoader, criterion: Callable, optimizer: torch.optim, num_epochs: int, num_classes: int=None) -> None:
    """
    Trains a given model using the provided training data, optimizer and loss criterion for a given number of epochs.

    Args:
        model: Neural network model to train.
        data_loader: PyTorch data loader containing the training data.
        criterion: Loss criterion used for training the model.
        optimizer: Optimizer used to update the model's parameters.
        num_epochs: Number of epochs to train the model.

    Returns:
        None
    """

    for epoch in range(num_epochs):

        running_loss = 0.0
        running_acc  = 0.0
        running_nmi  = 0.0

        # Initialize tensors for storing true and predicted labels
        labels_true = torch.zeros(len(data_loader.dataset))
        labels_pred = torch.zeros(len(data_loader.dataset))

        # Loop over the mini-batches in the data loader
        for i, data in enumerate(data_loader):
        
            # Get the inputs and labels for the mini-batch
            inputs, labels = data

            # Use GPU if available
            inputs = inputs.to(device)

            # Image augmentation
            if data_loader.dataset.augment_data:
                inputs_trans = torch.stack([data_loader.dataset.transform_list(input) for input in inputs])
                # # Flatten input data for the feed forward model
                # inputs       = [inputs.view(inputs.size(0), -1), inputs_trans.view(inputs_trans.size(0), -1)]
                inputs       = [inputs, inputs_trans]
            # else:
                # inputs = inputs.view(inputs.size(0), -1)
        
            # Zero the parameter gradients
            optimizer.zero_grad()

            # Forward pass through the model
            if data_loader.dataset.augment_data:
                outputs = [F.softmax(model(inputs[0]), dim=1), F.softmax(model(inputs[1]), dim=1)]
            else:
                outputs = F.softmax(model(inputs), dim=1)

            # Set arguments for objective function
            # kwargs = {key: value for key, value in locals().items() if key in criterion.__code__.co_varnames}
            kwargs = {"model": model, "inputs": inputs, "outputs": outputs}
            kwargs = {key: value for key, value in kwargs.items() if key in criterion.__code__.co_varnames}
            
            # Compute the loss
            loss = criterion(**kwargs)
            # Backward pass through the model and compute gradients
            loss.backward()
        
            # Update the weights
            optimizer.step()

            # Accumulate the loss for the mini-batch
            running_loss += loss.item()

            outputs = outputs[0] if data_loader.dataset.augment_data else outputs

            running_acc  += unsupervised_clustering_accuracy(labels, torch.argmax(outputs.cpu(), dim=1),C=num_classes)
            running_nmi  += normalized_mutual_info_score(labels, torch.argmax(outputs.cpu(), dim=1))

            # Store predicted and true labels in tensors
            labels_true[i*len(labels):(i+1)*len(labels)] = labels
            labels_pred[i*len(labels):(i+1)*len(labels)] = torch.argmax(outputs, dim=1)

        acc = unsupervised_clustering_accuracy(labels_true, labels_pred, C=num_classes)
        nmi = normalized_mutual_info_score(labels_true, labels_pred)

        # Compute the average loss and accuracy for the epoch and print
        print(f"Epoch {epoch+1} loss: {running_loss/len(data_loader):.4f},\
              running_acc: {running_acc/len(data_loader):.4f}, acc: {acc:.4f},\
              running_nmi: {running_nmi/len(data_loader):.4f}, nmi: {nmi:.4f}")
        # Store data to file
        writer.writerow([epoch+1, running_loss/len(data_loader), running_acc/len(data_loader), acc, running_nmi/len(data_loader), nmi])

def reassign(y_true: torch.Tensor, y_pred: torch.Tensor, C: int=None) -> float:
    
    # Create confusion matrix
    cm = confusion_matrix(y_pred, y_true, labels=list(range(C)))

    # Compute best matching between true and predicted labels using the Hungarian algorithm
    _, col_ind = linear_sum_assignment(-cm)

    # Reassign labels for the predicted clusters
    y_pred_reassigned = torch.tensor(col_ind)[y_pred.long()]
    
    return y_pred_reassigned
        
def unsupervised_clustering_accuracy(y_true: torch.Tensor, y_pred: torch.Tensor, C: int=None) -> float:
    """
    Computes the unsupervised clustering accuracy between two clusterings.
    Uses the Hungarian algorithm to find the best matching between true and predicted labels.

    Args:
        y_true: true cluster labels as a 1D torch.Tensor
        y_pred: predicted cluster labels as a 1D torch.Tensor
        C:      number of classes

    Returns:
        accuracy: unsupervised clustering accuracy as a float
    """
    
    y_pred_reassigned = reassign(y_true, y_pred, C)

    # Compute accuracy as the percentage of correctly classified samples
    acc = accuracy_score(y_true, y_pred_reassigned)

    return acc

def unsupervised_balanced_clustering_accuracy(y_true: torch.Tensor, y_pred: torch.Tensor, C: int=None) -> float:
    """
    Computes the unsupervised clustering accuracy between two clusterings.
    Uses the Hungarian algorithm to find the best matching between true and predicted labels.

    Args:
        y_true: true cluster labels as a 1D torch.Tensor
        y_pred: predicted cluster labels as a 1D torch.Tensor
        C:      number of classes

    Returns:
        accuracy: unsupervised clustering accuracy as a float
    """
    
    y_pred_reassigned = reassign(y_true, y_pred, C)

    # Compute accuracy as the percentage of correctly classified samples
    acc = balanced_accuracy_score(y_true, y_pred_reassigned)

    return acc


def test_classifier(model, data_loader: DataLoader, num_classes: int) -> float:
    """
    Testing a classifier given the model and a test set.

    Args:
        model: Neural network model to train.
        data_loader: PyTorch data loader containing the test data.
    
    Returns:
        None
    """
    
    # Disable gradient computation, not needed for inference
    model.eval()
    # Initialize tensors for storing true and predicted labels
    y_true = torch.zeros(len(data_loader.dataset))
    y_pred = torch.zeros(len(data_loader.dataset))

    with torch.no_grad():
        # Iterate over the mini-batches in the data loader
        for i, data in enumerate(data_loader):
            # Get the inputs and true labels for the mini-batch and reshape
            inputs, labels_true = data
            
            # Use GPU if available
            inputs      = inputs.to(device)
                                    
            # # TODO flattening should be done in the feed forward model, else statement should be removed
            # inputs = inputs.view(inputs.size(0), -1)
            
            # Forward pass through the model to get predicted labels
            labels_pred = F.softmax(model(inputs), dim=1)

            # Store predicted and true labels in tensors
            y_pred[i*len(labels_true):(i+1)*len(labels_true)] = torch.argmax(labels_pred.cpu(), dim=1)
            y_true[i*len(labels_true):(i+1)*len(labels_true)] = labels_true

    # Compute unsupervised clustering accuracy score
    acc = unsupervised_clustering_accuracy(y_true, y_pred, C=num_classes)
    
    acc_balanced = unsupervised_balanced_clustering_accuracy(y_true, y_pred, C=num_classes)

    print(f"\nThe unsupervised clustering accuracy score of the classifier is: {acc}")
    
    return acc, acc_balanced

In [7]:
"""

"""

from archt import get_model

# Information Maximizing Self-Augmented Training
from IMSAT import regularized_information_maximization

# Invariant Information Clustering
from IIC import invariant_information_clustering

from datasets.dataset_classes import NDSBDataset, MNISTDataset

In [8]:
"""

"""

# Create the train and test datasets
train_dataset = NDSBDataset(train=True, augment_data=True)
test_dataset  = NDSBDataset(train=False)

# Create the train and test data loaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader  = DataLoader(test_dataset, batch_size=batch_size)

In [9]:
# Initialize model
model = get_model("inception_v3", num_classes=10).to(device)

# Initialize loss function, and optimizer
criterion = invariant_information_clustering
optimizer = optim.Adam(model.parameters(), lr=lr)

# Store metadata to .log file
logger = logging.getLogger(__name__)
# Set the logging level
logger.setLevel(logging.INFO)
# Add handler to the logger
logger.addHandler(logging.FileHandler(f'logs/IIC_ten_classes_balanced/{datetime.now().strftime("%Y-%m-%d-%H-%M")}.log'))

# Write metadata to .log file
logger.info(f'Optimization criterion: {criterion.__name__}')
logger.info(f'Learning rate: {lr}')
logger.info(f'Number of epochs: {num_epochs}')
logger.info(f'Batch size: {batch_size}')
logger.info(f'Optimizer: {optimizer}')
logger.info(f'Model: {model}')

# Train the model
train(model, train_loader, criterion, optimizer, num_epochs, num_classes=10)

# Test model
acc, acc_balanced = test_classifier(model, test_loader,num_classes=10)

logger.info(f'Accuracy: {acc}')
logger.info(f'Balanced Accuracy: {acc_balanced}')
# Close data file
f.close()



Model specifications: Inception3(
  (Conv2d_1a_3x3): BasicConv2d(
    (conv): Conv2d(1, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
    (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (Conv2d_2a_3x3): BasicConv2d(
    (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), bias=False)
    (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (Conv2d_2b_3x3): BasicConv2d(
    (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (maxpool1): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  (Conv2d_3b_1x1): BasicConv2d(
    (conv): Conv2d(64, 80, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn): BatchNorm2d(80, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (Conv2d_4a_3x3): BasicConv2d(
    (conv): Conv2d(80, 192, ke

Epoch 1 loss: -0.0097,              running_acc: 0.2523, acc: 0.2286,              running_nmi: 0.1567, nmi: 0.1110
Epoch 2 loss: -0.0631,              running_acc: 0.2815, acc: 0.2808,              running_nmi: 0.2231, nmi: 0.2069
Epoch 3 loss: -0.1252,              running_acc: 0.2966, acc: 0.3005,              running_nmi: 0.2520, nmi: 0.2300
Epoch 4 loss: -0.2081,              running_acc: 0.3281, acc: 0.3232,              running_nmi: 0.2758, nmi: 0.2634
Epoch 5 loss: -0.2965,              running_acc: 0.3329, acc: 0.3300,              running_nmi: 0.2839, nmi: 0.2764
Epoch 6 loss: -0.3707,              running_acc: 0.3270, acc: 0.3310,              running_nmi: 0.2948, nmi: 0.2942
Epoch 7 loss: -0.4426,              running_acc: 0.3291, acc: 0.3232,              running_nmi: 0.3174, nmi: 0.3096
Epoch 8 loss: -0.5258,              running_acc: 0.3232, acc: 0.3232,              running_nmi: 0.3409, nmi: 0.3337
Epoch 9 loss: -0.5681,              running_acc: 0.3389, acc: 0.3291,   