# Define the Vision Transfomer

In [None]:
import torch
import torch.nn as nn

class VisionTransformer(nn.Module):
    def __init__(self, embed_dim, num_heads, feedforward_dim, num_layers, num_tokens, max_patches, dropout=0.0, hidden_dim=None):
        super().__init__()

        # Create the embedding matrix for 2-cell binary combinations + 1 mask token
        embedding_matrix = torch.zeros((num_tokens, embed_dim))  # Shape: (num_tokens, embed_dim)

        # Generate all possible 2-cell binary patches
        patches = torch.tensor([
            [a, b]
            for a in range(2)
            for b in range(2)
        ])  # Shape: (4, 2) for 4 combinations of 2-cell binary patches

        # Assign each patch's values as its embedding
        for i, patch in enumerate(patches):
            embedding_matrix[i, :] = patch  # Set the embedding to the patch values

        # Set the last row to all 2s for the masked token
        embedding_matrix[-1, :] = 0.5  # Mask token embedding

        # Create embedding layer
        self.embedding_matrix = nn.Embedding.from_pretrained(
            embedding_matrix, freeze=True
        )

        # Transformer encoder layers
        self.encoder_layers = nn.ModuleList([
            TransformerEncoderLayer(hidden_dim, num_heads, feedforward_dim, dropout)
            for _ in range(num_layers)
        ])
        
        # Load hidden_dim
        self.hidden_dim = hidden_dim
        
        # Linear layer to project input embeddings
        self.input_proj = nn.Linear(embed_dim, self.hidden_dim)
        
        # Positional embeddings
        self.positional_embedding = nn.Parameter(torch.randn(max_patches, 1, self.hidden_dim))  # Shape: (seq_len, 1, hidden_dim)
        
        # Output layer
        self.fc_out = nn.Linear(self.hidden_dim, num_tokens-1)

    def forward(self, patches):
        # Retrieve embeddings
        embeddings = self.embedding_matrix(patches)  # (batch_size, seq_len, embed_dim)

        # Prepare input for transformer layers
        x = embeddings.permute(1, 0, 2)  # (seq_len, batch_size, embed_dim)

        # Extract seq_len and batch_size
        seq_len, batch_size, _ = x.size()
        
        # Project input to hidden_dim
        z = self.input_proj(x)  # Shape: (seq_len, batch_size, hidden_dim)

        # Add positional embedding
        pos_emb = self.positional_embedding[:seq_len, :, :].expand(-1, batch_size, -1)  # Shape: (seq_len, batch_size, hidden_dim)
        z = z + pos_emb

        # Pass through transformer layers
        for layer in self.encoder_layers:
            z = layer(z)

        # Output logits
        z = z.permute(1, 0, 2)  # Back to (batch_size, seq_len, hidden_dim)
        logits = self.fc_out(z)  # (batch_size, seq_len, num_tokens-1)
        return logits

    def get_probabilities(self, logits):
        """Compute probabilities using softmax."""
        return torch.softmax(logits, dim=-1)


class TransformerEncoderLayer(nn.Module):
    def __init__(self, hidden_dim, num_heads, feedforward_dim, dropout=0.1):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.num_heads = num_heads

        # Ensure hidden_dim is divisible by num_heads
        assert self.hidden_dim % self.num_heads == 0, "Hidden dimension must be divisible by the number of heads."

        # Multi-head attention
        self.attention = nn.MultiheadAttention(embed_dim=self.hidden_dim, num_heads=num_heads, dropout=dropout)

        # Feedforward network
        self.feedforward = nn.Sequential(
            nn.Linear(self.hidden_dim, feedforward_dim),
            nn.ReLU(),
            nn.Linear(feedforward_dim, self.hidden_dim),
            nn.Dropout(dropout),
        )

        # Layer normalization
        self.norm1 = nn.LayerNorm(self.hidden_dim)
        self.norm2 = nn.LayerNorm(self.hidden_dim)

        # Dropout
        self.dropout = nn.Dropout(dropout)

    def forward(self, z):
        """
        Args:
            z: Tensor of shape (seq_len, batch_size, hidden_dim)
        Returns:
            Tensor of shape (seq_len, batch_size, hidden_dim)
        """
        seq_len, batch_size, hidden_dim = z.size()

        # Apply LayerNorm
        z_norm = self.norm1(z)

        # Self-attention
        attn_output, _ = self.attention(z_norm, z_norm, z_norm)  # Shape: (seq_len, batch_size, hidden_dim)

        # Residual connection
        z = z + self.dropout(attn_output)

        # Feedforward layer
        z_norm = self.norm2(z)
        feedforward_output = self.feedforward(z_norm)

        # Final residual connection
        z = z + self.dropout(feedforward_output)

        return z

# Make a simple dataset

In [None]:
import matplotlib.pyplot as plt
# Toy Dataset: 100 Binary 4x4 Images
images = torch.randint(0, 2, (100, 4, 4)).long()

# Plot the first 10 images in a 5x2 subplot
fig, axs = plt.subplots(2, 5, figsize=(10, 4))
fig.suptitle("10 First Realizations of Dataset", fontsize=16)

for idx, ax in enumerate(axs.flat):  # Flatten the 2D array of axes for easy iteration
    ax.imshow(images[idx].numpy(), cmap="gray")
    ax.set_title(f"Image {idx + 1}")
    ax.axis("off")  # Turn off the axes for cleaner visualization

plt.tight_layout(rect=[0, 0, 1, 0.96])  # Adjust layout to fit the title
plt.show()


# Make everything ready for training

In [None]:
import wandb
# Parameters
batch_size = 1
embed_dim = 2
hidden_dim = 3
num_heads = 1
feedforward_dim = hidden_dim*2  # (2-4)
num_layers = 1
num_tokens = 5  # 4 tokens + 1 mask token
max_patches = 8
dropout = 0.2
learning_rate = 3e-4
num_epochs = 100

# Initialize wandb
wandb.login()
wandb.init(
    project="vision-transformer-toy-example",
    config={
        "batch_size": batch_size,
        "embed_dim": embed_dim,
        "num_heads": num_heads,
        "feedforward_dim": feedforward_dim,
        "num_layers": num_layers,
        "num_tokens": num_tokens,
        "max_patches": max_patches,
        "dropout": dropout,
        "learning_rate": learning_rate,
        "num_epochs": num_epochs,
    },
)

# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Model
model = VisionTransformer(embed_dim, num_heads, feedforward_dim, num_layers, num_tokens, max_patches, dropout, hidden_dim).to(device)
#model.load_state_dict(checkpoint)  # Load model weights
# Optimizer and Loss
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()

#Dataloader
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
class BinaryImageDataset(Dataset):
    def __init__(self, images):
        """
        Args:
            images (Tensor): Tensor of shape (num_images, 64, 64) with binary values (0 or 1).
        """
        self.images = images

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        return torch.tensor(image, dtype=torch.float32)
    
dataset = BinaryImageDataset(images)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Start training loop with all the visualizations

In [None]:
# Training Loop
for epoch in range(num_epochs):
    model.train()
    total_loss = 0

    # Determine the number of patches to mask for this epoch using the scheduler
    num_masked_patches = 4

    for batch_idx, images in enumerate(dataloader):
        # Preprocess images
        patch_indices = torch.stack([preprocess_image(img) for img in images]).long()
        masked_patches = patch_indices.clone()

        # Create a mask for the determined number of patches
        mask = torch.zeros_like(masked_patches, dtype=torch.bool, device=device)
        for b in range(masked_patches.size(0)):  # Iterate over the batch
            mask_indices = torch.randperm(masked_patches.size(1))[:num_masked_patches]
            mask[b, mask_indices] = True

        # Apply the mask
        masked_patches[mask] = num_tokens - 1  # Assign mask token

        # Move to device
        masked_patches, patch_indices = masked_patches.to(device), patch_indices.to(device)

        # Forward pass
        logits = model(masked_patches)
        
        # Take out only the masked patches
        masked_logits = torch.stack([logits[i, mask[i], :] for i in range(mask.shape[0])], dim=0)  # (current_batch_size, num_masked, num_tokens-1)
        masked_patch_indices = torch.stack([patch_indices[i, mask[i]] for i in range(mask.shape[0])], dim=0)  # (current_batch_size, num_masked)
        # Flatten masked logits and masked patch indices
        masked_logits = masked_logits.view(-1, num_tokens-1)  # Shape: (batch_size * num_masked, num_tokens-1)
        masked_patch_indices = masked_patch_indices.view(-1)  # Shape: (batch_size * num_masked)

        # Calculate loss
        loss = criterion(masked_logits,masked_patch_indices)

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        # Log batch metrics
        wandb.log({"batch_loss": loss.item(), "num_masked_patches": num_masked_patches})

        # Log visualizations for the first batch
        if batch_idx == 0:
            with torch.no_grad():
                predicted_indices = torch.argmax(logits, dim=-1).cpu()[0]
                reconstructed_image = reconstruct_image_from_patches(predicted_indices)

                visualized_masked_patches = masked_patches.cpu()[0].clone()
                visualized_masked_patches[visualized_masked_patches == num_tokens - 1] = -1
                masked_image = reconstruct_image_from_patches(visualized_masked_patches)

                wandb.log({
                    "Original Image": wandb.Image(
                        reconstruct_image_from_patches(patch_indices.cpu()[0])
                    ),
                    "Masked Image": wandb.Image(masked_image, caption="Masked Image"),
                    "Reconstructed Image": wandb.Image(
                        reconstructed_image, caption="Reconstructed Image"
                    ),
                })

    avg_loss = total_loss / len(dataloader)
    print(f"Epoch [{epoch+1}/{num_epochs}] completed. Average Loss: {avg_loss:.4f}")

    # Log epoch metrics
    wandb.log({"epoch_loss": avg_loss})

    # Save model checkpoints every 50 epochs
    if (epoch + 1) % 10 == 0:
        checkpoint_path = f"checkpoints/vision_transformer_epoch_{epoch+1}.pth"
        os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True)  # Ensure the directory exists
        torch.save(model.state_dict(), checkpoint_path)
        print(f"Checkpoint saved for epoch {epoch+1}")
        wandb.save(checkpoint_path)

# Save the final model
torch.save(model.state_dict(), "vision_transformer_final.pth")
wandb.save("vision_transformer_final.pth")
print("Final model saved as 'vision_transformer_final.pth'.")
wandb.finish()

# Build a 4x4 - test image

In [None]:
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np

In [None]:
image = torch.tensor([[0,1,0,1],[1,1,1,1],[1,0,0,0],[0,0,0,0]],dtype=torch.float32)
plt.imshow(image,cmap="gray")
plt.title("Original Image Y")
plt.show()

# Step 1: Build a masked image

In [None]:
# mask half the patches
masked_image = image.clone()
masked_image[0,0:2] = 0.5
masked_image[1,2:4] = 0.5
masked_image[2,2:4] = 0.5
masked_image[3,0:2] = 0.5
plt.imshow(masked_image,cmap="gray")
plt.title("Masked Image X")
plt.colorbar()
plt.show()
masked_image

In [None]:
# Change the dimensions
masked_patches = torch.zeros(1,8,2) # (batch_size,num_patches,embed_dim)
for i in range(8):
    if i%2 == 0:
        masked_patches[0,i,:] = masked_image[int(i/2),0:2]
    else:
        masked_patches[0,i,:] = masked_image[int((i-1)/2),2:4]
print("masked patches: ", masked_patches)

# Visualize the Vector Embeddings

In [None]:
# Convert to NumPy for easier plotting
vectors = masked_patches.squeeze(0).detach().numpy()

# Separate the x and y components (embedding dimensions)
x_components = vectors[:, 0]  # 1st embedding dimension
y_components = vectors[:, 1]  # 2nd embedding dimension

# Ensure no LaTeX is used
plt.rcParams['text.usetex'] = False

# Create the plot
plt.figure(figsize=(8, 8))
for i in range(len(vectors)):
    # Draw the arrow
    plt.arrow(0, 0, x_components[i], y_components[i], 
              head_width=0.05, head_length=0.1, fc='blue', ec='blue')

    # Add plain text label above the arrow
    label = f"x_{i+1}"  # Simple plain text label
    plt.text(x_components[i] + 0.02, y_components[i] + 0.02, label, fontsize=12, color="red")

# Add labels and grid
plt.axhline(0, color='black', linewidth=0.5)
plt.axvline(0, color='black', linewidth=0.5)
plt.grid(color='gray', linestyle='--', linewidth=0.5)
plt.title("Vector embeddings")
plt.xlabel("1st Embedding Dimension")
plt.ylabel("2nd Embedding Dimension")

# Set axis limits for better visualization
plt.xlim(min(x_components)-0.01, max(x_components)+0.1)
plt.ylim(min(y_components)-0.01, max(y_components)+0.1)

plt.show()

# Step 2: build Z = XA^T + b

In [None]:
# Add a linear layer
x = masked_patches.permute(1, 0, 2)  # (num_patches, batch_size, embed_dim)
print("x: ", x)
lin_layer = nn.Linear(2,3)
z = lin_layer(x) # (num_patches,batch_size,hidden_dim)
print("z: ", z)
print("z dim:", z.size())

# Visualize the Z vectors. 

In [None]:
import plotly.graph_objects as go

# Reshape to [8, 3] for easier manipulation
vectors = z.squeeze(1).detach().numpy()

# Normalize the vectors for consistent lengths
normalized_vectors = vectors / np.linalg.norm(vectors, axis=1, keepdims=True)

# Extract components for normalized vectors
x_components = normalized_vectors[:, 0]  # X components
y_components = normalized_vectors[:, 1]  # Y components
z_components = normalized_vectors[:, 2]  # Z components

# Create the plotly figure
fig = go.Figure()

# Add vectors to the plot
for i in range(len(normalized_vectors)):
    # Start and end points for the vector
    x_start, y_start, z_start = 0, 0, 0  # Origin
    x_end, y_end, z_end = x_components[i], y_components[i], z_components[i]

    # Add the arrow (line)
    fig.add_trace(go.Scatter3d(
        x=[x_start, x_end],
        y=[y_start, y_end],
        z=[z_start, z_end],
        mode="lines+markers",
        line=dict(color="blue", width=4),
        marker=dict(size=4),
        name=f"z_{i+1}"
    ))

    # Add label for the vector
    fig.add_trace(go.Scatter3d(
        x=[x_end],
        y=[y_end],
        z=[z_end],
        mode="text",
        text=[f"z_{i+1}"],
        textposition="top center",
        textfont=dict(size=12, color="red")
    ))

# Update layout for better aesthetics
fig.update_layout(
    title="3D Vector Embeddings",
    scene=dict(
        xaxis_title="1st Hidden Dim",
        yaxis_title="2nd Hidden Dim",
        zaxis_title="3rd Hidden Dim",
        xaxis=dict(range=[-1.5, 1.5]),
        yaxis=dict(range=[-1.5, 1.5]),
        zaxis=dict(range=[-1.5, 1.5])
    ),
    margin=dict(l=0, r=0, b=0, t=40),
    showlegend=False
)

# Show the interactive plot
fig.show()

# Step 3: Add a Positional Embedding

In [None]:
positional_embedding = nn.Parameter(torch.randn(8, 1, 3))  # Shape: (num_patches, 1, hidden_dim)
z = z + positional_embedding
print("z + positional embedding: ", z)

# Visualize with the Positional Embeddings

In [None]:
# Reshape to [8, 3] for easier manipulation
vectors = z.squeeze(1).detach().numpy()

# Normalize the vectors for consistent lengths
normalized_vectors = vectors / np.linalg.norm(vectors, axis=1, keepdims=True)

# Extract components for normalized vectors
x_components = normalized_vectors[:, 0]  # X components
y_components = normalized_vectors[:, 1]  # Y components
z_components = normalized_vectors[:, 2]  # Z components

# Create the plotly figure
fig = go.Figure()

# Add vectors to the plot
for i in range(len(normalized_vectors)):
    # Start and end points for the vector
    x_start, y_start, z_start = 0, 0, 0  # Origin
    x_end, y_end, z_end = x_components[i], y_components[i], z_components[i]

    # Add the arrow (line)
    fig.add_trace(go.Scatter3d(
        x=[x_start, x_end],
        y=[y_start, y_end],
        z=[z_start, z_end],
        mode="lines+markers",
        line=dict(color="blue", width=4),
        marker=dict(size=4),
        name=f"z_{i+1}"
    ))

    # Add label for the vector
    fig.add_trace(go.Scatter3d(
        x=[x_end],
        y=[y_end],
        z=[z_end],
        mode="text",
        text=[f"z_{i+1}"],
        textposition="top center",
        textfont=dict(size=12, color="red")
    ))

# Update layout for better aesthetics
fig.update_layout(
    title="3D Vector embeddings + Positional Embeddings",
    scene=dict(
        xaxis_title="1st Hidden Dim",
        yaxis_title="2nd Hidden Dim",
        zaxis_title="3rd Hidden Dim",
        xaxis=dict(range=[-1.5, 1.5]),
        yaxis=dict(range=[-1.5, 1.5]),
        zaxis=dict(range=[-1.5, 1.5])
    ),
    margin=dict(l=0, r=0, b=0, t=40),
    showlegend=False
)

# Show the interactive plot
fig.show()

# Step 4: Get the attention output

In [None]:
seq_len, batch_size, hidden_dim = z.size()

# Apply LayerNorm
lay_norm = nn.LayerNorm(3)
z_norm = lay_norm(z)

attention = nn.MultiheadAttention(embed_dim=3, num_heads=1, dropout=0.0)
# Self-attention
attn_output, attn_weights = attention(z_norm, z_norm, z_norm)  # Shape: (num_patches, batch_size, hidden_dim)
print("size of context vector: ", attn_output.size())
print("size of QK^T-matrix: ", attn_weights.size())

# Visualize the attention scores

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Example tensor of size (1, 8, 8)
tensor = attn_weights

# Extract the 2D matrix
matrix = tensor.squeeze(0).detach().numpy()

# Create the plot
plt.figure(figsize=(8, 8))
plt.imshow(matrix, cmap="viridis", interpolation="nearest")  # Use a colormap
plt.colorbar(label="Value")  # Add a color scale

# Annotate the values on the heatmap
for i in range(matrix.shape[0]):  # Loop over rows
    for j in range(matrix.shape[1]):  # Loop over columns
        plt.text(j, i, f"{matrix[i, j]:.2f}", ha="center", va="center", color="white" if matrix[i, j] < 0.5 else "black")

# Add labels and title
plt.title("Attention Scores")
plt.xlabel("Columns")
plt.ylabel("Rows")
plt.show()

# Visualize the Context Vectors. 

In [None]:
# Reshape to [8, 3] for easier manipulation
vectors = attn_output.squeeze(1).detach().numpy()

# Normalize the vectors for consistent lengths
normalized_vectors = vectors / np.linalg.norm(vectors, axis=1, keepdims=True)

# Extract components for normalized vectors
x_components = normalized_vectors[:, 0]  # X components
y_components = normalized_vectors[:, 1]  # Y components
z_components = normalized_vectors[:, 2]  # Z components

# Create the plotly figure
fig = go.Figure()

# Add vectors to the plot
for i in range(len(normalized_vectors)):
    # Start and end points for the vector
    x_start, y_start, z_start = 0, 0, 0  # Origin
    x_end, y_end, z_end = x_components[i], y_components[i], z_components[i]

    # Add the arrow (line)
    fig.add_trace(go.Scatter3d(
        x=[x_start, x_end],
        y=[y_start, y_end],
        z=[z_start, z_end],
        mode="lines+markers",
        line=dict(color="blue", width=4),
        marker=dict(size=4),
        name=f"z_{i+1}"
    ))

    # Add label for the vector
    fig.add_trace(go.Scatter3d(
        x=[x_end],
        y=[y_end],
        z=[z_end],
        mode="text",
        text=[f"c_{i+1}"],
        textposition="top center",
        textfont=dict(size=12, color="red")
    ))

# Update layout for better aesthetics
fig.update_layout(
    title="Context Vectors",
    scene=dict(
        xaxis_title="1st Hidden Dim",
        yaxis_title="2nd Hidden Dim",
        zaxis_title="3rd Hidden Dim",
        xaxis=dict(range=[-1.5, 1.5]),
        yaxis=dict(range=[-1.5, 1.5]),
        zaxis=dict(range=[-1.5, 1.5])
    ),
    margin=dict(l=0, r=0, b=0, t=40),
    showlegend=False
)

# Show the interactive plot
fig.show()

# Step 5: FFN Layer

In [None]:
# Feedforward layer
z = z + attn_output # (num_tokens,batch_size,hidden_dim)
z_norm = lay_norm(z) # (num_tokens,batch_size,hidden_dim)
# Feedforward network
feedforward = nn.Sequential(
    nn.Linear(3, 6),
    nn.ReLU(),
    nn.Linear(6, 3)
)
feedforward_output = feedforward(z_norm) # (num_tokens,batch_size,hidden_dim)

# Final residual connection
z = z + feedforward_output # (num_tokens,batch_size,hidden_dim)

In [None]:
import networkx as nx
# Extract weights
layer1_weights = feedforward[0].weight.detach().numpy()  # Weights of the first linear layer
layer2_weights = feedforward[2].weight.detach().numpy()  # Weights of the second linear layer
output_values = feedforward_output[0].detach().numpy()  # Select the first row to handle dimensions  # Extract values for output nodes
# Prepare text for labels
input_weights_text = [f"{w:.2f}" for w in layer1_weights.mean(axis=0)]  # Average weights for input layer
hidden_weights_text = [f"{w:.2f}" for w in layer2_weights.mean(axis=0)]  # Average weights for hidden layer
output_weights_text = [f"{float(v):.2f}" for v in output_values]  # Explicitly convert to Python float   # Just label the outputs generically

# Initialize the graph
G = nx.DiGraph()

# Define layers
input_layer = [f"Input {i+1}" for i in range(3)]
hidden_layer = [f"Hidden {i+1}" for i in range(6)]
#output_layer = [f"Output {i+1}" for i in range(3)]

# Add nodes with weights as labels
for i, node in enumerate(input_layer):
    G.add_node(node, subset=0, weight=input_weights_text[i])
for i, node in enumerate(hidden_layer):
    G.add_node(node, subset=1, weight=hidden_weights_text[i])
for i, node in enumerate(output_layer):
    G.add_node(node, subset=2, weight=output_weights_text[i])

# Connect layers
for i in input_layer:
    for h in hidden_layer:
        G.add_edge(i, h)
for h in hidden_layer:
    for o in output_layer:
        G.add_edge(h, o)

# Create a multipartite layout
plt.figure(figsize=(12, 8))
pos = nx.multipartite_layout(G, subset_key="subset")

# Draw the graph
nx.draw(
    G,
    pos,
    with_labels=True,
    labels={n: f"{n}\n{G.nodes[n]['weight']}" for n in G.nodes()},  # Add weights to node labels
    node_size=2000,
    node_color="lightblue",
    edge_color="gray",
    font_size=8,
    font_weight="bold"
)

plt.title("Feedforward Neural Network with Weights")
plt.show()