In [1]:
import os
import sys

FS_MOL_CHECKOUT_PATH = os.path.abspath("../")

os.chdir(FS_MOL_CHECKOUT_PATH)
sys.path.insert(0, FS_MOL_CHECKOUT_PATH)

In [2]:
import torch
from tqdm import tqdm
from fs_mol.data_modules.MXM_datamodule import MXMDataModule, MXMDataset
from torch.utils.data import DataLoader
from torch.nn import functional as F
import pytorch_lightning as pl

from dataclasses import dataclass


@dataclass(frozen=True)
class MXMNetTrainingConfig:
    # Training Settings:
    batch_size: int = 8
    train_support_count: int = 16
    train_query_count: int = 16
    train_shuffle: bool = True

    temprature: float = 0.07

    # Validation Settings:
    valid_support_count: int = 64
    valid_batch_size: int = 256

    # Model Settings:
    envelope_exponent: int = 6
    num_spherical: int = 7
    num_radial: int = 5
    dim: int = 128
    cutoff: int = 3.0
    layer: int = 5

    accumulate_grad_batches: int = 4
    learning_rate: float = 1e-4
    weight_decay: float = 1e-5

    dropout: float = 0.2

    encoder_dims = [128, 128, 256, 256, 512, 512]


config = MXMNetTrainingConfig()

data_module = MXMDataModule(
    "/FS-MOL/data/mxm/",
    batch_size=config.batch_size,
    support_size=config.train_support_count,
    query_size=config.train_query_count,
)

In [5]:
from typing import Any, Optional
import numpy as np
from pytorch_lightning.utilities.types import STEP_OUTPUT
from sklearn.metrics import auc, precision_recall_curve
import wandb
from MHNfs.mhnfs.modules import CrossAttentionModule_2
from MXMNet.model import Config, MXMNet
from pytorch_lightning.loggers import WandbLogger
from fewshot_utils.is_debugger_attached import is_debugger_attached

from fs_mol.models.protonet import calculate_mahalanobis_logits


class MXMNetLighteningModule(pl.LightningModule):
    def __init__(self, config: MXMNetTrainingConfig) -> None:
        super().__init__()
        self.config = config

        self.graph_encoder = MXMNet(
            Config(config.dim, config.layer, config.cutoff, config.encoder_dims, 512)
        )

        self.cross_attn = CrossAttentionModule_2(config.dim * config.layer, 64, 8, 0.5)

        self.validation_step_output = []

    def get_support_query(self, input_tensor, is_query_index):
        support_indices = (is_query_index == 0).nonzero().squeeze(1)
        query_indices = (is_query_index == 1).nonzero().squeeze(1)

        return input_tensor[support_indices], input_tensor[query_indices]

    def select_batch(self, input_tensor, batch_index, batch_no):
        current_batch_indices = (batch_index == batch_no).nonzero().squeeze(1)

        return input_tensor[current_batch_indices]

    # def get_logits_with_attn(self, graph_reprs, labels, query_index, batch_index):
    #     all_logits = []
    #     batch_size = batch_index.max().item() + 1
    #     for i in range(batch_size):
    #         current_batch_graph_repr = self.select_batch(graph_reprs, batch_index, i)
    #         current_batch_labels = self.select_batch(labels, batch_index, i)
    #         current_batch_query_index = self.select_batch(query_index, batch_index, i)

    #         support_repr, query_repr = self.get_support_query(
    #             current_batch_graph_repr, current_batch_query_index
    #         )

    #         support_labels, query_labels = self.get_support_query(
    #             current_batch_labels, current_batch_query_index
    #         )

    #         support_negative_indices = (support_labels == 0).nonzero().squeeze(1)
    #         support_positive_indices = (support_labels == 1).nonzero().squeeze(1)

    #         support_positive = support_repr[support_positive_indices]
    #         support_negative = support_repr[support_negative_indices]

    #         query_attn, support_active_attn, support_inactive_attn = self.cross_attn(
    #             query_repr, support_positive, support_negative
    #         )

    #     return torch.cat(all_logits, dim=0)

    def get_logits_per_batch(self, graph_reprs, labels, query_index, batch_index):
        all_logits = []
        batch_size = batch_index.max().item() + 1
        for i in range(batch_size):
            current_batch_graph_repr = self.select_batch(graph_reprs, batch_index, i)
            current_batch_labels = self.select_batch(labels, batch_index, i)

            current_batch_query_index = self.select_batch(query_index, batch_index, i)

            support_repr, query_repr = self.get_support_query(
                current_batch_graph_repr, current_batch_query_index
            )
            support_labels, query_labels = self.get_support_query(
                current_batch_labels, current_batch_query_index
            )

            logits = calculate_mahalanobis_logits(
                support_repr, support_labels, query_repr, device=self.device
            )

            all_logits.append(logits)
        return torch.cat(all_logits, dim=0)

    def training_step(self, batch):
        input_graphs, is_query, labels, batch_index = (
            batch["graphs"],
            batch["is_query"],
            batch["labels"],
            batch["batch_index"],
        )
        graph_representations = self.graph_encoder(input_graphs)
        # graph_representations = F.normalize(graph_representations, dim=-1)
        _, query_labels = self.get_support_query(labels, is_query)

        logits = self.get_logits_per_batch(graph_representations, labels, is_query, batch_index)

        loss = F.cross_entropy(logits / self.config.temprature, query_labels)

        self.log("loss", loss, on_step=True, on_epoch=False, batch_size=self.config.batch_size)
        self.log(
            "loss_per_epoch", loss, on_epoch=True, on_step=False, batch_size=self.config.batch_size
        )

        return loss

    def validation_step(self, batch, batch_idx):
        input_graphs, is_query, labels, batch_index = (
            batch["graphs"],
            batch["is_query"],
            batch["labels"],
            batch["batch_index"],
        )

        graph_representations = self.graph_encoder(input_graphs)
        # graph_representations = F.normalize(graph_representations, dim=-1)

        logits = self.get_logits_per_batch(graph_representations, labels, is_query, batch_index)
        _, query_labels = self.get_support_query(labels, is_query)

        auc_pr = self.calculate_delta_auc_pr(logits, query_labels)

        self.validation_step_output.append(auc_pr)

    def on_validation_epoch_end(self) -> None:
        mean_delta_auc_pr = np.mean(self.validation_step_output)
        self.log("mean_delta_auc_pr", mean_delta_auc_pr)

        self.validation_step_output.clear()

    def calculate_delta_auc_pr(self, logits, targets):
        predictions = F.softmax(logits, dim=1)[:, 1]

        precision, recall, _ = precision_recall_curve(
            targets.detach().cpu().numpy(), predictions.detach().cpu().numpy()
        )

        auc_score = auc(recall, precision)

        random_classifier_auc_pr = np.mean(targets.detach().cpu().numpy())

        return auc_score - random_classifier_auc_pr

    def configure_optimizers(self):
        return torch.optim.Adam(
            self.parameters(),
            lr=self.config.learning_rate,
            weight_decay=config.weight_decay,
            fused=True,
        )


sample_path = "/FS-MOL/lightning_logs/ls1b4s7r/checkpoints/epoch=19-step=54920.ckpt"

wandb_enabled = not is_debugger_attached()


def train(path):
    run_id = path.split("/")[3] if path is not None else None
    if wandb_enabled:
        wandb.init(project="MXMNet_New", config=config, id=run_id if run_id is not None else None)
    model = (
        MXMNetLighteningModule.load_from_checkpoint(path, config=config)
        if path is not None
        else MXMNetLighteningModule(config)
    )
    trainer = pl.Trainer(
        accelerator="gpu",
        devices=1,
        max_epochs=20,
        log_every_n_steps=1,
        logger=WandbLogger() if wandb_enabled else None,
        default_root_dir="/FS-MOL/MXM_Checkpoint/",
    )
    if wandb_enabled:
        wandb.watch(model, log="all")
    trainer.fit(model, datamodule=data_module)


train(path)

VBox(children=(Label(value='0.004 MB of 0.010 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.440296…

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016668199933337745, max=1.0…

  rank_zero_warn(
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Loading cached result from /FS-MOL/data/mxm/train/../cached/task_name_length_3aff30c7f2cc0298d1f5e520da7b4b1b8dbf22f66c4a76252e433c0bca2d8655.pt


  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type   | Params
-----------------------------------------
0 | graph_encoder | MXMNet | 1.1 M 
-----------------------------------------
1.1 M     Trainable params
0         Non-trainable params
1.1 M     Total params
4.384     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]