In [1]:
import os
import sys

FS_MOL_CHECKOUT_PATH = os.path.abspath("../")

os.chdir(FS_MOL_CHECKOUT_PATH)
sys.path.insert(0, FS_MOL_CHECKOUT_PATH)

In [2]:
from dataclasses import dataclass
from torch.nn import functional as F
import torch
from fs_mol.clip_like import FingerprintEncoder
from fs_mol.data.clip_dataset import CLIPDataset
from fs_mol.modules.gat import GAT_GraphEncoder, TrainConfig
from fs_mol.data.clip_fewshot_dataset import FSMOL
from fs_mol.models.protonet import calculate_mahalanobis_logits
import wandb
from fs_mol.data.torch_dl import FSMOLHTorchDataset, FSMOLTorchDataloader
from fs_mol.data import DataFold
from torch_geometric.loader import DataLoader
import atexit

from pytorch_lightning.loggers import WandbLogger
from fs_mol.utils.torch_utils import torchify

atexit.register(torch.cuda.empty_cache)
#config.dim, config.layer, config.cutoff, config.encoder_dims, 512

@dataclass(frozen=True)
class TrainConfig:
    # Training Settings:
    batch_size: int = 64
    train_support_count: int = 32
    train_query_count: int = 256
    train_shuffle: bool = True

    temprature: float = 0.07

    # Validation Settings:
    valid_support_count: int = 64
    valid_batch_size: int = 256

    # Model Settings:
    envelope_exponent: int = 6
    num_spherical: int = 7
    num_radial: int = 5
    dim: int = 256
    cutoff: int = 5.0
    layer: int = 7

    accumulate_grad_batches: int = 4
    learning_rate: float = 1e-4
    weight_decay: float = 1e-5

    dropout: float = 0.2

    encoder_dims = [128, 128, 256, 256, 512, 512]


config = TrainConfig()

In [3]:
def load_checkpoint_weights(path):
    checkpoint = torch.load(path)

    return checkpoint["state_dict"]

In [4]:
train_dataset = FSMOLHTorchDataset("train", "pyg")
valid_dataset = FSMOLHTorchDataset("valid", "pyg")
test_dataset = FSMOLHTorchDataset("test", "pyg")

KeyboardInterrupt: 

In [None]:
train_dl = FSMOLTorchDataloader(
    train_dataset,
    batch_size=config.batch_size,
    datatype="pyg",
    num_workers=0,
    shuffle=config.train_shuffle,
    support_count=config.train_support_count,
    query_count=config.train_query_count,
)


batches = next(iter(train_dl))

for i in batches:
    print(i)

TypeError: 'NoneType' object is not iterable

In [None]:
valid_dl_16 = FSMOLTorchDataloader(
    valid_dataset,
    batch_size=config.valid_batch_size,
    datatype="pyg",
    num_workers=0,
    support_count=16,
)
valid_dl_64 = FSMOLTorchDataloader(
    valid_dataset,
    batch_size=config.valid_batch_size,
    datatype="pyg",
    num_workers=0,
    support_count=64,
)
valid_dl_32 = FSMOLTorchDataloader(
    valid_dataset,
    batch_size=config.valid_batch_size,
    datatype="pyg",
    num_workers=0,
    support_count=32,
)
valid_dl_128 = FSMOLTorchDataloader(
    valid_dataset,
    batch_size=config.valid_batch_size,
    datatype="pyg",
    num_workers=0,
    support_count=128,
)

In [None]:
test_dl_16 = FSMOLTorchDataloader(
    test_dataset,
    batch_size=config.valid_batch_size,
    datatype="pyg",
    num_workers=0,
    support_count=16,
)
test_dl_32 = FSMOLTorchDataloader(
    test_dataset,
    batch_size=config.valid_batch_size,
    datatype="pyg",
    num_workers=0,
    support_count=32,
)
test_dl_64 = FSMOLTorchDataloader(
    test_dataset,
    batch_size=config.valid_batch_size,
    datatype="pyg",
    num_workers=0,
    support_count=64,
)
test_dl_128 = FSMOLTorchDataloader(
    test_dataset,
    batch_size=config.valid_batch_size,
    datatype="pyg",
    num_workers=0,
    support_count=128,
)

In [None]:
batches = next(iter(valid_dl_32))

print(list(batches)[0][0])

DataBatch(x=[10952], edge_index=[2, 23322], pos=[10952, 3], bool_label=[256], batch=[10952], ptr=[257])


In [None]:
from typing import Any, Callable, Optional, Union
from pytorch_lightning.core.optimizer import LightningOptimizer
from torch.optim.optimizer import Optimizer
from torch_geometric.nn.models.autoencoder import VGAE
from MXMNet.model import MXMNet, Config
from fs_mol.models.protonet import PyG_GraphFeatureExtractor, GraphFeatureExtractor
from fs_mol.modules.graph_feature_extractor import GraphFeatureExtractorConfig
from torch import nn
import pytorch_lightning as pl
from torch_geometric.utils import to_dense_batch
from torch_geometric.data import Batch
from pytorch_lightning.loggers import WandbLogger
from fs_mol.transformer_based_pretraining import ScaledDotProductAttention
from fs_mol.utils.metrics import compute_binary_task_metrics
import numpy as np
from torch.nn.utils import clip_grad_norm_
from itertools import chain


class ClipLike(pl.LightningModule):
    def __init__(self, config: TrainConfig):
        self.config = config
        super().__init__()
        self.automatic_optimization = False
        self.graph_encoder = MXMNet(
            Config(config.dim, config.layer, config.cutoff, config.encoder_dims, 512),
            num_spherical=config.num_spherical,
            num_radial=config.num_radial,
            envelope_exponent=config.envelope_exponent,
            dropout=config.dropout,
        )

    def calculate_feats(self, batch):
        encoded_graphs = self.graph_encoder(batch)
        # feats = torch.cat([encoded_graphs, batch.fingerprint.reshape(-1, 2048)], dim=1)

        return encoded_graphs

    def calc_loss(self, input):
        batch, labels, index_map = input
        feats = self.graph_encoder(batch)
        feats = F.normalize(feats, dim=-1)

        support_feats = feats[index_map == 0]
        query_feats = feats[index_map == 1]

        support_labels = labels[index_map == 0]
        query_labels = labels[index_map == 1]

        logits = calculate_mahalanobis_logits(
            support_feats, support_labels, query_feats, torch.device("cuda")
        )
        loss = F.cross_entropy(logits / config.temprature, query_labels)

        return loss, logits, query_labels

    def on_train_end(self):
        torch.cuda.empty_cache()

    def training_step(self, inputs, batch_idx):
        opt = self.optimizers()
        loss_acc = 0
        count = 1
        for input in inputs:
            loss, _, _ = self.calc_loss(input)
            self.manual_backward(loss)
            loss_acc += loss
            count += 1

        self.log("train_loss", loss_acc / count, on_step=True, batch_size=config.batch_size)
        if config.accumulate_grad_batches <= 1:
            # self.clip_gradients(opt, gradient_clip_val=0.5, gradient_clip_algorithm="norm")
            opt.step()
            opt.zero_grad()
        else:
            if (batch_idx + 1) % config.accumulate_grad_batches == 0:
                # self.clip_gradients(opt, gradient_clip_val=0.5, gradient_clip_algorithm="norm")
                opt.step()
                opt.zero_grad()

        # return loss

        # print('Going to Concat:')
        # flatted_vectors = torch.cat(resulting_vectors, dim=0)
        # support_feats = flatted_vectors[index_map == 0]
        # query_feats = flatted_vectors[index_map == 1]

        # print('Done!')
        # return loss
        # for batch in batches:
        #     loss, _, _ = self.calc_loss(batch)

        #     self.manual_backward(loss)
        #     self.log('train_loss', loss, on_step=True, on_epoch=False, batch_size=config.batch_size)

        # if (batch_idx + 1) % self.config.accumulate_grad_batches == 0:
        #     self.clip_gradients(opt, gradient_clip_val=.5, gradient_clip_algorithm="norm")
        #     opt.step()

    def validation_step(self, batches, batch_idx, loader_idx, dataloader_idx=0):
        for batch in batches:
            valid_loss, logits, query_labels = self.calc_loss(batch)

            self.log("valid_loss", valid_loss, on_step=False, on_epoch=True, batch_size=256)

            batch_preds = F.softmax(logits, dim=1).detach().cpu().numpy()

            metrics = compute_binary_task_metrics(
                predictions=batch_preds[:, 1], labels=query_labels.detach().cpu().numpy()
            )

            for k, v in metrics.__dict__.items():
                self.log(
                    f"{2 ** (loader_idx + 4)}_valid_{k}",
                    v,
                    on_epoch=True,
                    on_step=False,
                    batch_size=config.valid_batch_size,
                )

    def test_step(self, batches, batch_idx, loader_idx, dataloader_idx=0):
        for batch in batches:
            valid_loss, logits, query_labels = self.calc_loss(batch)

            self.log("valid_loss", valid_loss, on_step=False, on_epoch=True, batch_size=256)

            batch_preds = F.softmax(logits, dim=1).detach().cpu().numpy()

            metrics = compute_binary_task_metrics(
                predictions=batch_preds[:, 1], labels=query_labels.detach().cpu().numpy()
            )

            for k, v in metrics.__dict__.items():
                self.log(
                    f"{2 ** (loader_idx + 4)}_test_{k}",
                    v,
                    on_epoch=True,
                    on_step=False,
                    batch_size=config.valid_batch_size,
                )

    def optimizer_step(
        self,
        epoch: int,
        batch_idx: int,
        optimizer: Optimizer,
        optimizer_closure: Callable[[], Any] = None,
    ) -> None:
        return super().optimizer_step(epoch, batch_idx, optimizer, optimizer_closure)

    def configure_optimizers(self):
        return torch.optim.Adam(
            self.parameters(),
            lr=self.config.learning_rate,
            weight_decay=config.weight_decay,
            fused=True,
        )


# wandb.init(project='MXM-Test', config=config)
# trainer = pl.Trainer(accelerator='gpu', devices=1, max_epochs=10, log_every_n_steps=1, logger=WandbLogger(), default_root_dir='/FS-MOL/MXM_Checkpoint/')
# model = ClipLike(config).load_from_checkpoint('/FS-MOL/lightning_logs/m38aeocm/checkpoints/epoch=19-step=6440.ckpt', config=config)

# trainer.test(model, dataloaders=[test_dl_16, test_dl_32, test_dl_64, test_dl_128])
path = "/FS-MOL/lightning_logs/8yr5woq6/checkpoints/epoch=9-step=3220-v1.ckpt"

# path = None


def train_mxm(config):
    if path is not None:
        run_id = path.split("/")[3]
        wandb.init(project="MXM", config=config, id=run_id, resume=True)
        model = ClipLike.load_from_checkpoint(path, config=config)
    else:
        wandb.init(project="MXM", config=config)
        model = ClipLike(config)
    wandb.watch(model, log="all")
    trainer = pl.Trainer(
        accelerator="gpu",
        devices=1,
        max_epochs=10,
        log_every_n_steps=1,
        logger=WandbLogger(),
        default_root_dir="/FS-MOL/MXM_Checkpoint/",
    )

    trainer.fit(
        model=model,
        train_dataloaders=train_dl,
        val_dataloaders=[valid_dl_16, valid_dl_32, valid_dl_64, valid_dl_128],
    )


def test_mxm(path):
    wandb.init(project="MXM-Test", config=config)
    trainer = pl.Trainer(
        accelerator="gpu",
        devices=1,
        max_epochs=20,
        log_every_n_steps=1,
        logger=WandbLogger(),
        default_root_dir="/FS-MOL/MXM_Checkpoint/",
    )
    model = ClipLike.load_from_checkpoint(path, config=config)

    trainer.test(model, dataloaders=[test_dl_16, test_dl_32, test_dl_64, test_dl_128])


test_mxm(path)
train_mxm(config)

VBox(children=(Label(value='0.000 MB of 0.000 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
128_valid_acc/dataloader_idx_3,▁▃▄▅▆▂▂█▅▇
128_valid_avg_precision/dataloader_idx_3,▁▃▃▄▄▃▇▇█▇
128_valid_balanced_acc/dataloader_idx_3,▁▃▄▅▆▂▃█▅▇
128_valid_delta_auc_pr/dataloader_idx_3,▁▃▃▄▄▃▇▇█▇
128_valid_f1/dataloader_idx_3,▁▅▄▇▆▂▁█▄▇
128_valid_kappa/dataloader_idx_3,▁▃▄▅▆▂▂█▅▇
128_valid_optimistic_auc_pr/dataloader_idx_3,▁▃▄▄▄▃▇▇█▆
128_valid_optimistic_delta_auc_pr/dataloader_idx_3,▁▃▄▄▄▃▇▇█▆
128_valid_prec/dataloader_idx_3,▁▂▃▃▇▂▃█▅▆
128_valid_recall/dataloader_idx_3,▂▆▅█▅▃▁▆▃▆

0,1
128_valid_acc/dataloader_idx_3,0.72999
128_valid_avg_precision/dataloader_idx_3,0.77862
128_valid_balanced_acc/dataloader_idx_3,0.7261
128_valid_delta_auc_pr/dataloader_idx_3,0.31169
128_valid_f1/dataloader_idx_3,0.70542
128_valid_kappa/dataloader_idx_3,0.45032
128_valid_optimistic_auc_pr/dataloader_idx_3,0.76971
128_valid_optimistic_delta_auc_pr/dataloader_idx_3,0.30277
128_valid_prec/dataloader_idx_3,0.70049
128_valid_recall/dataloader_idx_3,0.71881


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.0166682136166249, max=1.0))…

  rank_zero_warn(
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
  rank_zero_warn(


Testing: 0it [00:00, ?it/s]



────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
          Test metric                     DataLoader 0                    DataLoader 1
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
          128_test_acc
     128_test_avg_precision
     128_test_balanced_acc
     128_test_delta_auc_pr
          128_test_f1
         128_test_kappa
   128_test_optimistic_auc_pr
128_test_optimistic_delta_auc_pr
         128_test_prec
        128_test_recall
        128_test_roc_auc
         128_test_size
          16_test_acc                  0.6561748459663684
     16_test_avg_precision             0.6735776124394305
      16_test_balanced_acc             0.6466572828182953
      16_test_delta_auc_pr            0.21322405746593268
           16_test_f1                  0.6073572579598342
         16_test_kappa                0.29076901390448034
   16_test_opt