# Imports and Configuration


In [1]:

from pathlib import Path
import torch
import torchaudio
import numpy as np
import pandas as pd
from datasets import Dataset, DatasetDict
from transformers import WhisperFeatureExtractor, WhisperModel, TrainingArguments, Trainer
import torch.nn as nn
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
from datasets import load_from_disk, concatenate_datasets
import os





In [2]:
# Constants
MAX_TEXT_LEN = 128
BATCH_SIZE = 8
NUM_CLASSES = 6
LEARNING_RATE = 2e-5
EPOCHS = 10
SEED = 42

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


Using device: cuda


In [3]:
import os
import torch
import torchaudio
from datasets import Dataset
import pandas as pd

def load_and_preprocess_in_batches(df, batch_size=100, output_dir='processed_batches_whisper'):
    os.makedirs(output_dir, exist_ok=True)
    total_rows = len(df)
    for start_idx in range(0, total_rows, batch_size):
        end_idx = min(start_idx + batch_size, total_rows)
        batch_df = df.iloc[start_idx:end_idx]
        data = []
        for idx, row in batch_df.iterrows():
            raw_path = row["audio"]
            path = str(raw_path[0]) if isinstance(raw_path, list) else str(raw_path)
            label = int(row["label"])
            try:
                waveform, sample_rate = torchaudio.load(path)
                data.append({
                    "audio": {
                        "array": waveform.squeeze().numpy(),
                        "sampling_rate": sample_rate
                    },
                    "label": label
                })
            except Exception as e:
                print(f"Error processing file {path}: {e}")
        if data:
            batch_dataset = Dataset.from_list(data)
            batch_file = os.path.join(output_dir, f'batch_{start_idx}_{end_idx}.arrow')
            batch_dataset.save_to_disk(batch_file)
            print(f"Saved batch {start_idx}-{end_idx} to {batch_file}")


In [4]:
base_data_path = Path.cwd().joinpath('testdataset')
    
    # Load the dataset using MAMKit
from mamkit.data.datasets import MMUSEDFallacy, InputMode
    
mm_used_fallacy_loader = MMUSEDFallacy(
        task_name='afc',               
        input_mode=InputMode.AUDIO_ONLY,
        base_data_path=base_data_path
    )
    
    # Get the splits
# splits = list(mm_used_fallacy_loader.get_splits(method_key='mancini-et-al-2024'))
splits= mm_used_fallacy_loader.get_splits('mm-argfallacy-2025') 
    # For demonstration, we'll use one split (you can iterate through all)
split = splits[0]

print(split)
    
    # Get train, validation, and test datasets
train_data = split.train
val_data = split.val

print(len(train_data))
print(len(val_data))
print(train_data.inputs)
print(train_data.labels)
    

# Extract inputs and labels
inputs = train_data.inputs
labels = train_data.labels

def dataset_to_df(dataset):
    return pd.DataFrame({
        "audio": dataset.inputs,
        "label": dataset.labels,
    })

train_df = dataset_to_df(split.train)

print(train_df)

Mapping audio data...: 100%|██████████| 3388/3388 [00:01<00:00, 2517.30it/s]
Building AFC Context: 100%|██████████| 3388/3388 [00:00<00:00, 13826.16it/s]


SplitInfo(train=<mamkit.data.datasets.UnimodalDataset object at 0x0000023CC9311FC0>, val=<mamkit.data.datasets.UnimodalDataset object at 0x0000023CC9313B80>, test=<mamkit.data.datasets.UnimodalDataset object at 0x0000023CC9312E90>)
1228
0
[list([WindowsPath('D:/newargmining/testdataset/MMUSED-fallacy/audio_clips/10_1984/87.wav')])
 list([WindowsPath('D:/newargmining/testdataset/MMUSED-fallacy/audio_clips/10_1984/149.wav')])
 list([WindowsPath('D:/newargmining/testdataset/MMUSED-fallacy/audio_clips/10_1984/146.wav')])
 ...
 list([WindowsPath('D:/newargmining/testdataset/MMUSED-fallacy/audio_clips/46_2020/136.wav')])
 list([WindowsPath('D:/newargmining/testdataset/MMUSED-fallacy/audio_clips/46_2020/313.wav')])
 list([WindowsPath('D:/newargmining/testdataset/MMUSED-fallacy/audio_clips/46_2020/1090.wav')])]
[0 0 1 ... 4 4 5]
                                                  audio label
0     [D:\newargmining\testdataset\MMUSED-fallacy\au...     0
1     [D:\newargmining\testdataset\MMUSED-f

In [5]:
# Assuming 'df' is your DataFrame
load_and_preprocess_in_batches(train_df, batch_size=100)


Saving the dataset (0/1 shards):   0%|          | 0/100 [00:00<?, ? examples/s]

Saved batch 0-100 to processed_batches_whisper\batch_0_100.arrow


Saving the dataset (0/1 shards):   0%|          | 0/100 [00:00<?, ? examples/s]

Saved batch 100-200 to processed_batches_whisper\batch_100_200.arrow


Saving the dataset (0/1 shards):   0%|          | 0/100 [00:00<?, ? examples/s]

Saved batch 200-300 to processed_batches_whisper\batch_200_300.arrow


Saving the dataset (0/1 shards):   0%|          | 0/100 [00:00<?, ? examples/s]

Saved batch 300-400 to processed_batches_whisper\batch_300_400.arrow


Saving the dataset (0/1 shards):   0%|          | 0/100 [00:00<?, ? examples/s]

Saved batch 400-500 to processed_batches_whisper\batch_400_500.arrow


Saving the dataset (0/1 shards):   0%|          | 0/100 [00:00<?, ? examples/s]

Saved batch 500-600 to processed_batches_whisper\batch_500_600.arrow


Saving the dataset (0/1 shards):   0%|          | 0/100 [00:00<?, ? examples/s]

Saved batch 600-700 to processed_batches_whisper\batch_600_700.arrow


Saving the dataset (0/1 shards):   0%|          | 0/100 [00:00<?, ? examples/s]

Saved batch 700-800 to processed_batches_whisper\batch_700_800.arrow


Saving the dataset (0/1 shards):   0%|          | 0/100 [00:00<?, ? examples/s]

Saved batch 800-900 to processed_batches_whisper\batch_800_900.arrow


Saving the dataset (0/1 shards):   0%|          | 0/100 [00:00<?, ? examples/s]

Saved batch 900-1000 to processed_batches_whisper\batch_900_1000.arrow


Saving the dataset (0/1 shards):   0%|          | 0/100 [00:00<?, ? examples/s]

Saved batch 1000-1100 to processed_batches_whisper\batch_1000_1100.arrow


Saving the dataset (0/1 shards):   0%|          | 0/100 [00:00<?, ? examples/s]

Saved batch 1100-1200 to processed_batches_whisper\batch_1100_1200.arrow


Saving the dataset (0/1 shards):   0%|          | 0/28 [00:00<?, ? examples/s]

Saved batch 1200-1228 to processed_batches_whisper\batch_1200_1228.arrow


In [6]:


def load_all_batches(output_dir='processed_batches_whisper'):
    batch_files = [os.path.join(output_dir, f) for f in os.listdir(output_dir) if f.endswith('.arrow')]
    datasets = [load_from_disk(batch_file) for batch_file in batch_files]
    full_dataset = concatenate_datasets(datasets)
    return full_dataset

# Load the full dataset
full_dataset = load_all_batches()


In [7]:
print(full_dataset)

Dataset({
    features: ['audio', 'label'],
    num_rows: 1228
})


In [8]:
# Step 1: Split into 80% train and 20% temp (which will be split into val and test)
train_val_test = full_dataset.train_test_split(test_size=0.2, seed=SEED)

# Step 2: Split the 20% temp into 50% validation and 50% test (i.e., 10% each of the original dataset)
val_test = train_val_test['test'].train_test_split(test_size=0.5, seed=SEED)

# Combine the splits into a DatasetDict
dataset_dict = DatasetDict({
    'train': train_val_test['train'],
    'validation': val_test['train'],
    'test': val_test['test']
})


In [3]:
from transformers import AutoFeatureExtractor

model_checkpoint = "openai/whisper-small"
feature_extractor = AutoFeatureExtractor.from_pretrained(model_checkpoint)

In [10]:
# from transformers import AutoFeatureExtractor

# model_checkpoint = "openai/whisper-small"
# feature_extractor = AutoFeatureExtractor.from_pretrained(model_checkpoint)

In [11]:
fallacy_mapping = {
    0: "Appeal to Emotion",
    1: "Appeal to Authority",
    2: "Ad Hominem",
    3: "False Cause",
    4: "Slippery Slope",
    5: "Slogans"
}

In [4]:
from transformers import AutoModelForAudioClassification, TrainingArguments, Trainer

num_labels = 6
model = AutoModelForAudioClassification.from_pretrained( model_checkpoint, num_labels=num_labels)

Some weights of WhisperForAudioClassification were not initialized from the model checkpoint at openai/whisper-small and are newly initialized: ['model.classifier.bias', 'model.classifier.weight', 'model.projector.bias', 'model.projector.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [5]:
print(model)

WhisperForAudioClassification(
  (encoder): WhisperEncoder(
    (conv1): Conv1d(80, 768, kernel_size=(3,), stride=(1,), padding=(1,))
    (conv2): Conv1d(768, 768, kernel_size=(3,), stride=(2,), padding=(1,))
    (embed_positions): Embedding(1500, 768)
    (layers): ModuleList(
      (0-11): 12 x WhisperEncoderLayer(
        (self_attn): WhisperSdpaAttention(
          (k_proj): Linear(in_features=768, out_features=768, bias=False)
          (v_proj): Linear(in_features=768, out_features=768, bias=True)
          (q_proj): Linear(in_features=768, out_features=768, bias=True)
          (out_proj): Linear(in_features=768, out_features=768, bias=True)
        )
        (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (activation_fn): GELUActivation()
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (fc2): Linear(in_features=3072, out_features=768, bias=True)
        (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwis

In [13]:
def extract_features(batch):
    input_features = []
    for audio in batch["audio"]:
        # Ensure audio is a dictionary with 'array' and 'sampling_rate'
        if isinstance(audio, dict) and "array" in audio and "sampling_rate" in audio:
            inputs = feature_extractor(
                audio["array"],
                sampling_rate=audio["sampling_rate"],
                return_tensors="pt"
            )
            input_features.append(inputs.input_features[0])
        else:
            # Handle cases where audio is not in the expected format
            input_features.append(None)
    batch["input_features"] = input_features
    return batch


dataset_dict = dataset_dict.map(
    extract_features,
    batched=True,
    batch_size=4
)

# Set format for PyTorch
columns = ["input_features", "label"]
dataset_dict.set_format(type="torch", columns=columns)


Map:   0%|          | 0/982 [00:00<?, ? examples/s]

Map:   0%|          | 0/123 [00:00<?, ? examples/s]

Map:   0%|          | 0/123 [00:00<?, ? examples/s]

In [14]:
# batch_size = 8
# training_args = TrainingArguments(output_dir=f"{model_checkpoint}-finetuned-iemocap4",
#                                   evaluation_strategy="epoch",
#                                   learning_rate=3e-5,
#                                   per_device_train_batch_size=batch_size,
#                                   gradient_accumulation_steps=1, ### ??
#                                   per_device_eval_batch_size=batch_size,
#                                   num_train_epochs=5,
#                                   warmup_ratio=0.1, ###?
#                                   logging_steps=10,
#                                   load_best_model_at_end=False,
#                                   metric_for_best_model="accuracy", ###?
#                                   push_to_hub=True,
#                                   hub_private_repo=True,
#                                   report_to = 'wandb',
#                                   run_name = 'Whisper-fine-tuning'                                  )

In [15]:
# class WhisperClassifier(nn.Module):
#     def __init__(self, num_labels):
#         super(WhisperClassifier, self).__init__()
#         self.whisper = WhisperModel.from_pretrained("openai/whisper-small")
#         self.classifier = nn.Linear(self.whisper.config.hidden_size, num_labels)

#     def forward(self, input_features, labels=None):
#         # outputs = self.whisper(input_features=input_features)
#         logits = self.classifier(outputs.last_hidden_state[:, 0, :])  # Using the first token's representation
#         loss = None
#         if labels is not None:
#             loss = nn.CrossEntropyLoss()(logits, labels)
#         return {"loss": loss, "logits": logits}


In [16]:
def compute_metrics(pred):
    labels = pred.label_ids
    preds = np.argmax(pred.predictions, axis=1)
    return {
        "accuracy": accuracy_score(labels, preds),
        "f1": f1_score(labels, preds, average="macro"),
        "precision": precision_score(labels, preds, average="weighted"),
        "recall": recall_score(labels, preds, average="weighted")
    }


In [17]:
training_args = TrainingArguments(
    output_dir="./whisper-fallacy",
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    logging_dir="./logs",
    num_train_epochs=EPOCHS,
    learning_rate=LEARNING_RATE,
    remove_unused_columns=False
)


# model = WhisperClassifier(num_labels=6).to(device)

print(dataset_dict['train'])
print(dataset_dict['validation'])
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset_dict['train'],
    eval_dataset=dataset_dict['validation'],
    tokenizer=feature_extractor,
    compute_metrics=compute_metrics,
)


  trainer = Trainer(


Dataset({
    features: ['audio', 'label', 'input_features'],
    num_rows: 982
})
Dataset({
    features: ['audio', 'label', 'input_features'],
    num_rows: 123
})


In [18]:
import torch
torch.cuda.empty_cache()

In [19]:
trainer.train()


Epoch,Training Loss,Validation Loss,Accuracy,F1,Precision,Recall
1,No log,1.230639,0.593496,0.12415,0.352237,0.593496
2,No log,1.216116,0.593496,0.12415,0.352237,0.593496
3,No log,1.201973,0.593496,0.168317,0.441463,0.593496
4,No log,1.45739,0.569106,0.267785,0.526568,0.569106
5,1.055300,1.706863,0.495935,0.230942,0.49507,0.495935
6,1.055300,1.884981,0.569106,0.294848,0.55426,0.569106
7,1.055300,2.024053,0.560976,0.263036,0.53909,0.560976
8,1.055300,2.143887,0.544715,0.249455,0.526524,0.544715
9,0.232200,2.30249,0.536585,0.248392,0.527671,0.536585
10,0.232200,2.31809,0.536585,0.245781,0.520376,0.536585


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
Non-default generation parameters: {'max_length': 448, 'suppress_tokens': [1, 2, 7, 8, 9, 10, 14, 25, 26, 27, 28, 29, 31, 58, 59, 60, 61, 62, 63, 90, 91, 92, 93, 359, 503, 522, 542, 873, 893, 902, 918, 922, 931, 1350, 1853, 1982, 2460, 2627, 3246, 3253, 3268, 3536, 3846, 3961, 4183, 4667, 6585, 6647, 7273, 9061, 9383, 10428, 10929, 11938, 12033, 12331, 12562, 13793, 14157, 14635, 15265, 15618, 16553, 16604, 18362, 18956, 20075, 21675, 22520, 26130, 26161, 26435, 28279, 29464, 31650, 32302, 32470, 36865, 42863, 47425, 49870, 50254, 50258, 50360, 50361, 50362], 'begin_suppress_tokens': [220, 50257]}
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
Non-default generation parameters: {'max_length': 448, 'suppress_tokens': [1, 2, 7, 8, 9, 10, 14, 25, 26, 27, 28, 29, 31, 58, 59, 60, 61, 62, 63, 90, 91, 92, 93, 359, 503, 522, 542, 873, 893, 902, 918, 922, 931, 1350, 1853, 1982, 2460, 2627, 3246, 3

TrainOutput(global_step=1230, training_loss=0.5325281011379831, metrics={'train_runtime': 4590.2097, 'train_samples_per_second': 2.139, 'train_steps_per_second': 0.268, 'total_flos': 1.2330850849344e+18, 'train_loss': 0.5325281011379831, 'epoch': 10.0})

In [None]:
trainer.evaluate(dataset_dict['test'])


{'eval_loss': 2.318284273147583,
 'eval_accuracy': 0.5365853658536586,
 'eval_f1': 0.3168465824094963,
 'eval_precision': 0.5032565871877206,
 'eval_recall': 0.5365853658536586,
 'eval_runtime': 23.5591,
 'eval_samples_per_second': 5.221,
 'eval_steps_per_second': 0.679,
 'epoch': 10.0}

In [21]:
# Save the model, tokenizer and configuration
model.save_pretrained("./final-fallacy-classifier-whispersmall")
feature_extractor.save_pretrained("./final-fallacy-classifier-whsipersmall")

Non-default generation parameters: {'max_length': 448, 'suppress_tokens': [1, 2, 7, 8, 9, 10, 14, 25, 26, 27, 28, 29, 31, 58, 59, 60, 61, 62, 63, 90, 91, 92, 93, 359, 503, 522, 542, 873, 893, 902, 918, 922, 931, 1350, 1853, 1982, 2460, 2627, 3246, 3253, 3268, 3536, 3846, 3961, 4183, 4667, 6585, 6647, 7273, 9061, 9383, 10428, 10929, 11938, 12033, 12331, 12562, 13793, 14157, 14635, 15265, 15618, 16553, 16604, 18362, 18956, 20075, 21675, 22520, 26130, 26161, 26435, 28279, 29464, 31650, 32302, 32470, 36865, 42863, 47425, 49870, 50254, 50258, 50360, 50361, 50362], 'begin_suppress_tokens': [220, 50257]}


['./final-fallacy-classifier-whsipersmall\\preprocessor_config.json']