In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torch.optim import AdamW
from transformers import DistilBertTokenizer, DistilBertModel
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score
import pandas as pd
import numpy as np
from tqdm import tqdm
import os
import matplotlib.pyplot as plt
import torch.nn.functional as F

In [None]:
MODEL_NAME = "DeepPavlov/distilrubert-base-cased-conversational"
MAX_LENGTH = 96
NUM_LABELS = 2
LEARNING_RATE = 2e-5
NUM_EPOCHS = 3
BATCH_SIZE = 2448
TEST_SIZE = 0.2
VAL_TEST_SPLIT = 0.3
RANDOM_STATE = 42
SAVE_DIR = "style_classifier"
MODEL_PATH = os.path.join(SAVE_DIR, "model.pth")
LOSS_PLOT_PATH = os.path.join(SAVE_DIR, "training_loss.png")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
from google.colab import drive
drive.mount('/content/drive')
SAVE_DIR = "/content/drive/MyDrive/CycleGAN_for_TST_problem/style_transfer/style_classifier"
MODEL_PATH = os.path.join(SAVE_DIR, "model.pth")

In [None]:
data = pd.read_csv("data.csv")

In [None]:
texts = list(data['tg_text']) + list(data['lit_text'])
labels = [0] * len(data['tg_text']) + [1] * len(data['lit_text'])

In [None]:
train_texts, temp_texts, train_labels, temp_labels = train_test_split(
    texts, labels, test_size=TEST_SIZE, random_state=RANDOM_STATE
)
val_texts, test_texts, val_labels, test_labels = train_test_split(
    temp_texts, temp_labels, test_size=VAL_TEST_SPLIT, random_state=RANDOM_STATE
)

In [None]:
tokenizer = DistilBertTokenizer.from_pretrained(MODEL_NAME)

In [None]:
def tokenize(texts, max_length=MAX_LENGTH):
    return tokenizer(texts, padding=True, truncation=True, max_length=max_length, return_tensors='pt')

In [None]:
train_encodings = tokenize(train_texts)
val_encodings = tokenize(val_texts)
test_encodings = tokenize(test_texts)

os.makedirs(SAVE_DIR, exist_ok=True)

torch.save(train_encodings, os.path.join(SAVE_DIR, "train_encodings.pt"))
torch.save(val_encodings, os.path.join(SAVE_DIR, "val_encodings.pt"))
torch.save(test_encodings, os.path.join(SAVE_DIR, "test_encodings.pt"))

In [None]:
train_encodings = torch.load(os.path.join(SAVE_DIR, "train_encodings.pt"))
val_encodings = torch.load(os.path.join(SAVE_DIR, "val_encodings.pt"))
test_encodings = torch.load(os.path.join(SAVE_DIR, "test_encodings.pt"))

In [None]:
class StyleDataset(Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        item = {key: val[idx] for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx], dtype=torch.long)
        return item

In [None]:
train_dataset = StyleDataset(train_encodings, train_labels)
val_dataset = StyleDataset(val_encodings, val_labels)
test_dataset = StyleDataset(test_encodings, test_labels)

In [None]:
class StyleClassifier(nn.Module):
    def __init__(self, num_labels=NUM_LABELS):
        super().__init__()
        self.bert = DistilBertModel.from_pretrained(MODEL_NAME)
        self.classifier = nn.Linear(self.bert.config.hidden_size, num_labels)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        cls_output = outputs.last_hidden_state[:, 0, :]
        logits = self.classifier(cls_output)
        return logits

In [None]:
model = StyleClassifier().to(device)

In [None]:
loss_fn = nn.CrossEntropyLoss()

In [None]:
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)

In [None]:
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE)


In [None]:
scaler = torch.cuda.amp.GradScaler()

In [None]:
train_losses = []
val_losses = []

for epoch in range(NUM_EPOCHS):
    model.train()
    total_train_loss = 0
    for batch in tqdm(train_loader):
        optimizer.zero_grad()
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        with torch.amp.autocast(device.type):
            logits = model(input_ids, attention_mask)
            loss = loss_fn(logits, labels)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        total_train_loss += loss.item()

    avg_train_loss = total_train_loss / len(train_loader)
    train_losses.append(avg_train_loss)
    print(f"Epoch {epoch+1}, Train Loss: {avg_train_loss:.4f}")

    model.eval()
    total_val_loss = 0
    val_preds = []
    val_true = []
    with torch.no_grad():
        for batch in val_loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            with torch.amp.autocast(device.type):
                logits = model(input_ids, attention_mask)
                loss = loss_fn(logits, labels)
                preds = torch.argmax(logits, dim=1)

            total_val_loss += loss.item()
            val_preds.extend(preds.cpu().numpy())
            val_true.extend(labels.cpu().numpy())

    avg_val_loss = total_val_loss / len(val_loader)
    val_losses.append(avg_val_loss)
    val_accuracy = accuracy_score(val_true, val_preds)
    val_f1 = f1_score(val_true, val_preds, average='weighted')
    print(f"Epoch {epoch+1}, Validation Loss: {avg_val_loss:.4f}, Accuracy: {val_accuracy:.4f}, F1 Score: {val_f1:.4f}")

100%|██████████| 904/904 [10:06<00:00,  1.49it/s]


Epoch 1, Train Loss: 0.1175
Epoch 1, Validation Loss: 0.0983, Accuracy: 0.9604, F1 Score: 0.9604


100%|██████████| 904/904 [10:03<00:00,  1.50it/s]


Epoch 2, Train Loss: 0.0871
Epoch 2, Validation Loss: 0.0918, Accuracy: 0.9632, F1 Score: 0.9632


100%|██████████| 904/904 [10:04<00:00,  1.49it/s]


Epoch 3, Train Loss: 0.0725
Epoch 3, Validation Loss: 0.0900, Accuracy: 0.9644, F1 Score: 0.9644


In [None]:
model.eval()
test_preds = []
test_true = []
with torch.no_grad():
    for batch in test_loader:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        with torch.amp.autocast(device.type):
            logits = model(input_ids, attention_mask)
            preds = torch.argmax(logits, dim=1)

        test_preds.extend(preds.cpu().numpy())
        test_true.extend(labels.cpu().numpy())

test_accuracy = accuracy_score(test_true, test_preds)
test_f1 = f1_score(test_true, test_preds, average='weighted')
print(f"Test Accuracy: {test_accuracy:.4f}, F1 Score: {test_f1:.4f}")

Test Accuracy: 0.9647, F1 Score: 0.9647


In [None]:
plt.figure(figsize=(8, 6))
plt.plot(range(1, NUM_EPOCHS + 1), train_losses, label='Train Loss', marker='o')
plt.plot(range(1, NUM_EPOCHS + 1), val_losses, label='Validation Loss', marker='o')
plt.title('Training and Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)
plt.savefig(LOSS_PLOT_PATH)
plt.close()

In [None]:
os.makedirs(SAVE_DIR, exist_ok=True)
torch.save(model.state_dict(), MODEL_PATH)
tokenizer.save_pretrained(SAVE_DIR)

('style_classifier/tokenizer_config.json',
 'style_classifier/special_tokens_map.json',
 'style_classifier/vocab.txt',
 'style_classifier/added_tokens.json')

In [None]:
model = StyleClassifier()
model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
model.to(device)
model.eval()

tokenizer = DistilBertTokenizer.from_pretrained(SAVE_DIR)

In [None]:
def predict_style(texts, return_probs=False):
    encodings = tokenizer(texts, padding=True, truncation=True, max_length=MAX_LENGTH, return_tensors="pt")
    input_ids = encodings['input_ids'].to(device)
    attention_mask = encodings['attention_mask'].to(device)

    with torch.no_grad():
        logits = model(input_ids, attention_mask)
        probs = torch.softmax(logits, dim=1)

    if return_probs:
        return probs.cpu().numpy()
    else:
        preds = torch.argmax(probs, dim=1)
        return preds.cpu().numpy()


In [None]:
conv_text = "Было же время, всё было как будто чище, легче. Дышалось. Не знаю, как объяснить — но тогда просто жил и не думал, зачем. И это было нормально. а теперь всё как будто через фильтр, чужой."
lit_text = "Он вспоминал то время не как череду событий, а как состояние: утро, в котором не нужно ничего решать. Пустота, от которой не страшно. Тогда он просто существовал — не объясняя себе зачем. Теперь всё иначе, и в этом иначе не было покоя."
predict_style(conv_text, return_probs=True), predict_style(lit_text, return_probs=True)

(array([[0.99698526, 0.00301473]], dtype=float32),
 array([[0.19649537, 0.8035046 ]], dtype=float32))

In [None]:
def style_loss(predicted_texts, target_style_label):
    # Токенизация
    encodings = tokenizer(predicted_texts, padding=True, truncation=True, max_length=96, return_tensors='pt').to(device)

    with torch.no_grad():
        logits = model(encodings['input_ids'], encodings['attention_mask'])

    target_labels = torch.full((logits.size(0),), target_style_label, dtype=torch.long, device=device)

    loss = F.cross_entropy(logits, target_labels)
    return loss


In [None]:
style_loss(text, 0)

tensor(2.1934e-05, device='cuda:0')