In [1]:
import math
from massspecgym.data.datasets import MSnDataset
from massspecgym.featurize import SpectrumFeaturizer
from massspecgym.data import RetrievalDataset, MassSpecDataModule
import pytorch_lightning as pl
from torch.utils.data import DataLoader

from torch.utils.data import Dataset


import torch
import torch.nn as nn
import torch.nn.functional as F
import typing as T
from typing import List, Optional
from torch_geometric.nn import GATConv, global_mean_pool

from massspecgym.models.base import Stage
from massspecgym.models.de_novo.base import DeNovoMassSpecGymModel

from phantoms.utils.custom_tokenizers import ByteBPETokenizerWithSpecialTokens
from phantoms.utils.constants import PAD_TOKEN, SOS_TOKEN, EOS_TOKEN, UNK_TOKEN

from torch.nn import TransformerDecoder, TransformerDecoderLayer

In [2]:
class PositionalEncoding(nn.Module):
    """
    Standard sinusoidal positional encoding as in "Attention is All You Need".
    Expects shape [seq_len, batch_size, d_model] if batch_first=False.
    """
    def __init__(self, d_model: int, max_len: int = 5000):
        super().__init__()
        self.d_model = d_model

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
        )
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        # shape becomes [max_len, d_model]

        pe = pe.unsqueeze(1)  # [max_len, 1, d_model]
        self.register_buffer("pe", pe)  # not a learnable parameter

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        x: [seq_len, batch_size, d_model]
        Returns the input plus positional encodings.
        """
        seq_len = x.size(0)
        # Add the encoding only up to seq_len
        # self.pe[:seq_len] is [seq_len, 1, d_model]
        return x + self.pe[:seq_len]

class SMILESDataset(Dataset):
    def __init__(self, smiles_list, tokenizer, max_len=200):
        self.smiles_list = smiles_list
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.smiles_list)

    def __getitem__(self, idx):
        text = self.smiles_list[idx]
        # Encode the text with special tokens (the post-processor adds SOS and EOS)
        token_ids = self.tokenizer.encode(text, add_special_tokens=True)
        # Truncate if necessary
        token_ids = token_ids[:self.max_len]
        # For teacher forcing, input is all tokens except the last,
        # target is all tokens except the first.
        input_ids = token_ids[:-1]
        target_ids = token_ids[1:]
        return {
            "input": torch.tensor(input_ids, dtype=torch.long),
            "target": torch.tensor(target_ids, dtype=torch.long)
        }

class SMILESLanguageModel(pl.LightningModule):
    def __init__(self, vocab_size: int, d_model: int = 256, nhead: int = 4,
                 num_decoder_layers: int = 4, dropout: float = 0.1,
                 pad_token_id: int = 0, max_len: int = 200):
        super().__init__()
        self.d_model = d_model
        decoder_layer = TransformerDecoderLayer(
            d_model=d_model, nhead=nhead, dim_feedforward=4*d_model,
            dropout=dropout, activation="relu", batch_first=False
        )
        self.decoder = TransformerDecoder(decoder_layer, num_layers=num_decoder_layers)
        self.pos_encoder = PositionalEncoding(d_model=d_model, max_len=max_len)
        self.embedding = nn.Embedding(vocab_size, d_model, padding_idx=pad_token_id)
        self.fc_out = nn.Linear(d_model, vocab_size)
        self.criterion = nn.CrossEntropyLoss(ignore_index=pad_token_id)
    def forward(self, tgt_input):
        # tgt_input: [seq_len, batch]
        emb = self.embedding(tgt_input) * math.sqrt(self.d_model)
        emb = self.pos_encoder(emb)
        batch = tgt_input.size(1)
        # For standalone LM, we can use a zero "memory"
        memory = torch.zeros(1, batch, self.d_model, device=tgt_input.device)
        seq_len = tgt_input.size(0)
        causal_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=torch.bool, device=tgt_input.device), diagonal=1)
        output = self.decoder(tgt=emb, memory=memory, tgt_mask=causal_mask)
        logits = self.fc_out(output)
        return logits
    def training_step(self, batch, batch_idx):
        logits = self.forward(batch["input"])  # [seq_len, batch, vocab_size]
        loss = self.criterion(logits.view(-1, logits.size(-1)), batch["target"].view(-1))
        self.log("loss", loss)
        return loss
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-4)

# Full model: GATDeNovoTransformer (same as you provided)
class GATDeNovoTransformer(DeNovoMassSpecGymModel):
    """
    Example GAT -> (single embedding) -> Transformer Decoder for SMILES generation.
    """
    def __init__(
        self,
        input_dim: int,
        d_model: int = 256,
        nhead: int = 4,
        num_gat_layers: int = 3,
        num_decoder_layers: int = 4,
        num_gat_heads: int = 4,
        gat_dropout: float = 0.6,
        smiles_tokenizer: ByteBPETokenizerWithSpecialTokens = None,
        start_token: str = SOS_TOKEN,
        end_token: str = EOS_TOKEN,
        pad_token: str = PAD_TOKEN,
        unk_token: str = UNK_TOKEN,
        dropout: float = 0.1,
        max_smiles_len: int = 200,
        k_predictions: int = 1,
        temperature: T.Optional[float] = 1.0,
        pre_norm: bool = False,
        chemical_formula: bool = False,
        log_only_loss_at_stages: Optional[list] = [Stage.TRAIN],
        *args, **kwargs
    ):
        super().__init__(log_only_loss_at_stages=log_only_loss_at_stages, *args, **kwargs)
        if smiles_tokenizer is None:
            raise ValueError("Must provide a ByteBPETokenizerWithSpecialTokens instance.")
        self.smiles_tokenizer = smiles_tokenizer
        self.vocab_size = self.smiles_tokenizer.get_vocab_size()

        for tok in [start_token, end_token, pad_token, unk_token]:
            if tok not in self.smiles_tokenizer.get_vocab():
                raise ValueError(f"Special token '{tok}' not in tokenizer vocab")

        self.start_token_id = self.smiles_tokenizer.token_to_id(start_token)
        self.end_token_id   = self.smiles_tokenizer.token_to_id(end_token)
        self.pad_token_id   = self.smiles_tokenizer.token_to_id(pad_token)
        self.unk_token_id   = self.smiles_tokenizer.token_to_id(unk_token)

        self.d_model = d_model
        self.max_smiles_len = max_smiles_len
        self.k_predictions = k_predictions
        self.temperature = temperature if k_predictions > 1 else None
        self.chemical_formula = chemical_formula

        self.gat_layers = nn.ModuleList()
        self.gat_layers.append(
            GATConv(
                in_channels=input_dim,
                out_channels=d_model // num_gat_heads,
                heads=num_gat_heads,
                dropout=gat_dropout,
                add_self_loops=True
            )
        )
        for _ in range(num_gat_layers - 1):
            self.gat_layers.append(
                GATConv(
                    in_channels=d_model,
                    out_channels=d_model // num_gat_heads,
                    heads=num_gat_heads,
                    dropout=gat_dropout,
                    add_self_loops=True
                )
            )

        self.encoder_fc = nn.Linear(d_model, d_model)
        if self.chemical_formula:
            self.formula_mlp = nn.Sequential(
                nn.Linear(128, d_model),
                nn.ReLU(),
                nn.Linear(d_model, d_model),
            )

        decoder_layer = TransformerDecoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=4 * d_model,
            dropout=dropout,
            activation="relu",
            batch_first=False,
            norm_first=pre_norm
        )
        self.transformer_decoder = TransformerDecoder(decoder_layer, num_layers=num_decoder_layers)
        self.pos_encoder = PositionalEncoding(d_model=d_model, max_len=max_smiles_len)
        self.decoder_embed = nn.Embedding(self.vocab_size, d_model, padding_idx=self.pad_token_id)
        self.decoder_fc = nn.Linear(d_model, self.vocab_size)
        self.criterion = nn.CrossEntropyLoss(ignore_index=self.pad_token_id)
        self._init_weights()

    def _init_weights(self):
        nn.init.xavier_uniform_(self.encoder_fc.weight)
        nn.init.zeros_(self.encoder_fc.bias)
        nn.init.normal_(self.decoder_embed.weight, mean=0, std=self.d_model**-0.5)
        nn.init.xavier_uniform_(self.decoder_fc.weight)
        nn.init.zeros_(self.decoder_fc.bias)

    def step(self, batch: dict, stage: Stage = Stage.NONE) -> dict:
        output_dict = self.forward(batch)
        loss = output_dict["loss"]
        self.log(
            f"{stage.to_pref()}loss",
            loss,
            prog_bar=True,
            batch_size=batch["spec"].num_graphs,
        )
        if stage not in self.log_only_loss_at_stages:
            mols_pred = self.decode_smiles(batch)
            output_dict["mols_pred"] = mols_pred
        else:
            output_dict["mols_pred"] = None
        return output_dict

    def forward(self, batch: dict) -> dict:
        spec = batch["spec"]
        smiles_list = batch["mol"]
        x, edge_index, batch_idx = spec.x, spec.edge_index, spec.batch
        for gat in self.gat_layers:
            x = gat(x, edge_index)
            x = F.elu(x)
        x = global_mean_pool(x, batch_idx)
        if self.chemical_formula and ("formula" in batch):
            formula = batch["formula"].float().to(x.device)
            x = x + self.formula_mlp(formula)
        memory = self.encoder_fc(x)
        memory = memory.unsqueeze(0)  # [1, batch, d_model]

        encoded_smiles = self.smiles_tokenizer.encode_batch(smiles_list)
        smiles_ids = torch.tensor(encoded_smiles, dtype=torch.long, device=x.device)
        tgt_input  = smiles_ids[:, :-1]
        tgt_output = smiles_ids[:, 1:]
        tgt_input  = tgt_input.transpose(0, 1).contiguous()
        tgt_output = tgt_output.transpose(0, 1).contiguous()
        tgt_key_padding_mask = (tgt_input == self.pad_token_id).transpose(0, 1)
        tgt_embed = self.decoder_embed(tgt_input) * math.sqrt(self.d_model)
        tgt_embed = self.pos_encoder(tgt_embed)
        tgt_len = tgt_input.size(0)
        causal_mask = self._generate_square_subsequent_mask(tgt_len).to(tgt_embed.device)
        decoded = self.transformer_decoder(
            tgt=tgt_embed,
            memory=memory,
            tgt_mask=causal_mask,
            tgt_key_padding_mask=tgt_key_padding_mask
        )
        logits = self.decoder_fc(decoded)
        logits = logits.transpose(0, 1).contiguous()
        tgt_output = tgt_output.transpose(0, 1).contiguous()
        loss = self.criterion(logits.view(-1, self.vocab_size), tgt_output.view(-1))
        return dict(loss=loss)

    def decode_smiles(self, batch: dict) -> List[List[str]]:
        spec = batch["spec"]
        x, edge_index, batch_idx = spec.x, spec.edge_index, spec.batch
        for gat in self.gat_layers:
            x = gat(x, edge_index)
            x = F.elu(x)
        x = global_mean_pool(x, batch_idx)
        if self.chemical_formula and ("formula" in batch):
            formula = batch["formula"].float().to(x.device)
            x = x + self.formula_mlp(formula)
        memory = self.encoder_fc(x).unsqueeze(0)
        batch_size = memory.size(1)
        device = memory.device
        all_decoded_smiles = [[] for _ in range(batch_size)]
        for _ in range(self.k_predictions):
            generated_tokens = torch.full((1, batch_size), self.start_token_id,
                                           dtype=torch.long, device=device)
            finished = [False]*batch_size
            decoded_sequences: List[List[int]] = [[] for _ in range(batch_size)]
            for step in range(self.max_smiles_len):
                tgt_embed = self.decoder_embed(generated_tokens) * math.sqrt(self.d_model)
                tgt_embed = self.pos_encoder(tgt_embed)
                tgt_len = tgt_embed.size(0)
                causal_mask = self._generate_square_subsequent_mask(tgt_len).to(device)
                output = self.transformer_decoder(
                    tgt=tgt_embed,
                    memory=memory,
                    tgt_mask=causal_mask
                )
                last_logits = self.decoder_fc(output[-1])
                if self.temperature is None:
                    next_token = torch.argmax(last_logits, dim=-1)
                else:
                    probs = F.softmax(last_logits / self.temperature, dim=-1)
                    next_token = torch.multinomial(probs, num_samples=1).squeeze(-1)
                next_token = next_token.unsqueeze(0)
                generated_tokens = torch.cat([generated_tokens, next_token], dim=0)
                for i in range(batch_size):
                    if not finished[i]:
                        token_id = next_token[0, i].item()
                        if token_id == self.end_token_id:
                            finished[i] = True
                        else:
                            decoded_sequences[i].append(token_id)
                if all(finished):
                    break
            batch_smiles = []
            for seq_ids in decoded_sequences:
                text = self.smiles_tokenizer.decode(seq_ids, skip_special_tokens=True)
                batch_smiles.append(text)
            for i in range(batch_size):
                all_decoded_smiles[i].append(batch_smiles[i])
        return all_decoded_smiles

    def _generate_square_subsequent_mask(self, sz: int) -> torch.Tensor:
        return torch.triu(torch.ones(sz, sz, dtype=torch.bool), diagonal=1)

In [None]:
spectra_mgf = "/Users/macbook/CODE/Majer:MassSpecGym/data/MSn/min_sample_trees.mgf"
split_file = "/Users/macbook/CODE/Majer:MassSpecGym/data/MSn/20241211_split.tsv"

In [None]:
split_file = "/Users/macbook/CODE/Majer:MassSpecGym/data/MSn/20241211_split.tsv"
config = {
    'features': ['binned_peaks'],
    'feature_attributes': {
        'binned_peaks': {
            'max_mz': 1000,
            'bin_width': 0.25,
            'to_rel_intensities': True,
        },
    },
}
featurizer = SpectrumFeaturizer(config, mode='torch')
batch_size = 12
input_dim = 4000

In [None]:
msn_dataset = MSnDataset(
    pth=spectra_mgf,
    featurizer=featurizer,
    mol_transform=None,
    max_allowed_deviation=0.005
)

In [None]:
data_module_msn = MassSpecDataModule(
    dataset=msn_dataset,
    batch_size=batch_size,
    split_pth=split_file,
    num_workers=0,
)

In [None]:
SMILES_TOKENIZER_SAVE_PATH = "/Users/macbook/CODE/Majer:MassSpecGym/data/tokenizers/smiles_tokenizer.json"
smiles_tokenizer = ByteBPETokenizerWithSpecialTokens(tokenizer_path=SMILES_TOKENIZER_SAVE_PATH)

In [None]:
smiles_data = msn_dataset.smiles

In [None]:
# Create the dataset and dataloader
pretrain_dataset = SMILESDataset(smiles_data, smiles_tokenizer, max_len=200)
pretrain_dataloader = DataLoader(
    pretrain_dataset,
    batch_size=32,
    shuffle=True,
    collate_fn=lambda batch: {
        "input": torch.nn.utils.rnn.pad_sequence(
            [item["input"] for item in batch], batch_first=False, padding_value=smiles_tokenizer.token_to_id(PAD_TOKEN)
        ),
        "target": torch.nn.utils.rnn.pad_sequence(
            [item["target"] for item in batch], batch_first=False, padding_value=smiles_tokenizer.token_to_id(PAD_TOKEN)
        )
    }
)

In [None]:
# Create the language model
vocab_size = smiles_tokenizer.get_vocab_size()
model_pretrain = SMILESLanguageModel(
    vocab_size=vocab_size,
    d_model=256,
    nhead=4,
    num_decoder_layers=4,
    dropout=0.1,
    pad_token_id=smiles_tokenizer.token_to_id(PAD_TOKEN),
    max_len=200
)

# Train using PyTorch Lightning Trainer
trainer = pl.Trainer(max_epochs=10, accelerator="cpu", devices=1)  # adjust as available
trainer.fit(model_pretrain, pretrain_dataloader)

# Save the pretrained weights
torch.save(model_pretrain.state_dict(), "smiles_decoder_pretrained.pth")
print("Pretraining complete and weights saved.")

In [None]:
model_full = GATDeNovoTransformer(
    input_dim=input_dim,  # example feature dimension from your spec.x shape
    d_model=256,
    nhead=4,
    num_gat_layers=3,
    num_decoder_layers=4,
    num_gat_heads=4,
    gat_dropout=0.6,
    smiles_tokenizer=smiles_tokenizer,
    dropout=0.1,
    max_smiles_len=200,
    k_predictions=1,
    temperature=1.0,
    pre_norm=False,
    chemical_formula=False
)

In [None]:
# Load pretrained decoder weights into the full model
pretrained_dict = torch.load("smiles_decoder_pretrained.pth", map_location="cpu")
model_dict = model_full.state_dict()

In [None]:
# Only update the decoder-related parts (keys matching "decoder", "pos_encoder", "decoder_embed", "decoder_fc")
pretrained_keys = {k: v for k, v in pretrained_dict.items() if k in model_dict and ("decoder" in k or "pos_encoder" in k or "embedding" in k or "fc_out" in k or "decoder_embed" in k or "decoder_fc" in k)}
model_dict.update(pretrained_keys)
model_full.load_state_dict(model_dict)
print("Loaded pretrained decoder weights into the full model.")

In [None]:
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import TensorBoardLogger

tb_logger = TensorBoardLogger("logs", name="gat_de_novo_transformer")
trainer_full = Trainer(
    accelerator="cpu",                # Change to "gpu" if available
    devices=1,
    max_epochs=2,                     # Adjust as needed
    log_every_n_steps=10,
    limit_train_batches=2,
    limit_val_batches=2,
    limit_test_batches=2,
    logger=tb_logger,
)

In [None]:
trainer_full.fit(model_full, datamodule=data_module_msn)

In [None]:
trainer_full.test(model_full, datamodule=data_module_msn)

In [3]:
# import torch
# import torch.nn as nn
# import torch.nn.functional as F
# import typing as T
# from torch_geometric.nn import GATConv, global_mean_pool
#
# from massspecgym.models.base import Stage
# from massspecgym.models.de_novo.base import DeNovoMassSpecGymModel
#
# # Adjust import paths if needed:
# from phantoms.utils.custom_tokenizers import ByteBPETokenizerWithSpecialTokens
# from phantoms.utils.constants import PAD_TOKEN, SOS_TOKEN, EOS_TOKEN, UNK_TOKEN
#
# from torch.nn import TransformerDecoder, TransformerDecoderLayer

class GATDeNovoTransformer(DeNovoMassSpecGymModel):
    """
    Example GAT -> (single embedding) -> Transformer Decoder for SMILES generation.
    """

    def __init__(
        self,
        input_dim: int,              # node feature dimension
        d_model: int = 256,
        nhead: int = 4,
        num_gat_layers: int = 3,
        num_decoder_layers: int = 4,
        num_gat_heads: int = 4,
        gat_dropout: float = 0.6,
        smiles_tokenizer: ByteBPETokenizerWithSpecialTokens = None,
        start_token: str = SOS_TOKEN,
        end_token: str = EOS_TOKEN,
        pad_token: str = PAD_TOKEN,
        unk_token: str = UNK_TOKEN,
        dropout: float = 0.1,
        max_smiles_len: int = 200,
        k_predictions: int = 1,
        temperature: T.Optional[float] = 1.0,
        pre_norm: bool = False,
        chemical_formula: bool = False,
        *args, **kwargs
    ):
        super().__init__(*args, **kwargs)

        # ------------------------------
        #  1) SMILES Tokenizer
        # ------------------------------
        if smiles_tokenizer is None:
            raise ValueError("Must provide a ByteBPETokenizerWithSpecialTokens instance.")
        self.smiles_tokenizer = smiles_tokenizer
        self.vocab_size = self.smiles_tokenizer.get_vocab_size()

        # Make sure special tokens exist in vocab
        for tok in [start_token, end_token, pad_token, unk_token]:
            if tok not in self.smiles_tokenizer.get_vocab():
                raise ValueError(f"Special token '{tok}' not in tokenizer vocab")

        # IDs
        self.start_token_id = self.smiles_tokenizer.token_to_id(start_token)
        self.end_token_id   = self.smiles_tokenizer.token_to_id(end_token)
        self.pad_token_id   = self.smiles_tokenizer.token_to_id(pad_token)
        self.unk_token_id   = self.smiles_tokenizer.token_to_id(unk_token)

        # ------------------------------
        #  2) Hyperparams
        # ------------------------------
        self.d_model = d_model
        self.max_smiles_len = max_smiles_len
        self.k_predictions = k_predictions
        self.temperature = temperature if k_predictions > 1 else None
        self.chemical_formula = chemical_formula

        # ------------------------------
        #  3) GAT Encoder
        # ------------------------------
        # We'll build a stack of GATConv => ELU
        self.gat_layers = nn.ModuleList()

        # first layer
        self.gat_layers.append(
            GATConv(
                in_channels=input_dim,
                out_channels=d_model // num_gat_heads,
                heads=num_gat_heads,
                dropout=gat_dropout,
                add_self_loops=True
            )
        )
        # subsequent layers
        for _ in range(num_gat_layers - 1):
            self.gat_layers.append(
                GATConv(
                    in_channels=d_model,  # after heads, we combine into d_model
                    out_channels=d_model // num_gat_heads,
                    heads=num_gat_heads,
                    dropout=gat_dropout,
                    add_self_loops=True
                )
            )

        # ------------------------------
        #  4) Projection to d_model
        # ------------------------------
        self.encoder_fc = nn.Linear(d_model, d_model)

        # If you want formula => embed => add in:
        if self.chemical_formula:
            # Suppose formula is 128-dim or something
            self.formula_mlp = nn.Sequential(
                nn.Linear(128, d_model),
                nn.ReLU(),
                nn.Linear(d_model, d_model),
            )

        # ------------------------------
        #  5) Transformer Decoder
        # ------------------------------
        # We'll define a standard TransformerDecoder of `num_decoder_layers`
        decoder_layer = TransformerDecoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=4 * d_model,
            dropout=dropout,
            activation="relu",
            batch_first=False,     # By default PyTorch uses seq_first
            norm_first=pre_norm
        )
        self.transformer_decoder = TransformerDecoder(decoder_layer, num_layers=num_decoder_layers)

        self.pos_encoder = PositionalEncoding(d_model=d_model, max_len=max_smiles_len)

        # Embeddings and final projection for SMILES tokens
        self.decoder_embed = nn.Embedding(self.vocab_size, d_model, padding_idx=self.pad_token_id)
        self.decoder_fc = nn.Linear(d_model, self.vocab_size)

        # We use a standard CE loss ignoring the pad index
        self.criterion = nn.CrossEntropyLoss(ignore_index=self.pad_token_id)

        # init weights
        self._init_weights()

    def _init_weights(self):
        nn.init.xavier_uniform_(self.encoder_fc.weight)
        nn.init.zeros_(self.encoder_fc.bias)

        nn.init.normal_(self.decoder_embed.weight, mean=0, std=self.d_model**-0.5)
        nn.init.xavier_uniform_(self.decoder_fc.weight)
        nn.init.zeros_(self.decoder_fc.bias)

    # ------------------------------
    #  TRAINING/VAL/TEST STEPS
    # ------------------------------
    def step(self, batch: dict, stage: Stage = Stage.NONE) -> dict:
        """
        Mandatory method from base:
          - forward pass,
          - compute loss,
          - optionally decode molecules.
        """
        print("running")
        output_dict = self.forward(batch)
        print("running")
        loss = output_dict["loss"]

        # Log the loss
        self.log(
            f"{stage.to_pref()}loss",
            loss,
            prog_bar=True,
            batch_size=batch["spec"].num_graphs,  # or batch['spec'].size(0) if you do that
        )
        print("running")

        # For steps where we want metrics, set `mols_pred` so we can evaluate
        if stage not in self.log_only_loss_at_stages:
            mols_pred = self.decode_smiles(batch)  # or do beam if k_predictions>1
            output_dict["mols_pred"] = mols_pred
        else:
            output_dict["mols_pred"] = None

        return output_dict

    def forward(self, batch: dict) -> dict:
        """
        1) GAT-encode MSn trees => single [batch_size, d_model].
        2) Teacher-forcing the SMILES decoder => compute CE loss.
        """
        spec = batch["spec"]  # PyG DataBatch: .x, .edge_index, .batch
        smiles_list = batch["mol"]  # list of SMILES strings

        # ------------------------------
        #  1) GAT ENCODER
        # ------------------------------
        x, edge_index, batch_idx = spec.x, spec.edge_index, spec.batch
        # optional: debug shape
        # print("x shape: ", x.shape, "edge_index shape: ", edge_index.shape)

        for layer_idx, gat in enumerate(self.gat_layers):
            x = gat(x, edge_index)
            x = F.elu(x)

        # global mean pool => [batch_size, d_model]
        x = global_mean_pool(x, batch_idx)
        print("running")
        # optional formula
        if self.chemical_formula and ("formula" in batch):
            # Suppose batch["formula"] is shape [batch_size, 128]
            formula = batch["formula"].float().to(x.device)
            x = x + self.formula_mlp(formula)

        # project to exactly d_model (in case it changed shape)
        memory = self.encoder_fc(x)  # shape [batch_size, d_model]

        # We want shape [S=1, B, D] for the memory
        memory = memory.unsqueeze(0)  # => [1, batch_size, d_model]

        # ------------------------------
        #  2) TOKENIZE SMILES + TEACHER-FORCE
        # ------------------------------
        encoded_smiles = self.smiles_tokenizer.encode_batch(smiles_list)  # list of lists
        # Convert to a padded tensor
        # If your ByteBPETokenizer is already applying padding to max_len, you get length = self.max_smiles_len
        # shaped [batch_size, seq_len]
        smiles_ids = torch.tensor(encoded_smiles, dtype=torch.long, device=x.device)

        # We'll do teacher forcing:
        #  input to the decoder: [  <s>  token1 token2 ... token_{n-1} ]
        #  output to match:      [ token1 token2 ... token_{n-1}  </s> ]
        # So let's shift everything by 1
        tgt_input  = smiles_ids[:, :-1]  # [batch, seq_len-1]
        tgt_output = smiles_ids[:, 1:]   # [batch, seq_len-1]
        print("running")
        # Transpose to [seq_len-1, batch]
        tgt_input  = tgt_input.transpose(0, 1).contiguous()   # => [tgt_len, batch]
        tgt_output = tgt_output.transpose(0, 1).contiguous()  # => [tgt_len, batch]

        # 2.1) Make a key_padding_mask for shape [batch, tgt_len]
        #  True means "ignore this position"
        tgt_key_padding_mask = (tgt_input == self.pad_token_id).transpose(0, 1)  # => [batch, tgt_len]

        # 2.2) Embeddings
        tgt_embed = self.decoder_embed(tgt_input) * math.sqrt(self.d_model)
        # shape => [tgt_len, batch, d_model]

        # 2.3) Positional encoding
        tgt_embed = self.pos_encoder(tgt_embed)
        print("running")
        # 2.4) Subsequent (causal) mask for the decoder
        tgt_len = tgt_input.size(0)
        causal_mask = self._generate_square_subsequent_mask(tgt_len).to(tgt_embed.device)

        # 2.5) Pass through TransformerDecoder
        # memory: [1, batch, d_model]
        # tgt_embed: [tgt_len, batch, d_model]
        decoded = self.transformer_decoder(
            tgt=tgt_embed,
            memory=memory,
            tgt_mask=causal_mask,                 # shape [tgt_len, tgt_len]
            tgt_key_padding_mask=tgt_key_padding_mask
        )
        # => [tgt_len, batch, d_model]

        # 2.6) Final linear
        logits = self.decoder_fc(decoded)  # => [tgt_len, batch, vocab_size]

        # Flatten for CE loss
        logits = logits.transpose(0, 1).contiguous()  # => [batch, tgt_len, vocab_size]
        tgt_output = tgt_output.transpose(0, 1).contiguous()  # => [batch, tgt_len]

        loss = self.criterion(
            logits.view(-1, self.vocab_size),
            tgt_output.view(-1)
        )

        return dict(loss=loss)

    # ------------------------------
    #  3) GREEDY/TEMPERATURE DECODING
    # ------------------------------
    def decode_smiles(self, batch: dict) -> list[list[str]]:
        """
        Generate up to k_predictions SMILES for each example in the batch, 
        returning a nested list of shape [batch_size, k_predictions].
        We'll do a simple greedy or top-1 sampling approach here.
        """
        spec = batch["spec"]
        x, edge_index, batch_idx = spec.x, spec.edge_index, spec.batch

        # Encode GAT
        for gat in self.gat_layers:
            x = gat(x, edge_index)
            x = F.elu(x)
        x = global_mean_pool(x, batch_idx)

        if self.chemical_formula and ("formula" in batch):
            formula = batch["formula"].float().to(x.device)
            x = x + self.formula_mlp(formula)

        memory = self.encoder_fc(x).unsqueeze(0)  # => [1, batch, d_model]

        batch_size = memory.size(1)
        device = memory.device

        # We'll store k decoded sequences per example
        all_decoded_smiles = [[] for _ in range(batch_size)]

        # For simplicity, do k= self.k_predictions times
        for _ in range(self.k_predictions):
            # We'll do standard AR decoding up to max_smiles_len
            # shape => [1, batch]
            generated_tokens = torch.full(
                (1, batch_size),
                self.start_token_id,
                dtype=torch.long,
                device=device
            )
            finished = [False]*batch_size

            decoded_sequences: T.List[T.List[int]] = [[] for _ in range(batch_size)]

            for step in range(self.max_smiles_len):
                tgt_embed = self.decoder_embed(generated_tokens) * math.sqrt(self.d_model)
                tgt_embed = self.pos_encoder(tgt_embed)
                tgt_len = tgt_embed.size(0)

                causal_mask = self._generate_square_subsequent_mask(tgt_len).to(device)

                # decode
                output = self.transformer_decoder(
                    tgt=tgt_embed,
                    memory=memory,
                    tgt_mask=causal_mask
                )
                # shape => [tgt_len, batch, d_model]

                # final projection for last token
                last_logits = self.decoder_fc(output[-1])  # => [batch, vocab_size]

                if self.temperature is None:
                    # Greedy
                    next_token = torch.argmax(last_logits, dim=-1)  # => [batch]
                else:
                    # Temperature-based sampling
                    probs = F.softmax(last_logits / self.temperature, dim=-1)
                    next_token = torch.multinomial(probs, num_samples=1).squeeze(-1)

                # append
                next_token = next_token.unsqueeze(0)  # => [1, batch]
                generated_tokens = torch.cat([generated_tokens, next_token], dim=0)

                # check if ended
                for i in range(batch_size):
                    if not finished[i]:
                        token_id = next_token[0, i].item()
                        if token_id == self.end_token_id:
                            finished[i] = True
                        else:
                            decoded_sequences[i].append(token_id)
                if all(finished):
                    break

            # Now decode into text
            # For each item in batch, convert token IDs -> string
            batch_smiles = []
            for seq_ids in decoded_sequences:
                # We decode each token ID into text. 
                # NOTE: self.smiles_tokenizer.decode expects a *list of IDs for the entire sentence*.
                # If your tokenizer merges multiple IDs into subwords, you might get partial fragments.
                # But let's do it in one shot:
                text = self.smiles_tokenizer.decode(seq_ids, skip_special_tokens=True)
                batch_smiles.append(text)

            # Store
            for i in range(batch_size):
                all_decoded_smiles[i].append(batch_smiles[i])

        return all_decoded_smiles

    # ------------------------------
    #  4) Causal Mask Utility
    # ------------------------------
    def _generate_square_subsequent_mask(self, sz: int) -> torch.Tensor:
        """
        2D causal mask for the target sequence of length sz.
        shape => [sz, sz]
          True => blocked,  False => allowed
        PyTorch's Transformer can accept bool masks with True = no-attend.
        """
        mask = torch.triu(torch.ones(sz, sz, dtype=torch.bool), diagonal=1)
        return mask

In [4]:
SMILES_TOKENIZER_SAVE_PATH = "/Users/macbook/CODE/Majer:MassSpecGym/data/tokenizers/smiles_tokenizer.json"
smiles_tokenizer = ByteBPETokenizerWithSpecialTokens(tokenizer_path=SMILES_TOKENIZER_SAVE_PATH)

Loaded tokenizer from /Users/macbook/CODE/Majer:MassSpecGym/data/tokenizers/smiles_tokenizer.json.


In [44]:
spectra_mgf = "/Users/macbook/CODE/Majer:MassSpecGym/data/MSn/min_sample_trees.mgf"
split_file = "/Users/macbook/CODE/Majer:MassSpecGym/data/MSn/20241211_split.tsv"

In [45]:
config = {
    'features': ['binned_peaks'],
    'feature_attributes': {
        'binned_peaks': {
            'max_mz': 1000,
            'bin_width': 0.25,
            'to_rel_intensities': True,
        },
    },
}
featurizer = SpectrumFeaturizer(config, mode='torch')
batch_size = 12

In [46]:
msn_dataset = MSnDataset(
    pth=spectra_mgf,
    featurizer=featurizer,
    mol_transform=None,
    max_allowed_deviation=0.005
)

In [47]:
data_module_msn = MassSpecDataModule(
    dataset=msn_dataset,
    batch_size=batch_size,
    split_pth=split_file,
    num_workers=0,
)

In [48]:
input_dim = 4000  # Update this based on your actual data

# Initialize the Model
model = GATDeNovoTransformer(
    input_dim=input_dim,
    d_model=512,
    nhead=8,
    num_encoder_layers=3,
    num_decoder_layers=6,
    num_gat_heads=8,
    gat_dropout=0.6,
    smiles_tokenizer=smiles_tokenizer,
    start_token=SOS_TOKEN,
    end_token=EOS_TOKEN,
    pad_token=PAD_TOKEN,
    unk_token=UNK_TOKEN,
    dropout=0.1,
    max_smiles_len=200,
    k_predictions=1,
    temperature=1.0,
    pre_norm=False,
    chemical_formula=False  # Set to True if incorporating chemical formula embeddings
)

In [51]:
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import TensorBoardLogger

# Initialize TensorBoard Logger
tb_logger = TensorBoardLogger("logs", name="gat_de_novo_transformer")

# Initialize the Trainer
trainer = Trainer(
    accelerator="cpu",                # Change to "gpu" if available
    devices=1,
    max_epochs=20,                     # Adjust as needed
    log_every_n_steps=10,
    limit_train_batches=2,
    limit_val_batches=2,
    limit_test_batches=2,
    logger=tb_logger,
)

# Train the Model
trainer.fit(model, datamodule=data_module_msn)

# Test the Model
trainer.test(model, datamodule=data_module_msn)

GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/Users/macbook/UTILS/anaconda3/envs/phantoms_env/lib/python3.11/site-packages/pytorch_lightning/trainer/setup.py:177: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.

   | Name                        | Type               | Params | Mode 
----------------------------------------------------------------------------
0  | gat_layers                  | ModuleList         | 2.6 M  | train
1  | encoder_fc                  | Linear             | 262 K  | train
2  | transformer_decoder         | TransformerDecoder | 25.2 M | train
3  | pos_encoder                 | PositionalEncoding | 0      | train
4  | decoder_embed               | Embedding          | 317 K  | train
5  | decoder_fc                  | Linear             | 318 K  | train
6  | criterion                   | CrossEntropyLoss   | 0      | train
7  | val_num_valid_mols          | Mea

Train dataset size: 82
Val dataset size: 6


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

running
running
running
running


/Users/macbook/UTILS/anaconda3/envs/phantoms_env/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.


running
running



Detected KeyboardInterrupt, attempting graceful shutdown ...


NameError: name 'exit' is not defined

In [None]:
# def is_valid_smiles(smiles: str) -> bool:
#     try:
#         mol = Chem.MolFromSmiles(smiles)
#         return mol is not None
#     except:
#         return False
# 
# def calculate_uniqueness(smiles_list: T.List[str]) -> float:
#     unique_smiles = set(smiles_list)
#     return len(unique_smiles) / len(smiles_list)
# 
# def calculate_novelty(smiles_list: T.List[str], training_set: T.Set[str]) -> float:
#     novel_smiles = [smi for smi in smiles_list if smi not in training_set]
#     return len(novel_smiles) / len(smiles_list)
# 
# def calculate_diversity(smiles_list: T.List[str]) -> float:
#     similarities = []
#     for i in range(len(smiles_list)):
#         mol1 = Chem.MolFromSmiles(smiles_list[i])
#         if mol1 is None:
#             continue
#         fp1 = Chem.RDKFingerprint(mol1)
#         for j in range(i + 1, len(smiles_list)):
#             mol2 = Chem.MolFromSmiles(smiles_list[j])
#             if mol2 is None:
#                 continue
#             fp2 = Chem.RDKFingerprint(mol2)
#             similarity = TanimotoSimilarity(fp1, fp2)
#             similarities.append(similarity)
#     if similarities:
#         average_similarity = sum(similarities) / len(similarities)
#         diversity = 1 - average_similarity  # Higher diversity when lower similarity
#     else:
#         diversity = 0.0
#     return diversity
# 
# def validation_step(self, batch, batch_idx):
#     outputs = self.forward(batch)
#     loss = outputs['loss']
#     self.log('val_loss', loss, on_epoch=True, prog_bar=True)
# 
#     if 'mols_pred' in outputs and outputs['mols_pred'] is not None:
#         mols_pred = outputs['mols_pred']
#         mols_true = batch['mol']
# 
#         # Validity
#         valid = [is_valid_smiles(smiles) for smiles in mols_pred]
#         validity_score = sum(valid) / len(valid) if valid else 0
#         self.log('val_validity', validity_score, on_epoch=True, prog_bar=True)
# 
#         # Uniqueness
#         uniqueness_score = calculate_uniqueness(mols_pred)
#         self.log('val_uniqueness', uniqueness_score, on_epoch=True, prog_bar=True)
# 
#         # Novelty (requires access to training set)
#         # Assuming 'self.training_set' is defined elsewhere in the model
#         if hasattr(self, 'training_set') and isinstance(self.training_set, set):
#             novelty_score = calculate_novelty(mols_pred, self.training_set)
#             self.log('val_novelty', novelty_score, on_epoch=True, prog_bar=True)
# 
#         # Diversity
#         diversity_score = calculate_diversity(mols_pred)
#         self.log('val_diversity', diversity_score, on_epoch=True, prog_bar=True)
# 
#     return {'val_loss': loss}