## **Mount Drive**

In [1]:
from google.colab import drive
drive.mount("/content/gdrive")

Mounted at /content/gdrive


In [2]:
cd /content/gdrive/MyDrive/VIT/Tamil Argumentation

/content/gdrive/MyDrive/VIT/Tamil Argumentation


## **Install**

In [3]:
pip install transformers

Collecting transformers
  Downloading transformers-4.33.2-py3-none-any.whl (7.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.6/7.6 MB[0m [31m58.6 MB/s[0m eta [36m0:00:00[0m
Collecting huggingface-hub<1.0,>=0.15.1 (from transformers)
  Downloading huggingface_hub-0.17.2-py3-none-any.whl (294 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m294.9/294.9 kB[0m [31m30.6 MB/s[0m eta [36m0:00:00[0m
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1 (from transformers)
  Downloading tokenizers-0.13.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.8/7.8 MB[0m [31m107.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting safetensors>=0.3.1 (from transformers)
  Downloading safetensors-0.3.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m73.9 MB/s[0m eta [36m0:00:

## **Import Libraries**

In [5]:
import pandas as pd
import numpy as np

from tqdm import tqdm

from copy import deepcopy

from sklearn import metrics
from sklearn.model_selection import KFold

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader

from transformers import AutoTokenizer,AutoModel

## **Import Dataset**

In [6]:
df = pd.read_excel("/content/gdrive/MyDrive/VIT/Tamil Argumentation/Twitter Comment Dataset.xlsx")

In [7]:
df.head()

Unnamed: 0,S No,Tweet,Date of Tweet,Topic,Parent Tweet,Language,Quality,Stance,Argument,Comment,Responding to Tone,Discussing Writer Characteristics,Remark,Relevancy
0,1,"Bro imagine today is Friday , big star movie i...",2018-05-22,Jalikattu,"And tamil people, jalikattu maadu for money an...",ENGLISH,Med,Undetermined,0,1,0,0,0,Relevant
1,2,Dei unnoda akkarai TN mela not on others and w...,2018-05-22,Jalikattu,"And tamil people, jalikattu maadu for money an...",ENGLISH,Med,Against,0,1,0,1,0,Relevant
2,3,En ninga ivara matum mention panuringa naraiya...,2018-05-22,Jalikattu,"And tamil people, jalikattu maadu for money an...",CODE-MIXED,Med,For,0,1,0,0,0,Relevant
3,4,What is happening in Thoothukudi is totally no...,2018-05-22,Jalikattu,"And tamil people, jalikattu maadu for money an...",ENGLISH,High,Against,1,1,0,0,0,Relevant
4,5,Ungaluku Sterlite protest prachanaya illa Bala...,2018-05-22,Jalikattu,"And tamil people, jalikattu maadu for money an...",CODE-MIXED,Med,Undetermined,0,0,1,0,0,Relevant


In [8]:
df.columns

Index(['S No', 'Tweet', 'Date of Tweet', 'Topic', 'Parent Tweet', 'Language',
       'Quality', 'Stance', 'Argument', 'Comment', 'Responding to Tone',
       'Discussing Writer Characteristics', 'Remark', 'Relevancy'],
      dtype='object')

## **Load Text and Labels**

In [38]:
text = df["Tweet"].to_numpy()

Quality_label = df["Quality"].to_numpy()
Argument_label = df["Argument"].to_numpy()
Comment_label = df["Comment"].to_numpy()
Writer_label = df["Discussing Writer Characteristics"].to_numpy()
Tone_label = df["Responding to Tone"].to_numpy()
Remark_label = df["Remark"].to_numpy()
Relevancy_label = df["Relevancy"].to_numpy()

## **Label Encoding**

In [39]:
encode_dict_quality = {
    "High": np.array([1, 0, 0]),
    "Med": np.array([0, 1, 0]),
    "Low": np.array([0, 0, 1]),
}

encode_dict = {
    1: np.array([1, 0]),
    0: np.array([0, 1]),
}

encode_dict_relevancy = {
    "Relevant": np.array([1, 0]),
    "Irrelevant": np.array([0, 1]),
}

In [40]:
Quality_label = np.array([encode_dict_quality[label] for label in Quality_label])
Argument_label = np.array([encode_dict[label] for label in Argument_label])
Comment_label = np.array([encode_dict[label] for label in Comment_label])
Writer_label = np.array([encode_dict[label] for label in Writer_label])
Tone_label = np.array([encode_dict[label] for label in Tone_label])
Remark_label = np.array([encode_dict[label] for label in Remark_label])
Relevancy_label = np.array([encode_dict_relevancy[label] for label in Relevancy_label])

## **Pre-Config for mBERT**

In [14]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

MAX_LEN = np.max([len(x) for x in text])
MAX_LEN = np.min([MAX_LEN, 510])

BATCH_SIZE = 32
LEARNING_RATE = 1e-1

In [15]:
tokenizer = AutoTokenizer.from_pretrained('bert-base-multilingual-cased')

Downloading (…)okenizer_config.json:   0%|          | 0.00/29.0 [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/625 [00:00<?, ?B/s]

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/996k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.96M [00:00<?, ?B/s]

## **Build Dataset for mBERT**

In [16]:
class ModelDataset(Dataset):
    def __init__(self, X, y, tokenizer, max_len):
        self.max_len = max_len
        self.text = X
        self.tokenizer = tokenizer
        self.targets = y

    def __len__(self):
        return len(self.targets)

    def __getitem__(self, index):
        text = self.text[index]
        inputs = self.tokenizer.encode_plus(
            text,
            truncation=True,
            add_special_tokens=True,
            max_length=self.max_len,
            padding='max_length',
            return_token_type_ids=True
        )

        ids = inputs['input_ids']
        mask = inputs['attention_mask']
        token_type_ids = inputs["token_type_ids"]

        return {
            'ids': torch.tensor(ids, dtype=torch.long).to(device),
            'mask': torch.tensor(mask, dtype=torch.long).to(device),
            'token_type_ids': torch.tensor(token_type_ids, dtype=torch.long).to(device),
            'targets': torch.tensor(self.targets[index], dtype=torch.float).to(device)
        }

## **Build Model**

In [25]:
class CustomModel(nn.Module):

    def __init__(self):
        super(CustomModel, self).__init__()
        self.bert = AutoModel.from_pretrained('bert-base-multilingual-cased')

        for param in self.bert.parameters():
            param.requires_grad = False

        self.out_layer = nn.Linear(768, 2)

    def forward(self, ids, mask, token_type_ids):
        _, features = self.bert(
            ids, token_type_ids=token_type_ids,
            attention_mask=mask, return_dict=False
        )

        output = self.out_layer(features)

        return output

## **Train Model**

In [28]:
def train(epoch, model, train_loader, loss_fn, optimizer):

    model.train()

    for batch in tqdm(train_loader):

        optimizer.zero_grad()

        ids = batch['ids'].to(device, dtype = torch.long)
        mask = batch['mask'].to(device, dtype = torch.long)
        token_type_ids = batch['token_type_ids'].to(device, dtype = torch.long)
        targets = batch['targets'].to(device, dtype = torch.float)

        outputs = model(ids, mask, token_type_ids)

        loss = loss_fn(outputs, targets)

        loss.backward()
        optimizer.step()


    print(f'Epoch: {epoch + 1}, Loss:  {loss.item()}')

In [21]:
def validation(data_loader, model):

    model.eval()
    targets = []
    outputs = []

    with torch.no_grad():

        for batch in data_loader:

            ids = batch['ids'].to(device, dtype = torch.long)
            mask = batch['mask'].to(device, dtype = torch.long)
            token_type_ids = batch['token_type_ids'].to(device, dtype = torch.long)
            batch_targets = batch['targets'].to(device, dtype = torch.float)

            batch_outputs = model(ids, mask, token_type_ids)

            targets.extend(batch_targets.cpu().numpy().tolist())
            outputs.extend(batch_outputs.cpu().numpy().tolist())

    return outputs, targets

In [None]:
kf = KFold(n_splits=5)

model_targets = []
model_labels = []

model = CustomModel().to(device)

loss_fn = nn.BCEWithLogitsLoss()
optimizer = torch.optim.AdamW(params=model.parameters(), lr=LEARNING_RATE)

for train_index, test_index in kf.split(text):

    text_train, text_test = text[train_index], text[test_index]
    labels_train, labels_test = Argument_label[train_index], Argument_label[test_index]

    train_data = ModelDataset(text_train, labels_train, tokenizer, MAX_LEN)
    test_data = ModelDataset(text_test, labels_test, tokenizer, MAX_LEN)

    train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)
    test_loader = DataLoader(test_data, batch_size=BATCH_SIZE)

    best_score = -np.inf
    best_weights = None

    EPOCHS = 10

    torch.save(model.state_dict(),"/content/gdrive/MyDrive/VIT/Tamil Argumentation/twitter_model_argument.pth")

    for epoch in range(EPOCHS):

        train(epoch, model, train_loader, loss_fn, optimizer)
        outputs, targets = validation(test_loader, model)

        print(outputs)
        print(" ")
        print(targets)

        targets = [np.argmax(x) for x in targets]
        outputs = [np.argmax(x) for x in outputs]

        score = metrics.f1_score(targets, outputs, average='weighted')

        if score > best_score:
            best_score = score
            best_weights = deepcopy(model.state_dict())

    model.load_state_dict(best_weights)
    torch.save(model.state_dict(),"/content/gdrive/MyDrive/VIT/Tamil Argumentation/twitter_model_argument.pth")

    outputs, targets = validation(test_loader, model)

    targets = [np.argmax(x) for x in targets]
    outputs = [np.argmax(x) for x in outputs]

    model_targets.extend(targets)
    model_labels.extend(outputs)

100%|██████████| 34/34 [00:18<00:00,  1.82it/s]


Epoch: 1, Loss:  0.5598838329315186
[[-11.443680763244629, 4.448367595672607], [-9.194916725158691, 2.7795374393463135], [-7.511484146118164, 2.795736312866211], [-12.234318733215332, 2.9090819358825684], [-8.684109687805176, 3.2519283294677734], [-9.600546836853027, 4.335109233856201], [-7.974612712860107, 3.8351988792419434], [-6.274366855621338, 2.065030097961426], [-10.968130111694336, 4.022838592529297], [-4.135931491851807, 1.4554136991500854], [-5.609196186065674, 1.6326847076416016], [-9.741717338562012, 3.8901522159576416], [-6.422377109527588, 1.494336724281311], [-5.254456520080566, 2.4498343467712402], [-7.857550144195557, 2.30241322517395], [-10.316399574279785, 4.053957939147949], [-9.103894233703613, 2.6927921772003174], [-11.049166679382324, 3.418760061264038], [-8.265060424804688, 2.5099010467529297], [-5.919986248016357, 0.9028308987617493], [-7.503967761993408, 1.2414355278015137], [-6.409285068511963, 1.1343861818313599], [-5.334238529205322, 2.594231128692627], [-7

100%|██████████| 34/34 [00:19<00:00,  1.79it/s]


Epoch: 2, Loss:  0.5588838458061218
[[-9.800026893615723, 4.449311256408691], [-6.9437737464904785, 1.6909692287445068], [-6.132838726043701, 2.7661452293395996], [-9.440315246582031, 2.5434517860412598], [-7.560821056365967, 3.51845645904541], [-8.671320915222168, 3.978992223739624], [-7.38712739944458, 3.615267753601074], [-5.075843811035156, 1.8433669805526733], [-8.299527168273926, 3.781998634338379], [-2.7879297733306885, 0.9509804248809814], [-4.382823467254639, 1.3518425226211548], [-8.403045654296875, 3.5106163024902344], [-4.757795810699463, 0.8465433120727539], [-4.466906547546387, 2.123065710067749], [-5.911749362945557, 1.5707836151123047], [-8.064680099487305, 3.316025733947754], [-6.399147033691406, 1.7224806547164917], [-8.985733032226562, 2.825735092163086], [-5.169861316680908, 1.3945595026016235], [-3.4710543155670166, 0.19029748439788818], [-5.125954627990723, 0.2636212110519409], [-3.065272569656372, -0.21738514304161072], [-6.353850841522217, 3.7421247959136963], [

100%|██████████| 34/34 [00:19<00:00,  1.75it/s]


Epoch: 3, Loss:  0.35326144099235535
[[-4.339311599731445, 3.293780565261841], [-3.093973398208618, 1.793893814086914], [-3.1716811656951904, 2.279451608657837], [-3.541466474533081, 2.4622156620025635], [-3.8103020191192627, 2.7847840785980225], [-5.144573211669922, 3.3836381435394287], [-3.8295176029205322, 2.7291743755340576], [-2.7794902324676514, 1.9101632833480835], [-5.499150276184082, 4.523352146148682], [-2.120807409286499, 1.7356293201446533], [-2.1674163341522217, 1.3948564529418945], [-5.244100093841553, 3.735236883163452], [-2.0769600868225098, 1.0695550441741943], [-1.598922848701477, 1.2478399276733398], [-3.2222089767456055, 1.8388313055038452], [-5.321995735168457, 4.095373630523682], [-3.7006399631500244, 2.7814884185791016], [-3.3661952018737793, 1.9836523532867432], [-3.416205644607544, 2.2447078227996826], [-1.880955457687378, 1.4669864177703857], [-0.647739827632904, 0.11600802838802338], [-1.1484512090682983, 0.7638611197471619], [-3.4780428409576416, 3.112962961

100%|██████████| 34/34 [00:20<00:00,  1.70it/s]


Epoch: 4, Loss:  0.5178078413009644
[[-7.207249164581299, 1.7990832328796387], [-6.098998546600342, -0.44349145889282227], [-6.029115200042725, 2.2099268436431885], [-6.403223037719727, -2.050870895385742], [-6.22036075592041, 2.441812515258789], [-8.685868263244629, 4.040225028991699], [-7.132476806640625, 3.2419655323028564], [-4.901683330535889, 0.7942491173744202], [-8.016501426696777, 3.4896278381347656], [-4.276391506195068, 1.9238595962524414], [-4.3459858894348145, 0.9482902884483337], [-8.77590560913086, 2.799626350402832], [-4.469359874725342, -0.05187028646469116], [-4.614922523498535, 2.320216417312622], [-5.865278720855713, 0.8844305872917175], [-8.591830253601074, 2.9747538566589355], [-6.636545658111572, 0.6585214734077454], [-5.772318363189697, -1.7748477458953857], [-5.8248515129089355, 1.6430251598358154], [-3.172016143798828, -0.7811381220817566], [-4.361323833465576, -2.0484983921051025], [-3.7864561080932617, -0.9042481184005737], [-6.281385898590088, 3.58314371109

100%|██████████| 34/34 [00:20<00:00,  1.69it/s]


Epoch: 5, Loss:  0.35227036476135254
[[-4.982089042663574, 6.453356742858887], [-3.959571361541748, 5.171030044555664], [-4.265913963317871, 4.976396083831787], [-3.7313599586486816, 5.801187038421631], [-4.5602216720581055, 5.206020355224609], [-6.592926979064941, 6.789474964141846], [-5.490494728088379, 6.599219799041748], [-3.662160873413086, 4.574798107147217], [-5.551715850830078, 7.404442310333252], [-3.1462182998657227, 3.842221975326538], [-2.658764123916626, 2.8612422943115234], [-6.856727600097656, 8.17751693725586], [-2.930539608001709, 3.78341007232666], [-2.827796220779419, 4.2573981285095215], [-3.703248977661133, 4.314901351928711], [-6.154750347137451, 7.978670597076416], [-4.146595001220703, 6.059696674346924], [-3.9157800674438477, 5.497455596923828], [-3.599519729614258, 4.918006896972656], [-1.860229730606079, 3.4968502521514893], [-1.8819432258605957, 3.5845799446105957], [-2.2601165771484375, 3.4854586124420166], [-4.974245071411133, 4.928400993347168], [-3.356699

100%|██████████| 34/34 [00:19<00:00,  1.72it/s]


Epoch: 6, Loss:  0.2009168267250061
[[-1.5665730237960815, 1.744640588760376], [-0.9933532476425171, -0.3658672571182251], [-2.5103445053100586, 1.6179324388504028], [0.39650627970695496, -0.6292432546615601], [-2.3170888423919678, 2.3444433212280273], [-4.525798797607422, 3.673320770263672], [-3.939502716064453, 3.583841562271118], [-1.475199818611145, 0.50774747133255], [-3.185142993927002, 2.527658462524414], [-1.7142491340637207, 0.6485254168510437], [-1.2772239446640015, -0.3125190734863281], [-4.427116870880127, 3.193911552429199], [-0.8980966210365295, -0.23732925951480865], [-2.0258543491363525, 2.238718032836914], [-1.5449082851409912, 0.11399541795253754], [-3.521212577819824, 2.6284406185150146], [-1.1147582530975342, 0.10537724196910858], [0.06215069442987442, -0.2994351387023926], [-2.3873469829559326, 0.9257237911224365], [0.2565824091434479, -1.0464656352996826], [0.8470356464385986, -1.2783366441726685], [0.33481284976005554, -1.4346189498901367], [-3.7216238975524902, 

100%|██████████| 34/34 [00:19<00:00,  1.72it/s]


Epoch: 7, Loss:  0.5924701690673828
[[-2.796077251434326, 3.2023377418518066], [-1.8823587894439697, 2.139829158782959], [-3.156707525253296, 3.229965925216675], [-1.1351714134216309, 1.7752882242202759], [-3.0499024391174316, 3.3414578437805176], [-5.028723239898682, 4.886719703674316], [-4.675660610198975, 4.912916660308838], [-1.3622679710388184, 1.615773320198059], [-3.262766122817993, 3.9182803630828857], [-2.2548229694366455, 2.434051752090454], [-1.5880348682403564, 1.2027082443237305], [-5.503357410430908, 5.468907356262207], [-1.354601502418518, 1.276694893836975], [-3.1663401126861572, 3.5830318927764893], [-1.834045648574829, 1.5182404518127441], [-4.353032112121582, 4.584770679473877], [-2.0405406951904297, 2.4088289737701416], [-1.21686851978302, 1.7144572734832764], [-2.33601975440979, 2.8077826499938965], [-0.2636963129043579, 0.7364331483840942], [-0.14919637143611908, 0.20209915935993195], [-0.43005669116973877, 0.7034077644348145], [-4.1446919441223145, 3.940342187881

100%|██████████| 34/34 [00:19<00:00,  1.70it/s]


Epoch: 8, Loss:  0.461014986038208
[[-0.44864973425865173, 0.8498893976211548], [0.25323858857154846, 0.05245538800954819], [-1.6045491695404053, 1.70589280128479], [0.9836679100990295, -0.25868192315101624], [-1.7016974687576294, 1.9686256647109985], [-3.0649383068084717, 3.3424301147460938], [-2.7669484615325928, 3.1457226276397705], [-0.5403062701225281, 0.7092685699462891], [-1.773310661315918, 2.1592721939086914], [-1.6164519786834717, 1.5398460626602173], [-0.5265067219734192, 0.2718527615070343], [-3.405616283416748, 3.563364028930664], [-0.12956079840660095, 0.008417390286922455], [-1.9615226984024048, 2.1940131187438965], [-0.35245999693870544, 0.2708720862865448], [-2.26416277885437, 2.512141704559326], [-0.1982930600643158, 0.4737187325954437], [0.8796289563179016, -0.41728124022483826], [-1.0753055810928345, 1.3489279747009277], [0.9177778363227844, -0.8683926463127136], [1.8457297086715698, -1.4347388744354248], [0.7456661462783813, -0.5936699509620667], [-2.89809870719909

100%|██████████| 34/34 [00:19<00:00,  1.71it/s]


Epoch: 9, Loss:  0.8919563293457031
[[-9.975924491882324, 11.122162818908691], [-7.652770042419434, 8.692623138427734], [-8.531707763671875, 9.278675079345703], [-9.339707374572754, 10.946649551391602], [-9.295890808105469, 10.187579154968262], [-10.638861656188965, 11.402942657470703], [-9.984284400939941, 10.876493453979492], [-6.311532020568848, 7.003173351287842], [-10.235086441040039, 11.434269905090332], [-5.952413558959961, 6.460977554321289], [-5.915468692779541, 6.243173599243164], [-11.834718704223633, 12.840827941894531], [-6.87308406829834, 7.541392803192139], [-7.3203535079956055, 7.973201274871826], [-7.473041534423828, 8.067207336425781], [-10.908687591552734, 11.91392707824707], [-8.394200325012207, 9.433782577514648], [-8.868268013000488, 10.210088729858398], [-7.952329158782959, 8.941017150878906], [-4.816871643066406, 5.63556432723999], [-6.078701972961426, 7.019378662109375], [-5.891946792602539, 6.621649742126465], [-8.445778846740723, 9.073156356811523], [-6.70410

100%|██████████| 34/34 [00:19<00:00,  1.71it/s]


Epoch: 10, Loss:  0.3659787178039551
[[-0.453384667634964, 0.8174098134040833], [-0.5781189203262329, 1.0214711427688599], [-1.2236411571502686, 1.4155802726745605], [-0.3002031147480011, 0.9816550016403198], [-1.4717260599136353, 1.7197414636611938], [-2.725545883178711, 2.981734037399292], [-2.9901468753814697, 3.4565603733062744], [-0.18767490983009338, 0.4383331537246704], [-2.6173861026763916, 3.0959489345550537], [-0.946894645690918, 1.1057895421981812], [0.2892191708087921, -0.3323078155517578], [-3.7662923336029053, 4.189234733581543], [-0.49770423769950867, 0.6698029041290283], [-0.05737629532814026, 0.24074088037014008], [-0.4675755202770233, 0.571384072303772], [-2.821652889251709, 3.2093770503997803], [-1.3588567972183228, 1.76995050907135], [0.14171907305717468, 0.46663570404052734], [-1.2730804681777954, 1.6551642417907715], [1.1629844903945923, -0.8739591836929321], [1.9299513101577759, -1.6047433614730835], [0.6121068000793457, -0.4057472348213196], [-1.391805648803711,

100%|██████████| 34/34 [00:19<00:00,  1.71it/s]


Epoch: 1, Loss:  0.44391605257987976
[[-5.991171360015869, 6.5900959968566895], [-4.172907829284668, 4.404595851898193], [-4.583683967590332, 4.997935771942139], [-6.324143886566162, 6.597784519195557], [-10.770237922668457, 11.194173812866211], [-4.641184329986572, 5.185033321380615], [-4.637569427490234, 4.64235258102417], [-4.996409893035889, 5.174704074859619], [-2.9970006942749023, 3.1598618030548096], [-4.867499828338623, 5.119129657745361], [-3.012638568878174, 3.2727839946746826], [-4.460744380950928, 4.343155384063721], [-5.800103187561035, 6.078698635101318], [-2.3277339935302734, 2.558230400085449], [-3.922549247741699, 4.17019510269165], [-5.643735408782959, 5.8286309242248535], [-6.5104522705078125, 7.052995204925537], [-4.5185699462890625, 4.8701300621032715], [-4.740530490875244, 5.2676568031311035], [-5.170800685882568, 5.693824291229248], [-3.6617696285247803, 4.340275287628174], [-5.6186137199401855, 6.480759143829346], [-5.563531398773193, 6.18719482421875], [-5.7079

100%|██████████| 34/34 [00:20<00:00,  1.69it/s]


Epoch: 2, Loss:  0.524526834487915
[[-6.094697952270508, 7.2562031745910645], [-4.328363418579102, 5.16270112991333], [-3.7341508865356445, 4.769181251525879], [-6.240013122558594, 7.107055187225342], [-10.167623519897461, 11.188011169433594], [-5.612069129943848, 6.758972644805908], [-3.013223171234131, 3.6101486682891846], [-4.49130916595459, 5.150588512420654], [-1.7693805694580078, 2.5890114307403564], [-4.026048183441162, 4.962527751922607], [-1.0657292604446411, 2.059234380722046], [-4.7799601554870605, 5.423882961273193], [-6.4956889152526855, 7.3578386306762695], [-4.658481121063232, 5.4534382820129395], [-3.830630302429199, 4.775569915771484], [-6.559876441955566, 7.45859956741333], [-6.867487907409668, 8.040252685546875], [-2.9499056339263916, 3.9055380821228027], [-5.876760482788086, 6.98156213760376], [-4.87727165222168, 6.022709846496582], [-5.695336818695068, 6.9197258949279785], [-6.071795463562012, 7.4117231369018555], [-6.2911272048950195, 7.407955646514893], [-5.70855

100%|██████████| 34/34 [00:19<00:00,  1.72it/s]


Epoch: 3, Loss:  0.3522353172302246
[[-8.40283203125, 6.932938098907471], [-4.589643955230713, 3.187056064605713], [-5.494711399078369, 4.184126377105713], [-6.278753757476807, 5.711879253387451], [-10.972721099853516, 10.257437705993652], [-7.398744106292725, 6.188523292541504], [-4.234041213989258, 3.366770029067993], [-5.146384239196777, 4.105065822601318], [-3.0043928623199463, 1.6608470678329468], [-5.356114864349365, 3.940176486968994], [-2.896368980407715, 1.7372184991836548], [-4.08987283706665, 3.127181053161621], [-8.543914794921875, 7.114729404449463], [-2.8923380374908447, 1.542978286743164], [-4.1361823081970215, 3.0145673751831055], [-8.124763488769531, 7.074424743652344], [-6.037140369415283, 4.8503804206848145], [-4.044166088104248, 2.7551686763763428], [-6.89301061630249, 5.503158092498779], [-5.559515953063965, 4.495925426483154], [-6.4871506690979, 5.051433086395264], [-7.728122234344482, 6.571531295776367], [-6.564333438873291, 5.566734790802002], [-6.41622972488403

100%|██████████| 34/34 [00:19<00:00,  1.72it/s]


Epoch: 4, Loss:  0.4195432662963867
[[-9.101736068725586, 10.057132720947266], [-5.767348766326904, 6.441348075866699], [-6.752938270568848, 7.58502197265625], [-7.8587470054626465, 8.094709396362305], [-12.301065444946289, 12.812936782836914], [-8.267504692077637, 9.165847778320312], [-5.061248779296875, 5.160353660583496], [-6.226983547210693, 6.466422080993652], [-4.627607345581055, 5.286779403686523], [-6.338287830352783, 7.058957099914551], [-3.7197556495666504, 4.4655680656433105], [-5.8167524337768555, 6.007784843444824], [-8.963058471679688, 10.026201248168945], [-4.542320251464844, 5.233432769775391], [-4.8548150062561035, 5.546926975250244], [-8.768345832824707, 9.592020034790039], [-8.145660400390625, 9.040895462036133], [-5.322066783905029, 6.1351704597473145], [-7.529916763305664, 8.746455192565918], [-6.23066520690918, 6.925774097442627], [-7.581651210784912, 8.86976432800293], [-8.063365936279297, 9.384180068969727], [-7.665045738220215, 8.45749282836914], [-7.5362052917

100%|██████████| 34/34 [00:19<00:00,  1.71it/s]


Epoch: 5, Loss:  0.28039759397506714
[[-10.701976776123047, 7.7052178382873535], [-7.295459747314453, 5.162759304046631], [-7.827775478363037, 5.55441951751709], [-8.481163024902344, 7.701928615570068], [-12.656034469604492, 12.165441513061523], [-10.245388984680176, 7.567250728607178], [-4.265228748321533, 3.3829901218414307], [-6.190281867980957, 4.917652130126953], [-5.020981788635254, 3.3443946838378906], [-8.104941368103027, 5.545289516448975], [-3.819049596786499, 2.0592517852783203], [-7.891208648681641, 6.6751484870910645], [-10.860371589660645, 7.578019618988037], [-9.259821891784668, 6.274425983428955], [-6.613222122192383, 4.6503214836120605], [-10.764070510864258, 8.409677505493164], [-10.205439567565918, 8.134400367736816], [-5.550745964050293, 3.743678092956543], [-10.762786865234375, 7.371882915496826], [-7.781742095947266, 5.85800313949585], [-11.846994400024414, 8.043850898742676], [-10.182432174682617, 7.334784984588623], [-9.483036041259766, 7.459322452545166], [-9.2

100%|██████████| 34/34 [00:19<00:00,  1.71it/s]


Epoch: 6, Loss:  0.6013432741165161
[[-13.47279167175293, 12.16475772857666], [-9.738067626953125, 8.682202339172363], [-10.531208038330078, 9.903478622436523], [-10.579570770263672, 9.573494911193848], [-15.928234100341797, 14.508404731750488], [-12.482706069946289, 11.430991172790527], [-6.264500617980957, 5.791919231414795], [-8.625834465026855, 7.799499034881592], [-6.938138484954834, 6.67971658706665], [-10.661885261535645, 9.999016761779785], [-6.215805530548096, 6.260182857513428], [-9.114547729492188, 8.395440101623535], [-13.98708724975586, 12.475028038024902], [-10.33293342590332, 8.970138549804688], [-8.451719284057617, 7.650350570678711], [-12.868309020996094, 11.592680931091309], [-12.18255615234375, 11.515185356140137], [-8.042095184326172, 7.812711238861084], [-13.212417602539062, 11.744782447814941], [-10.242631912231445, 9.138656616210938], [-14.011819839477539, 12.18079662322998], [-13.459259986877441, 12.000328063964844], [-12.109529495239258, 10.8503999710083], [-11

100%|██████████| 34/34 [00:19<00:00,  1.72it/s]


Epoch: 7, Loss:  0.8580756187438965
[[-6.400866508483887, 0.3649022579193115], [-3.485185146331787, -1.3008332252502441], [-4.480929374694824, 0.0766742080450058], [-7.096367359161377, 4.695807933807373], [-11.232222557067871, 7.243900299072266], [-6.341821670532227, 1.0436209440231323], [-3.1573030948638916, 1.376219391822815], [-4.511690139770508, 1.5566232204437256], [-2.941409111022949, -0.8600651025772095], [-4.414792537689209, -0.012449875473976135], [-1.9435374736785889, -1.0948498249053955], [-4.063997745513916, 1.7072370052337646], [-7.650456428527832, 1.077581763267517], [-2.8774430751800537, -2.452341318130493], [-3.220350742340088, -1.0198084115982056], [-6.526724815368652, 1.2482812404632568], [-5.746654510498047, 1.6366212368011475], [-3.0085320472717285, -0.19150717556476593], [-4.872851371765137, -2.062371015548706], [-4.811337471008301, 0.5646695494651794], [-5.844118118286133, -1.368463397026062], [-6.287153244018555, -0.20911450684070587], [-5.854799270629883, 0.9290

100%|██████████| 34/34 [00:19<00:00,  1.71it/s]


Epoch: 8, Loss:  0.6356099843978882
[[-2.7290351390838623, 4.193328380584717], [-1.0402734279632568, 1.9098048210144043], [-1.6158345937728882, 2.838132858276367], [-5.210201263427734, 5.747584342956543], [-7.521570205688477, 8.578279495239258], [-4.129423141479492, 5.658352375030518], [-1.7535107135772705, 2.1581151485443115], [-2.5509073734283447, 3.0740761756896973], [-0.7682219743728638, 1.1149210929870605], [-1.9852617979049683, 3.01873779296875], [-0.38022083044052124, 0.7261567115783691], [-2.8801088333129883, 3.947312116622925], [-3.8194637298583984, 5.144309997558594], [-1.9146109819412231, 2.9293456077575684], [-0.866144597530365, 1.1190322637557983], [-3.5739145278930664, 4.751067638397217], [-3.116093158721924, 4.412177085876465], [-0.9252315759658813, 1.8266974687576294], [-1.4372388124465942, 2.821403980255127], [-2.0681138038635254, 2.9182097911834717], [-2.1172428131103516, 4.037552833557129], [-1.9364334344863892, 3.5807430744171143], [-3.0905966758728027, 4.4037008285

100%|██████████| 34/34 [00:19<00:00,  1.72it/s]


Epoch: 9, Loss:  0.29958000779151917
[[-6.285831928253174, 6.2118024826049805], [-4.40023946762085, 4.388115406036377], [-5.020509243011475, 5.0141730308532715], [-7.335916519165039, 7.671231269836426], [-10.511672973632812, 10.720274925231934], [-7.1189775466918945, 7.4648356437683105], [-4.14018440246582, 4.067834377288818], [-5.030383586883545, 4.891139507293701], [-2.5191128253936768, 2.3755736351013184], [-5.381134033203125, 5.271965503692627], [-2.8548803329467773, 2.3338823318481445], [-5.8665971755981445, 6.588527679443359], [-6.883816719055176, 6.856903553009033], [-6.3610310554504395, 7.132412433624268], [-3.787614583969116, 3.662055253982544], [-7.402616500854492, 7.546309947967529], [-7.0624542236328125, 7.74586820602417], [-3.7953014373779297, 3.682419538497925], [-5.45135498046875, 5.705450534820557], [-5.235069274902344, 5.08849573135376], [-6.078965187072754, 6.515456676483154], [-5.809887409210205, 5.84968376159668], [-6.793369770050049, 7.034533977508545], [-5.9945745

100%|██████████| 34/34 [00:19<00:00,  1.72it/s]


Epoch: 10, Loss:  0.21237997710704803
[[-3.8735604286193848, 4.62863302230835], [-1.872902750968933, 2.106785535812378], [-2.845431327819824, 3.177454710006714], [-6.700647354125977, 7.026697158813477], [-10.500655174255371, 11.105114936828613], [-4.442620754241943, 5.3290791511535645], [-4.08944034576416, 4.220701217651367], [-4.203823089599609, 4.369187831878662], [-1.3681397438049316, 1.3646965026855469], [-3.0299038887023926, 3.2279698848724365], [-1.7997642755508423, 1.5606563091278076], [-3.917120933532715, 4.329520225524902], [-4.478744029998779, 5.1551899909973145], [-1.866565227508545, 2.2203369140625], [-1.9393888711929321, 2.1652278900146484], [-4.878136157989502, 5.483337879180908], [-4.647708415985107, 4.956716060638428], [-2.1399965286254883, 2.3192529678344727], [-2.0343761444091797, 2.7656896114349365], [-3.4090752601623535, 3.6328256130218506], [-1.6414698362350464, 2.2300400733947754], [-3.4801011085510254, 4.236183166503906], [-4.771411895751953, 5.232226848602295], 

100%|██████████| 34/34 [00:19<00:00,  1.70it/s]


Epoch: 1, Loss:  0.6277066469192505
[[-1.8642385005950928, 3.461151599884033], [-4.6773176193237305, 6.156754493713379], [-1.1135330200195312, 4.262803554534912], [-6.037947654724121, 7.20794153213501], [1.9100978374481201, 1.8455759286880493], [-4.137123107910156, 5.342535018920898], [-4.683741092681885, 5.554013252258301], [-5.545129776000977, 7.372381210327148], [-0.5893160700798035, 2.325801372528076], [-5.650145053863525, 6.95768928527832], [-1.0764198303222656, 2.6717448234558105], [-2.2388856410980225, 3.4964230060577393], [-4.567424297332764, 5.576223850250244], [2.313836097717285, 1.1329621076583862], [-2.8333206176757812, 4.779481887817383], [-1.6634304523468018, 3.831291437149048], [-3.6030144691467285, 5.183237075805664], [-4.902746200561523, 6.142879486083984], [-3.5576648712158203, 5.021519184112549], [-5.445871353149414, 6.626163959503174], [-2.0837671756744385, 3.2864990234375], [-4.037083625793457, 5.502084255218506], [-4.402319431304932, 5.935776710510254], [-1.292143

100%|██████████| 34/34 [00:20<00:00,  1.70it/s]


Epoch: 2, Loss:  0.282375305891037
[[-1.373692512512207, 1.7412211894989014], [-4.168546199798584, 4.860625743865967], [-1.7861850261688232, 1.7004613876342773], [-4.981241703033447, 5.2529425621032715], [-0.38065284490585327, -0.7745397090911865], [-3.276855230331421, 3.801044464111328], [-3.519268751144409, 3.8208441734313965], [-5.0578789710998535, 5.542710781097412], [-1.171436071395874, 1.3344800472259521], [-4.4824748039245605, 4.700736999511719], [-3.3248178958892822, 2.657862663269043], [-1.279836654663086, 2.049248695373535], [-3.2806150913238525, 3.7897727489471436], [-5.743441581726074, 3.7538836002349854], [-2.46557354927063, 2.7785098552703857], [-2.4776952266693115, 2.126164436340332], [-3.3470091819763184, 3.4489166736602783], [-4.188788890838623, 4.5546488761901855], [-3.1527578830718994, 3.7063002586364746], [-5.084191799163818, 4.691519260406494], [-1.6164774894714355, 2.30755877494812], [-3.4615533351898193, 3.447047472000122], [-3.5179178714752197, 3.612531185150146

100%|██████████| 34/34 [00:19<00:00,  1.71it/s]


Epoch: 3, Loss:  0.786777913570404
[[0.5994101762771606, -0.332918256521225], [-2.8853344917297363, 3.0828990936279297], [1.3905669450759888, -1.4748395681381226], [-2.6817626953125, 2.6942596435546875], [4.036703109741211, -4.901675701141357], [-1.6266690492630005, 1.7140225172042847], [-0.9867444038391113, 1.0818400382995605], [-3.0774409770965576, 3.5269036293029785], [1.4963912963867188, -1.2744098901748657], [-1.2821520566940308, 1.4526175260543823], [0.41569632291793823, -0.9629198312759399], [0.05787922441959381, 0.24256554245948792], [-1.272416114807129, 1.5402905941009521], [3.5066628456115723, -5.464602470397949], [0.5153554677963257, -0.37462934851646423], [0.7381616830825806, -1.0555672645568848], [-0.6031020879745483, 0.5035451650619507], [-1.9504166841506958, 2.0356273651123047], [-1.0013924837112427, 1.2883015871047974], [-1.6581324338912964, 1.2930976152420044], [-0.15443842113018036, 0.3501128852367401], [-0.6299635171890259, 0.6016143560409546], [-0.5197107195854187, 

100%|██████████| 34/34 [00:19<00:00,  1.72it/s]


Epoch: 4, Loss:  0.2919114828109741
[[0.2085319459438324, -0.16887715458869934], [-3.1055681705474854, 2.958615779876709], [-2.51487135887146, 2.0989248752593994], [-3.988459825515747, 3.540677309036255], [-3.442774534225464, 2.4284346103668213], [-2.282094955444336, 2.0515129566192627], [-2.0222830772399902, 1.6875791549682617], [-4.477188587188721, 4.354889869689941], [-1.4195972681045532, 1.0731391906738281], [-2.7878286838531494, 2.404869556427002], [-5.916260242462158, 5.215351104736328], [0.20637187361717224, -0.1290138065814972], [-1.9598655700683594, 1.8895305395126343], [-11.91786003112793, 10.681551933288574], [-1.0643855333328247, 0.926924467086792], [-3.277836561203003, 2.6584811210632324], [-2.9899394512176514, 2.539832353591919], [-3.060145139694214, 2.7681825160980225], [-2.2506589889526367, 2.185002565383911], [-5.138033390045166, 4.34283447265625], [-0.32784727215766907, 0.1751013696193695], [-2.3033969402313232, 2.0823092460632324], [-2.5429744720458984, 2.31119847297

100%|██████████| 34/34 [00:19<00:00,  1.71it/s]


Epoch: 5, Loss:  0.6263675093650818
[[-6.524350166320801, 6.364012718200684], [-7.117310047149658, 6.939703941345215], [-11.235607147216797, 10.759654998779297], [-10.25260066986084, 9.76517391204834], [-12.179313659667969, 11.481244087219238], [-7.0917649269104, 6.8293657302856445], [-7.849236965179443, 7.412079334259033], [-10.427931785583496, 10.262795448303223], [-6.074880123138428, 5.688333511352539], [-10.71139907836914, 10.230588912963867], [-7.309580326080322, 6.806368350982666], [-4.768189907073975, 4.672658443450928], [-7.926675319671631, 7.673158168792725], [1.165095329284668, -1.225165843963623], [-8.11931037902832, 7.881834983825684], [-10.598018646240234, 9.90874195098877], [-9.44051456451416, 8.932246208190918], [-8.640132904052734, 8.306879043579102], [-7.309789657592773, 7.146783351898193], [-12.044251441955566, 11.290057182312012], [-3.974738597869873, 3.7600090503692627], [-10.294068336486816, 9.920976638793945], [-10.740349769592285, 10.34022045135498], [-7.71282196

100%|██████████| 34/34 [00:19<00:00,  1.71it/s]


Epoch: 6, Loss:  0.07166513055562973
[[-2.790159225463867, 4.847777843475342], [-6.362049102783203, 7.528796195983887], [-6.433193206787109, 10.144058227539062], [-7.978367328643799, 9.93972396850586], [-7.88006067276001, 12.314857482910156], [-6.138247489929199, 7.668402671813965], [-5.688028812408447, 7.448789119720459], [-8.534382820129395, 10.96192741394043], [-4.4984917640686035, 7.572322368621826], [-7.424884796142578, 9.612401962280273], [-7.785086631774902, 11.903536796569824], [-2.51884388923645, 3.8060760498046875], [-5.282628059387207, 7.000720977783203], [-14.356531143188477, 18.13699722290039], [-5.701312065124512, 7.729950428009033], [-7.035184860229492, 10.448811531066895], [-7.350639343261719, 9.657130241394043], [-6.850556373596191, 8.702939987182617], [-5.954937934875488, 7.92954158782959], [-9.432311058044434, 12.030261993408203], [-2.998084306716919, 4.449044704437256], [-6.616254806518555, 8.923087120056152], [-7.183474063873291, 9.52135181427002], [-4.048014640808

100%|██████████| 34/34 [00:19<00:00,  1.71it/s]


Epoch: 7, Loss:  1.1453845500946045
[[-2.3068597316741943, -1.248526930809021], [-6.376408576965332, 4.52207088470459], [-5.902628421783447, -0.21266846358776093], [-6.961890697479248, 4.167898654937744], [-5.982339382171631, -1.7814682722091675], [-5.294004440307617, 2.9987173080444336], [-4.889400959014893, 2.2461252212524414], [-8.427576065063477, 5.793201446533203], [-3.0318992137908936, -0.7229044437408447], [-6.310494899749756, 3.154653787612915], [-4.877676486968994, 1.177659034729004], [-2.39831805229187, -0.27640193700790405], [-5.074711799621582, 2.358633041381836], [4.559058666229248, -2.09566593170166], [-4.350112438201904, 0.2957609295845032], [-7.015531539916992, 1.6271926164627075], [-5.866360664367676, 2.560110092163086], [-5.99407958984375, 3.451667547225952], [-5.028521537780762, 2.537306308746338], [-7.901261806488037, 4.628847122192383], [-1.8320987224578857, -0.17900602519512177], [-5.86333703994751, 2.105057954788208], [-6.211719512939453, 2.2074644565582275], [-2

100%|██████████| 34/34 [00:19<00:00,  1.71it/s]


Epoch: 8, Loss:  0.5253444910049438
[[-1.2380081415176392, 1.4508376121520996], [-5.116187572479248, 4.587805271148682], [-5.136022090911865, 4.524803638458252], [-5.705078601837158, 4.770379066467285], [-6.789386749267578, 5.066553115844727], [-3.8410301208496094, 3.203718423843384], [-3.6747937202453613, 2.995461940765381], [-7.5819501876831055, 7.4553542137146], [-2.3937201499938965, 1.944610595703125], [-5.192813396453857, 4.421280860900879], [-4.650354385375977, 3.9925296306610107], [-0.6795272827148438, 0.40817591547966003], [-3.739741325378418, 3.066025495529175], [-4.021197319030762, 1.0193490982055664], [-3.6030170917510986, 2.712167978286743], [-5.954559803009033, 5.159971714019775], [-4.844328880310059, 4.1661577224731445], [-4.824992656707764, 4.109776973724365], [-4.361963272094727, 3.7956318855285645], [-7.5769572257995605, 6.394529819488525], [-0.7284054160118103, 0.4545004665851593], [-5.195898056030273, 4.696480751037598], [-5.88615608215332, 4.751417636871338], [-1.70

100%|██████████| 34/34 [00:19<00:00,  1.71it/s]


Epoch: 9, Loss:  0.2753322720527649
[[-3.264810562133789, 1.310197353363037], [-6.276706695556641, 5.266434192657471], [-5.768356800079346, 2.1234121322631836], [-7.533563613891602, 5.636188983917236], [-5.340643405914307, 0.837058961391449], [-5.030045509338379, 3.8836967945098877], [-5.357246398925781, 3.495502233505249], [-8.405131340026855, 6.967571258544922], [-2.4460971355438232, 0.658360481262207], [-7.348694324493408, 5.040103435516357], [-2.788900852203369, 1.9851444959640503], [-2.8459320068359375, 1.5061891078948975], [-5.6531758308410645, 3.7194695472717285], [4.737088680267334, -1.268601894378662], [-4.71522855758667, 2.507739782333374], [-6.141822338104248, 2.9206137657165527], [-6.165186405181885, 4.083194255828857], [-6.511630535125732, 4.86740255355835], [-5.8380045890808105, 4.083959102630615], [-8.028007507324219, 5.651363849639893], [-2.4045398235321045, 1.775900959968567], [-6.927393913269043, 4.2962236404418945], [-7.411491870880127, 4.63777494430542], [-3.7305784

100%|██████████| 34/34 [00:19<00:00,  1.72it/s]


Epoch: 10, Loss:  0.36270850896835327
[[0.34174108505249023, 0.8763284683227539], [-3.437394857406616, 4.228213787078857], [-3.803628921508789, 4.501104831695557], [-4.49505615234375, 5.246577739715576], [-4.594691753387451, 5.233916282653809], [-1.8291003704071045, 2.837033987045288], [-2.1563708782196045, 3.0202407836914062], [-5.632519245147705, 6.760752201080322], [-1.3169186115264893, 2.506457805633545], [-3.7237131595611572, 4.449387073516846], [-4.795305252075195, 7.115102291107178], [0.6542525291442871, 0.4857829213142395], [-1.9801124334335327, 2.794398546218872], [-8.335450172424316, 8.73595142364502], [-1.7371554374694824, 2.3122737407684326], [-3.9538681507110596, 5.16686487197876], [-3.9140188694000244, 4.732583522796631], [-3.774807929992676, 4.3923563957214355], [-3.221428871154785, 3.917921304702759], [-5.478630065917969, 6.197798252105713], [-0.3499978184700012, 1.1922776699066162], [-3.2025299072265625, 4.206900119781494], [-3.747981309890747, 4.278140544891357], [-0.

 65%|██████▍   | 22/34 [00:12<00:07,  1.66it/s]

In [None]:
accuracy = metrics.accuracy_score(model_targets, model_labels)
f1_score_w_avg = metrics.f1_score(model_targets, model_labels, average='weighted')

print(f"Accuracy Score = {accuracy}")
print(f"F1 Score (Weighted) = {f1_score_w_avg}")

print(metrics.classification_report(model_targets, model_labels))

Accuracy Score = 0.6718518518518518
F1 Score (Weighted) = 0.660670534429505
              precision    recall  f1-score   support

           0       0.57      0.37      0.44       197
           1       0.73      0.81      0.77       849
           2       0.53      0.48      0.50       304

    accuracy                           0.67      1350
   macro avg       0.61      0.55      0.57      1350
weighted avg       0.66      0.67      0.66      1350

