In [1]:
import pandas as pd
import numpy as np
import torch
from torch.autograd import Variable 
from tqdm import tqdm
from sklearn.metrics import classification_report
DATASET_PATH = 'output/train_1_arabert.pkl'

In [2]:
def edit_categories(x):
    if x == 'info_news':
        return 0
    elif x == 'celebrity':
        return 1
    elif x == 'plan':
        return 2
    elif x == 'requests':
        return 3
    elif x == 'rumors':
        return 4
    elif x == 'advice':
        return 5
    elif x == 'restrictions':
        return 6
    elif x == 'personal':
        return 7
    elif x == 'unrelated':
        return 8
    elif x == 'others':
        return 9
    else:
        return -1

In [3]:

# build the pytorch dataset
class ArabertDataset(torch.utils.data.Dataset):
    def __init__(self, dataset_path):
        dataset = pd.read_pickle(dataset_path)
        print(dataset['stance'].value_counts())
        pad_embedding = torch.zeros(1, 768)
        max_len = 0
        embeddings = dataset['embeddings'].values
        for i in range(len(embeddings)):
            if embeddings[i].shape[0] > max_len:
                max_len = embeddings[i].shape[0]
        for i in range(len(embeddings)):
            if embeddings[i].shape[0] < max_len:
                pad = torch.zeros(max_len - embeddings[i].shape[0], 768).to(torch.device('cuda'))
                embeddings[i] = torch.cat((embeddings[i], pad), dim=0)
        categories = dataset['category'].apply(edit_categories)
        self.embeddings = embeddings # already a tensor
        self.stance = dataset['stance']
        self.category = categories
        self.stance = torch.tensor(self.stance.values)
        self.category = torch.tensor(self.category.values)



    def __len__(self):
        return len(self.stance)

    def __getitem__(self, idx):
        return self.embeddings[idx], self.stance[idx], self.category[idx]


In [4]:
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

In [5]:
torch.cuda.is_available()

True

In [6]:
# build the pytorch dataloader
train_dataset = ArabertDataset(DATASET_PATH)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)

# load the dev set
dev_dataset = ArabertDataset('output/dev_1_arabert.pkl')
dev_loader = torch.utils.data.DataLoader(dev_dataset, batch_size=64, shuffle=True)

# build the model
class RNN(torch.nn.Module):

    def __init__(self, input_size, hidden_size, num_layers, num_classes):
        super(RNN, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.rnn = torch.nn.RNN(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = torch.nn.Linear(hidden_size, num_classes)




    def forward(self, x):
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)

        out, _ = self.rnn(x, h0)
        out = self.fc(out[:, -1, :])
        return out




2    5538
1    1012
0     438
Name: stance, dtype: int64
2    804
1    126
0     70
Name: stance, dtype: int64


In [11]:
# train the stance_model
stance_model = RNN(768, 256, 1, 10).to(device)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(stance_model.parameters(), lr=1e-3)
for epoch in range(20):
    stance_model.train()
    for i, (embedding, stance, category) in enumerate(tqdm(train_loader)):
        embedding = embedding.to(device)
        stance = stance.to(device)
        category = category.to(device)

        outputs = stance_model(embedding)
        loss = criterion(outputs, category)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # calculate the total epoch accuracy
        _, predicted = torch.max(outputs.data, 1)
        total = stance.size(0)
        correct = (predicted == category).sum().item()
        accuracy = correct / total

        if (i+1) % len(train_loader) == 0:
            pass
            # print(f'Epoch [{epoch+1}/{10}], Step [{i+1}/{len(train_loader)}], Loss: {loss.item():.4f}')
            # print(f'Epoch [{epoch+1}/{10}], Step [{i+1}/{len(train_loader)}], Accuracy: {accuracy:.4f}')

    # inference mode
    stance_model.eval()
    # get accuracy on development set
    with torch.no_grad():
        predicted = np.empty(len(dev_dataset), dtype=int)
        stances = []
        for i, (embedding, stance, category) in enumerate(dev_loader):
            embedding = embedding.to(device)
            stance = stance.to(device)
            category = category.to(device)

            outputs = stance_model(embedding)
            predicted[i*64:(i+1)*64] = torch.argmax(outputs.data, 1).cpu().numpy()
        stances = dev_dataset.stance
        print(classification_report(dev_dataset.category, predicted, zero_division=0))

100%|██████████| 110/110 [00:01<00:00, 74.87it/s]


              precision    recall  f1-score   support

           0       0.55      1.00      0.71       545
           1       0.00      0.00      0.00       145
           2       0.00      0.00      0.00        82
           3       0.00      0.00      0.00        20
           4       0.00      0.00      0.00        15
           5       0.00      0.00      0.00        10
           6       0.00      0.00      0.00         2
           7       0.00      0.00      0.00       128
           8       0.00      0.00      0.00        36
           9       0.00      0.00      0.00        17

    accuracy                           0.54      1000
   macro avg       0.05      0.10      0.07      1000
weighted avg       0.30      0.54      0.38      1000



100%|██████████| 110/110 [00:01<00:00, 79.51it/s]


              precision    recall  f1-score   support

           0       0.55      1.00      0.71       545
           1       0.00      0.00      0.00       145
           2       0.00      0.00      0.00        82
           3       0.00      0.00      0.00        20
           4       0.00      0.00      0.00        15
           5       0.00      0.00      0.00        10
           6       0.00      0.00      0.00         2
           7       0.00      0.00      0.00       128
           8       0.00      0.00      0.00        36
           9       0.00      0.00      0.00        17

    accuracy                           0.55      1000
   macro avg       0.05      0.10      0.07      1000
weighted avg       0.30      0.55      0.38      1000



100%|██████████| 110/110 [00:01<00:00, 81.85it/s]


              precision    recall  f1-score   support

           0       0.54      1.00      0.70       545
           1       0.00      0.00      0.00       145
           2       0.00      0.00      0.00        82
           3       0.00      0.00      0.00        20
           4       0.00      0.00      0.00        15
           5       0.00      0.00      0.00        10
           6       0.00      0.00      0.00         2
           7       0.00      0.00      0.00       128
           8       0.00      0.00      0.00        36
           9       0.00      0.00      0.00        17

    accuracy                           0.54      1000
   macro avg       0.05      0.10      0.07      1000
weighted avg       0.30      0.54      0.38      1000



100%|██████████| 110/110 [00:01<00:00, 82.99it/s]


              precision    recall  f1-score   support

           0       0.54      1.00      0.70       545
           1       0.00      0.00      0.00       145
           2       0.00      0.00      0.00        82
           3       0.00      0.00      0.00        20
           4       0.00      0.00      0.00        15
           5       0.00      0.00      0.00        10
           6       0.00      0.00      0.00         2
           7       0.00      0.00      0.00       128
           8       0.00      0.00      0.00        36
           9       0.00      0.00      0.00        17

    accuracy                           0.54      1000
   macro avg       0.05      0.10      0.07      1000
weighted avg       0.30      0.54      0.38      1000



100%|██████████| 110/110 [00:01<00:00, 82.66it/s]


              precision    recall  f1-score   support

           0       0.55      1.00      0.71       545
           1       0.00      0.00      0.00       145
           2       0.00      0.00      0.00        82
           3       0.00      0.00      0.00        20
           4       0.00      0.00      0.00        15
           5       0.00      0.00      0.00        10
           6       0.00      0.00      0.00         2
           7       0.00      0.00      0.00       128
           8       0.00      0.00      0.00        36
           9       0.00      0.00      0.00        17

    accuracy                           0.55      1000
   macro avg       0.05      0.10      0.07      1000
weighted avg       0.30      0.55      0.38      1000



100%|██████████| 110/110 [00:01<00:00, 82.19it/s]


              precision    recall  f1-score   support

           0       0.55      1.00      0.71       545
           1       0.00      0.00      0.00       145
           2       0.00      0.00      0.00        82
           3       0.00      0.00      0.00        20
           4       0.00      0.00      0.00        15
           5       0.00      0.00      0.00        10
           6       0.00      0.00      0.00         2
           7       0.00      0.00      0.00       128
           8       0.00      0.00      0.00        36
           9       0.00      0.00      0.00        17

    accuracy                           0.55      1000
   macro avg       0.05      0.10      0.07      1000
weighted avg       0.30      0.55      0.38      1000



100%|██████████| 110/110 [00:01<00:00, 79.84it/s]


              precision    recall  f1-score   support

           0       0.54      1.00      0.70       545
           1       0.00      0.00      0.00       145
           2       0.00      0.00      0.00        82
           3       0.00      0.00      0.00        20
           4       0.00      0.00      0.00        15
           5       0.00      0.00      0.00        10
           6       0.00      0.00      0.00         2
           7       0.00      0.00      0.00       128
           8       0.00      0.00      0.00        36
           9       0.00      0.00      0.00        17

    accuracy                           0.54      1000
   macro avg       0.05      0.10      0.07      1000
weighted avg       0.30      0.54      0.38      1000



100%|██████████| 110/110 [00:01<00:00, 82.77it/s]


              precision    recall  f1-score   support

           0       0.54      1.00      0.70       545
           1       0.00      0.00      0.00       145
           2       0.00      0.00      0.00        82
           3       0.00      0.00      0.00        20
           4       0.00      0.00      0.00        15
           5       0.00      0.00      0.00        10
           6       0.00      0.00      0.00         2
           7       0.00      0.00      0.00       128
           8       0.00      0.00      0.00        36
           9       0.00      0.00      0.00        17

    accuracy                           0.54      1000
   macro avg       0.05      0.10      0.07      1000
weighted avg       0.30      0.54      0.38      1000



100%|██████████| 110/110 [00:01<00:00, 82.14it/s]


              precision    recall  f1-score   support

           0       0.55      1.00      0.71       545
           1       0.00      0.00      0.00       145
           2       0.00      0.00      0.00        82
           3       0.00      0.00      0.00        20
           4       0.00      0.00      0.00        15
           5       0.00      0.00      0.00        10
           6       0.00      0.00      0.00         2
           7       0.00      0.00      0.00       128
           8       0.00      0.00      0.00        36
           9       0.00      0.00      0.00        17

    accuracy                           0.54      1000
   macro avg       0.05      0.10      0.07      1000
weighted avg       0.30      0.54      0.38      1000



100%|██████████| 110/110 [00:01<00:00, 82.56it/s]


              precision    recall  f1-score   support

           0       0.55      1.00      0.70       545
           1       0.00      0.00      0.00       145
           2       0.00      0.00      0.00        82
           3       0.00      0.00      0.00        20
           4       0.00      0.00      0.00        15
           5       0.00      0.00      0.00        10
           6       0.00      0.00      0.00         2
           7       0.00      0.00      0.00       128
           8       0.00      0.00      0.00        36
           9       0.00      0.00      0.00        17

    accuracy                           0.54      1000
   macro avg       0.05      0.10      0.07      1000
weighted avg       0.30      0.54      0.38      1000



100%|██████████| 110/110 [00:01<00:00, 81.41it/s]


              precision    recall  f1-score   support

           0       0.55      1.00      0.71       545
           1       0.00      0.00      0.00       145
           2       0.00      0.00      0.00        82
           3       0.00      0.00      0.00        20
           4       0.00      0.00      0.00        15
           5       0.00      0.00      0.00        10
           6       0.00      0.00      0.00         2
           7       1.00      0.01      0.02       128
           8       0.00      0.00      0.00        36
           9       0.00      0.00      0.00        17

    accuracy                           0.55      1000
   macro avg       0.15      0.10      0.07      1000
weighted avg       0.43      0.55      0.39      1000



100%|██████████| 110/110 [00:01<00:00, 81.51it/s]


              precision    recall  f1-score   support

           0       0.55      1.00      0.71       545
           1       0.00      0.00      0.00       145
           2       0.00      0.00      0.00        82
           3       0.00      0.00      0.00        20
           4       0.00      0.00      0.00        15
           5       0.00      0.00      0.00        10
           6       0.00      0.00      0.00         2
           7       0.00      0.00      0.00       128
           8       0.00      0.00      0.00        36
           9       0.00      0.00      0.00        17

    accuracy                           0.55      1000
   macro avg       0.05      0.10      0.07      1000
weighted avg       0.30      0.55      0.38      1000



100%|██████████| 110/110 [00:01<00:00, 81.74it/s]


              precision    recall  f1-score   support

           0       0.55      1.00      0.71       545
           1       0.00      0.00      0.00       145
           2       0.00      0.00      0.00        82
           3       0.00      0.00      0.00        20
           4       0.00      0.00      0.00        15
           5       0.00      0.00      0.00        10
           6       0.00      0.00      0.00         2
           7       0.00      0.00      0.00       128
           8       0.00      0.00      0.00        36
           9       0.00      0.00      0.00        17

    accuracy                           0.55      1000
   macro avg       0.05      0.10      0.07      1000
weighted avg       0.30      0.55      0.38      1000



100%|██████████| 110/110 [00:01<00:00, 85.14it/s]


              precision    recall  f1-score   support

           0       0.54      1.00      0.70       545
           1       0.00      0.00      0.00       145
           2       0.00      0.00      0.00        82
           3       0.00      0.00      0.00        20
           4       0.00      0.00      0.00        15
           5       0.00      0.00      0.00        10
           6       0.00      0.00      0.00         2
           7       0.00      0.00      0.00       128
           8       0.00      0.00      0.00        36
           9       0.00      0.00      0.00        17

    accuracy                           0.54      1000
   macro avg       0.05      0.10      0.07      1000
weighted avg       0.30      0.54      0.38      1000



100%|██████████| 110/110 [00:01<00:00, 85.26it/s]


              precision    recall  f1-score   support

           0       0.54      1.00      0.70       545
           1       0.00      0.00      0.00       145
           2       0.00      0.00      0.00        82
           3       0.00      0.00      0.00        20
           4       0.00      0.00      0.00        15
           5       0.00      0.00      0.00        10
           6       0.00      0.00      0.00         2
           7       0.00      0.00      0.00       128
           8       0.00      0.00      0.00        36
           9       0.00      0.00      0.00        17

    accuracy                           0.54      1000
   macro avg       0.05      0.10      0.07      1000
weighted avg       0.30      0.54      0.38      1000



100%|██████████| 110/110 [00:01<00:00, 85.20it/s]


              precision    recall  f1-score   support

           0       0.55      1.00      0.71       545
           1       1.00      0.01      0.01       145
           2       0.00      0.00      0.00        82
           3       0.00      0.00      0.00        20
           4       0.00      0.00      0.00        15
           5       0.00      0.00      0.00        10
           6       0.00      0.00      0.00         2
           7       0.00      0.00      0.00       128
           8       0.00      0.00      0.00        36
           9       0.00      0.00      0.00        17

    accuracy                           0.55      1000
   macro avg       0.15      0.10      0.07      1000
weighted avg       0.44      0.55      0.39      1000



100%|██████████| 110/110 [00:01<00:00, 84.47it/s]


              precision    recall  f1-score   support

           0       0.55      1.00      0.71       545
           1       0.00      0.00      0.00       145
           2       0.00      0.00      0.00        82
           3       0.00      0.00      0.00        20
           4       0.00      0.00      0.00        15
           5       0.00      0.00      0.00        10
           6       0.00      0.00      0.00         2
           7       0.00      0.00      0.00       128
           8       0.00      0.00      0.00        36
           9       0.00      0.00      0.00        17

    accuracy                           0.55      1000
   macro avg       0.05      0.10      0.07      1000
weighted avg       0.30      0.55      0.38      1000



100%|██████████| 110/110 [00:01<00:00, 81.44it/s]


              precision    recall  f1-score   support

           0       0.55      1.00      0.71       545
           1       0.00      0.00      0.00       145
           2       0.00      0.00      0.00        82
           3       0.00      0.00      0.00        20
           4       0.00      0.00      0.00        15
           5       0.00      0.00      0.00        10
           6       0.00      0.00      0.00         2
           7       0.00      0.00      0.00       128
           8       0.00      0.00      0.00        36
           9       0.00      0.00      0.00        17

    accuracy                           0.54      1000
   macro avg       0.05      0.10      0.07      1000
weighted avg       0.30      0.54      0.38      1000



100%|██████████| 110/110 [00:01<00:00, 83.58it/s]


              precision    recall  f1-score   support

           0       0.54      1.00      0.70       545
           1       0.00      0.00      0.00       145
           2       0.00      0.00      0.00        82
           3       0.00      0.00      0.00        20
           4       0.00      0.00      0.00        15
           5       0.00      0.00      0.00        10
           6       0.00      0.00      0.00         2
           7       0.00      0.00      0.00       128
           8       0.00      0.00      0.00        36
           9       0.00      0.00      0.00        17

    accuracy                           0.54      1000
   macro avg       0.05      0.10      0.07      1000
weighted avg       0.30      0.54      0.38      1000



100%|██████████| 110/110 [00:01<00:00, 85.08it/s]

              precision    recall  f1-score   support

           0       0.54      1.00      0.70       545
           1       0.00      0.00      0.00       145
           2       0.00      0.00      0.00        82
           3       0.00      0.00      0.00        20
           4       0.00      0.00      0.00        15
           5       0.00      0.00      0.00        10
           6       0.00      0.00      0.00         2
           7       0.00      0.00      0.00       128
           8       0.00      0.00      0.00        36
           9       0.00      0.00      0.00        17

    accuracy                           0.54      1000
   macro avg       0.05      0.10      0.07      1000
weighted avg       0.30      0.54      0.38      1000






In [None]:
# Results
# RNN + train_1_arabert.pkl --> 80.4% dev accuracy
# RNN + train_2_arabert.pkl --> 71.0% dev accuracy
# RNN + train_3_arabert.pkl --> 7.2% dev accuracy
# RNN + train_4_arabert.pkl --> 28.9% dev accuracy

In [None]:
# # train the category_model
# category_model = RNN(768, 512, 4, 10).to(device)
# criterion = torch.nn.CrossEntropyLoss()
# optimizer = torch.optim.Adam(category_model.parameters(), lr=1e-3)
# category_model.train()
# for epoch in range(10):
#     for i, (embedding, stance, category) in enumerate(tqdm(train_loader)):
#         embedding = embedding.to(device)
#         stance = stance.to(device)
#         category = category.to(device)

#         outputs = category_model(embedding)
#         loss = criterion(outputs, category)

#         optimizer.zero_grad()
#         loss.backward()
#         optimizer.step()

#         # calculate the total epoch accuracy
#         _, predicted = torch.max(outputs.data, 1)
#         total = category.size(0)
#         correct = (predicted == category).sum().item()
#         accuracy = correct / total

#         if (i+1) % len(train_loader) == 0:
#             print(f'Epoch [{epoch+1}/{10}], Step [{i+1}/{len(train_loader)}], Loss: {loss.item():.4f}')
#             print(f'Epoch [{epoch+1}/{10}], Step [{i+1}/{len(train_loader)}], Accuracy: {accuracy:.4f}')

In [None]:
# # inference mode
# category_model.eval()

# # load the dev set
# dev_dataset = ArabertDataset('output/dev_1_arabert.pkl')
# dev_loader = torch.utils.data.DataLoader(dev_dataset, batch_size=64, shuffle=True)

# # get accuracy on development set
# with torch.no_grad():
#     correct = 0
#     total = 0
#     for embedding, stance, category in dev_loader:
#         embedding = embedding.to(device)
#         stance = stance.to(device)
#         category = category.to(device)

#         outputs = category_model(embedding)
#         _, predicted = torch.max(outputs.data, 1)
#         total += category.size(0)
#         correct += (predicted == category).sum().item()
#     print(f'Accuracy of the model on the dev set: {100 * correct / total}%')


In [None]:
# Results
# RNN + train_1_arabert.pkl --> 54.5% dev accuracy
# RNN + train_2_arabert.pkl --> 54.4% dev accuracy
# RNN + train_3_arabert.pkl --> 7.2% dev accuracy
# RNN + train_4_arabert.pkl --> 54.5% dev accuracy