In [None]:
# STEP 1: Install all required packages
!pip install -U pip
!pip install biomed-multi-alignment[examples] tdc

In [None]:
from google.colab import files
uploaded = files.upload()

In [None]:
# === MOUNT DRIVE ===
from google.colab import drive
drive.mount('/content/drive')

# === IMPORTS ===
import os, gc
import pandas as pd
import torch
from tqdm import tqdm
from mammal.model import Mammal
from fuse.data.tokenizers.modular_tokenizer.op import ModularTokenizerOp
from mammal.examples.dti_bindingdb_kd.task import DtiBindingdbKdTask

# === DEVICE SETUP ===
print("✅ PyTorch version:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("💡 Using GPU:", torch.cuda.get_device_name(0))
else:
    print("⚠️ GPU not available. Using CPU.")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# === CONFIG ===
input_path = "/content/step2_filtered_lipinski.csv"
output_path = "/content/drive/MyDrive/inference_results_partial.csv"
final_output_path = "/content/drive/MyDrive/TeamYOURTEAMNAME.csv"  # Change for final submission

target_sequence = """PEQPFIVLGQEEYGEHHSSIMHCRVDCSGRRVASLDVDGVIKVWSFNPIMQTKASSISKSPLLSLEWATKRDRLLLLGSGVGTVRLYDTEAKKNLCEININDNMPRILSLACSPNGASFVCSAAAPSLTSQVPGRLLLWDTKTMKQQLQFSLDPEPIAINCTAFNHNGNLLVTGAADGVIRLFDMQQHECAMSWRAHYGEVYSVEFSYDENTVYSIGEDGKFIQWNIHKSGLKVSEYSLPSDATGPFVLSGYSGYKQVQVPRGRLFAFDSEGNYMLTCSATGGVIYKLGGDEKVLESCLSLGGHRAPVVTVDWSTAMDCGTCLTASMDGKIKLTTLLAHKA"""
norm_y_mean = 5.79384684128215
norm_y_std = 1.33808027428196
BATCH_SIZE = 3  # Safe for T4 or lower

# === LOAD MODEL AND TOKENIZER ===
model = Mammal.from_pretrained("ibm/biomed.omics.bl.sm.ma-ted-458m.dti_bindingdb_pkd")
model.eval().to(device)
tokenizer_op = ModularTokenizerOp.from_pretrained("ibm/biomed.omics.bl.sm.ma-ted-458m.dti_bindingdb_pkd")

# === INPUT FILE LOADING AND RESUME SUPPORT ===
df = pd.read_csv(input_path)

if os.path.exists(output_path):
    existing_df = pd.read_csv(output_path)
    done_ids = set(existing_df["RandomID"])
    print(f"✅ Resuming from {len(done_ids)} predictions.")
else:
    existing_df = pd.DataFrame(columns=["RandomID", "Sel_50", "Score"])
    done_ids = set()

df = df[~df["RandomID"].isin(done_ids)].reset_index(drop=True)
print(f"🧪 Remaining to process: {len(df)} molecules")
all_batches = [df[i:i + BATCH_SIZE] for i in range(0, len(df), BATCH_SIZE)]

# === INFERENCE LOOP ===
for batch_df in tqdm(all_batches):
    try:
        samples, ids = [], []

        for _, row in batch_df.iterrows():
            try:
                sample_dict = {
                    "target_seq": target_sequence,
                    "drug_seq": row["SMILES"]
                }
                sample_dict = DtiBindingdbKdTask.data_preprocessing(
                    sample_dict=sample_dict,
                    tokenizer_op=tokenizer_op,
                    target_sequence_key="target_seq",
                    drug_sequence_key="drug_seq",
                    norm_y_mean=None,
                    norm_y_std=None,
                    device=device
                )
                samples.append(sample_dict)
                ids.append(row["RandomID"])
            except RuntimeError as e:
                print(f"⚠️ Skipping SMILES ({row['RandomID']}): {e}")
                torch.cuda.empty_cache()
                gc.collect()

        if not samples:
            continue

        # 🔁 Ensure tensors are moved to correct device (very important!)
        for i, sample in enumerate(samples):
            samples[i] = {k: v.to(device) if torch.is_tensor(v) else v for k, v in sample.items()}

        with torch.no_grad():
            batch_out = model.forward_encoder_only(samples)

        batch_out = DtiBindingdbKdTask.process_model_output(
            batch_out,
            scalars_preds_processed_key="model.out.dti_bindingdb_kd",
            norm_y_mean=norm_y_mean,
            norm_y_std=norm_y_std
        )

        scores = batch_out["model.out.dti_bindingdb_kd"]
        results_batch = pd.DataFrame({
            "RandomID": ids,
            "Sel_50": 0,
            "Score": [s.cpu().item() for s in scores]
        })

        results_batch.to_csv(output_path, mode='a', header=not os.path.exists(output_path), index=False)
        os.sync()  # Force disk sync

        del samples, ids, sample_dict, batch_out, results_batch, scores
        torch.cuda.empty_cache()
        gc.collect()

    except RuntimeError as e:
        print(f"🔥 Batch error: {e}")
        torch.cuda.empty_cache()
        gc.collect()

# === FINAL STEP: Select Top 50 ===
print("🎯 Selecting top 50 compounds...")
final_df = pd.read_csv(output_path)
final_df = final_df.sort_values(by="Score", ascending=False)
final_df.loc[final_df.head(50).index, "Sel_50"] = 1
final_df.to_csv(final_output_path, index=False)
print(f"✅ Submission file saved to: {final_output_path}")