In [25]:
resnet = torchvision.models.resnet18()

torch.Size([1, 1000])


ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [28]:
layer1

Sequential(
  (0): BasicBlock(
    (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (1): BasicBlock(
    (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
)

In [None]:
import torch
import torch.nn as nn
import pytorch_lightning as pl
import torchvision
import numpy as np
import json


resnet = torchvision.models.resnet18()

# print(resnet.named_modules)

class PatchEmbedding(nn.Module):
    
    def __init__(self, in_channels=1, embed_dim=512, patch_size=16):
        super().__init__()
        self.proj = nn.Conv2d(in_channels, embed_dim, patch_size, patch_size)

    def forward(self, x):
        out = self.proj(x)
        out = out.flatten(2)
        return out

class HRMVision(pl.LightningModule):
    
    def __init__(self, output_size,in_channels=4, sequence_length = 8, patch_size=6, embed_dim=8, h_cycle = 4, l_cycle = 8, device='cpu', model_name="model", learning_rate=1e-3):
        super().__init__()
        self.model_name = model_name
        
        self.h_cycle = h_cycle
        self.l_cycle = l_cycle
        self.context_length =  sequence_length
        self.patchify = PatchEmbedding(in_channels, sequence_length, patch_size)
        
        # self.token_embed = nn.Embedding(vocab_size, hidden_size)
        self.pos_embed = nn.Embedding(self.context_length, embed_dim*2)
        self.low = nn.GRUCell(input_size=embed_dim*embed_dim*2, hidden_size=embed_dim*embed_dim*2, device=device,)
        self.high = nn.GRUCell(input_size=embed_dim*embed_dim*2, hidden_size=embed_dim*embed_dim*2, device=device)
        
        # self.low = nn.LSTMCell(input_size=hidden_size, hidden_size=hidden_size, device=device,)
        # self.high = nn.LSTMCell(input_size=hidden_size, hidden_size=hidden_size, device=device)
        
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim*embed_dim*2, embed_dim*embed_dim*2),
            nn.ReLU(),
            nn.Linear(embed_dim*embed_dim*2, output_size)
        )
        self.loss_fn = nn.CrossEntropyLoss()
        self.learning_rate = learning_rate
        
        
        # Metrics tracking
        self.train_losses = []
        self.val_losses = []
        self.train_accs = []
        self.val_accs = []
        self.gradient_norms = []
        self.learning_rates = []
        
        # For convergence speed tracking
        self.epoch_train_losses = []
        self.epoch_val_losses = []
        
        # For overfitting analysis
        self.train_val_gap = []
        # self.acc = pl.metrics.classification.Accuracy(task='multiclass', num_classes = 10)
        
    def forward(self, image):
        
        # print(f'image: {image.shape}')
        token_embs = self.patchify(image)
        # print(f'token_embs: {token_embs.shape}')
        pos_embs = self.pos_embed(torch.arange(0, self.context_length).to(image.device))
        # print(f'pos: {pos_embs.shape}')
        embs = token_embs+pos_embs
        # print(f'embs: {embs.shape}')
        embs = embs.view(image.shape[0], -1)
        # print(f'embs: {embs.shape}')
        # hx, cx = torch.zeros((tokens.shape[0],pos_embs.shape[-1])), torch.zeros((tokens.shape[0],pos_embs.shape[-1]))
        z_l = torch.zeros((image.shape[0],embs.shape[-1])).to(image.device)
        # print(z_l.shape)
        for i in range(self.h_cycle*self.l_cycle):
            z_l = self.low(embs, z_l)
            if i%self.h_cycle == 0: 
                # print(f"at {i}")
                z_h = self.high(embs, z_l)
                z_l = z_h
        # print('here')
        out = self.mlp(z_h)
        return out

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.loss_fn(logits, y)
        
        preds = logits.argmax(dim=1)
        acc = (preds == y).float().mean()
        
        # Store metrics
        self.train_losses.append(loss.item())
        self.train_accs.append(acc.item())
        
        self.log("train_loss", loss, prog_bar=True)
        self.log("train_acc", acc, prog_bar=True)
        
        return loss
    
    def on_train_epoch_end(self):
        # Calculate gradient norm for gradient flow analysis
        total_norm = 0
        for p in self.parameters():
            if p.grad is not None:
                param_norm = p.grad.data.norm(2)
                total_norm += param_norm.item() ** 2
        total_norm = total_norm ** (1. / 2)
        
        self.gradient_norms.append(total_norm)
        self.log("gradient_norm", total_norm)
        
        # Store learning rate
        current_lr = self.trainer.optimizers[0].param_groups[0]['lr']
        self.learning_rates.append(current_lr)
        self.log("learning_rate", current_lr)
        
        # Calculate epoch averages for convergence analysis
        if len(self.train_losses) > 0:
            # Get losses from this epoch only
            steps_per_epoch = len(self.trainer.train_dataloader)
            epoch_start_idx = max(0, len(self.train_losses) - steps_per_epoch)
            epoch_train_loss = np.mean(self.train_losses[epoch_start_idx:])
            self.epoch_train_losses.append(epoch_train_loss)
            
    
    def on_validation_epoch_end(self):
        # Get current epoch metrics
        val_loss = self.trainer.callback_metrics.get("val_loss", 0)
        val_acc = self.trainer.callback_metrics.get("val_acc", 0)
        
        if isinstance(val_loss, torch.Tensor):
            val_loss = val_loss.item()
        if isinstance(val_acc, torch.Tensor):
            val_acc = val_acc.item()
            
        self.val_losses.append(val_loss)
        self.val_accs.append(val_acc)
        self.epoch_val_losses.append(val_loss)
        
        # Calculate overfitting gap (train-val performance difference)
        if len(self.epoch_train_losses) > 0 and len(self.epoch_val_losses) > 0:
            train_loss = self.epoch_train_losses[-1]
            val_loss = self.epoch_val_losses[-1]
            gap = val_loss - train_loss  # Positive = overfitting
            self.train_val_gap.append(gap)
            self.log("overfitting_gap", gap)
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.learning_rate)
    
    def save_metrics(self, filepath):
        """Save all tracked metrics to a JSON file"""
        metrics = {
            'model_name': self.model_name,
            'train_losses': self.train_losses,
            'val_losses': self.val_losses,
            'train_accs': self.train_accs,
            'val_accs': self.val_accs,
            'epoch_train_losses': self.epoch_train_losses,
            'epoch_val_losses': self.epoch_val_losses,
            'gradient_norms': self.gradient_norms,
            'learning_rates': self.learning_rates,
            'train_val_gap': self.train_val_gap,
            'parameter_count': sum(p.numel() for p in self.parameters())
        }
        
        with open(filepath, 'w') as f:
            json.dump(metrics, f, indent=2)
        
        print(f"Metrics saved to {filepath}")
    
    def print_analysis(self):
        """Print analysis of tracked metrics"""
        if len(self.epoch_train_losses) < 2:
            print("Not enough epochs for analysis")
            return
            
        print(f"\n=== Analysis for {self.model_name} ===")
        
        # Convergence speed
        initial_loss = self.epoch_train_losses[0]
        final_loss = self.epoch_train_losses[-1]
        convergence_rate = (initial_loss - final_loss) / len(self.epoch_train_losses)
        print(f"Convergence rate: {convergence_rate:.6f} loss/epoch")
        
        # Overfitting analysis
        if len(self.train_val_gap) > 0:
            avg_gap = np.mean(self.train_val_gap)
            final_gap = self.train_val_gap[-1]
            print(f"Average overfitting gap: {avg_gap:.6f}")
            print(f"Final overfitting gap: {final_gap:.6f}")
            
        # Gradient flow
        if len(self.gradient_norms) > 0:
            avg_grad_norm = np.mean(self.gradient_norms)
            grad_std = np.std(self.gradient_norms)
            print(f"Average gradient norm: {avg_grad_norm:.6f}")
            print(f"Gradient norm std: {grad_std:.6f}")
            
        # Loss landscape smoothness (approximated by loss variance)
        train_loss_var = np.var(self.train_losses)
        val_loss_var = np.var(self.val_losses) if self.val_losses else 0
        print(f"Training loss variance (smoothness): {train_loss_var:.6f}")
        print(f"Validation loss variance: {val_loss_var:.6f}")   
     
model = HRMVision(output_size=10, in_channels=1)

x = torch.randn((1,1, 28, 28))
out = model(x)
print(out.shape)


class ResnetSmall(nn.Module):
    
    def __init__(self):
        super().__init__()
        
        self.conv1 = nn.Conv2d(1, 64, kernel_size=(7,7), stride=(2, 2), padding=(3,3), bias=False)
        self.bn1 = nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = resnet.maxpool
        self.layer1 = nn.Sequential(
            nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False),
            nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False),
            nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
        )
        self.avg_pool = nn.AdaptiveAvgPool2d(output_size=(1, 1))
        self.fc = nn.Linear(in_features=64, out_features=10, bias=True)
        
    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.layer1(out)
        out = self.avg_pool(out)
        out = torch.flatten(out, 1)
        # print(out.shape)
        out = self.fc(out)
        return out
    
small_model = ResnetSmall()
x = torch.randn((1,1,28,28))
out = small_model(x)

print(out.shape)


torch.Size([1, 10])


torch.Size([1, 10])


In [15]:
import torch
from torch.utils.data import DataLoader, random_split, TensorDataset
from torchvision import datasets, transforms
from torch.utils.data import Subset
import numpy as np
import pytorch_lightning as pl

# --------------------------
# 1. Load & preprocess MNIST
# --------------------------
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))   # Standard MNIST normalization
])

# Download train data
mnist = datasets.MNIST(root='.', train=True, download=True, transform=transform)

# Split into train and validation
train_size = int(0.8 * len(mnist))
val_size = len(mnist) - train_size
mnist_train, mnist_val = random_split(mnist, [train_size, val_size])



# mnist_train is already a Subset — we need to look at its .indices
original_indices = np.array(mnist_train.indices)
targets = mnist.targets.numpy()[original_indices]  # filter targets by the split

num_classes = 10
samples_per_class = 100
selected_indices = []

for c in range(num_classes):
    class_indices = np.where(targets == c)[0]           # indices within this split
    chosen = np.random.choice(class_indices, samples_per_class, replace=False)
    selected_indices.extend(chosen)

# Now remap to actual dataset indices
balanced_indices = original_indices[selected_indices]

balanced_subset = Subset(mnist, balanced_indices)

train_loader = DataLoader(balanced_subset, batch_size=32, shuffle=True)



# --------------------------
# 2. Create DataLoaders
# --------------------------
# train_loader = DataLoader(mnist_train, batch_size=64, shuffle=True, num_workers=4)
val_loader = DataLoader(mnist_val, batch_size=64, shuffle=False, num_workers=4)


# Instantiate trainer
trainer = pl.Trainer(
    max_epochs=8,
    accelerator='auto',  # Uses GPU if available
    devices=1,
    logger=True,
    log_every_n_steps=50,
    check_val_every_n_epoch=1
)

# Train
# trainer.fit(model, train_loader, val_loader)


💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [16]:
model = HRMVision(output_size=10, in_channels=1)

In [17]:
trainer.fit(model, train_loader, val_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type             | Params | Mode 
-------------------------------------------------------
0 | patchify  | PatchEmbedding   | 296    | train
1 | pos_embed | Embedding        | 128    | train
2 | low       | GRUCell          | 99.1 K | train
3 | high      | GRUCell          | 99.1 K | train
4 | mlp       | Sequential       | 17.8 K | train
5 | loss_fn   | CrossEntropyLoss | 0      | train
-------------------------------------------------------
216 K     Trainable params
0         Non-trainable params
216 K     Total params
0.865     Total estimated model params size (MB)
10        Modules in train mode
0         Modules in eval mode


Epoch 7: 100%|██████████| 32/32 [00:01<00:00, 19.09it/s, v_num=7, train_loss=0.0485, train_acc=1.000]

`Trainer.fit` stopped: `max_epochs=8` reached.


Epoch 7: 100%|██████████| 32/32 [00:01<00:00, 18.66it/s, v_num=7, train_loss=0.0485, train_acc=1.000]


In [18]:
model_name = "HRMVision"
model.save_metrics(f"{model_name}_metrics.json")
model.print_analysis()

Metrics saved to HRMVision_metrics.json

=== Analysis for model ===
Convergence rate: 0.230247 loss/epoch
Average gradient norm: 2.774449
Gradient norm std: 1.572824
Training loss variance (smoothness): 0.368841
Validation loss variance: 0.000000


In [19]:
import pandas as pd

df = pd.read_csv("./lightning_logs/version_1/metrics.csv")

In [20]:
df

Unnamed: 0,epoch,step,train_loss,val_loss
0,0,739,0.048125,
1,0,749,,0.133682
2,1,1479,0.137428,
3,1,1499,,0.095904
4,2,2219,0.029386,
5,2,2249,,0.096174


In [2]:
# Lets Reproduce GPT-2 (124M)\lightning_logs\version_3\checkpoints\epoch=7-step=256.ckpt
model = HRMVision.load_from_checkpoint("./lightning_logs/version_3/checkpoints/epoch=7-step=256.ckpt", output_size=10, in_channels=1)
model

HRMVision(
  (patchify): PatchEmbedding(
    (proj): Conv2d(1, 8, kernel_size=(6, 6), stride=(6, 6))
  )
  (pos_embed): Embedding(8, 16)
  (low): GRUCell(128, 128)
  (high): GRUCell(128, 128)
  (mlp): Sequential(
    (0): Linear(in_features=128, out_features=128, bias=True)
    (1): ReLU()
    (2): Linear(in_features=128, out_features=10, bias=True)
  )
  (loss_fn): CrossEntropyLoss()
)

In [22]:
print(mnist_val[1][0].shape, mnist_val[1][1])
test_input = mnist_val[1][0].unsqueeze(0).to('cpu')
# print(test_input.shape)
prediction = model(test_input)
print(torch.argmax(prediction, dim=-1))

torch.Size([1, 28, 28]) 4
tensor([4])


In [23]:
true_labels = [mnist_val[i][1] for i in range(len(mnist_val))]
pred_labels = [torch.argmax(model(mnist_val[i][0].unsqueeze(0).to('cpu')),dim=-1).item() for i in range(len(mnist_val))]

In [24]:
from sklearn.metrics import accuracy_score
accuracy_score(true_labels, pred_labels)

0.89575

In [None]:
# previously accuracy was 0.89075

In [65]:
labels = []
for i in range(1000):
    labels.append(mnist_train[i][1])
print(labels)

[1, 9, 2, 2, 1, 8, 3, 6, 0, 7, 2, 5, 5, 1, 3, 8, 4, 9, 4, 9, 6, 6, 1, 5, 3, 5, 0, 3, 5, 9, 7, 7, 8, 6, 8, 6, 6, 9, 1, 6, 8, 0, 0, 3, 7, 3, 6, 7, 6, 1, 7, 0, 5, 1, 3, 0, 2, 2, 1, 3, 4, 8, 7, 2, 3, 3, 4, 1, 4, 2, 4, 7, 7, 2, 8, 8, 4, 1, 3, 2, 4, 1, 5, 1, 9, 1, 3, 5, 7, 9, 1, 5, 6, 8, 7, 9, 6, 0, 1, 3, 3, 7, 1, 0, 9, 8, 5, 6, 4, 0, 5, 2, 5, 0, 9, 7, 6, 9, 5, 2, 1, 9, 2, 6, 0, 3, 6, 1, 8, 9, 5, 9, 6, 7, 6, 7, 3, 8, 8, 0, 2, 3, 5, 1, 5, 4, 9, 3, 9, 9, 8, 4, 7, 1, 8, 2, 5, 1, 1, 6, 2, 4, 8, 0, 3, 4, 2, 3, 4, 0, 1, 5, 2, 4, 3, 8, 6, 7, 2, 0, 9, 2, 2, 3, 2, 4, 2, 9, 0, 6, 1, 7, 3, 7, 2, 0, 8, 0, 7, 4, 4, 9, 4, 0, 2, 6, 8, 9, 0, 8, 5, 8, 4, 9, 4, 9, 8, 4, 8, 9, 3, 8, 0, 3, 6, 9, 3, 4, 3, 3, 3, 9, 7, 4, 9, 6, 6, 4, 9, 7, 5, 0, 6, 4, 7, 5, 5, 8, 9, 0, 9, 5, 4, 7, 0, 2, 2, 0, 8, 4, 3, 5, 8, 1, 9, 9, 5, 6, 2, 1, 6, 0, 0, 9, 8, 8, 1, 3, 6, 7, 6, 8, 2, 0, 5, 3, 2, 8, 8, 2, 6, 6, 9, 9, 1, 5, 5, 6, 8, 9, 1, 3, 9, 4, 5, 5, 5, 2, 4, 5, 7, 8, 5, 4, 5, 1, 5, 7, 1, 7, 3, 4, 6, 4, 2, 3, 8, 9, 2, 4, 5, 7, 3, 

In [66]:
count_dict = dict()
for i in range(10):
    count_dict[i] = labels.count(i)
print(count_dict)

{0: 98, 1: 106, 2: 98, 3: 106, 4: 92, 5: 101, 6: 94, 7: 110, 8: 103, 9: 92}


# Tranformers Version

In [1]:
from transformers import PreTrainedModel, PretrainedConfig
import torch
import torch.nn as nn

# 1. Create a config class
class HRMConfig(PretrainedConfig):
    model_type = "hrm"

    def __init__(self, 
                 in_channels=1, 
                 embed_dim=16, 
                 sequence_length=16, 
                 output_size=10, 
                 h_cycle=4, 
                 l_cycle=8,
                 patch_size=16, 
                 **kwargs):
        super().__init__(**kwargs)
        self.hidden_size = in_channels
        self.patch_size = patch_size
        self.embed_dim = embed_dim
        self.sequence_length = sequence_length
        self.output_size = output_size
        self.h_cycle = h_cycle
        self.l_cycle = l_cycle
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


# 2. Wrap your model inside a PreTrainedModel
class HRMForClassification(PreTrainedModel):
    config_class = HRMConfig

    def __init__(self, config):
        super().__init__(config)

        self.h_cycle = config.h_cycle
        self.l_cycle = config.l_cycle
        self.context_length =  16
        self.patchify = PatchEmbedding(config.in_channels, config.sequence_length, config.patch_size)
        
        # self.token_embed = nn.Embedding(vocab_size, hidden_size)
        self.pos_embed = nn.Embedding(self.context_length, config.embed_dim)
        self.low = nn.GRUCell(input_size=config.embed_dim*config.embed_dim, hidden_size=config.embed_dim*config.embed_dim, device=config.device,)
        self.high = nn.GRUCell(input_size=config.embed_dim*config.embed_dim, hidden_size=config.embed_dim*config.embed_dim, device=config.device)
        
        # self.low = nn.LSTMCell(input_size=hidden_size, hidden_size=hidden_size, device=device,)
        # self.high = nn.LSTMCell(input_size=hidden_size, hidden_size=hidden_size, device=device)
        
        self.mlp = nn.Sequential(
            nn.Linear(config.embed_dim*config.embed_dim, config.embed_dim*config.embed_dim),
            nn.ReLU(),
            nn.Linear(config.embed_dim*config.embed_dim, config.output_size)
        )

        # Initialize weights the Transformers way
        self.post_init()

    def forward(self, input_ids=None, labels=None, **kwargs):
        tokens = input_ids
        token_embs = self.token_embed(tokens)
        pos_embs = self.pos_embed(
            torch.arange(0, tokens.shape[-1], device=tokens.device)
        )
        embs = token_embs + pos_embs
        embs = embs.view(tokens.shape[0], -1)

        z_l = torch.zeros((tokens.shape[0], embs.shape[-1]), device=tokens.device)
        for i in range(self.h_cycle * self.l_cycle):
            z_l = self.low(embs, z_l)
            if i % self.h_cycle == 0:
                z_h = self.high(embs, z_l)
                z_l = z_h
        logits = self.mlp(z_h)

        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.config.output_size), labels.view(-1))

        return {"loss": loss, "logits": logits}


  from .autonotebook import tqdm as notebook_tqdm
