In [None]:
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score
from transformers import BertTokenizer, BertForSequenceClassification, AdamW
from torch.utils.data import DataLoader, TensorDataset
import torch
from tqdm import tqdm
import os
import json

# Define base directory for all saved files
BASE_DIR = '/kaggle/working/'

# Load the dataset
try:
    df = pd.read_csv('/kaggle/input/image-value/merged_output_f3.csv')
except FileNotFoundError:
    raise FileNotFoundError("The dataset file 'your_dataset.csv' was not found. Please ensure the file exists and the path is correct.")

# Check if required columns exist
required_columns = ['extracted_text', 'entity_name', 'entity_value']
missing_columns = [col for col in required_columns if col not in df.columns]
if missing_columns:
    raise ValueError(f"The following required columns are missing from the dataset: {', '.join(missing_columns)}")

# Remove rows with NaN values
df_clean = df.dropna(subset=required_columns)
if len(df_clean) < len(df):
    print(f"Removed {len(df) - len(df_clean)} rows with NaN values.")
    df = df_clean

# Combine extracted_text and entity_name
df['input_text'] = df['extracted_text'].astype(str) + ' [SEP] ' + df['entity_name'].astype(str)

# Create a dictionary to map entity_value to integer labels
unique_values = df['entity_value'].unique()
value_to_label = {value: idx for idx, value in enumerate(unique_values)}
label_to_value = {idx: value for value, idx in value_to_label.items()}

# Convert entity_value to integer labels
df['label'] = df['entity_value'].map(value_to_label)

# Split the data
train_texts, val_texts, train_labels, val_labels = train_test_split(
    df['input_text'].tolist(),
    df['label'].tolist(),
    test_size=0.2,
    random_state=42
)

# Load pre-trained BERT tokenizer and model
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=len(unique_values))

# Function to safely tokenize texts
def safe_tokenize(texts, max_length=128):
    try:
        return tokenizer(texts, truncation=True, padding=True, max_length=max_length)
    except Exception as e:
        print(f"Error during tokenization: {str(e)}")
        print("Problematic texts:")
        for text in texts:
            if not isinstance(text, str):
                print(f"  - {text} (type: {type(text)})")
        raise

# Tokenize and encode the texts
try:
    train_encodings = safe_tokenize(train_texts)
    val_encodings = safe_tokenize(val_texts)
except Exception as e:
    print(f"Tokenization failed: {str(e)}")
    raise

# Convert to PyTorch tensors
train_dataset = TensorDataset(
    torch.tensor(train_encodings['input_ids']),
    torch.tensor(train_encodings['attention_mask']),
    torch.tensor(train_labels)
)
val_dataset = TensorDataset(
    torch.tensor(val_encodings['input_ids']),
    torch.tensor(val_encodings['attention_mask']),
    torch.tensor(val_labels)
)

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)

# Set up the optimizer
optimizer = AdamW(model.parameters(), lr=2e-5)

# Checkpoint functions
def save_checkpoint(epoch, model, optimizer, best_f1):
    """Save training checkpoint"""
    checkpoint_dir = os.path.join(BASE_DIR, 'checkpoints')
    os.makedirs(checkpoint_dir, exist_ok=True)
    
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'best_f1': best_f1,
        'value_to_label': value_to_label,
        'label_to_value': label_to_value
    }
    
    checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch}.pt')
    torch.save(checkpoint, checkpoint_path)
    print(f"Checkpoint saved: {checkpoint_path}")

def load_checkpoint(model, optimizer):
    """Load latest checkpoint if it exists"""
    checkpoint_dir = os.path.join('/kaggle/input', 'checkpoint')
    if not os.path.exists(checkpoint_dir):
        return 0, 0.0  # Return default values if no checkpoint exists
    
    # Find the latest checkpoint
    checkpoints = [f for f in os.listdir(checkpoint_dir) if f.startswith('checkpoint_epoch_')]
    if not checkpoints:
        print("heloooooo")
        return 0, 0.0
    
    latest_checkpoint = max(checkpoints, key=lambda x: int(x.split('_')[-1].split('.')[0]))
    checkpoint_path = os.path.join(checkpoint_dir, latest_checkpoint)
    
    try:
        checkpoint = torch.load(checkpoint_path)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        
        global value_to_label, label_to_value
        value_to_label = checkpoint['value_to_label']
        label_to_value = checkpoint['label_to_value']
        
        print(f"Resumed from epoch {checkpoint['epoch']} with best F1 score: {checkpoint['best_f1']:.4f}")
        return checkpoint['epoch'], checkpoint['best_f1']
        
    except Exception as e:
        print(f"Error loading checkpoint: {str(e)}")
        return 0, 0.0

# Training loop
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

# Load checkpoint if exists
start_epoch, best_f1 = load_checkpoint(model, optimizer)
print("vasuuuuuu")
num_epochs = 25
for epoch in range (num_epochs):
    model.train()
    for batch in tqdm(train_loader, desc=f'Epoch {epoch + 1}/{num_epochs}'):
        input_ids, attention_mask, labels = [b.to(device) for b in batch]
        outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

    # Validation
    model.eval()
    val_preds = []
    val_true = []
    with torch.no_grad():
        for batch in tqdm(val_loader, desc='Validation'):
            input_ids, attention_mask, labels = [b.to(device) for b in batch]
            outputs = model(input_ids, attention_mask=attention_mask)
            preds = torch.argmax(outputs.logits, dim=1)
            val_preds.extend(preds.cpu().numpy())
            val_true.extend(labels.cpu().numpy())

    # Calculate F1 score
    current_f1 = f1_score(val_true, val_preds, average='weighted')
    print(f'Epoch {epoch + 1}/{num_epochs}, Validation F1 Score: {current_f1:.4f}')
    
    # Save checkpoint if we have the best score
    if current_f1 > best_f1:
        best_f1 = current_f1
    if((epoch+1)%5==0):
        save_checkpoint(epoch + 1, model, optimizer, best_f1)

# Function to predict entity_value
def predict_entity_value(text, entity_name):
    input_text = str(text) + ' [SEP] ' + str(entity_name)
    inputs = tokenizer(input_text, return_tensors='pt', truncation=True, padding=True, max_length=128)
    inputs = {k: v.to(device) for k, v in inputs.items()}
    
    with torch.no_grad():
        outputs = model(**inputs)
        pred = torch.argmax(outputs.logits, dim=1).item()
    
    return label_to_value[pred]

# Example usage
example_text = "Product weight: 500g"
example_entity_name = "item_weight"
predicted_value = predict_entity_value(example_text, example_entity_name)
print(f"Predicted entity value: {predicted_value}")

# Save the model and tokenizer in the Kaggle working directory
model_save_path = os.path.join(BASE_DIR, 'entity_value_predictor_model')
tokenizer_save_path = os.path.join(BASE_DIR, 'entity_value_predictor_tokenizer')
mappings_save_path = os.path.join(BASE_DIR, 'entity_value_predictor_mappings.json')

# Create directories if they don't exist
os.makedirs(model_save_path, exist_ok=True)
os.makedirs(tokenizer_save_path, exist_ok=True)

# Save model and tokenizer
model.save_pretrained(model_save_path)
tokenizer.save_pretrained(tokenizer_save_path)

# Save the label mappings
mappings = {
    'value_to_label': value_to_label,
    'label_to_value': label_to_value
}
with open(mappings_save_path, 'w') as f:
    json.dump(mappings, f)

print(f"\nModel saved to: {model_save_path}")
print(f"Tokenizer saved to: {tokenizer_save_path}")
print(f"Mappings saved to: {mappings_save_path}")

# To load the model later:
# loaded_model = BertForSequenceClassification.from_pretrained('/kaggle/working/entity_value_predictor_model')
# loaded_tokenizer = BertTokenizer.from_pretrained('/kaggle/working/entity_value_predictor_tokenizer')
# with open('/kaggle/working/entity_value_predictor_mappings.json', 'r') as f:
#     loaded_mappings = json.load(f)