In [None]:
import torch
import torch.nn as nn

class PatchEmbedding(nn.Module):
    def __init__(self, img_size, patch_size, in_channels, embed_dim):
        super(PatchEmbedding, self).__init__()
        self.num_patches = (img_size // patch_size) ** 2
        self.patch_size = patch_size
        self.projection = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        x = self.projection(x)  # [batch_size, embed_dim, num_patches, num_patches]
        x = x.flatten(2)        # [batch_size, embed_dim, num_patches]
        x = x.transpose(1, 2)   # [batch_size, num_patches, embed_dim]
        return x

class PositionalEncoding(nn.Module):
    def __init__(self, num_patches, embed_dim):
        super(PositionalEncoding, self).__init__()
        self.pos_embedding = nn.Parameter(torch.zeros(1, num_patches, embed_dim))

    def forward(self, x):
        return x + self.pos_embedding

class TransformerEncoderBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, mlp_dim, dropout=0.1):
        super(TransformerEncoderBlock, self).__init__()
        self.attention = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, mlp_dim),
            nn.GELU(),
            nn.Linear(mlp_dim, embed_dim),
            nn.Dropout(dropout)
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        attn_output, _ = self.attention(x, x, x)
        x = self.norm1(x + self.dropout(attn_output))
        mlp_output = self.mlp(x)
        x = self.norm2(x + self.dropout(mlp_output))
        return x

class VisionTransformer(nn.Module):
    def __init__(self, img_size=128, patch_size=16, in_channels=3, num_classes=1000, embed_dim=256, num_heads=8, mlp_dim=1024, num_layers=8, dropout=0.1):
        super(VisionTransformer, self).__init__()
        self.patch_embedding = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
        self.positional_encoding = PositionalEncoding(self.patch_embedding.num_patches + 1, embed_dim)  
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.encoder_blocks = nn.ModuleList([
            TransformerEncoderBlock(embed_dim, num_heads, mlp_dim, dropout) for _ in range(num_layers)
        ])
        self.norm = nn.LayerNorm(embed_dim)
        self.fc = nn.Linear(embed_dim, num_classes)

    def forward(self, x):
        x = self.patch_embedding(x)
        batch_size = x.shape[0]
        cls_tokens = self.cls_token.expand(batch_size, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x = self.positional_encoding(x)
        for block in self.encoder_blocks:
            x = block(x)
        cls_output = self.norm(x[:, 0])
        x = self.fc(cls_output)
        return x


In [None]:
import torch.nn as nn
import torch
import math

def xavier_initialization(model):
    for param in model.parameters():
        if param.dim() > 1:
            nn.init.xavier_uniform_(param)

def he_initialization(model):
    for param in model.parameters():
        if param.dim() > 1:
            nn.init.kaiming_uniform_(param, nonlinearity='relu')

def custom_uniform_initialization(model, low=-0.1, high=0.1):
    for param in model.parameters():
        if param.dim() > 1:
            nn.init.uniform_(param, a=low, b=high)

# def sparse_initialization(model, sparsity=0.5):
#     for param in model.parameters():
#         if param.dim() > 1:
#             nn.init.sparse_(param, sparsity=sparsity, std=0.01)
# Not doing sparse initialization because of very poor and unstable results and a requires more epochs.

def orthogonal_initialization(model):
    for param in model.parameters():
        if param.dim() > 1:
            nn.init.orthogonal_(param)

def lecun_normal_initialization(model):
    for param in model.parameters():
        if param.dim() > 1:
            nn.init.normal_(param, mean=0, std=(1 / param.size(1)) ** 0.5)

def sine_cosine_initialization(model):
    for param in model.parameters():
        if param.dim() > 1:
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(param)
            bound = math.sqrt(6 / fan_in)
            with torch.no_grad():
                param.uniform_(-bound, bound)
                param.mul_(torch.sin(param))
                

def scaled_orthogonal_initialization(model, scale=1.0):
    for param in model.parameters():
        if param.dim() > 1:
            nn.init.orthogonal_(param)
            param.mul_(scale)
            #param = param*scale used for some initialization methods both are same but the internal calculations are different I guess.

def layer_wise_initialization(model):
    for name, param in model.named_parameters():
        if "attention" in name and param.dim() > 1:
            nn.init.orthogonal_(param)
        elif "mlp" in name and param.dim() > 1:
            nn.init.normal_(param, mean=0, std=0.02)
        else:
            nn.init.xavier_uniform_(param)

def initialize_model(model, init_type='xavier'):
    if init_type == 'xavier':
        xavier_initialization(model)
    elif init_type == 'he':
        he_initialization(model)
    elif init_type == 'custom_uniform':
        custom_uniform_initialization(model)
    elif init_type == 'sparse':
        sparse_initialization(model)
    elif init_type == 'orthogonal':
        orthogonal_initialization(model)
    elif init_type == 'lecun_normal':
        lecun_normal_initialization(model)
    elif init_type == 'sine_cosine':
        sine_cosine_initialization(model)
    elif init_type == 'scaled_orthogonal':
        scaled_orthogonal_initialization(model)
    elif init_type == 'layer_wise':
        layer_wise_initialization(model)
    else:
        raise ValueError(f"Unknown initialization type: {init_type}")
    print(f"Model initialized with {init_type} initialization.")


In [3]:
import torch
import h5py
from functools import partial

# Forward hook to capture weights, activations, and covariance
def forward_hook(module, input, output, hdf5_file, layer_name, component_name, epoch, batch_idx):
    if isinstance(output, tuple):
        output = output[0]

    weights = module.weight.detach().cpu().numpy() if hasattr(module, 'weight') else None
    activations = output.detach().cpu().numpy()
    flat_activations = output.view(output.size(0), -1)
    covariance_matrix = torch.cov(flat_activations).detach().cpu().numpy()

    # Save captured data to HDF5
    save_to_hdf5(hdf5_file, {
        "weights": weights,
        "activations": activations,
        "covariance_matrix": covariance_matrix
    }, epoch, layer_name, component_name, batch_idx)

# Backward hook to capture gradients
def backward_hook(module, grad_input, grad_output, hdf5_file, layer_name, component_name, epoch, batch_idx):
    if isinstance(grad_output, tuple):
        grad_output = grad_output[0]

    gradients = grad_output.detach().cpu().numpy()

    # Save gradients to HDF5
    save_to_hdf5(hdf5_file, {
        "gradients": gradients
    }, epoch, layer_name, component_name, batch_idx)

# Save to HDF5
def save_to_hdf5(hdf5_file, data, epoch, layer_name, component_name, batch_idx):
    group_path = f'epoch_{epoch}/{layer_name}/{component_name}/batch_{batch_idx}'
    group = hdf5_file.require_group(group_path)
    
    # Overwrite any existing data for reliable storage at each interval
    for key, value in data.items():
        if value is not None:
            if isinstance(value, list):
                value = torch.tensor(value)
            if key in group:
                del group[key]  # Remove existing dataset before creating it fresh
            group.create_dataset(key, data=value, compression="gzip")
    
    print(f"[INFO] Data saved to {group_path} in HDF5 file.")

# Register hooks for specific blocks and remove after each capture to control intervals
def register_and_remove_hooks(model, hdf5_file, epoch, batch_idx):
    hooks = []
    for i, block in enumerate(model.encoder_blocks):
        if i == 3 or i == 7:  # Only 4th and 8th blocks
            layer_name = f'encoder_block_{i}'

            # Register forward hooks
            hooks.append(block.attention.register_forward_hook(
                partial(forward_hook, hdf5_file=hdf5_file, layer_name=layer_name, component_name='attention', epoch=epoch, batch_idx=batch_idx)
            ))
            hooks.append(block.mlp.register_forward_hook(
                partial(forward_hook, hdf5_file=hdf5_file, layer_name=layer_name, component_name='mlp', epoch=epoch, batch_idx=batch_idx)
            ))

            # Register backward hooks
            hooks.append(block.attention.register_backward_hook(
                partial(backward_hook, hdf5_file=hdf5_file, layer_name=layer_name, component_name='attention', epoch=epoch, batch_idx=batch_idx)
            ))
            hooks.append(block.mlp.register_backward_hook(
                partial(backward_hook, hdf5_file=hdf5_file, layer_name=layer_name, component_name='mlp', epoch=epoch, batch_idx=batch_idx)
            ))

    print("[INFO] Hooks registered for encoder_block_3 and encoder_block_7")

    return hooks  # Return hooks to be removed after capture

# Training function with hooks only at specified intervals
def train(model, train_loader, criterion, optimizer, device, epoch, n_batches_to_save, hdf5_file):
    model.train()
    running_loss = 0.0
    
    print(f"\n[INFO] Starting Epoch {epoch + 1}")
    
    for batch_idx, (inputs, targets) in enumerate(train_loader):
        inputs, targets = inputs.to(device, non_blocking=True), targets.to(device, non_blocking=True)

        # Register and unregister hooks only every n_batches_to_save
        if (batch_idx + 1) % n_batches_to_save == 0:
            print(f"[INFO] Registering hooks for Epoch {epoch + 1}, Batch {batch_idx + 1}")
            hooks = register_and_remove_hooks(model, hdf5_file, epoch, batch_idx + 1)

        # Forward pass
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

        # Unregister hooks immediately after capturing data to avoid duplication
        if (batch_idx + 1) % n_batches_to_save == 0:
            for hook in hooks:
                hook.remove()
            print(f"[INFO] Unregistered hooks after Batch {batch_idx + 1}")

    avg_loss = running_loss / len(train_loader)
    print(f"[INFO] Epoch {epoch + 1} completed with Average Loss: {avg_loss:.4f}\n")
    return avg_loss


In [4]:
import torch
import torch.optim as optim
import torch.nn as nn
import os
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# Initialize HDF5 file for saving data
def initialize_hdf5_file(filename):
    if not os.path.exists('./outputs'):
        os.makedirs('./outputs')
    return h5py.File(f'./outputs/{filename}', 'a')  # Append mode

# Save model, optimizer, and epoch information as checkpoint
def save_checkpoint(model, optimizer, epoch, filename='checkpoint.pth'):
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict()
    }
    torch.save(checkpoint, filename)
    print(f"[INFO] Checkpoint saved for epoch {epoch + 1}")

# Load model and optimizer from checkpoint
def load_checkpoint(filename, model, optimizer):
    checkpoint = torch.load(filename)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch'] + 1  # Start from the next epoch
    print(f"[INFO] Resuming training from epoch {start_epoch}")
    return start_epoch

# Main training loop with checkpointing and HDF5 data saving
def main():
    epochs = 10
    batch_size = 32
    learning_rate = 0.001
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    n_batches_to_save = 50  # Data capture every 50 batches

    # Initialize output directories and HDF5 file
    os.makedirs('./checkpoints', exist_ok=True)
    hdf5_file = initialize_hdf5_file("training_data.h5")

    # Data loader setup
    transform = transforms.Compose([
        transforms.Resize((128, 128)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)

    # Model, criterion, optimizer setup
    model = VisionTransformer(img_size=128, patch_size=16, in_channels=3, num_classes=10).to(device)
    initialize_model(model, init_type='orthogonal')
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    # Check for existing checkpoint
    checkpoint_path = './checkpoints/latest_checkpoint.pth'
    if os.path.exists(checkpoint_path):
        start_epoch = load_checkpoint(checkpoint_path, model, optimizer)
    else:
        start_epoch = 0  # Start from scratch if no checkpoint exists

    for epoch in range(start_epoch, epochs):
        # Train the model and capture data
        loss = train(model, train_loader, criterion, optimizer, device, epoch, n_batches_to_save, hdf5_file)
        
        # Save model checkpoint at the end of each epoch
        save_checkpoint(model, optimizer, epoch, filename=checkpoint_path)

    # Close the HDF5 file after training completes
    hdf5_file.close()
    print(f"[INFO] Training completed for {epochs} epochs.")

if __name__ == '__main__':
    main()


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:02<00:00, 76953841.97it/s] 


Extracting ./data/cifar-10-python.tar.gz to ./data
Model initialized with orthogonal initialization.

[INFO] Starting Epoch 1
[INFO] Registering hooks for Epoch 1, Batch 50
[INFO] Hooks registered for encoder_block_3 and encoder_block_7
[INFO] Data saved to epoch_0/encoder_block_3/attention/batch_50 in HDF5 file.


  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)


[INFO] Data saved to epoch_0/encoder_block_3/mlp/batch_50 in HDF5 file.
[INFO] Data saved to epoch_0/encoder_block_7/attention/batch_50 in HDF5 file.
[INFO] Data saved to epoch_0/encoder_block_7/mlp/batch_50 in HDF5 file.
[INFO] Data saved to epoch_0/encoder_block_7/mlp/batch_50 in HDF5 file.
[INFO] Data saved to epoch_0/encoder_block_7/attention/batch_50 in HDF5 file.
[INFO] Data saved to epoch_0/encoder_block_3/mlp/batch_50 in HDF5 file.
[INFO] Data saved to epoch_0/encoder_block_3/attention/batch_50 in HDF5 file.
[INFO] Unregistered hooks after Batch 50
[INFO] Registering hooks for Epoch 1, Batch 100
[INFO] Hooks registered for encoder_block_3 and encoder_block_7
[INFO] Data saved to epoch_0/encoder_block_3/attention/batch_100 in HDF5 file.
[INFO] Data saved to epoch_0/encoder_block_3/mlp/batch_100 in HDF5 file.
[INFO] Data saved to epoch_0/encoder_block_7/attention/batch_100 in HDF5 file.
[INFO] Data saved to epoch_0/encoder_block_7/mlp/batch_100 in HDF5 file.
[INFO] Data saved to 

In [5]:
from IPython.display import FileLink
FileLink(r'outputs/training_data.h5')