In [1]:
%env CUDA_VISIBLE_DEVICES=4

env: CUDA_VISIBLE_DEVICES=4


In [2]:
from torchvision.models import resnet18
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

# Standard normalization for CIFAR-10
NORM_MEAN = [0.4914, 0.4822, 0.4465]
NORM_STD = [0.2023, 0.1994, 0.2010]

# Define transforms: Resize images to 224x224 to fit standard ResNet input
# ResNet was designed for 224x224 ImageNet images, so we resize the small CIFAR images.
transform = transforms.Compose([
    transforms.Resize(224), 
    transforms.ToTensor(),
    transforms.Normalize(NORM_MEAN, NORM_STD)
])

# Load and Download CIFAR-10 (downloads automatically if not found)
BATCH_SIZE = 128

# Training Set
train_dataset = datasets.CIFAR10(
    root='./data', 
    train=True, 
    download=True, 
    transform=transform
)
train_loader = DataLoader(
    train_dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=True, 
    num_workers=2
)


# Test/Validation Set
test_dataset = datasets.CIFAR10(
    root='./data', 
    train=False, 
    download=True, 
    transform=transform
)
test_loader = DataLoader(
    test_dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=False, 
    num_workers=2
)

In [6]:
device = 0
model = resnet18().to(device)

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from tqdm.auto import tqdm

def train_one_epoch(model, dataloader, criterion, optimizer, device):
    # Set the model to training mode
    model.train()
    
    # Iterate over the data loader
    for inputs, targets in tqdm(dataloader):
        # Move data to the specified device (e.g., CUDA or CPU)
        inputs, targets = inputs.to(device), targets.to(device)
        
        # 1. Zero the gradients
        optimizer.zero_grad()
        
        # 2. Forward pass: compute predicted outputs
        outputs = model(inputs)
        
        # 3. Compute loss
        loss = criterion(outputs, targets)
        print(loss)
        
        # 4. Backward pass: compute gradient of the loss w.r.t model parameters
        loss.backward()
        
        # 5. Update weights
        optimizer.step()
        
    return loss.item() # Return the last batch loss

# --- Main loop setup ---
# Example setup (you must replace these with actual ImageNet components)
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model.to(device) 

# Number of epochs to train for
NUM_EPOCHS = 1

criterion = nn.CrossEntropyLoss()

# optimizer = optim.AdamW(
#     model.parameters(), 
#     lr=1e-3,
#     betas=(0.9, 0.999), 
#     eps=1e-8,
#     weight_decay=1e-4 
# )
optimizer = optim.SGD(
    model.parameters(),
    lr=1e-3,
    weight_decay=1e-4,
)
for epoch in range(NUM_EPOCHS):
    avg_loss = train_one_epoch(model, train_loader, criterion, optimizer, device)
    print(f"Epoch {epoch+1}/{NUM_EPOCHS}, Loss: {avg_loss:.4f}")

### Evolutionary updates

In [9]:
# train for a bit on subsampled dataset
import numpy as np
BATCH_SIZE = 128

indices = np.random.choice(10_000, 10_000, replace=False)
split1_batches = 10
split1_dl = DataLoader(
    [train_dataset[i] for i in indices[:BATCH_SIZE * split1_batches]],
    batch_size=BATCH_SIZE,
    num_workers=2,
)
split2_dl = DataLoader(
    [train_dataset[i] for i in indices[BATCH_SIZE * split1_batches:]],
    batch_size=BATCH_SIZE,
    num_workers=2,
)

In [10]:
# validation set
val_dataset = datasets.CIFAR10(
    root='./data', 
    train=False, 
    download=True, 
    transform=transform
)
val_loader = DataLoader(
    val_dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=False, 
    num_workers=2
)

In [185]:
model = resnet18().to(device=device)  # reinitialize
_ = model.eval()
model.load_state_dict(torch.load('tmp.pt'))

<All keys matched successfully>

In [71]:
torch.save(model.state_dict(), 'tmp.pt') 

In [18]:
# train for a bit on subsampled dataset with standard grad descent
optimizer = optim.AdamW(
    model.parameters(), 
    lr=1e-3,
    betas=(0.9, 0.999), 
    eps=1e-8,
    weight_decay=1e-4 
)
train_one_epoch(model, split1_dl, criterion, optimizer, device)

  0%|          | 0/10 [00:00<?, ?it/s]

tensor(6.8719, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(5.5283, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(4.3273, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(3.5862, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(2.9705, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(2.7260, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(2.2442, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(2.1571, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(2.1348, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(2.0373, device='cuda:0', grad_fn=<NllLossBackward0>)


2.0372564792633057

In [159]:
from dataclasses import dataclass, field
import math
from typing import Iterable


@dataclass
class EvMutation:
    param_seeds: list[int] = field(default_factory=list)
    is_identity: bool = False
    reward: float | None = None
    

class EvOptimizer:

    def __init__(
        self,
        params: Iterable[nn.parameter.Parameter],
        step_size=1e-5,
        lr=1e-3,
        n_mutations=64,
        select_max=False,
        allow_skip_mutation=True,
        persist_parent=True,
    ):
        self.params = list(params)
        self.mutations: list[EvMutation] = []
        self.active_mutation: EvMutation | None = None
        self.n_mutations = n_mutations
        self.step_size = step_size
        self.lr = lr
        self.select_max = select_max
        self.allow_skip_mutation = allow_skip_mutation
        self.persist_parent = persist_parent
        self.parent_params = None
        if persist_parent:
            # create separate parameter list with shared data
            self.parent_params = [p.clone() for p in self.params]
            for p, p_parent in zip(self.params, self.parent_params):
                p_parent.data = p.data

    def __enter__(self):
        return self

    def __exit__(self, *args, **kwargs):
        if (
            self.active_mutation is not None and
            self.active_mutation.reward is None
        ):
            raise RuntimeError("exit before collecting reward for active mutation")
        
        # aggregate current mutation set (in case n_mutations
        # is not even with number of loop iterations)
        if self.mutation_index > -1:
            self.aggregate_mutations()

    def _param_delta_iter(self, param_seeds: list[int] | None = None):
        if param_seeds is None:
            param_seeds = [None] * len(self.params)
        elif len(param_seeds) != len(self.params):
            raise RuntimeError("mismatch between number of params and seeds")
        for p, seed in zip(self.params, param_seeds):
            # use cpu random number generator for stability
            if seed is None:
                seed = torch.seed()
            else:
                torch.manual_seed(seed)
            
            yield torch.randn(p.shape, dtype=p.dtype).to(p.device) * self.step_size, seed

    @property
    def mutation_index(self):
        return len(self.mutations) - 1

    @torch.no_grad()
    def revert_mutation(self):
        if not self.active_mutation.is_identity:
            if self.persist_parent:
                # reset data ptr to current parent (non-mutated) tensor
                for p, p_parent in zip(self.params, self.parent_params):
                    p.data = p_parent.data
            else:
                # regenerate deltas from param seeds and subtract to get non-mutated tensor
                for p, (delta_p, _) in zip(self.params, self._param_delta_iter(self.active_mutation.param_seeds)):
                    p -= delta_p
        self.active_mutation = None

    @torch.no_grad()
    def mutate(self):
        if self.allow_skip_mutation and self.mutation_index == -1:
            # save mutation to reference current unperturbed weights
            new_mutation = EvMutation(is_identity=True)
        else:
            # make new mutation
            param_seeds = []
            for p, (delta_p, seed) in zip(self.params, self._param_delta_iter()):
                if self.persist_parent:
                    # can't use inplace if we have shared data ptr with current parent
                    # instead we create a new tensor from applied delta and update ptr
                    # of mutated weight to point to it
                    p.data = p + delta_p
                else:
                    p += delta_p
                param_seeds.append(seed)
            new_mutation = EvMutation(param_seeds)

        self.mutations.append(new_mutation)
        self.active_mutation = new_mutation

    def reward_step(self, reward: float):
        self.active_mutation.reward = reward
        self.revert_mutation()

        if len(self.mutations) % self.n_mutations == 0:
            self.aggregate_mutations()

    @torch.no_grad()
    def aggregate_mutations(self):
        if self.active_mutation is not None:
            raise RuntimeError("cannot aggregate while mutation is still active")

        # print("aggregating mutations...", end="")

        if self.select_max:
            # get maximum reward candidate
            max_r = float('-inf')
            max_m = None
            for m in self.mutations:
                if m.reward > max_r:
                    max_r = m.reward
                    max_m = m
            if not max_m.is_identity:
                for p, (p_delta, seed) in zip(self.params, self._param_delta_iter(max_m.param_seeds)):
                    # if persist_parent is True, p and p_parent should share data
                    # so we need only update one
                    p += p_delta

        else:       
            # weighted avg mutations based on their reward z-score
            # all mutations should have reward set by now
            n = len(self.mutations)
            mean_reward = sum(m.reward for m in self.mutations) / n
            var_reward = sum((m.reward - mean_reward) ** 2 for m in self.mutations) / n
            z_scores = [(m.reward - mean_reward) / (var_reward ** 0.5) for m in self.mutations if not m.is_identity]
    
            deltas = [self._param_delta_iter(m.param_seeds) for m in self.mutations if not m.is_identity]
            for p in self.params:
                p_deltas, _ = zip(*[next(d) for d in deltas])
    
                for p_delta, z in zip(p_deltas, z_scores):
                    p += p_delta * self.lr * z / n

        self.mutations = []
        # print("done")

In [186]:
CANDIDATE_SET_SIZE = 16
UPDATES_PER_BATCH = 1
# each candidate is evaluated on same batch during each update, UPDATES_PER_BATCH updates per batch
ITERS_PER_BATCH = CANDIDATE_SET_SIZE * UPDATES_PER_BATCH
LR = 1
STEP_SIZE = 0.01 #0.004 #0.001
SELECT_MAX = False
PERSIST_PARENT = True

In [187]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from tqdm.auto import tqdm

@torch.no_grad()
def train_one_epoch_ev(model, dataloader, val_loader, report_val=False):
    model.eval()  # freeze batchnorm
    val_iter = iter(val_loader)

    criterion = nn.CrossEntropyLoss()
    
    sample_param = next(model.parameters())
    dtype = sample_param.dtype
    device = sample_param.device
    
    with EvOptimizer(model.parameters(),
                     step_size=STEP_SIZE,
                     lr=LR,
                     n_mutations=CANDIDATE_SET_SIZE,
                     select_max=SELECT_MAX,
                     persist_parent=PERSIST_PARENT) as optimizer:
        avg_loss = 0
        min_loss = 1000
        step_count = 0
        pbar = tqdm(total=len(dataloader) * ITERS_PER_BATCH)
        for i, (inputs, targets) in enumerate(dataloader):
            inputs, targets = inputs.to(device=device, dtype=dtype), targets.to(device=device)
            print(f"==================== batch {i} ====================")
            print("    step    |  avg loss  |  min loss", end="")
            if report_val:
                print("  |  val loss")

            for j in range(ITERS_PER_BATCH):
                optimizer.mutate()
                
                outputs = model(inputs)
                
                loss = criterion(outputs, targets).item()
                avg_loss += loss
                min_loss = loss if loss < min_loss else min_loss

                # if final mutation, mutation_index will reset
                optimizer.reward_step(-loss)

                
                # report aggregate training loss across all mutations
                # (this will be higher than loss for aggregated model)
                if optimizer.mutation_index == -1:
                    avg_loss /= CANDIDATE_SET_SIZE
                    step_digits = len(str(step_count))
                    l_pad = (12 - step_digits) // 2
                    r_pad = 12 - l_pad - step_digits
                    print(f"{' ' * r_pad}{step_count}{' ' * l_pad}|  {avg_loss:1.6f}  |  {min_loss:1.6f}", end="")

                    if report_val:
                        val_inputs, val_targets = next(val_iter)
                        val_inputs, val_targets = val_inputs.to(device=device, dtype=dtype), val_targets.to(device=device)
                        val_outputs = model(val_inputs)
                        val_loss = criterion(val_outputs, val_targets).item()
                        print(f"  |  {val_loss:1.6f}")

                    avg_loss = 0
                    min_loss = 1000
                    step_count += 1

                pbar.update(1)
        
    return loss

train_one_epoch_ev(model, split2_dl, val_loader, report_val=True)

  0%|          | 0/1104 [00:00<?, ?it/s]

    step    |  avg loss  |  min loss  |  val loss
aggregating mutations...done
      0     |  8.405111  |  7.371973  |  7.261257
    step    |  avg loss  |  min loss  |  val loss
aggregating mutations...done
      1     |  8.059920  |  7.167271  |  7.021247
    step    |  avg loss  |  min loss  |  val loss
aggregating mutations...done
      2     |  8.150660  |  7.129123  |  7.187538
    step    |  avg loss  |  min loss  |  val loss
aggregating mutations...done
      3     |  8.102978  |  7.098753  |  6.940623
    step    |  avg loss  |  min loss  |  val loss
aggregating mutations...done
      4     |  8.153156  |  7.019162  |  6.890968
    step    |  avg loss  |  min loss  |  val loss
aggregating mutations...done
      5     |  7.833194  |  6.940998  |  6.829600
    step    |  avg loss  |  min loss  |  val loss
aggregating mutations...done
      6     |  7.271399  |  6.614646  |  6.737214
    step    |  avg loss  |  min loss  |  val loss
aggregating mutations...done
      7     |  7.2

KeyboardInterrupt: 

### Test equivalence of reverted mutation

In [261]:
x = model.conv1.weight.clone()
with torch.no_grad():
    seed = torch.seed()
    z = model.conv1.weight.to(torch.float64)
    z += torch.randn(model.conv1.weight.shape).to(device=0, dtype=torch.float64)
    z = z.to(torch.float)
    torch.manual_seed(seed)
    z = z.to(torch.float64)
    z -= torch.randn(model.conv1.weight.shape).to(device=0, dtype=torch.float64)
(x == z.to(torch.float)).all()

tensor(False, device='cuda:0')

In [260]:
x = model.conv1.weight.clone()
with torch.no_grad():
    seed = torch.seed()
    z = model.conv1.weight.to(torch.float64)
    z += torch.randn(model.conv1.weight.shape).to(device=0, dtype=torch.float64)
    torch.manual_seed(seed)
    z -= torch.randn(model.conv1.weight.shape).to(device=0, dtype=torch.float64)
(x == z.to(torch.float)).all()

tensor(True, device='cuda:0')

In [259]:
x = model.conv1.weight.clone()
with torch.no_grad():
    r = torch.randn(model.conv1.weight.shape).to(device=0, dtype=torch.float64)
    z = model.conv1.weight.to(torch.float64)
    z += r
    z -= r
(x == z.to(torch.float)).all()

tensor(True, device='cuda:0')

Without upcast

In [349]:
x.abs().min(), r.abs().max()

(tensor(0.0005, device='cuda:0', dtype=torch.float64, grad_fn=<MinBackward1>),
 tensor(3.7207, device='cuda:0'))

In [350]:
x = model.conv1.weight.clone()
with torch.no_grad():
    r = torch.randn(model.conv1.weight.shape).to(device=0)
    z = model.conv1.weight
    z += r
    z -= r
    # z = -(-z + r)
(x == z).all(), (x - z).abs().max()

(tensor(False, device='cuda:0'),
 tensor(8.8818e-16, device='cuda:0', dtype=torch.float64,
        grad_fn=<MaxBackward1>))

In [345]:
x = model.conv1.weight.clone()
with torch.no_grad():
    r = torch.randn(model.conv1.weight.shape).to(device=0)
    z = model.conv1.weight
    z_ = z + r - r
    # z_ = z + r
    # z_ -= r
    diff = z_ - z
    print(diff.abs().max())
    neg_r = - r - diff
    z__ = z + r + neg_r

(z - z_).abs().max(), (z - z__).abs().max(), diff.abs().max()

tensor(8.8818e-16, device='cuda:0', dtype=torch.float64)


(tensor(8.8818e-16, device='cuda:0', dtype=torch.float64,
        grad_fn=<MaxBackward1>),
 tensor(2.2204e-16, device='cuda:0', dtype=torch.float64,
        grad_fn=<MaxBackward1>),
 tensor(8.8818e-16, device='cuda:0', dtype=torch.float64))