In [None]:
import numpy as np
import torch
from torch.utils.data import DataLoader
from transformers import BertTokenizerFast, BertForMaskedLM

# 1. Your Dataset class (make sure numpy is imported: import numpy as np)
import torch
import numpy as np

class MaskedTextDataset(torch.utils.data.Dataset):
    """Dataset for masked language modeling with verbose output."""
    def __init__(self, texts, tokenizer, max_length=128):
        self.texts = texts
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        print(f"\nOriginal Text [{idx}]: {text}")

        encoding = self.tokenizer(
            text,
            return_special_tokens_mask=True,
            max_length=self.max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )

        input_ids = encoding.input_ids.clone().squeeze(0)
        print("Tokenized input_ids:", input_ids.tolist())

        special_tokens_mask = encoding.special_tokens_mask.squeeze(0).bool()
        print("Special tokens mask:", special_tokens_mask.tolist())

        labels = input_ids.clone()
        print("Labels (pre-masking):", labels.tolist())

        # Find maskable positions
        mask_positions = (~special_tokens_mask).nonzero(as_tuple=True)[0]
        print("Eligible positions for masking:", mask_positions.tolist())

        num_to_mask = max(1, int(0.15 * len(mask_positions)))
        mask_indices = np.random.choice(mask_positions.tolist(), size=num_to_mask, replace=False)
        print("Selected positions to mask:", mask_indices.tolist())

        input_ids[mask_indices] = self.tokenizer.mask_token_id
        print("Input_ids after masking:", input_ids.tolist())

        attention_mask = encoding.attention_mask.squeeze(0)
        print("Attention mask:", attention_mask.tolist())

        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": labels
        }

# 2. Prepare data and tokenizer
texts = [
    "The quick brown fox jumps over the lazy dog.",
    "I love masked language modeling with Transformers!",
    "PyTorch and Hugging Face make prototyping easy."
]
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')

# 3. Instantiate dataset and DataLoader
dataset = MaskedTextDataset(texts, tokenizer, max_length=32)
loader = DataLoader(dataset, batch_size=2, shuffle=True)

for i, batch in enumerate(loader):
  print(f"batch_i: {i}\n\n\n")

  batch = {k: v.to("cuda") for k, v in batch.items()}
  print(f"batch: {batch}\n\n\n")

  print(f"input ids: {batch['input_ids']}\n")
  print(f"attention_mask: {batch['attention_mask']}\n")
  print(f"labels: {batch['labels']}\n")


Original Text [2]: PyTorch and Hugging Face make prototyping easy.
Tokenized input_ids: [101, 1052, 22123, 2953, 2818, 1998, 17662, 2227, 2191, 15053, 3723, 4691, 3733, 1012, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
Special tokens mask: [True, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True]
Labels (pre-masking): [101, 1052, 22123, 2953, 2818, 1998, 17662, 2227, 2191, 15053, 3723, 4691, 3733, 1012, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
Eligible positions for masking: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]
Selected positions to mask: [8]
Input_ids after masking: [101, 1052, 22123, 2953, 2818, 1998, 17662, 2227, 103, 15053, 3723, 4691, 3733, 1012, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
Attention mask: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 

RuntimeError: Found no NVIDIA driver on your system. Please check that you have an NVIDIA GPU and installed a driver from http://www.nvidia.com/Download/index.aspx

In [None]:
import torch.nn as nn
import os
import graphviz
import tempfile
import webbrowser
import json
import uuid
from typing import Dict, Any, Optional, Tuple

class ModelVisualizer:
    """
    A utility class for visualizing PyTorch model structures.
    """

    @staticmethod
    def _build_module_tree(named_modules):
        """
        Builds a nested dict representing the module hierarchy.
        named_modules: iterable of (full_name, module) from model.named_modules()
        """
        tree = {}
        module_types = {}

        for full_name, module in named_modules:
            # Store module type
            module_types[full_name] = type(module).__name__

            # Build tree structure
            parts = full_name.split('.') if full_name else []
            subtree = tree
            for part in parts:
                subtree = subtree.setdefault(part, {})

        return tree, module_types

    @staticmethod
    def _print_tree(subtree, module_types, full_path='', prefix='', is_last=True):
        """
        Recursively prints the nested dict as an ASCII tree with module types.
        """
        # Choose branch characters
        branch = '└─ ' if is_last else '├─ '
        for idx, (name, child) in enumerate(sorted(subtree.items())):
            is_child_last = (idx == len(subtree) - 1)

            # Calculate the full path for this module
            current_path = f"{full_path}.{name}" if full_path else name

            # Get the module type (if available)
            module_type = f" ({module_types.get(current_path, '')})" if current_path in module_types else ""

            print(prefix + branch + name + module_type + ('/' if child else ''))

            # Prepare the prefix for children
            if child:
                extension = '    ' if is_child_last else '│   '
                ModelVisualizer._print_tree(child, module_types, current_path, prefix + extension, True)

    @staticmethod
    def print_module_tree(model: nn.Module, root_name: str = 'model'):
        """
        Prints the modules of a PyTorch model in a tree structure with their types.

        Example:
            ModelVisualizer.print_module_tree(my_model)
        """
        # Build tree from module names and get module types
        tree, module_types = ModelVisualizer._build_module_tree(model.named_modules())

        # Add root model type
        module_types[''] = type(model).__name__

        # Print the root
        print(f"{root_name} ({module_types.get('', '')})/")

        # Print its children
        ModelVisualizer._print_tree(tree, module_types)

    @staticmethod
    def _collect_module_info(model: nn.Module) -> Dict[str, Dict[str, Any]]:
        """
        Collects detailed information about each module in the model.

        Returns:
            A dictionary where keys are full module names and values are
            dictionaries containing module information.
        """
        module_info = {}

        for name, module in model.named_modules():
            if name == '':  # Skip the root module
                continue

            # Collect basic info
            info = {
                'type': type(module).__name__,
                'parameters': sum(p.numel() for p in module.parameters() if p.requires_grad),
                'trainable': any(p.requires_grad for p in module.parameters()),
            }

            # Add specific module type information
            if isinstance(module, nn.Conv2d):
                info.update({
                    'in_channels': module.in_channels,
                    'out_channels': module.out_channels,
                    'kernel_size': module.kernel_size,
                    'stride': module.stride,
                    'padding': module.padding,
                })
            elif isinstance(module, nn.Linear):
                info.update({
                    'in_features': module.in_features,
                    'out_features': module.out_features,
                })

            module_info[name] = info

        return module_info

    @staticmethod
    def create_interactive_visualization(model: nn.Module, output_path: Optional[str] = None,
                                        graph_attrs: Optional[Dict[str, str]] = None,
                                        node_attrs: Optional[Dict[str, str]] = None,
                                        open_browser: bool = True) -> str:
        """
        Creates an interactive visualization of the model structure.

        When clicked, nodes will copy the full module name to clipboard and display
        detailed information about the module.

        Args:
            model (nn.Module): The PyTorch model to visualize
            output_path (str, optional): Path to save the HTML file. If None, a temporary file is used.
            graph_attrs (Dict[str, str], optional): Attributes for the graph
            node_attrs (Dict[str, str], optional): Default attributes for nodes
            open_browser (bool): Whether to open the visualization in a browser automatically

        Returns:
            str: Path to the generated HTML file
        """
        # Set default attributes if not provided
        if graph_attrs is None:
            graph_attrs = {
                'rankdir': 'TB',
                'bgcolor': 'transparent',
                'splines': 'ortho',
                'fontname': 'Arial',
                'fontsize': '14',
            }

        if node_attrs is None:
            node_attrs = {
                'style': 'filled',
                'shape': 'box',
                'fillcolor': '#E5F5FD',
                'color': '#4285F4',
                'fontname': 'Arial',
                'fontsize': '12',
                'height': '0.4',
            }

        # Create a directed graph
        dot = graphviz.Digraph(
            'model_visualization',
            format='svg',
            graph_attr=graph_attrs
        )
        dot.attr('node', **node_attrs)

        # Get module information
        tree, module_types = ModelVisualizer._build_module_tree(model.named_modules())
        module_info = ModelVisualizer._collect_module_info(model)

        # Add the root node
        root_id = str(uuid.uuid4())
        root_type = type(model).__name__
        dot.node(root_id, f'{root_type}', tooltip=f'Root: {root_type}')

        # Process the modules and build the graph
        def add_nodes(subtree, parent_id, parent_path=''):
            for name, child in sorted(subtree.items()):
                current_path = f"{parent_path}.{name}" if parent_path else name

                # Create unique ID for this node
                node_id = str(uuid.uuid4())

                # Get module type and add node
                module_type = module_types.get(current_path, "")
                label = f"{name} ({module_type})" if module_type else name

                # Create tooltip with basic info
                tooltip = f"Name: {current_path}\nType: {module_type}"
                if current_path in module_info:
                    info = module_info[current_path]
                    tooltip += f"\nParameters: {info['parameters']:,}"

                # Add the node
                fillcolor = "#E5F5FD" if child else "#C2E0F4"  # Different color for leaf nodes
                dot.node(node_id, label, tooltip=tooltip, fillcolor=fillcolor,
                        id=f'node_{current_path.replace(".", "_")}',
                        data_name=current_path)

                # Connect to parent
                dot.edge(parent_id, node_id)

                # Process children
                if child:
                    add_nodes(child, node_id, current_path)

        # Add all nodes starting from root
        add_nodes(tree, root_id)

        # Generate the SVG
        svg_content = dot.pipe().decode('utf-8')

        # Prepare the module info for JavaScript
        js_module_info = json.dumps(module_info)

        # Create HTML with interactive features
        html_content = f"""<!DOCTYPE html>
<html>
<head>
    <title>Interactive Model Visualization</title>
    <style>
        body {{
            font-family: Arial, sans-serif;
            margin: 0;
            padding: 20px;
            display: flex;
            flex-direction: column;
            height: 100vh;
        }}
        .visualization-container {{
            flex: 1;
            overflow: auto;
            border: 1px solid #ccc;
            margin-bottom: 20px;
            padding: 10px;
            position: relative;
        }}
        .info-panel {{
            height: 200px;
            overflow: auto;
            border: 1px solid #ccc;
            padding: 10px;
            background-color: #f9f9f9;
        }}
        .controls {{
            margin-bottom: 10px;
            padding: 10px;
            background-color: #f0f0f0;
            border-radius: 4px;
        }}
        .zoom-controls {{
            position: absolute;
            top: 20px;
            right: 20px;
            background: white;
            border: 1px solid #ccc;
            border-radius: 4px;
            padding: 5px;
        }}
        .zoom-button {{
            cursor: pointer;
            margin: 0 5px;
            background: #4285F4;
            color: white;
            border: none;
            padding: 5px 10px;
            border-radius: 3px;
        }}
        .copy-notification {{
            position: fixed;
            top: 20px;
            left: 50%;
            transform: translateX(-50%);
            padding: 10px 20px;
            background-color: #333;
            color: white;
            border-radius: 4px;
            opacity: 0;
            transition: opacity 0.5s;
        }}
        table {{
            border-collapse: collapse;
            width: 100%;
        }}
        th, td {{
            padding: 8px;
            text-align: left;
            border-bottom: 1px solid #ddd;
        }}
        th {{
            background-color: #f2f2f2;
        }}
        .highlight-node {{
            stroke: #FF5722 !important;
            stroke-width: 3px !important;
        }}
        .fade {{
            opacity: 0.3;
            transition: opacity 0.3s;
        }}
        .node-menu {{
            position: absolute;
            background: white;
            border: 1px solid #ccc;
            border-radius: 4px;
            padding: 5px;
            box-shadow: 0 2px 10px rgba(0,0,0,0.1);
            z-index: 100;
            display: none;
        }}
        .menu-item {{
            cursor: pointer;
            padding: 5px 10px;
        }}
        .menu-item:hover {{
            background-color: #f0f0f0;
        }}
        .export-btn {{
            margin-top: 10px;
            padding: 8px 16px;
            background-color: #4CAF50;
            color: white;
            border: none;
            border-radius: 4px;
            cursor: pointer;
        }}
        .export-btn:hover {{
            background-color: #45a049;
        }}
        .layer-path {{
            font-family: monospace;
            padding: 5px;
            background-color: #f5f5f5;
            border: 1px solid #ddd;
            border-radius: 4px;
            margin: 5px 0;
            word-break: break-all;
        }}
    </style>
</head>
<body>
    <h1>Interactive Model Visualization</h1>
    <div class="controls">
        <button id="expandAll">Expand All</button>
        <button id="collapseAll">Collapse All</button>
        <input type="text" id="searchInput" placeholder="Search modules...">
        <button id="clearSearch">Clear</button>
        <button id="exportJSON" class="export-btn">Export Model Info (JSON)</button>
    </div>
    <div class="visualization-container" id="visualization">
        {svg_content}
        <div class="zoom-controls">
            <button class="zoom-button" id="zoomIn">+</button>
            <button class="zoom-button" id="zoomOut">-</button>
            <button class="zoom-button" id="resetZoom">Reset</button>
        </div>
    </div>
    <div class="info-panel" id="infoPanel">
        <h3>Module Information</h3>
        <p>Click on a module in the visualization to view its details</p>
    </div>
    <div class="copy-notification" id="copyNotification">Module name copied to clipboard!</div>
    <div class="node-menu" id="nodeMenu">
        <div class="menu-item" id="copyName">Copy Module Name</div>
        <div class="menu-item" id="showDetails">Show Details</div>
        <div class="menu-item" id="highlightPath">Highlight Path</div>
    </div>

    <script>
        // Module information from Python
        const moduleInfo = {js_module_info};

        // Setup interactive behavior
        document.addEventListener('DOMContentLoaded', function() {{
            const svg = document.querySelector('svg');
            const infoPanel = document.getElementById('infoPanel');
            const copyNotification = document.getElementById('copyNotification');
            const searchInput = document.getElementById('searchInput');
            const clearSearch = document.getElementById('clearSearch');
            const nodeMenu = document.getElementById('nodeMenu');
            const exportJSONBtn = document.getElementById('exportJSON');

            // Add context menu to SVG
            document.addEventListener('click', function() {{
                nodeMenu.style.display = 'none';
            }});

            // Zoom functionality
            let zoomLevel = 1;
            const zoomIn = document.getElementById('zoomIn');
            const zoomOut = document.getElementById('zoomOut');
            const resetZoom = document.getElementById('resetZoom');

            function updateZoom() {{
                svg.style.transform = `scale(${{zoomLevel}})`;
                svg.style.transformOrigin = 'top left';
            }}

            zoomIn.addEventListener('click', () => {{
                zoomLevel += 0.1;
                updateZoom();
            }});

            zoomOut.addEventListener('click', () => {{
                if (zoomLevel > 0.2) zoomLevel -= 0.1;
                updateZoom();
            }});

            resetZoom.addEventListener('click', () => {{
                zoomLevel = 1;
                updateZoom();
            }});

            // Export JSON
            exportJSONBtn.addEventListener('click', function() {{
                const json = JSON.stringify(moduleInfo, null, 2);
                const blob = new Blob([json], {{ type: 'application/json' }});
                const url = URL.createObjectURL(blob);
                const a = document.createElement('a');
                a.href = url;
                a.download = 'model_info.json';
                document.body.appendChild(a);
                a.click();
                document.body.removeChild(a);
            }});

            // Add click handlers to nodes
            const nodes = svg.querySelectorAll('[id^="node_"]');
            nodes.forEach(node => {{
                node.style.cursor = 'pointer';

                // Add click event
                node.addEventListener('click', function(e) {{
                    e.stopPropagation();
                    // Get the module name from this node
                    const moduleName = this.getAttribute('data-name');

                    // Copy to clipboard and handle callback properly
                    navigator.clipboard.writeText(moduleName).then(function() {{
                        // Show notification
                        copyNotification.textContent = `Module name "${{moduleName}}" copied to clipboard!`;
                        copyNotification.style.opacity = 1;
                        setTimeout(function() {{
                            copyNotification.style.opacity = 0;
                        }}, 2000);
                    }}).catch(function(err) {{
                        console.error('Failed to copy module name: ', err);
                    }});

                    // Display module information
                    displayModuleInfo(moduleName);

                    // Highlight the selected node
                    nodes.forEach(n => n.classList.remove('highlight-node'));
                    this.classList.add('highlight-node');
                }});

                // Add context menu
                node.addEventListener('contextmenu', function(e) {{
                    e.preventDefault();
                    const moduleName = this.getAttribute('data-name');

                    // Show context menu
                    nodeMenu.style.display = 'block';
                    nodeMenu.style.left = (e.pageX) + 'px';
                    nodeMenu.style.top = (e.pageY) + 'px';

                    // Set up menu items
                    document.getElementById('copyName').onclick = function(e) {{
                        e.stopPropagation();
                        navigator.clipboard.writeText(moduleName);
                        nodeMenu.style.display = 'none';
                        copyNotification.textContent = `Module name "${{moduleName}}" copied to clipboard!`;
                        copyNotification.style.opacity = 1;
                        setTimeout(function() {{
                            copyNotification.style.opacity = 0;
                        }}, 2000);
                    }};

                    document.getElementById('showDetails').onclick = function(e) {{
                        e.stopPropagation();
                        displayModuleInfo(moduleName);
                        nodeMenu.style.display = 'none';
                    }};

                    document.getElementById('highlightPath').onclick = function(e) {{
                        e.stopPropagation();
                        // Highlight module and its parents
                        const parts = moduleName.split('.');
                        let path = '';
                        nodes.forEach(n => n.classList.add('fade'));

                        // Highlight each part of the path
                        for(let i = 0; i < parts.length; i++) {{
                            path = path ? path + '.' + parts[i] : parts[i];
                            const node = document.querySelector(`[data-name="${{path}}"]`);
                            if(node) node.classList.remove('fade');
                        }}

                        setTimeout(function() {{
                            nodes.forEach(n => n.classList.remove('fade'));
                        }}, 3000);

                        nodeMenu.style.display = 'none';
                    }};
                }});
            }});

            // Function to display module information
            function displayModuleInfo(moduleName) {{
                const info = moduleInfo[moduleName];
                if (!info) {{
                    infoPanel.innerHTML = `<h3>Module: ${{moduleName}}</h3><p>No detailed information available</p>`;
                    return;
                }}

                let html = `<h3>Module: ${{moduleName}}</h3>`;
                html += `<div class="layer-path">${{moduleName}}</div>`;
                html += `<table>`;
                html += `<tr><th>Property</th><th>Value</th></tr>`;

                // Add basic properties
                html += `<tr><td>Type</td><td>${{info.type}}</td></tr>`;
                html += `<tr><td>Parameters</td><td>${{info.parameters.toLocaleString()}}</td></tr>`;
                html += `<tr><td>Trainable</td><td>${{info.trainable ? 'Yes' : 'No'}}</td></tr>`;

                // Add specific properties
                for (const [key, value] of Object.entries(info)) {{
                    if (!['type', 'parameters', 'trainable'].includes(key)) {{
                        html += `<tr><td>${{key}}</td><td>${{JSON.stringify(value)}}</td></tr>`;
                    }}
                }}

                html += `</table>`;
                infoPanel.innerHTML = html;
            }}

            // Search functionality
            searchInput.addEventListener('input', function() {{
                const searchTerm = this.value.toLowerCase();
                if (searchTerm === '') {{
                    nodes.forEach(node => {{
                        node.style.opacity = 1;
                    }});
                    return;
                }}

                nodes.forEach(node => {{
                    const moduleName = node.getAttribute('data-name').toLowerCase();
                    if (moduleName.includes(searchTerm)) {{
                        node.style.opacity = 1;
                    }} else {{
                        node.style.opacity = 0.2;
                    }}
                }});
            }});

            clearSearch.addEventListener('click', function() {{
                searchInput.value = '';
                nodes.forEach(node => {{
                    node.style.opacity = 1;
                }});
            }});

            // Add keyboard shortcuts
            document.addEventListener('keydown', function(e) {{
                // Ctrl+F to focus search
                if (e.ctrlKey && e.key === 'f') {{
                    e.preventDefault();
                    searchInput.focus();
                }}

                // Esc to clear search
                if (e.key === 'Escape') {{
                    searchInput.value = '';
                    nodes.forEach(node => {{
                        node.style.opacity = 1;
                    }});
                }}
            }});

            // Make the visualization container resizable
            const container = document.querySelector('.visualization-container');
            let startY, startHeight;

            function initResize(e) {{
                startY = e.clientY;
                startHeight = parseInt(document.defaultView.getComputedStyle(container).height, 10);
                document.documentElement.addEventListener('mousemove', doResize, false);
                document.documentElement.addEventListener('mouseup', stopResize, false);
            }}

            function doResize(e) {{
                container.style.height = (startHeight + e.clientY - startY) + 'px';
            }}

            function stopResize() {{
                document.documentElement.removeEventListener('mousemove', doResize, false);
                document.documentElement.removeEventListener('mouseup', stopResize, false);
            }}

            // Add a resize handle
            const resizeHandle = document.createElement('div');
            resizeHandle.style.cursor = 'ns-resize';
            resizeHandle.style.height = '10px';
            resizeHandle.style.backgroundColor = '#f0f0f0';
            resizeHandle.style.borderTop = '1px solid #ccc';
            resizeHandle.style.marginBottom = '10px';
            container.after(resizeHandle);

            resizeHandle.addEventListener('mousedown', initResize, false);
        }});
    </script>
</body>
</html>
"""

        # Determine the output file path
        if output_path is None:
            fd, output_path = tempfile.mkstemp(suffix='.html')
            os.close(fd)

        # Write the HTML to file
        with open(output_path, 'w', encoding='utf-8') as f:
            f.write(html_content)

        # Open in browser if requested
        if open_browser:
            webbrowser.open('file://' + os.path.abspath(output_path))

        return output_path

In [None]:
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
            nn.Sequential(
                nn.Conv2d(16, 32, kernel_size=3, padding=1),
                nn.ReLU(),
            )
        )
        self.classifier = nn.Sequential(
            nn.Linear(32 * 16 * 16, 128),
            nn.ReLU(),
            nn.Linear(128, 10)
        )

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        return self.classifier(x)

# Instantiate and visualize
if __name__ == "__main__":
    model = SimpleCNN()
    output_path = ModelVisualizer.create_interactive_visualization(model)
    print(f"Visualization saved to: {output_path}")

Visualization saved to: /tmp/tmp4pao0mrm.html
