In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import LabelEncoder
import numpy as np
from tqdm import tqdm
import yaml , os
from pathlib import Path
from sklearn.metrics import classification_report

In [2]:
# Load pretrained Whisper
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
print(os.getcwd())

cuda
c:\UWO\Projects\Multimodal Emotion detection


In [3]:
with open('./meld.yaml', 'r') as fp:
    meld_dict = yaml.safe_load(fp)
    
train_split = meld_dict['train']
test_split = meld_dict['test']
dev_split = meld_dict['dev']

In [4]:
for aud_key in train_split.keys():
    aud_path = os.path.normpath(os.path.join("./dataset_extracted/output_train_extracted", f"{aud_key}.wav"))
    
    if not os.path.exists(aud_path):
        print(f"File not found, skipping: {aud_path}")
        continue

    aud_properties = train_split[aud_key]
    text = aud_properties['Utterance']
    emotion_label = aud_properties['Emotion']
    

for aud_key in dev_split.keys():
    aud_path = os.path.normpath(os.path.join("./dataset_extracted/output_dev_extracted", f"{aud_key}.wav"))

    if not os.path.exists(aud_path):
        print(f"File not found, skipping: {aud_path}")
        continue

    aud_properties = dev_split[aud_key]
    text = aud_properties['Utterance']
    emotion_label = aud_properties['Emotion']


In [5]:
# Function to extract mean pooled encoder features from audio
# import whisperx
# whisper_model = whisperx.load_model("base")
# whisper_model.eval()
# whisper_model = whisper_model.to(device)

# def extract_whisper_features(audio_path):
#     # audio = whisper.load_audio(audio_path)
#     from pathlib import Path
#     audio_path = Path(audio_path).as_posix()
#     audio = whisperx.load_audio(audio_path)

#     audio = whisperx.pad_or_trim(audio)
#     mel = whisperx.log_mel_spectrogram(audio).to(device)
#     with torch.no_grad():
#         features = whisper_model.encoder(mel.unsqueeze(0))  # shape: [1, frames, dim]
#     return features.mean(dim=1).squeeze().cpu().numpy()  # shape: [dim]

In [6]:
import os, io
import torch
from whisper.model import AudioEncoder
from whisper import _MODELS, _ALIGNMENT_HEADS, _download, available_models
from whisper import ModelDimensions
from typing import Optional, Union

def load_model(
    model,
    name: str,
    device: Optional[Union[str, torch.device]] = None,
    download_root: str = None,
    in_memory: bool = False,
) -> AudioEncoder:
    if device is None:
        device = "cuda" if torch.cuda.is_available() else "cpu"
    if download_root is None:
        default = os.path.join(os.path.expanduser("~"), ".cache")
        download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper")

    if name in _MODELS:
        checkpoint_file = _download(_MODELS[name], download_root, in_memory)
        alignment_heads = _ALIGNMENT_HEADS[name]
    elif os.path.isfile(name):
        checkpoint_file = open(name, "rb").read() if in_memory else name
        alignment_heads = None
    else:
        raise RuntimeError(
            f"Model {name} not found; available models = {available_models()}"
        )

    with (
        io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")
    ) as fp:
        checkpoint = torch.load(fp, map_location=device)
    del checkpoint_file

    dims = ModelDimensions(**checkpoint["dims"])
    print(
        dims.n_mels,
        dims.n_audio_ctx,
        dims.n_audio_state,
        dims.n_audio_head,
        dims.n_audio_layer,
    )
    # model = AudioEncoder(
    #     dims.n_mels,
    #     dims.n_audio_ctx,
    #     dims.n_audio_state,
    #     dims.n_audio_head,
    #     dims.n_audio_layer,
    # )

    model_state_dict = model.state_dict()
    print("\n".join([f for f in checkpoint["model_state_dict"].keys() if "encoder" in f and f in model_state_dict.keys()]))
    encoder_keys = [f for f in checkpoint["model_state_dict"].keys() if "encoder" in f]
    missing_keys, unexpected_keys = model.load_state_dict({f: checkpoint["model_state_dict"][f] for f in encoder_keys}, 
                          strict=False)
    # if alignment_heads is not None:
    #     model.set_alignment_heads(alignment_heads)

    return model

In [7]:
from torch.utils.data import Dataset
import os
from os.path import isfile, join
from tqdm.notebook import tqdm
from typing import Literal

import librosa
import whisper
import torch
import yaml
from sklearn.preprocessing import LabelEncoder
from transformers import BertTokenizer

# whisper_model = whisper.load_model("base")
# whisper_model.eval()
# whisper_model.to(torch.device("cuda"))

class WhisperMELDDataset(Dataset):
    def __init__(self, dataset_path='./meld.yaml', split_name='train', sr=16000, label_encoder=None, 
                 mode: Literal["default", "temporal", "full"]="default", whisper_model=None, max_len=128):
        super(WhisperMELDDataset, self).__init__()

        with open(dataset_path, 'r') as fp:
            meld_dict = yaml.safe_load(fp)
        
        self.split_name = split_name
        self.sr = sr
        self.mode = mode
        self.max_len = max_len
        
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        
        # Extract correct split
        if split_name == 'train':
            self.split = meld_dict['train']
            self.audio_dir = "./dataset_extracted/output_train_extracted"
        elif split_name == 'test':
            self.split = meld_dict['test']
            self.audio_dir = "./dataset_extracted/output_test_extracted"
        elif split_name == 'dev':
            self.split = meld_dict['dev']
            self.audio_dir = "./dataset_extracted/output_dev_extracted"
        else:
            raise ValueError("split_name must be one of: train, test, dev")

        self.whisper_model = whisper_model
        self.keys = list(self.split.keys())

        # Build label encoder if not provided
        if label_encoder is None:
            all_labels = [entry['Emotion'] for entry in self.split.values()]
            self.label_encoder = LabelEncoder()
            self.label_encoder.fit(all_labels)
        else:
            self.label_encoder = label_encoder

        # Pre-encode all labels
        self.encoded_labels = {
            k: self.label_encoder.transform([v['Emotion']])[0]
            for k, v in self.split.items()
        }

    def __len__(self):
        return len(self.keys)

    def extract_whisper_features(self, audio_path):
        # audio = whisper.load_audio(audio_path)
        # audio_path = Path(audio_path).as_posix()
        # print(audio_path, isfile(audio_path))
        try:
            audio, _ = librosa.load(audio_path, sr=self.sr)
            # audio = whisper.load_audio(join(os.getcwd(), audio_path))
            audio = whisper.pad_or_trim(audio)    
            mel = whisper.log_mel_spectrogram(audio).to(device)
            
            if self.mode!="full":
                with torch.no_grad():
                    features = self.whisper_model.encoder(mel.unsqueeze(0))  # [1, T, 768]
                    # print(features.shape)
            else:
                feat = torch.zeros(1024).to(device)
        except Exception as e:
            print(f"[WARNING] Failed to load audio {audio_path}: {e} | Exists:{isfile(audio_path)}")
            feat = torch.zeros(1024).to(device)
            
        if self.mode=="default":
            mn = features.mean(dim=1).squeeze()
            std = features.std(dim=1).squeeze()
            feat = torch.concatenate([mn, std], dim=0)
            return feat
        elif self.mode=="temporal":
            feat = features.detach()
            return feat
        elif self.mode=="full":
            return mel
        return feat

    def __getitem__(self, idx):
        aud_key = self.keys[idx]
        # audio_path = os.path.join(self.audio_dir, f"{aud_key}.wav")
        # audio_path = os.path.normpath(os.path.join(self.audio_dir, f"{aud_key}.wav"))
        # audio_path = Path(audio_path).as_posix()
        audio_path = os.path.join(self.audio_dir, f"{aud_key}.wav")
        features = self.extract_whisper_features(audio_path)
        
        # Get text for corresponding utterance
        utterance = self.split[aud_key]['Utterance']
        
        # Tokenize text (BERT input)
        encoding = self.tokenizer(
        utterance,
        padding='max_length',
        truncation=True,
        max_length=self.max_len,
        return_tensors='pt'
        )
        
        input_ids = encoding['input_ids'].squeeze(0)  # (max_len,)
        attention_mask = encoding['attention_mask'].squeeze(0)  # (max_len,)
        label = self.encoded_labels[aud_key]
        return features, input_ids, attention_mask, torch.tensor(label, dtype=torch.long)

In [None]:
# Step 1: Prepare label encoder manually using meld.yaml (for reproducibility)
import yaml
from sklearn.preprocessing import LabelEncoder

with open('C:/UWO/Projects/Multimodal Emotion detection/meld.yaml', 'r') as fp:
    meld_dict = yaml.safe_load(fp)

# Use the train split to fit the encoder
all_labels = [entry['Emotion'] for entry in meld_dict['train'].values()]
label_encoder = LabelEncoder()
label_encoder.fit(all_labels)

# Step 2: Instantiate datasets
# train_dataset = WhisperMELDDataset(dataset_path='C:/UWO/Projects/Multimodal Emotion detection/meld.yaml', split_name='train', label_encoder=label_encoder)
# dev_dataset   = WhisperMELDDataset(dataset_path='C:/UWO/Projects/Multimodal Emotion detection/meld.yaml', split_name='dev',   label_encoder=label_encoder)
# test_dataset  = WhisperMELDDataset(dataset_path='C:/UWO/Projects/Multimodal Emotion detection/meld.yaml', split_name='test',  label_encoder=label_encoder)
# train_dataset = WhisperMELDDataset(dataset_path='C:/UWO/Projects/Multimodal Emotion detection/meld.yaml', split_name='train', label_encoder=label_encoder, mode="temporal")
# dev_dataset   = WhisperMELDDataset(dataset_path='C:/UWO/Projects/Multimodal Emotion detection/meld.yaml', split_name='dev',   label_encoder=label_encoder, mode="temporal")
# test_dataset  = WhisperMELDDataset(dataset_path='C:/UWO/Projects/Multimodal Emotion detection/meld.yaml', split_name='test',  label_encoder=label_encoder, mode="temporal")
train_dataset = WhisperMELDDataset(dataset_path='C:/UWO/Projects/Multimodal Emotion detection/meld.yaml', split_name='train', label_encoder=label_encoder, mode="full")
dev_dataset   = WhisperMELDDataset(dataset_path='C:/UWO/Projects/Multimodal Emotion detection/meld.yaml', split_name='dev',   label_encoder=label_encoder, mode="full")
test_dataset  = WhisperMELDDataset(dataset_path='C:/UWO/Projects/Multimodal Emotion detection/meld.yaml', split_name='test',  label_encoder=label_encoder, mode="full")

# Step 3: Dataloaders
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
dev_loader   = DataLoader(dev_dataset, batch_size=8)
test_loader  = DataLoader(test_dataset, batch_size=8)

In [10]:
label_arr = []
for i, (features, input_ids, attention_mask, label) in enumerate(tqdm(train_loader)):
    label_arr.append(label.cpu())
    if i == 0:  # Just print the first batch
        print(f"Found batch with features shape: {features.shape}, labels shape: {label.shape}")
    # break
label_arr = torch.concatenate(label_arr, dim=0).to(device)
print(label_arr.shape)

  0%|          | 0/1249 [00:00<?, ?it/s]

Found batch with features shape: torch.Size([8, 80, 3000]), labels shape: torch.Size([8])
torch.Size([9988])


In [11]:
for i, (features, input_ids, attention_mask, label) in enumerate(tqdm(train_loader)):
    if i == 0:  # Just print the first batch
        print(f"Found batch with features shape: {features.shape}, labels shape: {label.shape}, input_shape: {input_ids.shape} | {attention_mask.shape}")
    break

  0%|          | 0/1249 [00:00<?, ?it/s]

Found batch with features shape: torch.Size([8, 80, 3000]), labels shape: torch.Size([8]), input_shape: torch.Size([8, 128]) | torch.Size([8, 128])


In [12]:
print("\nSample paths from training set:")
for i, key in enumerate(train_dataset.keys[:5]):  # Print first 5 paths
    audio_path = os.path.join(train_dataset.audio_dir, f"{key}.wav")
    print(f"Train path {i+1}: {audio_path}")
    print(f"File exists: {os.path.exists(audio_path)}")

print("\nSample paths from validation set:")
for i, key in enumerate(dev_dataset.keys[:5]):  # Print first 5 paths
    audio_path = os.path.join(dev_dataset.audio_dir, f"{key}.wav")
    print(f"Dev path {i+1}: {audio_path}")
    print(f"File exists: {os.path.exists(audio_path)}")

print("\nSample paths from test set:")
for i, key in enumerate(test_dataset.keys[:5]):  # Print first 5 paths
    audio_path = os.path.join(test_dataset.audio_dir, f"{key}.wav")
    print(f"Test path {i+1}: {audio_path}")
    print(f"File exists: {os.path.exists(audio_path)}")


Sample paths from training set:
Train path 1: ./dataset_extracted/output_train_extracted\dia0_utt0.wav
File exists: True
Train path 2: ./dataset_extracted/output_train_extracted\dia0_utt1.wav
File exists: True
Train path 3: ./dataset_extracted/output_train_extracted\dia0_utt10.wav
File exists: True
Train path 4: ./dataset_extracted/output_train_extracted\dia0_utt11.wav
File exists: True
Train path 5: ./dataset_extracted/output_train_extracted\dia0_utt12.wav
File exists: True

Sample paths from validation set:
Dev path 1: ./dataset_extracted/output_dev_extracted\dia0_utt0.wav
File exists: True
Dev path 2: ./dataset_extracted/output_dev_extracted\dia0_utt1.wav
File exists: True
Dev path 3: ./dataset_extracted/output_dev_extracted\dia100_utt0.wav
File exists: True
Dev path 4: ./dataset_extracted/output_dev_extracted\dia101_utt0.wav
File exists: True
Dev path 5: ./dataset_extracted/output_dev_extracted\dia102_utt0.wav
File exists: True

Sample paths from test set:
Test path 1: ./dataset_e

In [13]:
from sklearn.utils.class_weight import compute_class_weight
# label_arr = (label_arr.shape[0] - torch.sum(label_arr, dim=1))/label_arr.shape[0]
label_weights = compute_class_weight(class_weight="balanced", classes=np.arange(7), y=label_arr.cpu().numpy())
label_weights = label_weights.astype(np.float32)
print(label_weights)

[1.286616  5.2651553 5.324094  0.8186214 0.3030064 2.0891027 1.1841139]


In [14]:
# MLP Classifier
from transformers import get_linear_schedule_with_warmup
from torcheval.metrics.functional import multiclass_f1_score
from transformers import BertModel
total_epochs = 10

class WhisperClassifier(nn.Module):
    def __init__(self, input_dim=768, hidden_dim=256, num_classes=7, train_whisper = False):
        super(WhisperClassifier, self).__init__()

        self.train_whisper = train_whisper
        if train_whisper:
            # self.audio_encoder = load_model("base")
            self.encoder = AudioEncoder(80, 1500, 512, 8, 6)
        self.conv1 = nn.Conv1d(in_channels=1500, out_channels=128, kernel_size=5, padding=2)
        self.bn_conv1 = nn.BatchNorm1d(128)
        self.relu_conv1 = nn.ELU()
        self.pool1 = nn.MaxPool1d(kernel_size=4)

        # Output from conv1d is [B, 128, input_dim // 2], flatten before FC
        self.flatten_dim = (input_dim // 4) * 128
        
        self.compressor = nn.Sequential(
            nn.Linear(self.flatten_dim, hidden_dim*4),
            nn.BatchNorm1d(hidden_dim*4),
            nn.ELU(),
            nn.Dropout(0.2),
        )

        # Text branch (BERT)
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        # self.text_fc = nn.Sequential(
        #     nn.Linear(768, hidden_dim),
        #     nn.BatchNorm1d(hidden_dim),
        #     nn.ELU(),
        #     nn.Dropout(0.2)
        # )
        
        self.fc1 = nn.Linear((hidden_dim*4) + 768, hidden_dim*4)
        self.bn1 = nn.BatchNorm1d(hidden_dim*4)
        self.relu1 = nn.ELU()
        self.dropout1 = nn.Dropout(0.2)
        
        self.fc2 = nn.Linear(hidden_dim*4, hidden_dim)
        self.bn2 = nn.BatchNorm1d(hidden_dim)
        self.relu2 = nn.ELU()
        self.dropout2 = nn.Dropout(0.2)
        
        self.fc3 = nn.Linear(hidden_dim, hidden_dim//4)
        # self.bn3 = nn.BatchNorm1d(hidden_dim//4)
        self.relu3 = nn.ELU()
        self.dropout3 = nn.Dropout(0.2)

        self.classifier = nn.Linear(hidden_dim//4, num_classes)

    def forward(self, x1, x2, x_mask):
        # print(x.shape)
        if self.train_whisper:
            x1 = self.encoder(x1)
        x1 = self.pool1(self.relu_conv1(self.bn_conv1(self.conv1(x1))))  # → [B, 128, input_dim // 2]
        x1 = x1.view(x1.size(0), -1)  # flatten to [B, flatten_dim]
        x1 = self.compressor(x1)
        
        # Text branch (BERT)
        bert_outputs = self.bert(input_ids=x2, attention_mask=x_mask)
        x2 = bert_outputs.last_hidden_state[:, 0, :]
        # x2 = self.text_fc(x2)  
        
        # Concatenate audio + text
        x = torch.cat([x1, x2], dim=1)  
        
        x = self.dropout1(self.relu1(self.bn1(self.fc1(x))))
        x = self.dropout2(self.relu2(self.bn2(self.fc2(x))))
        x = self.dropout3(self.relu3(self.fc3(x)))
        return self.classifier(x)

# Instantiate model
model = WhisperClassifier(input_dim=512, hidden_dim=256, num_classes=7, train_whisper=True).to(device)
criterion = nn.CrossEntropyLoss(weight=torch.from_numpy(label_weights).to(device))
optimizer = torch.optim.Adam(model.parameters(), 
                             lr=1e-5, 
                             weight_decay=1e-4
                             )
# scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.1)
model = load_model(model, "base")
print(model)
scheduler = get_linear_schedule_with_warmup(optimizer=optimizer, 
                                            num_warmup_steps=int(0.1 * total_epochs * len(train_loader)), 
                                            num_training_steps=int(total_epochs * len(train_loader)))

# Evaluation function
def evaluate_model(model, loader):
    model.eval()
    all_preds, all_labels = [], []
    with torch.no_grad():
        for X, input_ids, attention_mask, y in tqdm(loader):
            X, y = X.to(device), y.to(device)
            # X = normalize_x(X)
            # X = transform_x(X)
            input_ids = input_ids.to(device)
            attention_mask = attention_mask.to(device)
            
            logits = model(X, input_ids, attention_mask)
            preds = torch.argmax(logits, dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(y.cpu().numpy())

    print(classification_report(all_labels, all_preds, target_names=label_encoder.classes_))
    wf1 = multiclass_f1_score(torch.from_numpy(np.array(all_preds, dtype=np.int64)), 
                              torch.from_numpy(np.array(all_labels, dtype=np.int64)), 
                              num_classes=7, average="weighted")
    return wf1

# Training loop
def train_model(model, train_loader, dev_loader, epochs=50):
    for epoch in range(epochs):
        model.train()
        total_loss = 0
        for batch in tqdm(train_loader):
            X, IIDS, AM, y = batch
            X, IIDS, AM, y = X.to(device), IIDS.to(device), AM.to(device), y.to(device)
            # X = normalize_x(X)
            # X = transform_x(X)
            y = y.type(torch.long)
            optimizer.zero_grad()
            output = model(X, IIDS, AM)
            
            loss = criterion(output, y)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            scheduler.step()
            total_loss += loss.item()
            # break
        print(f"Epoch {epoch+1}/{epochs}, Training Loss: {total_loss/len(train_loader):.4f} | {scheduler.get_last_lr()}")
        acc = evaluate_model(model, dev_loader)
        acc = str(acc).replace('.', '_')
        torch.save(model.state_dict(), f"./checkpoints/epoch_{epoch}-acc_{acc}.pth")
        # break

train_model(model, train_loader, dev_loader, epochs=total_epochs)

# Final evaluation on test set
print("----- Final Test Performance -----")
evaluate_model(model, test_loader) 


  checkpoint = torch.load(fp, map_location=device)


80 1500 512 8 6
encoder.positional_embedding
encoder.conv1.weight
encoder.conv1.bias
encoder.conv2.weight
encoder.conv2.bias
encoder.blocks.0.mlp_ln.weight
encoder.blocks.0.mlp_ln.bias
encoder.blocks.0.mlp.0.weight
encoder.blocks.0.mlp.0.bias
encoder.blocks.0.mlp.2.weight
encoder.blocks.0.mlp.2.bias
encoder.blocks.0.attn_ln.weight
encoder.blocks.0.attn_ln.bias
encoder.blocks.0.attn.query.weight
encoder.blocks.0.attn.query.bias
encoder.blocks.0.attn.key.weight
encoder.blocks.0.attn.value.weight
encoder.blocks.0.attn.value.bias
encoder.blocks.0.attn.out.weight
encoder.blocks.0.attn.out.bias
encoder.blocks.1.mlp_ln.weight
encoder.blocks.1.mlp_ln.bias
encoder.blocks.1.mlp.0.weight
encoder.blocks.1.mlp.0.bias
encoder.blocks.1.mlp.2.weight
encoder.blocks.1.mlp.2.bias
encoder.blocks.1.attn_ln.weight
encoder.blocks.1.attn_ln.bias
encoder.blocks.1.attn.query.weight
encoder.blocks.1.attn.query.bias
encoder.blocks.1.attn.key.weight
encoder.blocks.1.attn.value.weight
encoder.blocks.1.attn.value.bi

  0%|          | 0/1249 [00:00<?, ?it/s]

Epoch 1/10, Training Loss: 1.8508 | [1e-05]


  0%|          | 0/139 [00:00<?, ?it/s]

              precision    recall  f1-score   support

       anger       0.45      0.39      0.41       153
     disgust       0.05      0.09      0.06        22
        fear       0.03      0.03      0.03        40
         joy       0.52      0.48      0.50       163
     neutral       0.81      0.46      0.59       469
     sadness       0.36      0.42      0.39       111
    surprise       0.36      0.83      0.51       150

    accuracy                           0.48      1108
   macro avg       0.37      0.39      0.35      1108
weighted avg       0.57      0.48      0.49      1108



  0%|          | 0/1249 [00:00<?, ?it/s]

Epoch 2/10, Training Loss: 1.5457 | [8.888888888888888e-06]


  0%|          | 0/139 [00:00<?, ?it/s]

              precision    recall  f1-score   support

       anger       0.48      0.37      0.41       153
     disgust       0.26      0.23      0.24        22
        fear       0.26      0.12      0.17        40
         joy       0.51      0.62      0.56       163
     neutral       0.77      0.73      0.75       469
     sadness       0.40      0.45      0.42       111
    surprise       0.54      0.67      0.60       150

    accuracy                           0.59      1108
   macro avg       0.46      0.46      0.45      1108
weighted avg       0.60      0.59      0.59      1108



  0%|          | 0/1249 [00:00<?, ?it/s]

Epoch 3/10, Training Loss: 1.3808 | [7.77777777777778e-06]


  0%|          | 0/139 [00:00<?, ?it/s]

              precision    recall  f1-score   support

       anger       0.50      0.33      0.40       153
     disgust       0.15      0.23      0.18        22
        fear       0.14      0.15      0.14        40
         joy       0.56      0.58      0.57       163
     neutral       0.78      0.67      0.72       469
     sadness       0.36      0.50      0.42       111
    surprise       0.52      0.72      0.60       150

    accuracy                           0.57      1108
   macro avg       0.43      0.45      0.43      1108
weighted avg       0.60      0.57      0.58      1108



  0%|          | 0/1249 [00:00<?, ?it/s]

Epoch 4/10, Training Loss: 1.1874 | [6.666666666666667e-06]


  0%|          | 0/139 [00:00<?, ?it/s]

              precision    recall  f1-score   support

       anger       0.54      0.40      0.46       153
     disgust       0.29      0.36      0.32        22
        fear       0.12      0.28      0.17        40
         joy       0.52      0.60      0.55       163
     neutral       0.77      0.68      0.72       469
     sadness       0.45      0.33      0.38       111
    surprise       0.53      0.69      0.60       150

    accuracy                           0.57      1108
   macro avg       0.46      0.48      0.46      1108
weighted avg       0.60      0.57      0.58      1108



  0%|          | 0/1249 [00:00<?, ?it/s]

Epoch 5/10, Training Loss: 1.0031 | [5.555555555555557e-06]


  0%|          | 0/139 [00:00<?, ?it/s]

              precision    recall  f1-score   support

       anger       0.51      0.42      0.46       153
     disgust       0.19      0.41      0.26        22
        fear       0.26      0.28      0.27        40
         joy       0.47      0.61      0.53       163
     neutral       0.78      0.69      0.73       469
     sadness       0.48      0.36      0.41       111
    surprise       0.54      0.65      0.59       150

    accuracy                           0.58      1108
   macro avg       0.46      0.49      0.47      1108
weighted avg       0.61      0.58      0.59      1108



  0%|          | 0/1249 [00:00<?, ?it/s]

Epoch 6/10, Training Loss: 0.7894 | [4.444444444444444e-06]


  0%|          | 0/139 [00:00<?, ?it/s]

              precision    recall  f1-score   support

       anger       0.53      0.47      0.50       153
     disgust       0.27      0.27      0.27        22
        fear       0.24      0.25      0.24        40
         joy       0.49      0.63      0.55       163
     neutral       0.79      0.68      0.73       469
     sadness       0.48      0.39      0.43       111
    surprise       0.50      0.69      0.58       150

    accuracy                           0.59      1108
   macro avg       0.47      0.48      0.47      1108
weighted avg       0.61      0.59      0.59      1108



  0%|          | 0/1249 [00:00<?, ?it/s]

Epoch 7/10, Training Loss: 0.6329 | [3.3333333333333333e-06]


  0%|          | 0/139 [00:00<?, ?it/s]

              precision    recall  f1-score   support

       anger       0.46      0.54      0.49       153
     disgust       0.26      0.41      0.32        22
        fear       0.19      0.20      0.20        40
         joy       0.64      0.48      0.55       163
     neutral       0.76      0.74      0.75       469
     sadness       0.49      0.43      0.46       111
    surprise       0.53      0.61      0.56       150

    accuracy                           0.60      1108
   macro avg       0.47      0.49      0.47      1108
weighted avg       0.61      0.60      0.60      1108



  0%|          | 0/1249 [00:00<?, ?it/s]

Epoch 8/10, Training Loss: 0.5157 | [2.222222222222222e-06]


  0%|          | 0/139 [00:00<?, ?it/s]

              precision    recall  f1-score   support

       anger       0.53      0.39      0.45       153
     disgust       0.36      0.41      0.38        22
        fear       0.20      0.20      0.20        40
         joy       0.46      0.60      0.52       163
     neutral       0.74      0.75      0.75       469
     sadness       0.49      0.34      0.40       111
    surprise       0.58      0.61      0.59       150

    accuracy                           0.59      1108
   macro avg       0.48      0.47      0.47      1108
weighted avg       0.59      0.59      0.59      1108



  0%|          | 0/1249 [00:00<?, ?it/s]

Epoch 9/10, Training Loss: 0.4416 | [1.111111111111111e-06]


  0%|          | 0/139 [00:00<?, ?it/s]

              precision    recall  f1-score   support

       anger       0.51      0.44      0.47       153
     disgust       0.30      0.36      0.33        22
        fear       0.19      0.15      0.17        40
         joy       0.50      0.58      0.53       163
     neutral       0.76      0.74      0.75       469
     sadness       0.51      0.35      0.41       111
    surprise       0.51      0.67      0.58       150

    accuracy                           0.60      1108
   macro avg       0.47      0.47      0.46      1108
weighted avg       0.60      0.60      0.60      1108



  0%|          | 0/1249 [00:00<?, ?it/s]

Epoch 10/10, Training Loss: 0.3762 | [0.0]


  0%|          | 0/139 [00:00<?, ?it/s]

              precision    recall  f1-score   support

       anger       0.50      0.44      0.47       153
     disgust       0.28      0.36      0.31        22
        fear       0.20      0.15      0.17        40
         joy       0.54      0.56      0.55       163
     neutral       0.73      0.77      0.75       469
     sadness       0.49      0.33      0.40       111
    surprise       0.56      0.65      0.60       150

    accuracy                           0.61      1108
   macro avg       0.47      0.47      0.47      1108
weighted avg       0.60      0.61      0.60      1108

----- Final Test Performance -----


  0%|          | 0/327 [00:00<?, ?it/s]

              precision    recall  f1-score   support

       anger       0.48      0.43      0.46       345
     disgust       0.23      0.22      0.23        68
        fear       0.20      0.26      0.23        50
         joy       0.58      0.55      0.56       402
     neutral       0.78      0.78      0.78      1255
     sadness       0.37      0.37      0.37       208
    surprise       0.51      0.60      0.55       281

    accuracy                           0.62      2609
   macro avg       0.45      0.46      0.45      2609
weighted avg       0.62      0.62      0.62      2609



tensor(0.6203)

In [17]:
from typing import List
def average_weights(pth_paths: List[str]) -> dict:
    assert len(pth_paths) > 0, "No checkpoint paths provided."

    # Load first model weights
    avg_wts = torch.load(pth_paths[0], map_location='cpu')

    # Initialize accumulator
    for key in avg_wts.keys():
        avg_wts[key] = avg_wts[key].clone()

    # Accumulate other weights
    for path in pth_paths[1:]:
        assert isfile(path), f"Checkpoint file {path} does not exist."
        wts = torch.load(path, map_location='cpu')
        for key in avg_wts.keys():
            avg_wts[key] += wts[key]

    # Average
    for key in avg_wts.keys():
        if avg_wts[key].dtype==torch.int64:
            print(avg_wts[key].dtype)
            avg_wts[key] = avg_wts[key]//len(pth_paths)
        else:
            avg_wts[key] /= float(len(pth_paths))

    return avg_wts

state_dict = average_weights([
    # join(os.getcwd(), "checkpoints/epoch_9-acc_tensor(0_3552).pth"),epoch_5-acc_tensor(0_3910).pth
    join(os.getcwd(), "checkpoints/epoch_6-acc_tensor(0_6013).pth"),
    join(os.getcwd(), "checkpoints/epoch_9-acc_tensor(0_5987).pth"),
    join(os.getcwd(), "checkpoints/epoch_1-acc_tensor(0_5908).pth"),
    ])
model.load_state_dict(state_dict)
model.to(device)
acc = evaluate_model(model, dev_loader)

  avg_wts = torch.load(pth_paths[0], map_location='cpu')
  wts = torch.load(path, map_location='cpu')


torch.int64
torch.int64
torch.int64
torch.int64


  0%|          | 0/139 [00:00<?, ?it/s]

              precision    recall  f1-score   support

       anger       0.56      0.42      0.48       153
     disgust       0.35      0.36      0.36        22
        fear       0.22      0.12      0.16        40
         joy       0.53      0.58      0.55       163
     neutral       0.77      0.72      0.74       469
     sadness       0.40      0.51      0.45       111
    surprise       0.56      0.69      0.62       150

    accuracy                           0.61      1108
   macro avg       0.48      0.49      0.48      1108
weighted avg       0.61      0.61      0.60      1108

