In [None]:
import pickle
import torch
import os
import sys


data_dir = os.path.join('..', "outputs", "embeddings")
with open(os.path.join(data_dir, "loaders_datasets.pkl"), 'rb') as f:
    data = pickle.load(f)

In [2]:
for split in data.keys():
    for modal in ['audio', 'text', 'video']:
        modal_tensors = data[split][modal].tensors
        data[split][modal] = {
            'features': modal_tensors[0], 
            'labels': modal_tensors[1]   
        }

    if 'text' in data[split]:
        text_features = data[split]['text']['features']
        text_len_tensor = torch.sum((text_features != 0).long(), dim=1) 
        data[split]['text']['text_len_tensor'] = text_len_tensor


for split in data.keys():
    print(f"{split} audio features shape: {data[split]['audio']['features'].shape}")
    print(f"{split} text features shape: {data[split]['text']['features'].shape}")
    print(f"{split} video features shape: {data[split]['video']['features'].shape}")
    print(f"{split} text length tensor shape: {data[split]['text']['text_len_tensor'].shape}")


train audio features shape: torch.Size([9988, 768])
train text features shape: torch.Size([9988, 768])
train video features shape: torch.Size([9988, 16, 768])
train text length tensor shape: torch.Size([9988])
val audio features shape: torch.Size([1108, 768])
val text features shape: torch.Size([1108, 768])
val video features shape: torch.Size([1108, 16, 768])
val text length tensor shape: torch.Size([1108])
test audio features shape: torch.Size([2610, 768])
test text features shape: torch.Size([2610, 768])
test video features shape: torch.Size([2610, 16, 768])
test text length tensor shape: torch.Size([2610])


In [35]:
import torch
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv, GATConv, RGCNConv
from sklearn.metrics import accuracy_score, f1_score
import matplotlib.pyplot as plt

# -------- Préparer les données --------
def prepare_graph_data(data, split):
    # Moyennage sur la dimension temporelle pour la modalité vidéo
    video_features = torch.mean(data[split]['video']['features'], dim=1)  # Moyenne sur la dimension temporelle
    
    # Concaténer les features
    features = torch.cat([
        data[split]['audio']['features'],
        data[split]['text']['features'],
        video_features
    ], dim=0)
    
    # Concaténer les labels
    labels = torch.cat([
        data[split]['audio']['labels'],
        data[split]['text']['labels'],
        data[split]['video']['labels']
    ], dim=0)
    
    # Initialiser les listes pour edge_index et edge_type
    edge_index = []
    edge_type = []

    # -------- Relations temporelles intra-modales --------
    def add_temporal_edges(offset, num_nodes, relation_type):
        for i in range(num_nodes):
            if i > 0:
                edge_index.append([offset + i, offset + i - 1])  # -1 (passé)
                edge_type.append(relation_type)
            edge_index.append([offset + i, offset + i])  # 0 (présent)
            edge_type.append(relation_type + 1)
            if i < num_nodes - 1:
                edge_index.append([offset + i, offset + i + 1])  # +1 (futur)
                edge_type.append(relation_type + 2)

    # Ajouter les relations intra-modales
    num_audio = data[split]['audio']['features'].size(0)
    num_text = data[split]['text']['features'].size(0)
    num_video = video_features.size(0)  # Taille après réduction temporelle

    add_temporal_edges(0, num_audio, 0)  # Audio : relation types 0, 1, 2
    add_temporal_edges(num_audio, num_text, 3)  # Text : relation types 3, 4, 5
    add_temporal_edges(num_audio + num_text, num_video, 6)  # Video : relation types 6, 7, 8

    # -------- Relations cross-modales --------
    for i in range(min(num_audio, num_text, num_video)):
        edge_index.append([i, num_audio + num_text + i])  # Audio -> Video
        edge_type.append(9)
        edge_index.append([num_audio + i, num_audio + num_text + i])  # Text -> Video
        edge_type.append(10)
        edge_index.append([i, num_audio + i])  # Audio -> Text
        edge_type.append(11)

    # Convertir en tenseurs
    edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
    edge_type = torch.tensor(edge_type, dtype=torch.long)

    print("Unique edge types:", torch.unique(edge_type))  # Debugging
    return Data(x=features, y=labels, edge_index=edge_index, edge_type=edge_type)

# -------- Définir les Modèles --------
class GCN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, dropout=0.0):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, out_channels)
        self.dropout = dropout

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)

class GAT(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, heads=4):
        super(GAT, self).__init__()
        self.conv1 = GATConv(in_channels, hidden_channels, heads=heads)
        self.conv2 = GATConv(hidden_channels * heads, out_channels, heads=1)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.elu(x)
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)

class RGCN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_relations):
        super(RGCN, self).__init__()
        self.conv1 = RGCNConv(in_channels, hidden_channels, num_relations)
        self.conv2 = RGCNConv(hidden_channels, out_channels, num_relations)

    def forward(self, x, edge_index, edge_type):
        x = self.conv1(x, edge_index, edge_type)
        x = F.relu(x)
        x = self.conv2(x, edge_index, edge_type)
        return F.log_softmax(x, dim=1)

# -------- Entraîner et Tester --------
def train_and_evaluate(model, train_data, val_data, test_data, optimizer, criterion, epochs=100):
    for epoch in range(epochs):
        # Entraînement
        model.train()
        optimizer.zero_grad()
        if isinstance(model, RGCN):
            out = model(train_data.x, train_data.edge_index, train_data.edge_type)
        
        else:
            out = model(train_data.x, train_data.edge_index)

        loss = criterion(out, train_data.y)
        loss.backward()
        optimizer.step()

        # Validation
        model.eval()
        with torch.no_grad():
            if isinstance(model, RGCN):
                val_out = model(val_data.x, val_data.edge_index, val_data.edge_type)
            else:
                val_out = model(val_data.x, val_data.edge_index)
            val_loss = criterion(val_out, val_data.y)
        print(f"Epoch {epoch + 1}, Train Loss: {loss.item():.4f}, Val Loss: {val_loss.item():.4f}")

    # Test
    model.eval()
    with torch.no_grad():
        if isinstance(model, RGCN):
            test_out = model(test_data.x, test_data.edge_index, test_data.edge_type)
        else:
            test_out = model(test_data.x, test_data.edge_index)
        test_pred = test_out.argmax(dim=1)
        test_acc = accuracy_score(test_data.y.cpu(), test_pred.cpu())
        test_f1 = f1_score(test_data.y.cpu(), test_pred.cpu(), average='macro')

    return test_acc, test_f1

# -------- Pipeline Principal --------
def main_pipeline(data):
    # Préparer les données
    train_data = prepare_graph_data(data, 'train')
    val_data = prepare_graph_data(data, 'val')
    test_data = prepare_graph_data(data, 'test')

    # Configurations de modèles
    architectures = [
        ("GCN (2 layers)", GCN(train_data.x.size(1), 64, len(torch.unique(train_data.y)))),
        ("GAT (4 heads)", GAT(train_data.x.size(1), 64, len(torch.unique(train_data.y)), heads=4)),
        ("RGCN", RGCN(train_data.x.size(1), 64, len(torch.unique(train_data.y)), num_relations=12)),
    ]

    # Entraîner et évaluer chaque modèle
    results = []
    for name, model in architectures:
        print(f"Training {name}...")
        optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
        criterion = torch.nn.CrossEntropyLoss()
        acc, f1 = train_and_evaluate(model, train_data, val_data, test_data, optimizer, criterion)
        results.append((name, acc * 100, f1 * 100))

    # Afficher les résultats
    print("\nBenchmark Results:")
    print("Model\t\t\tTest Accuracy\tTest F1-Score")
    for name, acc, f1 in results:
        print(f"{name:20}\t{acc:.2f}\t\t{f1:.2f}")



In [None]:
main_pipeline(data)

## **Improvement**

**1. Normalisation des Données**

In [74]:
from sklearn.preprocessing import StandardScaler

def prepare_graph_data(data, split, graph_augmentation=False):
    # Normaliser les données
    scaler_audio = StandardScaler()
    scaler_text = StandardScaler()
    scaler_video = StandardScaler()

    data[split]['audio']['features'] = torch.tensor(
        scaler_audio.fit_transform(data[split]['audio']['features']),
        dtype=torch.float
    )

    data[split]['text']['features'] = torch.tensor(
        scaler_text.fit_transform(data[split]['text']['features']),
        dtype=torch.float
    )

    video_shape = data[split]['video']['features'].shape
    data[split]['video']['features'] = torch.tensor(
        scaler_video.fit_transform(data[split]['video']['features'].reshape(-1, video_shape[-1])),
        dtype=torch.float
    ).reshape(video_shape)

    # Moyenne sur la dimension temporelle pour la modalité vidéo
    video_features = torch.mean(data[split]['video']['features'], dim=1)  # Moyenne sur la dimension temporelle
    
    # Concaténer les features
    features = torch.cat([
        data[split]['audio']['features'],
        data[split]['text']['features'],
        video_features
    ], dim=0)
    
    # Concaténer les labels
    labels = torch.cat([
        data[split]['audio']['labels'],
        data[split]['text']['labels'],
        data[split]['video']['labels']
    ], dim=0)
    
    # Initialiser les listes pour edge_index et edge_type
    edge_index = []
    edge_type = []

    # -------- Relations temporelles intra-modales --------
    def add_temporal_edges(offset, num_nodes, relation_type):
        for i in range(num_nodes):
            if i > 1:
                edge_index.append([offset + i, offset + i - 2])  # Passé lointain
                edge_type.append(relation_type + 3)
            if i < num_nodes - 2:
                edge_index.append([offset + i, offset + i + 2])  # Futur lointain
                edge_type.append(relation_type + 4)

            # Passé proche, présent, futur proche
            if i > 0:
                edge_index.append([offset + i, offset + i - 1])
                edge_type.append(relation_type)
            edge_index.append([offset + i, offset + i])
            edge_type.append(relation_type + 1)
            if i < num_nodes - 1:
                edge_index.append([offset + i, offset + i + 1])
                edge_type.append(relation_type + 2)


    # Ajouter les relations intra-modales
    num_audio = data[split]['audio']['features'].size(0)
    num_text = data[split]['text']['features'].size(0)
    num_video = video_features.size(0)  # Taille après réduction temporelle

    add_temporal_edges(0, num_audio, 0)  # Audio
    add_temporal_edges(num_audio, num_text, 5)  # Text
    add_temporal_edges(num_audio + num_text, num_video, 10)  # Video

    # Ajouter les relations cross-modales complexes
    for i in range(min(num_audio, num_text, num_video)):
        edge_index.append([i, num_audio + num_text + i])  # Audio -> Video
        edge_type.append(15)
        edge_index.append([num_audio + i, num_audio + num_text + i])  # Text -> Video
        edge_type.append(16)
        edge_index.append([i, num_audio + i])  # Audio -> Text
        edge_type.append(17)

    # Graph augmentation : ajouter des arêtes aléatoires si activé
    if graph_augmentation:
        for _ in range(100):  # Par exemple, 100 arêtes aléatoires
            src, dst = torch.randint(0, features.size(0), (2,))
            edge_index.append([src.item(), dst.item()])
            edge_type.append(18)

    # Convertir en tenseurs
    edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
    edge_type = torch.tensor(edge_type, dtype=torch.long)

    print("Unique edge types:", torch.unique(edge_type))  # Debugging
    return Data(x=features, y=labels, edge_index=edge_index, edge_type=edge_type)


b. Equilibrage des classes

In [75]:
from torch.nn import CrossEntropyLoss

def compute_class_weights(labels):
    unique_labels, counts = torch.unique(labels, return_counts=True)
    weights = 1.0 / counts.float()
    weights = weights / weights.sum()
    return weights

c. Augmentation des données


In [76]:
# graph_augmentation = True
# train_data = prepare_graph_data(data, 'train', graph_augmentation=graph_augmentation)


**2. Architecture des modèles** (Ajout de Couches Supplémentaires)

In [102]:
import torch.nn as nn

class DeepGCN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers=3, dropout=0.3):
        super(DeepGCN, self).__init__()
        self.layers = torch.nn.ModuleList()
        self.layers.append(GCNConv(in_channels, hidden_channels))
        self.batch_norms = torch.nn.ModuleList()
        self.batch_norms.append(torch.nn.BatchNorm1d(hidden_channels))

        for _ in range(num_layers - 1):
            self.layers.append(GCNConv(hidden_channels, hidden_channels))
            self.batch_norms.append(torch.nn.BatchNorm1d(hidden_channels))

        self.final_layer = GCNConv(hidden_channels, out_channels)
        self.dropout = dropout

    def forward(self, x, edge_index):
        for layer, batch_norm in zip(self.layers, self.batch_norms):
            x = layer(x, edge_index)
            x = batch_norm(x)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.final_layer(x, edge_index)
        return F.log_softmax(x, dim=1)

#----------------------------------------------------------------------------------------------------------
class DeepGAT(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, heads=4, num_layers=3, dropout=0.5):
        super(DeepGAT, self).__init__()
        self.layers = torch.nn.ModuleList()
        self.layers.append(GATConv(in_channels, hidden_channels, heads=heads))
        self.batch_norms = torch.nn.ModuleList()
        self.batch_norms.append(torch.nn.BatchNorm1d(hidden_channels * heads))
        for _ in range(num_layers - 1):
            self.layers.append(GATConv(hidden_channels * heads, hidden_channels, heads=heads))
            self.batch_norms.append(torch.nn.BatchNorm1d(hidden_channels * heads))

        self.final_layer = GATConv(hidden_channels * heads, out_channels, heads=1)
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, x, edge_index):
        for layer, batch_norm in zip(self.layers, self.batch_norms):
            x = layer(x, edge_index)
            x = batch_norm(x)
            x = F.elu(x)
            x = self.dropout(x)
        x = self.final_layer(x, edge_index)
        return F.log_softmax(x, dim=1)
#----------------------------------------------------------------------------------------------------------
class CombinedGCN_GAT_RGCN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_relations, dropout=0.5):
        super(CombinedGCN_GAT_RGCN, self).__init__()
        self.gcn = GCNConv(in_channels, hidden_channels)
        self.gat = GATConv(hidden_channels, hidden_channels, heads=4)
        self.rgcn = RGCNConv(hidden_channels * 4, hidden_channels, num_relations)
        self.final_layer = GCNConv(hidden_channels, out_channels)
        self.dropout = nn.Dropout(p=dropout)
        
    def forward(self, x, edge_index, edge_type):
        x = self.gcn(x, edge_index)
        x = F.relu(x)
        x = self.dropout(x)
        x = self.gat(x, edge_index)
        x = F.relu(x)
        x = self.dropout(x)
        x = self.rgcn(x, edge_index, edge_type)
        x = F.relu(x)
        x = self.dropout(x)
        x = self.final_layer(x, edge_index)
        return F.log_softmax(x, dim=1)
#----------------------------------------------------------------------------------------------------
from torch_geometric.nn import GATConv

class TemporalGAT(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_relations, heads=4, dropout=0.5):
        super(TemporalGAT, self).__init__()
        self.gat_layers = torch.nn.ModuleList()
        self.gat_layers.append(GATConv(in_channels, hidden_channels, heads=heads, add_self_loops=False))
        for _ in range(2):  # Exemple : 3 couches de GAT
            self.gat_layers.append(GATConv(hidden_channels * heads, hidden_channels, heads=heads, add_self_loops=False))
        self.final_layer = GATConv(hidden_channels * heads, out_channels, heads=1, add_self_loops=False)
        self.num_relations = num_relations
        self.dropout = torch.nn.Dropout(p=dropout)

    def forward(self, x, edge_index, edge_type=None):
        if edge_type is None:
            raise ValueError("edge_type is required for TemporalGAT")

        for gat_layer in self.gat_layers:
            x = F.dropout(x, p=self.dropout, training=self.training)  
            x = gat_layer(x, edge_index)  
        x = self.final_layer(x, edge_index)
        return F.log_softmax(x, dim=1)
#----------------------------------------------------------------------------------------------------
from torch_geometric.nn import GATv2Conv

class TemporalGATv2(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, heads=4, dropout=0.5):
        super(TemporalGATv2, self).__init__()
        self.gat_layers = torch.nn.ModuleList()
        self.gat_layers.append(GATv2Conv(in_channels, hidden_channels, heads=heads, add_self_loops=False))
        for _ in range(2):
            self.gat_layers.append(GATv2Conv(hidden_channels * heads, hidden_channels, heads=heads, add_self_loops=False))
        self.final_layer = GATv2Conv(hidden_channels * heads, out_channels, heads=1, add_self_loops=False)
        self.dropout = dropout
        self.dropout_layer = nn.Dropout(p=dropout)  # Rename to avoid conflict

    def forward(self, x, edge_index):
        for gat_layer in self.gat_layers:
            x = F.dropout(x, p=self.dropout, training=self.training)
            x = gat_layer(x, edge_index)

        x = self.final_layer(x, edge_index)
        return F.log_softmax(x, dim=1)


5. Relations dans le graphe

a) Relations cross-modales

Introduisez plus de connexions cross-modales dans edge_index, par exemple en liant chaque nœud audio à plusieurs nœuds vidéo proches au lieu d'un seul.

b) Graph Augmentation

Ajoutez des bruits aux connexions existantes ou utilisez des méthodes comme DropEdge (supprimer des arêtes aléatoires à chaque itération) pour améliorer la robustesse.

7. Autres Modèles


a) Graph Attention Networks (GATv2)


In [103]:
class CombinedGCN_GAT(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(CombinedGCN_GAT, self).__init__()
        self.gcn = GCNConv(in_channels, hidden_channels)
        self.gat = GATConv(hidden_channels, hidden_channels, heads=4)
        self.final_layer = GCNConv(hidden_channels * 4, out_channels)

    def forward(self, x, edge_index):
        x = F.relu(self.gcn(x, edge_index))
        x = F.elu(self.gat(x, edge_index))
        x = self.final_layer(x, edge_index)
        return F.log_softmax(x, dim=1)


In [104]:
def train_and_evaluate(model, train_data, val_data, test_data, optimizer, criterion, epochs=50):
    weights = compute_class_weights(train_data.y)
    criterion = CrossEntropyLoss(weight=weights)
    for epoch in range(epochs):
        model.train()
        optimizer.zero_grad()

        if hasattr(model, 'rgcn') or isinstance(model, (TemporalGAT, CombinedGCN_GAT_RGCN)):
            out = model(train_data.x, train_data.edge_index, train_data.edge_type)
        else:
            out = model(train_data.x, train_data.edge_index)

        loss = criterion(out, train_data.y)
        loss.backward()
        optimizer.step()

        model.eval()
        with torch.no_grad():
            if isinstance(model, (CombinedGCN_GAT_RGCN, TemporalGAT)):
                val_out = model(val_data.x, val_data.edge_index, val_data.edge_type)
            else:
                val_out = model(val_data.x, val_data.edge_index)
            val_loss = criterion(val_out, val_data.y)
        print(f"Epoch {epoch + 1}, Train Loss: {loss.item():.4f}, Val Loss: {val_loss.item():.4f}")

    # Test
    model.eval()
    with torch.no_grad():
        if isinstance(model, (TemporalGAT, CombinedGCN_GAT_RGCN)):
            test_out = model(test_data.x, test_data.edge_index, test_data.edge_type)
        else:
            test_out = model(test_data.x, test_data.edge_index)
        test_pred = test_out.argmax(dim=1)
        test_acc = accuracy_score(test_data.y.cpu(), test_pred.cpu())
        test_f1 = f1_score(test_data.y.cpu(), test_pred.cpu(), average='macro')

    return test_acc, test_f1


**Pipeline**

In [107]:
def main_pipeline(data):
    train_data = prepare_graph_data(data, 'train', graph_augmentation=False)
    val_data = prepare_graph_data(data, 'val', graph_augmentation=False)
    test_data = prepare_graph_data(data, 'test', graph_augmentation=False)

    # Configurations de modèles
    architectures = [
        # ("GCN (3 layers)", DeepGCN(train_data.x.size(1), 64, len(torch.unique(train_data.y)), num_layers=3)),
        # ("GCN (5 layers)", DeepGCN(train_data.x.size(1), 64, len(torch.unique(train_data.y)), num_layers=5)),
        # ("GAT (3 layers)", DeepGAT(train_data.x.size(1), 64, len(torch.unique(train_data.y)), num_layers=3)),
        # ("GAT (5 layers)", DeepGAT(train_data.x.size(1), 64, len(torch.unique(train_data.y)), num_layers=5)),
        # ("RGCN (5 relations)", RGCN(train_data.x.size(1), 64, len(torch.unique(train_data.y)), num_relations=20)),
        # ("Temporal GAT", TemporalGAT(train_data.x.size(1), 64, len(torch.unique(train_data.y)), num_relations=20)),
        # ("Temporal GATv2", TemporalGATv2(train_data.x.size(1), 64, len(torch.unique(train_data.y)))),
        # ("Combined GCN+GAT", CombinedGCN_GAT(train_data.x.size(1), 64, len(torch.unique(train_data.y)))),
        ("GAT + BatchNorm", DeepGAT(train_data.x.size(1), 64, len(torch.unique(train_data.y)), num_layers=3)),
        ("GCN+GAT+RGCN", CombinedGCN_GAT_RGCN(train_data.x.size(1), 64, len(torch.unique(train_data.y)), num_relations=20)),
        ("GCN + Dropout", DeepGCN(train_data.x.size(1), 64, len(torch.unique(train_data.y)), num_layers=3, dropout=0.5)),
    ]

    results = []
    for name, model in architectures:
        print(f"Training {name}...")
        optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=1e-4)
        criterion = CrossEntropyLoss()
        acc, f1 = train_and_evaluate(model, train_data, val_data, test_data, optimizer, criterion)
        results.append((name, acc * 100, f1 * 100))

    print("\nBenchmark Results:")
    print("Model\t\t\tTest Accuracy\tTest F1-Score")
    for name, acc, f1 in results:
        print(f"{name:20}\t{acc:.2f}\t\t{f1:.2f}")


In [108]:
main_pipeline(data)

Unique edge types: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17])
Unique edge types: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17])
Unique edge types: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17])
Training GAT + BatchNorm...
Epoch 1, Train Loss: 2.6019, Val Loss: 3.2095
Epoch 2, Train Loss: 2.6310, Val Loss: 13.1124
Epoch 3, Train Loss: 3.6083, Val Loss: 7.3790
Epoch 4, Train Loss: 2.4415, Val Loss: 6.7082
Epoch 5, Train Loss: 2.2857, Val Loss: 6.4300
Epoch 6, Train Loss: 2.1886, Val Loss: 7.2506
Epoch 7, Train Loss: 2.1602, Val Loss: 8.0862
Epoch 8, Train Loss: 2.1045, Val Loss: 7.1511
Epoch 9, Train Loss: 2.0608, Val Loss: 5.9592
Epoch 10, Train Loss: 2.0509, Val Loss: 5.3529
Epoch 11, Train Loss: 2.0542, Val Loss: 4.9369
Epoch 12, Train Loss: 2.0336, Val Loss: 4.4209
Epoch 13, Train Loss: 2.0155, Val Loss: 3.9028
Epoch 14, Train Loss: 1.9910, Val Loss: 3.3549
Epoch 15, Train Loss: 

## **Approach 2**

1. GraphModel Class Implementation


In [44]:
import pickle
import torch
import os
import numpy as np
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GATConv, RGCNConv
from sklearn.metrics import accuracy_score, f1_score


# Load data
data_dir = os.path.join('..', "outputs", "embeddings")
data = torch.load(os.path.join(data_dir, "loaders_datasets_reduced_label_dim_4.pt"))

# Prepare data
for split in data.keys():
    for modal in ['audio', 'text', 'video']:
        modal_tensors = data[split][modal].tensors
        data[split][modal] = {
            'features': modal_tensors[0], 
            'labels': modal_tensors[1]   
        }

    if 'text' in data[split]:
        text_features = data[split]['text']['features']
        text_len_tensor = torch.sum((text_features != 0).long(), dim=1)
        data[split]['text']['text_len_tensor'] = text_len_tensor

data


{'train': {'audio': {'features': tensor([[ 0.0939,  0.0762,  0.0689,  ...,  0.2469,  0.3439, -0.3689],
           [-0.0282,  0.0721,  0.4385,  ...,  0.4353,  0.0995, -0.0296],
           [-0.0721,  0.1159,  0.1769,  ...,  0.4833,  0.0715,  0.0421],
           ...,
           [ 0.0947,  0.2997, -0.1643,  ...,  0.2286,  0.3729, -0.2248],
           [ 0.1912,  0.0553,  0.0986,  ...,  0.3378,  0.1588, -0.1218],
           [ 0.0359,  0.1585, -0.2475,  ...,  0.2054,  0.3535, -0.2127]]),
   'labels': tensor([0, 0, 0,  ..., 0, 0, 1])},
  'text': {'features': tensor([[-0.1967,  0.1299,  0.1679,  ..., -0.2378,  0.2947,  0.1266],
           [ 0.0635,  0.3604, -0.5336,  ..., -0.5177,  0.5179,  0.2851],
           [-0.0813,  0.3107,  0.0133,  ..., -0.1899,  0.4385,  0.1935],
           ...,
           [ 0.0512, -0.0819, -0.1413,  ..., -0.2258,  0.3533,  0.4741],
           [-0.3271, -0.0116, -0.0164,  ..., -0.1454,  0.4163,  0.1177],
           [-0.1336, -0.3128, -0.4481,  ..., -0.0193,  0.1588,  0

In [71]:
import torch
import torch.nn as nn
from torch_geometric.nn import RGCNConv, TransformerConv
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm   
from sklearn.metrics import confusion_matrix, classification_report


# Debug function for logging
def debug_message(message, variable=None):
    if variable is not None:
        print(f"[DEBUG] {message}: {variable}")
    else:
        print(f"[DEBUG] {message}")

class GNN(nn.Module):
    def __init__(self, g_dim, h1_dim, h2_dim, num_relations, n_modals, args):
        super(GNN, self).__init__()
        self.rgcn = RGCNConv(g_dim, h1_dim, num_relations=num_relations)
        self.transformer = TransformerConv(h1_dim, h2_dim, heads=1)

    def forward(self, x, edge_index, edge_type):
        # Validate edge_index
        debug_message("Validating edge_index", edge_index.shape)
        if edge_index.max() >= x.size(0) or edge_index.min() < 0:
            raise ValueError(f"Invalid edge_index values: max = {edge_index.max()}, min = {edge_index.min()}, x.size(0) = {x.size(0)}")

        # Apply GNN layers
        x = self.rgcn(x, edge_index, edge_type)
        debug_message("RGCN output shape", x.shape)
        x = torch.relu(x)
        x = self.transformer(x, edge_index)
        x = torch.relu(x)
        debug_message("Transformer output shape", x.shape)
        return x

class GraphModel(nn.Module):
    def __init__(self, g_dim, h1_dim, h2_dim, device, args):
        super(GraphModel, self).__init__()

        self.n_modals = len(args['modalities'])
        self.wp = args['wp']
        self.wf = args['wf']
        self.device = device

        print(f"GraphModel --> Edge type: {args['edge_type']}")
        print(f"GraphModel --> Window past: {args['wp']}")
        print(f"GraphModel --> Window future: {args['wf']}")
        edge_temp = "temp" in args['edge_type']
        edge_multi = "multi" in args['edge_type']

        edge_type_to_idx = {}

        if edge_temp:
            temporal = [-1, 1, 0]
            for j in temporal:
                for k in range(self.n_modals):
                    edge_type_to_idx[f"{j}{k}{k}"] = len(edge_type_to_idx)
        else:
            for j in range(self.n_modals):
                edge_type_to_idx[f"0{j}{j}"] = len(edge_type_to_idx)

        if edge_multi:
            for j in range(self.n_modals):
                for k in range(self.n_modals):
                    if j != k:
                        edge_type_to_idx[f"0{j}{k}"] = len(edge_type_to_idx)

        self.edge_type_to_idx = edge_type_to_idx
        self.num_relations = len(edge_type_to_idx)
        self.edge_multi = edge_multi
        self.edge_temp = edge_temp

        self.gnn = GNN(g_dim, h1_dim, h2_dim, self.num_relations, self.n_modals, args)

    def forward(self, x, lengths):
        node_features = self.feature_packing(x, lengths)

        node_type, edge_index, edge_type, edge_index_lengths = self.batch_graphify(lengths)

        out_gnn = self.gnn(node_features, edge_index, edge_type)
        out_gnn = self.multi_concat(out_gnn, lengths)

        return out_gnn

    def batch_graphify(self, lengths):
        node_type, edge_index, edge_type, edge_index_lengths = [], [], [], []
        edge_type_lengths = [0] * len(self.edge_type_to_idx)

        lengths = lengths.tolist()

        sum_length = 0
        total_length = sum(lengths)
        batch_size = len(lengths)

        for k in range(self.n_modals):
            for j in range(batch_size):
                cur_len = lengths[j]
                node_type.extend([k] * cur_len)

        for j in range(batch_size):
            cur_len = lengths[j]

            perms = self.edge_perms(cur_len, total_length)
            edge_index_lengths.append(len(perms))

            for item in perms:
                vertices = item[0]
                neighbor = item[1]
                edge_index.append([vertices + sum_length, neighbor + sum_length])

                if vertices % total_length > neighbor % total_length:
                    temporal_type = 1
                elif vertices % total_length < neighbor % total_length:
                    temporal_type = -1
                else:
                    temporal_type = 0

                edge_type.append(self.edge_type_to_idx[f"{temporal_type}{node_type[vertices + sum_length]}{node_type[neighbor + sum_length]}"])

            sum_length += cur_len

        edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
        edge_type = torch.tensor(edge_type, dtype=torch.long)

        # Validate edge_index
        debug_message("Validating final edge_index", edge_index.shape)
        if edge_index.max() >= total_length or edge_index.min() < 0:
            raise ValueError(f"Invalid edge_index values: max = {edge_index.max()}, min = {edge_index.min()}, total_length = {total_length}")

        node_type = torch.tensor(node_type).long().to(self.device)
        edge_index = edge_index.to(self.device)
        edge_type = edge_type.to(self.device)
        edge_index_lengths = torch.tensor(edge_index_lengths).long().to(self.device)

        return node_type, edge_index, edge_type, edge_index_lengths

    def edge_perms(self, length, total_lengths):
        all_perms = set()
        array = np.arange(length)

        for j in range(length):
            if self.wp == -1 and self.wf == -1:
                eff_array = array
            elif self.wp == -1:
                eff_array = array[: min(length, j + self.wf)]
            elif self.wf == -1:
                eff_array = array[max(0, j - self.wp) :]
            else:
                eff_array = array[max(0, j - self.wp) : min(length, j + self.wf)]

            for k in range(self.n_modals):
                node_index = j + k * total_lengths
                if self.edge_temp:
                    for item in eff_array:
                        all_perms.add((node_index, item + k * total_lengths))
                else:
                    all_perms.add((node_index, node_index))
                if self.edge_multi:
                    for l in range(self.n_modals):
                        if l != k:
                            all_perms.add((node_index, j + l * total_lengths))

        all_perms = [(src, dst) for src, dst in all_perms if 0 <= src < total_lengths and 0 <= dst < total_lengths]
        debug_message("Number of valid permutations", len(all_perms))
        return list(all_perms)

    def feature_packing(self, x, lengths):
        if isinstance(x, (list, tuple)):
            packed = torch.cat(x, dim=0)
        else:
            packed = x
        
        debug_message("Packed features shape", packed.shape)
        debug_message("Lengths sum", lengths.sum().item())
        
        if packed.size(0) != lengths.sum():
            raise ValueError(f"Mismatch in packed features size: packed.size(0) = {packed.size(0)}, lengths.sum() = {lengths.sum().item()}")
        
        return packed



    def multi_concat(self, out_gnn, lengths):
        split_sizes = lengths.tolist() if isinstance(lengths, torch.Tensor) else lengths
        debug_message("Lengths for split", split_sizes)
        debug_message("Out GNN shape before split", out_gnn.shape)
        
        if sum(split_sizes) != out_gnn.size(0):
            raise ValueError(f"Mismatch: lengths.sum() = {sum(split_sizes)}, out_gnn.size(0) = {out_gnn.size(0)}")
        
        split_features = torch.split(out_gnn, split_sizes, dim=0)
        debug_message("Split features shapes", [f.shape for f in split_features])
        
        # Vérification avant concaténation
        if any(f.size(1) != split_features[0].size(1) for f in split_features):
            raise ValueError("Inconsistent tensor shapes in split features")

        return torch.cat(split_features, dim=1)



In [72]:
import torch
args = {
    'modalities': ['audio', 'text', 'video'],
    'edge_type': ['temp', 'multi'],
    'wp': 2,
    'wf': 2,
    'device': torch.device('cuda' if torch.cuda.is_available() else 'cpu')
}

g_dim = 16   
h1_dim = 32 
h2_dim = 16  
num_relations = 6 

model = GraphModel(g_dim, h1_dim, h2_dim, args['device'], args)
model.to(args['device'])




GraphModel --> Edge type: ['temp', 'multi']
GraphModel --> Window past: 2
GraphModel --> Window future: 2


GraphModel(
  (gnn): GNN(
    (rgcn): RGCNConv(16, 32, num_relations=15)
    (transformer): TransformerConv(32, 16, heads=1)
  )
)

In [73]:
node_features = [torch.randn(4, 16), torch.randn(3, 16), torch.randn(3, 16)]
lengths = torch.tensor([4, 3, 3])


packed_features = GraphModel.feature_packing(GraphModel, node_features, lengths)
debug_message("Packed features shape", packed_features.shape)

assert packed_features.size(0) == lengths.sum().item(), \
    f"Mismatch: packed.size(0) = {packed_features.size(0)}, lengths.sum() = {lengths.sum().item()}"


[DEBUG] Packed features shape: torch.Size([10, 16])
[DEBUG] Lengths sum: 10
[DEBUG] Packed features shape: torch.Size([10, 16])


In [74]:
# Instanciation d'un modèle fictif pour le test
test_model = GraphModel(16, 32, 16, torch.device('cpu'), args)

# Définition des paramètres pour le test
length = 4  # Nombre de nœuds pour un groupe
total_length = 10  # Nombre total de nœuds dans le graphe

# Appel à edge_perms
perms = test_model.edge_perms(length, total_length)

# Messages de débogage pour examiner les permutations
debug_message("Generated edge permutations", len(perms))
debug_message("Sample edge permutations", perms[:10])  # Affiche les 10 premières permutations

# Validation des indices
assert all(0 <= src < total_length and 0 <= dst < total_length for src, dst in perms), \
    "Edge permutations contain invalid indices"



GraphModel --> Edge type: ['temp', 'multi']
GraphModel --> Window past: 2
GraphModel --> Window future: 2
[DEBUG] Number of valid permutations: 12
[DEBUG] Generated edge permutations: 12
[DEBUG] Sample edge permutations: [(3, 1), (2, 2), (1, 0), (3, 3), (0, 1), (1, 2), (2, 1), (3, 2), (0, 0), (1, 1)]


**Test batch_graphify**

In [75]:
# Longueurs des groupes/modalités
lengths = torch.tensor([4, 3, 3])  # Exemple avec 3 groupes : 4 nœuds, 3 nœuds, 3 nœuds

# Appel à batch_graphify
try:
    node_type, edge_index, edge_type, edge_index_lengths = test_model.batch_graphify(lengths)
    
    # Messages de débogage pour inspecter les résultats
    debug_message("Node type length", len(node_type))
    debug_message("Edge index shape", edge_index.shape)
    debug_message("Edge type shape", edge_type.shape)
    debug_message("Edge index lengths", edge_index_lengths.tolist())

    # Assertions pour valider les résultats
    assert edge_index.max() < lengths.sum().item(), "edge_index contient des indices hors limites"
    assert edge_index.min() >= 0, "edge_index contient des indices négatifs"
    assert len(edge_type) == edge_index.size(1), "Mismatch entre edge_type et edge_index"
    assert sum(edge_index_lengths.tolist()) == edge_index.size(1), \
        "Mismatch entre edge_index_lengths et le nombre total d'arêtes"
except Exception as e:
    print(f"[ERROR in batch_graphify]: {e}")


[DEBUG] Number of valid permutations: 12
[DEBUG] Number of valid permutations: 8
[DEBUG] Number of valid permutations: 8
[DEBUG] Validating final edge_index: torch.Size([2, 28])
[DEBUG] Node type length: 30
[DEBUG] Edge index shape: torch.Size([2, 28])
[DEBUG] Edge type shape: torch.Size([28])
[DEBUG] Edge index lengths: [12, 8, 8]


In [76]:

class GraphDataset(Dataset):
    def __init__(self, features, labels, lengths):
        self.features = features
        self.labels = labels
        self.lengths = lengths

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return self.features[idx], self.labels[idx]

def prepare_dataloaders(features, labels, lengths, batch_size):
    dataset = GraphDataset(features, labels, lengths)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    return loader
    
def prepare_graph_data(data, args):
    """
    Prépare les données pour le modèle GraphModel.
    :param data: Dictionnaire contenant les caractéristiques et les étiquettes.
    :param args: Arguments de configuration.
    :return: Tenseur pour les caractéristiques, les étiquettes, et les longueurs.
    """
    # Rassemblement des caractéristiques des modalités
    video_features = torch.mean(data['video']['features'], dim=1)
    features = torch.cat([
        data['audio']['features'],
        data['text']['features'],
        video_features
    ], dim=0)

    # Rassemblement des étiquettes
    labels = torch.cat([
        data['audio']['labels'],
        data['text']['labels'],
        data['video']['labels']
    ], dim=0)

    # Longueurs des nœuds pour chaque modalité
    lengths = [
        data['audio']['features'].size(0),
        data['text']['features'].size(0),
        video_features.size(0)
    ]

    debug_message("Prepared features shape", features.shape)
    debug_message("Prepared labels shape", labels.shape)
    debug_message("Prepared lengths", lengths)

    return features, labels, torch.tensor(lengths, dtype=torch.long)


train_features, train_labels, train_lengths = prepare_graph_data(data['train'], args)
dev_features, dev_labels, dev_lengths = prepare_graph_data(data['val'], args)
test_features, test_labels, test_lengths = prepare_graph_data(data['test'], args)


[DEBUG] Prepared features shape: torch.Size([23496, 768])
[DEBUG] Prepared labels shape: torch.Size([23496])
[DEBUG] Prepared lengths: [7832, 7832, 7832]
[DEBUG] Prepared features shape: torch.Size([2421, 768])
[DEBUG] Prepared labels shape: torch.Size([2421])
[DEBUG] Prepared lengths: [807, 807, 807]
[DEBUG] Prepared features shape: torch.Size([6213, 768])
[DEBUG] Prepared labels shape: torch.Size([6213])
[DEBUG] Prepared lengths: [2071, 2071, 2071]


In [77]:

def train_model(model, train_loader, val_loader, optimizer, criterion, args, epochs=10):
    best_val_loss = float('inf')

    for epoch in range(epochs):
        # Training
        model.train()
        train_loss = 0
        with tqdm(train_loader, desc=f"Epoch {epoch + 1} Training") as t:
            for features, labels in train_loader:
                features = features.to(args['device'])
                labels = labels.to(args['device'])

                optimizer.zero_grad()
                out = model(features, lengths=torch.tensor([len(features)]).to(args['device']))
                loss = criterion(out, labels)
                loss.backward()
                optimizer.step()

                train_loss += loss.item()

        debug_message(f"Epoch {epoch + 1} train loss", train_loss / len(train_loader))

        # Validation
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for features, labels in val_loader:
                features = features.to(args['device'])
                labels = labels.to(args['device'])

                out = model(features, lengths=torch.tensor([len(features)]).to(args['device']))
                val_loss += criterion(out, labels).item()

        val_loss /= len(val_loader)
        debug_message(f"Epoch {epoch + 1} validation loss", val_loss)

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), 'best_model.pth')
            debug_message(f"Epoch {epoch + 1}: Best model saved")

def evaluate_model(model, test_loader, args):
    model.eval()
    all_preds = []
    all_labels = []
    with torch.no_grad():
        with tqdm(test_loader, desc="Testing") as t:
            for features, labels in t:
                features = features.to(args['device'])
                labels = labels.to(args['device'])

                out = model(features, lengths=torch.tensor([len(features)]).to(args['device']))
                preds = out.argmax(dim=1)
                all_preds.append(preds.cpu())
                all_labels.append(labels.cpu())

    all_preds = torch.cat(all_preds)
    all_labels = torch.cat(all_labels)

    test_acc = accuracy_score(all_labels, all_preds)
    test_f1 = f1_score(all_labels, all_preds, average='macro')

    debug_message("Test Accuracy", test_acc)
    debug_message("Test F1 Score", test_f1)

    report = classification_report(all_labels, all_preds, target_names=[str(i) for i in range(all_labels.max() + 1)])
    with open('test_results.txt', 'w') as f:
        f.write(f"Test Accuracy: {test_acc}\n")
        f.write(f"Test F1 Score: {test_f1}\n")
        f.write(report)

    print(report)
    return test_acc, test_f1


In [78]:
args = {
    'modalities': ['audio', 'text', 'video'],
    'edge_type': ['temp', 'multi'],
    'wp': 2,
    'wf': 2,
    'device': torch.device('cuda' if torch.cuda.is_available() else 'cpu')
}

# Assuming train_data, dev_data, test_data are already prepared
train_loader = prepare_dataloaders(train_features, train_labels, train_lengths, batch_size=32)
val_loader = prepare_dataloaders(dev_features, dev_labels, dev_lengths, batch_size=32)
test_loader = prepare_dataloaders(test_features, test_labels, test_lengths, batch_size=32)

model = GraphModel(768, 128, 64, args['device'], args).to(args['device'])
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)
criterion = nn.CrossEntropyLoss()

train_model(model, train_loader, val_loader, optimizer, criterion, args, epochs=10)
evaluate_model(model, test_loader, args)

GraphModel --> Edge type: ['temp', 'multi']
GraphModel --> Window past: 2
GraphModel --> Window future: 2


Epoch 1 Training:   0%|          | 0/735 [00:00<?, ?it/s]

[DEBUG] Packed features shape: torch.Size([32, 768])
[DEBUG] Lengths sum: 32
[DEBUG] Number of valid permutations: 124
[DEBUG] Validating final edge_index: torch.Size([2, 124])
[DEBUG] Validating edge_index: torch.Size([2, 124])
[DEBUG] RGCN output shape: torch.Size([32, 128])
[DEBUG] Transformer output shape: torch.Size([32, 64])
[DEBUG] Lengths for split: [32]
[DEBUG] Out GNN shape before split: torch.Size([32, 64])
[DEBUG] Split features shapes: [torch.Size([32, 64])]
[DEBUG] Packed features shape: torch.Size([32, 768])
[DEBUG] Lengths sum: 32
[DEBUG] Number of valid permutations: 124
[DEBUG] Validating final edge_index: torch.Size([2, 124])
[DEBUG] Validating edge_index: torch.Size([2, 124])
[DEBUG] RGCN output shape: torch.Size([32, 128])
[DEBUG] Transformer output shape: torch.Size([32, 64])
[DEBUG] Lengths for split: [32]
[DEBUG] Out GNN shape before split: torch.Size([32, 64])
[DEBUG] Split features shapes: [torch.Size([32, 64])]
[DEBUG] Packed features shape: torch.Size([32, 7

Epoch 1 Training:   0%|          | 0/735 [00:18<?, ?it/s]

[DEBUG] Validating final edge_index: torch.Size([2, 124])
[DEBUG] Validating edge_index: torch.Size([2, 124])
[DEBUG] RGCN output shape: torch.Size([32, 128])
[DEBUG] Transformer output shape: torch.Size([32, 64])
[DEBUG] Lengths for split: [32]
[DEBUG] Out GNN shape before split: torch.Size([32, 64])
[DEBUG] Split features shapes: [torch.Size([32, 64])]
[DEBUG] Packed features shape: torch.Size([8, 768])
[DEBUG] Lengths sum: 8
[DEBUG] Number of valid permutations: 28
[DEBUG] Validating final edge_index: torch.Size([2, 28])
[DEBUG] Validating edge_index: torch.Size([2, 28])
[DEBUG] RGCN output shape: torch.Size([8, 128])
[DEBUG] Transformer output shape: torch.Size([8, 64])
[DEBUG] Lengths for split: [8]
[DEBUG] Out GNN shape before split: torch.Size([8, 64])
[DEBUG] Split features shapes: [torch.Size([8, 64])]
[DEBUG] Epoch 1 train loss: 1.1674074499785494
[DEBUG] Packed features shape: torch.Size([32, 768])
[DEBUG] Lengths sum: 32
[DEBUG] Number of valid permutations: 124
[DEBUG] Val




[DEBUG] RGCN output shape: torch.Size([32, 128])
[DEBUG] Transformer output shape: torch.Size([32, 64])
[DEBUG] Lengths for split: [32]
[DEBUG] Out GNN shape before split: torch.Size([32, 64])
[DEBUG] Split features shapes: [torch.Size([32, 64])]
[DEBUG] Packed features shape: torch.Size([32, 768])
[DEBUG] Lengths sum: 32
[DEBUG] Number of valid permutations: 124
[DEBUG] Validating final edge_index: torch.Size([2, 124])
[DEBUG] Validating edge_index: torch.Size([2, 124])
[DEBUG] RGCN output shape: torch.Size([32, 128])
[DEBUG] Transformer output shape: torch.Size([32, 64])
[DEBUG] Lengths for split: [32]
[DEBUG] Out GNN shape before split: torch.Size([32, 64])
[DEBUG] Split features shapes: [torch.Size([32, 64])]
[DEBUG] Packed features shape: torch.Size([32, 768])
[DEBUG] Lengths sum: 32
[DEBUG] Number of valid permutations: 124
[DEBUG] Validating final edge_index: torch.Size([2, 124])
[DEBUG] Validating edge_index: torch.Size([2, 124])
[DEBUG] RGCN output shape: torch.Size([32, 128])

Epoch 2 Training:   0%|          | 0/735 [00:00<?, ?it/s]

[DEBUG] Packed features shape: torch.Size([32, 768])
[DEBUG] Lengths sum: 32
[DEBUG] Number of valid permutations: 124
[DEBUG] Validating final edge_index: torch.Size([2, 124])
[DEBUG] Validating edge_index: torch.Size([2, 124])
[DEBUG] RGCN output shape: torch.Size([32, 128])
[DEBUG] Transformer output shape: torch.Size([32, 64])
[DEBUG] Lengths for split: [32]
[DEBUG] Out GNN shape before split: torch.Size([32, 64])
[DEBUG] Split features shapes: [torch.Size([32, 64])]
[DEBUG] Packed features shape: torch.Size([32, 768])
[DEBUG] Lengths sum: 32
[DEBUG] Number of valid permutations: 124
[DEBUG] Validating final edge_index: torch.Size([2, 124])
[DEBUG] Validating edge_index: torch.Size([2, 124])
[DEBUG] RGCN output shape: torch.Size([32, 128])
[DEBUG] Transformer output shape: torch.Size([32, 64])
[DEBUG] Lengths for split: [32]
[DEBUG] Out GNN shape before split: torch.Size([32, 64])
[DEBUG] Split features shapes: [torch.Size([32, 64])]
[DEBUG] Packed features shape: torch.Size([32, 7

Epoch 2 Training:   0%|          | 0/735 [00:18<?, ?it/s]

[DEBUG] Packed features shape: torch.Size([32, 768])
[DEBUG] Lengths sum: 32
[DEBUG] Number of valid permutations: 124
[DEBUG] Validating final edge_index: torch.Size([2, 124])
[DEBUG] Validating edge_index: torch.Size([2, 124])
[DEBUG] RGCN output shape: torch.Size([32, 128])
[DEBUG] Transformer output shape: torch.Size([32, 64])
[DEBUG] Lengths for split: [32]
[DEBUG] Out GNN shape before split: torch.Size([32, 64])
[DEBUG] Split features shapes: [torch.Size([32, 64])]
[DEBUG] Packed features shape: torch.Size([32, 768])
[DEBUG] Lengths sum: 32
[DEBUG] Number of valid permutations: 124
[DEBUG] Validating final edge_index: torch.Size([2, 124])
[DEBUG] Validating edge_index: torch.Size([2, 124])
[DEBUG] RGCN output shape: torch.Size([32, 128])
[DEBUG] Transformer output shape: torch.Size([32, 64])
[DEBUG] Lengths for split: [32]
[DEBUG] Out GNN shape before split: torch.Size([32, 64])
[DEBUG] Split features shapes: [torch.Size([32, 64])]
[DEBUG] Packed features shape: torch.Size([32, 7




[DEBUG] Packed features shape: torch.Size([32, 768])
[DEBUG] Lengths sum: 32
[DEBUG] Number of valid permutations: 124
[DEBUG] Validating final edge_index: torch.Size([2, 124])
[DEBUG] Validating edge_index: torch.Size([2, 124])
[DEBUG] RGCN output shape: torch.Size([32, 128])
[DEBUG] Transformer output shape: torch.Size([32, 64])
[DEBUG] Lengths for split: [32]
[DEBUG] Out GNN shape before split: torch.Size([32, 64])
[DEBUG] Split features shapes: [torch.Size([32, 64])]
[DEBUG] Packed features shape: torch.Size([32, 768])
[DEBUG] Lengths sum: 32
[DEBUG] Number of valid permutations: 124
[DEBUG] Validating final edge_index: torch.Size([2, 124])
[DEBUG] Validating edge_index: torch.Size([2, 124])
[DEBUG] RGCN output shape: torch.Size([32, 128])
[DEBUG] Transformer output shape: torch.Size([32, 64])
[DEBUG] Lengths for split: [32]
[DEBUG] Out GNN shape before split: torch.Size([32, 64])
[DEBUG] Split features shapes: [torch.Size([32, 64])]
[DEBUG] Packed features shape: torch.Size([32, 7

Epoch 3 Training:   0%|          | 0/735 [00:00<?, ?it/s]

[DEBUG] Packed features shape: torch.Size([32, 768])
[DEBUG] Lengths sum: 32
[DEBUG] Number of valid permutations: 124
[DEBUG] Validating final edge_index: torch.Size([2, 124])
[DEBUG] Validating edge_index: torch.Size([2, 124])
[DEBUG] RGCN output shape: torch.Size([32, 128])
[DEBUG] Transformer output shape: torch.Size([32, 64])
[DEBUG] Lengths for split: [32]
[DEBUG] Out GNN shape before split: torch.Size([32, 64])
[DEBUG] Split features shapes: [torch.Size([32, 64])]
[DEBUG] Packed features shape: torch.Size([32, 768])
[DEBUG] Lengths sum: 32
[DEBUG] Number of valid permutations: 124
[DEBUG] Validating final edge_index: torch.Size([2, 124])
[DEBUG] Validating edge_index: torch.Size([2, 124])
[DEBUG] RGCN output shape: torch.Size([32, 128])
[DEBUG] Transformer output shape: torch.Size([32, 64])
[DEBUG] Lengths for split: [32]
[DEBUG] Out GNN shape before split: torch.Size([32, 64])
[DEBUG] Split features shapes: [torch.Size([32, 64])]
[DEBUG] Packed features shape: torch.Size([32, 7

Epoch 3 Training:   0%|          | 0/735 [00:18<?, ?it/s]

[DEBUG] Transformer output shape: torch.Size([32, 64])
[DEBUG] Lengths for split: [32]
[DEBUG] Out GNN shape before split: torch.Size([32, 64])
[DEBUG] Split features shapes: [torch.Size([32, 64])]
[DEBUG] Packed features shape: torch.Size([32, 768])
[DEBUG] Lengths sum: 32
[DEBUG] Number of valid permutations: 124
[DEBUG] Validating final edge_index: torch.Size([2, 124])
[DEBUG] Validating edge_index: torch.Size([2, 124])
[DEBUG] RGCN output shape: torch.Size([32, 128])
[DEBUG] Transformer output shape: torch.Size([32, 64])
[DEBUG] Lengths for split: [32]
[DEBUG] Out GNN shape before split: torch.Size([32, 64])
[DEBUG] Split features shapes: [torch.Size([32, 64])]
[DEBUG] Packed features shape: torch.Size([32, 768])
[DEBUG] Lengths sum: 32
[DEBUG] Number of valid permutations: 124
[DEBUG] Validating final edge_index: torch.Size([2, 124])
[DEBUG] Validating edge_index: torch.Size([2, 124])
[DEBUG] RGCN output shape: torch.Size([32, 128])
[DEBUG] Transformer output shape: torch.Size([32




[DEBUG] Packed features shape: torch.Size([32, 768])
[DEBUG] Lengths sum: 32
[DEBUG] Number of valid permutations: 124
[DEBUG] Validating final edge_index: torch.Size([2, 124])
[DEBUG] Validating edge_index: torch.Size([2, 124])
[DEBUG] RGCN output shape: torch.Size([32, 128])
[DEBUG] Transformer output shape: torch.Size([32, 64])
[DEBUG] Lengths for split: [32]
[DEBUG] Out GNN shape before split: torch.Size([32, 64])
[DEBUG] Split features shapes: [torch.Size([32, 64])]
[DEBUG] Packed features shape: torch.Size([32, 768])
[DEBUG] Lengths sum: 32
[DEBUG] Number of valid permutations: 124
[DEBUG] Validating final edge_index: torch.Size([2, 124])
[DEBUG] Validating edge_index: torch.Size([2, 124])
[DEBUG] RGCN output shape: torch.Size([32, 128])
[DEBUG] Transformer output shape: torch.Size([32, 64])
[DEBUG] Lengths for split: [32]
[DEBUG] Out GNN shape before split: torch.Size([32, 64])
[DEBUG] Split features shapes: [torch.Size([32, 64])]
[DEBUG] Packed features shape: torch.Size([32, 7

Epoch 4 Training:   0%|          | 0/735 [00:00<?, ?it/s]

[DEBUG] Packed features shape: torch.Size([32, 768])
[DEBUG] Lengths sum: 32
[DEBUG] Number of valid permutations: 124
[DEBUG] Validating final edge_index: torch.Size([2, 124])
[DEBUG] Validating edge_index: torch.Size([2, 124])
[DEBUG] RGCN output shape: torch.Size([32, 128])
[DEBUG] Transformer output shape: torch.Size([32, 64])
[DEBUG] Lengths for split: [32]
[DEBUG] Out GNN shape before split: torch.Size([32, 64])
[DEBUG] Split features shapes: [torch.Size([32, 64])]
[DEBUG] Packed features shape: torch.Size([32, 768])
[DEBUG] Lengths sum: 32
[DEBUG] Number of valid permutations: 124
[DEBUG] Validating final edge_index: torch.Size([2, 124])
[DEBUG] Validating edge_index: torch.Size([2, 124])
[DEBUG] RGCN output shape: torch.Size([32, 128])
[DEBUG] Transformer output shape: torch.Size([32, 64])
[DEBUG] Lengths for split: [32]
[DEBUG] Out GNN shape before split: torch.Size([32, 64])
[DEBUG] Split features shapes: [torch.Size([32, 64])]
[DEBUG] Packed features shape: torch.Size([32, 7

Epoch 4 Training:   0%|          | 0/735 [00:18<?, ?it/s]

[DEBUG] Packed features shape: torch.Size([32, 768])
[DEBUG] Lengths sum: 32
[DEBUG] Number of valid permutations: 124
[DEBUG] Validating final edge_index: torch.Size([2, 124])
[DEBUG] Validating edge_index: torch.Size([2, 124])
[DEBUG] RGCN output shape: torch.Size([32, 128])
[DEBUG] Transformer output shape: torch.Size([32, 64])
[DEBUG] Lengths for split: [32]
[DEBUG] Out GNN shape before split: torch.Size([32, 64])
[DEBUG] Split features shapes: [torch.Size([32, 64])]
[DEBUG] Packed features shape: torch.Size([32, 768])
[DEBUG] Lengths sum: 32
[DEBUG] Number of valid permutations: 124
[DEBUG] Validating final edge_index: torch.Size([2, 124])
[DEBUG] Validating edge_index: torch.Size([2, 124])
[DEBUG] RGCN output shape: torch.Size([32, 128])
[DEBUG] Transformer output shape: torch.Size([32, 64])
[DEBUG] Lengths for split: [32]
[DEBUG] Out GNN shape before split: torch.Size([32, 64])
[DEBUG] Split features shapes: [torch.Size([32, 64])]
[DEBUG] Packed features shape: torch.Size([32, 7




[DEBUG] RGCN output shape: torch.Size([32, 128])
[DEBUG] Transformer output shape: torch.Size([32, 64])
[DEBUG] Lengths for split: [32]
[DEBUG] Out GNN shape before split: torch.Size([32, 64])
[DEBUG] Split features shapes: [torch.Size([32, 64])]
[DEBUG] Packed features shape: torch.Size([32, 768])
[DEBUG] Lengths sum: 32
[DEBUG] Number of valid permutations: 124
[DEBUG] Validating final edge_index: torch.Size([2, 124])
[DEBUG] Validating edge_index: torch.Size([2, 124])
[DEBUG] RGCN output shape: torch.Size([32, 128])
[DEBUG] Transformer output shape: torch.Size([32, 64])
[DEBUG] Lengths for split: [32]
[DEBUG] Out GNN shape before split: torch.Size([32, 64])
[DEBUG] Split features shapes: [torch.Size([32, 64])]
[DEBUG] Packed features shape: torch.Size([32, 768])
[DEBUG] Lengths sum: 32
[DEBUG] Number of valid permutations: 124
[DEBUG] Validating final edge_index: torch.Size([2, 124])
[DEBUG] Validating edge_index: torch.Size([2, 124])
[DEBUG] RGCN output shape: torch.Size([32, 128])

Epoch 5 Training:   0%|          | 0/735 [00:00<?, ?it/s]

[DEBUG] Packed features shape: torch.Size([32, 768])
[DEBUG] Lengths sum: 32
[DEBUG] Number of valid permutations: 124
[DEBUG] Validating final edge_index: torch.Size([2, 124])
[DEBUG] Validating edge_index: torch.Size([2, 124])
[DEBUG] RGCN output shape: torch.Size([32, 128])
[DEBUG] Transformer output shape: torch.Size([32, 64])
[DEBUG] Lengths for split: [32]
[DEBUG] Out GNN shape before split: torch.Size([32, 64])
[DEBUG] Split features shapes: [torch.Size([32, 64])]
[DEBUG] Packed features shape: torch.Size([32, 768])
[DEBUG] Lengths sum: 32
[DEBUG] Number of valid permutations: 124
[DEBUG] Validating final edge_index: torch.Size([2, 124])
[DEBUG] Validating edge_index: torch.Size([2, 124])
[DEBUG] RGCN output shape: torch.Size([32, 128])
[DEBUG] Transformer output shape: torch.Size([32, 64])
[DEBUG] Lengths for split: [32]
[DEBUG] Out GNN shape before split: torch.Size([32, 64])
[DEBUG] Split features shapes: [torch.Size([32, 64])]
[DEBUG] Packed features shape: torch.Size([32, 7

Epoch 5 Training:   0%|          | 0/735 [00:18<?, ?it/s]

[DEBUG] Packed features shape: torch.Size([32, 768])
[DEBUG] Lengths sum: 32
[DEBUG] Number of valid permutations: 124
[DEBUG] Validating final edge_index: torch.Size([2, 124])
[DEBUG] Validating edge_index: torch.Size([2, 124])
[DEBUG] RGCN output shape: torch.Size([32, 128])
[DEBUG] Transformer output shape: torch.Size([32, 64])
[DEBUG] Lengths for split: [32]
[DEBUG] Out GNN shape before split: torch.Size([32, 64])
[DEBUG] Split features shapes: [torch.Size([32, 64])]
[DEBUG] Packed features shape: torch.Size([32, 768])
[DEBUG] Lengths sum: 32
[DEBUG] Number of valid permutations: 124
[DEBUG] Validating final edge_index: torch.Size([2, 124])
[DEBUG] Validating edge_index: torch.Size([2, 124])
[DEBUG] RGCN output shape: torch.Size([32, 128])
[DEBUG] Transformer output shape: torch.Size([32, 64])
[DEBUG] Lengths for split: [32]
[DEBUG] Out GNN shape before split: torch.Size([32, 64])
[DEBUG] Split features shapes: [torch.Size([32, 64])]
[DEBUG] Packed features shape: torch.Size([32, 7




[DEBUG] RGCN output shape: torch.Size([32, 128])
[DEBUG] Transformer output shape: torch.Size([32, 64])
[DEBUG] Lengths for split: [32]
[DEBUG] Out GNN shape before split: torch.Size([32, 64])
[DEBUG] Split features shapes: [torch.Size([32, 64])]
[DEBUG] Packed features shape: torch.Size([32, 768])
[DEBUG] Lengths sum: 32
[DEBUG] Number of valid permutations: 124
[DEBUG] Validating final edge_index: torch.Size([2, 124])
[DEBUG] Validating edge_index: torch.Size([2, 124])
[DEBUG] RGCN output shape: torch.Size([32, 128])
[DEBUG] Transformer output shape: torch.Size([32, 64])
[DEBUG] Lengths for split: [32]
[DEBUG] Out GNN shape before split: torch.Size([32, 64])
[DEBUG] Split features shapes: [torch.Size([32, 64])]
[DEBUG] Packed features shape: torch.Size([32, 768])
[DEBUG] Lengths sum: 32
[DEBUG] Number of valid permutations: 124
[DEBUG] Validating final edge_index: torch.Size([2, 124])
[DEBUG] Validating edge_index: torch.Size([2, 124])
[DEBUG] RGCN output shape: torch.Size([32, 128])

Epoch 6 Training:   0%|          | 0/735 [00:00<?, ?it/s]

[DEBUG] Packed features shape: torch.Size([32, 768])
[DEBUG] Lengths sum: 32
[DEBUG] Number of valid permutations: 124
[DEBUG] Validating final edge_index: torch.Size([2, 124])
[DEBUG] Validating edge_index: torch.Size([2, 124])
[DEBUG] RGCN output shape: torch.Size([32, 128])
[DEBUG] Transformer output shape: torch.Size([32, 64])
[DEBUG] Lengths for split: [32]
[DEBUG] Out GNN shape before split: torch.Size([32, 64])
[DEBUG] Split features shapes: [torch.Size([32, 64])]
[DEBUG] Packed features shape: torch.Size([32, 768])
[DEBUG] Lengths sum: 32
[DEBUG] Number of valid permutations: 124
[DEBUG] Validating final edge_index: torch.Size([2, 124])
[DEBUG] Validating edge_index: torch.Size([2, 124])
[DEBUG] RGCN output shape: torch.Size([32, 128])
[DEBUG] Transformer output shape: torch.Size([32, 64])
[DEBUG] Lengths for split: [32]
[DEBUG] Out GNN shape before split: torch.Size([32, 64])
[DEBUG] Split features shapes: [torch.Size([32, 64])]
[DEBUG] Packed features shape: torch.Size([32, 7

Epoch 6 Training:   0%|          | 0/735 [00:18<?, ?it/s]

[DEBUG] RGCN output shape: torch.Size([32, 128])
[DEBUG] Transformer output shape: torch.Size([32, 64])
[DEBUG] Lengths for split: [32]
[DEBUG] Out GNN shape before split: torch.Size([32, 64])
[DEBUG] Split features shapes: [torch.Size([32, 64])]
[DEBUG] Packed features shape: torch.Size([32, 768])
[DEBUG] Lengths sum: 32
[DEBUG] Number of valid permutations: 124
[DEBUG] Validating final edge_index: torch.Size([2, 124])
[DEBUG] Validating edge_index: torch.Size([2, 124])
[DEBUG] RGCN output shape: torch.Size([32, 128])
[DEBUG] Transformer output shape: torch.Size([32, 64])
[DEBUG] Lengths for split: [32]
[DEBUG] Out GNN shape before split: torch.Size([32, 64])
[DEBUG] Split features shapes: [torch.Size([32, 64])]
[DEBUG] Packed features shape: torch.Size([32, 768])
[DEBUG] Lengths sum: 32
[DEBUG] Number of valid permutations: 124
[DEBUG] Validating final edge_index: torch.Size([2, 124])
[DEBUG] Validating edge_index: torch.Size([2, 124])
[DEBUG] RGCN output shape: torch.Size([32, 128])




[DEBUG] Packed features shape: torch.Size([32, 768])
[DEBUG] Lengths sum: 32
[DEBUG] Number of valid permutations: 124
[DEBUG] Validating final edge_index: torch.Size([2, 124])
[DEBUG] Validating edge_index: torch.Size([2, 124])
[DEBUG] RGCN output shape: torch.Size([32, 128])
[DEBUG] Transformer output shape: torch.Size([32, 64])
[DEBUG] Lengths for split: [32]
[DEBUG] Out GNN shape before split: torch.Size([32, 64])
[DEBUG] Split features shapes: [torch.Size([32, 64])]
[DEBUG] Packed features shape: torch.Size([32, 768])
[DEBUG] Lengths sum: 32
[DEBUG] Number of valid permutations: 124
[DEBUG] Validating final edge_index: torch.Size([2, 124])
[DEBUG] Validating edge_index: torch.Size([2, 124])
[DEBUG] RGCN output shape: torch.Size([32, 128])
[DEBUG] Transformer output shape: torch.Size([32, 64])
[DEBUG] Lengths for split: [32]
[DEBUG] Out GNN shape before split: torch.Size([32, 64])
[DEBUG] Split features shapes: [torch.Size([32, 64])]
[DEBUG] Packed features shape: torch.Size([32, 7

Epoch 7 Training:   0%|          | 0/735 [00:00<?, ?it/s]

[DEBUG] Packed features shape: torch.Size([32, 768])
[DEBUG] Lengths sum: 32
[DEBUG] Number of valid permutations: 124
[DEBUG] Validating final edge_index: torch.Size([2, 124])
[DEBUG] Validating edge_index: torch.Size([2, 124])
[DEBUG] RGCN output shape: torch.Size([32, 128])
[DEBUG] Transformer output shape: torch.Size([32, 64])
[DEBUG] Lengths for split: [32]
[DEBUG] Out GNN shape before split: torch.Size([32, 64])
[DEBUG] Split features shapes: [torch.Size([32, 64])]
[DEBUG] Packed features shape: torch.Size([32, 768])
[DEBUG] Lengths sum: 32
[DEBUG] Number of valid permutations: 124
[DEBUG] Validating final edge_index: torch.Size([2, 124])
[DEBUG] Validating edge_index: torch.Size([2, 124])
[DEBUG] RGCN output shape: torch.Size([32, 128])
[DEBUG] Transformer output shape: torch.Size([32, 64])
[DEBUG] Lengths for split: [32]
[DEBUG] Out GNN shape before split: torch.Size([32, 64])
[DEBUG] Split features shapes: [torch.Size([32, 64])]
[DEBUG] Packed features shape: torch.Size([32, 7

Epoch 7 Training:   0%|          | 0/735 [00:18<?, ?it/s]

[DEBUG] RGCN output shape: torch.Size([32, 128])
[DEBUG] Transformer output shape: torch.Size([32, 64])
[DEBUG] Lengths for split: [32]
[DEBUG] Out GNN shape before split: torch.Size([32, 64])
[DEBUG] Split features shapes: [torch.Size([32, 64])]
[DEBUG] Packed features shape: torch.Size([32, 768])
[DEBUG] Lengths sum: 32
[DEBUG] Number of valid permutations: 124
[DEBUG] Validating final edge_index: torch.Size([2, 124])
[DEBUG] Validating edge_index: torch.Size([2, 124])
[DEBUG] RGCN output shape: torch.Size([32, 128])
[DEBUG] Transformer output shape: torch.Size([32, 64])
[DEBUG] Lengths for split: [32]
[DEBUG] Out GNN shape before split: torch.Size([32, 64])
[DEBUG] Split features shapes: [torch.Size([32, 64])]
[DEBUG] Packed features shape: torch.Size([8, 768])
[DEBUG] Lengths sum: 8
[DEBUG] Number of valid permutations: 28
[DEBUG] Validating final edge_index: torch.Size([2, 28])
[DEBUG] Validating edge_index: torch.Size([2, 28])
[DEBUG] RGCN output shape: torch.Size([8, 128])
[DEBU




[DEBUG] Transformer output shape: torch.Size([32, 64])
[DEBUG] Lengths for split: [32]
[DEBUG] Out GNN shape before split: torch.Size([32, 64])
[DEBUG] Split features shapes: [torch.Size([32, 64])]
[DEBUG] Packed features shape: torch.Size([32, 768])
[DEBUG] Lengths sum: 32
[DEBUG] Number of valid permutations: 124
[DEBUG] Validating final edge_index: torch.Size([2, 124])
[DEBUG] Validating edge_index: torch.Size([2, 124])
[DEBUG] RGCN output shape: torch.Size([32, 128])
[DEBUG] Transformer output shape: torch.Size([32, 64])
[DEBUG] Lengths for split: [32]
[DEBUG] Out GNN shape before split: torch.Size([32, 64])
[DEBUG] Split features shapes: [torch.Size([32, 64])]
[DEBUG] Packed features shape: torch.Size([32, 768])
[DEBUG] Lengths sum: 32
[DEBUG] Number of valid permutations: 124
[DEBUG] Validating final edge_index: torch.Size([2, 124])
[DEBUG] Validating edge_index: torch.Size([2, 124])
[DEBUG] RGCN output shape: torch.Size([32, 128])
[DEBUG] Transformer output shape: torch.Size([32

Epoch 8 Training:   0%|          | 0/735 [00:00<?, ?it/s]

[DEBUG] Packed features shape: torch.Size([32, 768])
[DEBUG] Lengths sum: 32
[DEBUG] Number of valid permutations: 124
[DEBUG] Validating final edge_index: torch.Size([2, 124])
[DEBUG] Validating edge_index: torch.Size([2, 124])
[DEBUG] RGCN output shape: torch.Size([32, 128])
[DEBUG] Transformer output shape: torch.Size([32, 64])
[DEBUG] Lengths for split: [32]
[DEBUG] Out GNN shape before split: torch.Size([32, 64])
[DEBUG] Split features shapes: [torch.Size([32, 64])]
[DEBUG] Packed features shape: torch.Size([32, 768])
[DEBUG] Lengths sum: 32
[DEBUG] Number of valid permutations: 124
[DEBUG] Validating final edge_index: torch.Size([2, 124])
[DEBUG] Validating edge_index: torch.Size([2, 124])
[DEBUG] RGCN output shape: torch.Size([32, 128])
[DEBUG] Transformer output shape: torch.Size([32, 64])
[DEBUG] Lengths for split: [32]
[DEBUG] Out GNN shape before split: torch.Size([32, 64])
[DEBUG] Split features shapes: [torch.Size([32, 64])]
[DEBUG] Packed features shape: torch.Size([32, 7

Epoch 8 Training:   0%|          | 0/735 [00:18<?, ?it/s]

[DEBUG] RGCN output shape: torch.Size([32, 128])
[DEBUG] Transformer output shape: torch.Size([32, 64])
[DEBUG] Lengths for split: [32]
[DEBUG] Out GNN shape before split: torch.Size([32, 64])
[DEBUG] Split features shapes: [torch.Size([32, 64])]
[DEBUG] Packed features shape: torch.Size([32, 768])
[DEBUG] Lengths sum: 32
[DEBUG] Number of valid permutations: 124
[DEBUG] Validating final edge_index: torch.Size([2, 124])
[DEBUG] Validating edge_index: torch.Size([2, 124])
[DEBUG] RGCN output shape: torch.Size([32, 128])
[DEBUG] Transformer output shape: torch.Size([32, 64])
[DEBUG] Lengths for split: [32]
[DEBUG] Out GNN shape before split: torch.Size([32, 64])
[DEBUG] Split features shapes: [torch.Size([32, 64])]
[DEBUG] Packed features shape: torch.Size([32, 768])
[DEBUG] Lengths sum: 32
[DEBUG] Number of valid permutations: 124
[DEBUG] Validating final edge_index: torch.Size([2, 124])
[DEBUG] Validating edge_index: torch.Size([2, 124])
[DEBUG] RGCN output shape: torch.Size([32, 128])




[DEBUG] Epoch 8 train loss: 0.9011179286606458
[DEBUG] Packed features shape: torch.Size([32, 768])
[DEBUG] Lengths sum: 32
[DEBUG] Number of valid permutations: 124
[DEBUG] Validating final edge_index: torch.Size([2, 124])
[DEBUG] Validating edge_index: torch.Size([2, 124])
[DEBUG] RGCN output shape: torch.Size([32, 128])
[DEBUG] Transformer output shape: torch.Size([32, 64])
[DEBUG] Lengths for split: [32]
[DEBUG] Out GNN shape before split: torch.Size([32, 64])
[DEBUG] Split features shapes: [torch.Size([32, 64])]
[DEBUG] Packed features shape: torch.Size([32, 768])
[DEBUG] Lengths sum: 32
[DEBUG] Number of valid permutations: 124
[DEBUG] Validating final edge_index: torch.Size([2, 124])
[DEBUG] Validating edge_index: torch.Size([2, 124])
[DEBUG] RGCN output shape: torch.Size([32, 128])
[DEBUG] Transformer output shape: torch.Size([32, 64])
[DEBUG] Lengths for split: [32]
[DEBUG] Out GNN shape before split: torch.Size([32, 64])
[DEBUG] Split features shapes: [torch.Size([32, 64])]
[

Epoch 9 Training:   0%|          | 0/735 [00:00<?, ?it/s]

[DEBUG] Packed features shape: torch.Size([32, 768])
[DEBUG] Lengths sum: 32
[DEBUG] Number of valid permutations: 124
[DEBUG] Validating final edge_index: torch.Size([2, 124])
[DEBUG] Validating edge_index: torch.Size([2, 124])
[DEBUG] RGCN output shape: torch.Size([32, 128])
[DEBUG] Transformer output shape: torch.Size([32, 64])
[DEBUG] Lengths for split: [32]
[DEBUG] Out GNN shape before split: torch.Size([32, 64])
[DEBUG] Split features shapes: [torch.Size([32, 64])]
[DEBUG] Packed features shape: torch.Size([32, 768])
[DEBUG] Lengths sum: 32
[DEBUG] Number of valid permutations: 124
[DEBUG] Validating final edge_index: torch.Size([2, 124])
[DEBUG] Validating edge_index: torch.Size([2, 124])
[DEBUG] RGCN output shape: torch.Size([32, 128])
[DEBUG] Transformer output shape: torch.Size([32, 64])
[DEBUG] Lengths for split: [32]
[DEBUG] Out GNN shape before split: torch.Size([32, 64])
[DEBUG] Split features shapes: [torch.Size([32, 64])]
[DEBUG] Packed features shape: torch.Size([32, 7

Epoch 9 Training:   0%|          | 0/735 [00:18<?, ?it/s]

[DEBUG] RGCN output shape: torch.Size([32, 128])
[DEBUG] Transformer output shape: torch.Size([32, 64])
[DEBUG] Lengths for split: [32]
[DEBUG] Out GNN shape before split: torch.Size([32, 64])
[DEBUG] Split features shapes: [torch.Size([32, 64])]
[DEBUG] Packed features shape: torch.Size([8, 768])
[DEBUG] Lengths sum: 8
[DEBUG] Number of valid permutations: 28
[DEBUG] Validating final edge_index: torch.Size([2, 28])
[DEBUG] Validating edge_index: torch.Size([2, 28])
[DEBUG] RGCN output shape: torch.Size([8, 128])
[DEBUG] Transformer output shape: torch.Size([8, 64])
[DEBUG] Lengths for split: [8]
[DEBUG] Out GNN shape before split: torch.Size([8, 64])
[DEBUG] Split features shapes: [torch.Size([8, 64])]
[DEBUG] Epoch 9 train loss: 0.8993043738968518
[DEBUG] Packed features shape: torch.Size([32, 768])
[DEBUG] Lengths sum: 32
[DEBUG] Number of valid permutations: 124
[DEBUG] Validating final edge_index: torch.Size([2, 124])
[DEBUG] Validating edge_index: torch.Size([2, 124])
[DEBUG] RGC




[DEBUG] RGCN output shape: torch.Size([32, 128])
[DEBUG] Transformer output shape: torch.Size([32, 64])
[DEBUG] Lengths for split: [32]
[DEBUG] Out GNN shape before split: torch.Size([32, 64])
[DEBUG] Split features shapes: [torch.Size([32, 64])]
[DEBUG] Packed features shape: torch.Size([32, 768])
[DEBUG] Lengths sum: 32
[DEBUG] Number of valid permutations: 124
[DEBUG] Validating final edge_index: torch.Size([2, 124])
[DEBUG] Validating edge_index: torch.Size([2, 124])
[DEBUG] RGCN output shape: torch.Size([32, 128])
[DEBUG] Transformer output shape: torch.Size([32, 64])
[DEBUG] Lengths for split: [32]
[DEBUG] Out GNN shape before split: torch.Size([32, 64])
[DEBUG] Split features shapes: [torch.Size([32, 64])]
[DEBUG] Packed features shape: torch.Size([32, 768])
[DEBUG] Lengths sum: 32
[DEBUG] Number of valid permutations: 124
[DEBUG] Validating final edge_index: torch.Size([2, 124])
[DEBUG] Validating edge_index: torch.Size([2, 124])
[DEBUG] RGCN output shape: torch.Size([32, 128])

Epoch 10 Training:   0%|          | 0/735 [00:00<?, ?it/s]

[DEBUG] Packed features shape: torch.Size([32, 768])
[DEBUG] Lengths sum: 32
[DEBUG] Number of valid permutations: 124
[DEBUG] Validating final edge_index: torch.Size([2, 124])
[DEBUG] Validating edge_index: torch.Size([2, 124])
[DEBUG] RGCN output shape: torch.Size([32, 128])
[DEBUG] Transformer output shape: torch.Size([32, 64])
[DEBUG] Lengths for split: [32]
[DEBUG] Out GNN shape before split: torch.Size([32, 64])
[DEBUG] Split features shapes: [torch.Size([32, 64])]
[DEBUG] Packed features shape: torch.Size([32, 768])
[DEBUG] Lengths sum: 32
[DEBUG] Number of valid permutations: 124
[DEBUG] Validating final edge_index: torch.Size([2, 124])
[DEBUG] Validating edge_index: torch.Size([2, 124])
[DEBUG] RGCN output shape: torch.Size([32, 128])
[DEBUG] Transformer output shape: torch.Size([32, 64])
[DEBUG] Lengths for split: [32]
[DEBUG] Out GNN shape before split: torch.Size([32, 64])
[DEBUG] Split features shapes: [torch.Size([32, 64])]
[DEBUG] Packed features shape: torch.Size([32, 7

Epoch 10 Training:   0%|          | 0/735 [00:18<?, ?it/s]

[DEBUG] Packed features shape: torch.Size([32, 768])
[DEBUG] Lengths sum: 32
[DEBUG] Number of valid permutations: 124
[DEBUG] Validating final edge_index: torch.Size([2, 124])
[DEBUG] Validating edge_index: torch.Size([2, 124])
[DEBUG] RGCN output shape: torch.Size([32, 128])
[DEBUG] Transformer output shape: torch.Size([32, 64])
[DEBUG] Lengths for split: [32]
[DEBUG] Out GNN shape before split: torch.Size([32, 64])
[DEBUG] Split features shapes: [torch.Size([32, 64])]
[DEBUG] Packed features shape: torch.Size([32, 768])
[DEBUG] Lengths sum: 32
[DEBUG] Number of valid permutations: 124
[DEBUG] Validating final edge_index: torch.Size([2, 124])
[DEBUG] Validating edge_index: torch.Size([2, 124])
[DEBUG] RGCN output shape: torch.Size([32, 128])
[DEBUG] Transformer output shape: torch.Size([32, 64])
[DEBUG] Lengths for split: [32]
[DEBUG] Out GNN shape before split: torch.Size([32, 64])
[DEBUG] Split features shapes: [torch.Size([32, 64])]
[DEBUG] Packed features shape: torch.Size([32, 7




[DEBUG] Epoch 10 train loss: 0.8902191454050492
[DEBUG] Packed features shape: torch.Size([32, 768])
[DEBUG] Lengths sum: 32
[DEBUG] Number of valid permutations: 124
[DEBUG] Validating final edge_index: torch.Size([2, 124])
[DEBUG] Validating edge_index: torch.Size([2, 124])
[DEBUG] RGCN output shape: torch.Size([32, 128])
[DEBUG] Transformer output shape: torch.Size([32, 64])
[DEBUG] Lengths for split: [32]
[DEBUG] Out GNN shape before split: torch.Size([32, 64])
[DEBUG] Split features shapes: [torch.Size([32, 64])]
[DEBUG] Packed features shape: torch.Size([32, 768])
[DEBUG] Lengths sum: 32
[DEBUG] Number of valid permutations: 124
[DEBUG] Validating final edge_index: torch.Size([2, 124])
[DEBUG] Validating edge_index: torch.Size([2, 124])
[DEBUG] RGCN output shape: torch.Size([32, 128])
[DEBUG] Transformer output shape: torch.Size([32, 64])
[DEBUG] Lengths for split: [32]
[DEBUG] Out GNN shape before split: torch.Size([32, 64])
[DEBUG] Split features shapes: [torch.Size([32, 64])]


Testing:   4%|▎         | 7/195 [00:00<00:02, 68.46it/s]

[DEBUG] Packed features shape: torch.Size([32, 768])
[DEBUG] Lengths sum: 32
[DEBUG] Number of valid permutations: 124
[DEBUG] Validating final edge_index: torch.Size([2, 124])
[DEBUG] Validating edge_index: torch.Size([2, 124])
[DEBUG] RGCN output shape: torch.Size([32, 128])
[DEBUG] Transformer output shape: torch.Size([32, 64])
[DEBUG] Lengths for split: [32]
[DEBUG] Out GNN shape before split: torch.Size([32, 64])
[DEBUG] Split features shapes: [torch.Size([32, 64])]
[DEBUG] Packed features shape: torch.Size([32, 768])
[DEBUG] Lengths sum: 32
[DEBUG] Number of valid permutations: 124
[DEBUG] Validating final edge_index: torch.Size([2, 124])
[DEBUG] Validating edge_index: torch.Size([2, 124])
[DEBUG] RGCN output shape: torch.Size([32, 128])
[DEBUG] Transformer output shape: torch.Size([32, 64])
[DEBUG] Lengths for split: [32]
[DEBUG] Out GNN shape before split: torch.Size([32, 64])
[DEBUG] Split features shapes: [torch.Size([32, 64])]
[DEBUG] Packed features shape: torch.Size([32, 7

Testing:  14%|█▍        | 28/195 [00:00<00:02, 68.33it/s]

[DEBUG] RGCN output shape: torch.Size([32, 128])
[DEBUG] Transformer output shape: torch.Size([32, 64])
[DEBUG] Lengths for split: [32]
[DEBUG] Out GNN shape before split: torch.Size([32, 64])
[DEBUG] Split features shapes: [torch.Size([32, 64])]
[DEBUG] Packed features shape: torch.Size([32, 768])
[DEBUG] Lengths sum: 32
[DEBUG] Number of valid permutations: 124
[DEBUG] Validating final edge_index: torch.Size([2, 124])
[DEBUG] Validating edge_index: torch.Size([2, 124])
[DEBUG] RGCN output shape: torch.Size([32, 128])
[DEBUG] Transformer output shape: torch.Size([32, 64])
[DEBUG] Lengths for split: [32]
[DEBUG] Out GNN shape before split: torch.Size([32, 64])
[DEBUG] Split features shapes: [torch.Size([32, 64])]
[DEBUG] Packed features shape: torch.Size([32, 768])
[DEBUG] Lengths sum: 32
[DEBUG] Number of valid permutations: 124
[DEBUG] Validating final edge_index: torch.Size([2, 124])
[DEBUG] Validating edge_index: torch.Size([2, 124])
[DEBUG] RGCN output shape: torch.Size([32, 128])

Testing:  18%|█▊        | 35/195 [00:00<00:02, 67.48it/s]

[DEBUG] Packed features shape: torch.Size([32, 768])
[DEBUG] Lengths sum: 32
[DEBUG] Number of valid permutations: 124
[DEBUG] Validating final edge_index: torch.Size([2, 124])
[DEBUG] Validating edge_index: torch.Size([2, 124])
[DEBUG] RGCN output shape: torch.Size([32, 128])
[DEBUG] Transformer output shape: torch.Size([32, 64])
[DEBUG] Lengths for split: [32]
[DEBUG] Out GNN shape before split: torch.Size([32, 64])
[DEBUG] Split features shapes: [torch.Size([32, 64])]
[DEBUG] Packed features shape: torch.Size([32, 768])
[DEBUG] Lengths sum: 32
[DEBUG] Number of valid permutations: 124
[DEBUG] Validating final edge_index: torch.Size([2, 124])
[DEBUG] Validating edge_index: torch.Size([2, 124])
[DEBUG] RGCN output shape: torch.Size([32, 128])
[DEBUG] Transformer output shape: torch.Size([32, 64])
[DEBUG] Lengths for split: [32]
[DEBUG] Out GNN shape before split: torch.Size([32, 64])
[DEBUG] Split features shapes: [torch.Size([32, 64])]
[DEBUG] Packed features shape: torch.Size([32, 7

Testing:  29%|██▊       | 56/195 [00:00<00:02, 66.45it/s]

[DEBUG] RGCN output shape: torch.Size([32, 128])
[DEBUG] Transformer output shape: torch.Size([32, 64])
[DEBUG] Lengths for split: [32]
[DEBUG] Out GNN shape before split: torch.Size([32, 64])
[DEBUG] Split features shapes: [torch.Size([32, 64])]
[DEBUG] Packed features shape: torch.Size([32, 768])
[DEBUG] Lengths sum: 32
[DEBUG] Number of valid permutations: 124
[DEBUG] Validating final edge_index: torch.Size([2, 124])
[DEBUG] Validating edge_index: torch.Size([2, 124])
[DEBUG] RGCN output shape: torch.Size([32, 128])
[DEBUG] Transformer output shape: torch.Size([32, 64])
[DEBUG] Lengths for split: [32]
[DEBUG] Out GNN shape before split: torch.Size([32, 64])
[DEBUG] Split features shapes: [torch.Size([32, 64])]
[DEBUG] Packed features shape: torch.Size([32, 768])
[DEBUG] Lengths sum: 32
[DEBUG] Number of valid permutations: 124
[DEBUG] Validating final edge_index: torch.Size([2, 124])
[DEBUG] Validating edge_index: torch.Size([2, 124])
[DEBUG] RGCN output shape: torch.Size([32, 128])

Testing:  32%|███▏      | 63/195 [00:00<00:02, 65.98it/s]

[DEBUG] Packed features shape: torch.Size([32, 768])
[DEBUG] Lengths sum: 32
[DEBUG] Number of valid permutations: 124
[DEBUG] Validating final edge_index: torch.Size([2, 124])
[DEBUG] Validating edge_index: torch.Size([2, 124])
[DEBUG] RGCN output shape: torch.Size([32, 128])
[DEBUG] Transformer output shape: torch.Size([32, 64])
[DEBUG] Lengths for split: [32]
[DEBUG] Out GNN shape before split: torch.Size([32, 64])
[DEBUG] Split features shapes: [torch.Size([32, 64])]
[DEBUG] Packed features shape: torch.Size([32, 768])
[DEBUG] Lengths sum: 32
[DEBUG] Number of valid permutations: 124
[DEBUG] Validating final edge_index: torch.Size([2, 124])
[DEBUG] Validating edge_index: torch.Size([2, 124])
[DEBUG] RGCN output shape: torch.Size([32, 128])
[DEBUG] Transformer output shape: torch.Size([32, 64])
[DEBUG] Lengths for split: [32]
[DEBUG] Out GNN shape before split: torch.Size([32, 64])
[DEBUG] Split features shapes: [torch.Size([32, 64])]
[DEBUG] Packed features shape: torch.Size([32, 7

Testing:  39%|███▉      | 77/195 [00:01<00:01, 65.18it/s]

[DEBUG] RGCN output shape: torch.Size([32, 128])
[DEBUG] Transformer output shape: torch.Size([32, 64])
[DEBUG] Lengths for split: [32]
[DEBUG] Out GNN shape before split: torch.Size([32, 64])
[DEBUG] Split features shapes: [torch.Size([32, 64])]
[DEBUG] Packed features shape: torch.Size([32, 768])
[DEBUG] Lengths sum: 32
[DEBUG] Number of valid permutations: 124
[DEBUG] Validating final edge_index: torch.Size([2, 124])
[DEBUG] Validating edge_index: torch.Size([2, 124])
[DEBUG] RGCN output shape: torch.Size([32, 128])
[DEBUG] Transformer output shape: torch.Size([32, 64])
[DEBUG] Lengths for split: [32]
[DEBUG] Out GNN shape before split: torch.Size([32, 64])
[DEBUG] Split features shapes: [torch.Size([32, 64])]
[DEBUG] Packed features shape: torch.Size([32, 768])
[DEBUG] Lengths sum: 32
[DEBUG] Number of valid permutations: 124
[DEBUG] Validating final edge_index: torch.Size([2, 124])
[DEBUG] Validating edge_index: torch.Size([2, 124])
[DEBUG] RGCN output shape: torch.Size([32, 128])

Testing:  47%|████▋     | 91/195 [00:01<00:01, 63.84it/s]

[DEBUG] RGCN output shape: torch.Size([32, 128])
[DEBUG] Transformer output shape: torch.Size([32, 64])
[DEBUG] Lengths for split: [32]
[DEBUG] Out GNN shape before split: torch.Size([32, 64])
[DEBUG] Split features shapes: [torch.Size([32, 64])]
[DEBUG] Packed features shape: torch.Size([32, 768])
[DEBUG] Lengths sum: 32
[DEBUG] Number of valid permutations: 124
[DEBUG] Validating final edge_index: torch.Size([2, 124])
[DEBUG] Validating edge_index: torch.Size([2, 124])
[DEBUG] RGCN output shape: torch.Size([32, 128])
[DEBUG] Transformer output shape: torch.Size([32, 64])
[DEBUG] Lengths for split: [32]
[DEBUG] Out GNN shape before split: torch.Size([32, 64])
[DEBUG] Split features shapes: [torch.Size([32, 64])]
[DEBUG] Packed features shape: torch.Size([32, 768])
[DEBUG] Lengths sum: 32
[DEBUG] Number of valid permutations: 124
[DEBUG] Validating final edge_index: torch.Size([2, 124])
[DEBUG] Validating edge_index: torch.Size([2, 124])
[DEBUG] RGCN output shape: torch.Size([32, 128])

Testing:  54%|█████▍    | 105/195 [00:01<00:01, 63.83it/s]

[DEBUG] Packed features shape: torch.Size([32, 768])
[DEBUG] Lengths sum: 32
[DEBUG] Number of valid permutations: 124
[DEBUG] Validating final edge_index: torch.Size([2, 124])
[DEBUG] Validating edge_index: torch.Size([2, 124])
[DEBUG] RGCN output shape: torch.Size([32, 128])
[DEBUG] Transformer output shape: torch.Size([32, 64])
[DEBUG] Lengths for split: [32]
[DEBUG] Out GNN shape before split: torch.Size([32, 64])
[DEBUG] Split features shapes: [torch.Size([32, 64])]
[DEBUG] Packed features shape: torch.Size([32, 768])
[DEBUG] Lengths sum: 32
[DEBUG] Number of valid permutations: 124
[DEBUG] Validating final edge_index: torch.Size([2, 124])
[DEBUG] Validating edge_index: torch.Size([2, 124])
[DEBUG] RGCN output shape: torch.Size([32, 128])
[DEBUG] Transformer output shape: torch.Size([32, 64])
[DEBUG] Lengths for split: [32]
[DEBUG] Out GNN shape before split: torch.Size([32, 64])
[DEBUG] Split features shapes: [torch.Size([32, 64])]
[DEBUG] Packed features shape: torch.Size([32, 7

Testing:  61%|██████    | 119/195 [00:01<00:01, 63.63it/s]

[DEBUG] RGCN output shape: torch.Size([32, 128])
[DEBUG] Transformer output shape: torch.Size([32, 64])
[DEBUG] Lengths for split: [32]
[DEBUG] Out GNN shape before split: torch.Size([32, 64])
[DEBUG] Split features shapes: [torch.Size([32, 64])]
[DEBUG] Packed features shape: torch.Size([32, 768])
[DEBUG] Lengths sum: 32
[DEBUG] Number of valid permutations: 124
[DEBUG] Validating final edge_index: torch.Size([2, 124])
[DEBUG] Validating edge_index: torch.Size([2, 124])
[DEBUG] RGCN output shape: torch.Size([32, 128])
[DEBUG] Transformer output shape: torch.Size([32, 64])
[DEBUG] Lengths for split: [32]
[DEBUG] Out GNN shape before split: torch.Size([32, 64])
[DEBUG] Split features shapes: [torch.Size([32, 64])]
[DEBUG] Packed features shape: torch.Size([32, 768])
[DEBUG] Lengths sum: 32
[DEBUG] Number of valid permutations: 124
[DEBUG] Validating final edge_index: torch.Size([2, 124])
[DEBUG] Validating edge_index: torch.Size([2, 124])
[DEBUG] RGCN output shape: torch.Size([32, 128])

Testing:  68%|██████▊   | 133/195 [00:02<00:01, 59.25it/s]

[DEBUG] RGCN output shape: torch.Size([32, 128])
[DEBUG] Transformer output shape: torch.Size([32, 64])
[DEBUG] Lengths for split: [32]
[DEBUG] Out GNN shape before split: torch.Size([32, 64])
[DEBUG] Split features shapes: [torch.Size([32, 64])]
[DEBUG] Packed features shape: torch.Size([32, 768])
[DEBUG] Lengths sum: 32
[DEBUG] Number of valid permutations: 124
[DEBUG] Validating final edge_index: torch.Size([2, 124])
[DEBUG] Validating edge_index: torch.Size([2, 124])
[DEBUG] RGCN output shape: torch.Size([32, 128])
[DEBUG] Transformer output shape: torch.Size([32, 64])
[DEBUG] Lengths for split: [32]
[DEBUG] Out GNN shape before split: torch.Size([32, 64])
[DEBUG] Split features shapes: [torch.Size([32, 64])]
[DEBUG] Packed features shape: torch.Size([32, 768])
[DEBUG] Lengths sum: 32
[DEBUG] Number of valid permutations: 124
[DEBUG] Validating final edge_index: torch.Size([2, 124])
[DEBUG] Validating edge_index: torch.Size([2, 124])
[DEBUG] RGCN output shape: torch.Size([32, 128])

Testing:  71%|███████▏  | 139/195 [00:02<00:00, 58.65it/s]

[DEBUG] RGCN output shape: torch.Size([32, 128])
[DEBUG] Transformer output shape: torch.Size([32, 64])
[DEBUG] Lengths for split: [32]
[DEBUG] Out GNN shape before split: torch.Size([32, 64])
[DEBUG] Split features shapes: [torch.Size([32, 64])]
[DEBUG] Packed features shape: torch.Size([32, 768])
[DEBUG] Lengths sum: 32
[DEBUG] Number of valid permutations: 124
[DEBUG] Validating final edge_index: torch.Size([2, 124])
[DEBUG] Validating edge_index: torch.Size([2, 124])
[DEBUG] RGCN output shape: torch.Size([32, 128])
[DEBUG] Transformer output shape: torch.Size([32, 64])
[DEBUG] Lengths for split: [32]
[DEBUG] Out GNN shape before split: torch.Size([32, 64])
[DEBUG] Split features shapes: [torch.Size([32, 64])]
[DEBUG] Packed features shape: torch.Size([32, 768])
[DEBUG] Lengths sum: 32
[DEBUG] Number of valid permutations: 124
[DEBUG] Validating final edge_index: torch.Size([2, 124])
[DEBUG] Validating edge_index: torch.Size([2, 124])
[DEBUG] RGCN output shape: torch.Size([32, 128])

Testing:  77%|███████▋  | 151/195 [00:02<00:00, 52.25it/s]

[DEBUG] Packed features shape: torch.Size([32, 768])
[DEBUG] Lengths sum: 32
[DEBUG] Number of valid permutations: 124
[DEBUG] Validating final edge_index: torch.Size([2, 124])
[DEBUG] Validating edge_index: torch.Size([2, 124])
[DEBUG] RGCN output shape: torch.Size([32, 128])
[DEBUG] Transformer output shape: torch.Size([32, 64])
[DEBUG] Lengths for split: [32]
[DEBUG] Out GNN shape before split: torch.Size([32, 64])
[DEBUG] Split features shapes: [torch.Size([32, 64])]
[DEBUG] Packed features shape: torch.Size([32, 768])
[DEBUG] Lengths sum: 32
[DEBUG] Number of valid permutations: 124
[DEBUG] Validating final edge_index: torch.Size([2, 124])
[DEBUG] Validating edge_index: torch.Size([2, 124])
[DEBUG] RGCN output shape: torch.Size([32, 128])
[DEBUG] Transformer output shape: torch.Size([32, 64])
[DEBUG] Lengths for split: [32]
[DEBUG] Out GNN shape before split: torch.Size([32, 64])
[DEBUG] Split features shapes: [torch.Size([32, 64])]
[DEBUG] Packed features shape: torch.Size([32, 7

Testing:  85%|████████▍ | 165/195 [00:02<00:00, 56.58it/s]

[DEBUG] RGCN output shape: torch.Size([32, 128])
[DEBUG] Transformer output shape: torch.Size([32, 64])
[DEBUG] Lengths for split: [32]
[DEBUG] Out GNN shape before split: torch.Size([32, 64])
[DEBUG] Split features shapes: [torch.Size([32, 64])]
[DEBUG] Packed features shape: torch.Size([32, 768])
[DEBUG] Lengths sum: 32
[DEBUG] Number of valid permutations: 124
[DEBUG] Validating final edge_index: torch.Size([2, 124])
[DEBUG] Validating edge_index: torch.Size([2, 124])
[DEBUG] RGCN output shape: torch.Size([32, 128])
[DEBUG] Transformer output shape: torch.Size([32, 64])
[DEBUG] Lengths for split: [32]
[DEBUG] Out GNN shape before split: torch.Size([32, 64])
[DEBUG] Split features shapes: [torch.Size([32, 64])]
[DEBUG] Packed features shape: torch.Size([32, 768])
[DEBUG] Lengths sum: 32
[DEBUG] Number of valid permutations: 124
[DEBUG] Validating final edge_index: torch.Size([2, 124])
[DEBUG] Validating edge_index: torch.Size([2, 124])
[DEBUG] RGCN output shape: torch.Size([32, 128])

Testing:  91%|█████████▏| 178/195 [00:02<00:00, 57.56it/s]

[DEBUG] RGCN output shape: torch.Size([32, 128])
[DEBUG] Transformer output shape: torch.Size([32, 64])
[DEBUG] Lengths for split: [32]
[DEBUG] Out GNN shape before split: torch.Size([32, 64])
[DEBUG] Split features shapes: [torch.Size([32, 64])]
[DEBUG] Packed features shape: torch.Size([32, 768])
[DEBUG] Lengths sum: 32
[DEBUG] Number of valid permutations: 124
[DEBUG] Validating final edge_index: torch.Size([2, 124])
[DEBUG] Validating edge_index: torch.Size([2, 124])
[DEBUG] RGCN output shape: torch.Size([32, 128])
[DEBUG] Transformer output shape: torch.Size([32, 64])
[DEBUG] Lengths for split: [32]
[DEBUG] Out GNN shape before split: torch.Size([32, 64])
[DEBUG] Split features shapes: [torch.Size([32, 64])]
[DEBUG] Packed features shape: torch.Size([32, 768])
[DEBUG] Lengths sum: 32
[DEBUG] Number of valid permutations: 124
[DEBUG] Validating final edge_index: torch.Size([2, 124])
[DEBUG] Validating edge_index: torch.Size([2, 124])
[DEBUG] RGCN output shape: torch.Size([32, 128])

Testing:  98%|█████████▊| 191/195 [00:03<00:00, 57.08it/s]

[DEBUG] RGCN output shape: torch.Size([32, 128])
[DEBUG] Transformer output shape: torch.Size([32, 64])
[DEBUG] Lengths for split: [32]
[DEBUG] Out GNN shape before split: torch.Size([32, 64])
[DEBUG] Split features shapes: [torch.Size([32, 64])]
[DEBUG] Packed features shape: torch.Size([32, 768])
[DEBUG] Lengths sum: 32
[DEBUG] Number of valid permutations: 124
[DEBUG] Validating final edge_index: torch.Size([2, 124])
[DEBUG] Validating edge_index: torch.Size([2, 124])
[DEBUG] RGCN output shape: torch.Size([32, 128])
[DEBUG] Transformer output shape: torch.Size([32, 64])
[DEBUG] Lengths for split: [32]
[DEBUG] Out GNN shape before split: torch.Size([32, 64])
[DEBUG] Split features shapes: [torch.Size([32, 64])]
[DEBUG] Packed features shape: torch.Size([32, 768])
[DEBUG] Lengths sum: 32
[DEBUG] Number of valid permutations: 124
[DEBUG] Validating final edge_index: torch.Size([2, 124])
[DEBUG] Validating edge_index: torch.Size([2, 124])
[DEBUG] RGCN output shape: torch.Size([32, 128])

Testing: 100%|██████████| 195/195 [00:03<00:00, 61.32it/s]

[DEBUG] RGCN output shape: torch.Size([32, 128])
[DEBUG] Transformer output shape: torch.Size([32, 64])
[DEBUG] Lengths for split: [32]
[DEBUG] Out GNN shape before split: torch.Size([32, 64])
[DEBUG] Split features shapes: [torch.Size([32, 64])]
[DEBUG] Packed features shape: torch.Size([5, 768])
[DEBUG] Lengths sum: 5
[DEBUG] Number of valid permutations: 16
[DEBUG] Validating final edge_index: torch.Size([2, 16])
[DEBUG] Validating edge_index: torch.Size([2, 16])
[DEBUG] RGCN output shape: torch.Size([5, 128])
[DEBUG] Transformer output shape: torch.Size([5, 64])
[DEBUG] Lengths for split: [5]
[DEBUG] Out GNN shape before split: torch.Size([5, 64])
[DEBUG] Split features shapes: [torch.Size([5, 64])]
[DEBUG] Test Accuracy: 0.6489618541767262
[DEBUG] Test F1 Score: 0.34920521936213583
              precision    recall  f1-score   support

           0       0.66      0.97      0.78      3768
           1       0.56      0.22      0.32      1206
           2       0.65      0.11      




(0.6489618541767262, 0.34920521936213583)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

#---------------------------------------------------------------------------------------------------------------------------------------------------------
import torch
from torch import nn
import torch.nn.functional as F
from .MultiheadAttention import MultiheadAttention
import math


class TransformerEncoder(nn.Module):
    """
    Transformer encoder consisting of *args.encoder_layers* layers. Each layer
    is a :class:`TransformerEncoderLayer`.
    Args:
        embed_tokens (torch.nn.Embedding): input embedding
        num_heads (int): number of heads
        layers (int): number of layers
        attn_dropout (float): dropout applied on the attention weights
        relu_dropout (float): dropout applied on the first layer of the residual block
        res_dropout (float): dropout applied on the residual block
        attn_mask (bool): whether to apply mask on the attention weights
    """

    def __init__(self, embed_dim, num_heads, layers, attn_dropout=0.0, relu_dropout=0.0, res_dropout=0.0,
                 embed_dropout=0.0, attn_mask=False):
        super().__init__()
        self.dropout = embed_dropout      # Embedding dropout
        self.attn_dropout = attn_dropout
        self.embed_dim = embed_dim
        self.embed_scale = math.sqrt(embed_dim)
        
        self.attn_mask = attn_mask

        self.layers = nn.ModuleList([])
        for layer in range(layers):
            new_layer = TransformerEncoderLayer(embed_dim,
                                                num_heads=num_heads,
                                                attn_dropout=attn_dropout,
                                                relu_dropout=relu_dropout,
                                                res_dropout=res_dropout,
                                                attn_mask=attn_mask)
            self.layers.append(new_layer)

        self.register_buffer('version', torch.Tensor([2]))
        self.normalize = True
        if self.normalize:
            self.layer_norm = LayerNorm(embed_dim)

    def forward(self, x_in, x_in_k = None, x_in_v = None):
        """
        Args:
            x_in (FloatTensor): embedded input of shape `(src_len, batch, embed_dim)`
            x_in_k (FloatTensor): embedded input of shape `(src_len, batch, embed_dim)`
            x_in_v (FloatTensor): embedded input of shape `(src_len, batch, embed_dim)`
        Returns:
            dict:
                - **encoder_out** (Tensor): the last encoder layer's output of
                  shape `(src_len, batch, embed_dim)`
                - **encoder_padding_mask** (ByteTensor): the positions of
                  padding elements of shape `(batch, src_len)`
        """
        # embed tokens and positions
        x = self.embed_scale * x_in
    
        x = F.dropout(x, p=self.dropout, training=self.training)

        if x_in_k is not None and x_in_v is not None:
            # embed tokens and positions    
            x_k = self.embed_scale * x_in_k
            x_v = self.embed_scale * x_in_v
            
            x_k = F.dropout(x_k, p=self.dropout, training=self.training)
            x_v = F.dropout(x_v, p=self.dropout, training=self.training)
        
        # encoder layers
        intermediates = [x]
        for layer in self.layers:
            if x_in_k is not None and x_in_v is not None:
                x = layer(x, x_k, x_v)
            else:
                x = layer(x)
            intermediates.append(x)

        if self.normalize:
            x = self.layer_norm(x)

        return x

    def max_positions(self):
        """Maximum input length supported by the encoder."""
        if self.embed_positions is None:
            return self.max_source_positions
        return min(self.max_source_positions, self.embed_positions.max_positions())


class TransformerEncoderLayer(nn.Module):
    """Encoder layer block.
    In the original paper each operation (multi-head attention or FFN) is
    postprocessed with: `dropout -> add residual -> layernorm`. In the
    tensor2tensor code they suggest that learning is more robust when
    preprocessing each layer with layernorm and postprocessing with:
    `dropout -> add residual`. We default to the approach in the paper, but the
    tensor2tensor approach can be enabled by setting
    *args.encoder_normalize_before* to ``True``.
    Args:
        embed_dim: Embedding dimension
    """

    def __init__(self, embed_dim, num_heads=4, attn_dropout=0.1, relu_dropout=0.1, res_dropout=0.1,
                 attn_mask=False):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        
        self.self_attn = MultiheadAttention(
            embed_dim=self.embed_dim,
            num_heads=self.num_heads,
            attn_dropout=attn_dropout
        )
        self.attn_mask = attn_mask

        self.relu_dropout = relu_dropout
        self.res_dropout = res_dropout
        self.normalize_before = True

        self.fc1 = Linear(self.embed_dim, 4*self.embed_dim)   # The "Add & Norm" part in the paper
        self.fc2 = Linear(4*self.embed_dim, self.embed_dim)
        self.layer_norms = nn.ModuleList([LayerNorm(self.embed_dim) for _ in range(2)])

    def forward(self, x, x_k=None, x_v=None):
        """
        Args:
            x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
            encoder_padding_mask (ByteTensor): binary ByteTensor of shape
                `(batch, src_len)` where padding elements are indicated by ``1``.
            x_k (Tensor): same as x
            x_v (Tensor): same as x
        Returns:
            encoded output of shape `(batch, src_len, embed_dim)`
        """
        residual = x
        x = self.maybe_layer_norm(0, x, before=True)
        mask = buffered_future_mask(x, x_k) if self.attn_mask else None
        if x_k is None and x_v is None:
            x, _ = self.self_attn(query=x, key=x, value=x, attn_mask=mask)
        else:
            x_k = self.maybe_layer_norm(0, x_k, before=True)
            x_v = self.maybe_layer_norm(0, x_v, before=True) 
            x, _ = self.self_attn(query=x, key=x_k, value=x_v, attn_mask=mask)
        x = F.dropout(x, p=self.res_dropout, training=self.training)
        x = residual + x
        x = self.maybe_layer_norm(0, x, after=True)

        residual = x
        x = self.maybe_layer_norm(1, x, before=True)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, p=self.relu_dropout, training=self.training)
        x = self.fc2(x)
        x = F.dropout(x, p=self.res_dropout, training=self.training)
        x = residual + x
        x = self.maybe_layer_norm(1, x, after=True)
        return x

    def maybe_layer_norm(self, i, x, before=False, after=False):
        assert before ^ after
        if after ^ self.normalize_before:
            return self.layer_norms[i](x)
        else:
            return x

def fill_with_neg_inf(t):
    """FP16-compatible function that fills a tensor with -inf."""
    return t.float().fill_(float('-inf')).type_as(t)


def buffered_future_mask(tensor, tensor2=None):
    dim1 = dim2 = tensor.size(0)
    if tensor2 is not None:
        dim2 = tensor2.size(0)
    future_mask = torch.triu(fill_with_neg_inf(torch.ones(dim1, dim2)), 1+abs(dim2-dim1))
    if tensor.is_cuda:
        future_mask = future_mask.cuda()
    return future_mask[:dim1, :dim2]


def Linear(in_features, out_features, bias=True):
    m = nn.Linear(in_features, out_features, bias)
    nn.init.xavier_uniform_(m.weight)
    if bias:
        nn.init.constant_(m.bias, 0.)
    return m


def LayerNorm(embedding_dim):
    m = nn.LayerNorm(embedding_dim)
    return m


if __name__ == '__main__':
    encoder = TransformerEncoder(100, 4, 2)
    x = torch.tensor(torch.rand(83, 32, 100))
    print(encoder(x).shape)
# ------------------------------------------------------------------------------------------------------------------------------------------------------------


class CrossmodalNet(nn.Module):
    def __init__(self, inchannels, args) -> None:
        super(CrossmodalNet, self).__init__()

        self.modalities = args.modalities
        n_modals = len(args.modalities)

        layers = nn.ModuleDict()
        for j in self.modalities:
            for k in self.modalities:
                if j == k: continue
                layers_name = j + k
                layers[layers_name] = TransformerEncoder(inchannels, num_heads=args.crossmodal_nheads, layers=args.num_crossmodal)
            layers[f'mem_{j}'] = TransformerEncoder(inchannels * (n_modals - 1), num_heads=args.self_att_nheads, layers=args.num_self_att)
        self.layers = layers
        
    def forward(self, x_s):
        
        assert len(x_s) == len(self.modalities), f'{len(x_s)} diff {self.modalities}'

        for j in range(len(x_s)):
            x_s[j] = x_s[j].permute(1, 0, 2)

        out_dict = {}
        for j, x_j in zip(self.modalities, x_s):
            temp = []
            for k, x_k  in zip(self.modalities, x_s):
                if j == k: continue
                layer_name = j + k
                out_dict[layer_name] = self.layers[layer_name](x_j, x_k, x_k)
                temp.append(out_dict[layer_name])
            temp = torch.cat(temp, dim=2)
            out_dict[f'mem_{j}'] = self.layers[f'mem_{j}'](temp)
        out = []
        for j in self.modalities:
            out.append(out_dict[f'mem_{j}'])
        
        out = torch.cat(out, dim=2)

        return out