How to apply transformers in text, image, video, audio and time series etc

Transformer model
- A backbone architecture to represent and extract information from a sequence of embeddings of discrete tokens.
- It takes input of shape (seq_lenght, embedd_size) and outputs tensors of identical shapes
- Ex-1 : Images

In [1]:
import argparse
import math
import os
from datetime import datetime

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import CIFAR10
from torchsummary import summary

In [2]:
class TransformerEncoderBlock(nn.Module):
    """

    """
    def __init__(self, embed_dim, ffn_size, num_heads=2, dropout=0.2):
        super(TransformerEncoderBlock, self).__init__()
        # multi-head attention
        self.mhattn = nn.MultiheadAttention(
            embed_dim = embed_dim,
            num_heads = num_heads,
            dropout=dropout,
            batch_first=True,
        )
        # 2. Feed Forward NN
        self.ffn = nn.Sequential(
            nn.Linear(embed_dim, ffn_size),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(ffn_size, embed_dim)
        )
        # 3. Layer normalization
        self.input_ln = nn.LayerNorm(embed_dim)
        self.ffn_ln = nn.LayerNorm(embed_dim)

        # forward pass over the input data x
        def forward(self, x):
            y = self.input_ln(x)
            y, _ = self.mhattn(y, y, y),
            y += x

            z = self.ffn_ln(y)
            z = self.ffn(z)
            z += y

            return z

In [3]:
class TransformerEncoderBlock(nn.Module):
    """Pre-LayerNorm transformer encoder block with MHSA + feedforward network."""
    def __init__(self, embed_dim, ffn_dim, num_heads=4, dropout=0.1):
        super().__init__()
        self.ln1 = nn.LayerNorm(embed_dim)
        self.mha = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, dropout=dropout, batch_first=True)
        self.ln2 = nn.LayerNorm(embed_dim)
        self.ffn = nn.Sequential(
            nn.Linear(embed_dim, ffn_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(ffn_dim, embed_dim),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        # x: (batch, seq_len, embed_dim)
        y = self.ln1(x)
        attn_out, _ = self.mha(y, y, y, need_weights=False)
        x = x + attn_out
        z = self.ln2(x)
        x = x + self.ffn(z)
        return x

In [4]:
class VisionTransformer(nn.Module):
    """

    """
    def __init__(self,
                 img_dim,
                 num_patches,
                 num_classes,
                 embed_dim,
                 ffn_size,
                 num_heads,
                 dropout,
                 num_blocks):
        super(VisionTransformer, self).__init__()

In [5]:
class VisionTransformer(nn.Module):
    def __init__(
        self,
        image_size=(3, 32, 32),
        patch_grid=(8, 8),
        embed_dim=64,
        ffn_dim=128,
        num_heads=4,
        num_blocks=4,
        num_classes=10,
        dropout=0.1,
        use_sinusoidal_pos=False,
    ):
        super().__init__()
        C, H, W = image_size
        px = H // patch_grid[0]
        py = W // patch_grid[1]
        assert H % patch_grid[0] == 0 and W % patch_grid[1] == 0, "Image dims must be divisible by patch grid"

        self.patch_grid = patch_grid
        self.patch_size = (px, py)
        self.num_patches = patch_grid[0] * patch_grid[1]
        self.num_channels = C
        self.embed_dim = embed_dim

        # linear projection of flattened patch (C * px * py -> embed_dim)
        patch_area = C * px * py
        self.patch_embed = nn.Linear(patch_area, embed_dim)

        # class token
        self.class_token = nn.Parameter(torch.randn(1, 1, embed_dim))

        # positional encoding (sinusoidal or learnable)
        seq_len = self.num_patches + 1
        if use_sinusoidal_pos:
            self.register_buffer('pos_encoding', self._build_sinusoidal_pos_encoding(seq_len, embed_dim))
        else:
            self.pos_embedding = nn.Parameter(torch.randn(1, seq_len, embed_dim))

        # transformer blocks
        self.blocks = nn.ModuleList([
            TransformerEncoderBlock(embed_dim=embed_dim, ffn_dim=ffn_dim, num_heads=num_heads, dropout=dropout)
            for _ in range(num_blocks)
        ])
        self.norm = nn.LayerNorm(embed_dim)

        # classification head
        self.head = nn.Linear(embed_dim, num_classes)

    def _build_sinusoidal_pos_encoding(self, seq_len, dim):
        pe = torch.zeros(seq_len, dim)
        position = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, dim, 2).float() * (-math.log(10000.0) / dim))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        return pe.unsqueeze(0)  # (1, seq_len, dim)

    def patchify(self, x: torch.Tensor):
        # x: (B, C, H, W)
        B, C, H, W = x.shape
        px, py = self.patch_size
        gx, gy = self.patch_grid
        # reshape to (B, C, gx, px, gy, py)
        x = x.view(B, C, gx, px, gy, py)
        # move patch dims next to channels then flatten patch content
        # (B, gx, gy, C, px, py)
        x = x.permute(0, 2, 4, 1, 3, 5)
        # flatten patch spatial & channel dims -> (B, num_patches, patch_area)
        x = x.reshape(B, gx * gy, C * px * py)
        return x

    def forward(self, x: torch.Tensor):
        # patchify + embed
        patches = self.patchify(x)  # (B, num_patches, patch_area)
        patches = self.patch_embed(patches)  # (B, num_patches, embed_dim)

        B = patches.size(0)
        class_tok = self.class_token.expand(B, -1, -1)  # (B, 1, embed_dim)
        x = torch.cat([class_tok, patches], dim=1)  # (B, seq_len, embed_dim)

        # add positional encoding
        if hasattr(self, 'pos_encoding'):
            x = x + self.pos_encoding.to(x.dtype).to(x.device)
        else:
            x = x + self.pos_embedding

        # transformer
        for blk in self.blocks:
            x = blk(x)
        x = self.norm(x)

        cls = x[:, 0]  # (B, embed_dim)
        logits = self.head(cls)
        return logits




# ----------------------------
# Optional: attention visualization helper (requires model modification)
# ----------------------------
# Note: The current compact transformer uses nn.MultiheadAttention with need_weights=False in forward to save memory.
# To visualize attention, modify TransformerEncoderBlock to return attention weights (set need_weights=True) and collect them.
# For brevity we include a simple placeholder function here; users may adjust the block if they need attention maps.





In [6]:
# ----------------------------
# Training utilities
# ----------------------------

def get_dataloaders(batch_size=128, data_dir='./data', num_workers=4):
    mean = (0.4914, 0.4822, 0.4465)
    std = (0.2470, 0.2435, 0.2616)

    train_transforms = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ])

    test_transforms = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ])

    train_set = CIFAR10(root=data_dir, train=True, download=True, transform=train_transforms)
    test_set = CIFAR10(root=data_dir, train=False, download=True, transform=test_transforms)

    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=num_workers)
    return train_loader, test_loader


def train_one_epoch(model, device, dataloader, optimizer, epoch, scaler=None):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    for X, y in dataloader:
        X = X.to(device)
        y = y.to(device)
        optimizer.zero_grad()
        logits = model(X)
        loss = F.cross_entropy(logits, y)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * X.size(0)
        preds = logits.argmax(dim=1)
        correct += (preds == y).sum().item()
        total += X.size(0)

    avg_loss = running_loss / total
    acc = correct / total
    print(f"Epoch {epoch}: Train loss {avg_loss:.4f}, Train acc {acc:.4f}")
    return avg_loss, acc


def evaluate(model, device, dataloader):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for X, y in dataloader:
            X = X.to(device)
            y = y.to(device)
            logits = model(X)
            loss = F.cross_entropy(logits, y, reduction='sum')
            running_loss += loss.item()
            preds = logits.argmax(dim=1)
            correct += (preds == y).sum().item()
            total += X.size(0)
    avg_loss = running_loss / total
    acc = correct / total
    print(f"Eval loss {avg_loss:.4f}, Eval acc {acc:.4f}")
    return avg_loss, acc

def save_checkpoint(state, outdir, name="best.pth"):
    os.makedirs(outdir, exist_ok=True)
    path = os.path.join(outdir, name)
    torch.save(state, path)
    print(f"Saved checkpoint: {path}")



In [9]:
# ----------------------------
# Main
# ----------------------------

batch_size = 128
epochs = 10
lr = 3e-4
weight_decay = 0.05
data_dir = './data'
output_dir = './checkpoints'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
seed = 42
num_workers = 4


def set_seed(seed=42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


set_seed(seed)

device = torch.device(device if torch.cuda.is_available() and 'cuda' in device else 'cpu')
print(f"Using device: {device}")

train_loader, test_loader = get_dataloaders(batch_size=batch_size, data_dir=data_dir, num_workers=num_workers)

model = VisionTransformer(
    image_size=(3, 32, 32),
    patch_grid=(8, 8),
    embed_dim=128,
    ffn_dim=256,
    num_heads=4,
    num_blocks=6,
    num_classes=10,
    dropout=0.1,
    use_sinusoidal_pos=False,
)
model.to(device)

# optimizer + scheduler
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
total_steps = epochs * len(train_loader)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=total_steps)

best_acc = 0.0
for epoch in range(1, epochs + 1):
    train_loss, train_acc = train_one_epoch(model, device, train_loader, optimizer, epoch)
    val_loss, val_acc = evaluate(model, device, test_loader)
    # step scheduler once per epoch (approx)
    scheduler.step()

    # checkpoint best
    if val_acc > best_acc:
        best_acc = val_acc
        save_checkpoint({
            'epoch': epoch,
            'model_state': model.state_dict(),
            'optimizer_state': optimizer.state_dict(),
            'val_acc': val_acc,
        }, output_dir, name=f"vit_cifar10_best.pth")

print(f"Training finished. Best val acc: {best_acc:.4f}")

Using device: cuda


100%|██████████| 170M/170M [00:14<00:00, 11.9MB/s]


Epoch 1: Train loss 1.8406, Train acc 0.3241
Eval loss 1.5571, Eval acc 0.4411
Saved checkpoint: ./checkpoints/vit_cifar10_best.pth
Epoch 2: Train loss 1.5562, Train acc 0.4346
Eval loss 1.4083, Eval acc 0.4927
Saved checkpoint: ./checkpoints/vit_cifar10_best.pth
Epoch 3: Train loss 1.4489, Train acc 0.4748
Eval loss 1.3103, Eval acc 0.5267
Saved checkpoint: ./checkpoints/vit_cifar10_best.pth
Epoch 4: Train loss 1.3731, Train acc 0.5023
Eval loss 1.2810, Eval acc 0.5360
Saved checkpoint: ./checkpoints/vit_cifar10_best.pth
Epoch 5: Train loss 1.3113, Train acc 0.5250
Eval loss 1.1954, Eval acc 0.5675
Saved checkpoint: ./checkpoints/vit_cifar10_best.pth
Epoch 6: Train loss 1.2629, Train acc 0.5427
Eval loss 1.1834, Eval acc 0.5721
Saved checkpoint: ./checkpoints/vit_cifar10_best.pth
Epoch 7: Train loss 1.2220, Train acc 0.5600
Eval loss 1.1290, Eval acc 0.5904
Saved checkpoint: ./checkpoints/vit_cifar10_best.pth
Epoch 8: Train loss 1.1867, Train acc 0.5719
Eval loss 1.0914, Eval acc 0.60

In [11]:
checkpoint_path = os.path.join(output_dir, "vit_cifar10_best.pth")
checkpoint = torch.load(checkpoint_path, map_location=device)

model_eval = VisionTransformer(
    image_size=(3, 32, 32),
    patch_grid=(8, 8),
    embed_dim=128,
    ffn_dim=256,
    num_heads=4,
    num_blocks=6,
    num_classes=10,
    dropout=0.1,
    use_sinusoidal_pos=False,
)
model_eval.load_state_dict(checkpoint['model_state'])
model_eval.to(device)

VisionTransformer(
  (patch_embed): Linear(in_features=48, out_features=128, bias=True)
  (blocks): ModuleList(
    (0-5): 6 x TransformerEncoderBlock(
      (ln1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (mha): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
      )
      (ln2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (ffn): Sequential(
        (0): Linear(in_features=128, out_features=256, bias=True)
        (1): GELU(approximate='none')
        (2): Dropout(p=0.1, inplace=False)
        (3): Linear(in_features=256, out_features=128, bias=True)
        (4): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  (head): Linear(in_features=128, out_features=10, bias=True)
)

In [12]:
_, real_data_loader = get_dataloaders(batch_size=batch_size, data_dir=data_dir, num_workers=num_workers)



In [13]:
print(get_dataloaders)

<function get_dataloaders at 0x7dd949321d00>


In [14]:
model_eval.eval()
real_data_loss, real_data_acc = evaluate(model_eval, device, real_data_loader)
print(f"Evaluation on real dataset: Loss {real_data_loss:.4f}, Accuracy {real_data_acc:.4f}")



Eval loss 1.0597, Eval acc 0.6160
Evaluation on real dataset: Loss 1.0597, Accuracy 0.6160


In [15]:
model_eval.eval()
predictions = []
with torch.no_grad():
    for X, y in real_data_loader:
        X = X.to(device)
        logits = model_eval(X)
        preds = torch.argmax(logits, dim=1)
        predictions.extend(preds.cpu().tolist())

print(f"Number of predictions: {len(predictions)}")



Number of predictions: 10000


In [16]:
from sklearn.metrics import classification_report, confusion_matrix
import numpy as np

# Extract true labels from the DataLoader
true_labels = []
for _, y in real_data_loader:
    true_labels.extend(y.cpu().tolist())

# Ensure the number of predictions and true labels match
print(f"Number of true labels: {len(true_labels)}")
print(f"Number of predictions: {len(predictions)}")
if len(true_labels) != len(predictions):
    print("Warning: Number of true labels and predictions do not match.")


# Calculate and print classification report
print("\nClassification Report:")
print(classification_report(true_labels, predictions))

# Calculate and print confusion matrix
print("\nConfusion Matrix:")
conf_matrix = confusion_matrix(true_labels, predictions)
print(conf_matrix)

# Interpret the results
print("\nInterpretation:")
print(f"Overall Accuracy: {real_data_acc:.4f}")
print("The classification report provides precision, recall, and F1-score for each class.")
print("Precision indicates the accuracy of positive predictions for each class.")
print("Recall indicates the ability of the model to find all positive samples for each class.")
print("The F1-score is the harmonic mean of precision and recall.")
print("The confusion matrix shows the number of correct and incorrect predictions for each class.")
print("Diagonal elements represent correct predictions, while off-diagonal elements represent misclassifications.")



Number of true labels: 10000
Number of predictions: 10000

Classification Report:
              precision    recall  f1-score   support

           0       0.57      0.73      0.64      1000
           1       0.63      0.78      0.70      1000
           2       0.52      0.52      0.52      1000
           3       0.51      0.31      0.39      1000
           4       0.59      0.55      0.57      1000
           5       0.53      0.56      0.54      1000
           6       0.67      0.74      0.70      1000
           7       0.69      0.68      0.68      1000
           8       0.73      0.74      0.74      1000
           9       0.69      0.55      0.62      1000

    accuracy                           0.62     10000
   macro avg       0.61      0.62      0.61     10000
weighted avg       0.61      0.62      0.61     10000


Confusion Matrix:
[[730  40  30  14  17   2  25   7 103  32]
 [ 66 778  14   3   5   1   6   5  32  90]
 [106  16 519  47 109  68  71  32  18  14]
 [ 37  18  

## Summary:

### Data Analysis Key Findings

*   The trained Vision Transformer model achieved an accuracy of 61.60% and a loss of 1.0597 on the real CIFAR-10 test dataset.
*   The analysis of predictions revealed varying performance across different classes in the CIFAR-10 dataset, as shown by the precision, recall, and F1-scores in the classification report.
*   The confusion matrix indicated the distribution of correct and incorrect predictions for each class, highlighting which classes the model struggled to classify correctly.

### Insights or Next Steps

*   Investigate the specific classes where the model shows lower performance (lower precision, recall, and F1-score) to understand potential causes, such as class imbalance or visual similarities between misclassified classes.
*   Consider techniques like data augmentation, transfer learning with a pre-trained model on a larger dataset, or fine-tuning the model architecture or hyperparameters to improve performance, especially for the underperforming classes.
