# PARAMETERS

In [None]:
SEED = 1
MODEL_PARAMETER_NAME = 'ltn'
GPU_ID = '1'

# IMPORTS

In [None]:
import sys
import os

os.environ['CUDA_VISIBLE_DEVICES'] = GPU_ID
sys.path.append(os.path.abspath(".."))       # for 'protonet_mnist_add_utils' folder
sys.path.append(os.path.abspath("../.."))    # for 'data' folder
sys.path.append(os.path.abspath("../../..")) # for 'models' and 'datasets' folders


print(sys.path)

In [None]:
import cv2
import torch
import argparse
import datetime
import importlib
import matplotlib.pyplot as plt
import setproctitle, socket, uuid
import torch.nn.functional as F
import torchvision.transforms as T

from tqdm import tqdm
from collections import Counter
from argparse import Namespace
from numpy import float32, zeros
from datasets import get_dataset
from models import get_model
from models.mnistdpl import MnistDPL
from warmup_scheduler import GradualWarmupScheduler
from datasets.utils.base_dataset import BaseDataset
from torchvision import datasets, transforms
from utils import fprint
from utils.train import train
from utils.test import test
from utils.preprocess_resnet import preprocess
from utils.conf import *
from utils.args import *
from utils.status import progress_bar
from utils.checkpoint import save_model, create_load_ckpt
from utils.dpl_loss import ADDMNIST_DPL
from utils.metrics import (
    evaluate_metrics,
    evaluate_mix,
    mean_entropy,
)
from torch.utils.data import Dataset, DataLoader, Subset, TensorDataset

from sklearn.metrics import confusion_matrix
from sklearn.metrics import silhouette_samples
from scipy.spatial.distance import cdist

from ultralytics import YOLO

from protonet_kand_modules.arguments import args_dpl, args_sl, args_ltn
from protonet_kand_modules.utility_modules.check_gpu import my_gpu_info
from protonet_kand_modules.data_modules.proto_data_creation import get_support_loader

%matplotlib inline

# SETUP

In [None]:
if MODEL_PARAMETER_NAME == 'dpl':   args = args_dpl
elif MODEL_PARAMETER_NAME == 'sl':  args = args_sl
else:                               args = args_ltn

# saving
save_folder = "kand" 
save_model_name = MODEL_PARAMETER_NAME
save_paths = []
save_path = os.path.join("..", "outputs", 
    save_folder, 
    "baseline-kandinsky-single-joint", 
    save_model_name
)
save_paths.append(save_path)
print(f"Save paths: {str(save_paths)}")

if args.model in ['prokandsl', 'prokandltn', 'prokanddpl'] or args.prototypes:
    raise ValueError("This experiment is NOT meant for pNet based models.")

In [None]:
my_gpu_info()

In [None]:
# Add uuid, timestamp and hostname for logging
args.conf_jobnum = str(uuid.uuid4())
args.conf_timestamp = str(datetime.datetime.now())
args.conf_host = socket.gethostname()
args.GPU_ID = GPU_ID

# set job name
setproctitle.setproctitle(
    "{}_{}_{}".format(
        args.model,
        args.buffer_size if "buffer_size" in args else 0,
        args.dataset,
    )
)

## Annotated Dataset Creation

### Loading the data

In [None]:
proto_images = torch.load('../data/kand_annotations/pnet_proto/concept_prototypes.pt')
proto_labels = torch.load('../data/kand_annotations/pnet_proto/labels_prototypes.pt')
print("Prototypical data loaded")
print("Images: ", proto_images.shape)
print("Labels: ", proto_labels.shape)

support_loader = get_support_loader(proto_images, proto_labels, query_batch_size=32, debug=False)

### Creating the Datasets

In [None]:
class PrimitivesDataset(Dataset):
    def __init__(self, images, labels, transform=None):
        """
        Args:
            images (Tensor): Tensor of shape [N, 3, 64, 64]
            labels (Tensor): Tensor of shape [N, 2] where:
                             - labels[:, 0] is the shape label  (0: square, 1: circle, 2: triangle)
                             - labels[:, 1] is the colour label (0: red, 1: yellow, 2: blue)
            transform: Optional transformation to apply to images.
        """
        self.images = images
        self.labels = labels  # shape [N, 2]
        self.transform = transform

    def __getitem__(self, index):
        image = self.images[index]
        # Return shape label and colour label separately
        shape_label = self.labels[index, 0].long()
        color_label = self.labels[index, 1].long()
        if self.transform:
            image = self.transform(image)
        return image, shape_label.squeeze(), color_label.squeeze()

    def __len__(self):
        return len(self.images)

### Creating the Episodes

In [None]:
from torch.utils.data import Sampler

class FixedBatchSampler(Sampler):
    def __init__(self, dataset_size, batch_size, iterations):
        self.dataset_size = dataset_size
        self.batch_size = batch_size
        self.iterations = iterations

    def __iter__(self):
        for i in range(self.iterations):
            start = (i * self.batch_size) % self.dataset_size
            end = start + self.batch_size
            if end <= self.dataset_size:
                yield list(range(start, end))
            else:
                # wrap around if needed
                part1 = list(range(start, self.dataset_size))
                part2 = list(range(0, end - self.dataset_size))
                yield part1 + part2

    def __len__(self):
        return self.iterations


In [None]:
class PrototypicalBatchSampler(object):
    """
    Yields a batch of indices for episodic training.
    At each iteration, it randomly selects 'classes_per_it' classes and then picks
    'num_samples' samples for each selected class.
    """
    def __init__(self, labels, classes_per_it, num_samples, iterations):
        """
        Args:
            labels (array-like): 1D array or list of labels for the target task.
                                 This should be either the shape labels or the colour labels.
            classes_per_it (int): Number of random classes for each iteration.
            num_samples (int): Number of samples per class (support + query) in each episode.
            iterations (int): Number of iterations (episodes) per epoch.
        """
        self.labels = np.array(labels)
        self.classes_per_it = classes_per_it
        self.sample_per_class = num_samples
        self.iterations = iterations
        
        self.classes, self.counts = np.unique(self.labels, return_counts=True)
        self.classes = torch.LongTensor(self.classes)

        # Create an index matrix of shape (num_classes, max_samples_in_class)
        max_count = max(self.counts)
        self.indexes = np.empty((len(self.classes), max_count), dtype=int)
        self.indexes.fill(-1)
        self.indexes = torch.LongTensor(self.indexes)
        self.numel_per_class = torch.zeros(len(self.classes), dtype=torch.long)

        # Fill in the matrix with indices for each class.
        for idx, label in enumerate(self.labels):
            # Find the row corresponding to this label
            class_idx = (self.classes == label).nonzero(as_tuple=False).item()
            # Find the next available column (where the value is -1)
            pos = (self.indexes[class_idx] == -1).nonzero(as_tuple=False)[0].item()
            self.indexes[class_idx, pos] = idx
            self.numel_per_class[class_idx] += 1

    def __iter__(self):
        """
        Yield a batch of indices for each episode.
        """
        spc = self.sample_per_class
        cpi = self.classes_per_it

        for _ in range(self.iterations):
            batch = torch.LongTensor(cpi * spc)
            # Randomly choose 'classes_per_it' classes
            c_idxs = torch.randperm(len(self.classes))[:cpi]
            for i, class_idx in enumerate(c_idxs):
                s = slice(i * spc, (i + 1) * spc)
                # Randomly choose 'num_samples' indices for the class
                perm = torch.randperm(self.numel_per_class[class_idx])
                sample_idxs = perm[:spc]
                batch[s] = self.indexes[class_idx, sample_idxs]
            # Shuffle the batch indices
            batch = batch[torch.randperm(len(batch))]
            yield batch

    def __len__(self):
        return self.iterations

### Istantiating the DataLoaders

In [None]:
# Create PrimitvesDataset instance
kand_proto_dataset = PrimitivesDataset(proto_images, proto_labels, transform=None)

# Extract the 1D label arrays from the dataset labels. Note: support_dataset.labels is a tensor of shape [N,2].
shape_labels = kand_proto_dataset.labels[:, 0].numpy()  # & ok
color_labels = kand_proto_dataset.labels[:, 1].numpy()  # & ok

# Create episodes sampler for shapes and colours:
shape_sampler = PrototypicalBatchSampler(shape_labels, args.classes_per_it, args.num_samples, args.iterations)
color_sampler = PrototypicalBatchSampler(color_labels, args.classes_per_it, args.num_samples, args.iterations)

# Create dataloaders for each primitve
episodic_shape_dataloader = DataLoader(kand_proto_dataset, batch_sampler=shape_sampler)
episodic_color_dataloader = DataLoader(kand_proto_dataset, batch_sampler=color_sampler)

print(f"Number of episodes (shape): {len(episodic_shape_dataloader)}")
print(f"Number of episodes (color): {len(episodic_color_dataloader)}")

### Checking the Data

In [None]:
# Extract the 1D label arrays from the dataset labels. Note: support_dataset.labels is a tensor of shape [N,2].
num_distinct_shape_labels = np.unique(shape_labels).size

# Should have 3 labels for shapes (0: square, 1: circle, 2: triangle) and 3 labels for colours (0: red, 1: yellow, 2: blue)
print(f"Number of distinct shape labels: {num_distinct_shape_labels}")
num_distinct_color_labels = np.unique(color_labels).size
print(f"Number of distinct color labels: {num_distinct_color_labels}")

# Prototypical networks expects nunpy arrays for labels
assert isinstance(shape_labels, np.ndarray), "shape labels should be a numpy.ndarray"
assert isinstance(color_labels, np.ndarray), "color labels should be a numpy.ndarray"

# Check tensor shapes and values
assert kand_proto_dataset.images.shape == (shape_labels.size, 3, 64, 64), \
    "The shape of kand_proto_dataset.images should be (number of shape labels, 3, 64, 64)"
assert kand_proto_dataset.images.shape == (color_labels.size, 3, 64, 64), \
    "The shape of kand_proto_dataset.images should be (number of color labels, 3, 64, 64)"
assert kand_proto_dataset.labels.shape == (color_labels.size, 2), \
    "The shape of mnist_dataset.labels should be (number of shape labels, 1)"
assert kand_proto_dataset.labels.shape == (color_labels.size, 2), \
    "The shape of mnist_dataset.labels should be (number of color labels, 1)"
assert kand_proto_dataset.images.min() >= 0 and kand_proto_dataset.images.max() <= 1, \
    "The values of kand_proto_dataset.images should be between 0 and 1"
assert np.all(np.isin(shape_labels, [0, 1, 2])), "Shape labels should only contain values 0, 1, or 2"
assert np.all(np.isin(color_labels, [0, 1, 2])), "Color labels should only contain values 0, 1, or 2"    


for batch in episodic_shape_dataloader:
    images, shape_labels_batch, _ = batch
    shape_labels_list = shape_labels_batch.tolist()
    label_counts = Counter(shape_labels_list)
    print("Batch images shape:", images.shape)  # Expected: [batch_size, 3, 64, 64]
    print("Batch shape labels:", shape_labels_list)
    print("Shape label distribution in batch:", label_counts)
    break
fig, axes = plt.subplots(2, 5, figsize=(15, 6))
for i in range(10):
    ax = axes[i // 5, i % 5]
    img = images[i].permute(1, 2, 0).numpy() # Convert tensor from (3, 64, 64) to (64, 64, 3) for display
    ax.imshow(img)
    ax.set_title(f"Shape Label: {shape_labels_list[i]}")
    ax.axis("off")
plt.tight_layout()
plt.show()

### Inspect one batch from the color dataloader
print("\nColor-based episodic batch:")
for batch in episodic_color_dataloader:
    images, _, color_labels_batch = batch
    # We only need the color labels for the color network
    color_labels_list = color_labels_batch.tolist()
    label_counts = Counter(color_labels_list)
    print("Batch images shape:", images.shape)  # Expected: [batch_size, 3, 64, 64]
    print("Batch color labels:", color_labels_list)
    print("Color label distribution in batch:", label_counts)
    break
fig, axes = plt.subplots(2, 5, figsize=(15, 6))
for i in range(10):
    ax = axes[i // 5, i % 5]
    img = images[i].permute(1, 2, 0).numpy() # Convert tensor from (3, 64, 64) to (64, 64, 3) for display
    ax.imshow(img)
    ax.set_title(f"Color Label: {color_labels_list[i]}")
    ax.axis("off")
plt.tight_layout()
plt.show()

# MODEL & UNSUPERVISED DATASET LOADING

In [None]:
args.model

In [None]:
dataset = get_dataset(args)
n_images, c_split = dataset.get_split()
encoder, decoder = dataset.get_backbone()
model = get_model(args, encoder, decoder, n_images, c_split)
loss = model.get_loss(args)
model.start_optim(args)

print("Using Dataset: ", dataset)
print("Number of images: ", n_images)
print("Using backbone: ", encoder)
print("Using Model: ", model)
print("Using Loss: ", loss)
print("Working with taks: ", args.task)

# YOLO

In [None]:
yaml_path = os.path.join(os.getcwd(), "../data/kand_config_yolo.yaml")
my_yolo_project_path = f"ultralytics-4/"
my_yolo_premodel_path = f"ultralytics-4/pretrained/yolo11n.pt"
args.yolo_folder = my_yolo_project_path

yolo = YOLO(my_yolo_premodel_path)

# PRETRAINING

In [None]:
model.encoder

In [None]:
def evaluate_model(model, data_loader):
    model.encoder.eval()  # Switch to evaluation mode
    correct_shape = 0
    correct_color = 0
    total = 0

    with torch.no_grad():
        for images, shape_labels, color_labels in data_loader:
            images = images.to(model.device)
            shape_labels = shape_labels.to(model.device)
            color_labels = color_labels.to(model.device)

            preds = model.encoder(images)
            shape_preds, color_preds = torch.split(preds, [3, 3], dim=1)

            # Get predicted class indices
            shape_pred_labels = torch.argmax(shape_preds, dim=1)
            color_pred_labels = torch.argmax(color_preds, dim=1)

            # Compare with ground truth
            correct_shape += (shape_pred_labels == shape_labels).sum().item()
            correct_color += (color_pred_labels == color_labels).sum().item()
            total += images.size(0)

    shape_acc = correct_shape / total
    color_acc = correct_color / total
    overall_acc = (shape_acc + color_acc) / 2  # average of both

    print(f"Shape Accuracy: {shape_acc:.4f}")
    print(f"Color Accuracy: {color_acc:.4f}")
    print(f"Overall Accuracy: {overall_acc:.4f}")

In [None]:
def pre_train(model, train_loader, args, seed: int = 0):

    # ^ PHASE 0: PreTraining 
    # Full reproducibility
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
    torch.backends.cudnn.enabled = False

    # Optimizer for encoder
    enc_opt = torch.optim.Adam(
        model.encoder.parameters(), 
        lr=args.lr, 
        weight_decay=args.weight_decay
    )

    fprint("\n--- Start of Training ---\n")
    model.encoder.to(model.device)
    model.encoder.train()

    # Define separate loss functions for shape and color prediction
    shape_loss_fn = torch.nn.CrossEntropyLoss()
    color_loss_fn = torch.nn.CrossEntropyLoss()

    for epoch in range(args.proto_epochs):
        for i, batch in enumerate(train_loader):
            images, shape_labels, color_labels = batch
            
            # Move to device
            images = images.to(model.device)
            shape_labels = shape_labels.to(model.device)
            color_labels = color_labels.to(model.device)

            batch_size = images.size(0)

            # Shape checks
            assert images.shape == torch.Size([batch_size, 3, 64, 64]), \
                f"Expected shape [{batch_size}, 3, 64, 64], but got {images.shape}"
            assert shape_labels.shape == torch.Size([batch_size]), \
                f"Expected shape [{batch_size}], but got {shape_labels.shape}"
            assert color_labels.shape == torch.Size([batch_size]), \
                f"Expected shape [{batch_size}], but got {color_labels.shape}"
            
            # Zero gradients
            enc_opt.zero_grad()

            # Forward pass
            preds = model.encoder(images)
            
            # Example: preds split into shape and color logits
            shape_preds, color_preds = torch.split(preds, [3, 3], dim=1)

            # Assertions on output shapes
            assert shape_preds.shape == (batch_size, 3), \
                f"Expected shape_preds ({batch_size}, 3), but got {shape_preds.shape}"
            assert color_preds.shape == (batch_size, 3), \
                f"Expected color_preds ({batch_size}, 3), but got {color_preds.shape}"

            # Compute separate losses
            loss_shape = shape_loss_fn(shape_preds, shape_labels)
            loss_color = color_loss_fn(color_preds, color_labels)

            # Total loss
            loss = loss_shape + loss_color

            # Backpropagation
            loss.backward()
            enc_opt.step()

            # Progress update
            progress_bar(i, len(train_loader), epoch, loss.item())

    evaluate_model(model, train_loader)


pre_train(model, episodic_shape_dataloader, args)

# TRAINING LOOP

In [None]:
def train(model: MnistDPL,
    dataset: BaseDataset, 
    concept_extractor,
    concept_extractor_training_path,
    concept_extractor_project_path,
    transform,
    _loss: ADDMNIST_DPL,
    args,
    save_folder: str,
    patience: int = 3
    ):
    
    best_cacc = 0.0
    epochs_no_improve = 0   # for early stopping

    model.to(model.device)

    train_loader, val_loader, test_loader = dataset.get_data_loaders()
    dataset.print_stats()
    scheduler = torch.optim.lr_scheduler.ExponentialLR(model.opt, args.exp_decay)
    w_scheduler = None
    if args.warmup_steps > 0:   w_scheduler = GradualWarmupScheduler(model.opt, 1.0, args.warmup_steps)

    fprint("\n--- Start of Training ---\n")

    # default for warm-up
    model.opt.zero_grad()
    model.opt.step()

    # * Start of training
    for epoch in range(args.proto_epochs + 1):  # first epoch is for determining the baseline accuracy
        print(f"Epoch {epoch+1}/{args.proto_epochs + 1}")

        # ^ PHASE 1: Training the Concept Extractor
        print('----------------------------------')
        print('--- Concept Extractor Training ---')
        if epoch == 0:
            results = concept_extractor.train(data=concept_extractor_training_path, 
                        epochs=args.extractor_training_epochs, 
                        imgsz=64, 
                        project=concept_extractor_project_path)
            yolo_save_dir = os.path.join(results.save_dir, "weights", "last.pt")
        else:
            assert yolo_save_dir is not None
            concept_extractor = YOLO(yolo_save_dir)
            results = concept_extractor.train(data=concept_extractor_training_path, 
                        epochs=args.extractor_training_epochs, 
                        imgsz=64, 
                        project=concept_extractor_project_path)
            yolo_save_dir = os.path.join(results.save_dir, "weights", "last.pt")

        # ^ PHASE 2: Main Model Training
        ys, y_true, cs, cs_true = None, None, None, None
        for i, data in enumerate(train_loader):

            if epoch == 0:
                model.eval()
                assert not model.training, "Model should **NOT** be in training mode!"
                assert not model.encoder.training, "Encoder should **NOT** be in training mode!"
            else:    
                model.train()
                model.opt.zero_grad()
                assert model.training, "Model should be in training mode!"
                assert model.encoder.training, "Encoder should be in training mode!"

            images, labels, concepts = data
            images, labels, concepts = (
                images.to(model.device),
                labels.to(model.device),
                concepts.to(model.device),
            )

            out_dict = model(images, concept_extractor, transform, args)
            out_dict.update({"LABELS": labels, "CONCEPTS": concepts})
            
            loss, losses = _loss(out_dict, args)
            loss.backward()
            
            if epoch != 0:
                model.opt.step()

            if ys is None:
                ys = out_dict["YS"]
                y_true = out_dict["LABELS"]
                cs = out_dict["pCS"]
                cs_true = out_dict["CONCEPTS"]
            else:
                ys = torch.concatenate((ys, out_dict["YS"]), dim=0)
                y_true = torch.concatenate((y_true, out_dict["LABELS"]), dim=0)
                cs = torch.concatenate((cs, out_dict["pCS"]), dim=0)
                cs_true = torch.concatenate((cs_true, out_dict["CONCEPTS"]), dim=0)

            if i % 10 == 0:
                progress_bar(i, len(train_loader) - 9, epoch, loss.item())

        y_pred = torch.argmax(ys, dim=-1)
        #print("Argmax predictions have shape: ", y_pred.shape)

        if "patterns" in args.task:
            y_true = y_true[:, -1]  # it is the last one

        model.eval()
        tloss, cacc, yacc, f1 = evaluate_metrics(model, val_loader, args, concept_extractor=concept_extractor, transform=transform)

        # update the (warmup) scheduler at end of the epoch
        if epoch < args.warmup_steps:
            w_scheduler.step()
        else:
            scheduler.step()
            if hasattr(_loss, "grade"):
                _loss.update_grade(epoch)

        ### LOGGING ###
        fprint("  ACC C", cacc, "  ACC Y", yacc, "F1 Y", f1)
        print()

        if not args.tuning and cacc > best_cacc:
            print("Saving...")
            # Update best F1 score
            if best_cacc == 0.0:     print("Baseline accuracy has been determined.")
            best_cacc = cacc
            epochs_no_improve = 0
                
            # Save the best model and the concept extractor
            torch.save(model.state_dict(), save_folder)
            concept_save_path = os.path.join(os.path.dirname(save_folder), f"best_{SEED}.pt")
            concept_extractor.save(concept_save_path)
            print(f"Saved best model with CACC score: {best_cacc}")
            print()
        
        elif cacc <= best_cacc:
            epochs_no_improve += 1

        if epochs_no_improve >= patience:
            print(f"Early stopping triggered after {epoch+1} epochs.")
            break
    
    
    fprint("\n--- End of Training ---\n")
    return best_cacc

# RUN ALL THINGS

In [None]:
f1_scores = dict()
print(f"*** Training model with seed {SEED}")
print("Chosen device:", model.device)
if not os.path.exists(save_path): os.makedirs(save_path, exist_ok=True)
save_folder = os.path.join(save_path, f"{save_model_name}_{SEED}.pth")
print("Saving model in folder: ", save_folder)

best_cacc = train(model=model,
    dataset=dataset,
    concept_extractor=yolo,                              # yolo model
    concept_extractor_training_path=yaml_path,           # yolo training data path
    concept_extractor_project_path=my_yolo_project_path, # yolo project path
    transform=T.Resize((64, 64)),                        # resizer     
    _loss=loss,
    args=args,
    save_folder=save_folder,
)
save_model(model, args, SEED)  # save the model parameters

print(f"*** Finished training model with seed {SEED}")

print("Training finished.")
print(f"Best CACC score: {best_cacc}")