In [None]:
import os
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from tqdm import tqdm
from sklearn.metrics import f1_score, accuracy_score
from torch.utils.data import Dataset, DataLoader
from transformers import BertModel, BertTokenizer
import spacy
import nltk
import re
import string
from nltk.tokenize import sent_tokenize, word_tokenize

In [None]:
nltk.download('punkt_tab')
nltk.download('averaged_perceptron_tagger_eng')
nltk.download('punkt')
nltk.download('averaged_perceptron_tagger')

[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!
[nltk_data] Downloading package averaged_perceptron_tagger_eng to
[nltk_data]     /root/nltk_data...
[nltk_data]   Package averaged_perceptron_tagger_eng is already up-to-
[nltk_data]       date!
[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     /root/nltk_data...
[nltk_data]   Package averaged_perceptron_tagger is already up-to-
[nltk_data]       date!


True

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda', index=0)

In [None]:
# read the csv datasets
train_df = pd.read_csv('train_en_dataset.csv')
test_df = pd.read_csv('test_en_dataset.csv')

In [None]:
from torch.utils.data import Dataset, DataLoader

class TweetDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        tweet = self.data.iloc[idx]['tweet']
        label = self.data.iloc[idx]['value']
        return tweet, label

In [None]:
train_dataset = TweetDataset(train_df)
test_dataset = TweetDataset(test_df)

In [None]:
train_dataset[0:5]

(0    “mansplaining” is literally just how intellige...
 1    if you don’t want me but your friend do, dont ...
 2    @username @username @username @username isn't ...
 3    @username's account is temporarily unavailable...
 4    @username if it wasn't for the gender biases o...
 Name: tweet, dtype: object,
 0    1.0
 1    1.0
 2    1.0
 3    0.0
 4    1.0
 Name: value, dtype: float64)

In [None]:
positive_samples = sum(value == 1 for value in train_df['value'])
negative_samples = sum(value == 0 for value in train_df['value'])

In [None]:
pos_weight = torch.tensor([ negative_samples / positive_samples ]).to(device)
pos_weight

tensor([1.5296], device='cuda:0')

In [None]:
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)

In [None]:
nlp = spacy.load('en_core_web_sm')



In [None]:
def handcrafted_features(texts):
    features = []
    for text in texts:
        doc = nlp(text)

        # clauses per Sentence
        sentence_count = len(list(doc.sents))
        clause_count = sum(1 for token in doc if token.dep_ in {"csubj", "ccomp", "advcl", "acl", "relcl"})
        clause_per_sentence = clause_count / sentence_count if sentence_count > 0 else 0

        # count of imperative sentences
        imperative_count = sum(1 for sent in doc.sents if len(sent) > 0 and sent[0].pos_ == "VERB" and sent[0].tag_ == "VB")

        # count of passive voice usage
        passive_count = sum(
            1 for token in doc if token.dep_ == "nsubjpass" and any(child.dep_ == "auxpass" for child in token.head.children)
        )

        # ratio of women-related gendered pronouns to total pronouns
        pronouns = [token.text.lower() for token in doc if token.pos_ in {"PRON"}]
        women_gendered_pronouns = {'she', 'her', 'hers'}
        gendered_count = sum(1 for pronoun in pronouns if pronoun in women_gendered_pronouns)
        total_pronouns = len(pronouns)
        gendered_pronoun_ratio = gendered_count / total_pronouns if total_pronouns > 0 else 0

        # count of negations
        neg_count = sum(1 for token in doc if token.dep_ == "neg")

        # append syntactic features for each text as a list
        features.append([
            clause_per_sentence,
            imperative_count,
            passive_count,
            gendered_pronoun_ratio,
            neg_count
        ])
    return torch.tensor(features, dtype=torch.float32)

In [None]:
class SemSynSexistDetector(nn.Module):
    def __init__(self, padding='max_length', num_classes=1, handcrafted_feature_dim=5):
        super(SemSynSexistDetector, self).__init__()
        self.padding = padding
        self.berttokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        self.pooling = nn.AdaptiveAvgPool1d(1)

        combined_feature_dim = self.bert.config.hidden_size + handcrafted_feature_dim
        self.cls = nn.Sequential(
            nn.Linear(combined_feature_dim, 512),
            nn.ReLU(),
            nn.Dropout(0.1),

            nn.Linear(512, 256),
            nn.LayerNorm(256),
            nn.ReLU(),
            nn.Dropout(0.1),

            nn.Linear(256, num_classes)

        )

        for param in self.bert.parameters():
            param.requires_grad = False

    def tokenize(self, texts):
        encoding = self.berttokenizer(
            texts,
            add_special_tokens=True,
            padding=self.padding,
            truncation=True,
            max_length=256,
            return_tensors="pt"
        )
        input_ids = encoding['input_ids'].to(device)
        attention_mask = encoding['attention_mask'].to(device)
        return input_ids, attention_mask

    def forward(self, texts):
        input_ids, attention_mask = self.tokenize(texts)
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        cls_token = outputs.pooler_output
        syn_sem_features = handcrafted_features(texts).to(device)
        # the major difference: concotenated featrues        
        combined_features = torch.cat([cls_token, syn_sem_features], dim=1)
        features = self.cls(combined_features)
        return features

In [None]:
# train function
def train(model, train_loader, test_loader, optimizer,
          scheduler,
          epochs, device, criterion=nn.BCEWithLogitsLoss(pos_weight=pos_weight)):
    best_acc = 0
    model.train()

    for epoch in range(epochs):
        total_loss = 0

        # training loop
        for (texts, labels) in tqdm(train_loader):
            labels = labels.to(torch.float32).to(device)
            optimizer.zero_grad()
            logits = model(texts)
            logits = logits.squeeze(1)
            loss = criterion(logits, labels)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        avg_loss = total_loss / len(train_loader)
        print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}")

        # evaluate the model on the validation set after each epoch
        acc, f1 = evaluate(model, test_loader, device)
        print(f"Test Accuracy: {acc:.4f}, F1 Score: {f1:.4f}")

        # if current acc is greater than previous best acc, save a new best model
        if acc > best_acc:
            best_acc = acc
            print(f"New best model found with accuracy: {best_acc:.4f}, saving the model...")
            torch.save(model, "best_model.pth")

        # apply scheduler to adjust the learning rate
        scheduler.step()

    print("Training complete!")

In [None]:
# evaluate model
sigmoid = nn.Sigmoid()

def evaluate(model, dataloader, device, threshold=0.5):
    model.eval()
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for (texts, labels) in tqdm(dataloader):
            labels = labels.to(device)
            feature = model(texts)
            logits = sigmoid(feature)
            logits = logits.squeeze(1)
            preds = (logits > threshold).int()

            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    accuracy = accuracy_score(all_labels, all_preds)
    f1 = f1_score(all_labels, all_preds)

    print(f"Accuracy: {accuracy:.4f}")
    print(f"F1 Score: {f1:.4f}")

    return accuracy, f1

In [None]:
model = SemSynSexistDetector()
model.to(device)



model.safetensors:  17%|#6        | 73.4M/440M [00:00<?, ?B/s]

SemSynSexistDetector(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elem

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [None]:
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.1)

In [None]:
epochs = 50

In [None]:
train(model, train_loader, test_loader, optimizer, scheduler, epochs, device)

100%|██████████| 166/166 [00:32<00:00,  5.07it/s]


Epoch 1/50, Loss: 0.8366


100%|██████████| 42/42 [00:13<00:00,  3.00it/s]


Accuracy: 0.5762
F1 Score: 0.0539
Test Accuracy: 0.5762, F1 Score: 0.0539
New best model found with accuracy: 0.5762, saving the model...


100%|██████████| 166/166 [00:28<00:00,  5.92it/s]


Epoch 2/50, Loss: 0.8087


100%|██████████| 42/42 [00:13<00:00,  3.06it/s]


Accuracy: 0.6546
F1 Score: 0.5429
Test Accuracy: 0.6546, F1 Score: 0.5429
New best model found with accuracy: 0.6546, saving the model...


100%|██████████| 166/166 [00:28<00:00,  5.92it/s]


Epoch 3/50, Loss: 0.7821


100%|██████████| 42/42 [00:13<00:00,  3.06it/s]


Accuracy: 0.5415
F1 Score: 0.6390
Test Accuracy: 0.5415, F1 Score: 0.6390


100%|██████████| 166/166 [00:28<00:00,  5.92it/s]


Epoch 4/50, Loss: 0.7541


100%|██████████| 42/42 [00:13<00:00,  3.07it/s]


Accuracy: 0.6757
F1 Score: 0.6274
Test Accuracy: 0.6757, F1 Score: 0.6274
New best model found with accuracy: 0.6757, saving the model...


100%|██████████| 166/166 [00:28<00:00,  5.88it/s]


Epoch 5/50, Loss: 0.7486


100%|██████████| 42/42 [00:13<00:00,  3.08it/s]


Accuracy: 0.6908
F1 Score: 0.6333
Test Accuracy: 0.6908, F1 Score: 0.6333
New best model found with accuracy: 0.6908, saving the model...


100%|██████████| 166/166 [00:27<00:00,  5.94it/s]


Epoch 6/50, Loss: 0.7013


100%|██████████| 42/42 [00:13<00:00,  3.09it/s]


Accuracy: 0.7014
F1 Score: 0.7054
Test Accuracy: 0.7014, F1 Score: 0.7054
New best model found with accuracy: 0.7014, saving the model...


100%|██████████| 166/166 [00:28<00:00,  5.91it/s]


Epoch 7/50, Loss: 0.7076


100%|██████████| 42/42 [00:13<00:00,  3.02it/s]


Accuracy: 0.7074
F1 Score: 0.6745
Test Accuracy: 0.7074, F1 Score: 0.6745
New best model found with accuracy: 0.7074, saving the model...


100%|██████████| 166/166 [00:28<00:00,  5.90it/s]


Epoch 8/50, Loss: 0.6635


100%|██████████| 42/42 [00:13<00:00,  3.08it/s]


Accuracy: 0.6938
F1 Score: 0.5762
Test Accuracy: 0.6938, F1 Score: 0.5762


100%|██████████| 166/166 [00:28<00:00,  5.87it/s]


Epoch 9/50, Loss: 0.6681


100%|██████████| 42/42 [00:13<00:00,  3.04it/s]


Accuracy: 0.7210
F1 Score: 0.6816
Test Accuracy: 0.7210, F1 Score: 0.6816
New best model found with accuracy: 0.7210, saving the model...


100%|██████████| 166/166 [00:28<00:00,  5.91it/s]


Epoch 10/50, Loss: 0.6470


100%|██████████| 42/42 [00:13<00:00,  3.03it/s]


Accuracy: 0.7059
F1 Score: 0.7041
Test Accuracy: 0.7059, F1 Score: 0.7041


100%|██████████| 166/166 [00:28<00:00,  5.85it/s]


Epoch 11/50, Loss: 0.6497


100%|██████████| 42/42 [00:13<00:00,  3.07it/s]


Accuracy: 0.7451
F1 Score: 0.7169
Test Accuracy: 0.7451, F1 Score: 0.7169
New best model found with accuracy: 0.7451, saving the model...


100%|██████████| 166/166 [00:28<00:00,  5.90it/s]


Epoch 12/50, Loss: 0.6253


100%|██████████| 42/42 [00:13<00:00,  3.06it/s]


Accuracy: 0.7451
F1 Score: 0.7372
Test Accuracy: 0.7451, F1 Score: 0.7372


100%|██████████| 166/166 [00:28<00:00,  5.92it/s]


Epoch 13/50, Loss: 0.6121


100%|██████████| 42/42 [00:13<00:00,  3.07it/s]


Accuracy: 0.7210
F1 Score: 0.7114
Test Accuracy: 0.7210, F1 Score: 0.7114


100%|██████████| 166/166 [00:28<00:00,  5.89it/s]


Epoch 14/50, Loss: 0.6143


100%|██████████| 42/42 [00:13<00:00,  3.06it/s]


Accuracy: 0.7225
F1 Score: 0.6528
Test Accuracy: 0.7225, F1 Score: 0.6528


100%|██████████| 166/166 [00:28<00:00,  5.93it/s]


Epoch 15/50, Loss: 0.6087


100%|██████████| 42/42 [00:13<00:00,  3.02it/s]


Accuracy: 0.7421
F1 Score: 0.6839
Test Accuracy: 0.7421, F1 Score: 0.6839


100%|██████████| 166/166 [00:27<00:00,  5.93it/s]


Epoch 16/50, Loss: 0.5763


100%|██████████| 42/42 [00:13<00:00,  3.05it/s]


Accuracy: 0.7541
F1 Score: 0.7306
Test Accuracy: 0.7541, F1 Score: 0.7306
New best model found with accuracy: 0.7541, saving the model...


100%|██████████| 166/166 [00:28<00:00,  5.88it/s]


Epoch 17/50, Loss: 0.5658


100%|██████████| 42/42 [00:13<00:00,  3.05it/s]


Accuracy: 0.7330
F1 Score: 0.7256
Test Accuracy: 0.7330, F1 Score: 0.7256


100%|██████████| 166/166 [00:28<00:00,  5.88it/s]


Epoch 18/50, Loss: 0.5629


100%|██████████| 42/42 [00:13<00:00,  3.05it/s]


Accuracy: 0.7436
F1 Score: 0.7185
Test Accuracy: 0.7436, F1 Score: 0.7185


100%|██████████| 166/166 [00:28<00:00,  5.91it/s]


Epoch 19/50, Loss: 0.5601


100%|██████████| 42/42 [00:13<00:00,  3.06it/s]


Accuracy: 0.7466
F1 Score: 0.7299
Test Accuracy: 0.7466, F1 Score: 0.7299


100%|██████████| 166/166 [00:28<00:00,  5.84it/s]


Epoch 20/50, Loss: 0.5582


100%|██████████| 42/42 [00:13<00:00,  3.06it/s]


Accuracy: 0.7406
F1 Score: 0.7208
Test Accuracy: 0.7406, F1 Score: 0.7208


100%|██████████| 166/166 [00:28<00:00,  5.84it/s]


Epoch 21/50, Loss: 0.5576


100%|██████████| 42/42 [00:13<00:00,  3.05it/s]


Accuracy: 0.7436
F1 Score: 0.7007
Test Accuracy: 0.7436, F1 Score: 0.7007


100%|██████████| 166/166 [00:28<00:00,  5.89it/s]


Epoch 22/50, Loss: 0.5536


100%|██████████| 42/42 [00:13<00:00,  3.06it/s]


Accuracy: 0.7451
F1 Score: 0.6998
Test Accuracy: 0.7451, F1 Score: 0.6998


100%|██████████| 166/166 [00:28<00:00,  5.89it/s]


Epoch 23/50, Loss: 0.5517


100%|██████████| 42/42 [00:13<00:00,  3.06it/s]


Accuracy: 0.7345
F1 Score: 0.7152
Test Accuracy: 0.7345, F1 Score: 0.7152


100%|██████████| 166/166 [00:28<00:00,  5.89it/s]


Epoch 24/50, Loss: 0.5526


100%|██████████| 42/42 [00:13<00:00,  3.07it/s]


Accuracy: 0.7496
F1 Score: 0.6795
Test Accuracy: 0.7496, F1 Score: 0.6795


100%|██████████| 166/166 [00:28<00:00,  5.85it/s]


Epoch 25/50, Loss: 0.5524


100%|██████████| 42/42 [00:13<00:00,  3.05it/s]


Accuracy: 0.7466
F1 Score: 0.7209
Test Accuracy: 0.7466, F1 Score: 0.7209


100%|██████████| 166/166 [00:28<00:00,  5.89it/s]


Epoch 26/50, Loss: 0.5501


100%|██████████| 42/42 [00:13<00:00,  3.04it/s]


Accuracy: 0.7466
F1 Score: 0.7123
Test Accuracy: 0.7466, F1 Score: 0.7123


100%|██████████| 166/166 [00:28<00:00,  5.84it/s]


Epoch 27/50, Loss: 0.5491


100%|██████████| 42/42 [00:13<00:00,  3.04it/s]


Accuracy: 0.7391
F1 Score: 0.6872
Test Accuracy: 0.7391, F1 Score: 0.6872


100%|██████████| 166/166 [00:28<00:00,  5.89it/s]


Epoch 28/50, Loss: 0.5423


100%|██████████| 42/42 [00:13<00:00,  3.07it/s]


Accuracy: 0.7436
F1 Score: 0.7038
Test Accuracy: 0.7436, F1 Score: 0.7038


100%|██████████| 166/166 [00:28<00:00,  5.92it/s]


Epoch 29/50, Loss: 0.5409


100%|██████████| 42/42 [00:13<00:00,  3.06it/s]


Accuracy: 0.7391
F1 Score: 0.7092
Test Accuracy: 0.7391, F1 Score: 0.7092


100%|██████████| 166/166 [00:28<00:00,  5.91it/s]


Epoch 30/50, Loss: 0.5396


100%|██████████| 42/42 [00:13<00:00,  3.06it/s]


Accuracy: 0.7436
F1 Score: 0.6986
Test Accuracy: 0.7436, F1 Score: 0.6986


100%|██████████| 166/166 [00:28<00:00,  5.91it/s]


Epoch 31/50, Loss: 0.5321


100%|██████████| 42/42 [00:14<00:00,  3.00it/s]


Accuracy: 0.7421
F1 Score: 0.7057
Test Accuracy: 0.7421, F1 Score: 0.7057


100%|██████████| 166/166 [00:28<00:00,  5.89it/s]


Epoch 32/50, Loss: 0.5320


100%|██████████| 42/42 [00:13<00:00,  3.06it/s]


Accuracy: 0.7406
F1 Score: 0.7075
Test Accuracy: 0.7406, F1 Score: 0.7075


100%|██████████| 166/166 [00:28<00:00,  5.83it/s]


Epoch 33/50, Loss: 0.5310


100%|██████████| 42/42 [00:13<00:00,  3.06it/s]


Accuracy: 0.7436
F1 Score: 0.7119
Test Accuracy: 0.7436, F1 Score: 0.7119


100%|██████████| 166/166 [00:28<00:00,  5.90it/s]


Epoch 34/50, Loss: 0.5310


100%|██████████| 42/42 [00:13<00:00,  3.05it/s]


Accuracy: 0.7421
F1 Score: 0.7036
Test Accuracy: 0.7421, F1 Score: 0.7036


100%|██████████| 166/166 [00:28<00:00,  5.92it/s]


Epoch 35/50, Loss: 0.5312


100%|██████████| 42/42 [00:13<00:00,  3.07it/s]


Accuracy: 0.7421
F1 Score: 0.7077
Test Accuracy: 0.7421, F1 Score: 0.7077


100%|██████████| 166/166 [00:28<00:00,  5.92it/s]


Epoch 36/50, Loss: 0.5303


100%|██████████| 42/42 [00:13<00:00,  3.07it/s]


Accuracy: 0.7451
F1 Score: 0.7121
Test Accuracy: 0.7451, F1 Score: 0.7121


100%|██████████| 166/166 [00:27<00:00,  5.94it/s]


Epoch 37/50, Loss: 0.5300


100%|██████████| 42/42 [00:13<00:00,  3.06it/s]


Accuracy: 0.7436
F1 Score: 0.7157
Test Accuracy: 0.7436, F1 Score: 0.7157


100%|██████████| 166/166 [00:28<00:00,  5.91it/s]


Epoch 38/50, Loss: 0.5302


100%|██████████| 42/42 [00:13<00:00,  3.07it/s]


Accuracy: 0.7451
F1 Score: 0.7121
Test Accuracy: 0.7451, F1 Score: 0.7121


100%|██████████| 166/166 [00:28<00:00,  5.89it/s]


Epoch 39/50, Loss: 0.5302


100%|██████████| 42/42 [00:13<00:00,  3.03it/s]


Accuracy: 0.7451
F1 Score: 0.7121
Test Accuracy: 0.7451, F1 Score: 0.7121


100%|██████████| 166/166 [00:28<00:00,  5.90it/s]


Epoch 40/50, Loss: 0.5296


100%|██████████| 42/42 [00:13<00:00,  3.07it/s]


Accuracy: 0.7406
F1 Score: 0.7045
Test Accuracy: 0.7406, F1 Score: 0.7045


100%|██████████| 166/166 [00:28<00:00,  5.88it/s]


Epoch 41/50, Loss: 0.5293


100%|██████████| 42/42 [00:13<00:00,  3.04it/s]


Accuracy: 0.7421
F1 Score: 0.7077
Test Accuracy: 0.7421, F1 Score: 0.7077


100%|██████████| 166/166 [00:28<00:00,  5.90it/s]


Epoch 42/50, Loss: 0.5299


100%|██████████| 42/42 [00:13<00:00,  3.04it/s]


Accuracy: 0.7421
F1 Score: 0.7026
Test Accuracy: 0.7421, F1 Score: 0.7026


100%|██████████| 166/166 [00:28<00:00,  5.88it/s]


Epoch 43/50, Loss: 0.5295


100%|██████████| 42/42 [00:13<00:00,  3.06it/s]


Accuracy: 0.7451
F1 Score: 0.7111
Test Accuracy: 0.7451, F1 Score: 0.7111


100%|██████████| 166/166 [00:28<00:00,  5.88it/s]


Epoch 44/50, Loss: 0.5290


100%|██████████| 42/42 [00:13<00:00,  3.01it/s]


Accuracy: 0.7406
F1 Score: 0.7045
Test Accuracy: 0.7406, F1 Score: 0.7045


100%|██████████| 166/166 [00:28<00:00,  5.89it/s]


Epoch 45/50, Loss: 0.5286


100%|██████████| 42/42 [00:13<00:00,  3.00it/s]


Accuracy: 0.7406
F1 Score: 0.7045
Test Accuracy: 0.7406, F1 Score: 0.7045


100%|██████████| 166/166 [00:28<00:00,  5.86it/s]


Epoch 46/50, Loss: 0.5280


100%|██████████| 42/42 [00:13<00:00,  3.06it/s]


Accuracy: 0.7406
F1 Score: 0.7045
Test Accuracy: 0.7406, F1 Score: 0.7045


100%|██████████| 166/166 [00:28<00:00,  5.90it/s]


Epoch 47/50, Loss: 0.5280


100%|██████████| 42/42 [00:13<00:00,  3.02it/s]


Accuracy: 0.7406
F1 Score: 0.7045
Test Accuracy: 0.7406, F1 Score: 0.7045


100%|██████████| 166/166 [00:28<00:00,  5.90it/s]


Epoch 48/50, Loss: 0.5279


100%|██████████| 42/42 [00:13<00:00,  3.07it/s]


Accuracy: 0.7421
F1 Score: 0.7077
Test Accuracy: 0.7421, F1 Score: 0.7077


100%|██████████| 166/166 [00:28<00:00,  5.91it/s]


Epoch 49/50, Loss: 0.5277


100%|██████████| 42/42 [00:13<00:00,  3.06it/s]


Accuracy: 0.7421
F1 Score: 0.7077
Test Accuracy: 0.7421, F1 Score: 0.7077


100%|██████████| 166/166 [00:28<00:00,  5.92it/s]


Epoch 50/50, Loss: 0.5286


100%|██████████| 42/42 [00:13<00:00,  3.06it/s]

Accuracy: 0.7421
F1 Score: 0.7077
Test Accuracy: 0.7421, F1 Score: 0.7077
Training complete!





In [None]:
# load the best model during training
best_sem_syn_model = torch.load('best_model.pth').to(device)

  best_sem_syn_model = torch.load('best_model.pth').to(device)


In [None]:
evaluate(best_sem_syn_model, test_loader, device)

100%|██████████| 42/42 [00:13<00:00,  3.04it/s]

Accuracy: 0.7541
F1 Score: 0.7306





(0.7541478129713424, 0.7305785123966942)