# PAPERMILL

In [None]:
# model's seed
seed = 0

# additional paramters
model_parameter_name = 'sl'
uns_parameter_percentage = 1.0
sup_loss_weight = 1.0
GPU_ID = '1'

In [None]:
assert seed is not None, "seed should not be None"
assert isinstance(seed, int), "seed should be an integer"
assert isinstance(uns_parameter_percentage, float), "uns_parameter_percentage should be a float"
assert 0.0 <= uns_parameter_percentage <= 1.0, "uns_parameter_percentage should be in the range [0.0, 1.0]"
assert model_parameter_name is not None, "model_parameter_name should not be None"
assert sup_loss_weight is not None, "sup_loss_weight should not be None"
assert isinstance(sup_loss_weight, float), "sup_loss_weight should be a float"
assert GPU_ID is not None, "GPU_ID should not be None"

print("Papermill seed parameter is: " + str(seed))
print("Papermill model name is: " + model_parameter_name)
print("Papermill uns_parameter_percentage is: " + str(uns_parameter_percentage))
print("Papermill GPU_ID is: " + GPU_ID)

# IMPORTS

In [None]:
import sys
import os

os.environ['CUDA_VISIBLE_DEVICES'] = '1'
sys.path.append(os.path.abspath(".."))       # for 'protonet_mnist_add_utils' folder
sys.path.append(os.path.abspath("../.."))    # for 'data' folder
sys.path.append(os.path.abspath("../../..")) # for 'models' and 'datasets' folders


print(sys.path)

In [None]:
import cv2
import torch
import torch.nn as nn
import argparse
import datetime
import importlib
import setproctitle, socket, uuid
import torch.nn.functional as F
import matplotlib.pyplot as plt

from tqdm import tqdm
from argparse import Namespace
from numpy import float32, zeros
from datasets import get_dataset
from models import get_model
from models.mnistdpl import MnistDPL
from warmup_scheduler import GradualWarmupScheduler
from datasets.utils.base_dataset import BaseDataset
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms
from utils import fprint
from utils.train import train
from utils.test import test
from utils.preprocess_resnet import preprocess
from utils.conf import *
from utils.args import *
from utils.status import progress_bar
from utils.checkpoint import save_model, create_load_ckpt
from utils.dpl_loss import ADDMNIST_DPL
from utils.metrics import (
    evaluate_metrics,
    evaluate_mix,
    mean_entropy,
)
from sklearn.metrics import confusion_matrix
from sklearn.metrics import silhouette_samples
from scipy.spatial.distance import cdist

from protonet_kand_modules.arguments import args_dpl, args_sl, args_ltn
from protonet_kand_modules.utility_modules.check_gpu import my_gpu_info

# SETUP

In [None]:
if model_parameter_name == 'dpl':   args = args_dpl
elif model_parameter_name == 'sl':  args = args_sl
else:                               args = args_ltn

# saving
save_folder = "kandinsky" 
save_model_name = model_parameter_name
save_paths = []
save_path = os.path.join("..", "NEW-outputs", 
    save_folder, 
    "baseline", 
    save_model_name,
    f"DEBUG-supervisions-via-augmentations-{uns_parameter_percentage}",
)
save_paths.append(save_path)
print(f"Save paths: {str(save_paths)}")

if args.model in ['prokandsl', 'prokandltn', 'prokanddpl'] or args.prototypes:
    raise ValueError("This experiment is NOT meant for pNet based models.")

In [None]:
MODEL = model_parameter_name
UNS_PERCENTAGE = uns_parameter_percentage
CONCEPT_LOSS_WEIGHT = sup_loss_weight
if CONCEPT_LOSS_WEIGHT > 1.0 and MODEL != 'sl':
    raise Exception("Concept loss weight should be less than or equal to 1.0 for DPL and LTN")
elif MODEL != 'sl':
    assert CONCEPT_LOSS_WEIGHT == 1.0, 'Loss weight greater than 1 is only for SL'

print("Model: ", MODEL)
print("Unsupervised Percentage: ", UNS_PERCENTAGE)
print("Concept Loss Weight: ", CONCEPT_LOSS_WEIGHT)

In [None]:
my_gpu_info()

In [None]:
# Add uuid, timestamp and hostname for logging
args.conf_jobnum = str(uuid.uuid4())
args.conf_timestamp = str(datetime.datetime.now())
args.conf_host = socket.gethostname()

# set job name
setproctitle.setproctitle(
    "{}_{}_{}".format(
        args.model,
        args.buffer_size if "buffer_size" in args else 0,
        args.dataset,
    )
)

# DATASET

## Supervisions

In [None]:
proto_images = torch.load('data/kand_annotations/yolo_annotations/images.pt')
proto_labels = torch.load('data/kand_annotations/yolo_annotations/labels.pt')
print("Prototypical data loaded")
print("Images: ", proto_images.shape)
print("Labels: ", proto_labels.shape)

In [None]:
class SupervisedDataset(Dataset):
    def __init__(self, images, labels, transform=None):
        """
        Args:
            images (Tensor): Tensor of shape [N, 3, 64, 64]
            labels (Tensor): Tensor of shape [N, 6] where:
                             - labels[:, :3] are the shape labels  (0: square, 1: circle, 2: triangle)
                             - labels[:, 3:] are the colour labels (0: red, 1: yellow, 2: blue)
            transform: Optional transformation to apply to images.
        """
        self.images = images
        self.labels = labels  # shape [N, 6]
        self.transform = transform

    def __getitem__(self, index):
        image = self.images[index]
        label = self.labels[index].long()
        if self.transform:
            image = self.transform(image)
        return image, label.squeeze(0)

    def __len__(self):
        return len(self.images)

In [None]:
# Create PrimitvesDataset instance
kand_sup_dataset = SupervisedDataset(proto_images, proto_labels, transform=None)
sup_train_loader = DataLoader(kand_sup_dataset, batch_size=args.batch_size, shuffle=True)

# Plot the first 30 images in kand_proto_dataset
fig, axes = plt.subplots(3, 10, figsize=(20, 6))
axes = axes.flatten()
for i in range(30):
    image, labels = kand_sup_dataset[i]
    axes[i].imshow(image.permute(1, 2, 0))  # Convert from CHW to HWC for plotting
    axes[i].set_title(f"{labels.numpy()}")
    axes[i].axis("off")
plt.tight_layout()
plt.show()

## Unsupervised

In [None]:
dataset = get_dataset(args)
unsup_train_loader, unsup_val_loader, unsup_test_loader = dataset.get_data_loaders()
dataset.print_stats()    
n_images, c_split = dataset.get_split()
encoder, decoder = dataset.get_backbone()
model = get_model(args, encoder, decoder, n_images, c_split)
loss = model.get_loss(args)
model.start_optim(args)

print("Using Dataset: ", dataset)
print("Number of images: ", n_images)
print("Using backbone: ", encoder)
print("Using Model: ", model)
print("Using Loss: ", loss)
print("Working with taks: ", args.task)

# TRAINING LOOP

In [None]:
def train(model:MnistDPL,
        sup_train_loader:DataLoader,
        unsup_train_loader:DataLoader,
        unsup_val_loader:DataLoader,
        _loss: ADDMNIST_DPL, 
        args,
        seed,
        save_folder,
        sup_loss_weight=1.0,
        patience=5,
        debug=False):
    
     # for full reproducibility
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
    torch.backends.cudnn.enabled = False
    
    best_cacc = 0.0
    epochs_no_improve = 0   # for early stopping

    model.to(model.device)

    scheduler = torch.optim.lr_scheduler.ExponentialLR(model.opt, args.exp_decay)
    w_scheduler = None
    if args.warmup_steps > 0:   w_scheduler = GradualWarmupScheduler(model.opt, 1.0, args.warmup_steps)

    fprint("\n--- Start of Training ---\n")

    # default for warm-up
    model.opt.zero_grad()
    model.opt.step()

    # & FOR EACH EPOCH
    for epoch in range(args.proto_epochs):  # ^ ensure consistency with the number of epochs used for prototypical networks
        model.train()

        ###############################
        # 1. Supervised phase: Teach the model to recognize primitives (merged in triplets)
        ###############################
        print("Start of supervised episodic training.")
        for i, (images, labels) in enumerate(sup_train_loader):
            sup_images = images.to(model.device)  # shape: (batch_size, 3, 64, 64)
            sup_labels = labels.to(model.device)  # shape: (batch_size, 6)
            batch_size = sup_images.size(0)

            assert sup_images.dim() == 4 and sup_images.size(1) == 3 \
                and sup_images.size(2) == 64 and sup_images.size(3) == 64, \
                f"Expected sup_images [B,3,64,64], got {sup_images.shape}"
            assert sup_labels.shape == torch.Size([batch_size, 6]), \
                f"Expected sup_labels [{batch_size},6], got {sup_labels.shape}"

            # make batch_size divisible by 3 by dropping the extra 1 or 2 samples
            if batch_size % 3 != 0:
                drop = batch_size % 3
                sup_images = sup_images[:-drop]
                sup_labels = sup_labels[:-drop]
                batch_size -= drop

            # now form triplets: 0 with 1 with 2, 3 with 4 with 5, ...
            merged_images = torch.cat([
                sup_images[0::3],   # first in each triplet
                sup_images[1::3],   # second
                sup_images[2::3]    # third
            ], dim=3)  # concat along width → new width = 64*3 = 192

            expected_bs = batch_size // 3
            assert merged_images.shape == torch.Size([expected_bs, 3, 64, 192]), \
                f"Expected merged_images [{expected_bs},3,64,192], got {merged_images.shape}"

            # extract labels for each of the three primitives in the triplet
            labels_first  = sup_labels[0::3]  # [bs//3, 6]
            labels_second = sup_labels[1::3]  # [bs//3, 6]
            labels_third  = sup_labels[2::3]  # [bs//3, 6]

            if args.debug:
                img = merged_images[0].cpu().permute(1, 2, 0).numpy()
                plt.imshow(img)
                plt.title(f"Labels: {labels_first[0].tolist()} | "
                        f"{labels_second[0].tolist()} | "
                        f"{labels_third[0].tolist()}")
                plt.axis('off')
                plt.show()

            # Forward pass: now feeding the concatenated triplets
            out_dict = model(merged_images)
            logits = out_dict["CS"]
            
            B = logits.size(0)
            num_objects = 6
            num_classes = 3

            def triplet_loss(logits_slice, labels_slice):
                l = logits_slice.reshape(B, num_objects, num_classes)
                l_flat = l.reshape(B * num_objects, num_classes)
                t_flat = labels_slice.reshape(B * num_objects)
                return F.cross_entropy(l_flat, t_flat)

            logits1 = logits[:, 0, :]  # [B, 18]
            logits2 = logits[:, 1, :]
            logits3 = logits[:, 2, :]

            # compute individual losses
            loss1 = triplet_loss(logits1, labels_first)
            loss2 = triplet_loss(logits2, labels_second)
            loss3 = triplet_loss(logits3, labels_third)

            # total concept loss (you can average instead of sum if you prefer)
            concept_loss = sup_loss_weight * (loss1 + loss2 + loss3)

            # backprop
            model.opt.zero_grad()
            concept_loss.backward()
            model.opt.step()

            if i % 10 == 0:
                print(f"Supervised phase, Epoch {epoch}, Batch {i}: "
                    f"Concept Loss = {concept_loss.item():.4f}")
        
        ###############################
        # 2. Original unsupervised training phase (sum prediction)
        ###############################
        # ys are the predictions of the model, y_true are the true labels, cs are the predictions of the concepts, cs_true are the true concepts
        ys, y_true, cs, cs_true = None, None, None, None
        
        # & FOR EACH BATCH
        print("Start of unsupervised training.")
        for i, data in enumerate(unsup_train_loader):
            if random.random() > UNS_PERCENTAGE:
                continue  # Skip this batch with probability (1 - percentage)

            images, labels, concepts = data
            images, labels, concepts = (
                images.to(model.device),
                labels.to(model.device),
                concepts.to(model.device),
            )

            # ^ baseline model
            out_dict = model(images)
            out_dict.update({"LABELS": labels, "CONCEPTS": concepts})
            
            model.opt.zero_grad()
            loss, losses = _loss(out_dict, args)

            loss.backward()
            model.opt.step()

            if ys is None:
                ys = out_dict["YS"]
                y_true = out_dict["LABELS"]
                cs = out_dict["pCS"]
                cs_true = out_dict["CONCEPTS"]
            else:
                ys = torch.concatenate((ys, out_dict["YS"]), dim=0)
                y_true = torch.concatenate((y_true, out_dict["LABELS"]), dim=0)
                cs = torch.concatenate((cs, out_dict["pCS"]), dim=0)
                cs_true = torch.concatenate((cs_true, out_dict["CONCEPTS"]), dim=0)

            if i % 10 == 0:
                progress_bar(i, len(unsup_train_loader) - 9, epoch, loss.item())

        print("%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%")
        print("End of epoch ", epoch)
        print("%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%")
        print()

        if UNS_PERCENTAGE == 0.0:
            print("Saving...")
            torch.save(model.state_dict(), save_folder)
            print(f"Saved best model with F1 score: {best_cacc}")
            print()
            continue

        y_pred = torch.argmax(ys, dim=-1)
        #print("Argmax predictions have shape: ", y_pred.shape)

        if "patterns" in args.task:
            y_true = y_true[:, -1]  # it is the last one

        model.eval()
        tloss, cacc, yacc, f1 = evaluate_metrics(model, unsup_val_loader, args)

        # update the (warmup) scheduler at end of the epoch
        if epoch < args.warmup_steps:
            w_scheduler.step()
        else:
            scheduler.step()
            if hasattr(_loss, "grade"):
                _loss.update_grade(epoch)

        ### LOGGING ###
        fprint("  ACC C", cacc, "  ACC Y", yacc, "F1 Y", f1)
        print()

        if not args.tuning and cacc > best_cacc:
            print("Saving...")
            # Update best F1 score
            best_cacc = cacc
            epochs_no_improve = 0
                
            # Save the best model and the concept extractor
            torch.save(model.state_dict(), save_folder)
            print(f"Saved best model with CACC score: {best_cacc}")
            print()
        
        elif cacc <= best_cacc:
            epochs_no_improve += 1

        if epochs_no_improve >= patience:
            print(f"Early stopping triggered after {epoch+1} epochs.")
            break

    
    fprint("\n--- End of Training ---\n")
    return best_cacc

# RUN ALL THINGS

In [None]:
f1_scores = dict()
print(f"*** Training model with seed {seed}")
print("Chosen device:", model.device)
if not os.path.exists(save_path): os.makedirs(save_path, exist_ok=True)
save_folder = os.path.join(save_path, f"{save_model_name}_{seed}.pth")
print("Saving model in folder: ", save_folder)

best_f1 = train(model=model,
        sup_train_loader=sup_train_loader,
        unsup_train_loader=unsup_train_loader,
        unsup_val_loader=unsup_val_loader,
        _loss=loss, 
        args=args,
        seed=seed,
        save_folder=save_folder,
        sup_loss_weight=CONCEPT_LOSS_WEIGHT,
        debug=False
    )
f1_scores[(seed)] = best_f1
save_model(model, args, seed)  # save the model parameters

print(f"*** Finished training model with seed {seed}")

print("Training finished.")
best_weight_seed = max(f1_scores, key=f1_scores.get)
print(f"Best weight and seed combination: {best_weight_seed} with F1 score: {f1_scores[best_weight_seed]}")