In [None]:
# model's seeds
s1 = 0
s2 = None
s3 = None
s4 = None
s5 = None
s6 = None
s7 = None
s8 = None
s9 = None
s10 = None

# additional paramters
model_parameter_name = 'sl'
uns_parameter_percentage = 1.0
sup_loss_weight = 1.0
GPU_ID = '3'

In [None]:
seeds_list = [int(s) for s in [s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] if s is not None]

assert len(seeds_list) > 0, "seeds_list should have at least one entry"
assert all(isinstance(seed, int) for seed in seeds_list), "Not all entries are integers"
assert model_parameter_name is not None, "model_parameter_name should not be None"
assert uns_parameter_percentage is not None, "uns_parameter_percentage should not be None"
assert sup_loss_weight is not None, "sup_loss_weight should not be None"
assert GPU_ID is not None, "GPU_ID should not be None"

print("Papermill seeds parameters are: " + str(seeds_list))
print("Papermill model name is: " + model_parameter_name)
print("Papermill uns_parameter_percentage is: " + str(uns_parameter_percentage))
print("Papermill sup_loss_weight is: " + str(sup_loss_weight))

# IMPORTS

In [None]:
import sys
import os

os.environ['CUDA_VISIBLE_DEVICES'] = GPU_ID
sys.path.append(os.path.abspath(".."))       # for 'protonet_mnist_add_utils' folder
sys.path.append(os.path.abspath("../.."))    # for 'data' folder
sys.path.append(os.path.abspath("../../..")) # for 'models' and 'datasets' folders


print(sys.path)

In [None]:
import cv2
import torch
import torch.nn as nn
import argparse
import datetime
import importlib
import setproctitle, socket, uuid
import torch.nn.functional as F
import numpy as np
import random

from tqdm import tqdm
from argparse import Namespace
from numpy import float32, zeros
from datasets import get_dataset
from models import get_model
from models.mnistdpl import MnistDPL
from warmup_scheduler import GradualWarmupScheduler
from datasets.utils.base_dataset import BaseDataset
from torchvision import datasets, transforms
from cv2 import (
    INTER_CUBIC, 
    getRotationMatrix2D, 
    imread, 
    warpAffine, 
    moments, 
    WARP_INVERSE_MAP
)
from torch.nn.modules import Module

from collections import Counter
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader

from utils import fprint
from utils.train import train
from utils.test import test
from utils.preprocess_resnet import preprocess
from utils.conf import *
from utils.args import *
from utils.status import progress_bar
from utils.checkpoint import save_model, create_load_ckpt
from utils.dpl_loss import ADDMNIST_DPL
from utils.metrics import (
    evaluate_metrics,
    evaluate_mix,
    mean_entropy,
)
from protonet_mnist_add_modules.utility_modules.proto_utils import ( 
    init_dataloader,
    get_random_classes
)

from sklearn.metrics import confusion_matrix
from sklearn.metrics import silhouette_samples
from scipy.spatial.distance import cdist

from protonet_mnist_add_modules.arguments import args_sl, args_ltn, args_dpl
from protonet_mnist_add_modules.utility_modules.plotting import plot_training_image
from protonet_mnist_add_modules.data_modules.proto_data_creation import (
    choose_initial_prototypes,
    get_original_support_query_set,
    get_augmented_support_query_set,
    get_augmented_support_query_loader
)
from backbones.addmnist_protonet import PrototypicalLoss

from protonet_mnist_add_modules.utility_modules.setup import my_gpu_info

# SETUP

In [None]:
MODEL = model_parameter_name
UNS_PERCENTAGE = uns_parameter_percentage
CONCEPT_LOSS_WEIGHT = sup_loss_weight
if CONCEPT_LOSS_WEIGHT > 1.0 and MODEL != 'sl':
    raise Exception("Concept loss weight should be less than or equal to 1.0 for DPL and LTN")
elif MODEL != 'sl':
    assert CONCEPT_LOSS_WEIGHT == 1.0, 'Loss weight greater than 1 is only for SL'

print("Model: ", MODEL)
print("Unsupervised Percentage: ", UNS_PERCENTAGE)
print("Concept Loss Weight: ", CONCEPT_LOSS_WEIGHT)

In [None]:
if MODEL == 'sl':       args = args_sl
elif MODEL == 'ltn':    args = args_ltn
else:                   args = args_dpl

args.seeds = seeds_list
print("Seeds: " + str(args.seeds))

# logging
args.conf_jobnum = str(uuid.uuid4())
args.conf_timestamp = str(datetime.datetime.now())
args.conf_host = socket.gethostname()

# set job name
setproctitle.setproctitle(
    "{}_{}_{}".format(
        args.model,
        args.buffer_size if "buffer_size" in args else 0,
        args.dataset,
    )
)

# saving
save_folder = "mnadd-even-odd" 
save_model_name = MODEL
save_paths = []
for i in range(len(args.prototypical_loss_weight)):
    save_path = os.path.join("..", "..", 
        "outputs", 
        save_folder, 
        "my_models", 
        save_model_name,
        f"DEBUG-baseline-concept-supervised-{UNS_PERCENTAGE}-ARTICLE"
    )
    save_paths.append(save_path)
print(f"Save paths: {str(save_paths)}")

In [None]:
my_gpu_info()

In [None]:
if args.model not in ['mnistsl', 'mnistltn', 'mnistdpl'] or args.prototypes:
    raise ValueError("This experiment is meant for baseline models.")

# Utilities

In [None]:

# * for training data inspection
def plot_training_image(images, labels, plot_index_start=0, plot_index_end=10):
    for plotting_index in range(plot_index_start, plot_index_end + 1):
        image = images[plotting_index].cpu().numpy().transpose(1, 2, 0)
        plt.figure(figsize=(5, 5))
        plt.imshow(image, cmap='gray')
        plt.title(f"Label {labels[plotting_index]}")
        plt.axis('off')
        plt.show()

In [None]:

# * For prototypes computation at test phase 
def get_random_classes(images, labels, n_support):
    unique_classes = torch.unique(labels)
    assert len(unique_classes) == 10, "There should be exactly 10 unique classes."

    selected_images = []
    selected_labels = []

    for cls in unique_classes:
        class_indices = (labels == cls).nonzero(as_tuple=True)[0]
        assert len(class_indices) >= n_support, f"Not enough samples for class {cls}"
        random_indices = torch.randperm(len(class_indices))[:n_support]
        selected_images.append(images[class_indices[random_indices]])
        selected_labels.append(labels[class_indices[random_indices]])

    selected_images = torch.cat(selected_images)
    selected_labels = torch.cat(selected_labels)

    return selected_images, selected_labels

# DATA

## Loading data for training the ProtoNet

In [None]:
args_protonet = Namespace(
    dataset=args.prototypical_dataset,     
    batch_size=args.prototypical_batch_size,
    preprocess=0,
    c_sup=1, # ^ supervision loaded to simulate direct annotation for prototypes
    which_c=[-1],
    model=args.model,        
    task=args.task,    
)

addmnist_dataset = get_dataset(args_protonet)
addmnist_train_loader, _ , _ = addmnist_dataset.get_data_loaders()
print(addmnist_dataset)

## Create (or get) the initial annotated images-prototype seed and augment it

In [None]:
if ( (not os.path.exists('data/prototypes/proto_loader_dataset.pth')) or args.debug ):
    print("Creating proto_loader_dataset.pth")
    choose_initial_prototypes(addmnist_train_loader, debug=args.debug)

tr_dataloader = init_dataloader()

support_images_aug, support_labels_aug, query_images_aug, query_labels_aug, no_aug = get_augmented_support_query_set(
    tr_dataloader, debug=args.debug)

assert support_images_aug.numel() > 0, "support_images_aug is an empty tensor"
assert support_labels_aug.numel() > 0, "support_labels_aug is an empty tensor"
assert query_images_aug.numel() > 0, "query_images_aug is an empty tensor"
assert query_labels_aug.numel() > 0, "query_labels_aug is an empty tensor"

assert not torch.all(support_images_aug == 0), "All elements in support_images_aug are zero"
assert not torch.all(support_labels_aug == 0), "All elements in support_labels_aug are zero"
assert not torch.all(query_images_aug == 0), "All elements in query_images_aug are zero"
assert not torch.all(query_labels_aug == 0), "All elements in query_labels_aug are zero"

assert no_aug > 0, "no_aug should be greater than 0"

support_loader, query_loader = get_augmented_support_query_loader(
    support_images_aug, 
    support_labels_aug, 
    query_images_aug, 
    query_labels_aug,
    query_batch_size=32,
    debug=args.debug
)

## Dataset Creation & Check

In [None]:
class MNISTAugDataset(Dataset):
    def __init__(self, images, labels, hide_labels=None, transform=None):
        self.images = images
        self.labels = labels
        self.transform = transform

    def __getitem__(self, index):
        image = self.images[index]
        label = self.labels[index]
        if self.transform:
            image = self.transform(image)
        return image, label.squeeze()

    def __len__(self):
        return len(self.labels)

# Example instantiation:
mnist_dataset = MNISTAugDataset(support_images_aug, support_labels_aug)

In [None]:
# Plot ten images from the MNISTAugDataset
fig, axes = plt.subplots(5, 10, figsize=(15, 10))
for i, ax in enumerate(axes.flat):
    if i < len(mnist_dataset.images):
        offset = random.randint(0, len(mnist_dataset.images) - 1)
        image = mnist_dataset.images[offset].cpu().numpy().transpose(1, 2, 0)
        ax.imshow(image, cmap='gray')
        ax.set_title(f"Label {mnist_dataset.labels[offset].item()}")  
        ax.axis('off')
plt.tight_layout()
plt.show()

# plot_training_image(mnist_dataset.images, mnist_dataset.labels, plot_index_start=40, plot_index_end=50)

In [None]:
sup_train_loader = DataLoader(mnist_dataset, batch_size=args.batch_size, shuffle=True)

## Getting the Unsupervised Dataset

In [None]:
dataset = get_dataset(args)
print(dataset)

n_images, c_split = dataset.get_split()
encoder, decoder = dataset.get_backbone()
model = get_model(args, encoder, decoder, n_images, c_split)
loss = model.get_loss(args)
model.start_optim(args)

print("Using Dataset: ", dataset)
print("Using backbone: ", encoder)
print("Using Model: ", model)
print("Using Loss: ", loss)

unsup_train_loader, unsup_val_loader, _ = dataset.get_data_loaders()
dataset.print_stats()

## Checkig the Unsupervised Data

In [None]:
for i, data in enumerate(unsup_train_loader):
    images, labels, concepts = data        
    plot_training_image(images, labels)
    break

# Training

In [None]:
args.model

## Main Loop

In [None]:
def train(model:MnistDPL,
        sup_train_loader:DataLoader,
        unsup_train_loader:DataLoader,
        unsup_val_loader:DataLoader,
        _loss: ADDMNIST_DPL, 
        args,
        seed,
        save_folder,
        sup_loss_weight=1.0,
        debug=False):
    
    # for full reproducibility
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
    torch.backends.cudnn.enabled = False
    
    best_f1 = 0.0
    epochs_no_improve = 0   # for early stopping

    # model configuration for shortmnist
    if args.dataset == "shortmnist":    model = model.float()
    model.to(model.device)

    # get the data loaders
    scheduler = torch.optim.lr_scheduler.ExponentialLR(model.opt, args.exp_decay)
    w_scheduler = None
    if args.warmup_steps > 0:   w_scheduler = GradualWarmupScheduler(model.opt, 1.0, args.warmup_steps)

    fprint("\n--- Start of Training ---\n")

    # default for warm-up
    model.opt.zero_grad()
    model.opt.step()

    # & FOR EACH EPOCH
    for epoch in range(args.proto_epochs):  # ^ ensure consistency with the number of epochs used for prototypical networks
        model.train()

        ###############################
        # 1. Episodic phase: Teach the model to recognize digits
        ###############################
        print("Start of supervised episodic training.")
        for i, (images, labels) in enumerate(sup_train_loader):
            sup_images = images.to(model.device)  # shape: (batch_size, C, 28, 28)
            sup_labels = labels.to(model.device)  # shape: (batch_size,)
            batch_size = sup_images.size(0)

            assert sup_images.shape == torch.Size([batch_size, 1, 28, 28]), \
            f"Expected shape [{batch_size}, 1, 28, 28], but got {sup_images.shape}"
            assert sup_labels.shape == torch.Size([batch_size]), \
            f"Expected shape [{batch_size}], but got {sup_labels.shape}"

            # Ensure batch size is even to form pairs (if odd, drop the last sample)
            if batch_size % 2 != 0:
                sup_images = sup_images[:-1]
                sup_labels = sup_labels[:-1]
                batch_size -= 1
                
            # Merge pairs: merge 0 with 1, 2 with 3, and so on. This yields merged_images of shape (batch_size//2, C, 28, 56)
            merged_images = torch.cat([sup_images[0::2], sup_images[1::2]], dim=3)

            assert merged_images.shape == torch.Size([batch_size//2, 1, 28, 56]), \
            f"Expected shape [{batch_size//2}, 1, 28, 56], but got {merged_images.shape}"
            
            # Extract corresponding labels for each digit in the pair
            labels_first = sup_labels[0::2]   # labels for the first digit in each pair
            labels_second = sup_labels[1::2]  # labels for the second digit in each pair

            # Plot the first merged image
            if debug:
                fig, axes = plt.subplots(1, min(5, merged_images.size(0)), figsize=(15, 3))
                for idx in range(min(5, merged_images.size(0))):
                    axes[idx].imshow(merged_images[idx].cpu().numpy().squeeze(), cmap='gray')
                    axes[idx].set_title(f"{labels_first[idx].item()}, {labels_second[idx].item()}")
                    axes[idx].axis('off')
                plt.tight_layout()
                plt.show()
            
                
            # Forward pass: the model expects an image with two digits
            out_dict = model(merged_images)
            nconcept_preds = out_dict["pCS"]
            
            assert nconcept_preds.shape == torch.Size([batch_size//2, 2, 10]), \
                f"Expected shape [{batch_size//2}, 2, 10], but got {nconcept_preds.shape}"
            
            concept_loss_first = F.cross_entropy(nconcept_preds[:, 0], labels_first)
            concept_loss_second = F.cross_entropy(nconcept_preds[:, 1], labels_second)
            concept_loss = sup_loss_weight * (concept_loss_first + concept_loss_second)

            # Backward pass and optimization step
            model.opt.zero_grad()
            concept_loss.backward()
            model.opt.step()

            if i % 10 == 0:
                print(f"Episodic phase, Epoch {epoch}, Batch {i}: Concept Loss = {concept_loss.item():.4f}")
        
        ###############################
        # 2. Original unsupervised training phase (sum prediction)
        ###############################
        # ys are the predictions of the model, y_true are the true labels, cs are the predictions of the concepts, cs_true are the true concepts
        ys, y_true, cs, cs_true = None, None, None, None
        
        # & FOR EACH BATCH
        print("Start of unsupervised training.")
        for i, data in enumerate(unsup_train_loader):
            if random.random() > UNS_PERCENTAGE:
                continue  # Skip this batch with probability (1 - percentage)

            images, labels, concepts = data
            images, labels, concepts = (
                images.to(model.device),    # input IMAGES
                labels.to(model.device),    # ground truth LABELS
                concepts.to(model.device),  # ground truth CONCEPTS
            )

            # ^ baseline model
            out_dict = model(images)

            ''' Enrich the out_dict with the ground truth labels and concepts '''
            out_dict.update({"LABELS": labels, "CONCEPTS": concepts})

            ''' Extract the predicted concepts for the first image in the batch '''
            model.opt.zero_grad()
            loss, losses = _loss(out_dict, args)
            loss.backward()
            model.opt.step()
            
            if ys is None:  # first iteration
                ys = out_dict["YS"]
                y_true = out_dict["LABELS"]
                cs = out_dict["pCS"]
                cs_true = out_dict["CONCEPTS"]
            else:           # all other iterations
                ys = torch.concatenate((ys, out_dict["YS"]), dim=0)
                y_true = torch.concatenate((y_true, out_dict["LABELS"]), dim=0)
                cs = torch.concatenate((cs, out_dict["pCS"]), dim=0)
                cs_true = torch.concatenate((cs_true, out_dict["CONCEPTS"]), dim=0)

            if i % 10 == 0:
                progress_bar(i, len(unsup_train_loader) - 9, epoch, loss.item())


        print("%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%")
        print("End of epoch ", epoch)
        print("%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%")
        print()
        
        if UNS_PERCENTAGE == 0.0:
            print("Saving...")
            torch.save(model.state_dict(), save_folder)
            print(f"Saved best model with F1 score: {best_f1}")
            print()
            continue
        # this are the actual model predictions
        y_pred = torch.argmax(ys, dim=-1)

        
        # enter the evaluation phase
        model.eval()
        # ^ baseline model
        tloss, cacc, yacc, f1 = evaluate_metrics(model, unsup_val_loader, args)

        # update the (warmup) scheduler at end of the epoch
        if epoch < args.warmup_steps:
            w_scheduler.step()
        else:
            scheduler.step()
            if hasattr(_loss, "grade"):
                _loss.update_grade(epoch)

        ### LOGGING ###
        fprint("  ACC C", cacc, "  ACC Y", yacc, "F1 Y", f1)
        print()

        if not args.tuning and f1 > best_f1:
            print("Saving...")
            # Update best F1 score
            best_f1 = f1
            epochs_no_improve = 0

            # Save the best model
            torch.save(model.state_dict(), save_folder)
            print(f"Saved best model with F1 score: {best_f1}")
            print()
        
        elif f1 <= best_f1:
            epochs_no_improve += 1

        if epochs_no_improve >= args.patience:
            print(f"Early stopping triggered after {epoch+1} epochs.")
            break

    print("End of training")
    return best_f1

## Running

In [None]:
f1_scores = dict()
save_path = save_paths[0]
for seed in args.seeds:
    print(f"*** Training model with seed {seed}")
    print("Chosen device:", model.device)
    print("Save path for this model: ", save_path)
    if not os.path.exists(save_path): os.makedirs(save_path, exist_ok=True)
    save_folder = os.path.join(save_path, f"{save_model_name}_{seed}.pth")
    print("Saving in folder: ", save_folder)

    best_f1 = train(model=model,
        sup_train_loader=sup_train_loader,
        unsup_train_loader=unsup_train_loader,
        unsup_val_loader=unsup_val_loader,
        _loss=loss, 
        args=args,
        seed=seed,
        sup_loss_weight=CONCEPT_LOSS_WEIGHT,
        save_folder=save_folder
    )
    f1_scores[seed] = best_f1
    save_model(model, args, seed)  # save the model parameters

best_weight_seed = max(f1_scores, key=f1_scores.get)
print(f"Best weight and seed combination: {best_weight_seed} with F1 score: {f1_scores[best_weight_seed]}")