In [180]:
from transformers import AutoTokenizer, AdamW, AutoModelForSequenceClassification, get_scheduler
from datasets import load_metric
from accelerate import Accelerator
from torch.utils.data import DataLoader
import torch
from torch.utils.data import Dataset
from tqdm import tqdm
import pandas as pd
from sklearn.model_selection import train_test_split

### Data Processing

In [181]:
df = pd.read_csv("ind.csv")
df = df.drop_duplicates()
df.shape

(79, 2)

In [182]:
checkpoint = "bert-base-uncased"

In [183]:
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

In [184]:
texts = df["name"].tolist()
labels = df["label"].tolist()

In [185]:
train_texts, test_texts, train_labels, test_labels = train_test_split(texts, labels, test_size=0.2)

In [186]:
tokenized_train = tokenizer(train_texts, padding=True, truncation=True)
tokenized_test = tokenizer(test_texts, padding=True, truncation=True)

In [187]:
class CustomDataset(Dataset):
    def __init__(self, texts, labels):
        self.texts = texts
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.texts.items()}
        item['labels'] = torch.tensor(self.labels[idx], dtype=torch.float32)
        return item

In [188]:
train_ds = CustomDataset(tokenized_train, train_labels)
test_ds = CustomDataset(tokenized_test, test_labels)

In [189]:
batch_size = 8

train_dl = DataLoader(train_ds, shuffle=True, batch_size=batch_size)
test_dl = DataLoader(test_ds, shuffle=True, batch_size=batch_size)

### Load Model

In [190]:
# define accelerator that eases train-test process
accelerator = Accelerator()

In [191]:
# define model
model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=1)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [192]:
# define optimizer
optimizer = AdamW(model.parameters(), lr=0.001)



In [193]:
# set dataloaders, model and optimizer to accelerator (to gpu or smth)
train_dl, test_dl, model, optimizer = accelerator.prepare(train_dl, test_dl, model, optimizer)

In [194]:
# define epochs
epochs = 100
training_steps = epochs * len(train_dl)

In [195]:
# define scheduler for training
scheduler = get_scheduler(
    "linear",
    optimizer = optimizer,
    num_warmup_steps=0,
    num_training_steps=training_steps
)

### Train Model

In [196]:
# set metrics
metric = load_metric("accuracy")

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.


In [197]:
for epoch in tqdm(range(epochs)):
    
    # TRAINING
    model.train()

    for batch in train_dl:

        # forward pass
        output = model(**batch)
        loss = output.loss
        
        # metrics
        logits = output.logits
        preds = torch.round(logits)
        metric.add_batch(predictions=preds, references=batch["labels"])
        
        # backprop
        accelerator.backward(loss)
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()

    print("Training: ", metric.compute())

    
    # TESTING
    model.eval()

    with torch.inference_mode():
        for batch in test_dl:

            # forward pass
            output = model(**batch)
            
            # metrics
            logits = output.logits
            preds = torch.round(logits)
            metric.add_batch(predictions=preds, references=batch["labels"])

    print("Test: ", metric.compute())



Training:  {'accuracy': 0.38095238095238093}




Test:  {'accuracy': 0.4375}
Training:  {'accuracy': 0.3333333333333333}




Test:  {'accuracy': 0.4375}
Training:  {'accuracy': 0.49206349206349204}




Test:  {'accuracy': 0.4375}
Training:  {'accuracy': 0.5714285714285714}




Test:  {'accuracy': 0.5625}
Training:  {'accuracy': 0.5238095238095238}




Test:  {'accuracy': 0.4375}
Training:  {'accuracy': 0.6031746031746031}




Test:  {'accuracy': 0.4375}
Training:  {'accuracy': 0.5079365079365079}




Test:  {'accuracy': 0.5625}
Training:  {'accuracy': 0.5238095238095238}




Test:  {'accuracy': 0.4375}
Training:  {'accuracy': 0.4603174603174603}




Test:  {'accuracy': 0.5625}
Training:  {'accuracy': 0.42857142857142855}




Test:  {'accuracy': 0.4375}
Training:  {'accuracy': 0.5079365079365079}




Test:  {'accuracy': 0.4375}
Training:  {'accuracy': 0.49206349206349204}




Test:  {'accuracy': 0.4375}
Training:  {'accuracy': 0.5238095238095238}




Test:  {'accuracy': 0.4375}
Training:  {'accuracy': 0.38095238095238093}




Test:  {'accuracy': 0.4375}
Training:  {'accuracy': 0.49206349206349204}




Test:  {'accuracy': 0.5625}
Training:  {'accuracy': 0.4126984126984127}




Test:  {'accuracy': 0.4375}
Training:  {'accuracy': 0.5238095238095238}




Test:  {'accuracy': 0.5625}
Training:  {'accuracy': 0.38095238095238093}




Test:  {'accuracy': 0.4375}
Training:  {'accuracy': 0.49206349206349204}




Test:  {'accuracy': 0.5625}
Training:  {'accuracy': 0.49206349206349204}




Test:  {'accuracy': 0.4375}
Training:  {'accuracy': 0.47619047619047616}




Test:  {'accuracy': 0.5625}
Training:  {'accuracy': 0.5396825396825397}




Test:  {'accuracy': 0.4375}
Training:  {'accuracy': 0.5873015873015873}




Test:  {'accuracy': 0.5625}
Training:  {'accuracy': 0.5555555555555556}




Test:  {'accuracy': 0.4375}
Training:  {'accuracy': 0.4603174603174603}




Test:  {'accuracy': 0.5625}
Training:  {'accuracy': 0.49206349206349204}




Test:  {'accuracy': 0.4375}
Training:  {'accuracy': 0.3492063492063492}




Test:  {'accuracy': 0.5625}
Training:  {'accuracy': 0.49206349206349204}




Test:  {'accuracy': 0.4375}
Training:  {'accuracy': 0.5396825396825397}




Test:  {'accuracy': 0.5625}
Training:  {'accuracy': 0.5396825396825397}




Test:  {'accuracy': 0.5625}
Training:  {'accuracy': 0.5079365079365079}




Test:  {'accuracy': 0.5625}
Training:  {'accuracy': 0.5555555555555556}




Test:  {'accuracy': 0.4375}
Training:  {'accuracy': 0.5238095238095238}




Test:  {'accuracy': 0.5625}
Training:  {'accuracy': 0.3968253968253968}




Test:  {'accuracy': 0.4375}
Training:  {'accuracy': 0.42857142857142855}




Test:  {'accuracy': 0.5625}
Training:  {'accuracy': 0.4603174603174603}




Test:  {'accuracy': 0.4375}
Training:  {'accuracy': 0.5873015873015873}




Test:  {'accuracy': 0.5625}
Training:  {'accuracy': 0.5079365079365079}




Test:  {'accuracy': 0.4375}
Training:  {'accuracy': 0.5238095238095238}




Test:  {'accuracy': 0.5625}
Training:  {'accuracy': 0.49206349206349204}




Test:  {'accuracy': 0.5625}
Training:  {'accuracy': 0.5714285714285714}




Test:  {'accuracy': 0.4375}
Training:  {'accuracy': 0.5714285714285714}




Test:  {'accuracy': 0.5625}
Training:  {'accuracy': 0.3968253968253968}




Test:  {'accuracy': 0.4375}
Training:  {'accuracy': 0.5238095238095238}




Test:  {'accuracy': 0.5625}
Training:  {'accuracy': 0.49206349206349204}




Test:  {'accuracy': 0.4375}
Training:  {'accuracy': 0.47619047619047616}




Test:  {'accuracy': 0.4375}
Training:  {'accuracy': 0.42857142857142855}




Test:  {'accuracy': 0.5625}
Training:  {'accuracy': 0.5555555555555556}




Test:  {'accuracy': 0.4375}
Training:  {'accuracy': 0.4126984126984127}




Test:  {'accuracy': 0.4375}
Training:  {'accuracy': 0.5238095238095238}




Test:  {'accuracy': 0.4375}
Training:  {'accuracy': 0.42857142857142855}




Test:  {'accuracy': 0.5625}
Training:  {'accuracy': 0.49206349206349204}




Test:  {'accuracy': 0.4375}
Training:  {'accuracy': 0.4603174603174603}




Test:  {'accuracy': 0.4375}
Training:  {'accuracy': 0.42857142857142855}




Test:  {'accuracy': 0.4375}
Training:  {'accuracy': 0.5238095238095238}




Test:  {'accuracy': 0.4375}
Training:  {'accuracy': 0.5079365079365079}




Test:  {'accuracy': 0.5625}
Training:  {'accuracy': 0.42857142857142855}




Test:  {'accuracy': 0.4375}
Training:  {'accuracy': 0.6031746031746031}




Test:  {'accuracy': 0.5625}
Training:  {'accuracy': 0.42857142857142855}




Test:  {'accuracy': 0.4375}
Training:  {'accuracy': 0.4444444444444444}




Test:  {'accuracy': 0.4375}
Training:  {'accuracy': 0.5396825396825397}




Test:  {'accuracy': 0.5625}
Training:  {'accuracy': 0.5873015873015873}




Test:  {'accuracy': 0.4375}
Training:  {'accuracy': 0.5396825396825397}




Test:  {'accuracy': 0.4375}
Training:  {'accuracy': 0.5555555555555556}




Test:  {'accuracy': 0.5625}
Training:  {'accuracy': 0.4603174603174603}




Test:  {'accuracy': 0.4375}
Training:  {'accuracy': 0.42857142857142855}




Test:  {'accuracy': 0.5625}
Training:  {'accuracy': 0.6349206349206349}




Test:  {'accuracy': 0.4375}
Training:  {'accuracy': 0.5555555555555556}




Test:  {'accuracy': 0.4375}
Training:  {'accuracy': 0.5396825396825397}




Test:  {'accuracy': 0.4375}
Training:  {'accuracy': 0.4444444444444444}




Test:  {'accuracy': 0.4375}
Training:  {'accuracy': 0.4603174603174603}




Test:  {'accuracy': 0.4375}
Training:  {'accuracy': 0.38095238095238093}




Test:  {'accuracy': 0.4375}
Training:  {'accuracy': 0.5079365079365079}




Test:  {'accuracy': 0.4375}
Training:  {'accuracy': 0.49206349206349204}




Test:  {'accuracy': 0.5625}
Training:  {'accuracy': 0.4603174603174603}




Test:  {'accuracy': 0.4375}
Training:  {'accuracy': 0.4126984126984127}




Test:  {'accuracy': 0.4375}
Training:  {'accuracy': 0.42857142857142855}




Test:  {'accuracy': 0.4375}
Training:  {'accuracy': 0.42857142857142855}




Test:  {'accuracy': 0.4375}
Training:  {'accuracy': 0.42857142857142855}




Test:  {'accuracy': 0.4375}
Training:  {'accuracy': 0.42857142857142855}




Test:  {'accuracy': 0.4375}
Training:  {'accuracy': 0.36507936507936506}




Test:  {'accuracy': 0.4375}
Training:  {'accuracy': 0.47619047619047616}




Test:  {'accuracy': 0.4375}
Training:  {'accuracy': 0.6190476190476191}




Test:  {'accuracy': 0.5625}
Training:  {'accuracy': 0.49206349206349204}




Test:  {'accuracy': 0.5625}
Training:  {'accuracy': 0.5555555555555556}




Test:  {'accuracy': 0.4375}
Training:  {'accuracy': 0.5873015873015873}




Test:  {'accuracy': 0.4375}
Training:  {'accuracy': 0.4444444444444444}




Test:  {'accuracy': 0.5625}
Training:  {'accuracy': 0.4603174603174603}




Test:  {'accuracy': 0.4375}
Training:  {'accuracy': 0.49206349206349204}




Test:  {'accuracy': 0.4375}
Training:  {'accuracy': 0.42857142857142855}




Test:  {'accuracy': 0.4375}
Training:  {'accuracy': 0.5555555555555556}




Test:  {'accuracy': 0.4375}
Training:  {'accuracy': 0.4126984126984127}




Test:  {'accuracy': 0.4375}
Training:  {'accuracy': 0.49206349206349204}




Test:  {'accuracy': 0.4375}
Training:  {'accuracy': 0.4603174603174603}




Test:  {'accuracy': 0.4375}
Training:  {'accuracy': 0.5555555555555556}




Test:  {'accuracy': 0.4375}
Training:  {'accuracy': 0.5238095238095238}




Test:  {'accuracy': 0.4375}
Training:  {'accuracy': 0.4444444444444444}




Test:  {'accuracy': 0.4375}
Training:  {'accuracy': 0.5714285714285714}




Test:  {'accuracy': 0.4375}
Training:  {'accuracy': 0.5079365079365079}




Test:  {'accuracy': 0.4375}
Training:  {'accuracy': 0.47619047619047616}


100%|██████████| 100/100 [04:35<00:00,  2.75s/it]

Test:  {'accuracy': 0.4375}



