In [1]:
import os
import sys

FS_MOL_CHECKOUT_PATH = os.path.abspath('../')

os.chdir(FS_MOL_CHECKOUT_PATH)
sys.path.insert(0, FS_MOL_CHECKOUT_PATH)

In [2]:
from torch.nn import functional as F
import torch
from fs_mol.clip_like import FingerprintEncoder
from fs_mol.data.clip_dataset import CLIPDataset
from fs_mol.modules.gat import GAT_GraphEncoder, TrainConfig
from fs_mol.data.clip_fewshot_dataset import FSMOL
from fs_mol.models.protonet import calculate_mahalanobis_logits

In [3]:
train_dataset = CLIPDataset(root='/', raw_file_path='/FS-MOL/train_raw_mols.pt', dest_file_name='/FS-MOL/train_preprocessed_mols.pt')
# valid_dataset = CLIPDataset(root='/', raw_file_path='/FS-MOL/valid_raw_mols.pt', dest_file_name='/FS-MOL/valid_preprocessed_mols.pt')   
valid_dataset = torch.load('/FS-MOL/valid_none_dup_processed.pt')
# fewshot_dataset = FSMOL()

In [4]:
print(f'Train Dataset Size: {len(train_dataset)}')
print(f'Valid Dataset Size: {len(valid_dataset)}')

Train Dataset Size: 216827
Valid Dataset Size: 9574


In [5]:
config = TrainConfig()

config.graph_encoder_num_layers

5

In [6]:
from torch_geometric.loader import DataLoader
from torch.utils.data import random_split

train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, drop_last=False)
valid_loader = DataLoader(valid_dataset, batch_size=config.batch_size, drop_last=False)

batch = next(iter(train_loader))

In [7]:
# deg_hist = PNAConv.get_degree_histogram(loader=loader)


In [8]:
# class PNA_GraphEncoder(nn.Module):
#     def __init__(self, deg_hist) -> None:
#         super().__init__()
        
#         self.batch_size = 32
        
#         self.gnn = PNA(
#             in_channels=32,
#             hidden_channels=128,
#             num_layers=10,
#             out_channels=128,
#             edge_dim=1,
#             aggregators=['sum', 'mean', 'max', 'std'],
#             scalers=['amplification', 'attenuation'],
#             deg=deg_hist
#         )
        
#         self.readout = CombinedGraphReadout(
#             node_dim=128,
#             out_dim=512,
#             num_heads=12,
#             head_dim=64,
#         )
        
#     def forward(self, batch):
#         edge_index, edge_attr = to_undirected(batch.edge_index, batch.edge_attr, 32)
#         node_features = self.gnn(batch.x, edge_index, edge_attr=edge_attr)
#         return self.readout(node_features, batch.batch, self.batch_size)
    
# model = PNA_GraphEncoder(deg_hist=deg_hist).cuda()

# model(batch).shape

In [9]:
model = GAT_GraphEncoder(config).to('cuda')

model(batch)

tensor([[-0.2696,  0.4219,  0.9399,  ...,  0.2700, -0.9707, -0.3564],
        [-0.3516,  0.2873,  0.7667,  ...,  0.1339, -0.9777, -0.4234],
        [-0.3090,  0.3072,  0.8088,  ...,  0.1545, -0.7782, -0.4980],
        ...,
        [-0.3513,  0.3004,  0.7584,  ...,  0.3638, -0.6603, -0.4933],
        [-0.3678,  0.4369,  0.8178,  ...,  0.3691, -0.8012, -0.5012],
        [-0.3207,  0.3194,  0.8370,  ...,  0.2484, -0.8435, -0.3207]],
       device='cuda:0', grad_fn=<AddmmBackward0>)

In [10]:
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from wandb import wandb
from torchmetrics import Accuracy
from torch_geometric.data import Batch



valid_step = 0

class ClipLike(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.temp = 0.1
        # self.graph_encoder = PyG_GraphFeatureExtractor(GraphFeatureExtractorConfig())
        self.graph_encoder = GAT_GraphEncoder(config)
        self.fingerprint_encoder = FingerprintEncoder(2048, config.fingerprint_encoder_hidden_dim, config.fingerprint_encoder_output_dim, config.fingerprint_encoder_dropout)
        self.acc_metric = Accuracy(task='binary')
    
    def training_step(self, batch, batch_idx):
        encoded_graphs = self.graph_encoder(batch)
        encoded_fingerprints = self.fingerprint_encoder(batch.fingerprint.reshape(-1, 2048).to(torch.float32))
        # TODO: Check the OpenAI's Codebase on CLIP and make sure this is right.
        logits = encoded_graphs @ encoded_fingerprints.T / self.temp
        
        current_batch_size = encoded_graphs.shape[0]
        
        targets = torch.eye(current_batch_size).cuda() / self.temp
        
        loss = F.cross_entropy(logits, targets, reduction="none")
        
        loss = loss.mean()
        self.log('train_loss', loss, batch_size=current_batch_size)
        return loss
    
    def validation_step(self, batch, batch_idx):
        encoded_graphs = self.graph_encoder(batch)
        encoded_fingerprints = self.fingerprint_encoder(batch.fingerprint.reshape(-1, 2048).to(torch.float32))
        # TODO: Check the OpenAI's Codebase on CLIP and make sure this is right.
        logits = encoded_graphs @ encoded_fingerprints.T / self.temp
        
        current_batch_size = encoded_graphs.shape[0]
        
        targets = torch.eye(current_batch_size).cuda() / self.temp
        
        loss = F.cross_entropy(logits, targets, reduction="none")
        
        loss = loss.mean()
        self.log('valid_loss', loss, batch_size=current_batch_size)
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)
    
    def validation_epoch_end(self, outputs) -> None:
        fewshot_dataset = FSMOL()
        mean_acc = 0
        for support_set, query_set in fewshot_dataset:
            support_labels = [mol.bool_label for mol in support_set]
            query_labels = [mol.bool_label for mol in query_set]
            
            encoded_support_graphs = self.graph_encoder(Batch.from_data_list(support_set).to('cuda'))
            encoded_query_graphs = self.graph_encoder(Batch.from_data_list(query_set).to('cuda'))
            
            logits = calculate_mahalanobis_logits(encoded_support_graphs, torch.tensor(support_labels, device=torch.device('cuda')), encoded_query_graphs, torch.device('cuda'))
            porbabilities = torch.softmax(logits, dim=1)
            
            predictions = torch.argmax(porbabilities, dim=1)
            
            acc = self.acc_metric(predictions, torch.tensor(query_labels, device=torch.device('cuda')))
            mean_acc += acc
            
        wandb.log({'fewshot_acc': mean_acc / len(fewshot_dataset)})
            
            
    
wandb.init(config=config)
model = ClipLike()
wandb.watch(model, log='all')
    
trainer = pl.Trainer(logger=WandbLogger(), accelerator='gpu', devices=1, max_epochs=100)
trainer.fit(model=model, train_dataloaders=train_loader, val_dataloaders=valid_loader)



Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mathecoder[0m ([33mdest[0m). Use [1m`wandb login --relogin`[0m to force relogin


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01666892743334453, max=1.0)…

  rank_zero_warn(
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name                | Type               | Params
-----------------------------------------------------------
0 | graph_encoder       | GAT_GraphEncoder   | 1.0 M 
1 | fingerprint_encoder | FingerprintEncoder | 2.6 M 
2 | acc_metric          | BinaryAccuracy     | 0     
-----------------------------------------------------------
3.6 M     Trainable params
0         Non-trainable params
3.6 M     Total params
14.513    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]