In [21]:
!pip install setfit==1.1.0
!pip install transformers==4.42.2
!mkdir models




[notice] A new release of pip is available: 24.0 -> 24.3.1
[notice] To update, run: C:\Users\durek\AppData\Local\Microsoft\WindowsApps\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\python.exe -m pip install --upgrade pip





[notice] A new release of pip is available: 24.0 -> 24.3.1
[notice] To update, run: C:\Users\durek\AppData\Local\Microsoft\WindowsApps\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\python.exe -m pip install --upgrade pip
A subdirectory or file models already exists.


In [22]:
import pandas as pd
import time
import numpy as np
import torch
import nltk
import collections
from nltk.corpus import stopwords
from setfit import SetFitModel
from datasets import Dataset, load_dataset
from tqdm.auto import tqdm
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.dataset import random_split
from transformers import BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments

tqdm.pandas()

In [23]:
# dataset from hf_hub
langs = ['java', 'python', 'pharo']
labels = {
    'java': ['summary', 'Ownership', 'Expand', 'usage', 'Pointer', 'deprecation', 'rational'],
    'python': ['Usage', 'Parameters', 'DevelopmentNotes', 'Expand', 'Summary'],
    'pharo': ['Keyimplementationpoints', 'Example', 'Responsibilities', 'Classreferences', 'Intent', 'Keymessages', 'Collaborators']
}
ds = load_dataset('NLBSE/nlbse25-code-comment-classification')

ds

DatasetDict({
    java_train: Dataset({
        features: ['index', 'class', 'comment_sentence', 'partition', 'combo', 'labels'],
        num_rows: 7614
    })
    java_test: Dataset({
        features: ['index', 'class', 'comment_sentence', 'partition', 'combo', 'labels'],
        num_rows: 1725
    })
    python_train: Dataset({
        features: ['index', 'class', 'comment_sentence', 'partition', 'combo', 'labels'],
        num_rows: 1884
    })
    python_test: Dataset({
        features: ['index', 'class', 'comment_sentence', 'partition', 'combo', 'labels'],
        num_rows: 406
    })
    pharo_train: Dataset({
        features: ['index', 'class', 'comment_sentence', 'partition', 'combo', 'labels'],
        num_rows: 1298
    })
    pharo_test: Dataset({
        features: ['index', 'class', 'comment_sentence', 'partition', 'combo', 'labels'],
        num_rows: 289
    })
})

In [24]:
nltk.download('stopwords')
nltk.download('punkt_tab')
stop_words = set(stopwords.words('english'))
stop_words.add(',')

def remove_stop_words(dataset): 
    unstoppered_dataset = []
    for sentence in dataset:
        words = nltk.word_tokenize(sentence)
        filtered_words = [word for word in words if word not in stop_words]
        unstoppered_dataset.append(' '.join(filtered_words))
    return unstoppered_dataset

[nltk_data] Downloading package stopwords to
[nltk_data]     C:\Users\durek\AppData\Roaming\nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
[nltk_data] Downloading package punkt_tab to
[nltk_data]     C:\Users\durek\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!


In [26]:
for lan in langs:
    dataset = remove_stop_words(ds[f'{lan}_train']['combo'])
    word_count = dict(collections.Counter(" ".join(dataset).split()))
    tags = [x for x in dataset]
    tag_count = dict(collections.Counter(tags))
    print('total count of words',len(word_count))
    print('total count of tags',len(tag_count))
    most_common_tags = sorted(tag_count.items(), key=lambda x: x[1], reverse=True)[:10]
    most_common_words = sorted(word_count.items(), key=lambda x: x[1], reverse=True)[:10]

    print('Most common tags ', (','.join(tag for tag, _ in most_common_tags)))
    print('Most common words ',  (','.join(word for word, _ in most_common_words)))

total count of words 8855
total count of tags 6759
Most common tags  // $ NON-NLS-1 $ | LanguageSettingsProviderAssociationManager.java,// ok expected . | TestHarFileSystemBasics.java,* @ param fs filesystem | SwiftTestUtils.java,* @ param name name use exception message . | Check.java,// $ NON-NLS-1 $ | MakefileEditor.java,* @ param tUnit translation unit | ASTCache.java,// parse command line | NNThroughputBenchmark.java,@ param owner * Listener variable changes . | PaintTarget.java,@ link CacheLoader | TestingCacheLoaders.java,@ param name * Variable name . | PaintTarget.java
Most common words  |,*,.,@,//,(,),/,param,>
total count of words 2688
total count of tags 1830
Most common tags  traceback recent call last | ConfigDict,brew^func | UseOptimizer,123 | ConfigDict,set 0 fail first retry type . | Retry,1 nan 3 | IntegerArray,versionchanged 025.0 | PlotAccessor,key . | omdict,output input | AdaptiveMaxPool3d,display vid | YouTubeVideo,traceback recent call last | _MockPOP3
Most comm

In [6]:
def get_test_split(dataset_size):
    eighty_split = round(dataset_size * 0.8)
    twenty_split = dataset_size - eighty_split
    return [eighty_split, twenty_split]

def get_train_test(lang):
    X_test = remove_stop_words(ds[f'{lang}_test']['combo'])
    y_test = ds[f'{lang}_test']['labels']
    
    raw_train, raw_val = random_split(ds[f'{lang}_train'], lengths=get_test_split(len(ds[f'{lang}_train'])))

    X_train = remove_stop_words([x['combo'] for x in raw_train])
    y_train = [x['labels'] for x in raw_train]
    
    X_val = remove_stop_words([x['combo'] for x in raw_val])
    y_val = [x['labels'] for x in raw_val]
    return X_test, X_train, X_val, y_test, y_train, y_val

BATCH_SIZE = 32
def get_data_loader(dataset):
    return DataLoader(dataset=dataset, batch_size=BATCH_SIZE, shuffle=False)

In [7]:


class CustomDataset(Dataset):
    def __init__(self, encodings):
        self.encodings = encodings

    def __len__(self):
        return len(self.encodings["input_ids"])

    def __getitem__(self, idx):
        return {key: val[idx] for key, val in self.encodings.items()}

In [8]:


def prepare_data_for_bert(sentences, labels, tokenizer, num_labels):
  tokenized_data = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')
  if not isinstance(labels[0], list):
    labels = [[1 if i in label else 0 for i in range(num_labels)] for label in labels]
  tokenized_data["labels"] = torch.tensor(labels, dtype=torch.float32)
  return tokenized_data

def data_collator(features):
  input_ids = torch.stack([f["input_ids"] for f in features])
  attention_mask = torch.stack([f["attention_mask"] for f in features])
  labels = torch.stack([f["labels"] for f in features])
  # print(f"Batch input shape: {input_ids.shape}, Batch labels shape: {labels.shape}")
  return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}

for lang in langs:
  tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
  num_labels = len(labels[lang])
  model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=num_labels)
  X_test, X_train, X_val, y_test, y_train, y_val = get_train_test(lang)

  train_encodings = prepare_data_for_bert(X_train, y_train, tokenizer, num_labels)
  val_encodings = prepare_data_for_bert(X_val, y_val, tokenizer, num_labels)
  test_encodings = prepare_data_for_bert(X_test, y_test, tokenizer, num_labels)

  train_dataset = CustomDataset(train_encodings)
  val_dataset = CustomDataset(val_encodings)

  

  training_args = TrainingArguments(
      output_dir='./results',
      num_train_epochs=3,
      per_device_train_batch_size=16,
      per_device_eval_batch_size=16,
      warmup_steps=500,
      weight_decay=0.01,
      logging_dir='./logs',
      logging_steps=10,
  )
  
  trainer = Trainer(
      model=model,
      args=training_args,
      train_dataset=train_dataset,
      eval_dataset=val_dataset,
      data_collator=data_collator
  )

  # Train the model
  trainer.train()
  trainer.model.save_pretrained(f'./models/{lan}')

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  1%|          | 10/1143 [00:02<05:14,  3.60it/s]

{'loss': 0.6562, 'grad_norm': 2.5559494495391846, 'learning_rate': 1.0000000000000002e-06, 'epoch': 0.03}


  2%|▏         | 20/1143 [00:05<05:07,  3.65it/s]

{'loss': 0.6402, 'grad_norm': 2.3762083053588867, 'learning_rate': 2.0000000000000003e-06, 'epoch': 0.05}


  3%|▎         | 30/1143 [00:08<05:05,  3.65it/s]

{'loss': 0.6186, 'grad_norm': 2.1695592403411865, 'learning_rate': 3e-06, 'epoch': 0.08}


  3%|▎         | 40/1143 [00:11<05:02,  3.65it/s]

{'loss': 0.5961, 'grad_norm': 2.14536452293396, 'learning_rate': 4.000000000000001e-06, 'epoch': 0.1}


  4%|▍         | 50/1143 [00:13<04:58,  3.66it/s]

{'loss': 0.558, 'grad_norm': 1.9377450942993164, 'learning_rate': 5e-06, 'epoch': 0.13}


  5%|▌         | 60/1143 [00:16<04:55,  3.67it/s]

{'loss': 0.5044, 'grad_norm': 1.5471680164337158, 'learning_rate': 6e-06, 'epoch': 0.16}


  6%|▌         | 70/1143 [00:19<04:52,  3.67it/s]

{'loss': 0.4555, 'grad_norm': 1.4877221584320068, 'learning_rate': 7.000000000000001e-06, 'epoch': 0.18}


  7%|▋         | 80/1143 [00:22<04:50,  3.66it/s]

{'loss': 0.4075, 'grad_norm': 1.0593472719192505, 'learning_rate': 8.000000000000001e-06, 'epoch': 0.21}


  8%|▊         | 90/1143 [00:24<04:47,  3.66it/s]

{'loss': 0.3755, 'grad_norm': 0.9953741431236267, 'learning_rate': 9e-06, 'epoch': 0.24}


  9%|▊         | 100/1143 [00:27<04:45,  3.66it/s]

{'loss': 0.3493, 'grad_norm': 0.6898040175437927, 'learning_rate': 1e-05, 'epoch': 0.26}


 10%|▉         | 110/1143 [00:30<04:42,  3.66it/s]

{'loss': 0.3487, 'grad_norm': 0.7825172543525696, 'learning_rate': 1.1000000000000001e-05, 'epoch': 0.29}


 10%|█         | 120/1143 [00:33<04:43,  3.61it/s]

{'loss': 0.3364, 'grad_norm': 0.891537070274353, 'learning_rate': 1.2e-05, 'epoch': 0.31}


 11%|█▏        | 130/1143 [00:35<04:43,  3.57it/s]

{'loss': 0.2954, 'grad_norm': 0.7577946186065674, 'learning_rate': 1.3000000000000001e-05, 'epoch': 0.34}


 12%|█▏        | 140/1143 [00:38<04:38,  3.61it/s]

{'loss': 0.3077, 'grad_norm': 0.9871401190757751, 'learning_rate': 1.4000000000000001e-05, 'epoch': 0.37}


 13%|█▎        | 150/1143 [00:41<04:35,  3.60it/s]

{'loss': 0.2713, 'grad_norm': 1.0934011936187744, 'learning_rate': 1.5e-05, 'epoch': 0.39}


 14%|█▍        | 160/1143 [00:44<04:36,  3.56it/s]

{'loss': 0.2271, 'grad_norm': 1.0765897035598755, 'learning_rate': 1.6000000000000003e-05, 'epoch': 0.42}


 15%|█▍        | 170/1143 [00:47<04:29,  3.62it/s]

{'loss': 0.2731, 'grad_norm': 1.1037184000015259, 'learning_rate': 1.7000000000000003e-05, 'epoch': 0.45}


 16%|█▌        | 180/1143 [00:49<04:25,  3.63it/s]

{'loss': 0.2257, 'grad_norm': 0.7831043004989624, 'learning_rate': 1.8e-05, 'epoch': 0.47}


 17%|█▋        | 190/1143 [00:52<04:22,  3.64it/s]

{'loss': 0.2115, 'grad_norm': 0.5637199282646179, 'learning_rate': 1.9e-05, 'epoch': 0.5}


 17%|█▋        | 200/1143 [00:55<04:19,  3.64it/s]

{'loss': 0.1968, 'grad_norm': 0.8413820266723633, 'learning_rate': 2e-05, 'epoch': 0.52}


 18%|█▊        | 210/1143 [00:58<04:16,  3.64it/s]

{'loss': 0.2106, 'grad_norm': 1.023559808731079, 'learning_rate': 2.1e-05, 'epoch': 0.55}


 19%|█▉        | 220/1143 [01:00<04:21,  3.54it/s]

{'loss': 0.2038, 'grad_norm': 0.48777520656585693, 'learning_rate': 2.2000000000000003e-05, 'epoch': 0.58}


 20%|██        | 230/1143 [01:03<04:14,  3.59it/s]

{'loss': 0.1712, 'grad_norm': 1.1615960597991943, 'learning_rate': 2.3000000000000003e-05, 'epoch': 0.6}


 21%|██        | 240/1143 [01:06<04:09,  3.62it/s]

{'loss': 0.1733, 'grad_norm': 1.4166499376296997, 'learning_rate': 2.4e-05, 'epoch': 0.63}


 22%|██▏       | 250/1143 [01:09<04:06,  3.62it/s]

{'loss': 0.1908, 'grad_norm': 1.5184874534606934, 'learning_rate': 2.5e-05, 'epoch': 0.66}


 23%|██▎       | 260/1143 [01:12<04:03,  3.62it/s]

{'loss': 0.1502, 'grad_norm': 1.1327614784240723, 'learning_rate': 2.6000000000000002e-05, 'epoch': 0.68}


 24%|██▎       | 270/1143 [01:14<04:02,  3.60it/s]

{'loss': 0.1555, 'grad_norm': 1.1929140090942383, 'learning_rate': 2.7000000000000002e-05, 'epoch': 0.71}


 24%|██▍       | 280/1143 [01:17<03:58,  3.62it/s]

{'loss': 0.1387, 'grad_norm': 0.8877588510513306, 'learning_rate': 2.8000000000000003e-05, 'epoch': 0.73}


 25%|██▌       | 290/1143 [01:20<03:57,  3.59it/s]

{'loss': 0.1607, 'grad_norm': 1.0907708406448364, 'learning_rate': 2.9e-05, 'epoch': 0.76}


 26%|██▌       | 300/1143 [01:23<03:53,  3.61it/s]

{'loss': 0.1572, 'grad_norm': 0.9839709401130676, 'learning_rate': 3e-05, 'epoch': 0.79}


 27%|██▋       | 310/1143 [01:25<03:52,  3.58it/s]

{'loss': 0.1836, 'grad_norm': 0.9124348759651184, 'learning_rate': 3.1e-05, 'epoch': 0.81}


 28%|██▊       | 320/1143 [01:28<03:49,  3.58it/s]

{'loss': 0.1413, 'grad_norm': 0.5954126119613647, 'learning_rate': 3.2000000000000005e-05, 'epoch': 0.84}


 29%|██▉       | 330/1143 [01:31<03:48,  3.56it/s]

{'loss': 0.1645, 'grad_norm': 0.8562962412834167, 'learning_rate': 3.3e-05, 'epoch': 0.87}


 30%|██▉       | 340/1143 [01:34<03:44,  3.58it/s]

{'loss': 0.1049, 'grad_norm': 0.5353888273239136, 'learning_rate': 3.4000000000000007e-05, 'epoch': 0.89}


 31%|███       | 350/1143 [01:37<03:42,  3.56it/s]

{'loss': 0.1863, 'grad_norm': 0.6395153403282166, 'learning_rate': 3.5e-05, 'epoch': 0.92}


 31%|███▏      | 360/1143 [01:40<03:38,  3.58it/s]

{'loss': 0.1306, 'grad_norm': 1.3337671756744385, 'learning_rate': 3.6e-05, 'epoch': 0.94}


 32%|███▏      | 370/1143 [01:42<03:36,  3.57it/s]

{'loss': 0.1664, 'grad_norm': 0.8103851675987244, 'learning_rate': 3.7e-05, 'epoch': 0.97}


 33%|███▎      | 381/1143 [01:45<03:16,  3.87it/s]

{'loss': 0.128, 'grad_norm': 1.3756901025772095, 'learning_rate': 3.8e-05, 'epoch': 1.0}


 34%|███▍      | 390/1143 [01:48<03:27,  3.63it/s]

{'loss': 0.1313, 'grad_norm': 0.8082970976829529, 'learning_rate': 3.9000000000000006e-05, 'epoch': 1.02}


 35%|███▍      | 400/1143 [01:51<03:26,  3.60it/s]

{'loss': 0.1313, 'grad_norm': 1.6404367685317993, 'learning_rate': 4e-05, 'epoch': 1.05}


 36%|███▌      | 410/1143 [01:53<03:22,  3.62it/s]

{'loss': 0.1208, 'grad_norm': 0.5615829825401306, 'learning_rate': 4.1e-05, 'epoch': 1.08}


 37%|███▋      | 420/1143 [01:56<03:21,  3.59it/s]

{'loss': 0.1418, 'grad_norm': 0.5325814485549927, 'learning_rate': 4.2e-05, 'epoch': 1.1}


 38%|███▊      | 430/1143 [01:59<03:16,  3.63it/s]

{'loss': 0.1438, 'grad_norm': 3.0908122062683105, 'learning_rate': 4.3e-05, 'epoch': 1.13}


 38%|███▊      | 440/1143 [02:02<03:14,  3.62it/s]

{'loss': 0.1232, 'grad_norm': 1.0657589435577393, 'learning_rate': 4.4000000000000006e-05, 'epoch': 1.15}


 39%|███▉      | 450/1143 [02:04<03:11,  3.63it/s]

{'loss': 0.1412, 'grad_norm': 0.8883611559867859, 'learning_rate': 4.5e-05, 'epoch': 1.18}


 40%|████      | 460/1143 [02:07<03:08,  3.62it/s]

{'loss': 0.116, 'grad_norm': 0.7698433995246887, 'learning_rate': 4.600000000000001e-05, 'epoch': 1.21}


 41%|████      | 470/1143 [02:10<03:05,  3.63it/s]

{'loss': 0.1071, 'grad_norm': 1.263037085533142, 'learning_rate': 4.7e-05, 'epoch': 1.23}


 42%|████▏     | 480/1143 [02:13<03:03,  3.61it/s]

{'loss': 0.1482, 'grad_norm': 0.7129653096199036, 'learning_rate': 4.8e-05, 'epoch': 1.26}


 43%|████▎     | 490/1143 [02:16<03:03,  3.56it/s]

{'loss': 0.1245, 'grad_norm': 1.3506873846054077, 'learning_rate': 4.9e-05, 'epoch': 1.29}


 44%|████▎     | 500/1143 [02:18<02:59,  3.59it/s]

{'loss': 0.1344, 'grad_norm': 1.4525749683380127, 'learning_rate': 5e-05, 'epoch': 1.31}


 45%|████▍     | 510/1143 [02:22<03:04,  3.42it/s]

{'loss': 0.1197, 'grad_norm': 1.744982361793518, 'learning_rate': 4.922239502332815e-05, 'epoch': 1.34}


 45%|████▌     | 520/1143 [02:25<02:55,  3.55it/s]

{'loss': 0.0945, 'grad_norm': 0.9981407523155212, 'learning_rate': 4.84447900466563e-05, 'epoch': 1.36}


 46%|████▋     | 530/1143 [02:28<02:51,  3.57it/s]

{'loss': 0.0988, 'grad_norm': 1.14034104347229, 'learning_rate': 4.7667185069984446e-05, 'epoch': 1.39}


 47%|████▋     | 540/1143 [02:31<02:48,  3.59it/s]

{'loss': 0.143, 'grad_norm': 1.7059996128082275, 'learning_rate': 4.68895800933126e-05, 'epoch': 1.42}


 48%|████▊     | 550/1143 [02:33<02:45,  3.58it/s]

{'loss': 0.1178, 'grad_norm': 0.7754373550415039, 'learning_rate': 4.6111975116640746e-05, 'epoch': 1.44}


 49%|████▉     | 560/1143 [02:36<02:42,  3.58it/s]

{'loss': 0.0906, 'grad_norm': 0.5735083818435669, 'learning_rate': 4.53343701399689e-05, 'epoch': 1.47}


 50%|████▉     | 570/1143 [02:39<02:40,  3.56it/s]

{'loss': 0.1292, 'grad_norm': 0.853203535079956, 'learning_rate': 4.455676516329705e-05, 'epoch': 1.5}


 51%|█████     | 580/1143 [02:42<02:37,  3.58it/s]

{'loss': 0.1258, 'grad_norm': 0.6337575316429138, 'learning_rate': 4.37791601866252e-05, 'epoch': 1.52}


 52%|█████▏    | 590/1143 [02:45<02:32,  3.63it/s]

{'loss': 0.1145, 'grad_norm': 1.4950188398361206, 'learning_rate': 4.300155520995335e-05, 'epoch': 1.55}


 52%|█████▏    | 600/1143 [02:47<02:31,  3.58it/s]

{'loss': 0.1097, 'grad_norm': 1.6854966878890991, 'learning_rate': 4.22239502332815e-05, 'epoch': 1.57}


 53%|█████▎    | 610/1143 [02:50<02:28,  3.58it/s]

{'loss': 0.1035, 'grad_norm': 0.5249829888343811, 'learning_rate': 4.144634525660964e-05, 'epoch': 1.6}


 54%|█████▍    | 620/1143 [02:53<02:26,  3.58it/s]

{'loss': 0.0985, 'grad_norm': 0.4138135612010956, 'learning_rate': 4.06687402799378e-05, 'epoch': 1.63}


 55%|█████▌    | 630/1143 [02:56<02:23,  3.58it/s]

{'loss': 0.1246, 'grad_norm': 14.820125579833984, 'learning_rate': 3.989113530326594e-05, 'epoch': 1.65}


 56%|█████▌    | 640/1143 [02:59<02:21,  3.56it/s]

{'loss': 0.0869, 'grad_norm': 1.7807871103286743, 'learning_rate': 3.911353032659409e-05, 'epoch': 1.68}


 57%|█████▋    | 650/1143 [03:01<02:17,  3.58it/s]

{'loss': 0.0905, 'grad_norm': 3.343636989593506, 'learning_rate': 3.833592534992224e-05, 'epoch': 1.71}


 58%|█████▊    | 660/1143 [03:04<02:15,  3.58it/s]

{'loss': 0.0797, 'grad_norm': 0.9078494906425476, 'learning_rate': 3.755832037325039e-05, 'epoch': 1.73}


 59%|█████▊    | 670/1143 [03:07<02:12,  3.56it/s]

{'loss': 0.0806, 'grad_norm': 0.13757453858852386, 'learning_rate': 3.678071539657854e-05, 'epoch': 1.76}


 59%|█████▉    | 680/1143 [03:10<02:09,  3.57it/s]

{'loss': 0.1048, 'grad_norm': 1.1911338567733765, 'learning_rate': 3.6003110419906685e-05, 'epoch': 1.78}


 60%|██████    | 690/1143 [03:13<02:06,  3.58it/s]

{'loss': 0.0825, 'grad_norm': 1.6964929103851318, 'learning_rate': 3.522550544323484e-05, 'epoch': 1.81}


 61%|██████    | 700/1143 [03:15<02:03,  3.58it/s]

{'loss': 0.0733, 'grad_norm': 5.420984268188477, 'learning_rate': 3.4447900466562985e-05, 'epoch': 1.84}


 62%|██████▏   | 710/1143 [03:18<02:02,  3.53it/s]

{'loss': 0.1234, 'grad_norm': 1.3768606185913086, 'learning_rate': 3.3670295489891136e-05, 'epoch': 1.86}


 63%|██████▎   | 720/1143 [03:21<01:58,  3.57it/s]

{'loss': 0.1081, 'grad_norm': 1.1406394243240356, 'learning_rate': 3.2892690513219286e-05, 'epoch': 1.89}


 64%|██████▍   | 730/1143 [03:24<01:55,  3.56it/s]

{'loss': 0.1093, 'grad_norm': 0.9759132266044617, 'learning_rate': 3.2115085536547436e-05, 'epoch': 1.92}


 65%|██████▍   | 740/1143 [03:27<01:52,  3.58it/s]

{'loss': 0.0606, 'grad_norm': 1.3364771604537964, 'learning_rate': 3.1337480559875586e-05, 'epoch': 1.94}


 66%|██████▌   | 750/1143 [03:30<01:49,  3.58it/s]

{'loss': 0.1236, 'grad_norm': 0.6501545906066895, 'learning_rate': 3.0559875583203736e-05, 'epoch': 1.97}


 66%|██████▋   | 760/1143 [03:32<01:46,  3.58it/s]

{'loss': 0.1045, 'grad_norm': 0.36730635166168213, 'learning_rate': 2.978227060653188e-05, 'epoch': 1.99}


 67%|██████▋   | 770/1143 [03:35<01:44,  3.56it/s]

{'loss': 0.0735, 'grad_norm': 0.8043469786643982, 'learning_rate': 2.9004665629860033e-05, 'epoch': 2.02}


 68%|██████▊   | 780/1143 [03:38<01:40,  3.60it/s]

{'loss': 0.0612, 'grad_norm': 1.112476110458374, 'learning_rate': 2.822706065318818e-05, 'epoch': 2.05}


 69%|██████▉   | 790/1143 [03:41<01:38,  3.60it/s]

{'loss': 0.0989, 'grad_norm': 1.1165355443954468, 'learning_rate': 2.7449455676516334e-05, 'epoch': 2.07}


 70%|██████▉   | 800/1143 [03:43<01:34,  3.61it/s]

{'loss': 0.0588, 'grad_norm': 0.8731000423431396, 'learning_rate': 2.667185069984448e-05, 'epoch': 2.1}


 71%|███████   | 810/1143 [03:46<01:33,  3.57it/s]

{'loss': 0.073, 'grad_norm': 0.3917151391506195, 'learning_rate': 2.5894245723172627e-05, 'epoch': 2.13}


 72%|███████▏  | 820/1143 [03:49<01:30,  3.55it/s]

{'loss': 0.0959, 'grad_norm': 0.5585570335388184, 'learning_rate': 2.511664074650078e-05, 'epoch': 2.15}


 73%|███████▎  | 830/1143 [03:52<01:27,  3.59it/s]

{'loss': 0.0704, 'grad_norm': 0.40997326374053955, 'learning_rate': 2.4339035769828927e-05, 'epoch': 2.18}


 73%|███████▎  | 840/1143 [03:55<01:25,  3.56it/s]

{'loss': 0.0845, 'grad_norm': 2.0016632080078125, 'learning_rate': 2.3561430793157078e-05, 'epoch': 2.2}


 74%|███████▍  | 850/1143 [03:57<01:21,  3.61it/s]

{'loss': 0.0656, 'grad_norm': 0.33391982316970825, 'learning_rate': 2.2783825816485228e-05, 'epoch': 2.23}


 75%|███████▌  | 860/1143 [04:00<01:18,  3.60it/s]

{'loss': 0.0807, 'grad_norm': 0.5787703394889832, 'learning_rate': 2.2006220839813378e-05, 'epoch': 2.26}


 76%|███████▌  | 870/1143 [04:03<01:16,  3.55it/s]

{'loss': 0.0724, 'grad_norm': 1.323004961013794, 'learning_rate': 2.1228615863141525e-05, 'epoch': 2.28}


 77%|███████▋  | 880/1143 [04:06<01:13,  3.57it/s]

{'loss': 0.1046, 'grad_norm': 2.0879316329956055, 'learning_rate': 2.0451010886469675e-05, 'epoch': 2.31}


 78%|███████▊  | 890/1143 [04:09<01:10,  3.57it/s]

{'loss': 0.0797, 'grad_norm': 1.971909999847412, 'learning_rate': 1.9673405909797825e-05, 'epoch': 2.34}


 79%|███████▊  | 900/1143 [04:11<01:08,  3.56it/s]

{'loss': 0.0833, 'grad_norm': 4.060495853424072, 'learning_rate': 1.8895800933125972e-05, 'epoch': 2.36}


 80%|███████▉  | 910/1143 [04:14<01:05,  3.56it/s]

{'loss': 0.0658, 'grad_norm': 0.7200765609741211, 'learning_rate': 1.8118195956454122e-05, 'epoch': 2.39}


 80%|████████  | 920/1143 [04:17<01:02,  3.57it/s]

{'loss': 0.0719, 'grad_norm': 0.42006897926330566, 'learning_rate': 1.734059097978227e-05, 'epoch': 2.41}


 81%|████████▏ | 930/1143 [04:20<00:59,  3.58it/s]

{'loss': 0.0437, 'grad_norm': 0.3484354317188263, 'learning_rate': 1.656298600311042e-05, 'epoch': 2.44}


 82%|████████▏ | 940/1143 [04:23<00:56,  3.58it/s]

{'loss': 0.0761, 'grad_norm': 1.4582829475402832, 'learning_rate': 1.578538102643857e-05, 'epoch': 2.47}


 83%|████████▎ | 950/1143 [04:26<00:54,  3.57it/s]

{'loss': 0.084, 'grad_norm': 0.8006773591041565, 'learning_rate': 1.500777604976672e-05, 'epoch': 2.49}


 84%|████████▍ | 960/1143 [04:28<00:51,  3.57it/s]

{'loss': 0.072, 'grad_norm': 0.7692151069641113, 'learning_rate': 1.423017107309487e-05, 'epoch': 2.52}


 85%|████████▍ | 970/1143 [04:31<00:48,  3.56it/s]

{'loss': 0.0552, 'grad_norm': 0.4100859761238098, 'learning_rate': 1.3452566096423016e-05, 'epoch': 2.55}


 86%|████████▌ | 980/1143 [04:34<00:45,  3.58it/s]

{'loss': 0.0663, 'grad_norm': 1.0901354551315308, 'learning_rate': 1.2674961119751166e-05, 'epoch': 2.57}


 87%|████████▋ | 990/1143 [04:37<00:42,  3.57it/s]

{'loss': 0.0915, 'grad_norm': 0.9091700911521912, 'learning_rate': 1.1897356143079317e-05, 'epoch': 2.6}


 87%|████████▋ | 1000/1143 [04:40<00:40,  3.57it/s]

{'loss': 0.052, 'grad_norm': 1.8873369693756104, 'learning_rate': 1.1119751166407467e-05, 'epoch': 2.62}


 88%|████████▊ | 1010/1143 [04:43<00:39,  3.38it/s]

{'loss': 0.0822, 'grad_norm': 1.2633287906646729, 'learning_rate': 1.0342146189735615e-05, 'epoch': 2.65}


 89%|████████▉ | 1020/1143 [04:46<00:34,  3.54it/s]

{'loss': 0.0433, 'grad_norm': 0.23714007437229156, 'learning_rate': 9.564541213063764e-06, 'epoch': 2.68}


 90%|█████████ | 1030/1143 [04:49<00:31,  3.59it/s]

{'loss': 0.109, 'grad_norm': 1.0525039434432983, 'learning_rate': 8.786936236391912e-06, 'epoch': 2.7}


 91%|█████████ | 1040/1143 [04:52<00:28,  3.60it/s]

{'loss': 0.0677, 'grad_norm': 0.9635331034660339, 'learning_rate': 8.009331259720062e-06, 'epoch': 2.73}


 92%|█████████▏| 1050/1143 [04:55<00:26,  3.57it/s]

{'loss': 0.0535, 'grad_norm': 0.9832435250282288, 'learning_rate': 7.2317262830482126e-06, 'epoch': 2.76}


 93%|█████████▎| 1060/1143 [04:57<00:23,  3.57it/s]

{'loss': 0.0781, 'grad_norm': 1.763541340827942, 'learning_rate': 6.454121306376361e-06, 'epoch': 2.78}


 94%|█████████▎| 1070/1143 [05:00<00:20,  3.57it/s]

{'loss': 0.0595, 'grad_norm': 2.8899402618408203, 'learning_rate': 5.67651632970451e-06, 'epoch': 2.81}


 94%|█████████▍| 1080/1143 [05:03<00:17,  3.60it/s]

{'loss': 0.0743, 'grad_norm': 1.3253012895584106, 'learning_rate': 4.89891135303266e-06, 'epoch': 2.83}


 95%|█████████▌| 1090/1143 [05:06<00:14,  3.60it/s]

{'loss': 0.0754, 'grad_norm': 2.0013554096221924, 'learning_rate': 4.121306376360809e-06, 'epoch': 2.86}


 96%|█████████▌| 1100/1143 [05:09<00:11,  3.59it/s]

{'loss': 0.0541, 'grad_norm': 0.3629354238510132, 'learning_rate': 3.343701399688958e-06, 'epoch': 2.89}


 97%|█████████▋| 1110/1143 [05:11<00:09,  3.59it/s]

{'loss': 0.0748, 'grad_norm': 0.3842093348503113, 'learning_rate': 2.5660964230171077e-06, 'epoch': 2.91}


 98%|█████████▊| 1120/1143 [05:14<00:06,  3.60it/s]

{'loss': 0.0727, 'grad_norm': 3.0310862064361572, 'learning_rate': 1.7884914463452568e-06, 'epoch': 2.94}


 99%|█████████▉| 1130/1143 [05:17<00:03,  3.60it/s]

{'loss': 0.0415, 'grad_norm': 1.5247738361358643, 'learning_rate': 1.010886469673406e-06, 'epoch': 2.97}


100%|█████████▉| 1140/1143 [05:20<00:00,  3.60it/s]

{'loss': 0.0402, 'grad_norm': 1.431164264678955, 'learning_rate': 2.3328149300155523e-07, 'epoch': 2.99}


100%|██████████| 1143/1143 [05:21<00:00,  3.55it/s]


{'train_runtime': 321.937, 'train_samples_per_second': 56.76, 'train_steps_per_second': 3.55, 'train_loss': 0.15535265301156231, 'epoch': 3.0}


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  5%|▍         | 14/285 [00:00<00:14, 18.80it/s]

{'loss': 0.716, 'grad_norm': 2.6184709072113037, 'learning_rate': 1.0000000000000002e-06, 'epoch': 0.11}


  8%|▊         | 24/285 [00:01<00:13, 19.26it/s]

{'loss': 0.7098, 'grad_norm': 2.696014165878296, 'learning_rate': 2.0000000000000003e-06, 'epoch': 0.21}


 11%|█         | 32/285 [00:01<00:12, 19.66it/s]

{'loss': 0.6948, 'grad_norm': 2.038691520690918, 'learning_rate': 3e-06, 'epoch': 0.32}


 15%|█▌        | 44/285 [00:02<00:12, 20.02it/s]

{'loss': 0.6661, 'grad_norm': 2.24939227104187, 'learning_rate': 4.000000000000001e-06, 'epoch': 0.42}


 19%|█▊        | 53/285 [00:02<00:11, 19.92it/s]

{'loss': 0.6475, 'grad_norm': 2.0555713176727295, 'learning_rate': 5e-06, 'epoch': 0.53}


 22%|██▏       | 62/285 [00:03<00:11, 19.66it/s]

{'loss': 0.6068, 'grad_norm': 1.7495758533477783, 'learning_rate': 6e-06, 'epoch': 0.63}


 26%|██▌       | 74/285 [00:03<00:10, 19.73it/s]

{'loss': 0.5799, 'grad_norm': 3.8589351177215576, 'learning_rate': 7.000000000000001e-06, 'epoch': 0.74}


 29%|██▉       | 84/285 [00:04<00:10, 19.99it/s]

{'loss': 0.5305, 'grad_norm': 3.9420573711395264, 'learning_rate': 8.000000000000001e-06, 'epoch': 0.84}


 33%|███▎      | 94/285 [00:04<00:09, 19.87it/s]

{'loss': 0.5265, 'grad_norm': 1.8916069269180298, 'learning_rate': 9e-06, 'epoch': 0.95}


 36%|███▋      | 104/285 [00:05<00:09, 19.78it/s]

{'loss': 0.5074, 'grad_norm': 1.112158179283142, 'learning_rate': 1e-05, 'epoch': 1.05}


 40%|████      | 114/285 [00:05<00:08, 20.02it/s]

{'loss': 0.4794, 'grad_norm': 1.8169327974319458, 'learning_rate': 1.1000000000000001e-05, 'epoch': 1.16}


 43%|████▎     | 123/285 [00:06<00:08, 19.99it/s]

{'loss': 0.4821, 'grad_norm': 1.6952779293060303, 'learning_rate': 1.2e-05, 'epoch': 1.26}


 47%|████▋     | 134/285 [00:06<00:07, 19.93it/s]

{'loss': 0.4805, 'grad_norm': 1.3621374368667603, 'learning_rate': 1.3000000000000001e-05, 'epoch': 1.37}


 50%|█████     | 143/285 [00:07<00:07, 19.93it/s]

{'loss': 0.4662, 'grad_norm': 1.227439045906067, 'learning_rate': 1.4000000000000001e-05, 'epoch': 1.47}


 53%|█████▎    | 152/285 [00:07<00:06, 19.96it/s]

{'loss': 0.4696, 'grad_norm': 1.5330435037612915, 'learning_rate': 1.5e-05, 'epoch': 1.58}


 58%|█████▊    | 164/285 [00:08<00:06, 20.09it/s]

{'loss': 0.4473, 'grad_norm': 1.8720934391021729, 'learning_rate': 1.6000000000000003e-05, 'epoch': 1.68}


 61%|██████    | 173/285 [00:08<00:05, 20.04it/s]

{'loss': 0.4482, 'grad_norm': 2.615161895751953, 'learning_rate': 1.7000000000000003e-05, 'epoch': 1.79}


 64%|██████▍   | 182/285 [00:09<00:05, 19.93it/s]

{'loss': 0.4414, 'grad_norm': 2.3138928413391113, 'learning_rate': 1.8e-05, 'epoch': 1.89}


 68%|██████▊   | 194/285 [00:09<00:04, 20.38it/s]

{'loss': 0.4414, 'grad_norm': 3.372636556625366, 'learning_rate': 1.9e-05, 'epoch': 2.0}


 71%|███████   | 203/285 [00:10<00:04, 20.16it/s]

{'loss': 0.3911, 'grad_norm': 1.9696252346038818, 'learning_rate': 2e-05, 'epoch': 2.11}


 74%|███████▍  | 212/285 [00:10<00:03, 19.90it/s]

{'loss': 0.4081, 'grad_norm': 2.5427281856536865, 'learning_rate': 2.1e-05, 'epoch': 2.21}


 79%|███████▊  | 224/285 [00:11<00:03, 20.19it/s]

{'loss': 0.3947, 'grad_norm': 2.2675938606262207, 'learning_rate': 2.2000000000000003e-05, 'epoch': 2.32}


 82%|████████▏ | 233/285 [00:11<00:02, 20.02it/s]

{'loss': 0.3791, 'grad_norm': 3.883298635482788, 'learning_rate': 2.3000000000000003e-05, 'epoch': 2.42}


 85%|████████▍ | 242/285 [00:12<00:02, 19.88it/s]

{'loss': 0.3965, 'grad_norm': 2.2502846717834473, 'learning_rate': 2.4e-05, 'epoch': 2.53}


 89%|████████▉ | 254/285 [00:12<00:01, 20.02it/s]

{'loss': 0.385, 'grad_norm': 3.692549705505371, 'learning_rate': 2.5e-05, 'epoch': 2.63}


 92%|█████████▏| 263/285 [00:13<00:01, 19.99it/s]

{'loss': 0.3683, 'grad_norm': 3.119147539138794, 'learning_rate': 2.6000000000000002e-05, 'epoch': 2.74}


 95%|█████████▌| 272/285 [00:13<00:00, 19.92it/s]

{'loss': 0.3795, 'grad_norm': 3.0481414794921875, 'learning_rate': 2.7000000000000002e-05, 'epoch': 2.84}


100%|█████████▉| 284/285 [00:14<00:00, 19.98it/s]

{'loss': 0.4013, 'grad_norm': 2.672109603881836, 'learning_rate': 2.8000000000000003e-05, 'epoch': 2.95}


100%|██████████| 285/285 [00:15<00:00, 18.64it/s]


{'train_runtime': 15.292, 'train_samples_per_second': 295.644, 'train_steps_per_second': 18.637, 'train_loss': 0.49238941042046797, 'epoch': 3.0}


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  6%|▌         | 12/195 [00:01<00:16, 11.22it/s]

{'loss': 0.6661, 'grad_norm': 2.1690027713775635, 'learning_rate': 1.0000000000000002e-06, 'epoch': 0.15}


 11%|█▏        | 22/195 [00:02<00:14, 11.67it/s]

{'loss': 0.6533, 'grad_norm': 2.344601631164551, 'learning_rate': 2.0000000000000003e-06, 'epoch': 0.31}


 16%|█▋        | 32/195 [00:02<00:13, 11.73it/s]

{'loss': 0.6451, 'grad_norm': 2.1031975746154785, 'learning_rate': 3e-06, 'epoch': 0.46}


 22%|██▏       | 42/195 [00:03<00:13, 11.67it/s]

{'loss': 0.6128, 'grad_norm': 2.2305824756622314, 'learning_rate': 4.000000000000001e-06, 'epoch': 0.62}


 27%|██▋       | 52/195 [00:04<00:12, 11.70it/s]

{'loss': 0.5786, 'grad_norm': 2.283547878265381, 'learning_rate': 5e-06, 'epoch': 0.77}


 32%|███▏      | 62/195 [00:05<00:11, 11.76it/s]

{'loss': 0.5306, 'grad_norm': 1.7833751440048218, 'learning_rate': 6e-06, 'epoch': 0.92}


 37%|███▋      | 72/195 [00:06<00:10, 11.86it/s]

{'loss': 0.4896, 'grad_norm': 1.6686160564422607, 'learning_rate': 7.000000000000001e-06, 'epoch': 1.08}


 42%|████▏     | 82/195 [00:07<00:09, 11.71it/s]

{'loss': 0.4657, 'grad_norm': 1.2183144092559814, 'learning_rate': 8.000000000000001e-06, 'epoch': 1.23}


 47%|████▋     | 92/195 [00:07<00:08, 11.71it/s]

{'loss': 0.4415, 'grad_norm': 1.2614117860794067, 'learning_rate': 9e-06, 'epoch': 1.38}


 52%|█████▏    | 102/195 [00:08<00:07, 11.71it/s]

{'loss': 0.4224, 'grad_norm': 1.1729240417480469, 'learning_rate': 1e-05, 'epoch': 1.54}


 57%|█████▋    | 112/195 [00:09<00:07, 11.76it/s]

{'loss': 0.4128, 'grad_norm': 1.4572575092315674, 'learning_rate': 1.1000000000000001e-05, 'epoch': 1.69}


 63%|██████▎   | 122/195 [00:10<00:06, 11.73it/s]

{'loss': 0.3873, 'grad_norm': 0.68861985206604, 'learning_rate': 1.2e-05, 'epoch': 1.85}


 68%|██████▊   | 132/195 [00:11<00:05, 11.74it/s]

{'loss': 0.3698, 'grad_norm': 1.1319215297698975, 'learning_rate': 1.3000000000000001e-05, 'epoch': 2.0}


 73%|███████▎  | 142/195 [00:12<00:04, 11.79it/s]

{'loss': 0.3679, 'grad_norm': 1.2437562942504883, 'learning_rate': 1.4000000000000001e-05, 'epoch': 2.15}


 78%|███████▊  | 152/195 [00:13<00:03, 11.74it/s]

{'loss': 0.3499, 'grad_norm': 0.9065372943878174, 'learning_rate': 1.5e-05, 'epoch': 2.31}


 83%|████████▎ | 162/195 [00:13<00:02, 11.79it/s]

{'loss': 0.3484, 'grad_norm': 1.6505500078201294, 'learning_rate': 1.6000000000000003e-05, 'epoch': 2.46}


 88%|████████▊ | 172/195 [00:14<00:01, 11.70it/s]

{'loss': 0.3306, 'grad_norm': 1.658427119255066, 'learning_rate': 1.7000000000000003e-05, 'epoch': 2.62}


 93%|█████████▎| 182/195 [00:15<00:01, 11.79it/s]

{'loss': 0.3509, 'grad_norm': 1.242890477180481, 'learning_rate': 1.8e-05, 'epoch': 2.77}


 98%|█████████▊| 192/195 [00:16<00:00, 11.75it/s]

{'loss': 0.3203, 'grad_norm': 1.4901281595230103, 'learning_rate': 1.9e-05, 'epoch': 2.92}


100%|██████████| 195/195 [00:17<00:00, 11.07it/s]


{'train_runtime': 17.6151, 'train_samples_per_second': 176.78, 'train_steps_per_second': 11.07, 'train_loss': 0.4557581962683262, 'epoch': 3.0}


In [9]:
total_flops = 0
total_time = 0
scores = []
for lan in langs:
    # to load trained models:
    model = SetFitModel.from_pretrained(f'./models/{lan}')
    # print(f"combo: {len(ds[f'{lan}_test']['combo'])}, label: {len(ds[f'{lan}_test']['labels'])}")
    # to load pretrained models from Hub:
    # model = SetFitModel.from_pretrained(f"NLBSE/nlbse25_{lan}")
    with torch.profiler.profile(with_flops=True) as p:
        begin = time.time()
        for i in range(10):
          y_pred = model(ds[f'{lan}_test']['combo']).numpy().T
        total = time.time() - begin
        total_time = total_time + total
    total_flops = total_flops + (sum(k.flops for k in p.key_averages()) / 1e9)
    y_true = np.array(ds[f'{lan}_test']['labels']).T
    for i in range(len(y_pred)):
        assert(len(y_pred[i]) == len(y_true[i]))
        tp = sum([true == pred == 1 for (true,pred) in zip(y_true[i], y_pred[i])])
        tn = sum([true == pred == 0 for (true,pred) in zip(y_true[i], y_pred[i])])
        fp = sum([true == 0 and pred == 1 for (true,pred) in zip(y_true[i], y_pred[i])])
        fn = sum([true == 1 and pred == 0 for (true,pred) in zip(y_true[i], y_pred[i])])
        precision = tp / (tp + fp)
        recall = tp / (tp + fn)
        f1 = (2*tp) / (2*tp + fp + fn)
        scores.append({'lan': lan, 'cat': labels[lan][i],'precision': precision,'recall': recall,'f1': f1})
        print({'lan': lan, 'cat': labels[lan][i],'precision': precision,'recall': recall,'f1': f1})
print("Compute in GFLOPs:", total_flops/10)
print("Avg runtime in seconds:", total_time/10)
scores = pd.DataFrame(scores)
scores

HFValidationError: Repo id must be in the form 'repo_name' or 'namespace/repo_name': './models/java'. Use `repo_type` argument if needed.

In [10]:
print("Compute in GFLOPs:", total_flops/10)
print("Avg runtime in seconds:", total_time/10)
scores = pd.DataFrame(scores)

max_avg_runtime = 5
max_avg_flops = 5000

def score(avg_f1, avg_runtime, avg_flops):
    return (0.6 * avg_f1 +
      0.2 * ((max_avg_runtime - avg_runtime) / max_avg_runtime) +
      0.2 * ((max_avg_flops - avg_flops) / max_avg_flops))

avg_f1 = scores.f1.mean()
avg_runtime = total_time/10
avg_flops = total_flops/10

round(score(avg_f1, avg_runtime, avg_flops), 2)

Compute in GFLOPs: 0.0
Avg runtime in seconds: 0.0


AttributeError: 'DataFrame' object has no attribute 'f1'