In [1]:
# Import Libraries
import pandas as pd
import numpy as np
import string
import re
import random
import time
import matplotlib.pyplot as plt
import seaborn as sns

import nltk
from nltk.tokenize import word_tokenize
from nltk.corpus import stopwords
from nltk.stem import WordNetLemmatizer
from Sastrawi.Stemmer.StemmerFactory import StemmerFactory

from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix

import torch
from torch import nn
from transformers import BertTokenizer, BertModel, AdamW
from torch.utils.data import Dataset, DataLoader

import warnings
warnings.filterwarnings("ignore")

# For tqdm progress bars
from tqdm import tqdm


## Load CSV File

In [2]:
dfnews = pd.read_csv('https://raw.githubusercontent.com/jasminemutia/dataset/main/media_news.csv')

## Data Overview

In [3]:
dfnews.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 7949 entries, 0 to 7948
Data columns (total 4 columns):
 #   Column  Non-Null Count  Dtype 
---  ------  --------------  ----- 
 0   URL     7949 non-null   object
 1   Title   7949 non-null   object
 2   Media   7949 non-null   object
 3   Text    7927 non-null   object
dtypes: object(4)
memory usage: 248.5+ KB


In [4]:
row, col = dfnews.shape
print("Jumlah baris dalam data IBM: ", row)
print("Jumlah kolom dalam data IBM: ", col)

Jumlah baris dalam data IBM:  7949
Jumlah kolom dalam data IBM:  4


In [5]:
dfnews = dfnews.drop(["URL", "Media"], axis=1)

In [6]:
dfnews.head()

Unnamed: 0,Title,Text
0,Menakar Dampak Resesi Jepang ke Pasar Modal In...,"Liputan6.com, Jakarta Jepang mengalami technic..."
1,"IHSG Lanjutkan Kenaikan, Investor Asing Borong...","Liputan6.com, Jakarta - Laju Indeks Harga Saha..."
2,"Suspensi Capai 4 Tahun, Bursa Ingatkan Potensi...","Liputan6.com, Jakarta Bursa Efek Indonesia (BE..."
3,"Lolos PKPU, Bursa Cabut Notasi Khusus M pada A...","Liputan6.com, Jakarta - Bursa Efek Indonesia (..."
4,"Bergerak Volatil, Bursa Cecar Emiten Panca Mit...","Liputan6.com, Jakarta Emiten pengolahan udang,..."


In [7]:
dfnews.duplicated().sum()

403

In [8]:
dfnews = dfnews.drop_duplicates()
dfnews

Unnamed: 0,Title,Text
0,Menakar Dampak Resesi Jepang ke Pasar Modal In...,"Liputan6.com, Jakarta Jepang mengalami technic..."
1,"IHSG Lanjutkan Kenaikan, Investor Asing Borong...","Liputan6.com, Jakarta - Laju Indeks Harga Saha..."
2,"Suspensi Capai 4 Tahun, Bursa Ingatkan Potensi...","Liputan6.com, Jakarta Bursa Efek Indonesia (BE..."
3,"Lolos PKPU, Bursa Cabut Notasi Khusus M pada A...","Liputan6.com, Jakarta - Bursa Efek Indonesia (..."
4,"Bergerak Volatil, Bursa Cecar Emiten Panca Mit...","Liputan6.com, Jakarta Emiten pengolahan udang,..."
...,...,...
7544,"Saham China dibuka lebih rendah, indeks Shangh...","Saham China dibuka lebih rendah, indeks Shangh..."
7545,Emas jatuh di bawah level kunci 2.000 dolar ka...,Emas jatuh di bawah level kunci 2.000 dolar ka...
7546,"Saham Inggris berakhir negatif, indeks FTSE 10...","Saham Inggris berakhir negatif, indeks FTSE 10..."
7547,"Saham Jerman berbalik melemah, indeks DAX 40 t...","Saham Jerman berbalik melemah, indeks DAX 40 t..."


In [9]:
dfnews['Text'].isnull().sum()

22

In [10]:
dfnews = dfnews.dropna(subset=['Text']).reset_index(drop=True)

In [11]:
dfnews.shape

(7524, 2)

In [12]:
dfnews

Unnamed: 0,Title,Text
0,Menakar Dampak Resesi Jepang ke Pasar Modal In...,"Liputan6.com, Jakarta Jepang mengalami technic..."
1,"IHSG Lanjutkan Kenaikan, Investor Asing Borong...","Liputan6.com, Jakarta - Laju Indeks Harga Saha..."
2,"Suspensi Capai 4 Tahun, Bursa Ingatkan Potensi...","Liputan6.com, Jakarta Bursa Efek Indonesia (BE..."
3,"Lolos PKPU, Bursa Cabut Notasi Khusus M pada A...","Liputan6.com, Jakarta - Bursa Efek Indonesia (..."
4,"Bergerak Volatil, Bursa Cecar Emiten Panca Mit...","Liputan6.com, Jakarta Emiten pengolahan udang,..."
...,...,...
7519,"Saham China dibuka lebih rendah, indeks Shangh...","Saham China dibuka lebih rendah, indeks Shangh..."
7520,Emas jatuh di bawah level kunci 2.000 dolar ka...,Emas jatuh di bawah level kunci 2.000 dolar ka...
7521,"Saham Inggris berakhir negatif, indeks FTSE 10...","Saham Inggris berakhir negatif, indeks FTSE 10..."
7522,"Saham Jerman berbalik melemah, indeks DAX 40 t...","Saham Jerman berbalik melemah, indeks DAX 40 t..."


In [13]:
dfnews.to_csv("media_news_unlabelled.csv", index=False)

In [14]:
dfnews_unlabelled = pd.read_csv('https://raw.githubusercontent.com/jasminemutia/dataset/main/media_news_unlabelled.csv')

## Labelling Data

In [15]:
df_labelled = pd.read_csv('https://raw.githubusercontent.com/jasminemutia/dataset/main/media_news_labelled.csv')

In [16]:
# Hapus baris yang memiliki label "Null"
df_labelled = df_labelled[df_labelled['Sentiment'] != 'Null']

In [17]:
df_labelled

Unnamed: 0,Title,Text,Sentiment
0,Menakar Dampak Resesi Jepang ke Pasar Modal In...,"Liputan6.com, Jakarta Jepang mengalami technic...",Positive
1,"IHSG Lanjutkan Kenaikan, Investor Asing Borong...","Liputan6.com, Jakarta - Laju Indeks Harga Saha...",Positive
2,"Suspensi Capai 4 Tahun, Bursa Ingatkan Potensi...","Liputan6.com, Jakarta Bursa Efek Indonesia (BE...",Neutral
4,"Bergerak Volatil, Bursa Cecar Emiten Panca Mit...","Liputan6.com, Jakarta Emiten pengolahan udang,...",Neutral
6,"Stock Split 1:2, Tembaga Mulia Semanan Umumkan...","Liputan6.com, Jakarta PT Tembaga Mulia Semanan...",Positive
...,...,...,...
7519,"Saham China dibuka lebih rendah, indeks Shangh...","Saham China dibuka lebih rendah, indeks Shangh...",Negative
7520,Emas jatuh di bawah level kunci 2.000 dolar ka...,Emas jatuh di bawah level kunci 2.000 dolar ka...,Negative
7521,"Saham Inggris berakhir negatif, indeks FTSE 10...","Saham Inggris berakhir negatif, indeks FTSE 10...",Negative
7522,"Saham Jerman berbalik melemah, indeks DAX 40 t...","Saham Jerman berbalik melemah, indeks DAX 40 t...",Negative


In [18]:
df_labelled.to_csv("news_sentiment.csv", index=False)

## Data Cleansing

In [19]:
df_labelled = pd.read_csv("https://raw.githubusercontent.com/jasminemutia/dataset/main/news_sentiment.csv")

In [20]:
df_labelled

Unnamed: 0,Title,Text,Sentiment
0,Menakar Dampak Resesi Jepang ke Pasar Modal In...,"Liputan6.com, Jakarta Jepang mengalami technic...",Positive
1,"IHSG Lanjutkan Kenaikan, Investor Asing Borong...","Liputan6.com, Jakarta - Laju Indeks Harga Saha...",Positive
2,"Suspensi Capai 4 Tahun, Bursa Ingatkan Potensi...","Liputan6.com, Jakarta Bursa Efek Indonesia (BE...",Neutral
3,"Bergerak Volatil, Bursa Cecar Emiten Panca Mit...","Liputan6.com, Jakarta Emiten pengolahan udang,...",Neutral
4,"Stock Split 1:2, Tembaga Mulia Semanan Umumkan...","Liputan6.com, Jakarta PT Tembaga Mulia Semanan...",Positive
...,...,...,...
7323,"Saham China dibuka lebih rendah, indeks Shangh...","Saham China dibuka lebih rendah, indeks Shangh...",Negative
7324,Emas jatuh di bawah level kunci 2.000 dolar ka...,Emas jatuh di bawah level kunci 2.000 dolar ka...,Negative
7325,"Saham Inggris berakhir negatif, indeks FTSE 10...","Saham Inggris berakhir negatif, indeks FTSE 10...",Negative
7326,"Saham Jerman berbalik melemah, indeks DAX 40 t...","Saham Jerman berbalik melemah, indeks DAX 40 t...",Negative


In [21]:
print("String punctuation to be removed: ", string.punctuation)

# Buat kolom baru untuk menyimpan teks yang sudah dibersihkan
df_labelled['Clean Text'] = df_labelled['Text'].copy()
df_labelled['Clean Text'] = df_labelled['Clean Text'].str.lower()
# Remove Number
df_labelled['Clean Text'] = df_labelled['Clean Text'].str.replace(r'\d+', '', regex=True)
# Remove "https" and ".com"
df_labelled['Clean Text'] = df_labelled['Clean Text'].str.replace(r'\.com|\.id|\.co', '',regex=True)
df_labelled['Clean Text'] = df_labelled['Clean Text'].str.replace(r'https\S+|http\S+|www\.\S+','', regex=True)
# Remove Enter
df_labelled['Clean Text'] = df_labelled['Clean Text'].str.replace(r'\n', ' ', regex=True)
# Remove punctuation from the 'text' column
df_labelled['Clean Text'] = df_labelled['Clean Text'].str.replace('[{}]'.format(string.punctuation), '', regex=True)

# Additional symbol removal
additional_symbols = r'[©â€“œ]'
df_labelled['Clean Text'] = df_labelled['Clean Text'].str.replace(additional_symbols, '', regex=True)

# # remove stop words
# stop = set(stopwords.words('indonesian'))
# text = [x for x in df_labelled['Text'] if x not in stop]

# Remove multiple whitespcae
def remove_whitespace(text):
  if isinstance(text, str):
    return re.sub(r'\s+', ' ', text)
  else:
    return text

df_labelled['Clean Text'] = df_labelled['Clean Text'].apply(remove_whitespace)

String punctuation to be removed:  !"#$%&'()*+,-./:;<=>?@[\]^_`{|}~


In [22]:
import nltk
nltk.download('stopwords')

[nltk_data] Downloading package stopwords to
[nltk_data]     C:\Users\nicho\AppData\Roaming\nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


True

In [23]:
# Mengambil stopwords bahasa Indonesia
stop = set(stopwords.words('indonesian'))

# Fungsi untuk menghapus stopwords dari teks
def remove_stopwords(text):
    if isinstance(text, str):
        # Membagi teks menjadi kata-kata
        words = text.split()
        # Menghapus stopwords
        filtered_words = [word for word in words if word.lower() not in stop]
        # Menggabungkan kata-kata kembali menjadi teks
        return ' '.join(filtered_words)
    else:
        return text

# Terapkan fungsi remove_stopwords pada kolom 'Text'
df_labelled['Clean Text'] = df_labelled['Clean Text'].apply(remove_stopwords)

In [24]:
df_labelled

Unnamed: 0,Title,Text,Sentiment,Clean Text
0,Menakar Dampak Resesi Jepang ke Pasar Modal In...,"Liputan6.com, Jakarta Jepang mengalami technic...",Positive,liputan jakarta jepang mengalami technical rec...
1,"IHSG Lanjutkan Kenaikan, Investor Asing Borong...","Liputan6.com, Jakarta - Laju Indeks Harga Saha...",Positive,liputan jakarta laju indeks harga saham gabung...
2,"Suspensi Capai 4 Tahun, Bursa Ingatkan Potensi...","Liputan6.com, Jakarta Bursa Efek Indonesia (BE...",Neutral,liputan jakarta bursa efek indonesia bei mengu...
3,"Bergerak Volatil, Bursa Cecar Emiten Panca Mit...","Liputan6.com, Jakarta Emiten pengolahan udang,...",Neutral,liputan jakarta emiten pengolahan udang pt pan...
4,"Stock Split 1:2, Tembaga Mulia Semanan Umumkan...","Liputan6.com, Jakarta PT Tembaga Mulia Semanan...",Positive,liputan jakarta pt tembaga mulia semanan tbk t...
...,...,...,...,...
7323,"Saham China dibuka lebih rendah, indeks Shangh...","Saham China dibuka lebih rendah, indeks Shangh...",Negative,saham china dibuka rendah indeks shanghai jatu...
7324,Emas jatuh di bawah level kunci 2.000 dolar ka...,Emas jatuh di bawah level kunci 2.000 dolar ka...,Negative,emas jatuh level kunci dolar greenback menguat...
7325,"Saham Inggris berakhir negatif, indeks FTSE 10...","Saham Inggris berakhir negatif, indeks FTSE 10...",Negative,saham inggris negatif indeks ftse berkurang pe...
7326,"Saham Jerman berbalik melemah, indeks DAX 40 t...","Saham Jerman berbalik melemah, indeks DAX 40 t...",Negative,saham jerman berbalik melemah indeks dax terpa...


## Data Splitting LLM Model

In [25]:
# Splitting the summarized data into training, validation, and testing sets
X_train, X_temp, y_train, y_temp = train_test_split(df_labelled['Clean Text'], df_labelled['Sentiment'], test_size=0.2, random_state=32, stratify=df_labelled['Sentiment'])
X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.5, random_state=32, stratify=y_temp)

print("Jumlah data pada training set :", len(X_train))
print("Jumlah data pada validation set :", len(X_val))
print("Jumlah data pada testing set :", len(X_test))

# Recreate DataFrames for each split
df_train = pd.DataFrame({
    'Clean Text': X_train,
    'Sentiment': y_train
})

df_val = pd.DataFrame({
    'Clean Text': X_val,
    'Sentiment': y_val
})

df_test = pd.DataFrame({
    'Clean Text': X_test,
    'Sentiment': y_test
})

Jumlah data pada training set : 5862
Jumlah data pada validation set : 733
Jumlah data pada testing set : 733


## Tokenization

In [26]:
import nltk
nltk.download('punkt')

[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\nicho\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

In [27]:
 #Fungsi untuk tokenisasi teks
def tokenize_text(text):
    if isinstance(text, str):
        # Tokenisasi kata
        word_tokens = word_tokenize(text)
        return word_tokens
    else:
        return []

# Terapkan fungsi tokenize_text pada kolom teks yang akan ditokenisasi
X_train = X_train.apply(tokenize_text)

# Tampilkan hasil tokenisasi
print(X_train.head())

5307    [idxchannel, bursa, efek, indonesia, bei, peri...
5900    [koci, raup, rp, miliar, dana, ipo, mayoritas,...
1809    [nilai, tukar, rupiah, posisi, rp, dolar, as, ...
1239    [jakarta, kompas, pemilik, waralaba, kebab, tu...
2722    [presiden, joko, widodo, jokowi, resmi, mewaji...
Name: Clean Text, dtype: object


In [28]:
 #Fungsi untuk tokenisasi teks
def tokenize_text(text):
    if isinstance(text, str):
        # Tokenisasi kata
        word_tokens = word_tokenize(text)
        return word_tokens
    else:
        return []

# Terapkan fungsi tokenize_text pada kolom teks yang akan ditokenisasi
X_test = X_test.apply(tokenize_text)

# Tampilkan hasil tokenisasi
print(X_test.head())

566     [jakarta, kompas, kehadiran, papan, pemantauan...
376     [jakarta, kompas, indeks, harga, saham, gabung...
7152    [saham, china, dibuka, rendah, indeks, shangha...
4554    [jakarta, cnbc, indonesia, cio, mandiri, manaj...
5598    [ihsg, ditutup, menguat, pasar, respon, sikap,...
Name: Clean Text, dtype: object


In [29]:
 #Fungsi untuk tokenisasi teks
def tokenize_text(text):
    if isinstance(text, str):
        # Tokenisasi kata
        word_tokens = word_tokenize(text)
        return word_tokens
    else:
        return []

# Terapkan fungsi tokenize_text pada kolom teks yang akan ditokenisasi
X_val = X_val.apply(tokenize_text)

# Tampilkan hasil tokenisasi
print(X_val.head())

7300    [saham, inggris, berbalik, menguat, indeks, ft...
2006    [menteri, bumn, erick, thohir, menyebut, harga...
6797    [ihsg, pekan, ditutup, melemah, ikuti, bursa, ...
578     [kompas, chief, executive, officer, ceo, cofou...
6670    [metrodata, electronics, raih, pendapatan, rp,...
Name: Clean Text, dtype: object


## Lemmatization

In [30]:
import nltk

# Download the required NLTK resources
nltk.download('punkt')
nltk.download('wordnet')


[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\nicho\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package wordnet to
[nltk_data]     C:\Users\nicho\AppData\Roaming\nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


True

In [31]:
# Inisialisasi WordNetLemmatizer
lemmatizer = WordNetLemmatizer()

# Fungsi untuk lemmatisasi teks
def lemmatize_text(text):
    if isinstance(text, list):
        # Lemmatisasi kata
        lemmatized_tokens = [lemmatizer.lemmatize(token) for token in text]
        return lemmatized_tokens
    else:
        return []

# Terapkan fungsi lemmatize_text pada kolom 'Text'
X_train = X_train.apply(lemmatize_text)

# Tampilkan hasil lemmatisasi
print(X_train.head())

5307    [idxchannel, bursa, efek, indonesia, bei, peri...
5900    [koci, raup, rp, miliar, dana, ipo, mayoritas,...
1809    [nilai, tukar, rupiah, posisi, rp, dolar, a, d...
1239    [jakarta, kompas, pemilik, waralaba, kebab, tu...
2722    [presiden, joko, widodo, jokowi, resmi, mewaji...
Name: Clean Text, dtype: object


In [32]:
# Inisialisasi WordNetLemmatizer
lemmatizer = WordNetLemmatizer()

# Fungsi untuk lemmatisasi teks
def lemmatize_text(text):
    if isinstance(text, list):
        # Lemmatisasi kata
        lemmatized_tokens = [lemmatizer.lemmatize(token) for token in text]
        return lemmatized_tokens
    else:
        return []

# Terapkan fungsi lemmatize_text pada kolom 'Text'
X_test = X_test.apply(lemmatize_text)

# Tampilkan hasil lemmatisasi
print(X_test.head())

566     [jakarta, kompas, kehadiran, papan, pemantauan...
376     [jakarta, kompas, indeks, harga, saham, gabung...
7152    [saham, china, dibuka, rendah, indeks, shangha...
4554    [jakarta, cnbc, indonesia, cio, mandiri, manaj...
5598    [ihsg, ditutup, menguat, pasar, respon, sikap,...
Name: Clean Text, dtype: object


In [33]:
# Inisialisasi WordNetLemmatizer
lemmatizer = WordNetLemmatizer()

# Fungsi untuk lemmatisasi teks
def lemmatize_text(text):
    if isinstance(text, list):
        # Lemmatisasi kata
        lemmatized_tokens = [lemmatizer.lemmatize(token) for token in text]
        return lemmatized_tokens
    else:
        return []

# Terapkan fungsi lemmatize_text pada kolom 'Text'
X_val = X_val.apply(lemmatize_text)

# Tampilkan hasil lemmatisasi
print(X_val.head())

7300    [saham, inggris, berbalik, menguat, indeks, ft...
2006    [menteri, bumn, erick, thohir, menyebut, harga...
6797    [ihsg, pekan, ditutup, melemah, ikuti, bursa, ...
578     [kompas, chief, executive, officer, ceo, cofou...
6670    [metrodata, electronics, raih, pendapatan, rp,...
Name: Clean Text, dtype: object


## Vectorization Bert

In [34]:
# Define the dataset class for BERT
class FinancialNewsDataset(Dataset):
    def __init__(self, df, tokenizer, max_len):
        #self.texts = df_labelled['Clean Text'].tolist()
        self.texts = df['Clean Text'].tolist()  # Menggunakan df yang benar
        self.labels = df['Sentiment'].apply(
            lambda x: 0 if x == 'Negative' else
                      1 if x == 'Neutral' else
                      2 if x == 'Positive' else
                      3 if x == 'Dual' else None  # Hindari -1
        ).dropna().tolist()  # Hapus data dengan label None
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        label = self.labels[idx]

        encoding = self.tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            max_length=self.max_len,
            return_token_type_ids=False,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt'
        )

        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': torch.tensor(label, dtype=torch.long)
        }

In [35]:
# Load pre-trained BERT tokenizer
tokenizer = BertTokenizer.from_pretrained('indolem/indobert-base-uncased')

# Define max length for tokenized inputs
MAX_LEN = 128

# Create data loaders
train_dataset = FinancialNewsDataset(df_train, tokenizer, MAX_LEN)
val_dataset = FinancialNewsDataset(df_val, tokenizer, MAX_LEN)
test_dataset = FinancialNewsDataset(df_test, tokenizer, MAX_LEN)

train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=16)
test_dataloader = DataLoader(test_dataset, batch_size=16)

## Modelling

In [36]:
class BertClassifierTune(nn.Module):
    def __init__(self, dropout=0.3):
        super(BertClassifierTune, self).__init__()  # Corrected class name
        self.bert = BertModel.from_pretrained('indolem/indobert-base-uncased')
        self.dropout = nn.Dropout(dropout)

        # Add extra fully connected layer
        self.fc1 = nn.Linear(768, 256)  # First hidden layer
        self.relu = nn.ReLU()  # Activation function

        # Final classifier layer for 4 classes
        self.fc2 = nn.Linear(256, 4)  # 4 classes: Negative, Neutral, Positive, Dual

    def forward(self, input_ids, attention_mask):
        _, pooled_output = self.bert(input_ids=input_ids, attention_mask=attention_mask, return_dict=False)

        # Pass through dropout
        dropout_output = self.dropout(pooled_output)

        # First fully connected layer + activation
        fc1_output = self.fc1(dropout_output)
        relu_output = self.relu(fc1_output)

        # Final fully connected layer for classification
        output = self.fc2(relu_output)

        return output


In [37]:
def train(model, train_dataloader, val_dataloader, epochs, lr):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    optimizer = AdamW(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss().to(device)

    for epoch in range(epochs):
        model.train()
        total_train_loss, total_train_acc = 0, 0

        for batch in tqdm(train_dataloader):
            optimizer.zero_grad()
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            outputs = model(input_ids, attention_mask)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            total_train_loss += loss.item()
            total_train_acc += (outputs.argmax(dim=1) == labels).sum().item()

        val_loss, val_acc = evaluate(model, val_dataloader, criterion, device)
        print(f'Epoch {epoch + 1}: Train Loss: {total_train_loss / len(train_dataloader):.3f}, Train Acc: {total_train_acc / len(df_train):.3f}, Val Loss: {val_loss:.3f}, Val Acc: {val_acc:.3f}')

def evaluate(model, dataloader, criterion, device):
    model.eval()
    total_loss, total_acc = 0, 0

    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            outputs = model(input_ids, attention_mask)
            loss = criterion(outputs, labels)
            # Harus (batch_size, 4)
            # Harus (batch_size,)
            # print(f"Output size: {outputs.size()}", flush=True)  # Force print output immediately
            # print(f"Label size: {labels.size()}", flush=True)

            total_loss += loss.item()
            total_acc += (outputs.argmax(dim=1) == labels).sum().item()

    return total_loss / len(dataloader), total_acc / len(dataloader.dataset)

In [38]:
def train_tuning_v2(model, df_train, df_val, tokenizer, max_len, epochs, lr, batch_size):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    optimizer = AdamW(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss().to(device)

    # Create DataLoaders with batch size specified during tuning
    train_dataset = FinancialNewsDataset(df_train, tokenizer, max_len)
    val_dataset = FinancialNewsDataset(df_val, tokenizer, max_len)
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_dataloader = DataLoader(val_dataset, batch_size=batch_size)

    best_val_loss = float('inf')
    patience, patience_counter = 2, 0  # Adjust patience as needed

    for epoch in range(epochs):
        model.train()
        total_train_loss, total_train_acc = 0, 0

        for batch in tqdm(train_dataloader):
            optimizer.zero_grad()
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            outputs = model(input_ids, attention_mask)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            total_train_loss += loss.item()
            total_train_acc += (outputs.argmax(dim=1) == labels).sum().item()

        val_loss, val_acc = evaluate(model, val_dataloader, criterion, device)

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
        else:
            patience_counter += 1

        if patience_counter >= patience:
            print("Early stopping triggered")
            break

        print(f'Epoch {epoch + 1}: Train Loss: {total_train_loss / len(train_dataloader):.3f}, Train Acc: {total_train_acc / len(df_train):.3f}, Val Loss: {val_loss:.3f}, Val Acc: {val_acc:.3f}')

# Evaluate function remains the same

In [39]:
model2RelU = BertClassifierTune()
EPOCHS = 3
LR = 1e-5
BATCH_SIZE = 32  # Ubah sesuai kebutuhan

train_tuning_v2(model2RelU, df_train, df_val, tokenizer, MAX_LEN, epochs=EPOCHS, lr=LR, batch_size=BATCH_SIZE)


100%|██████████| 184/184 [29:57<00:00,  9.77s/it]


Epoch 1: Train Loss: 1.073, Train Acc: 0.539, Val Loss: 0.926, Val Acc: 0.636


100%|██████████| 184/184 [30:36<00:00,  9.98s/it]


Epoch 2: Train Loss: 0.850, Train Acc: 0.687, Val Loss: 0.819, Val Acc: 0.681


100%|██████████| 184/184 [31:00<00:00, 10.11s/it]


Epoch 3: Train Loss: 0.741, Train Acc: 0.729, Val Loss: 0.792, Val Acc: 0.693


In [40]:
test_loss, test_acc = evaluate(model2RelU, test_dataloader, nn.CrossEntropyLoss(), device=torch.device("cuda" if torch.cuda.is_available() else "cpu"))
print(f'Test Loss: {test_loss:.3f}, Test Accuracy: {test_acc:.3f}')

Test Loss: 0.806, Test Accuracy: 0.690


In [41]:
import torch

# Menyimpan hanya state_dict model
torch.save(model2RelU.state_dict(), 'bert_model.pth')

In [42]:
import torch

# Memuat state_dict model
model2RelU.load_state_dict(torch.load('bert_model.pth'))

# Pastikan model dalam mode evaluasi setelah memuat
model2RelU.eval()


BertClassifierTune(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(31923, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elemen

In [43]:
torch.save(model2RelU.state_dict(), 'bert_model_state_dict.pt')

In [44]:
state_dict = torch.load('bert_model_state_dict.pt')
print(state_dict.keys())

odict_keys(['bert.embeddings.word_embeddings.weight', 'bert.embeddings.position_embeddings.weight', 'bert.embeddings.token_type_embeddings.weight', 'bert.embeddings.LayerNorm.weight', 'bert.embeddings.LayerNorm.bias', 'bert.encoder.layer.0.attention.self.query.weight', 'bert.encoder.layer.0.attention.self.query.bias', 'bert.encoder.layer.0.attention.self.key.weight', 'bert.encoder.layer.0.attention.self.key.bias', 'bert.encoder.layer.0.attention.self.value.weight', 'bert.encoder.layer.0.attention.self.value.bias', 'bert.encoder.layer.0.attention.output.dense.weight', 'bert.encoder.layer.0.attention.output.dense.bias', 'bert.encoder.layer.0.attention.output.LayerNorm.weight', 'bert.encoder.layer.0.attention.output.LayerNorm.bias', 'bert.encoder.layer.0.intermediate.dense.weight', 'bert.encoder.layer.0.intermediate.dense.bias', 'bert.encoder.layer.0.output.dense.weight', 'bert.encoder.layer.0.output.dense.bias', 'bert.encoder.layer.0.output.LayerNorm.weight', 'bert.encoder.layer.0.output