In [None]:
#---------------------------------------------------
# Using LLMs (Gemma, Mistral , Llama) for generating description for the tables
#----------------------------------------------------
import os
import json
import torch
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM
from huggingface_hub import login
from more_itertools import chunked

import torch._dynamo
torch._dynamo.disable()

# === Config setting ===
model_id = "google/gemma-3-27b-it"
hf_token = "**"
input_path = "./Arxiv-papers2000/sample_nonnull_cleaned_10000.json"
output_path = "./Arxiv-papers2000/sample_nonnull_cleaned_10000.json_with_gemma-3-27b-it_sent_Gemma-p9_WOC.json"
field_name = "Gemma_sentence"
batch_size = 2


login(token=hf_token)

# === Remove LaTeX caption
import re
def mask_caption(text):
    start = text.find(r'\caption{')
    if start == -1:
        return text
    i = start + len(r'\caption{')
    depth = 1
    while i < len(text):
        if text[i] == '{':
            depth += 1
        elif text[i] == '}':
            depth -= 1
            if depth == 0:
                break
        i += 1
    if depth == 0:
        return text[:start] + text[i+1:]
    return text

# === Load tokenizer and model  --- on multi-GPU (#2)
print("Loading tokenizer and model...")
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto",  # splits model for multiple GPUs automatically
    torch_dtype=torch.float32
)
model.eval()

# === Load data
with open(input_path, "r", encoding="utf-8") as f:
    data = json.load(f)
print(f"Total records: {len(data)}")

# === Batch processing
for batch in tqdm(chunked(data, batch_size), total=len(data)//batch_size, desc="Generating sentences"):
    prompts = []
    valid_entries = []
    for entry in batch:
        table = mask_caption(entry.get("table_content", "").strip())
        caption = entry.get("table_caption", "").strip()

        if not table or not caption:
            entry[field_name] = "[NO TABLE OR CAPTION]"
            continue

        prompt = (
            f"<s>[INST] Given the following table content, write a short description of the table.  "
            f"Respond in plain English. Add the [END] token at the end.\n\n"
            f"Table:\n{table}\n\n"
            f"[/INST] Answer:"
        )
        prompts.append(prompt)
        valid_entries.append(entry)

    if not prompts:
        continue

    inputs = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True, max_length=1024).to(model.device)

    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=200,
            do_sample=True,
            temperature=0.01,
            top_p=0.9,
            pad_token_id=tokenizer.eos_token_id
        )

    decoded = tokenizer.batch_decode(outputs, skip_special_tokens=True)

    for entry, out in zip(valid_entries, decoded):
        if "[END]" in out:
            sentence = out.split("Answer:")[-1].split("[END]")[0].strip()
        else:
            sentence = out.split("Answer:")[-1].strip()
        entry[field_name] = sentence

# === Save
with open(output_path, "w", encoding="utf-8") as f:
    json.dump(data, f, indent=2, ensure_ascii=False)

print(f"Saved generated sentences to: {output_path}")

In [None]:
import os
import json
import torch
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM
from huggingface_hub import login
from more_itertools import chunked
import re

# === Config ===
model_id = "mistralai/Mistral-Small-24B-Instruct-2501"
hf_token = "**"
input_path = "./Arxiv-papers2000/sample_nonnull_cleaned_10000.json"
output_path = "./Arxiv-papers2000/sample_nonnull_cleaned_10000.json_with_Mistral-24B_sent_p3_WOC.json"
field_name = "Mistral24B_sentence"
batch_size = 4

# === Login to Hugging Face
login(token=hf_token)

# === Remove LaTeX caption
def mask_caption(text):
    start = text.find(r'\\caption{')
    if start == -1:
        return text
    i = start + len(r'\\caption{')
    depth = 1
    while i < len(text):
        if text[i] == '{':
            depth += 1
        elif text[i] == '}':
            depth -= 1
            if depth == 0:
                break
        i += 1
    if depth == 0:
        return text[:start] + text[i+1:]
    return text

# === Load tokenizer and model on multi-GPU
print("Loading tokenizer and model...")
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto",
    torch_dtype=torch.float16
)
model.eval()

# === Load data
with open(input_path, "r", encoding="utf-8") as f:
    data = json.load(f)
print(f"Total records: {len(data)}")

# === Batch processing
for batch in tqdm(chunked(data, batch_size), total=len(data)//batch_size, desc="Generating sentences"):
    prompts = []
    valid_entries = []
    for entry in batch:
        table = mask_caption(entry.get("table_content", "").strip())
        caption = entry.get("table_caption", "").strip()

        if not table or not caption:
            entry[field_name] = "[NO TABLE OR CAPTION]"
            continue

        prompt = (
            f"<s>[INST] Given the following table content, write a short description of the table.  "
            f"Respond in plain English. Add the [END] token at the end.\n\n"
            f"Table:\n{table}\n\n"
            f"[/INST] Answer:"
        )
        prompts.append(prompt)
        valid_entries.append(entry)

    if not prompts:
        continue

    inputs = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True, max_length=1024).to(model.device)

    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=200,
            do_sample=True,
            temperature=0.01,
            top_p=0.9,
            pad_token_id=tokenizer.eos_token_id
        )

    decoded = tokenizer.batch_decode(outputs, skip_special_tokens=True)

    for entry, out in zip(valid_entries, decoded):
        if "[END]" in out:
            sentence = out.split("Answer:")[-1].split("[END]")[0].strip()
        else:
            sentence = out.split("Answer:")[-1].strip()
        entry[field_name] = sentence

# === Save
with open(output_path, "w", encoding="utf-8") as f:
    json.dump(data, f, indent=2, ensure_ascii=False)

print(f"Saved generated sentences to: {output_path}")
