In [None]:
import torch
from transformers import DistilBertForSequenceClassification, Trainer, TrainingArguments, EarlyStoppingCallback
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
import pickle

# Wczytanie ztokenizowanych danych
with open("train_tokenized.pkl", "rb") as f:
    train_encodings = pickle.load(f)

with open("val_tokenized.pkl", "rb") as f:
    val_encodings = pickle.load(f)

# Wczytanie etykiet
with open("train_labels.pkl", "rb") as f:
    train_labels = pickle.load(f)

with open("val_labels.pkl", "rb") as f:
    val_labels = pickle.load(f)


# Konwersja danych do formatu PyTorch
class DatasetWrapper(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        item = {key: torch.tensor(self.encodings[idx][key]) for key in self.encodings[0].keys()}
        item['labels'] = torch.tensor(self.labels[idx])
        return item

train_dataset = DatasetWrapper(train_encodings, train_labels)
val_dataset = DatasetWrapper(val_encodings, val_labels)

# Definicja modelu
model = DistilBertForSequenceClassification.from_pretrained(
    "distilbert-base-uncased",
    num_labels=3  # Liczba klas
)

# Obliczenie wag klas
class_weights = torch.tensor([0.0001, 0.0001, 300.0], dtype=torch.float32).to("cuda")

# Funkcja do obliczania metryk
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = torch.argmax(torch.tensor(logits), dim=-1).numpy()
    precision, recall, f1, _ = precision_recall_fscore_support(labels, predictions, average='weighted')
    acc = accuracy_score(labels, predictions)
    return {"accuracy": acc, "precision": precision, "recall": recall, "f1": f1}

# Argumenty treningu
training_args = TrainingArguments(
    output_dir="./results",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=128,  # Batch size
    per_device_eval_batch_size=128,
    num_train_epochs=6,  # Maksymalna liczba epok
    weight_decay=0.01,
    fp16=True,  # Automatic Mixed Precision (AMP)
    logging_dir="./logs",
    logging_steps=500,
    save_total_limit=2,
    load_best_model_at_end=True
)

# Definicja trenera z mechanizmem Early Stopping
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    compute_metrics=compute_metrics,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=2)]  # Zatrzymanie po 2 epokach bez poprawy
)

# Trening modelu
trainer.train()

# Zapisanie wytrenowanego modelu
model.save_pretrained("./distilbert_model")
print("Model został wytrenowany i zapisany w katalogu ./distilbert_model")
