In [24]:
from dataset2 import CommentClassificationDataset, get_dataloaders, get_datasets,split_train_valid
from train import train_model, validate, get_device, find_avg_sentence_length, find_longest_length, find_word_frequency, word2int, filter_low_freq_words
from model import MultiLabelEncoderClassifier
from tokenizer import Tokenizer
import torch
import pandas as pd
import numpy as np
from tqdm import tqdm

In [2]:
device = get_device()

In [3]:
train_file_path = "/content/train.csv"
test_file_path = "/content/test.csv"

In [4]:
train_data, valid_data = split_train_valid(train_data_path=train_file_path)

In [5]:
test_data = pd.read_csv(test_file_path)

In [6]:
train_data[1110:1120]

Unnamed: 0,id,comment_text,toxic,severe_toxic,obscene,threat,insult,identity_hate
1110,2dc30f5537e35a87,"""\n\n lol? \n\nI think that there are too many...",0,0,0,0,0,0
1111,b34ad33c8532b713,"In broad strokes, this looks like the right ki...",0,0,0,0,0,0
1112,b8bc7bae97bf1807,"Sign on what? If AIM, no. And can you host the...",0,0,0,0,0,0
1113,e599cfd7b4a0e9f4,hello.. i have not removed anything so it sho...,0,0,0,0,0,0
1114,42f5912c2b77c3a1,Some Say She is a little mothyer fukin bitchy,1,0,1,0,1,0
1115,3c5726bdf7fa9f9a,"He wrote two operas one early (1934, withdraw...",0,0,0,0,0,0
1116,5aa2fa96591e9d29,"""\nOK two things you might want to look at The...",0,0,0,0,0,0
1117,a1a9395e3c7e8de5,"""==DRV on Gaia series==\nI have speedily close...",0,0,0,0,0,0
1118,ed82f77e8000952a,""":::You grossly misrepresented the Duncan sour...",0,0,0,0,0,0
1119,ab78a03c13c51cd2,"Again making a claim with no evidence, you are...",0,0,0,0,0,0


In [7]:
tokenizer = Tokenizer()

In [8]:
data_list = train_data["comment_text"].tolist()

In [9]:
input_tokens = [tokenizer.tokenize(comment) for comment in data_list]
input_words = find_word_frequency(input_tokens)

In [10]:
filtered_tokens = filter_low_freq_words(input_words, min_freq=8)
num_words = len(filtered_tokens)

In [11]:
filtered_tokens[:10]

[('.', 536202),
 (' the', 435250),
 (',', 420244),
 (' to', 266413),
 ('\n', 252552),
 (' of', 201425),
 (' and', 197581),
 (' a', 191223),
 (' you', 187444),
 (' i', 183965)]

In [12]:
vocab = word2int(filtered_tokens, num_words=num_words)

In [13]:
vocab

{'.': 1,
 ' the': 2,
 ',': 3,
 ' to': 4,
 '\n': 5,
 ' of': 6,
 ' and': 7,
 ' a': 8,
 ' you': 9,
 ' i': 10,
 ' ': 11,
 ' is': 12,
 ' that': 13,
 ' it': 14,
 ' in': 15,
 ' for': 16,
 ' not': 17,
 ' this': 18,
 ' on': 19,
 ' be': 20,
 ' as': 21,
 ' have': 22,
 ' are': 23,
 "'s": 24,
 ' your': 25,
 ' with': 26,
 '?': 27,
 ' article': 28,
 "'t": 29,
 ' was': 30,
 ' if': 31,
 ' or': 32,
 ' but': 33,
 ' page': 34,
 ' \n': 35,
 ' my': 36,
 ' an': 37,
 ' from': 38,
 ' by': 39,
 ' at': 40,
 ' do': 41,
 ' can': 42,
 '!': 43,
 ' wikipedia': 44,
 ' me': 45,
 ' about': 46,
 'i': 47,
 ' so': 48,
 ' there': 49,
 ' what': 50,
 ' has': 51,
 ' all': 52,
 ' talk': 53,
 ' will': 54,
 ' would': 55,
 ' they': 56,
 ' one': 57,
 ' like': 58,
 ' he': 59,
 ' just': 60,
 ' no': 61,
 ' been': 62,
 ' which': 63,
 ' any': 64,
 ' please': 65,
 ' we': 66,
 ' should': 67,
 ' more': 68,
 ' don': 69,
 '  ': 70,
 ' other': 71,
 ' some': 72,
 ' who': 73,
 ' here': 74,
 ' see': 75,
 ' think': 76,
 ' his': 77,
 '\n\n': 78,
 

In [14]:
tokenizer.change_vocab(vocab)

In [15]:
tokenizer.save_vocab(file_path="/content/vocab1.json")

In [16]:
tokenizer.load_vocab(file_path="/content/vocab1.json")

In [17]:
vocab_size = len(tokenizer.vocab)
vocab_size

31852

In [18]:
dataset_train, dataset_valid, dataset_test = get_datasets(train_data, valid_data, test_data, tokenizer, max_len=339)
train_loader, valid_loader, test_loader = get_dataloaders(dataset_train, dataset_valid, dataset_test, BATCH_SIZE=8)

Number of training samples: 143614
Number of validation samples: 15957
Number of test samples: 153164




In [19]:
model = MultiLabelEncoderClassifier(vocab_size=vocab_size, embed_dim=256, num_layers=4, num_heads=8, num_labels=6).to(device)
print(model)

MultiLabelEncoderClassifier(
  (emb): Embedding(31852, 256)
  (encoder_layer): TransformerEncoderLayer(
    (self_attn): MultiheadAttention(
      (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)
    )
    (linear1): Linear(in_features=256, out_features=2048, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)
    (linear2): Linear(in_features=2048, out_features=256, bias=True)
    (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
    (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
    (dropout1): Dropout(p=0.1, inplace=False)
    (dropout2): Dropout(p=0.1, inplace=False)
  )
  (encoder): TransformerEncoder(
    (layers): ModuleList(
      (0-3): 4 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)
        )
        (linear1): Linear(in_features=256, out_features=2048, bias=True)
        (dropout): 

In [20]:
# Define the optimizer and loss function
learning_rate = 1e-4
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
criterion = torch.nn.BCEWithLogitsLoss()

In [21]:
best_valid_loss = float('inf')
epochs = 5

for epoch in range(epochs):
    train_loss, train_acc = train_model(model, train_loader, optimizer, criterion, device)
    valid_loss, valid_acc = validate(model, valid_loader, criterion, device)

    print(f"Epoch {epoch+1}/{epochs}")
    print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc}")
    print(f"Valid Loss: {valid_loss:.4f}, Valid Acc: {valid_acc}")

    # If validation set loss is lower, save the best model
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), "best_model.pth")
        print(f"Best model saved at epoch {epoch+1}")

Training


100%|██████████| 17952/17952 [14:02<00:00, 21.30it/s]

Validation



100%|██████████| 1995/1995 [00:34<00:00, 58.44it/s]

Epoch 1/5
Train Loss: 0.6932, Train Acc: [90.43477655381787, 99.00566797108917, 94.72614090548275, 99.68874900775691, 95.07081482306738, 99.11011461278149]
Valid Loss: 0.6932, Valid Acc: [90.2111925800589, 98.94717052077459, 94.51651312903428, 99.79319420943786, 94.96145892084978, 99.16650999561321]
Best model saved at epoch 1
Training



100%|██████████| 17952/17952 [14:02<00:00, 21.30it/s]

Validation



100%|██████████| 1995/1995 [00:33<00:00, 59.56it/s]


Epoch 2/5
Train Loss: 0.6931, Train Acc: [90.43825810854095, 99.00636428203379, 94.72614090548275, 99.69014162964613, 95.07499268873508, 99.11429247844917]
Valid Loss: 0.6931, Valid Acc: [90.2111925800589, 98.94717052077459, 94.51651312903428, 99.79319420943786, 94.96145892084978, 99.16650999561321]
Best model saved at epoch 2
Training


100%|██████████| 17952/17952 [14:02<00:00, 21.32it/s]

Validation



100%|██████████| 1995/1995 [00:33<00:00, 59.72it/s]


Epoch 3/5
Train Loss: 0.6931, Train Acc: [90.43825810854095, 99.00636428203379, 94.72614090548275, 99.69014162964613, 95.07499268873508, 99.11429247844917]
Valid Loss: 0.6931, Valid Acc: [90.2111925800589, 98.94717052077459, 94.51651312903428, 99.79319420943786, 94.96145892084978, 99.16650999561321]
Best model saved at epoch 3
Training


100%|██████████| 17952/17952 [14:03<00:00, 21.29it/s]

Validation



100%|██████████| 1995/1995 [00:33<00:00, 59.78it/s]


Epoch 4/5
Train Loss: 0.6931, Train Acc: [90.43825810854095, 99.00636428203379, 94.72614090548275, 99.69014162964613, 95.07499268873508, 99.11429247844917]
Valid Loss: 0.6931, Valid Acc: [90.2111925800589, 98.94717052077459, 94.51651312903428, 99.79319420943786, 94.96145892084978, 99.16650999561321]
Best model saved at epoch 4
Training


100%|██████████| 17952/17952 [14:03<00:00, 21.29it/s]

Validation



100%|██████████| 1995/1995 [00:32<00:00, 61.02it/s]


Epoch 5/5
Train Loss: 0.6931, Train Acc: [90.43825810854095, 99.00636428203379, 94.72614090548275, 99.69014162964613, 95.07499268873508, 99.11429247844917]
Valid Loss: 0.6931, Valid Acc: [90.2111925800589, 98.94717052077459, 94.51651312903428, 99.79319420943786, 94.96145892084978, 99.16650999561321]
Best model saved at epoch 5


In [47]:
def model_predict(model, testloader, device):
    model.eval()
    print('Testing')

    all_predictions = []
    with torch.no_grad():
        for i, data in tqdm(enumerate(testloader), total=len(testloader)):
            inputs = data['input']
            inputs = inputs.to(device)

            # Forward pass
            outputs = model(inputs)

            predictions = (outputs >= 0.5).int()
            all_predictions.append(predictions.cpu())

    all_predictions = torch.cat(all_predictions, dim=0)
    return all_predictions

In [48]:
predict = model_predict(model, test_loader, device)

Testing


100%|██████████| 19146/19146 [04:19<00:00, 73.77it/s]


In [50]:
len(predict)

153164

In [52]:
predict[110:120]

tensor([[0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0]], dtype=torch.int32)

In [27]:
test_labels = pd.read_csv("/content/test_labels.csv")

In [33]:
test_labels

Unnamed: 0,id,toxic,severe_toxic,obscene,threat,insult,identity_hate
0,00001cee341fdb12,-1,-1,-1,-1,-1,-1
1,0000247867823ef7,-1,-1,-1,-1,-1,-1
2,00013b17ad220c46,-1,-1,-1,-1,-1,-1
3,00017563c3f7919a,-1,-1,-1,-1,-1,-1
4,00017695ad8997eb,-1,-1,-1,-1,-1,-1
...,...,...,...,...,...,...,...
153159,fffcd0960ee309b5,-1,-1,-1,-1,-1,-1
153160,fffd7a9a6eb32c16,-1,-1,-1,-1,-1,-1
153161,fffda9e8d6fafa9e,-1,-1,-1,-1,-1,-1
153162,fffe8f1340a79fc2,-1,-1,-1,-1,-1,-1


In [34]:
num_labels = test_labels.shape[1]
num_labels

7

In [36]:
labels_list = test_labels.values.tolist()

In [58]:
labels_list[0]

['00001cee341fdb12', -1, -1, -1, -1, -1, -1]

In [63]:
def cul_acc(labels, predctions):
    labels = labels.values.tolist()
    num_labels = len(labels[0]) - 1
    correct_labels = [0] * num_labels
    total_labels = 0
    for x, y in zip(labels, predctions):
        if x[1] == -1:
            continue
        else:
            total_labels += 1
            for i in range(num_labels):
                if x[i+1] == y[i]:
                    correct_labels[i] += 1

    return [t / total_labels for t in correct_labels]

In [64]:
test_acc= cul_acc(test_labels, predict)
test_acc

[0.9048110287911469,
 0.9942636531307637,
 0.9423082934758823,
 0.9967019913095126,
 0.9464347119322267,
 0.9888711744662227]