In [None]:
import os
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, BertModel
import torch.nn as nn
import torch.optim as optim
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report
import torch.nn.functional as F


base_dir = os.getcwd() + "/"

# base_dir = "/content/drive/MyDrive/bookend/dev/text-style-classify/basic/"

In [None]:
"""from google.colab import drive
drive.mount('/content/drive')"""

In [None]:
df = pd.read_pickle(base_dir + 'data/labeled_compact.pkl')
df

In [None]:
max([len(l) for l in df['tokenized_sentence'].to_list()])

In [None]:
class CustomDataset(Dataset):
    def __init__(self, df, max_length=64):
        df = df.reset_index(drop=True)
        self.sentences = df['tokenized_sentence'].tolist()
        self.labels = df['encoded_author'].tolist()
        self.max_length = max_length

    def __len__(self):
        return len(self.sentences)
    
    def __getitem__(self, idx):
        sentence = self.sentences[idx]
        label = self.labels[idx]
        
        # 텐서 복사
        input_ids = sentence.clone().detach()
        attention_mask = torch.tensor([input_ids[i] != 0 for i in range(len(input_ids))])

        return {
            'input_ids': input_ids.flatten(),
            'label': torch.tensor(label, dtype=torch.long),
            'attention_mask': attention_mask.flatten()
        }


In [None]:
train_df, val_df = train_test_split(df, test_size=0.2, random_state=42)

train_dataset = CustomDataset(train_df)
val_dataset = CustomDataset(val_df)

train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)

In [None]:
class BERTClassifier(nn.Module):
    def __init__(self, bert_model_name, num_classes=3763, hidden_size=256):
        super(BERTClassifier, self).__init__()
        self.bert = BertModel.from_pretrained(bert_model_name)
        self.fc1 = nn.Linear(self.bert.config.hidden_size, hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, num_classes)
    
    def forward(self, input_ids, attention_mask=None):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        cls_output = outputs.pooler_output
        x = self.fc1(cls_output)
        x = self.relu(x)
        x = self.fc2(x)
        return x



In [None]:
model = BERTClassifier('distilbert-base-multilingual-cased', num_classes=3763, hidden_size=128)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

optimizer = optim.Adam(model.parameters(), lr=1e-5)
criterion = nn.CrossEntropyLoss()


In [None]:
df

In [None]:
epochs = 3

for epoch in range(epochs):
    model.train()
    total_loss = 0
    
    for batch in train_loader:
        optimizer.zero_grad()
        
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['label'].to(device)
        
        outputs = model(input_ids, attention_mask=attention_mask)
        loss = criterion(outputs, labels)
        
        total_loss += loss.item()
        
        loss.backward()
        optimizer.step()
    
    avg_train_loss = total_loss / len(train_loader)
    print(f'Epoch {epoch + 1}, Loss: {avg_train_loss}')
    
    # Validation
    model.eval()
    total_eval_accuracy = 0
    total_eval_loss = 0
    
    for batch_idx, batch in enumerate(val_loader):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['label'].to(device)
        
        with torch.no_grad():
            outputs = model(input_ids, attention_mask=attention_mask)
        
        loss = criterion(outputs, labels)
        
        total_eval_loss += loss.item()
        
        preds = torch.argmax(outputs, dim=1).flatten()
        accuracy = (preds == labels).cpu().numpy().mean() * 100
        total_eval_accuracy += accuracy
        print(f"Validation Batch: {batch_idx + 1}, Loss: {loss.item()}, Accuracy: {accuracy}")
    
    avg_val_accuracy = total_eval_accuracy / len(val_loader)
    avg_val_loss = total_eval_loss / len(val_loader)
    
    print(f'Validation Accuracy: {avg_val_accuracy}')
    print(f'Validation Loss: {avg_val_loss}')
