In [None]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import time

#############################
# Configuration and Setup
#############################
tf.random.set_seed(42)
np.random.seed(42)

batch_size = 128
num_classes = 10
epochs = 200
weight_decay = 5e-4
initial_lr = 0.1

# CIFAR-10 mean/std for normalization
cifar10_mean = [0.4914, 0.4822, 0.4465]
cifar10_std = [0.2470, 0.2435, 0.2616]

#############################
# Data Loading and Preprocessing
#############################
(x_train_full, y_train_full), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
y_train_full = tf.keras.utils.to_categorical(y_train_full, num_classes)
y_test = tf.keras.utils.to_categorical(y_test, num_classes)

# Split off a validation set
val_size = 5000
x_val = x_train_full[:val_size]
y_val = y_train_full[:val_size]
x_train = x_train_full[val_size:]
y_train = y_train_full[val_size:]

x_train = x_train.astype('float32') / 255.
x_val = x_val.astype('float32') / 255.
x_test = x_test.astype('float32') / 255.

def normalize(x):
    x = (x - cifar10_mean) / cifar10_std
    return x

data_augmentation = tf.keras.Sequential([
    tf.keras.layers.RandomCrop(32, 32),
    tf.keras.layers.RandomFlip("horizontal")
])

#############################
# Channel-Wise Learnable Activation
#############################
class ChannelWiseLearnableActivation(tf.keras.layers.Layer):
    def __init__(self, hidden_units=16, wd=1e-4):
        super(ChannelWiseLearnableActivation, self).__init__()
        self.hidden_units = hidden_units
        self.wd = wd

    def build(self, input_shape):
        c = input_shape[-1]
        self.dense1 = tf.keras.layers.Dense(
            self.hidden_units,
            activation='relu',
            kernel_initializer='he_normal',
            kernel_regularizer=tf.keras.regularizers.l2(self.wd)
        )
        self.dense2 = tf.keras.layers.Dense(
            c,
            activation='sigmoid',
            kernel_initializer='he_normal',
            kernel_regularizer=tf.keras.regularizers.l2(self.wd)
        )
        super().build(input_shape)

    def call(self, inputs):
        # inputs: (batch, h, w, c)
        x_mean = tf.reduce_mean(inputs, axis=[1, 2])  # (batch, c)
        x_hidden = self.dense1(x_mean)  # (batch, hidden_units)
        scale = self.dense2(x_hidden)   # (batch, c)
        scale = tf.reshape(scale, [-1, 1, 1, tf.shape(scale)[-1]])
        return inputs * scale

#############################
# ResNet Building Blocks
#############################
regularizer = tf.keras.regularizers.l2(weight_decay)
initializer = tf.keras.initializers.HeNormal()

def conv3x3(filters, stride=1):
    return tf.keras.layers.Conv2D(filters, kernel_size=3, strides=stride, padding='same',
                                  kernel_initializer=initializer,
                                  kernel_regularizer=regularizer, use_bias=False)

class NoOpLayer(tf.keras.layers.Layer):
    def call(self, x, training=None):
        return x

class ResidualBlock(tf.keras.layers.Layer):
    def __init__(self, filters, stride=1):
        super(ResidualBlock, self).__init__()
        self.filters = filters
        self.stride = stride

        self.conv1 = conv3x3(filters, stride)
        self.bn1 = tf.keras.layers.BatchNormalization()
        self.act1 = ChannelWiseLearnableActivation(hidden_units=16, wd=weight_decay)

        self.conv2 = conv3x3(filters)
        self.bn2 = tf.keras.layers.BatchNormalization()

        if stride != 1:
            self.shortcut = tf.keras.Sequential([
                tf.keras.layers.Conv2D(filters, kernel_size=1, strides=stride,
                                       kernel_initializer=initializer,
                                       kernel_regularizer=regularizer, use_bias=False),
                tf.keras.layers.BatchNormalization()
            ])
        else:
            self.shortcut = NoOpLayer()

        self.act2 = ChannelWiseLearnableActivation(hidden_units=16, wd=weight_decay)

    def call(self, x, training=False):
        shortcut = self.shortcut(x, training=training)

        out = self.conv1(x, training=training)
        out = self.bn1(out, training=training)
        out = self.act1(out)

        out = self.conv2(out, training=training)
        out = self.bn2(out, training=training)
        out = out + shortcut
        out = self.act2(out)
        return out

def make_layer(filters, num_blocks, stride=1):
    layers = []
    layers.append(ResidualBlock(filters, stride))
    for _ in range(1, num_blocks):
        layers.append(ResidualBlock(filters, stride=1))
    return tf.keras.Sequential(layers)

def build_resnet20():
    inputs = tf.keras.Input(shape=(32,32,3))
    x = conv3x3(16)(inputs)
    x = tf.keras.layers.BatchNormalization()(x)
    x = ChannelWiseLearnableActivation(hidden_units=16, wd=weight_decay)(x)

    # ResNet-20: 3 groups of residual blocks, each with 3 blocks
    x = make_layer(16, 3, stride=1)(x)
    x = make_layer(32, 3, stride=2)(x)
    x = make_layer(64, 3, stride=2)(x)

    x = tf.keras.layers.GlobalAveragePooling2D()(x)
    outputs = tf.keras.layers.Dense(num_classes, activation='softmax',
                                    kernel_initializer=initializer,
                                    kernel_regularizer=regularizer)(x)

    model = tf.keras.Model(inputs, outputs)
    return model

#############################
# Prepare Datasets
#############################
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = train_dataset.shuffle(buffer_size=50000).batch(batch_size)
train_dataset = train_dataset.map(lambda x,y: (data_augmentation(x, training=True), y))
train_dataset = train_dataset.map(lambda x,y: (normalize(x), y))
train_dataset = train_dataset.prefetch(tf.data.AUTOTUNE)

val_dataset = tf.data.Dataset.from_tensor_slices((x_val, y_val))
val_dataset = val_dataset.batch(batch_size)
val_dataset = val_dataset.map(lambda x,y: (normalize(x), y))
val_dataset = val_dataset.prefetch(tf.data.AUTOTUNE)

test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test))
test_dataset = test_dataset.batch(batch_size)
test_dataset = test_dataset.map(lambda x,y: (normalize(x), y))
test_dataset = test_dataset.prefetch(tf.data.AUTOTUNE)

#############################
# Learning Rate Scheduler (Cosine Decay)
#############################
steps_per_epoch = len(x_train) // batch_size
total_steps = steps_per_epoch * epochs

lr_schedule = tf.keras.optimizers.schedules.CosineDecay(
    initial_learning_rate=initial_lr,
    decay_steps=total_steps,
    alpha=0.001
)

optimizer = tf.keras.optimizers.SGD(learning_rate=lr_schedule, momentum=0.9, nesterov=True)

#############################
# Training and Callbacks
#############################
class TimingCallback(tf.keras.callbacks.Callback):
    def on_train_begin(self, logs=None):
        self.train_times = []
        self.start_time = time.time()

    def on_epoch_end(self, epoch, logs=None):
        end_time = time.time()
        epoch_time = end_time - self.start_time
        self.train_times.append(epoch_time)
        self.start_time = end_time
        print(f"Epoch {epoch+1}/{epochs}, Time: {epoch_time:.2f}s, "
              f"Loss: {logs.get('loss',0):.4f}, Val_Loss: {logs.get('val_loss',0):.4f}, "
              f"Val_Acc: {logs.get('val_accuracy',0):.4f}")

early_stopping = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=20, restore_best_weights=True)
checkpoint = tf.keras.callbacks.ModelCheckpoint("model_learn_best_cifar10.keras", save_best_only=True, monitor='val_loss')
reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=10)
timing_cb = TimingCallback()

#############################
# Build and Train Model
#############################
model = build_resnet20()
model.compile(optimizer=optimizer,
              loss='categorical_crossentropy',
              metrics=['accuracy'])

history = model.fit(
    train_dataset,
    validation_data=val_dataset,
    epochs=epochs,
    callbacks=[timing_cb, early_stopping, checkpoint, reduce_lr],
    verbose=1
)

# Evaluate on test set
test_loss, test_acc = model.evaluate(test_dataset, verbose=0)
weights = np.concatenate([w.numpy().flatten() for w in model.trainable_variables if w.dtype.is_floating])
sparsity = np.mean(np.abs(weights) < 1e-5)
total_time = np.sum(timing_cb.train_times)

print("Learnable Activation ResNet-20 Model:")
print(f"Test Accuracy: {test_acc:.4f}")
print(f"Sparsity: {sparsity:.4f}")
print(f"Total Training Time: {total_time:.2f}s")

# Plot Accuracy Curves
plt.figure(figsize=(12,5))
plt.plot(history.history['accuracy'], label='Train Accuracy')
plt.plot(history.history['val_accuracy'], label='Val Accuracy')
plt.title('Accuracy Over Epochs')
plt.legend()
plt.show()

# Plot Loss Curves
plt.figure(figsize=(12,5))
plt.plot(history.history['loss'], label='Train Loss')
plt.plot(history.history['val_loss'], label='Val Loss')
plt.title('Loss Over Epochs')
plt.legend()
plt.show()


Epoch 1/200
[1m352/352[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 73ms/step - accuracy: 0.2240 - loss: 4.1434Epoch 1/200, Time: 67.90s, Loss: 3.7790, Val_Loss: 3.4143, Val_Acc: 0.1818
[1m352/352[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m69s[0m 112ms/step - accuracy: 0.2242 - loss: 4.1423 - val_accuracy: 0.1818 - val_loss: 3.4143 - learning_rate: 0.1000
Epoch 2/200
[1m351/352[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 32ms/step - accuracy: 0.3438 - loss: 2.9016Epoch 2/200, Time: 39.36s, Loss: 2.7297, Val_Loss: 2.7343, Val_Acc: 0.2432
[1m352/352[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m38s[0m 34ms/step - accuracy: 0.3439 - loss: 2.9006 - val_accuracy: 0.2432 - val_loss: 2.7343 - learning_rate: 0.1000
Epoch 3/200
[1m350/352[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 32ms/step - accuracy: 0.3600 - loss: 2.3373Epoch 3/200, Time: 12.20s, Loss: 2.2509, Val_Loss: 2.2284, Val_Acc: 0.3302
[1m352/352[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [

TypeError: This optimizer was created with a `LearningRateSchedule` object as its `learning_rate` constructor argument, hence its learning rate is not settable. If you need the learning rate to be settable, you should instantiate the optimizer with a float `learning_rate` argument.

In [None]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import time

#############################
# Configuration and Setup
#############################
tf.random.set_seed(42)
np.random.seed(42)

batch_size = 128
num_classes = 10
epochs = 200
weight_decay = 5e-4
initial_lr = 0.1

# CIFAR-10 mean/std for normalization
cifar10_mean = [0.4914, 0.4822, 0.4465]
cifar10_std = [0.2470, 0.2435, 0.2616]

#############################
# Data Loading and Preprocessing
#############################
(x_train_full, y_train_full), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
y_train_full = tf.keras.utils.to_categorical(y_train_full, num_classes)
y_test = tf.keras.utils.to_categorical(y_test, num_classes)

# Split off a validation set
val_size = 5000
x_val = x_train_full[:val_size]
y_val = y_train_full[:val_size]
x_train = x_train_full[val_size:]
y_train = y_train_full[val_size:]

x_train = x_train.astype('float32') / 255.
x_val = x_val.astype('float32') / 255.
x_test = x_test.astype('float32') / 255.

def normalize(x):
    x = (x - cifar10_mean) / cifar10_std
    return x

data_augmentation = tf.keras.Sequential([
    tf.keras.layers.RandomCrop(32, 32),
    tf.keras.layers.RandomFlip("horizontal")
])

#############################
# Channel-Wise Learnable Activation
#############################
class ChannelWiseLearnableActivation(tf.keras.layers.Layer):
    def __init__(self, hidden_units=16, wd=1e-4):
        super(ChannelWiseLearnableActivation, self).__init__()
        self.hidden_units = hidden_units
        self.wd = wd

    def build(self, input_shape):
        c = input_shape[-1]
        self.dense1 = tf.keras.layers.Dense(
            self.hidden_units,
            activation='relu',
            kernel_initializer='he_normal',
            kernel_regularizer=tf.keras.regularizers.l2(self.wd)
        )
        self.dense2 = tf.keras.layers.Dense(
            c,
            activation='sigmoid',
            kernel_initializer='he_normal',
            kernel_regularizer=tf.keras.regularizers.l2(self.wd)
        )
        super().build(input_shape)

    def call(self, inputs):
        # inputs: (batch, h, w, c)
        x_mean = tf.reduce_mean(inputs, axis=[1, 2])  # (batch, c)
        x_hidden = self.dense1(x_mean)  # (batch, hidden_units)
        scale = self.dense2(x_hidden)   # (batch, c)
        scale = tf.reshape(scale, [-1, 1, 1, tf.shape(scale)[-1]])
        return inputs * scale

#############################
# ResNet Building Blocks
#############################
regularizer = tf.keras.regularizers.l2(weight_decay)
initializer = tf.keras.initializers.HeNormal()

def conv3x3(filters, stride=1):
    return tf.keras.layers.Conv2D(filters, kernel_size=3, strides=stride, padding='same',
                                  kernel_initializer=initializer,
                                  kernel_regularizer=regularizer, use_bias=False)

class NoOpLayer(tf.keras.layers.Layer):
    def call(self, x, training=None):
        return x

class ResidualBlock(tf.keras.layers.Layer):
    def __init__(self, filters, stride=1):
        super(ResidualBlock, self).__init__()
        self.filters = filters
        self.stride = stride

        self.conv1 = conv3x3(filters, stride)
        self.bn1 = tf.keras.layers.BatchNormalization()
        self.act1 = ChannelWiseLearnableActivation(hidden_units=16, wd=weight_decay)

        self.conv2 = conv3x3(filters)
        self.bn2 = tf.keras.layers.BatchNormalization()

        if stride != 1:
            self.shortcut = tf.keras.Sequential([
                tf.keras.layers.Conv2D(filters, kernel_size=1, strides=stride,
                                       kernel_initializer=initializer,
                                       kernel_regularizer=regularizer, use_bias=False),
                tf.keras.layers.BatchNormalization()
            ])
        else:
            self.shortcut = NoOpLayer()

        self.act2 = ChannelWiseLearnableActivation(hidden_units=16, wd=weight_decay)

    def call(self, x, training=False):
        shortcut = self.shortcut(x, training=training)

        out = self.conv1(x, training=training)
        out = self.bn1(out, training=training)
        out = self.act1(out)

        out = self.conv2(out, training=training)
        out = self.bn2(out, training=training)
        out = out + shortcut
        out = self.act2(out)
        return out

def make_layer(filters, num_blocks, stride=1):
    layers = []
    layers.append(ResidualBlock(filters, stride))
    for _ in range(1, num_blocks):
        layers.append(ResidualBlock(filters, stride=1))
    return tf.keras.Sequential(layers)

def build_resnet20():
    inputs = tf.keras.Input(shape=(32,32,3))
    x = conv3x3(16)(inputs)
    x = tf.keras.layers.BatchNormalization()(x)
    x = ChannelWiseLearnableActivation(hidden_units=16, wd=weight_decay)(x)

    # ResNet-20: 3 groups of residual blocks, each with 3 blocks
    x = make_layer(16, 3, stride=1)(x)
    x = make_layer(32, 3, stride=2)(x)
    x = make_layer(64, 3, stride=2)(x)

    x = tf.keras.layers.GlobalAveragePooling2D()(x)
    outputs = tf.keras.layers.Dense(num_classes, activation='softmax',
                                    kernel_initializer=initializer,
                                    kernel_regularizer=regularizer)(x)

    model = tf.keras.Model(inputs, outputs)
    return model

#############################
# Prepare Datasets
#############################
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = train_dataset.shuffle(buffer_size=50000).batch(batch_size)
train_dataset = train_dataset.map(lambda x,y: (data_augmentation(x, training=True), y))
train_dataset = train_dataset.map(lambda x,y: (normalize(x), y))
train_dataset = train_dataset.prefetch(tf.data.AUTOTUNE)

val_dataset = tf.data.Dataset.from_tensor_slices((x_val, y_val))
val_dataset = val_dataset.batch(batch_size)
val_dataset = val_dataset.map(lambda x,y: (normalize(x), y))
val_dataset = val_dataset.prefetch(tf.data.AUTOTUNE)

test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test))
test_dataset = test_dataset.batch(batch_size)
test_dataset = test_dataset.map(lambda x,y: (normalize(x), y))
test_dataset = test_dataset.prefetch(tf.data.AUTOTUNE)

#############################
# Learning Rate Scheduler (Cosine Decay)
#############################
steps_per_epoch = len(x_train) // batch_size
total_steps = steps_per_epoch * epochs

lr_schedule = tf.keras.optimizers.schedules.CosineDecay(
    initial_learning_rate=initial_lr,
    decay_steps=total_steps,
    alpha=0.001
)

optimizer = tf.keras.optimizers.SGD(learning_rate=lr_schedule, momentum=0.9, nesterov=True)

#############################
# Training and Callbacks
#############################
class TimingCallback(tf.keras.callbacks.Callback):
    def on_train_begin(self, logs=None):
        self.train_times = []
        self.start_time = time.time()

    def on_epoch_end(self, epoch, logs=None):
        end_time = time.time()
        epoch_time = end_time - self.start_time
        self.train_times.append(epoch_time)
        self.start_time = end_time
        print(f"Epoch {epoch+1}/{epochs}, Time: {epoch_time:.2f}s, "
              f"Loss: {logs.get('loss',0):.4f}, Val_Loss: {logs.get('val_loss',0):.4f}, "
              f"Val_Acc: {logs.get('val_accuracy',0):.4f}")

early_stopping = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=20, restore_best_weights=True)
checkpoint = tf.keras.callbacks.ModelCheckpoint("model_learn_best_cifar10.keras", save_best_only=True, monitor='val_loss')
timing_cb = TimingCallback()

#############################
# Build and Train Model
#############################
model = build_resnet20()
model.compile(optimizer=optimizer,
              loss='categorical_crossentropy',
              metrics=['accuracy'])

history = model.fit(
    train_dataset,
    validation_data=val_dataset,
    epochs=epochs,
    callbacks=[timing_cb, early_stopping, checkpoint],
    verbose=1
)

# Evaluate on test set
test_loss, test_acc = model.evaluate(test_dataset, verbose=0)
weights = np.concatenate([w.numpy().flatten() for w in model.trainable_variables if w.dtype.is_floating])
sparsity = np.mean(np.abs(weights) < 1e-5)
total_time = np.sum(timing_cb.train_times)

print("Learnable Activation ResNet-20 Model:")
print(f"Test Accuracy: {test_acc:.4f}")
print(f"Sparsity: {sparsity:.4f}")
print(f"Total Training Time: {total_time:.2f}s")

# Plot Accuracy Curves
plt.figure(figsize=(12,5))
plt.plot(history.history['accuracy'], label='Train Accuracy')
plt.plot(history.history['val_accuracy'], label='Val Accuracy')
plt.title('Accuracy Over Epochs')
plt.legend()
plt.show()

# Plot Loss Curves
plt.figure(figsize=(12,5))
plt.plot(history.history['loss'], label='Train Loss')
plt.plot(history.history['val_loss'], label='Val Loss')
plt.title('Loss Over Epochs')
plt.legend()
plt.show()


Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
[1m170498071/170498071[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 0us/step
Epoch 1/200
[1m352/352[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 68ms/step - accuracy: 0.2117 - loss: 4.3689Epoch 1/200, Time: 64.12s, Loss: 4.0342, Val_Loss: 3.5841, Val_Acc: 0.1706
[1m352/352[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m64s[0m 101ms/step - accuracy: 0.2118 - loss: 4.3680 - val_accuracy: 0.1706 - val_loss: 3.5841
Epoch 2/200
[1m352/352[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 31ms/step - accuracy: 0.3220 - loss: 3.0927Epoch 2/200, Time: 11.68s, Loss: 2.8940, Val_Loss: 2.8141, Val_Acc: 0.2340
[1m352/352[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 33ms/step - accuracy: 0.3220 - loss: 3.0922 - val_accuracy: 0.2340 - val_loss: 2.8141
Epoch 3/200
[1m351/352[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 33ms/step - accuracy: 0.3468 - loss: 2.4387Epoch 3/200, Ti

AttributeError: 'str' object has no attribute 'is_floating'

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms

####################################
# Hyperparameters
####################################
batch_size = 128
lr = 0.01
epochs = 2  # reduced epochs for demonstration
alpha = 0.1
equilibrium_iters = 5

####################################
# Data Loading
####################################
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914,0.4822,0.4465),(0.2470,0.2435,0.2616))
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, transform=transform, download=True)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, transform=transform, download=True)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)

####################################
# Model Definition
####################################
class SmallCNN(nn.Module):
    def __init__(self, num_classes=10):
        super(SmallCNN, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, 3, padding=1),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d((4,4)),
        )
        self.fc = nn.Linear(64*4*4, num_classes)

    def forward(self, x):
        f = self.features(x)
        f_flat = f.view(f.size(0), -1)
        # The fully connected layer gives an initial guess if needed, not strictly required
        out = self.fc(f_flat)
        return f_flat, out

model = SmallCNN().cuda()
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)

####################################
# Energy Function
####################################
def energy(x, target, alpha, f):
    # E(x) = CrossEntropy(x, target) + alpha * ||x||^2
    ce = F.cross_entropy(x, target)
    reg = (x**2).mean()
    return ce + alpha * reg

####################################
# Solve Equilibrium
####################################
@torch.enable_grad()
def solve_equilibrium(f, target, alpha):
    # We initialize x at zero with requires_grad=True
    x = torch.zeros(target.size(0), 10, device=target.device, requires_grad=True)

    # Perform gradient descent steps on x
    # We need create_graph=True to allow differentiation w.r.t. model parameters later
    for i in range(equilibrium_iters):
        E = energy(x, target, alpha, f)
        grad_x = torch.autograd.grad(E, x, retain_graph=True, create_graph=True)[0]
        x = x - 0.1 * grad_x
        # x here is a new tensor resulting from an operation, but still has a grad_fn
        # No need to manually set requires_grad=True; x will have grad_fn attached.

    return x

####################################
# Training Loop
####################################
for epoch in range(epochs):
    model.train()
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.cuda(), targets.cuda()
        optimizer.zero_grad()

        f, _ = model(inputs)
        x_star = solve_equilibrium(f, targets, alpha)

        # Now we compute E_final and backprop into model parameters
        E_final = energy(x_star, targets, alpha, f)
        E_final.backward()
        optimizer.step()

    # Evaluate on test set
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, targets in testloader:
            inputs, targets = inputs.cuda(), targets.cuda()
            f, _ = model(inputs)
            # Even at test time, we solve for equilibrium
            x_star = solve_equilibrium(f, targets, alpha)
            pred = x_star.argmax(dim=1)
            correct += pred.eq(targets).sum().item()
            total += targets.size(0)
    acc = 100. * correct / total
    print(f"Epoch {epoch+1}/{epochs}, Test Accuracy: {acc:.2f}%")

print("Training complete.")


Files already downloaded and verified
Files already downloaded and verified
Epoch 1/2, Test Accuracy: 100.00%
Epoch 2/2, Test Accuracy: 100.00%
Training complete.


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms

####################################
# Hyperparameters
####################################
batch_size = 128
lr = 0.01
epochs = 10
alpha = 0.1
equilibrium_iters = 5

####################################
# Data Loading
####################################
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914,0.4822,0.4465),(0.2470,0.2435,0.2616))
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, transform=transform, download=True)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, transform=transform, download=True)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)

####################################
# Model Definition
####################################
class SmallCNN(nn.Module):
    def __init__(self, num_classes=10):
        super(SmallCNN, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, 3, padding=1),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d((4,4)),
        )
        self.fc = nn.Linear(64*4*4, num_classes)

    def forward(self, x):
        f = self.features(x)
        f_flat = f.view(f.size(0), -1)
        out = self.fc(f_flat)
        return f_flat, out

model = SmallCNN().cuda()
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)

####################################
# Energy Function
####################################
def energy(x, target, alpha, f):
    # E(x) = CrossEntropy(x, target) + alpha * ||x||^2
    ce = F.cross_entropy(x, target)
    reg = (x**2).mean()
    return ce + alpha * reg

####################################
# Solve Equilibrium (Train time)
####################################
@torch.enable_grad()
def solve_equilibrium_train(f, target, alpha):
    # Initialize x with requires_grad=True
    x = torch.zeros(target.size(0), 10, device=target.device, requires_grad=True)
    for i in range(equilibrium_iters):
        E = energy(x, target, alpha, f)
        grad_x = torch.autograd.grad(E, x, retain_graph=True, create_graph=True)[0]
        x = x - 0.1 * grad_x
        # No need to manually set requires_grad here; x already has grad_fn.
    return x

####################################
# Training Loop
####################################
for epoch in range(epochs):
    model.train()
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.cuda(), targets.cuda()
        optimizer.zero_grad()

        f, _ = model(inputs)
        x_star = solve_equilibrium_train(f, targets, alpha)

        # Compute final energy and backprop w.r.t. model parameters
        E_final = energy(x_star, targets, alpha, f)
        E_final.backward()
        optimizer.step()

    # Evaluate on test set
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, targets in testloader:
            inputs, targets = inputs.cuda(), targets.cuda()
            f, out = model(inputs)
            # At test time, we do not solve for equilibrium with targets.
            # We just use model's direct output:
            pred = out.argmax(dim=1)
            correct += pred.eq(targets).sum().item()
            total += targets.size(0)
    acc = 100. * correct / total
    print(f"Epoch {epoch+1}/{epochs}, Test Accuracy: {acc:.2f}%")

print("Training complete.")


Files already downloaded and verified
Files already downloaded and verified
Epoch 1/10, Test Accuracy: 10.36%
Epoch 2/10, Test Accuracy: 10.36%
Epoch 3/10, Test Accuracy: 10.36%
Epoch 4/10, Test Accuracy: 10.36%
Epoch 5/10, Test Accuracy: 10.36%
Epoch 6/10, Test Accuracy: 10.36%
Epoch 7/10, Test Accuracy: 10.36%
Epoch 8/10, Test Accuracy: 10.36%
Epoch 9/10, Test Accuracy: 10.36%
Epoch 10/10, Test Accuracy: 10.36%
Training complete.


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms

####################################
# Hyperparameters
####################################
batch_size = 128
lr = 0.01
epochs = 10
alpha = 0.01
equilibrium_iters = 10

####################################
# Data Loading
####################################
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914,0.4822,0.4465),(0.2470,0.2435,0.2616))
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, transform=transform, download=True)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, transform=transform, download=True)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)

####################################
# Model Definition
####################################
class SmallCNN(nn.Module):
    def __init__(self, num_classes=10):
        super(SmallCNN, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, 3, padding=1),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d((4,4)),
        )
        self.fc = nn.Linear(64*4*4, num_classes)

    def forward(self, x):
        f = self.features(x)
        f_flat = f.view(f.size(0), -1)
        out = self.fc(f_flat)
        return f_flat, out

model = SmallCNN().cuda()
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)

####################################
# Energy Function
####################################
def energy(x, out, target, alpha):
    # Modified energy:
    # E(x) = CE(x, target) + alpha * ||x - out||^2
    ce = F.cross_entropy(x, target)
    diff = x - out
    reg = (diff**2).mean()
    return ce + alpha * reg

####################################
# Solve Equilibrium (Train time)
####################################
@torch.enable_grad()
def solve_equilibrium_train(out, target, alpha):
    # Initialize x from model output out
    x = out.clone().detach().requires_grad_(True)

    for i in range(equilibrium_iters):
        E = energy(x, out, target, alpha)
        grad_x = torch.autograd.grad(E, x, retain_graph=True, create_graph=True)[0]
        x = x - 0.1 * grad_x
        # No need to manually set requires_grad again; x from operation will have grad_fn.
        x.requires_grad_(True)
    return x

####################################
# Training Loop
####################################
for epoch in range(epochs):
    model.train()
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.cuda(), targets.cuda()
        optimizer.zero_grad()

        f, out = model(inputs)
        x_star = solve_equilibrium_train(out, targets, alpha)

        # Compute final energy and backprop
        E_final = energy(x_star, out, targets, alpha)
        E_final.backward()
        optimizer.step()

    # Evaluate on test set without equilibrium, just use model outputs
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, targets in testloader:
            inputs, targets = inputs.cuda(), targets.cuda()
            f, out = model(inputs)
            # Direct model output as prediction
            pred = out.argmax(dim=1)
            correct += pred.eq(targets).sum().item()
            total += targets.size(0)
    acc = 100. * correct / total
    print(f"Epoch {epoch+1}/{epochs}, Test Accuracy: {acc:.2f}%")

print("Training complete.")


Files already downloaded and verified
Files already downloaded and verified
Epoch 1/10, Test Accuracy: 10.85%
Epoch 2/10, Test Accuracy: 10.85%
Epoch 3/10, Test Accuracy: 10.87%
Epoch 4/10, Test Accuracy: 10.88%
Epoch 5/10, Test Accuracy: 10.87%
Epoch 6/10, Test Accuracy: 10.86%
Epoch 7/10, Test Accuracy: 10.86%
Epoch 8/10, Test Accuracy: 10.85%
Epoch 9/10, Test Accuracy: 10.84%
Epoch 10/10, Test Accuracy: 10.84%
Training complete.


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torchvision.models import resnet18
from torch.optim.lr_scheduler import StepLR

####################################
# Hyperparameters
####################################
batch_size = 128
lr = 0.1
epochs = 50
alpha = 0.001
equilibrium_iters = 20
equilibrium_step_size = 0.1  # step size for x updates

####################################
# Data Loading
####################################
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914,0.4822,0.4465),(0.2470,0.2435,0.2616))
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914,0.4822,0.4465),(0.2470,0.2435,0.2616))
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, transform=transform_train, download=True)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, transform=transform_test, download=True)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)

####################################
# Model Definition
####################################
# Use ResNet-18 and adjust final layer for CIFAR-10
model = resnet18(weights=None)
model.fc = nn.Linear(model.fc.in_features, 10)
model = model.cuda()

optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4)
scheduler = StepLR(optimizer, step_size=20, gamma=0.1)  # Reduce LR every 20 epochs

####################################
# Energy Function
####################################
def energy(x, out, target, alpha):
    # E(x) = CE(x, target) + alpha * ||x - out||^2
    ce = F.cross_entropy(x, target)
    diff = x - out
    reg = (diff**2).mean()
    return ce + alpha * reg

####################################
# Solve Equilibrium (Train time)
####################################
@torch.enable_grad()
def solve_equilibrium_train(out, target, alpha):
    # Initialize x from model output out
    x = out.clone().detach().requires_grad_(True)

    for i in range(equilibrium_iters):
        E = energy(x, out, target, alpha)
        grad_x = torch.autograd.grad(E, x, retain_graph=True, create_graph=True)[0]
        x = x - equilibrium_step_size * grad_x
        x.requires_grad_(True)
    return x

####################################
# Training and Evaluation
####################################
for epoch in range(epochs):
    model.train()
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.cuda(), targets.cuda()
        optimizer.zero_grad()

        out = model(inputs)
        # Equilibrium solving for training
        x_star = solve_equilibrium_train(out, targets, alpha)

        # Compute final energy and backprop
        E_final = energy(x_star, out, targets, alpha)
        E_final.backward()
        optimizer.step()

    scheduler.step()

    # Evaluate on test set (no equilibrium, just direct output)
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, targets in testloader:
            inputs, targets = inputs.cuda(), targets.cuda()
            out = model(inputs)
            pred = out.argmax(dim=1)
            correct += pred.eq(targets).sum().item()
            total += targets.size(0)
    acc = 100. * correct / total
    print(f"Epoch {epoch+1}/{epochs}, Test Accuracy: {acc:.2f}%")

print("Training complete.")


Files already downloaded and verified
Files already downloaded and verified
Epoch 1/50, Test Accuracy: 9.89%
Epoch 2/50, Test Accuracy: 10.07%
Epoch 3/50, Test Accuracy: 10.31%
Epoch 4/50, Test Accuracy: 10.40%
Epoch 5/50, Test Accuracy: 10.62%
Epoch 6/50, Test Accuracy: 10.97%
Epoch 7/50, Test Accuracy: 11.18%
Epoch 8/50, Test Accuracy: 11.47%
Epoch 9/50, Test Accuracy: 11.66%
Epoch 10/50, Test Accuracy: 11.69%
Epoch 11/50, Test Accuracy: 11.68%
Epoch 12/50, Test Accuracy: 11.35%
Epoch 13/50, Test Accuracy: 10.27%
Epoch 14/50, Test Accuracy: 9.88%
Epoch 15/50, Test Accuracy: 10.32%
Epoch 16/50, Test Accuracy: 9.99%
Epoch 17/50, Test Accuracy: 10.00%
Epoch 18/50, Test Accuracy: 10.00%
Epoch 19/50, Test Accuracy: 10.00%
Epoch 20/50, Test Accuracy: 10.00%
Epoch 21/50, Test Accuracy: 10.00%
Epoch 22/50, Test Accuracy: 10.00%
Epoch 23/50, Test Accuracy: 10.00%
Epoch 24/50, Test Accuracy: 10.00%
Epoch 25/50, Test Accuracy: 10.00%
Epoch 26/50, Test Accuracy: 10.00%
Epoch 27/50, Test Accuracy