# Bio Bert Model


In [1]:
# Cell 1: Imports and Setup
import os, random, json
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.metrics import precision_score, recall_score, f1_score
from iterstrat.ml_stratifiers import MultilabelStratifiedShuffleSplit
from transformers import AutoTokenizer, AutoModel, get_linear_schedule_with_warmup
from tqdm import tqdm
import nltk
from nltk.corpus import stopwords
import re

In [2]:
# Set random seeds for reproducibility
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)


In [None]:
# Device setup
device = (
    torch.device("mps")
    if torch.backends.mps.is_available()
    else torch.device("cuda")
    if torch.cuda.is_available()
    else torch.device("cpu")
)
print(f"Using device: {device}")

Using device: mps


## 1. Read in Data 

In [5]:
# Load data
df = pd.read_csv('../data/processed/shuffled_10_data.csv')
print(df.shape)
df.head()

(53130, 7)


Unnamed: 0,AC,PMID,Title,Abstract,Terms,Text_combined,batch_number
0,Q8NML3,17183211,"RamA, the transcriptional regulator of acetate...",The RamA protein represents a LuxR-type transc...,autoregulation,"RamA, the transcriptional regulator of acetate...",1
1,Q9SCZ4,17673660,The FERONIA receptor-like kinase mediates male...,"In flowering plants, signaling between the mal...",autophosphorylation,The FERONIA receptor-like kinase mediates male...,1
2,Q81WX1,12721629,The genome sequence of Bacillus anthracis Ames...,Bacillus anthracis is an endospore-forming bac...,,The genome sequence of Bacillus anthracis Ames...,1
3,P14410,8521865,Phosphorylation of the N-terminal intracellula...,This paper reports the phosphorylation of the ...,,Phosphorylation of the N-terminal intracellula...,1
4,P36898,14523231,Mutations in bone morphogenetic protein recept...,Brachydactyly (BD) type A2 is an autosomal dom...,autophosphorylation,Mutations in bone morphogenetic protein recept...,1


## 2. Clean Data

In [6]:
# Cell 3: Clean text
nltk.download('stopwords', quiet=True)
stop_words = set(stopwords.words('english'))

def clean_text(text):
    # Handle NaN values
    if pd.isna(text):
        return ""
    
    text = str(text).lower()
    # Keep hyphens as they may be important in biomedical terms
    text = re.sub(r'[^\w\s-]', '', text)
    # Keep numbers that might be part of important terms
    # text = re.sub(r'\d+', '', text)
    text = " ".join([word.strip() for word in text.split() if word not in stop_words])
    return text

df['Text_Cleaned'] = df['Text_combined'].apply(clean_text)

# Fill nan with 'non-autoregulatory'
df['Terms'] = df['Terms'].fillna('non-autoregulatory')

# Keep only selected columns
columns_to_keep = ['batch_number', 'Text_Cleaned', 'Terms']
df_cleaned = df[columns_to_keep]

In [7]:
print(df_cleaned.shape)
df_cleaned.head()

(53130, 3)


Unnamed: 0,batch_number,Text_Cleaned,Terms
0,1,rama transcriptional regulator acetate metabol...,autoregulation
1,1,feronia receptor-like kinase mediates male-fem...,autophosphorylation
2,1,genome sequence bacillus anthracis ames compar...,non-autoregulatory
3,1,phosphorylation n-terminal intracellular tail ...,non-autoregulatory
4,1,mutations bone morphogenetic protein receptor ...,autophosphorylation


## 3. Preprocess Data

In [8]:
# Cell 4: Prepare labels
# Convert terms to list
df_cleaned['Terms_List'] = df_cleaned['Terms'].apply(
    lambda x: [term.strip() for term in x.split(',')]
)
df_cleaned['Terms_List'] = df_cleaned['Terms_List'].apply(lambda x: list(set(x)))

mlb = MultiLabelBinarizer()
labels = mlb.fit_transform(df_cleaned['Terms_List'])
label_columns = mlb.classes_

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df_cleaned['Terms_List'] = df_cleaned['Terms'].apply(
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df_cleaned['Terms_List'] = df_cleaned['Terms_List'].apply(lambda x: list(set(x)))


In [None]:

labels_df = pd.DataFrame(labels, columns=label_columns)
existing_columns = [col for col in label_columns if col in df_cleaned.columns]
df_cleaned = df_cleaned.drop(columns=existing_columns, errors='ignore')
df_cleaned = pd.concat([df_cleaned, labels_df], axis=1)

In [10]:
print(df_cleaned.shape)
df_cleaned.head()

(53130, 19)


Unnamed: 0,batch_number,Text_Cleaned,Terms,Terms_List,autoactivation,autocatalysis,autocatalytic,autofeedback,autoinducer,autoinduction,autoinhibition,autoinhibitory,autokinase,autolysis,autophosphorylation,autoregulation,autoregulatory,autoubiquitination,non-autoregulatory
0,1,rama transcriptional regulator acetate metabol...,autoregulation,[autoregulation],0,0,0,0,0,0,0,0,0,0,0,1,0,0,0
1,1,feronia receptor-like kinase mediates male-fem...,autophosphorylation,[autophosphorylation],0,0,0,0,0,0,0,0,0,0,1,0,0,0,0
2,1,genome sequence bacillus anthracis ames compar...,non-autoregulatory,[non-autoregulatory],0,0,0,0,0,0,0,0,0,0,0,0,0,0,1
3,1,phosphorylation n-terminal intracellular tail ...,non-autoregulatory,[non-autoregulatory],0,0,0,0,0,0,0,0,0,0,0,0,0,0,1
4,1,mutations bone morphogenetic protein receptor ...,autophosphorylation,[autophosphorylation],0,0,0,0,0,0,0,0,0,0,1,0,0,0,0


## 4. Train Model

In [11]:
# Step 6: Instantiate the tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("dmis-lab/biobert-v1.1")


In [12]:
# Data Splitting
def split_single_batch_data(batch_number, test_size=0.2, random_state=42):
    """
    Split data from a single batch into train and test sets using stratified sampling.
    """
    batch_df = df_cleaned[df_cleaned['batch_number'] == batch_number].copy()
    
    if len(batch_df) == 0:
        print(f"Warning: No data found for batch {batch_number}")
        return None, None, None, None
        
    X = batch_df['Text_Cleaned']
    y = labels_df.loc[batch_df.index].values  # Ensure indexing alignment
    
    # Calculate label distribution
    non_auto_count = len(batch_df[batch_df['Terms'] == 'non-autoregulatory'])
    auto_count = len(batch_df[batch_df['Terms'] != 'non-autoregulatory'])
    
    try:
        # Split data
        msss = MultilabelStratifiedShuffleSplit(n_splits=1, test_size=test_size, random_state=random_state)
        
        for train_idx, test_idx in msss.split(X, y):
            X_train = X.iloc[train_idx]
            X_test = X.iloc[test_idx]
            y_train = y[train_idx]
            y_test = y[test_idx]
        
        print(f"Batch {batch_number} | Train: {len(X_train)}, Test: {len(X_test)} | Non-auto: {non_auto_count}, Auto: {auto_count}")
        
        return X_train, X_test, y_train, y_test
    except Exception as e:
        print(f"Error in splitting batch {batch_number}: {e}")
        # Fallback to regular train_test_split if stratified splitting fails
        X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size, random_state=random_state)
        print(f"Fallback split for batch {batch_number} | Train: {len(X_train)}, Test: {len(X_test)}")
        return X_train, X_test, y_train, y_test

In [13]:
# Calculate class weights
def get_data_and_weights(batch_number):
    """
    Get data and calculate class weights for a specific batch.
    """
    X_train, X_test, y_train, y_test = split_single_batch_data(batch_number)
    
    if X_train is None:
        return None, None, None, None, None
    
    # Calculate class weights
    pos_weights = []
    for i in range(y_train.shape[1]):
        neg_count = len(y_train) - np.sum(y_train[:, i])
        pos_count = np.sum(y_train[:, i])
        # Handle the case where a class might have no positive examples
        if pos_count == 0:
            pos_weights.append(1.0)  # Default weight
        else:
            # Calculate weight and clip to prevent extreme values
            weight = neg_count / pos_count
            weight = min(max(weight, 0.1), 10.0)  # Clip between 0.1 and 10
            pos_weights.append(weight)
    
    pos_weights = torch.FloatTensor(pos_weights).to(device)
    
    return X_train, X_test, y_train, y_test, pos_weights

In [14]:
# Cell 6: Dataset and DataLoader classes
# Create dataset class
class BioBERTDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_length=512):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length
    
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, idx):
        text = str(self.texts.iloc[idx])
        label = self.labels[idx]
        
        encoding = self.tokenizer(
            text,
            add_special_tokens=True,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': torch.FloatTensor(label)
        }


In [15]:
# Create data loader
def create_dataset_and_loader(X_train, y_train, X_test, y_test, batch_size):
    train_dataset = BioBERTDataset(X_train, y_train, tokenizer)
    test_dataset = BioBERTDataset(X_test, y_test, tokenizer)
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size)
    
    return train_loader, test_loader

In [16]:
# Cell 7: Model Architecture
# Improved BioBERT Classifier with intermediate layers
class ImprovedBioBERTClassifier(nn.Module):
    def __init__(self, n_classes, dropout1=0.1, dropout2=0.2):
        super(ImprovedBioBERTClassifier, self).__init__()
        # Create a fresh copy of the model
        self.bert = AutoModel.from_pretrained("dmis-lab/biobert-v1.1")
        
        # Add an intermediate layer
        self.dropout1 = nn.Dropout(dropout1)
        self.intermediate = nn.Linear(self.bert.config.hidden_size, 512)
        self.activation = nn.ReLU()
        self.dropout2 = nn.Dropout(dropout2)
        self.classifier = nn.Linear(512, n_classes)
        
        # Initialize weights
        nn.init.xavier_normal_(self.intermediate.weight)
        nn.init.zeros_(self.intermediate.bias)
        nn.init.xavier_normal_(self.classifier.weight)
        nn.init.zeros_(self.classifier.bias)
    
    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.last_hidden_state[:, 0, :]  # CLS token
        pooled_output = self.dropout1(pooled_output)
        intermediate = self.intermediate(pooled_output)
        intermediate = self.activation(intermediate)
        intermediate = self.dropout2(intermediate)
        logits = self.classifier(intermediate)
        return logits


In [17]:
# Cell 8: Training and Evaluation Functions
# Training function with improvements
def train_epoch(model, data_loader, optimizer, criterion, scheduler=None):
    model.train()
    total_loss = 0
    total_batches = 0
    
    progress_bar = tqdm(data_loader, desc="Training")
    
    for batch in progress_bar:
        try:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            
            # Forward pass
            optimizer.zero_grad()
            
            outputs = model(input_ids, attention_mask)
            loss = criterion(outputs, labels)
            
            # Check for NaN loss
            if torch.isnan(loss):
                print("WARNING: NaN loss detected, skipping batch")
                continue
                
            # Backward pass
            loss.backward()
            
            # Clip gradients
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            
            optimizer.step()
            if scheduler is not None:
                scheduler.step()
            
            # Update progress bar
            progress_bar.set_postfix({'loss': f"{loss.item():.4f}"})
            
            total_loss += loss.item()
            total_batches += 1
        except Exception as e:
            print(f"Error in batch processing: {e}")
            continue
    
    # Protect against division by zero
    if total_batches == 0:
        return float('inf')
    
    return total_loss / total_batches

# Optimize thresholds for each label
def optimize_thresholds(model, val_loader, n_labels):
    model.eval()
    all_outputs = []
    all_labels = []
    
    with torch.no_grad():
        for batch in tqdm(val_loader, desc="Optimizing thresholds"):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels']
            
            outputs = model(input_ids, attention_mask)
            probs = torch.sigmoid(outputs).cpu().numpy()
            
            all_outputs.append(probs)
            all_labels.append(labels.numpy())
    
    all_outputs = np.vstack(all_outputs)
    all_labels = np.vstack(all_labels)
    
    optimal_thresholds = []
    
    for i in range(n_labels):
        best_f1 = 0
        best_threshold = 0.5
        
        for threshold in np.arange(0.3, 0.7, 0.05):
            preds = (all_outputs[:, i] >= threshold).astype(int)
            f1 = f1_score(all_labels[:, i], preds, zero_division=0)
            
            if f1 > best_f1:
                best_f1 = f1
                best_threshold = threshold
        
        optimal_thresholds.append(best_threshold)
        
    return optimal_thresholds

# Evaluation function
def evaluate(model, data_loader, criterion, thresholds):
    """
    Evaluate the model with focused metrics output.
    """
    model.eval()
    total_loss = 0
    all_predictions = []
    all_labels = []
    
    with torch.no_grad():
        for batch in tqdm(data_loader, desc="Evaluating"):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            
            outputs = model(input_ids, attention_mask)
            loss = criterion(outputs, labels)
            total_loss += loss.item()
            
            probabilities = torch.sigmoid(outputs).cpu().numpy()
            
            # Apply thresholds per label
            predictions = np.array([
                (probabilities[:, i] >= thresholds[i]).astype(int) for i in range(len(thresholds))
            ]).T

            all_predictions.extend(predictions)
            all_labels.extend(labels.cpu().numpy())
    
    all_predictions = np.array(all_predictions)
    all_labels = np.array(all_labels)

    # Samples metrics
    samples_precision = precision_score(all_labels, all_predictions, average='samples', zero_division=0)
    samples_recall = recall_score(all_labels, all_predictions, average='samples', zero_division=0)
    samples_f1 = f1_score(all_labels, all_predictions, average='samples', zero_division=0)

    # F1 metrics
    micro_f1 = f1_score(all_labels, all_predictions, average='micro', zero_division=0)
    macro_f1 = f1_score(all_labels, all_predictions, average='macro', zero_division=0)
    weighted_f1 = f1_score(all_labels, all_predictions, average='weighted', zero_division=0)

    # Average loss
    avg_loss = total_loss / len(data_loader)

    metrics = {
        'loss': avg_loss,
        'micro_f1': micro_f1,
        'macro_f1': macro_f1,
        'weighted_f1': weighted_f1,
        'samples_f1': samples_f1,
        'samples_precision': samples_precision,
        'samples_recall': samples_recall
    }
    
    # Output results
    print(f"Loss: {avg_loss:.4f} | Micro F1: {micro_f1:.4f} | Macro F1: {macro_f1:.4f} | Weighted F1: {weighted_f1:.4f} | Samples F1: {samples_f1:.4f}")
    print(f"Samples Precision: {samples_precision:.4f} | Samples Recall: {samples_recall:.4f}")
    
    # Return metrics for further analysis
    return metrics


In [None]:
# Cell 9: Main Training Function
def train_model(batch_number, n_epochs, learning_rate, batch_size, use_optimal_thresholds=True):
    """
    Training loop for a single batch with dynamic threshold settings.
    """
    # Get data
    data_result = get_data_and_weights(batch_number)
    
    if data_result[0] is None:
        print(f"Skipping batch {batch_number} due to data issues.")
        return
        
    X_train, X_test, y_train, y_test, pos_weights = data_result
    print(f"\nProcessing Batch {batch_number} ...")
    
    # Set initial thresholds
    initial_thresholds = [0.5] * y_train.shape[1]
    print(f"\nInitial Thresholds: {[f'{t:.2f}' for t in initial_thresholds]}")
    
    # Create data loaders
    train_loader, test_loader = create_dataset_and_loader(X_train, y_train, X_test, y_test, batch_size)
    
    # Initialize model
    model = ImprovedBioBERTClassifier(n_classes=y_train.shape[1]).to(device)
    
    # Set up loss and optimizer
    criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weights)
    optimizer = AdamW(model.parameters(), lr=learning_rate, weight_decay=0.01)
    
    # Set up scheduler with warmup
    total_steps = len(train_loader) * n_epochs
    warmup_steps = int(total_steps * 0.1)  # 10% warmup
    
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=warmup_steps,
        num_training_steps=total_steps
    )
    
    # Training loop
    best_samples_f1 = 0.0
    
    # Dictionary to track all metrics
    all_metrics = {
        'train_loss': [],
        'val_loss': [],
        'micro_f1': [],
        'macro_f1': [],
        'weighted_f1': [],
        'samples_f1': [],
        'samples_precision': [],
        'samples_recall': []
    }
    
    for epoch in range(n_epochs):
        print(f"Epoch {epoch + 1}/{n_epochs} | Batch {batch_number}")
        
        # Train
        epoch_loss = train_epoch(model, train_loader, optimizer, criterion, scheduler)
        print(f"Training Loss: {epoch_loss:.4f}")
        all_metrics['train_loss'].append(epoch_loss)
        
        # Optimize thresholds if requested (only after first epoch)
        thresholds = initial_thresholds
        if use_optimal_thresholds and epoch >= 1:
            try:
                print("Optimizing thresholds...")
                thresholds = optimize_thresholds(model, test_loader, len(initial_thresholds))
                print(f"Optimized Thresholds: {[f'{t:.2f}' for t in thresholds]}")
            except Exception as e:
                print(f"Error optimizing thresholds: {e}")
                print("Using default thresholds")
                thresholds = initial_thresholds
        
        # Evaluate
        metrics = evaluate(model, test_loader, criterion, thresholds)
        
        # Store metrics
        for k, v in metrics.items():
            if k in all_metrics:
                all_metrics[k].append(v)
        
        current_samples_f1 = metrics['samples_f1']
        
        if current_samples_f1 > best_samples_f1:
            best_samples_f1 = current_samples_f1
            print(f"New best F1 score: {best_samples_f1:.4f}")
    
    return model, best_samples_f1


In [19]:
# Cell 10: Inference Class
class BioBERTInference:
    def __init__(self, model_path, thresholds_path, label_columns_path):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        # Load label columns
        with open(label_columns_path, 'r') as f:
            self.label_columns = json.load(f)
        
        # Initialize tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained("dmis-lab/biobert-v1.1")
        
        # Initialize model
        self.model = ImprovedBioBERTClassifier(n_classes=len(self.label_columns))
        self.model.load_state_dict(torch.load(model_path, map_location=self.device))
        self.model.to(self.device)
        self.model.eval()
        
        # Load thresholds
        with open(thresholds_path, 'r') as f:
            self.thresholds = json.load(f)
    
    def preprocess(self, text):
        return clean_text(text)
    
    def predict(self, text):
        processed_text = self.preprocess(text)
        
        encoding = self.tokenizer(
            processed_text,
            add_special_tokens=True,
            max_length=512,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        input_ids = encoding['input_ids'].to(self.device)
        attention_mask = encoding['attention_mask'].to(self.device)
        
        with torch.no_grad():
            outputs = self.model(input_ids, attention_mask)
            probabilities = torch.sigmoid(outputs).cpu().numpy()[0]
        
        # Apply thresholds and get predicted labels
        predictions = {}
        for i, label in enumerate(self.label_columns):
            if probabilities[i] >= self.thresholds[i]:
                predictions[label] = float(probabilities[i])
        
        return predictions


In [20]:
# Cell 11: Run Training
# Change these parameters as needed
n_epochs = 7
learning_rate = 5e-6
batch_size = 8

# To train a specific batch
batch_num = 1  # Change this to the batch you want to train
model, f1_score = train_model(
    batch_number=batch_num,
    n_epochs=n_epochs,
    learning_rate=learning_rate,
    batch_size=batch_size,
    use_optimal_thresholds=True
)
print(f"Batch {batch_num} training complete. Best F1 score: {f1_score:.4f}")

Batch 1 | Train: 4248, Test: 1065 | Non-auto: 3542, Auto: 1771

Processing Batch 1 ...

Initial Thresholds: ['0.50', '0.50', '0.50', '0.50', '0.50', '0.50', '0.50', '0.50', '0.50', '0.50', '0.50', '0.50', '0.50', '0.50', '0.50']
Epoch 1/7 | Batch 1


Training: 100%|██████████| 531/531 [15:32<00:00,  1.76s/it, loss=0.2189]


Training Loss: 0.4827


Evaluating: 100%|██████████| 134/134 [01:26<00:00,  1.56it/s]


Loss: 0.2680 | Micro F1: 0.7490 | Macro F1: 0.2659 | Weighted F1: 0.7411 | Samples F1: 0.7641
Samples Precision: 0.7523 | Samples Recall: 0.7930
  New best model saved to model/batch_1/best_model.pt
  Thresholds saved to model/batch_1/best_thresholds.json
Epoch 2/7 | Batch 1


Training: 100%|██████████| 531/531 [16:47<00:00,  1.90s/it, loss=0.0878]


Training Loss: 0.1764
Optimizing thresholds...


Optimizing thresholds: 100%|██████████| 134/134 [01:22<00:00,  1.62it/s]


Optimized Thresholds: ['0.50', '0.50', '0.50', '0.30', '0.60', '0.50', '0.65', '0.65', '0.50', '0.65', '0.55', '0.40', '0.65', '0.55', '0.40']


Evaluating: 100%|██████████| 134/134 [01:23<00:00,  1.61it/s]


Loss: 0.1008 | Micro F1: 0.9031 | Macro F1: 0.5927 | Weighted F1: 0.9058 | Samples F1: 0.9139
Samples Precision: 0.9011 | Samples Recall: 0.9399
  New best model saved to model/batch_1/best_model.pt
  Thresholds saved to model/batch_1/best_thresholds.json
Epoch 3/7 | Batch 1


Training: 100%|██████████| 531/531 [16:11<00:00,  1.83s/it, loss=0.0207]


Training Loss: 0.1036
Optimizing thresholds...


Optimizing thresholds: 100%|██████████| 134/134 [01:34<00:00,  1.42it/s]


Optimized Thresholds: ['0.30', '0.50', '0.35', '0.30', '0.30', '0.30', '0.60', '0.40', '0.50', '0.60', '0.60', '0.55', '0.35', '0.45', '0.65']


Evaluating: 100%|██████████| 134/134 [01:20<00:00,  1.67it/s]


Loss: 0.0848 | Micro F1: 0.9315 | Macro F1: 0.7845 | Weighted F1: 0.9311 | Samples F1: 0.9322
Samples Precision: 0.9291 | Samples Recall: 0.9394
  New best model saved to model/batch_1/best_model.pt
  Thresholds saved to model/batch_1/best_thresholds.json
Epoch 4/7 | Batch 1


Training: 100%|██████████| 531/531 [15:59<00:00,  1.81s/it, loss=0.1242]


Training Loss: 0.0790
Optimizing thresholds...


Optimizing thresholds: 100%|██████████| 134/134 [01:20<00:00,  1.66it/s]


Optimized Thresholds: ['0.35', '0.50', '0.40', '0.30', '0.30', '0.30', '0.40', '0.35', '0.30', '0.60', '0.60', '0.55', '0.60', '0.60', '0.30']


Evaluating: 100%|██████████| 134/134 [01:23<00:00,  1.60it/s]


Loss: 0.0750 | Micro F1: 0.9367 | Macro F1: 0.8414 | Weighted F1: 0.9368 | Samples F1: 0.9394
Samples Precision: 0.9365 | Samples Recall: 0.9465
  New best model saved to model/batch_1/best_model.pt
  Thresholds saved to model/batch_1/best_thresholds.json
Epoch 5/7 | Batch 1


Training: 100%|██████████| 531/531 [16:58<00:00,  1.92s/it, loss=0.0606]


Training Loss: 0.0655
Optimizing thresholds...


Optimizing thresholds: 100%|██████████| 134/134 [01:22<00:00,  1.63it/s]


Optimized Thresholds: ['0.30', '0.50', '0.35', '0.30', '0.30', '0.30', '0.65', '0.35', '0.30', '0.50', '0.60', '0.45', '0.30', '0.50', '0.30']


Evaluating: 100%|██████████| 134/134 [01:21<00:00,  1.64it/s]


Loss: 0.0792 | Micro F1: 0.9437 | Macro F1: 0.8703 | Weighted F1: 0.9430 | Samples F1: 0.9471
Samples Precision: 0.9441 | Samples Recall: 0.9549
  New best model saved to model/batch_1/best_model.pt
  Thresholds saved to model/batch_1/best_thresholds.json
Epoch 6/7 | Batch 1


Training: 100%|██████████| 531/531 [17:09<00:00,  1.94s/it, loss=0.6476]


Training Loss: 0.0614
Optimizing thresholds...


Optimizing thresholds: 100%|██████████| 134/134 [01:22<00:00,  1.62it/s]


Optimized Thresholds: ['0.30', '0.50', '0.55', '0.30', '0.30', '0.30', '0.60', '0.40', '0.30', '0.40', '0.65', '0.30', '0.30', '0.30', '0.30']


Evaluating: 100%|██████████| 134/134 [01:23<00:00,  1.60it/s]


Loss: 0.0723 | Micro F1: 0.9404 | Macro F1: 0.8678 | Weighted F1: 0.9400 | Samples F1: 0.9430
Samples Precision: 0.9404 | Samples Recall: 0.9502
Epoch 7/7 | Batch 1


Training: 100%|██████████| 531/531 [16:30<00:00,  1.87s/it, loss=0.0091]


Training Loss: 0.0544
Optimizing thresholds...


Optimizing thresholds: 100%|██████████| 134/134 [01:22<00:00,  1.63it/s]


Optimized Thresholds: ['0.30', '0.50', '0.50', '0.30', '0.30', '0.30', '0.60', '0.30', '0.30', '0.45', '0.55', '0.50', '0.30', '0.35', '0.30']


Evaluating: 100%|██████████| 134/134 [01:21<00:00,  1.64it/s]

Loss: 0.0698 | Micro F1: 0.9398 | Macro F1: 0.8687 | Weighted F1: 0.9397 | Samples F1: 0.9415
Samples Precision: 0.9394 | Samples Recall: 0.9474
Metrics saved to results/metrics_batch_1.json
Batch 1 training complete. Best F1 score: 0.9471





In [21]:
# Cell 12: Train All Batches (Optional)

# Get unique batch numbers
unique_batches = df_cleaned['batch_number'].unique()

# Train models for each batch
for batch_num in unique_batches:
    try:
        print(f"\n=== Starting training for batch {batch_num} ===\n")
        model, f1_score = train_model(
            batch_number=batch_num, 
            n_epochs=n_epochs, 
            learning_rate=learning_rate, 
            batch_size=batch_size,
            use_optimal_thresholds=True
        )
        print(f"Batch {batch_num} training complete. Best F1 score: {f1_score:.4f}")
    except Exception as e:
        print(f"Error training batch {batch_num}: {e}")
        continue



=== Starting training for batch 1 ===

Batch 1 | Train: 4248, Test: 1065 | Non-auto: 3542, Auto: 1771

Processing Batch 1 ...

Initial Thresholds: ['0.50', '0.50', '0.50', '0.50', '0.50', '0.50', '0.50', '0.50', '0.50', '0.50', '0.50', '0.50', '0.50', '0.50', '0.50']
Epoch 1/7 | Batch 1


Training:   3%|▎         | 14/531 [00:28<17:38,  2.05s/it, loss=0.7747]


KeyboardInterrupt: 