In [None]:
import json
import pandas as pd
import os
import torch
from torch.utils.data import Dataset
from transformers import RobertaTokenizer, RobertaForSequenceClassification, Trainer, TrainingArguments

In [None]:
import torch
print(torch.cuda.is_available())
print(torch.cuda.get_device_name(0))

In [None]:
training_data_path = '/path/to/your/training_data/'
ground_truth_path='/path/to/your/ground_truth_data/CTA_training_gt.csv'

In [None]:
ground_truth=pd.read_csv(ground_truth_path)
unique_labels=ground_truth['label'].unique()

label_to_id= {label: idx for idx, label in enumerate(unique_labels)}
id_to_label = {v: k for k, v in label_to_id.items()}

In [None]:
with open('./label_to_id.json', 'w') as f:
    json.dump(label_to_id, f)

In [None]:
with open('label_to_id.json', 'r') as json_file:
    label_to_id = json.load(json_file)

# Now you can use label_to_id as a dictionary
print(label_to_id)

In [None]:
class ColumnTypeDataset(Dataset):
    def __init__(self, labels_df, tables_dir, tokenizer, max_seq_len=512):
        self.labels_df = labels_df
        self.tables_dir = tables_dir
        self.tokenizer = tokenizer
        self.max_seq_len = max_seq_len


        # Filter out rows with missing table files
        self.labels_df = self.labels_df[self.labels_df['table_name'].apply(lambda x: os.path.exists(os.path.join(self.tables_dir, x)))].reset_index(drop=True)

    def __len__(self):
        return len(self.labels_df)

    def __getitem__(self, idx):
        row = self.labels_df.iloc[idx]
        table_name = row['table_name']
        col_index = row['column_index']
        label = label_to_id.get(row['label'], -1)

        # Charger le fichier JSON.gz
        table_path = os.path.join(self.tables_dir, table_name)

        # if not os.path.exists(table_path):
        #   print(f"File not found, skipping: {table_path}")
        #   # On retourne un exemple vide avec label -1
        #   return {
        #       'input_ids': torch.zeros(self.max_seq_len, dtype=torch.long),
        #       'attention_mask': torch.zeros(self.max_seq_len, dtype=torch.long),
        #       'labels': torch.tensor(-1, dtype=torch.long)
        #   }

        df = pd.read_json(table_path, compression='gzip', lines=True)



        # Sérialisation des colonnes
        tokenized_columns = []
        for col in df.columns:
            serialized = ' '.join(df[col].astype(str).values)
            tokens = self.tokenizer.encode(serialized, add_special_tokens=False, truncation=True, max_length=14)
            tokenized_columns += [self.tokenizer.cls_token_id] + tokens  # [CLS] tokens

        tokenized_columns.append(self.tokenizer.sep_token_id)  # [SEP]

        # Tronquer si trop long
        if len(tokenized_columns) > self.max_seq_len:
            tokenized_columns = tokenized_columns[:self.max_seq_len]

        # Attention mask
        attention_mask = [1] * len(tokenized_columns)

        # Padding
        pad_len = self.max_seq_len - len(tokenized_columns)
        input_ids = tokenized_columns + [tokenizer.pad_token_id] * pad_len
        attention_mask += [0] * pad_len

        return {
            'input_ids': torch.tensor(input_ids, dtype=torch.long),
            'attention_mask': torch.tensor(attention_mask, dtype=torch.long),
            'labels': torch.tensor(label, dtype=torch.long)
        }

In [None]:
# === Tokenizer ===
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')


In [None]:
# Chargement des labels et initialisation du tokenizer
train_labels_df = pd.read_csv("/mnt/data/nassima/CTA_training_gt.csv")
tables_dir = "/mnt/data/nassima/Train"


In [None]:
train_dataset = ColumnTypeDataset(train_labels_df, tables_dir, tokenizer)

In [None]:
model= RobertaForSequenceClassification.from_pretrained('roberta-base', num_labels=len(label_to_id))

In [None]:
training_args = TrainingArguments(
    output_dir='./results__CTA',
    num_train_epochs=30,
    per_device_train_batch_size=8,
    # per_device_eval_batch_size=4,
    # eval_strategy="epoch",
    logging_dir='./logs',          # Répertoire pour les logs
    logging_steps=10,
)

In [None]:
# === Trainer ===
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    # eval_dataset=val_dataset,
)

In [None]:
# Start training
try:
    trainer.train()
except KeyboardInterrupt:
    # If manually stopped (e.g., Ctrl+C), save the model and state
    print("Training interrupted. Saving the model and state...")
    trainer.save_model(training_args.output_dir)
    trainer.save_state()
    print("Model and state saved. You can resume training later.")
else:
    # If training completes normally, save the final model and state
    print("Training completed. Saving the final model and state...")
    trainer.save_model(training_args.output_dir)
    trainer.save_state()
    print("Final model and state saved.")