In [None]:
# MoLFormer-XL Enhanced Version â€“ Compatible with Two-Column Input Files

# Install required packages
!pip install transformers torch numpy pandas tqdm -q

import torch
from transformers import AutoModel, AutoTokenizer
import numpy as np
import pandas as pd
from google.colab import files
from tqdm import tqdm
import gc

def load_molformer_model(model_variant="both-10pct"):
    """
    Load the MoLFormer-XL model.
    model_variant: 
        - 'both-10pct' (recommended)
        - 'both-1pct' (faster, lower accuracy)
    """
    model_name = f"ibm/MoLFormer-XL-{model_variant}"
    print(f"Loading model: {model_name}...")

    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
    model = AutoModel.from_pretrained(model_name, deterministic_eval=True, trust_remote_code=True)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    model.eval()

    if torch.cuda.is_available():
        print(f"Using GPU: {torch.cuda.get_device_name(0)}")
    else:
        print("Using CPU (processing may be slow)")

    return model, tokenizer, device

def parse_smiles_file(content):
    """
    Parse a SMILES file and return a list of entries.
    Supported formats:
        1. Two columns: index<TAB>SMILES (recommended)
        2. One column: SMILES only
        3. Three columns: id<TAB>db_id<TAB>SMILES (legacy format)
    """
    data = []
    lines = content.strip().split('\n')

    # Detect and skip header if present
    first_line = lines[0].split('\t')
    start_idx = 0

    if len(first_line) >= 2:
        try:
            int(first_line[0])  # attempt numeric index
        except ValueError:
            print(f"Skipping header: {lines[0][:100]}...")
            start_idx = 1

    # Parse content
    for i, line in enumerate(lines[start_idx:], 1):
        if not line.strip():
            continue

        parts = line.split('\t')

        if len(parts) >= 3:  # legacy three-column format
            data.append({
                'id': parts[0],
                'db_id': parts[1],
                'smiles': parts[2].strip()
            })
        elif len(parts) == 2:  # two-column format
            data.append({
                'id': parts[0].strip(),
                'db_id': f'MOL_{parts[0].strip()}',
                'smiles': parts[1].strip()
            })
        elif len(parts) == 1:  # one-column SMILES-only format
            data.append({
                'id': str(i),
                'db_id': f'MOL_{i}',
                'smiles': parts[0].strip()
            })
        else:
            print(f"Warning: Skipping malformed line {i+start_idx}: {line[:50]}...")

    return data

def smiles_to_embeddings(smiles_list, model, tokenizer, device, batch_size=16):
    """Convert a batch of SMILES strings into embeddings."""
    vectors = []
    failed_indices = []

    # Adjust batch size based on available hardware
    if torch.cuda.is_available():
        batch_size = min(batch_size, 16)
    else:
        batch_size = min(batch_size, 4)

    for i in tqdm(range(0, len(smiles_list), batch_size), desc="Processing"):
        batch = smiles_list[i:i+batch_size]

        try:
            inputs = tokenizer(
                batch,
                return_tensors="pt",
                padding=True,
                truncation=True,
                max_length=202
            )
            inputs = {k: v.to(device) for k, v in inputs.items()}

            with torch.no_grad():
                outputs = model(**inputs)

                if hasattr(outputs, 'pooler_output') and outputs.pooler_output is not None:
                    embeddings = outputs.pooler_output
                else:
                    hidden_states = outputs.last_hidden_state
                    embeddings = hidden_states[:, 0, :]  # CLS token

            vectors.extend(embeddings.cpu().numpy())

        except Exception:
            print(f"\nBatch {i//batch_size + 1} failed. Processing entries individually...")

            for j, smiles in enumerate(batch):
                try:
                    inputs = tokenizer(
                        [smiles],
                        return_tensors="pt",
                        padding=True,
                        truncation=True,
                        max_length=202
                    )
                    inputs = {k: v.to(device) for k, v in inputs.items()}

                    with torch.no_grad():
                        outputs = model(**inputs)
                        if hasattr(outputs, 'pooler_output') and outputs.pooler_output is not None:
                            embedding = outputs.pooler_output[0]
                        else:
                            embedding = outputs.last_hidden_state[0, 0, :]

                    vectors.append(embedding.cpu().numpy())

                except Exception:
                    print(f"  Skipping invalid SMILES (index {i+j}): {smiles[:30]}...")
                    vectors.append(np.zeros(768))
                    failed_indices.append(i+j)

        if torch.cuda.is_available() and i % 100 == 0:
            torch.cuda.empty_cache()
            gc.collect()

    return np.array(vectors), failed_indices

def main():
    """Main program."""
    print("="*60)
    print(" MoLFormer-XL SMILES Embedding Tool ")
    print("="*60)

    print("\nSupported file formats:")
    print("1. Two-column (tab-separated): index<TAB>SMILES")
    print("2. One-column: SMILES only")
    print("3. Three-column: id<TAB>db_id<TAB>SMILES")
    print("\nHeaders are automatically detected and skipped.")

    print("\nSelect model variant:")
    print("1. MoLFormer-XL-both-10pct (recommended)")
    print("2. MoLFormer-XL-both-1pct (faster)")

    choice = input("Select (1 or 2, default 1): ").strip() or '1'
    variant = "both-10pct" if choice == '1' else "both-1pct"

    # Load model
    model, tokenizer, device = load_molformer_model(variant)

    print("\nUpload SMILES file...")
    print("Examples:")
    print("1\tCCCCCC")
    print("2\tCOc1ccccc1")
    print("3\tCCN(CC)CC")

    uploaded = files.upload()

    if not uploaded:
        print("No file uploaded.")
        return

    filename = list(uploaded.keys())[0]
    content = uploaded[filename].decode('utf-8')

    print("\nParsing file...")
    data = parse_smiles_file(content)

    if not data:
        print("Error: No valid SMILES found. Check file format.")
        return

    smiles_list = [item['smiles'] for item in data]

    print(f"\nParsed {len(smiles_list)} SMILES entries.")
    print("\nFirst 5 entries:")
    for i, item in enumerate(data[:5], 1):
        preview = item['smiles'][:50] + "..." if len(item['smiles']) > 50 else item['smiles']
        print(f"{i}. ID={item['id']}, SMILES={preview}")

    if len(smiles_list) > 100:
        cont = input(f"\nProcessing {len(smiles_list)} entries may take time. Continue? (y/n, default y): ").strip().lower()
        if cont == 'n':
            print("Operation cancelled.")
            return

    print("\nStarting embedding generation...")
    vectors, failed_indices = smiles_to_embeddings(smiles_list, model, tokenizer, device)

    if failed_indices:
        print(f"\nWarning: {len(failed_indices)} entries failed.")
        print("Sample failed IDs:", [data[idx]['id'] for idx in failed_indices[:10]],
              "..." if len(failed_indices) > 10 else "")

    print("\nEmbedding generation completed.")
    print(f"Array shape: {vectors.shape}")
    print(f"Embedding dimension: {vectors.shape[1]}")
    print(f"Memory usage: {vectors.nbytes / (1024*1024):.2f} MB")

    norms = np.linalg.norm(vectors, axis=1)
    print("\nVector statistics:")
    print(f"- Mean norm: {np.mean(norms):.4f}")
    print(f"- Std norm: {np.std(norms):.4f}")
    print(f"- Min norm: {np.min(norms):.4f}")
    print(f"- Max norm: {np.max(norms):.4f}")

    zero_vectors = np.sum(norms < 0.01)
    if zero_vectors > 0:
        print(f"- Zero vectors: {zero_vectors} (likely from failed SMILES)")

    print("\nSaving output files...")

    vector_file = 'smiles_vectors_molformer.npy'
    np.save(vector_file, vectors)

    full_data = {
        'vectors': vectors,
        'ids': [item['id'] for item in data],
        'db_ids': [item['db_id'] for item in data],
        'smiles': smiles_list,
        'failed_indices': failed_indices,
        'model': f'MoLFormer-XL-{variant}',
        'dimension': vectors.shape[1]
    }

    full_file = 'molformer_xl_data.npz'
    np.savez_compressed(full_file, **full_data)

    csv_file = 'molformer_xl_mapping.csv'
    df = pd.DataFrame({
        'ID': [item['id'] for item in data],
        'SMILES': smiles_list,
        'Vector_Index': list(range(len(data))),
        'Status': ['Failed' if i in failed_indices else 'OK' for i in range(len(data))]
    })
    df.to_csv(csv_file, index=False, encoding='utf-8')

    print("\nDownloading files...")
    files.download(vector_file)
    files.download(full_file)
    files.download(csv_file)

    print("\nFiles generated:")
    print(f"1. {vector_file}: Embedding matrix")
    print(f"2. {full_file}: Full dataset with metadata")
    print(f"3. {csv_file}: Mapping file")

    print("\n" + "="*60)
    print(" Done. Files downloaded. ")
    print("="*60)

    return vectors

# Run program
if __name__ == "__main__":
    vectors = main()
