In [1]:
from tqdm import tqdm
import jsonlines
from datasets import load_dataset
import numpy as np 
import pandas as pd 
from sklearn.metrics import f1_score, accuracy_score
from datasets import load_dataset
import torch

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)
fp16 = True if torch.cuda.is_available() else False

  from .autonotebook import tqdm as notebook_tqdm


cuda


In [2]:
train_output_file_path = "train_embeddings.jsonl"
val_output_file_path = "val_embeddings.jsonl"

train = load_dataset("json", data_files=train_output_file_path, streaming=False, split="train")
val = load_dataset("json", data_files=val_output_file_path, streaming=False, split="train")

In [3]:
train_labels, val_labels = [], []
train_max_pooled, val_max_pooled = [], []

for document in tqdm(train):
    embeddings = torch.tensor(document["embeddings"])
    doc_labels = torch.tensor(document["labels"])

    emb_max, _ = torch.max(embeddings, dim=0)
    #emb = torch.mean(embeddings, dim=0)

    train_labels.append(doc_labels)
    train_max_pooled.append(emb_max)

for document in tqdm(val):
    embeddings = torch.tensor(document["embeddings"])
    doc_labels = torch.tensor(document["labels"])

    emb_max, _ = torch.max(embeddings, dim=0)
    #emb = torch.mean(embeddings, dim=0)


    val_labels.append(doc_labels)
    val_max_pooled.append(emb_max)

100%|█████████████████████████████████████████████████████████████████████████| 307102/307102 [07:37<00:00, 671.34it/s]
100%|███████████████████████████████████████████████████████████████████████████| 17104/17104 [00:26<00:00, 639.60it/s]


In [4]:
from torch.utils.data import TensorDataset, DataLoader
batch_size = 32
train = TensorDataset(torch.stack(train_max_pooled), torch.stack(train_labels))
val = TensorDataset(torch.stack(val_max_pooled), torch.stack(val_labels))
train_loader = DataLoader(train, batch_size=batch_size)
val_loader = DataLoader(val, batch_size=batch_size)

In [5]:
import torch.nn as nn

class CustomMultiLabelClassifier(nn.Module):
    def __init__(self, input_size, num_classes):
        super(CustomMultiLabelClassifier, self).__init__()
        self.fc1 = nn.Linear(input_size, 64)
        self.relu1 = nn.ReLU()
        #self.fc2 = nn.Linear(512, 256)
        #self.relu2 = nn.ReLU()
        self.fc3 = nn.Linear(64, num_classes)
        #self.sigmoid = nn.Sigmoid()  # For multilabel classification

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu1(x)
        #x = self.fc2(x)
        #x = self.relu2(x)
        x = self.fc3(x)
        #x = self.sigmoid(x)
        return x

In [6]:
from util_loss import ResampleLoss
from collections import Counter
np.set_printoptions(suppress=True)

train_num = len(train)
print(train_num)
class_freq = [92768, 11607, 7541, 12527, 8960, 5859, 2845, 4794, 2845, 5137, 3556, 2094, 1763, 3122, 1798, 1527, 739, 469, 472, 506, 321, 144, 230, 181, 166, 148, 296, 107, 136, 70, 110, 96]
print(class_freq)

loss_func = ResampleLoss(reweight_func='rebalance', loss_weight=1.0,
                         focal=dict(focal=True, alpha=0.5, gamma=2),
                         logit_reg=dict(init_bias=0.05, neg_scale=2.0),
                         map_param=dict(alpha=0.1, beta=10.0, gamma=0.9), 
                         class_freq=class_freq, train_num=train_num)
loss_func

307102
[92768, 11607, 7541, 12527, 8960, 5859, 2845, 4794, 2845, 5137, 3556, 2094, 1763, 3122, 1798, 1527, 739, 469, 472, 506, 321, 144, 230, 181, 166, 148, 296, 107, 136, 70, 110, 96]


ResampleLoss()

In [None]:
model = CustomMultiLabelClassifier(input_size=768, num_classes=32)
model.to(device)

criterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

def evaluate(model, val_loader):
    model.eval()
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for inputs, labels in val_loader:
            outputs = model(inputs.to(device))
            logits = outputs
            preds = (logits > 0).int()  # Convert logits to binary predictions
            all_preds.extend(preds.tolist())
            all_labels.extend(labels.tolist())
    accuracy = accuracy_score(all_labels, all_preds)
    f1_micro = f1_score(all_labels, all_preds, average='micro')
    f1_macro = f1_score(all_labels, all_preds, average='macro')
    acc = accuracy_score(all_labels, all_preds)
    return acc, f1_micro, f1_macro

# Training loop
epochs = 15
for epoch in range(epochs):   
    model.train()
    train_loss = 0.0
    with tqdm(total=len(train_loader), desc=f"Epoch {epoch + 1}/{epochs}", unit="batch") as pbar:
        for inputs, labels in train_loader:
            optimizer.zero_grad()
            outputs = model(inputs.to(device))
            
            #loss = criterion(outputs, labels.float().to(device))
            loss = loss_func(outputs.view(-1, 32), labels.type_as(outputs).view(-1, 32))
            
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            pbar.update(1)
    
    train_loss /= len(train_loader)
    acc, f1_micro, f1_macr = evaluate(model, val_loader)

    print(f"Train Loss: {train_loss:.4f}, Validation Accuracy: {acc:.4f}, Validation F1 Micro: {f1_micro:.4f}, Validation F1 Macro: {f1_macr:.4f}")