# 05 - Transformer Model Training & Evaluation

Train and evaluate Transformer encoder-based model for audio classification.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns

print("✓ Imports complete")

## Define Transformer Architecture

In [None]:
class AudioTransformer(nn.Module):
    def __init__(self, input_dim, model_dim, num_heads, num_layers, num_classes, dropout=0.1):
        super(AudioTransformer, self).__init__()
        self.model_dim = model_dim

        # Project input features to model dimension
        self.input_projection = nn.Linear(input_dim, model_dim)

        # Transformer Encoder
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=model_dim, 
            nhead=num_heads, 
            dropout=dropout, 
            batch_first=True
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        # Classification head
        self.fc = nn.Linear(model_dim, num_classes)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # x shape: (batch_size, sequence_length, input_dim)
        
        # Project to model dimension
        x = self.input_projection(x)  # (batch, seq_len, model_dim)

        # Transformer encoder
        transformer_output = self.transformer_encoder(x)

        # Global average pooling
        pooled_output = torch.mean(transformer_output, dim=1)  # (batch, model_dim)

        # Classification
        output = self.fc(self.dropout(pooled_output))
        return output

print("✓ Transformer model defined")

## Prepare Data

In [None]:
# Use X_train_rnn, X_val_rnn, X_test_rnn from previous notebook (already in correct shape)
# If not available, transpose here:
if 'X_train_rnn' not in dir():
    X_train_rnn = X_train.transpose(0, 2, 1)
    X_val_rnn = X_val.transpose(0, 2, 1)
    X_test_rnn = X_test.transpose(0, 2, 1)

# Convert to tensors
X_train_tensor_transformer = torch.tensor(X_train_rnn, dtype=torch.float32)
X_val_tensor_transformer = torch.tensor(X_val_rnn, dtype=torch.float32)
X_test_tensor_transformer = torch.tensor(X_test_rnn, dtype=torch.float32)

# Labels (if not available from previous cells)
if 'y_train_tensor_rnn' not in dir():
    from sklearn.preprocessing import LabelEncoder
    label_encoder = LabelEncoder()
    y_train_encoded = label_encoder.fit_transform(y_train)
    y_val_encoded = label_encoder.transform(y_val)
    y_test_encoded = label_encoder.transform(y_test)
else:
    y_train_encoded = y_train_tensor_rnn.numpy()
    y_val_encoded = y_val_tensor_rnn.numpy()
    y_test_encoded = y_test_tensor_rnn.numpy()

y_train_tensor = torch.tensor(y_train_encoded, dtype=torch.long)
y_val_tensor = torch.tensor(y_val_encoded, dtype=torch.long)
y_test_tensor = torch.tensor(y_test_encoded, dtype=torch.long)

print(f"✓ Data prepared: {X_train_tensor_transformer.shape}")

## Create DataLoader

In [None]:
class AudioDataset(torch.utils.data.Dataset):
    def __init__(self, features, labels):
        self.features = features
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return self.features[idx], self.labels[idx]

train_dataset = AudioDataset(X_train_tensor_transformer, y_train_tensor)
val_dataset = AudioDataset(X_val_tensor_transformer, y_val_tensor)
test_dataset = AudioDataset(X_test_tensor_transformer, y_test_tensor)

batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

print(f"✓ DataLoaders created")

## Train Transformer Model

In [None]:
input_dim = X_train_rnn.shape[2]
model_dim = 64
num_heads = 2
num_layers = 2
num_classes = len(y_train.unique())

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

transformer_model = AudioTransformer(input_dim, model_dim, num_heads, num_layers, num_classes).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(transformer_model.parameters(), lr=0.001)

num_epochs = 20
train_losses = []
val_losses = []
val_accuracies = []

print("Starting Transformer training...")
for epoch in range(num_epochs):
    transformer_model.train()
    running_loss = 0.0
    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = transformer_model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * inputs.size(0)

    epoch_loss = running_loss / len(train_dataset)
    train_losses.append(epoch_loss)

    # Validation
    transformer_model.eval()
    running_val_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = transformer_model(inputs)
            loss = criterion(outputs, labels)
            running_val_loss += loss.item() * inputs.size(0)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    epoch_val_loss = running_val_loss / len(val_dataset)
    val_losses.append(epoch_val_loss)
    val_acc = correct / total
    val_accuracies.append(val_acc)

    if (epoch + 1) % 5 == 0:
        print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {epoch_loss:.4f}, Val Loss: {epoch_val_loss:.4f}, Val Acc: {val_acc:.4f}")

print("✓ Transformer Training complete")

## Evaluate on Test Set

In [None]:
transformer_model.eval()
all_predictions = []
all_true_labels = []

with torch.no_grad():
    for inputs, labels in test_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = transformer_model(inputs)
        _, predicted = torch.max(outputs, 1)
        all_predictions.extend(predicted.tolist())
        all_true_labels.extend(labels.tolist())

# Metrics
accuracy = accuracy_score(all_true_labels, all_predictions)
precision = precision_score(all_true_labels, all_predictions, average='weighted')
recall = recall_score(all_true_labels, all_predictions, average='weighted')
f1 = f1_score(all_true_labels, all_predictions, average='weighted')

print("\n" + "="*60)
print("Transformer Model Evaluation Metrics")
print("="*60)
print(f"Accuracy:  {accuracy:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall:    {recall:.4f}")
print(f"F1-score:  {f1:.4f}")
print("="*60)

## Visualize Results

In [None]:
cm = confusion_matrix(all_true_labels, all_predictions)
class_labels = label_encoder.classes_

# Confusion matrix
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Purples', xticklabels=class_labels, yticklabels=class_labels)
plt.xlabel('Predicted Label')
plt.ylabel('True Label')
plt.title('Transformer Confusion Matrix')
plt.tight_layout()
plt.show()

# Per-class accuracy
per_class_acc = cm.diagonal() / cm.sum(axis=1)
plt.figure(figsize=(10, 5))
plt.bar(class_labels, per_class_acc, color='mediumpurple', edgecolor='purple')
plt.xlabel('Class')
plt.ylabel('Accuracy')
plt.title('Transformer Per-Class Accuracy')
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()

# Training curves
plt.figure(figsize=(10, 5))
plt.plot(train_losses, label='Training Loss', linewidth=2)
plt.plot(val_losses, label='Validation Loss', linewidth=2)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Transformer Training History')
plt.legend()
plt.grid(alpha=0.3)
plt.tight_layout()
plt.show()