In [None]:
import pandas as pd
import numpy as np
import re
import time
import tensorflow as tf
from sklearn.model_selection import train_test_split
from sklearn.metrics import precision_recall_fscore_support
from transformers import AutoTokenizer, TFAutoModelForSequenceClassification
import os
import urllib.request
import transformers
import logging
import sys
import torch

logging.basicConfig(level=logging.INFO)  # Add logging
print(torch.cuda.is_available())           
print(torch.cuda.get_device_name(0))        
if torch.cuda.is_available():
    print("GPU Name:", torch.cuda.get_device_name(0))
    print("Total VRAM:", round(torch.cuda.get_device_properties(0).total_memory / 1024**3, 2), "GB")
else:
    print("CUDA not available. GPU not detected.")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
# ----------------------
# 0. Load and Clean Dataset
# ----------------------
df = pd.read_csv('C:/Users/Jaque/Downloads/modern_slavery_NER_us_india_val1.csv', index_col=0)
df = df.drop_duplicates().reset_index(drop=True)

def get_only_words_from_strings(text):
    text = re.sub(r'[^a-zA-Z\s]', '', text)
    text = re.sub(r'\s+', ' ', text).strip()
    return re.sub(r'\d+', '', text).strip()

df['content_corrected'] = df['content_corrected'].apply(get_only_words_from_strings)
df = df[['content_corrected', 'modern_slavery_in_supply_chain']].dropna()
df['target'] = df['modern_slavery_in_supply_chain'].apply(lambda x: 1 if str(x).strip().lower() == 'yes' else 0)
df = df[['content_corrected', 'target']]

texts = df['content'].tolist()
labels = df['target'].tolist()

# ----------------------
# 1. Split Data
# ----------------------
train_texts, test_texts, train_labels, test_labels = train_test_split(texts, labels, test_size=0.2, random_state=42)

# ----------------------
# 2. Tokenization
# ----------------------
MODEL_NAME = "casehold/legalbert"  # SPECIFICALLY casehold/legalbert
MAX_LEN = 256
BATCH_SIZE = 16
EPOCHS = 3
LEARNING_RATE = 2e-5

# --- CACHE MANAGEMENT ---
cache_dir = "./legalbert_cache"  # Explicit cache directory (create it if it doesn't exist)
os.makedirs(cache_dir, exist_ok=True)  # Ensure the directory exists
print(f"Transformers cache directory: {cache_dir}")

# --- NETWORK CHECK (Simplified) ---
try:
    urllib.request.urlopen('https://www.google.com', timeout=5)
    print("Network connection OK")
except urllib.error.URLError as e:
    print(f"Network error: {e}")

# --- VERSION CHECK ---
print(f"Transformers version: {transformers.__version__}")
print(f"TensorFlow version: {tf.__version__}")
print(f"Python version: {sys.version}")

# --- TOKENIZER and MODEL LOADING ---
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=cache_dir, force_download=True)
model = TFAutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=2, cache_dir=cache_dir, force_download=True)

train_encodings = tokenizer(train_texts, truncation=True, padding=True, max_length=MAX_LEN, return_tensors='tf')
test_encodings = tokenizer(test_texts, truncation=True, padding=True, max_length=MAX_LEN, return_tensors='tf')

# ----------------------
# 3. Prepare TensorFlow Datasets
# ----------------------
train_dataset = tf.data.Dataset.from_tensor_slices((dict(train_encodings), train_labels)).shuffle(len(train_texts)).batch(BATCH_SIZE)
test_dataset = tf.data.Dataset.from_tensor_slices((dict(test_encodings), test_labels)).batch(BATCH_SIZE)

# ----------------------
# 4. Compile Model
# ----------------------
optimizer = tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE)
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
metrics = [tf.keras.metrics.SparseCategoricalAccuracy('accuracy')]

model.compile(optimizer=optimizer, loss=loss, metrics=metrics)

# ----------------------
# 5. Train the Model with Timer
# ----------------------
start_train = time.time()
model.fit(train_dataset, validation_data=test_dataset, epochs=EPOCHS)
end_train = time.time()
print(f"\n🕒 Training time: {end_train - start_train:.2f} seconds")

# ----------------------
# 6. Evaluate and Analyze Results
# ----------------------
start_eval = time.time()
results = model.evaluate(test_dataset, verbose=0)
end_eval = time.time()
print(f"🕒 Evaluation time: {end_eval - start_eval:.2f} seconds")

acc = results[1]
predictions = model.predict(test_dataset)
y_pred_logits = predictions.logits
y_pred = np.argmax(y_pred_logits, axis=1)
y_true = np.array(test_labels)

prec, rec, f1, _ = precision_recall_fscore_support(y_true, y_pred, average='weighted', zero_division=0)

# ----------------------
# 7. Store Results
# ----------------------
model_display_name = f'casehold/legalbert'  # Accurate model name
final_results = {
    'Model': model_display_name,
    'Accuracy': round(acc * 100, 4),
    'Precision': round(prec * 100, 4),
    'Recall': round(rec * 100, 4),
    'F1 Score': round(f1 * 100, 4),
}

if 'test_result' not in locals():
    test_result = pd.DataFrame()
test_result = pd.concat([test_result, pd.DataFrame([final_results])], ignore_index=True)

print("\n📊 Updated Test Results DataFrame:")
print(test_result)

# ----------------------
# 8. Save Fine-tuned Model and Tokenizer
# ----------------------
save_directory = "./casehold_legalbert_finetuned"  # Specific save directory
model.save_pretrained(save_directory)
tokenizer.save_pretrained(save_directory)
print(f"\n💾 Model and tokenizer saved to: {save_directory}")

# ----------------------
# 9. Optional Inference Function
# ----------------------
def predict(text):
    start_pred = time.time()
    inputs = tokenizer(text, return_tensors="tf", truncation=True, padding=True, max_length=MAX_LEN)
    logits = model(inputs).logits
    end_pred = time.time()

    prob = tf.nn.softmax(logits, axis=1).numpy()[0]
    label = tf.argmax(logits, axis=1).numpy()[0]

    print(f"🕒 Prediction time: {end_pred - start_pred:.4f} seconds")
    return {"predicted_class": int(label), "probability_yes": float(prob[1]), "probability_no": float(prob[0])}

# Example
print("\n🧪 Example prediction:")
print(predict("The company failed to monitor labor practices in its supply chain."))