In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, BertForSequenceClassification
from sklearn.model_selection import train_test_split
from preprocess import load_and_preprocess

MODEL_NAME = "bert-base-uncased"
BATCH_SIZE = 8
EPOCHS = 2

tokenizer = BertTokenizer.from_pretrained(MODEL_NAME)

class SentimentDataset(Dataset):
    def __init__(self, texts, labels):
        self.encodings = tokenizer(
            texts, truncation=True, padding=True, max_length=128
        )
        self.labels = labels

    def __getitem__(self, idx):
        item = {k: torch.tensor(v[idx]) for k, v in self.encodings.items()}
        item["labels"] = torch.tensor(self.labels[idx])
        return item

    def __len__(self):
        return len(self.labels)

# Load data
df = load_and_preprocess("data/reviews.csv")

X_train, X_val, y_train, y_val = train_test_split(
    df['clean_text'], df['label'], test_size=0.2, random_state=42
)

train_dataset = SentimentDataset(X_train.tolist(), y_train.tolist())
val_dataset = SentimentDataset(X_val.tolist(), y_val.tolist())

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)

model = BertForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=2)
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Training
for epoch in range(EPOCHS):
    model.train()
    total_loss = 0

    for batch in train_loader:
        batch = {k: v.to(device) for k, v in batch.items()}
        optimizer.zero_grad()

        outputs = model(**batch)
        loss = outputs.loss
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch+1} Training Loss: {total_loss/len(train_loader):.4f}")

# Save model
model.save_pretrained("models/bert_sentiment_model")
tokenizer.save_pretrained("models/bert_sentiment_model")
