## **Mount Drive**

In [1]:
from google.colab import drive
drive.mount("/content/gdrive")

Mounted at /content/gdrive


In [2]:
cd /content/gdrive/MyDrive/VIT/Tamil Argumentation

/content/gdrive/MyDrive/VIT/Tamil Argumentation


## **Install**

In [3]:
pip install transformers nltk

Collecting transformers
  Downloading transformers-4.33.2-py3-none-any.whl (7.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.6/7.6 MB[0m [31m56.1 MB/s[0m eta [36m0:00:00[0m
Collecting huggingface-hub<1.0,>=0.15.1 (from transformers)
  Downloading huggingface_hub-0.17.3-py3-none-any.whl (295 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m295.0/295.0 kB[0m [31m31.2 MB/s[0m eta [36m0:00:00[0m
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1 (from transformers)
  Downloading tokenizers-0.13.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.8/7.8 MB[0m [31m117.1 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting safetensors>=0.3.1 (from transformers)
  Downloading safetensors-0.3.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m79.2 MB/s[0m eta [36m0:00:

In [4]:
import nltk
nltk.download('punkt')

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.


True

## **Import Libraries**

In [5]:
import pandas as pd
import numpy as np

from tqdm import tqdm

from copy import deepcopy

from sklearn import metrics
from sklearn.model_selection import KFold

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader

from transformers import AutoTokenizer,AutoModel,BertTokenizer

from nltk.tokenize import word_tokenize

## **Import Dataset**

In [6]:
df = pd.read_excel("/content/gdrive/MyDrive/VIT/Tamil Argumentation/Twitter Comment Dataset.xlsx")

In [7]:
df.head()

Unnamed: 0,S No,Tweet,Date of Tweet,Topic,Parent Tweet,Language,Quality,Stance,Argument,Comment,Responding to Tone,Discussing Writer Characteristics,Remark,Relevancy
0,1,"Bro imagine today is Friday , big star movie i...",2018-05-22,Jalikattu - Ban or Allow,"And tamil people, jalikattu maadu for money an...",ENGLISH,Med,Undetermined,0,1,0,0,0,Relevant
1,2,Dei unnoda akkarai TN mela not on others and w...,2018-05-22,Jalikattu - Ban or Allow,"And tamil people, jalikattu maadu for money an...",ENGLISH,Med,Against,0,1,0,1,0,Relevant
2,3,En ninga ivara matum mention panuringa naraiya...,2018-05-22,Jalikattu - Ban or Allow,"And tamil people, jalikattu maadu for money an...",CODE-MIXED,Med,For,0,1,0,0,0,Relevant
3,4,What is happening in Thoothukudi is totally no...,2018-05-22,Jalikattu - Ban or Allow,"And tamil people, jalikattu maadu for money an...",ENGLISH,High,Against,1,1,0,0,0,Relevant
4,5,Ungaluku Sterlite protest prachanaya illa Bala...,2018-05-22,Jalikattu - Ban or Allow,"And tamil people, jalikattu maadu for money an...",CODE-MIXED,Med,Undetermined,0,0,1,0,0,Relevant


In [8]:
df.columns

Index(['S No', 'Tweet', 'Date of Tweet', 'Topic', 'Parent Tweet', 'Language',
       'Quality', 'Stance', 'Argument', 'Comment', 'Responding to Tone',
       'Discussing Writer Characteristics', 'Remark', 'Relevancy'],
      dtype='object')

## **Load Text and Labels**

In [9]:
text = df["Tweet"].to_numpy()
pt = df["Parent Tweet"].to_numpy()
topic = df["Topic"].to_numpy()

Language_label = df["Language"].to_numpy()
Stance_label = df["Stance"].to_numpy()
Quality_label = df["Quality"].to_numpy()
Argument_label = df["Argument"].to_numpy()
Comment_label = df["Comment"].to_numpy()
Writer_label = df["Discussing Writer Characteristics"].to_numpy()
Tone_label = df["Responding to Tone"].to_numpy()
Remark_label = df["Remark"].to_numpy()
Relevancy_label = df["Relevancy"].to_numpy()

## **Label Encoding**

In [10]:
encode_dict_quality = {
    "High": np.array([1, 0, 0]),
    "Med": np.array([0, 1, 0]),
    "Low": np.array([0, 0, 1]),
}

encode_dict_language = {
    "ENGLISH": np.array([1, 0, 0]),
    "TAMIL": np.array([0, 1, 0]),
    "CODE-MIXED": np.array([0, 0, 1]),
}

encode_dict_stance = {
    "For": np.array([1, 0, 0]),
    "Against": np.array([0, 1, 0]),
    "Undetermined": np.array([0, 0, 1]),
}

encode_dict = {
    0: np.array([1, 0]),
    1: np.array([0, 1]),
}

encode_dict_relevancy = {
    "Relevant": np.array([1, 0]),
    "Irrelevant": np.array([0, 1]),
}

In [11]:
Language_label = np.array([encode_dict_language[label] for label in Language_label])
Quality_label = np.array([encode_dict_quality[label] for label in Quality_label])
Argument_label = np.array([encode_dict[label] for label in Argument_label])
Comment_label = np.array([encode_dict[label] for label in Comment_label])
Writer_label = np.array([encode_dict[label] for label in Writer_label])
Tone_label = np.array([encode_dict[label] for label in Tone_label])
Remark_label = np.array([encode_dict[label] for label in Remark_label])
Relevancy_label = np.array([encode_dict_relevancy[label] for label in Relevancy_label])
Stance_label = np.array([encode_dict_stance[label] for label in Stance_label])

## **Pre-Config for mBERT**

In [12]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

MAX_LEN = np.max([(len(x)+len(y)+1) for x,y in zip(text,pt)])
MAX_LEN = np.min([MAX_LEN, 510])

BATCH_SIZE = 32
LEARNING_RATE = 1e-1

In [13]:
tokenizer = AutoTokenizer.from_pretrained('xlm-roberta-large')

Downloading (…)lve/main/config.json:   0%|          | 0.00/616 [00:00<?, ?B/s]

Downloading (…)tencepiece.bpe.model:   0%|          | 0.00/5.07M [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/9.10M [00:00<?, ?B/s]

## **Build Dataset for mBERT**

In [14]:
class ModelDataset(Dataset):
    def __init__(self, X, PT, topic, y, tokenizer, max_len):
        self.max_len = max_len
        self.text = X
        self.topic = topic
        self.pt = PT
        self.tokenizer = tokenizer
        self.targets = y

    def __len__(self):
        return len(self.targets)

    def __getitem__(self, index):
        text = self.text[index]
        pt = self.pt[index]
        topic = self.topic[index]

        concat_text = f"{topic}. {pt}"

        inputs = self.tokenizer(
            text,
            truncation=True,
            add_special_tokens=True,
            max_length=self.max_len,
            padding='max_length',
            return_token_type_ids=True)

        ids = inputs['input_ids']
        mask = inputs['attention_mask']
        token_type_ids = inputs['token_type_ids']

        return {
            'ids': torch.tensor(ids, dtype=torch.long).to(device),
            'mask': torch.tensor(mask, dtype=torch.long).to(device),
            'token_type_ids': torch.tensor(token_type_ids, dtype=torch.long).to(device),
            'targets': torch.tensor(self.targets[index], dtype=torch.float).to(device),
        }

## **Build Model**

In [24]:
class CustomModel(nn.Module):

    def __init__(self):
        super(CustomModel, self).__init__()
        self.bert = AutoModel.from_pretrained('xlm-roberta-large')

        for param in self.bert.parameters():
            param.requires_grad = False

        self.out_layer = nn.Linear(1024, 3)

    def forward(self, ids, mask, token_type_ids):
        _, features = self.bert(
            ids, token_type_ids=token_type_ids,
            attention_mask=mask, return_dict=False
        )

        output = self.out_layer(features)

        return output

## **Train Model**

In [16]:
def train(epoch, model, train_loader, loss_fn, optimizer):

    model.train()

    for batch in tqdm(train_loader):

        optimizer.zero_grad()

        ids = batch['ids'].to(device, dtype = torch.long)
        mask = batch['mask'].to(device, dtype = torch.long)
        token_type_ids = batch['token_type_ids'].to(device, dtype = torch.long)
        targets = batch['targets'].to(device, dtype = torch.float)

        outputs = model(ids, mask, token_type_ids)

        loss = loss_fn(outputs, targets)

        loss.backward()
        optimizer.step()


    print(f'Epoch: {epoch + 1}, Loss:  {loss.item()}')

In [17]:
def validation(data_loader, model):

    model.eval()
    targets = []
    outputs = []

    with torch.no_grad():

        for batch in data_loader:

            ids = batch['ids'].to(device, dtype = torch.long)
            mask = batch['mask'].to(device, dtype = torch.long)
            token_type_ids = batch['token_type_ids'].to(device, dtype = torch.long)
            batch_targets = batch['targets'].to(device, dtype = torch.float)

            batch_outputs = model(ids, mask, token_type_ids)

            targets.extend(batch_targets.cpu().numpy().tolist())
            outputs.extend(batch_outputs.cpu().numpy().tolist())

    return outputs, targets

# Model initialization

In [25]:
kf = KFold(n_splits=5)

model_targets = []
model_labels = []

model = CustomModel().to(device)

loss_fn = nn.BCEWithLogitsLoss()
optimizer = torch.optim.AdamW(params=model.parameters(), lr=LEARNING_RATE)

In [None]:
ksi = 0

for train_index, test_index in kf.split(text):

    text_train, text_test = text[train_index], text[test_index]
    pt_train, pt_test = pt[train_index], pt[test_index]
    topic_train, topic_test = topic[train_index], topic[test_index]
    labels_train, labels_test = Quality_label[train_index], Quality_label[test_index]

    train_data = ModelDataset(text_train, pt_train, topic_train, labels_train, tokenizer, MAX_LEN)
    test_data = ModelDataset(text_test, pt_test, topic_test, labels_test, tokenizer, MAX_LEN)

    train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)
    test_loader = DataLoader(test_data, batch_size=BATCH_SIZE)

    best_score = -np.inf
    best_weights = None

    EPOCHS = 5

    print(f"Split   ---->    {ksi}")
    ksi+=1

    torch.save(model.state_dict(),"/content/gdrive/MyDrive/VIT/Tamil Argumentation/twitter_model_argument_w3i.pth")

    for epoch in range(EPOCHS):

        train(epoch, model, train_loader, loss_fn, optimizer)
        outputs, targets = validation(test_loader, model)

        print(outputs)
        print(" ")
        print(targets)

        targets = [np.argmax(x) for x in targets]
        outputs = [np.argmax(x) for x in outputs]

        score = metrics.f1_score(targets, outputs, average='weighted')

        if score > best_score:
            best_score = score
            best_weights = deepcopy(model.state_dict())

    model.load_state_dict(best_weights)
    torch.save(model.state_dict(),"/content/gdrive/MyDrive/VIT/Tamil Argumentation/twitter_model_argument_w3i.pth")

    outputs, targets = validation(test_loader, model)

    targets = [np.argmax(x) for x in targets]
    outputs = [np.argmax(x) for x in outputs]

    model_targets.extend(targets)
    model_labels.extend(outputs)

Split   ---->    0


100%|██████████| 34/34 [01:48<00:00,  3.18s/it]


Epoch: 1, Loss:  1.86967134475708
[[2.1826131343841553, 0.8460273742675781, -8.266489028930664], [0.13294443488121033, 1.4864377975463867, -7.731884479522705], [2.0893473625183105, 0.9235100746154785, -8.412938117980957], [3.6496565341949463, 0.7178024053573608, -9.590620040893555], [1.067246913909912, 1.8587608337402344, -9.010542869567871], [0.32825735211372375, 1.7714769840240479, -8.11626148223877], [0.9991947412490845, 1.4064373970031738, -8.100919723510742], [1.6864638328552246, 1.1849333047866821, -8.275257110595703], [1.6346983909606934, 0.6264225244522095, -7.35408878326416], [1.8790805339813232, 1.2136375904083252, -8.720536231994629], [1.8249820470809937, 0.8372930288314819, -8.647817611694336], [0.1165735051035881, 2.0016086101531982, -8.654379844665527], [0.9342275261878967, 1.3152744770050049, -8.043020248413086], [3.0170130729675293, 1.321244239807129, -10.134750366210938], [1.5941863059997559, 1.1292259693145752, -8.106264114379883], [2.751443386077881, 0.78596895933151

100%|██████████| 34/34 [01:58<00:00,  3.49s/it]


Epoch: 2, Loss:  0.9774911999702454
[[-10.12330150604248, 2.6109023094177246, -3.8482415676116943], [-13.05786418914795, 3.3790130615234375, -2.7519423961639404], [-11.703072547912598, 2.340658187866211, -2.541114330291748], [-9.189213752746582, 2.196193218231201, -4.494356155395508], [-12.215758323669434, 3.615398406982422, -3.9285056591033936], [-12.594670295715332, 3.832651138305664, -3.5448226928710938], [-11.546026229858398, 3.3478569984436035, -3.7291030883789062], [-10.902073860168457, 3.006666421890259, -3.714461326599121], [-10.471332550048828, 2.178392171859741, -2.9657657146453857], [-12.431539535522461, 2.761444330215454, -2.5574464797973633], [-12.501795768737793, 2.2991878986358643, -2.4404900074005127], [-13.566253662109375, 3.9745919704437256, -3.5654549598693848], [-12.7380952835083, 2.8395519256591797, -2.4718334674835205], [-10.40245532989502, 2.9799349308013916, -4.858643054962158], [-12.19206714630127, 2.625493049621582, -2.328030824661255], [-11.446743965148926, 2

100%|██████████| 34/34 [01:58<00:00,  3.49s/it]


Epoch: 3, Loss:  0.724330484867096
[[-2.9642789363861084, 1.813733458518982, -7.5500969886779785], [-5.8058247566223145, 2.0586211681365967, -6.149885654449463], [-3.766608953475952, 1.0682953596115112, -6.1944684982299805], [-1.6742931604385376, 1.4065477848052979, -8.024971008300781], [-4.853034496307373, 2.402371406555176, -7.276589393615723], [-5.554812908172607, 2.645049571990967, -7.03138542175293], [-4.538321495056152, 2.43827486038208, -7.399713039398193], [-3.7521398067474365, 2.1057896614074707, -7.408641338348389], [-3.6072094440460205, 1.2888318300247192, -6.690744400024414], [-4.325883388519287, 1.501339316368103, -6.2574357986450195], [-4.490044116973877, 0.9026466608047485, -5.774873733520508], [-6.419914722442627, 2.6510684490203857, -6.72913122177124], [-5.054781436920166, 1.3492212295532227, -5.9724249839782715], [-3.0333573818206787, 2.06846284866333, -8.173474311828613], [-4.309390068054199, 1.3916107416152954, -6.151628017425537], [-3.2635879516601562, 0.9649686813

100%|██████████| 34/34 [01:58<00:00,  3.49s/it]


Epoch: 4, Loss:  2.1821706295013428
[[-5.274359703063965, 5.980621814727783, -1.172497034072876], [-8.024287223815918, 6.316545009613037, 0.11056391894817352], [-6.0268378257751465, 4.774564266204834, -0.2049611359834671], [-4.250332832336426, 5.209606647491455, -2.079437255859375], [-7.20375919342041, 6.370400428771973, -1.13273286819458], [-7.749604225158691, 6.745360851287842, -0.6142783761024475], [-6.813478946685791, 6.6990580558776855, -0.9981051087379456], [-6.047542572021484, 6.175288677215576, -1.015675663948059], [-5.716564655303955, 5.643859386444092, -0.04308970645070076], [-6.604386329650879, 5.068131923675537, -0.4625924825668335], [-6.8602213859558105, 4.567379474639893, -0.011694509536027908], [-8.802384376525879, 6.676804065704346, -0.605187714099884], [-7.237304210662842, 5.341126918792725, 0.1497558206319809], [-5.755792617797852, 5.862400531768799, -2.52217960357666], [-6.478200435638428, 5.194299697875977, -0.09680034220218658], [-5.577139377593994, 4.4135460853576

100%|██████████| 34/34 [01:58<00:00,  3.49s/it]


Epoch: 5, Loss:  0.6421096920967102
[[-7.350135803222656, 3.5157978534698486, -3.7398219108581543], [-10.316859245300293, 4.145038604736328, -2.424607753753662], [-7.713814735412598, 2.472522497177124, -2.8434791564941406], [-5.747408866882324, 2.740723133087158, -4.27613639831543], [-9.336729049682617, 4.316552639007568, -3.307506799697876], [-10.135300636291504, 4.666316509246826, -3.034193515777588], [-9.085040092468262, 4.473447799682617, -3.5479326248168945], [-8.197060585021973, 3.782703161239624, -3.555819272994995], [-7.939848899841309, 3.2871572971343994, -3.1193182468414307], [-8.17988395690918, 2.80157470703125, -2.885906219482422], [-8.277424812316895, 2.2837798595428467, -2.346763849258423], [-10.985751152038574, 4.856649875640869, -2.822981119155884], [-9.29088306427002, 3.143974781036377, -2.57677960395813], [-7.21896505355835, 3.7989096641540527, -4.545895576477051], [-8.335445404052734, 2.954690456390381, -2.881932258605957], [-6.909828186035156, 1.9956547021865845, -3

100%|██████████| 34/34 [01:59<00:00,  3.51s/it]


Epoch: 1, Loss:  1.076370120048523
[[-3.7089149951934814, 9.174492835998535, 1.111815094947815], [-2.753838300704956, 8.312248229980469, 0.8939235806465149], [-3.9160654544830322, 9.678693771362305, 0.9450839757919312], [-5.492136478424072, 9.669410705566406, 1.1430546045303345], [-5.407599449157715, 10.42658805847168, 0.5656700134277344], [-4.629369258880615, 9.826030731201172, 0.7937548160552979], [-4.816736698150635, 8.41663646697998, 2.302821636199951], [-3.1956946849823, 8.700586318969727, 0.9140645265579224], [-2.766728162765503, 9.163942337036133, 0.6569986343383789], [-3.6870524883270264, 9.149633407592773, 1.3685358762741089], [-3.184901237487793, 9.230587005615234, 0.7336919903755188], [-5.4952473640441895, 10.083101272583008, 1.4369161128997803], [-4.275825500488281, 9.213438034057617, 0.2725071609020233], [-1.8318513631820679, 7.806940078735352, 0.11055521667003632], [-2.744551420211792, 8.534598350524902, 0.627199649810791], [-4.684131622314453, 9.439849853515625, 1.371051

100%|██████████| 34/34 [01:58<00:00,  3.48s/it]


Epoch: 2, Loss:  1.5401407480239868
[[-16.764942169189453, 1.2082173824310303, -6.441643238067627], [-15.082643508911133, 1.0674382448196411, -6.6787567138671875], [-17.172351837158203, 1.4217851161956787, -6.5828423500061035], [-18.64004898071289, 1.840468168258667, -6.470963001251221], [-18.735788345336914, 2.328502655029297, -7.052701473236084], [-17.893098831176758, 2.001520872116089, -6.979443073272705], [-17.75782012939453, 0.8106304407119751, -5.2812418937683105], [-15.747543334960938, 0.9916549324989319, -6.43856954574585], [-15.754266738891602, 1.3091816902160645, -7.050586700439453], [-17.006879806518555, 1.1426427364349365, -6.270026206970215], [-16.037513732910156, 1.4084525108337402, -6.9242448806762695], [-19.350549697875977, 1.8504470586776733, -6.1688761711120605], [-16.722980499267578, 1.6136218309402466, -7.171061038970947], [-13.176583290100098, 1.0810668468475342, -7.498267650604248], [-15.271506309509277, 1.259535789489746, -6.826411247253418], [-17.981849670410156

100%|██████████| 34/34 [01:58<00:00,  3.50s/it]


Epoch: 3, Loss:  2.2811081409454346
[[-9.721899032592773, 8.738327026367188, 3.1455376148223877], [-8.794689178466797, 7.470325469970703, 2.792914628982544], [-10.001372337341309, 9.253175735473633, 2.7992186546325684], [-11.508964538574219, 9.554261207580566, 3.638514757156372], [-11.556384086608887, 10.274189949035645, 2.629408597946167], [-10.766060829162598, 9.86863899230957, 2.6355297565460205], [-10.69267749786377, 7.551551818847656, 5.107273101806641], [-9.17305850982666, 7.726783275604248, 3.0585439205169678], [-8.896692276000977, 8.90875244140625, 1.952989935874939], [-9.797629356384277, 8.797014236450195, 3.3995237350463867], [-9.1932954788208, 8.751394271850586, 2.377136468887329], [-11.786687850952148, 10.029287338256836, 3.6973772048950195], [-10.255156517028809, 8.643953323364258, 2.5429527759552], [-7.712996006011963, 6.694910049438477, 1.1651427745819092], [-8.88786506652832, 7.59335994720459, 2.5466363430023193], [-10.847393035888672, 9.213390350341797, 3.6204018592834

100%|██████████| 34/34 [01:58<00:00,  3.50s/it]


Epoch: 4, Loss:  1.3577053546905518
[[-10.99296760559082, 0.397072434425354, -3.4239602088928223], [-9.657369613647461, -0.5537306666374207, -3.2689013481140137], [-11.420687675476074, 0.7468050122261047, -3.932683229446411], [-12.867535591125488, 1.1814143657684326, -3.0540924072265625], [-13.064717292785645, 1.8852226734161377, -4.371572494506836], [-12.241199493408203, 1.5551738739013672, -4.192292213439941], [-12.045987129211426, -0.5728477835655212, -1.5414843559265137], [-10.113299369812012, -0.4983702301979065, -3.2012131214141846], [-10.095377922058105, 0.5354838967323303, -4.568934440612793], [-11.22758674621582, 0.275205135345459, -3.0007801055908203], [-10.305493354797363, 0.39169609546661377, -4.142947196960449], [-13.56532096862793, 1.4902629852294922, -3.493473529815674], [-11.251694679260254, 0.7211552262306213, -4.334245681762695], [-8.0639066696167, -0.9044947624206543, -4.701113700866699], [-9.662291526794434, -0.4064345955848694, -3.66575288772583], [-12.397576332092

100%|██████████| 34/34 [01:59<00:00,  3.50s/it]


Epoch: 5, Loss:  0.8921444416046143
[[1.0689570903778076, 0.8103340268135071, -1.3678945302963257], [1.7860907316207886, -0.051256515085697174, -1.472464919090271], [0.5945842266082764, 1.1319199800491333, -1.6515424251556396], [-0.5746464729309082, 1.2416880130767822, -1.0650300979614258], [-0.9612016677856445, 2.0870282649993896, -2.1408486366271973], [-0.07716962695121765, 1.6781660318374634, -2.1274256706237793], [-0.02302449941635132, -0.11579600721597672, 0.11104031652212143], [1.4261491298675537, -0.06252846866846085, -1.22451651096344], [1.8582208156585693, 0.895742654800415, -2.4108493328094482], [0.9617627859115601, 0.5647225379943848, -0.800377607345581], [1.6266523599624634, 0.6615904569625854, -2.023301839828491], [-1.4009875059127808, 1.7595874071121216, -1.4063096046447754], [0.09304463863372803, 1.2806166410446167, -2.597132444381714], [2.7334742546081543, -0.03702641278505325, -3.1212542057037354], [1.9757936000823975, 0.21242058277130127, -1.8160014152526855], [-0.271

100%|██████████| 34/34 [01:59<00:00,  3.51s/it]


Epoch: 1, Loss:  0.5530247688293457
[[-5.158647060394287, -0.43238991498947144, -5.090697765350342], [-6.951864242553711, 0.7814956903457642, -5.32151985168457], [-4.220575332641602, -0.8178426027297974, -5.343008995056152], [-6.569844722747803, 0.21568305790424347, -5.314706325531006], [-2.681157350540161, -1.0175867080688477, -6.003741264343262], [-5.326478958129883, -0.6908144354820251, -5.046372890472412], [-4.012269496917725, -1.2223855257034302, -4.858967304229736], [-5.957605361938477, -0.5117477178573608, -4.39003849029541], [-3.0584945678710938, -2.040773868560791, -4.2745513916015625], [-3.73649001121521, -1.299194097518921, -4.978994846343994], [-3.813781261444092, -1.914046287536621, -3.7977499961853027], [-4.181523323059082, -1.2116281986236572, -4.9075398445129395], [-3.816814661026001, -1.4963738918304443, -4.806889057159424], [-4.34112548828125, -1.3131449222564697, -4.491595268249512], [-4.318157196044922, -1.24001944065094, -4.705649375915527], [-4.752966403961182, -1

100%|██████████| 34/34 [01:58<00:00,  3.50s/it]


Epoch: 2, Loss:  1.2528316974639893
[[-3.6548662185668945, 2.780264139175415, -6.3885626792907715], [-5.071595191955566, 3.036573886871338, -6.837707996368408], [-2.858099937438965, 2.041968822479248, -6.388732433319092], [-5.174205780029297, 3.6358072757720947, -6.555550575256348], [-0.8396937251091003, 0.9842356443405151, -7.324494361877441], [-3.926753044128418, 2.589475154876709, -6.304854393005371], [-2.715972423553467, 2.317816972732544, -6.066396713256836], [-4.73021936416626, 2.758422374725342, -5.518999099731445], [-1.6222501993179321, 0.9891412258148193, -5.5916266441345215], [-2.3740382194519043, 2.205410957336426, -6.202207565307617], [-2.3376851081848145, 1.1645140647888184, -5.231206893920898], [-2.2357072830200195, 1.3588566780090332, -6.563809871673584], [-2.1902992725372314, 1.506534457206726, -6.1645355224609375], [-2.5048277378082275, 1.241504430770874, -5.880924224853516], [-2.6728644371032715, 1.7107478380203247, -6.11234188079834], [-3.2731056213378906, 1.50273478

100%|██████████| 34/34 [01:58<00:00,  3.49s/it]


Epoch: 3, Loss:  0.8336901664733887
[[-1.1746220588684082, 1.3655128479003906, -4.571236610412598], [-2.3798093795776367, 2.0900044441223145, -5.123839378356934], [-0.18356603384017944, 0.7928826808929443, -4.55627965927124], [-2.697526693344116, 1.9749605655670166, -4.829776763916016], [1.722444772720337, 0.6000544428825378, -5.591687202453613], [-1.485947847366333, 1.1450705528259277, -4.556193828582764], [-0.3360559344291687, 0.8352972269058228, -4.226964950561523], [-2.0686540603637695, 1.0939651727676392, -3.8066153526306152], [0.8081074357032776, 0.0034512951970100403, -3.9624457359313965], [-0.060165755450725555, 0.7091835141181946, -4.3633036613464355], [0.03382677584886551, 0.13379159569740295, -3.5908305644989014], [0.06936836987733841, 0.7146912813186646, -4.721433162689209], [0.14390403032302856, 0.44560807943344116, -4.329423427581787], [-0.07628457993268967, 0.5378907918930054, -4.090448379516602], [-0.3969614505767822, 0.6455380916595459, -4.259235382080078], [-0.5754514

100%|██████████| 34/34 [01:58<00:00,  3.49s/it]


Epoch: 4, Loss:  0.7144038081169128
[[-1.0808467864990234, 1.3183135986328125, -6.62205171585083], [-2.3072509765625, 2.0189568996429443, -7.144656181335449], [-0.12045341730117798, 0.7262986302375793, -6.466419219970703], [-2.4532697200775146, 1.8582183122634888, -6.8256683349609375], [1.465662956237793, 0.8199800848960876, -7.701786041259766], [-1.265669584274292, 1.111952543258667, -6.540640830993652], [-0.25562939047813416, 0.8126865029335022, -6.24167537689209], [-1.962285041809082, 0.9375610947608948, -5.559663772583008], [0.41034454107284546, 0.11659538745880127, -5.8720550537109375], [0.015699103474617004, 0.6890451312065125, -6.416134357452393], [-0.42124760150909424, 0.22604458034038544, -5.428640365600586], [-0.007272064685821533, 0.8648541569709778, -6.850918769836426], [0.16791635751724243, 0.4919477105140686, -6.4053802490234375], [0.018553435802459717, 0.5040580034255981, -6.272491455078125], [-0.3740316331386566, 0.7307429909706116, -6.346887588500977], [-1.092440843582

100%|██████████| 34/34 [01:58<00:00,  3.49s/it]


Epoch: 5, Loss:  0.8873475193977356
[[-10.324965476989746, 5.50314998626709, -2.8400802612304688], [-11.361977577209473, 6.049489974975586, -3.0098776817321777], [-9.960371971130371, 4.986174583435059, -2.625067710876465], [-11.818450927734375, 6.089203834533691, -2.994351863861084], [-8.575085639953613, 5.126345634460449, -3.684035301208496], [-10.760409355163574, 5.4582624435424805, -2.8274755477905273], [-9.886988639831543, 5.1625657081604, -2.635873317718506], [-11.646763801574707, 5.197687149047852, -1.673585057258606], [-9.285537719726562, 4.472921371459961, -2.15073299407959], [-9.461642265319824, 4.95373010635376, -2.73876953125], [-9.825362205505371, 4.500819206237793, -1.6971300840377808], [-8.82502269744873, 4.8963303565979, -2.973766803741455], [-9.064178466796875, 4.699497222900391, -2.650545597076416], [-8.767936706542969, 4.422721862792969, -2.337493419647217], [-9.415899276733398, 4.818769454956055, -2.523122787475586], [-10.498608589172363, 4.659365653991699, -2.076454

100%|██████████| 34/34 [01:59<00:00,  3.51s/it]


Epoch: 1, Loss:  0.8877132534980774
[[-7.723280429840088, 0.6182026863098145, 2.6843247413635254], [-7.550929069519043, -0.007157064974308014, 4.1179046630859375], [-4.170421123504639, -1.1105200052261353, 2.5002357959747314], [-6.943990707397461, 0.02498074620962143, 3.6554133892059326], [-3.377582550048828, -2.307241678237915, 3.735027551651001], [-4.202070713043213, -1.4421499967575073, 3.1750500202178955], [-5.774460315704346, -0.7788528203964233, 4.156766891479492], [-7.329857349395752, 0.07842753082513809, 3.7870166301727295], [-5.985122203826904, -0.56024169921875, 3.424626350402832], [-7.667839527130127, -0.08716810494661331, 4.307529926300049], [-6.041531085968018, -1.5290143489837646, 4.10384464263916], [-6.0943217277526855, -1.701654076576233, 5.411434173583984], [-4.427343845367432, -1.3074458837509155, 3.295034408569336], [-7.445888042449951, 0.034468166530132294, 3.862347364425659], [-7.028270244598389, -0.34054237604141235, 4.274022102355957], [-6.6850056648254395, -0.37

100%|██████████| 34/34 [01:58<00:00,  3.49s/it]


Epoch: 2, Loss:  0.9696125388145447
[[-2.538653612136841, 3.9853384494781494, -3.665158987045288], [-2.3963797092437744, 2.7979986667633057, -2.7133336067199707], [1.693053960800171, 2.6312332153320312, -3.771074056625366], [-1.6860144138336182, 2.9447810649871826, -3.052388906478882], [2.520131826400757, 1.5407861471176147, -2.6004154682159424], [1.5041751861572266, 2.1439874172210693, -3.3495028018951416], [-0.36762571334838867, 2.1559219360351562, -2.7937381267547607], [-2.177009105682373, 2.8661580085754395, -3.037511110305786], [-0.5559577345848083, 2.6385092735290527, -3.2160861492156982], [-2.556565523147583, 2.72011137008667, -2.5011303424835205], [-1.031796932220459, 1.9313280582427979, -1.722425103187561], [-0.9393752217292786, 1.0488048791885376, -1.3930903673171997], [1.4798758029937744, 2.1184935569763184, -3.1792800426483154], [-2.3121001720428467, 2.976998805999756, -2.749305486679077], [-1.7589950561523438, 2.4860353469848633, -2.621533155441284], [-1.4219417572021484, 

100%|██████████| 34/34 [01:58<00:00,  3.49s/it]


Epoch: 3, Loss:  1.775477409362793
[[-4.132542610168457, -1.8134340047836304, -9.812769889831543], [-2.0754854679107666, -1.7627654075622559, -9.494144439697266], [-0.3774707615375519, -3.968738317489624, -10.042279243469238], [-1.7352955341339111, -1.9209636449813843, -9.667522430419922], [0.3651609420776367, -5.26346492767334, -8.6886625289917], [0.2544218599796295, -3.9623777866363525, -9.726926803588867], [0.040692299604415894, -2.505114793777466, -9.620013236999512], [-1.709071397781372, -1.543515682220459, -9.923195838928223], [-1.2109627723693848, -2.7453994750976562, -9.745865821838379], [-2.3551125526428223, -1.818060278892517, -9.2741117477417], [-3.831254005432129, -4.1891069412231445, -7.657333850860596], [-0.20459172129631042, -3.1966588497161865, -8.27101993560791], [-0.1500014066696167, -4.131588459014893, -9.489155769348145], [-2.510007858276367, -1.8474957942962646, -9.390620231628418], [-1.2864465713500977, -2.0652194023132324, -9.41884994506836], [-0.4316970705986023

100%|██████████| 34/34 [01:58<00:00,  3.49s/it]


Epoch: 4, Loss:  1.2534778118133545
[[-11.707956314086914, -2.032419443130493, 5.332268238067627], [-11.238020896911621, -2.3533411026000977, 4.863643169403076], [-7.245011329650879, -3.9897656440734863, 4.865900993347168], [-10.600005149841309, -2.342970371246338, 4.9916911125183105], [-6.4657464027404785, -5.3633131980896, 6.639664173126221], [-7.380741119384766, -4.155179500579834, 5.222236156463623], [-9.154732704162598, -3.33585786819458, 4.716361045837402], [-10.89809799194336, -2.2505156993865967, 4.408601760864258], [-9.520471572875977, -3.17768931388855, 4.940711975097656], [-11.452590942382812, -2.4036471843719482, 5.155589580535889], [-10.232622146606445, -4.005820274353027, 7.783448696136475], [-9.616756439208984, -4.293259620666504, 6.009960651397705], [-7.646827697753906, -4.0417327880859375, 5.483204364776611], [-11.131256103515625, -2.3627333641052246, 5.1954145431518555], [-10.49173355102539, -2.6974661350250244, 4.959876537322998], [-10.043014526367188, -2.74195480346

100%|██████████| 34/34 [01:59<00:00,  3.51s/it]


Epoch: 5, Loss:  1.4494280815124512
[[-0.23728498816490173, 0.5586630702018738, -8.0400972366333], [0.311163991689682, -1.2647958993911743, -6.830042839050293], [4.412856101989746, -0.7544541954994202, -9.418717384338379], [0.7973266839981079, -1.0817662477493286, -6.797896862030029], [4.8317742347717285, -2.1819281578063965, -7.27051305770874], [3.9235219955444336, -1.6993423700332642, -7.9434614181518555], [2.3653111457824707, -2.4613749980926514, -6.873921871185303], [0.5894418954849243, -1.4046767950057983, -7.0948662757873535], [1.9835885763168335, -1.281333565711975, -7.746074199676514], [0.05652293562889099, -1.3441540002822876, -6.48832893371582], [1.2020097970962524, -0.7316015362739563, -6.197973251342773], [2.010119915008545, -4.259517669677734, -4.948517799377441], [3.605623722076416, -1.1210187673568726, -8.014236450195312], [0.3469788730144501, -1.0408793687820435, -6.766452789306641], [0.8994871377944946, -1.8718206882476807, -6.41408634185791], [1.4085735082626343, -2.4

100%|██████████| 34/34 [01:59<00:00,  3.51s/it]


Epoch: 1, Loss:  0.4322931468486786
[[-9.045831680297852, 2.7728450298309326, -4.644202709197998], [-8.161139488220215, 2.045133590698242, -4.617364883422852], [-11.893778800964355, 2.58022141456604, -2.6835238933563232], [-10.128103256225586, 2.4807863235473633, -3.4903147220611572], [-10.494205474853516, 2.972964286804199, -4.1407904624938965], [-10.066165924072266, 2.7516579627990723, -3.486701726913452], [-10.560613632202148, 3.075793504714966, -3.9307332038879395], [-13.065818786621094, 3.384671926498413, -3.5996296405792236], [-9.614733695983887, 2.2719459533691406, -3.3345282077789307], [-10.012767791748047, 2.894207239151001, -4.134062767028809], [-9.43439769744873, 2.164541721343994, -3.980928421020508], [-9.882394790649414, 2.1911702156066895, -2.819204568862915], [-9.095147132873535, 2.4505984783172607, -4.023120880126953], [-11.47761344909668, 3.0118212699890137, -3.336538076400757], [-6.528745174407959, 1.5937824249267578, -5.313648700714111], [-8.407336235046387, 2.254661

100%|██████████| 34/34 [01:58<00:00,  3.48s/it]


Epoch: 2, Loss:  0.7322126030921936
[[-2.472571611404419, 2.438709020614624, -1.7440590858459473], [-1.81304132938385, 1.9176322221755981, -1.7160751819610596], [-4.8065643310546875, 2.2065436840057373, 0.5003914833068848], [-3.298182487487793, 2.2643535137176514, -0.5644938945770264], [-3.6649789810180664, 2.8440792560577393, -1.14921236038208], [-3.035533905029297, 2.527048349380493, -0.6171132326126099], [-3.6125805377960205, 2.8414554595947266, -0.9565404653549194], [-7.21363639831543, 2.785344362258911, -0.09914661943912506], [-2.663844108581543, 1.9994938373565674, -0.4495460093021393], [-3.2638192176818848, 2.540938377380371, -1.0805696249008179], [-3.1836109161376953, 2.046156644821167, -0.9866297245025635], [-2.862037181854248, 1.774968147277832, 0.1738433539867401], [-2.4487171173095703, 2.243077516555786, -1.0111916065216064], [-4.345184326171875, 2.6464550495147705, -0.33195754885673523], [-0.6555485725402832, 1.4257453680038452, -2.4279398918151855], [-2.0466737747192383, 

100%|██████████| 34/34 [01:59<00:00,  3.51s/it]


Epoch: 3, Loss:  0.4751746356487274
[[-2.2149369716644287, -0.19929850101470947, -3.186758279800415], [-1.4653003215789795, -0.6505374312400818, -3.1619787216186523], [-5.539823055267334, -0.4903339147567749, -0.6802293658256531], [-3.459120988845825, -0.4579707384109497, -1.8711681365966797], [-4.049154758453369, 0.2278342843055725, -2.455399513244629], [-3.4993512630462646, -0.14836758375167847, -1.8590598106384277], [-4.110940456390381, 0.18049633502960205, -2.1619210243225098], [-6.2530436515808105, 0.09437548369169235, -2.070899724960327], [-3.081195116043091, -0.5979599952697754, -1.737074613571167], [-3.34696364402771, -0.13983523845672607, -2.4008545875549316], [-2.852018117904663, -0.5089685916900635, -2.5819296836853027], [-3.287689685821533, -0.9403166770935059, -1.0737439393997192], [-2.4359443187713623, -0.42569541931152344, -2.3406660556793213], [-4.990674018859863, -0.06949426978826523, -1.567212462425232], [0.3838166296482086, -1.0888690948486328, -4.051980495452881], [

  0%|          | 0/34 [00:00<?, ?it/s]

In [20]:
accuracy = metrics.accuracy_score(model_targets, model_labels)
f1_score_w_avg = metrics.f1_score(model_targets, model_labels, average='weighted')

print(f"Accuracy Score = {accuracy}")
print(f"F1 Score (Weighted) = {f1_score_w_avg}")

print(metrics.classification_report(model_targets, model_labels))

Accuracy Score = 0.8511111111111112
F1 Score (Weighted) = 0.8006801971887129
              precision    recall  f1-score   support

           0       0.86      0.99      0.92      1139
           1       0.67      0.09      0.17       211

    accuracy                           0.85      1350
   macro avg       0.76      0.54      0.54      1350
weighted avg       0.83      0.85      0.80      1350



# Tester

In [23]:
test_data = ModelDataset(text, pt, topic, Argument_label, tokenizer, MAX_LEN)

test_loader = DataLoader(test_data, batch_size=BATCH_SIZE)

outputs, targets = validation(test_loader, model)

targets = [np.argmax(x) for x in targets]
outputs = [np.argmax(x) for x in outputs]

accuracy = metrics.accuracy_score(targets, outputs)
f1_score_w_avg = metrics.f1_score(targets, outputs, average='weighted')

print(f"Accuracy Score = {accuracy}")
print(f"F1 Score (Weighted) = {f1_score_w_avg}")

print(metrics.classification_report(targets, outputs))

Accuracy Score = 0.8466666666666667
F1 Score (Weighted) = 0.7818742258516321
              precision    recall  f1-score   support

           0       0.85      1.00      0.92      1139
           1       0.75      0.03      0.05       211

    accuracy                           0.85      1350
   macro avg       0.80      0.51      0.49      1350
weighted avg       0.83      0.85      0.78      1350



# Tokenization testing

In [None]:
print(tokenizer.decode([ 101, 28248, 35732, 22044,   118, 21631, 10345, 11101, 16602,   119,
        10159, 48502, 10107, 10108,   146, 16994, 10806,   117, 10231, 87150,
        22525, 22489, 11309, 10161, 21528, 10114, 63376, 11915, 10135,   108,
        10201, 35732, 22044,   123, 31081, 10169, 22528, 10114, 52824, 10123,
        11345,   119, 12882, 14796, 10944,   189, 11419, 51511, 51354, 10188,
        11049, 11309, 10161,   136,   102, 16938,   112,   188, 45476, 10114,
        23763, 10531,   119,   119,   119, 10678, 10114, 94992, 10219, 10111,
        23763, 10105, 11561, 40414,   119, 14453, 44096,   189, 11337, 10678,
        10114, 21852, 10479, 10301, 22489, 11426,   102]))

[CLS] Jalikattu - Ban or Allow. Lakhs of Idiots, uneducated Tamil ppl want to lift ban on # jalikattu 2play with animals to hurt them. So how can u expect gud from such ppl? [SEP] don't dare to say this... come to merina and say the same thing. Surely u will come to know who are Tamil people [SEP]


In [None]:
print(len(train_data[0]["token_type_ids"]))
print(len(train_data[0]["targets"]))
print(len(train_data[0]["ids"]))
print(len(train_data[0]["mask"]))

#print(train_data[0]["text_length"])
#print(train_data[0]["pt_length"])
#print(train_data[0]["topic_length"])

print(train_data[0]["token_type_ids"])
print(train_data[0]["ids"])


510
2
510
510
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 