In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
import torchvision
import torchvision.transforms as transforms
from torchvision.transforms import autoaugment
from torch.utils.checkpoint import checkpoint
import torch.optim as optim
import copy
import contextlib
import numpy as np

In [2]:
def get_cifar10_loaders(
    batch_size: int         = 256, 
    num_workers: int        = 8,  
    pin_memory: bool        = True,
    persistent_workers: bool = True, 
):
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(p=0.5),
        autoaugment.AutoAugment(autoaugment.AutoAugmentPolicy.CIFAR10),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])

    trainset = torchvision.datasets.CIFAR10(
        root='./data', train=True, download=True, transform=transform_train
    )
    trainloader = torch.utils.data.DataLoader(
        trainset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=pin_memory,
        persistent_workers=persistent_workers,
        drop_last=True,              
    )

    testset = torchvision.datasets.CIFAR10(
        root='./data', train=False, download=True, transform=transform_test
    )
    testloader = torch.utils.data.DataLoader(
        testset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=pin_memory,
        persistent_workers=persistent_workers,
    )

    return trainloader, testloader

In [3]:
class SEBlock(nn.Module):
    def __init__(self, channels, reduction=16):
        super(SEBlock, self).__init__()
        self.global_avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc1 = nn.Conv2d(channels, channels // reduction, kernel_size=1)
        self.fc2 = nn.Conv2d(channels // reduction, channels, kernel_size=1)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        w = self.global_avg_pool(x)
        w = self.relu(self.fc1(w))
        w = 2 * self.sigmoid(self.fc2(w))
        return x * w 

In [4]:
class BottleneckBlock(nn.Module):
    def __init__(self, in_channels, out_channels, max_depth=3, reduction_factor=4):
        super(BottleneckBlock, self).__init__()

        reduced_channels = in_channels // reduction_factor

        self.conv1 = nn.Conv2d(in_channels, reduced_channels, kernel_size=1, bias=False)
        self.bn1 = nn.GroupNorm(1, in_channels)

        self.conv2 = nn.Conv2d(reduced_channels, reduced_channels, kernel_size=3, stride=1, padding=1, bias=False, groups=reduced_channels)
        self.bn2 = nn.GroupNorm(1, reduced_channels)
        
        self.conv3 = nn.Conv2d(reduced_channels, out_channels, kernel_size=1, bias=False)
        self.bn3 = nn.GroupNorm(1, reduced_channels)

        self.relu = nn.ReLU()

        self.skip_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
        self.skip_bn = nn.GroupNorm(1, in_channels)

        self.se = SEBlock(out_channels)

    def forward(self, x, depth):
        identity = x

        out = self.relu(x)
        out = self.bn1(out)
        out = self.conv1(out)

        out = self.relu(out)
        out = self.bn2(out)
        out = self.conv2(out)

        out = self.relu(out)
        out = self.bn3(out)
        out = self.conv3(out)

        out = self.se(out)

        skip = self.relu(identity)
        skip = self.skip_bn(skip)
        skip = self.skip_conv(skip)

        out += skip
        return out

In [5]:
class AmplifyConv(nn.Module):
    def __init__(self, channels, max_depth=3, num_module=2):
        super(AmplifyConv, self).__init__()
        self.max_depth = max_depth
        self.num_module = num_module

        self.conv = nn.ModuleList([
            BottleneckBlock(channels, channels, max_depth) for _ in range(num_module)
        ])
        self.bn = nn.GroupNorm(1, channels)

        self.step_embeddings = nn.Parameter(torch.randn(max_depth, channels, 1, 1))
        self.step_embeddings1 = nn.Parameter(torch.randn(max_depth, channels, 1, 1))
        self.step_embeddings2 = nn.Parameter(torch.randn(max_depth, channels, 1, 1))

    def step_forward(self, x, x_prev, depth):
        out = self.conv[depth // (self.max_depth // self.num_module)](x, depth)
        x = 2 * F.sigmoid(self.step_embeddings1[depth]) * F.relu(x) + (1 + self.step_embeddings[depth]) * out
        x = self.bn(x)

        if depth < self.max_depth - 1:
            x = x + 2 * F.sigmoid(self.step_embeddings2[depth]) * F.relu(x_prev)
        return x

    def forward(self, x):
        prev = []
        for d in range(self.max_depth):
            if d % 2 == 0:
                prev.append(x)
                
            x_prev = prev[((d + 1) - ((d + 1) & -(d + 1))) // 2]
            x = checkpoint(self.step_forward, x, x_prev, d, use_reentrant=False)

        return x

In [6]:
class RecursionAmplifyConv(nn.Module):
    def __init__(self, channels, height, width):
        super(RecursionAmplifyConv, self).__init__()
        
        self.conv0 = nn.Conv2d(3, channels, kernel_size=4, stride=2, padding=1)
        self.amconv = AmplifyConv(channels, max_depth=32, num_module=1)
        
        self.max_pool = nn.AdaptiveMaxPool2d((1, 1))
        self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))


    def forward(self, x):
        
        x = self.conv0(x)

        x = self.amconv(x)

        x = torch.cat([self.max_pool(x), self.avg_pool(x)], dim=1)
        
        return x

In [7]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        
        self.reamconv= RecursionAmplifyConv(512, 32, 32)

        self.fc1 = nn.Linear(1024, 10)

    def forward(self, x):
        
        x = self.reamconv(x)
        
        x = x.contiguous().view(-1, 1024)
        x = F.relu(x)
        x = self.fc1(x)
        return x

In [8]:
class RL_fine_tune:
    def __init__(self, device):
        self.device = device


    def compute_reconstruct_reward(self, outputs, targets):
        predictions = torch.argmax(outputs, dim=1)
        correct = (predictions == targets).float()
        
        rewards = (correct - 0.5)
    
        return rewards, correct.sum().item()

    def compute_positive_reward(self, outputs, targets):
        predictions = torch.argmax(outputs, dim=1)
        correct = (predictions == targets).float()
        
        rewards = correct
    
        return rewards, correct.sum().item()

    def compute_negative_reward(self, outputs, targets):
        predictions = torch.argmax(outputs, dim=1)
        correct = (predictions == targets).float()
        
        rewards = (correct - 1)
    
        return rewards, correct.sum().item()

    
    def emergency_policy_loss(self, ref_outputs, pre_outputs, outputs, targets):
        
        probs = F.softmax(outputs, dim=1)
        select = - F.log_softmax(outputs, dim=1).gather(1, targets.unsqueeze(1)).squeeze(1)
        entropy = torch.sum(probs * torch.log(probs), dim=1)

        with torch.no_grad():
            old_probs = F.softmax(outputs, dim=1)
            old_select = - F.log_softmax(outputs, dim=1).gather(1, targets.unsqueeze(1)).squeeze(1)
            old_entropy = torch.sum(old_probs * torch.log(old_probs), dim=1)

        pre_probs = F.softmax(pre_outputs, dim=1)
        pre_select = - F.log_softmax(pre_outputs, dim=1).gather(1, targets.unsqueeze(1)).squeeze(1)
        pre_entropy = torch.sum(pre_probs * torch.log(pre_probs), dim=1)

        with torch.no_grad():
            old_pre_probs = F.softmax(pre_outputs, dim=1)
            old_pre_select = - F.log_softmax(pre_outputs, dim=1).gather(1, targets.unsqueeze(1)).squeeze(1)
            old_pre_entropy = torch.sum(old_pre_probs * torch.log(old_pre_probs), dim=1)
        
        ref_probs = F.softmax(ref_outputs, dim=1)
        ref_select = - F.log_softmax(ref_outputs, dim=1).gather(1, targets.unsqueeze(1)).squeeze(1)
        ref_entropy = torch.sum(ref_probs * torch.log(ref_probs), dim=1)

        with torch.no_grad():
            old_ref_probs = F.softmax(ref_outputs, dim=1)
            old_ref_select = - F.log_softmax(ref_outputs, dim=1).gather(1, targets.unsqueeze(1)).squeeze(1)
            old_ref_entropy = torch.sum(old_ref_probs * torch.log(old_ref_probs), dim=1)

        rewards, correct = self.compute_reconstruct_reward(outputs, targets)
        p_rewards, _ = self.compute_positive_reward(outputs, targets)
        n_rewards, _ = self.compute_negative_reward(outputs, targets)

        ce_loss = F.cross_entropy(outputs, targets)

        select_delta = (select - old_select)
        pre_select_delta = (pre_select - old_pre_select)
        ref_select_delta = (ref_select - old_ref_select)

        entropy_delta = (entropy - old_entropy)
        pre_entropy_delta = (pre_entropy - old_pre_entropy)
        ref_entropy_delta = (ref_entropy - old_ref_entropy)

        with torch.no_grad():

            metrics = {
                'ce_loss': ce_loss.item(),
                'targets': select.sum().item(),
                'entropy': entropy.sum().item(),
            }
    
        return select_delta, pre_select_delta, ref_select_delta, entropy_delta, pre_entropy_delta, ref_entropy_delta, correct, metrics

In [9]:
def train_model(
    net: torch.nn.Module,
    get_cifar10_loaders,
    evaluate_model,
    batch_size: int = 64,
    epochs: int = 400,
    eval_interval: int = 10,
    lr: float = 1e-3,
    checkpoint_path: str = "Reasoning_32R_512C_RL.pth",
):
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    net = net.to(device)
    ref_net = copy.deepcopy(net).to(device)
    pre_net = copy.deepcopy(net).to(device)

    trainloader, testloader = get_cifar10_loaders(
        batch_size=batch_size,
        num_workers=8,
        pin_memory=True,
        persistent_workers=True,
    )
    
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters(), lr=lr)
    scaler = torch.amp.GradScaler(enabled=device.type == "cuda")
    best_test_acc = 0.0
    rl_accuracy = 0.0
    
    rl_trainer = RL_fine_tune(device)
    rl_metrics_accumulator = {}
    prev_ref = []
    ref_count = 16
    d = 0
    
    for epoch in range(epochs):
        net.train()
        batch_count = 0
        running_cor = 0.0
            
        if (epoch == 0):
            print("\n💫 starting supervised training and RL fine-tuning with 1e-3 learning rate\n")

            for param_group in optimizer.param_groups:
                param_group['lr'] = lr
            
            scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
                optimizer, T_max=400, eta_min=1e-6
            )
        
        rl_metrics_accumulator = {
            'ce_loss': 0.0, 'targets': 0.0, 'entropy': 0.0,
            'rewards_mean': 0.0, 'rewards_std': 0.0, 
        }
        
        for batch_idx, (inputs, targets) in enumerate(trainloader):
            inputs = inputs.to(device, non_blocking=True)
            targets = targets.to(device, non_blocking=True)
            
            optimizer.zero_grad(set_to_none=True)
            
            with torch.amp.autocast(device_type="cuda"):
                outputs = net(inputs)
                ref_outputs = ref_net(inputs)
                pre_outputs = pre_net(inputs)
                
                select_delta, pre_select_delta, ref_select_delta, entropy_delta, pre_entropy_delta, ref_entropy_delta, cor, metrics = rl_trainer.emergency_policy_loss(ref_outputs, pre_outputs, outputs, targets)

                select_ratio = F.sigmoid(select_delta) - F.sigmoid(ref_select_delta)
                select_loss = select_ratio.sum()
                
                entropy_ratio = F.sigmoid(entropy_delta) - F.sigmoid(ref_entropy_delta)
                entropy_loss = entropy_ratio.sum()
                
                loss = select_loss + entropy_loss
                    
                for key, value in metrics.items():
                    rl_metrics_accumulator[key] += value

            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(net.parameters(), max_norm=0.5)
            scaler.step(optimizer)
            scaler.update()
            
            running_cor += cor
            batch_count += inputs.shape[0]

        scheduler.step()
        
        avg_acc = running_cor / batch_count
        avg_metrics = {k: v / batch_count for k, v in rl_metrics_accumulator.items()}
            
        print(f"Epoch [{epoch + 1}] 📊 acc: {avg_acc:.3f} | Targets: {avg_metrics['targets']:.4f} | Entropy: {avg_metrics['entropy']:.4f}")

        if running_cor > rl_accuracy:
            rl_accuracy = running_cor
            ref = copy.deepcopy(net).to(device)
            if d % 2 == 0:
                prev_ref.append(ref)
                prev_ref = prev_ref[-ref_count:]
            ref_net = prev_ref[((d % ref_count + 1) - ((d % ref_count + 1) & -(d % ref_count + 1))) // 2]
            d += 1
            print(f"New ref model saved")
            
        pre_net = copy.deepcopy(net).to(device)
            
        if (epoch + 1) % eval_interval == 0 or epoch == epochs - 1:
            print(f"\nEvaluating at epoch {epoch + 1} …")
            train_acc = evaluate_model(net, trainloader, criterion, "Train", device)
            test_acc  = evaluate_model(net, testloader,  criterion, "Test",  device)
            print(f"Total Trainable Parameters: {count_parameters(net):,}\n")
            print(f"Current learning rate: {optimizer.param_groups[0]['lr']:.2e}")
            
            if test_acc > best_test_acc:
                best_test_acc = test_acc
                torch.save(net.state_dict(), checkpoint_path)
                print(f"New best model saved with test accuracy: {test_acc:.2f}%")
    
    print("Finished Training")
    print(f"Best test accuracy achieved: {best_test_acc:.2f}%")

In [10]:
def evaluate_model(
    net: torch.nn.Module,
    dataloader,
    criterion,
    dataset_name: str = "",
    device=None,
):
    device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
    was_training = net.training
    net.eval()

    total_correct = 0
    total_seen = 0
    loss_sum = 0.0

    use_bf16 = torch.cuda.is_available() and torch.cuda.is_bf16_supported()
    autocast_ctx = (torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
                    if use_bf16 else contextlib.nullcontext())

    with torch.no_grad(), autocast_ctx:
        for images, labels in dataloader:
            images = images.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)

            outputs = net(images)  
            batch_loss = criterion(outputs.float(), labels)

            if not torch.isfinite(batch_loss):
                print("[eval/warn] non-finite loss",
                      "logits_minmax=", float(outputs.min()), float(outputs.max()))
                continue

            bs = labels.size(0)
            loss_sum    += batch_loss.item() * bs
            total_seen  += bs

            preds = outputs.argmax(dim=1)
            total_correct += (preds == labels).sum().item()

    if total_seen == 0:
        avg_loss = float("nan")
        acc = 0.0
    else:
        avg_loss = loss_sum / total_seen
        acc = 100.0 * total_correct / total_seen

    print(f"{dataset_name} Accuracy: {acc:.2f}%")
    print(f"{dataset_name} Average Loss: {avg_loss:.4f}")

    if was_training:
        net.train()

    return acc

In [11]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

# Example usage:
model = Net()
print(f"Total Trainable Parameters: {count_parameters(model):,}")

Total Trainable Parameters: 515,754


In [12]:
torch.backends.cudnn.benchmark = True
net = Net()
train_model(net, get_cifar10_loaders, evaluate_model)

model = Net()
print(f"Total Trainable Parameters: {count_parameters(model):,}")


💫 starting supervised training and RL fine-tuning with 1e-3 learning rate

Epoch [1] 📊 acc: 0.322 | Targets: 1.9702 | Entropy: -2.1619
New ref model saved
Epoch [2] 📊 acc: 0.481 | Targets: 1.6887 | Entropy: -2.0808
New ref model saved
Epoch [3] 📊 acc: 0.557 | Targets: 1.5632 | Entropy: -2.0424
New ref model saved
Epoch [4] 📊 acc: 0.608 | Targets: 1.4794 | Entropy: -2.0182
New ref model saved
Epoch [5] 📊 acc: 0.642 | Targets: 1.4183 | Entropy: -2.0007
New ref model saved
Epoch [6] 📊 acc: 0.670 | Targets: 1.3719 | Entropy: -1.9882
New ref model saved
Epoch [7] 📊 acc: 0.691 | Targets: 1.3303 | Entropy: -1.9772
New ref model saved
Epoch [8] 📊 acc: 0.708 | Targets: 1.2995 | Entropy: -1.9684
New ref model saved
Epoch [9] 📊 acc: 0.721 | Targets: 1.2753 | Entropy: -1.9639
New ref model saved
Epoch [10] 📊 acc: 0.733 | Targets: 1.2555 | Entropy: -1.9577
New ref model saved

Evaluating at epoch 10 …
Train Accuracy: 72.44%
Train Average Loss: 1.2676
Test Accuracy: 76.71%
Test Average Loss: 1.1991

Epoch [176] 📊 acc: 0.774 | Targets: 1.3655 | Entropy: -2.0883
Epoch [177] 📊 acc: 0.774 | Targets: 1.3323 | Entropy: -2.0678
Epoch [178] 📊 acc: 0.770 | Targets: 1.5458 | Entropy: -2.1848
Epoch [179] 📊 acc: 0.775 | Targets: 1.1216 | Entropy: -1.8690
Epoch [180] 📊 acc: 0.775 | Targets: 1.1022 | Entropy: -1.8514

Evaluating at epoch 180 …
Train Accuracy: 77.23%
Train Average Loss: 1.1225
Test Accuracy: 81.82%
Test Average Loss: 1.0367
Total Trainable Parameters: 38,322

Current learning rate: 6.20e-04
Epoch [181] 📊 acc: 0.775 | Targets: 1.3994 | Entropy: -2.1115
Epoch [182] 📊 acc: 0.778 | Targets: 1.2781 | Entropy: -2.0291
New ref model saved
Epoch [183] 📊 acc: 0.776 | Targets: 0.7942 | Entropy: -1.2579
Epoch [184] 📊 acc: 0.776 | Targets: 1.2646 | Entropy: -2.0089
Epoch [185] 📊 acc: 0.775 | Targets: 0.7817 | Entropy: -1.2292
Epoch [186] 📊 acc: 0.778 | Targets: 1.2205 | Entropy: -1.9701
Epoch [187] 📊 acc: 0.779 | Targets: 0.8365 | Entropy: -1.4001
New ref model saved
Epoch [188] 📊 acc: 0.77

Test Accuracy: 83.26%
Test Average Loss: 1.4223
Total Trainable Parameters: 38,322

Current learning rate: 3.15e-04
Epoch [271] 📊 acc: 0.796 | Targets: 0.8388 | Entropy: -1.4508
Epoch [272] 📊 acc: 0.794 | Targets: 1.1967 | Entropy: -1.9722
Epoch [273] 📊 acc: 0.796 | Targets: 1.0460 | Entropy: -1.8167
Epoch [274] 📊 acc: 0.797 | Targets: 1.3053 | Entropy: -2.0665
New ref model saved
Epoch [275] 📊 acc: 0.796 | Targets: 0.8205 | Entropy: -1.4223
Epoch [276] 📊 acc: 0.799 | Targets: 0.9167 | Entropy: -1.6269
New ref model saved
Epoch [277] 📊 acc: 0.792 | Targets: 1.4679 | Entropy: -2.1535
Epoch [278] 📊 acc: 0.797 | Targets: 1.2423 | Entropy: -2.0183
Epoch [279] 📊 acc: 0.795 | Targets: 0.9829 | Entropy: -1.7286
Epoch [280] 📊 acc: 0.798 | Targets: 0.7185 | Entropy: -1.1903

Evaluating at epoch 280 …
Train Accuracy: 80.20%
Train Average Loss: 0.7011
Test Accuracy: 84.19%
Test Average Loss: 0.5901
Total Trainable Parameters: 38,322

Current learning rate: 2.85e-04
New best model saved with test 

Epoch [365] 📊 acc: 0.803 | Targets: 0.6842 | Entropy: -1.0049
Epoch [366] 📊 acc: 0.806 | Targets: 1.0542 | Entropy: -1.8113
Epoch [367] 📊 acc: 0.807 | Targets: 1.1731 | Entropy: -1.9685
Epoch [368] 📊 acc: 0.808 | Targets: 1.1914 | Entropy: -1.9874
Epoch [369] 📊 acc: 0.807 | Targets: 1.0791 | Entropy: -1.8744
Epoch [370] 📊 acc: 0.798 | Targets: 1.6135 | Entropy: -2.2130

Evaluating at epoch 370 …
Train Accuracy: 80.22%
Train Average Loss: 1.6420
Test Accuracy: 83.33%
Test Average Loss: 1.5963
Total Trainable Parameters: 38,322

Current learning rate: 1.12e-04
Epoch [371] 📊 acc: 0.805 | Targets: 0.7205 | Entropy: -1.1192
Epoch [372] 📊 acc: 0.808 | Targets: 0.5987 | Entropy: -0.8198
Epoch [373] 📊 acc: 0.811 | Targets: 0.7630 | Entropy: -1.3453
Epoch [374] 📊 acc: 0.805 | Targets: 1.2616 | Entropy: -2.0339
Epoch [375] 📊 acc: 0.801 | Targets: 1.4518 | Entropy: -2.1553
Epoch [376] 📊 acc: 0.809 | Targets: 1.1203 | Entropy: -1.9163
Epoch [377] 📊 acc: 0.803 | Targets: 1.4322 | Entropy: -2.1452
E