In [4]:
from qml_ssl.data_jets_graph import *

import matplotlib.pyplot as plt

In [5]:
dataset = QG_Jets('../data/QG_Jets')
print(f"Length: {len(dataset)}, Info: {dataset[0]}, Sample...: \n{dataset[0].h[:5]}")


total_size = 100000
dataset = dataset[:total_size]

train_size = int(0.6 * total_size)
val_size = int(0.2 * total_size)
test_size = int(0.2 * total_size)

dataset.transform = KNNGroup(k=5, attr_name="h")
sample = dataset[10]
print(sample)

print(f'Node degree: {sample.num_edges / sample.num_nodes:.2f}')
print(f'Has isolated nodes: {sample.has_isolated_nodes()}')
print(f'Has self-loops: {sample.has_self_loops()}')
print(f'Is undirected: {sample.is_undirected()}')
nx.draw(pyg.utils.to_networkx(dataset[10]))

Length: 100000, Info: Data(y=[1], particleid=[18, 1], h=[18, 3], num_nodes=18), Sample...: 
tensor([[ 5.3666e-04,  3.8694e-01, -8.9608e-03],
        [ 3.1963e-04, -2.2557e-01, -2.0012e-01],
        [ 2.2936e-03, -3.2403e-02, -2.4649e-01],
        [ 8.2497e-03,  2.0372e-01,  1.5876e-02],
        [ 3.3865e-03, -1.8214e-01,  4.6523e-02]])


ImportError: 'knn_graph' requires 'torch-cluster'

In [None]:
from qml_ssl.data_jets_graph import *

def generate_embeddings(model, data_loader):
    """
    Generate embeddings for the given data using the provided model.

    Args:
        model (nn.Module): Trained model.
        data_loader (DataLoader): Data loader for the dataset.

    Returns:
        tuple: Embeddings and labels as numpy arrays.
    """
    model.eval()
    embeddings = []
    labels = []
    
    with torch.no_grad():
        for data in data_loader:
            data = data.to(model.device)
            emb = model.model(data)
            embeddings.append(emb)
            labels.append(data.y)
    
    embeddings = torch.cat(embeddings).cpu().numpy()
    labels = torch.cat(labels).cpu().numpy()
    
    return embeddings, labels

class Custom_GCN(pyg_nn.MessagePassing):
    def __init__(self, out_channels, in_channels=8):
        super().__init__(aggr='add')

        self.mlp = nn.Sequential(
            nn.Linear(in_channels, out_channels),
            nn.ReLU(),
            nn.Linear(out_channels, out_channels),
        )

    def forward(self, h, particleid, edge_index):
        return self.propagate(edge_index, h=h, particleid=particleid)

    def message(self, h_i, h_j, particleid_i, particleid_j):
        edge_feat = torch.cat([h_i, h_j, particleid_i, particleid_j], dim=-1)
        return self.mlp(edge_feat)
    
class GCN_Encoder(nn.Module):
    def __init__(self, hidden_dim=8):
        super().__init__()

        self.conv1 = Custom_GCN(hidden_dim)
        self.conv2 = Custom_GCN(hidden_dim, in_channels=hidden_dim*2+2)
        self.output_dim = 8
        # self.classifier = pyg_nn.MLP([hidden_dim, hidden_dim, output_dim], bias=[False, True])
        self.readout = nn.Sequential(
            nn.Linear(hidden_dim, 16),
            nn.ReLU(),
            nn.Linear(16, self.output_dim))

    def forward(self, data):
        h, particleid, edge_index, batch = data.h, data.particleid, data.edge_index, data.batch
        
        # First Custom_GCN layer
        x = self.conv1(h=h, particleid=particleid, edge_index=edge_index)
        x = x.relu()
        # x = self.dropout(x)
        
        # Second Custom_GCN layer
        x = self.conv2(h=x, particleid=particleid, edge_index=edge_index)
        x = x.relu()
        # x = self.dropout(h)
        
        # Global Pooling:
        x = pyg_nn.global_mean_pool(x, batch)
        
        # Classifier:
        # return self.classifier(x)
        return self.readout(x)

import pytorch_lightning as pl    
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print("Device:", device)
batch_size = 64
from pytorch_metric_learning import losses
import torchmetrics

class ModelPL_Contrastive(pl.LightningModule):
    def __init__(self, model, learning_rate=0.01):
        super().__init__()
        self.model = model
        self.learning_rate = learning_rate
        # self.criterion = losses.ContrastiveLoss(pos_margin=0.25, neg_margin=5.0)
        self.criterion = losses.NTXentLoss(temperature=0.2)

        self.train_loss = torchmetrics.MeanMetric()
        self.val_loss = torchmetrics.MeanMetric()

    def forward(self, data):
        return self.model(data)
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        lr_scheduler = {
            'scheduler': torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.25, patience=1),
            'monitor': 'val_loss',
            'interval': 'epoch',
            'frequency': 1
        }
        return [optimizer], [lr_scheduler]

    def training_step(self, data, batch_idx):
        embeddings = self(data)
        loss = self.criterion(embeddings, data.y)
        self.train_loss.update(loss)
        self.log('train_loss', loss, on_step=False, on_epoch=True, prog_bar=True, batch_size=batch_size)
        return loss

    def validation_step(self, data, batch_idx):
        embeddings = self(data)
        loss = self.criterion(embeddings, data.y)
        self.val_loss.update(loss)
        self.log('val_loss', loss, on_step=False, on_epoch=True, prog_bar=True, batch_size=batch_size)
        return loss

    def test_step(self, data, batch_idx):
        embeddings = self(data)
        loss = self.criterion(embeddings, data.y)
        self.log('test_loss', loss, on_step=False, on_epoch=True, prog_bar=True, batch_size=batch_size)
        return loss



class ModelPL_Classify(pl.LightningModule):
    def __init__(self, model, learning_rate=0.001):
        super().__init__()
        self.model = model
        self.classifier = nn.Sequential(
            nn.Linear(self.model.output_dim, 16),
            nn.ReLU(),
            nn.Linear(16, 2))
        self.learning_rate = learning_rate
        self.criterion = torch.nn.CrossEntropyLoss()
        
        from torchmetrics import AUROC, Accuracy 
        self.train_auc = AUROC(task='binary')
        self.val_auc = AUROC(task='binary')
        self.test_auc = AUROC(task='binary')
        
        self.train_acc = Accuracy(task='binary')
        self.val_acc = Accuracy(task='binary')
        self.test_acc = Accuracy(task='binary')

    def forward(self, data):
        embeddings = self.model(data)
        return self.classifier(embeddings)
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        lr_scheduler = {
            'scheduler': optim.lr_scheduler.ReduceLROnPlateau(optimizer, 
                                        mode='min', factor=0.25, patience=1),
            'monitor': 'val_loss', 
            'interval': 'epoch',
            'frequency': 1
        }
        return [optimizer], [lr_scheduler]

    def training_step(self, data, batch_idx):
        lr = self.optimizers().param_groups[0]['lr']
        self.log('learning_rate', lr, on_step=False, on_epoch=True, prog_bar=True, batch_size=batch_size)
        
        logits = self(data)
        loss = self.criterion(logits.squeeze(), data.y)
        self.log('train_loss', loss, on_step=False, on_epoch=True, prog_bar=True, batch_size=batch_size)
        
        self.train_auc(F.softmax(logits.squeeze(), dim=1)[:, 1], data.y)
        self.log("train_auc", self.train_auc, on_step=False, on_epoch=True, prog_bar=False, batch_size=batch_size)
        
        self.train_acc(logits.argmax(dim=-1), data.y)
        self.log('train_acc', self.train_acc, on_step=False, on_epoch=True, prog_bar=False, batch_size=batch_size)
        
        return loss

    def validation_step(self, data, batch_idx):
        logits = self(data)
        loss = self.criterion(logits.squeeze(), data.y)
        self.log('val_loss', loss, on_step=False, on_epoch=True, prog_bar=True, batch_size=batch_size)
        
        self.val_auc(F.softmax(logits.squeeze(), dim=1)[:, 1], data.y)
        self.log("val_auc", self.val_auc, on_step=False, on_epoch=True, prog_bar=True, batch_size=batch_size)
        
        self.val_acc(logits.argmax(dim=-1), data.y)
        self.log('val_acc', self.val_acc, on_step=False, on_epoch=True, prog_bar=True, batch_size=batch_size)

    def test_step(self, data, batch_idx):
        logits = self(data)
        
        self.test_auc(F.softmax(logits.squeeze(), dim=1)[:, 1], data.y)
        self.log("test_auc", self.test_auc, on_step=False, on_epoch=True, prog_bar=True, batch_size=batch_size)
        
        self.test_acc(logits.argmax(dim=-1), data.y)
        self.log('test_acc', self.test_acc, on_step=False, on_epoch=True, prog_bar=True, batch_size=batch_size)

## Classifying Task

In [None]:
from qml_ssl.utils import vmf_kde_on_circle, pca_proj, tsne_proj, plot_training

train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, val_size, test_size])

train_loader = pyg_loader.DataLoader(train_dataset, batch_size=batch_size, num_workers = 12)
val_loader = pyg_loader.DataLoader(val_dataset, batch_size=batch_size, num_workers = 12)
test_loader = pyg_loader.DataLoader(test_dataset, batch_size=batch_size, num_workers = 12)

GCN_encoder = GCN_Encoder()

Graph_pl = ModelPL_Classify(model=GCN_encoder, learning_rate=0.01)

logger = pl.loggers.CSVLogger(save_dir='logs', name='Graph_pl', version=0)

summary_callback = pl.callbacks.ModelSummary(max_depth=8)
callbacks = [summary_callback]

# embeddings, labels = generate_embeddings(Contrastive_Graph_pl, val_loader)
# pca_proj(embeddings, labels)
# tsne_proj(embeddings, labels)
# vmf_kde_on_circle(embeddings, labels)

trainer_GCNN = pl.Trainer(max_epochs=20, 
                            devices="auto",
                            callbacks=callbacks,
                            logger=logger,)

train_result = trainer_GCNN.fit(Graph_pl, train_dataloaders=train_loader, val_dataloaders=val_loader)
test_result = trainer_GCNN.test(dataloaders=test_loader)

### Reduce number of node - Preprocessing

In [None]:
k = 15

dataset.transform = T.Compose([TopKMomentum(k=k), KNNGroup(k=5, attr_name="h")])
sample = dataset[10]
print(sample)

print(f'Node degree: {sample.num_edges / sample.num_nodes:.2f}')
print(f'Has isolated nodes: {sample.has_isolated_nodes()}')
print(f'Has self-loops: {sample.has_self_loops()}')
print(f'Is undirected: {sample.is_undirected()}')
nx.draw(pyg.utils.to_networkx(dataset[10]))
plt.show()

train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, val_size, test_size])

train_loader = pyg_loader.DataLoader(train_dataset, batch_size=batch_size, num_workers = 12)
val_loader = pyg_loader.DataLoader(val_dataset, batch_size=batch_size, num_workers = 12)
test_loader = pyg_loader.DataLoader(test_dataset, batch_size=batch_size, num_workers = 12)

GCN_encoder = GCN_Encoder()

Graph_pl = ModelPL_Classify(model=GCN_encoder, learning_rate=0.01)

logger = pl.loggers.CSVLogger(save_dir='logs', name=f'graph_pl_{k}', version=0)

summary_callback = pl.callbacks.ModelSummary(max_depth=8)
callbacks = [summary_callback]

# embeddings, labels = generate_embeddings(Contrastive_Graph_pl, val_loader)
# pca_proj(embeddings, labels)
# tsne_proj(embeddings, labels)
# vmf_kde_on_circle(embeddings, labels)

trainer_GCNN = pl.Trainer(max_epochs=10, 
                            devices="auto",
                            callbacks=callbacks,
                            logger=logger,)

train_result = trainer_GCNN.fit(Graph_pl, train_dataloaders=train_loader, val_dataloaders=val_loader)
test_result = trainer_GCNN.test(dataloaders=test_loader)

In [None]:
import pandas as pd
import os
import matplotlib.pyplot as plt

def extract_metrics(logs_dir='logs', prefix='graph_pl_'):
    metrics = []

    # Iterate over directories in the logs_dir
    for folder in os.listdir(logs_dir):
        
        if folder.startswith(prefix):
            
            k = int(folder[len(prefix):])  # Extract the value of k from the folder name
            metrics_file = os.path.join(logs_dir, folder, 'version_0', 'metrics.csv')
            
            print(folder, k, metrics_file)
            if os.path.exists(metrics_file):
                df = pd.read_csv(metrics_file)
                
                if 'test_auc' in df.columns and 'test_acc' in df.columns:
                    
                    # Get the last row's test_auc and test_acc
                    test_auc = df['test_auc'].dropna().values[-1]
                    test_acc = df['test_acc'].dropna().values[-1]
                    metrics.append((k, test_auc, test_acc))

    return metrics

def plot_metrics(metrics):
    metrics = sorted(metrics, key=lambda x: x[0])  # Sort by k value
    ks, test_aucs, test_accs = zip(*metrics)  # Unzip the list of tuples

    plt.figure(figsize=(10, 5))

    # Plot test AUC
    plt.subplot(1, 2, 1)
    plt.plot(ks, test_aucs, marker='o', linestyle='-', color='b')
    plt.xlabel('k')
    plt.ylabel('Test AUC')
    plt.title('Test AUC vs k')

    # Plot test accuracy
    plt.subplot(1, 2, 2)
    plt.plot(ks, test_accs, marker='o', linestyle='-', color='r')
    plt.xlabel('k')
    plt.ylabel('Test Accuracy')
    plt.title('Test Accuracy vs k')

    plt.tight_layout()
    plt.show()
    
# logs_dir = 'logs'
# prefix = 'graph_pl'
metrics = extract_metrics()
print(metrics)
plot_metrics(metrics)

## Contrastive Task

In [None]:
from qml_ssl.utils import vmf_kde_on_circle, pca_proj, tsne_proj, plot_training

train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, val_size, test_size])

train_loader = pyg_loader.DataLoader(train_dataset, batch_size=batch_size, num_workers = 12)
val_loader = pyg_loader.DataLoader(val_dataset, batch_size=batch_size, num_workers = 12)
test_loader = pyg_loader.DataLoader(test_dataset, batch_size=batch_size, num_workers = 12)

encoder = GCN_Encoder()

Contrastive_Graph_pl = ModelPL_Contrastive(model=encoder, learning_rate=0.001)

logger = pl.loggers.CSVLogger(save_dir='logs', name='Contrastive_Graph_pl')

summary_callback = pl.callbacks.ModelSummary(max_depth=8)
callbacks = [summary_callback]

embeddings, labels = generate_embeddings(Contrastive_Graph_pl, val_loader)
pca_proj(embeddings, labels)
# tsne_proj(embeddings, labels)
# vmf_kde_on_circle(embeddings, labels)

In [None]:
trainer_GCNN = pl.Trainer(max_epochs=15, 
                            devices="auto",
                            callbacks=callbacks,
                            logger=logger,)

train_result = trainer_GCNN.fit(Contrastive_Graph_pl, train_dataloaders=train_loader, val_dataloaders=val_loader)
test_result = trainer_GCNN.test(dataloaders=test_loader)

In [None]:
embeddings, labels = generate_embeddings(Contrastive_Graph_pl, val_loader)
pca_proj(embeddings, labels)
# tsne_proj(embeddings, labels)
# vmf_kde_on_circle(embeddings, labels)

In [None]:
class LinearProbePL(pl.LightningModule):
    def __init__(self, pretrained_model, num_classes, learning_rate=0.001):
        super().__init__()
        self.pretrained_model = pretrained_model
        self.classifier = nn.Sequential(
            nn.Linear(pretrained_model.output_dim, 8),
            nn.ReLU(),
            nn.Linear(8, num_classes),
        )
        self.learning_rate = learning_rate
        self.criterion = nn.CrossEntropyLoss()

        from torchmetrics import AUROC, Accuracy 
        self.train_auc = AUROC(task='binary')
        self.val_auc = AUROC(task='binary')
        self.test_auc = AUROC(task='binary')
        
        self.train_acc = Accuracy(task='binary')
        self.val_acc = Accuracy(task='binary')
        self.test_acc = Accuracy(task='binary')
        
        for param in self.pretrained_model.parameters():
            param.requires_grad = False

    def forward(self, x):
        embeddings = self.pretrained_model(x)
        logits = self.classifier(embeddings)
        return logits
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        lr_scheduler = {
            'scheduler': optim.lr_scheduler.ReduceLROnPlateau(optimizer, 
                                        mode='min', factor=0.25, patience=1),
            'monitor': 'val_loss', 
            'interval': 'epoch',
            'frequency': 1
        }
        return [optimizer], [lr_scheduler]

    def training_step(self, data, batch_idx):
        lr = self.optimizers().param_groups[0]['lr']
        self.log('learning_rate', lr, on_step=False, on_epoch=True, prog_bar=True, batch_size=batch_size)
        
        logits = self(data)
        loss = self.criterion(logits.squeeze(), data.y)
        self.log('train_loss', loss, on_step=False, on_epoch=True, prog_bar=True, batch_size=batch_size)
        
        self.train_auc(F.softmax(logits.squeeze(), dim=1)[:, 1], data.y)
        self.log("train_auc", self.train_auc, on_step=False, on_epoch=True, prog_bar=False, batch_size=batch_size)
        
        self.train_acc(logits.argmax(dim=-1), data.y)
        self.log('train_acc', self.train_acc, on_step=False, on_epoch=True, prog_bar=False, batch_size=batch_size)
        
        return loss

    def validation_step(self, data, batch_idx):
        logits = self(data)
        loss = self.criterion(logits.squeeze(), data.y)
        self.log('val_loss', loss, on_step=False, on_epoch=True, prog_bar=True, batch_size=batch_size)
        
        self.val_auc(F.softmax(logits.squeeze(), dim=1)[:, 1], data.y)
        self.log("val_auc", self.val_auc, on_step=False, on_epoch=True, prog_bar=True, batch_size=batch_size)
        
        self.val_acc(logits.argmax(dim=-1), data.y)
        self.log('val_acc', self.val_acc, on_step=False, on_epoch=True, prog_bar=True, batch_size=batch_size)

    def test_step(self, data, batch_idx):
        logits = self(data)
        
        self.test_auc(F.softmax(logits.squeeze(), dim=1)[:, 1], data.y)
        self.log("test_auc", self.test_auc, on_step=False, on_epoch=True, prog_bar=True, batch_size=batch_size)
        
        self.test_acc(logits.argmax(dim=-1), data.y)
        self.log('test_acc', self.test_acc, on_step=False, on_epoch=True, prog_bar=True, batch_size=batch_size)

pretrained_model = GCN_Encoder()
# checkpoint_path = './logs/Contrastive_Graph_pl/version_0/checkpoints/epoch=14-step=14070.ckpt'
pretrained_model.load_state_dict(Contrastive_Graph_pl.model.state_dict())

num_classes = 2  # Adjust this based on your dataset
logger = pl.loggers.CSVLogger(save_dir='logs', name='Contrastive_Graph_finetune_pl')
# Create an instance of the LinearProbePL module
linear_probe_model = LinearProbePL(pretrained_model=pretrained_model, num_classes=num_classes, learning_rate=0.001)

trainer_linear_probe = pl.Trainer(max_epochs=15, 
                            devices="auto",
                            callbacks=callbacks,
                            logger=logger,)

train_result = trainer_linear_probe.fit(linear_probe_model, train_dataloaders=train_loader, val_dataloaders=val_loader)
test_result = trainer_linear_probe.test(dataloaders=test_loader)

In [None]:

import pandas as pd
import matplotlib.pyplot as plt

def plot_metrics_from_csv(metrics_file, metrics={'val_loss', 'val_acc', 'val_auc'}):
    df = pd.read_csv(metrics_file)

    required_columns = metrics
    if not required_columns.issubset(df.columns):
        raise ValueError("The CSV file does not contain the required metrics.")

    df = df.sort_values('epoch')

    df = df.fillna(method='ffill')

    epochs = df['epoch']
    val_loss = df['val_loss']
    val_acc = df['val_acc']
    val_auc = df['val_auc']

    plt.figure(figsize=(5*len(metrics), 5))

    plt.subplot(1, len(metrics), 1)
    plt.plot(epochs, val_loss, marker='o', linestyle='-', color='b', label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Validation Loss')
    plt.legend()

    plt.subplot(1, 3, 2)
    plt.plot(epochs, val_acc, marker='o', linestyle='-', color='r', label='Validation Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.title('Validation Accuracy')
    plt.legend()

    plt.subplot(1, 3, 3)
    plt.plot(epochs, val_auc, marker='o', linestyle='-', color='g', label='Validation AUC')
    plt.xlabel('Epoch')
    plt.ylabel('AUC')
    plt.title('Validation AUC')
    plt.legend()

    plt.tight_layout()
    plt.show()
    
plot_metrics_from_csv('./logs/Contrastive_Graph_pl/version_0/metrics.csv')