In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import torch
from legoml.utils.summary import summarize_model

[2m2025-09-17T04:44:10.973765Z[0m [[32m[1minfo     [0m] [1mFinished logging setup        [0m


In [3]:
from legoml.utils.seed import set_seed

In [4]:
device = torch.device("mps")
set_seed(42)

## data

In [45]:
import torchvision

In [46]:
class GPURresidentLoader:
    def __init__(self, images, labels, batch_size, shuffle=True, device="mps"):
        self.images = images.to(device)
        self.labels = labels.to(device)
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.num_samples = images.shape[0]
        self.num_batches = (self.num_samples + self.batch_size - 1) // self.batch_size

    def __iter__(self):
        if self.shuffle:
            self.indices = torch.randperm(self.num_samples, device=self.images.device)
        else:
            self.indices = torch.arange(self.num_samples, device=self.images.device)
        self.current_batch = 0
        return self

    def __next__(self):
        if self.current_batch >= self.num_batches:
            raise StopIteration
        
        start_idx = self.current_batch * self.batch_size
        end_idx = min(start_idx + self.batch_size, self.num_samples)
        
        batch_indices = self.indices[start_idx:end_idx]
        batch_images = self.images[batch_indices]
        batch_labels = self.labels[batch_indices]
        
        self.current_batch += 1
        return batch_images, batch_labels

    def __len__(self):
        return self.num_batches

In [17]:
data_path = "../../raw_data/"
batch_size = 128
num_workers = 2

train_ds = torchvision.datasets.CIFAR10(root=data_path, train=True, download=True)
test_ds = torchvision.datasets.CIFAR10(root=data_path, train=False, download=True)

train_images = torch.tensor(train_ds.data).permute(0, 3, 1, 2)
train_labels = torch.tensor(train_ds.targets)
test_images = torch.tensor(test_ds.data).permute(0, 3, 1, 2)
test_labels = torch.tensor(test_ds.targets)

train_loader = GPURresidentLoader(train_images, train_labels, batch_size, shuffle=True, device=device)
test_loader = GPURresidentLoader(test_images, test_labels, batch_size, shuffle=False, device=device)

In [19]:
train_loader.images.device

device(type='mps', index=0)

## airbench stuff

In [142]:
import os
from math import ceil

import torch
from torch import nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as T


#############################################
#                DataLoader                 #
#############################################

CIFAR_MEAN = torch.tensor((0.4914, 0.4822, 0.4465))
CIFAR_STD = torch.tensor((0.2470, 0.2435, 0.2616))

def batch_flip_lr(inputs):
    flip_mask = (torch.rand(len(inputs), device=inputs.device) < 0.5).view(-1, 1, 1, 1)
    return torch.where(flip_mask, inputs.flip(-1), inputs)

def batch_crop(images, crop_size):
    r = (images.size(-1) - crop_size)//2
    shifts = torch.randint(-r, r+1, size=(len(images), 2), device=images.device)
    images_out = torch.empty((len(images), 3, crop_size, crop_size), device=images.device, dtype=images.dtype)
    # The two cropping methods in this if-else produce equivalent results, but the second is faster for r > 2.
    if r <= 2:
        for sy in range(-r, r+1):
            for sx in range(-r, r+1):
                mask = (shifts[:, 0] == sy) & (shifts[:, 1] == sx)
                images_out[mask] = images[mask, :, r+sy:r+sy+crop_size, r+sx:r+sx+crop_size]
    else:
        images_tmp = torch.empty((len(images), 3, crop_size, crop_size+2*r), device=images.device, dtype=images.dtype)
        for s in range(-r, r+1):
            mask = (shifts[:, 0] == s)
            images_tmp[mask] = images[mask, :, r+s:r+s+crop_size, :]
        for s in range(-r, r+1):
            mask = (shifts[:, 1] == s)
            images_out[mask] = images_tmp[mask, :, :, r+s:r+s+crop_size]
    return images_out

def make_random_square_masks(inputs, size):
    is_even = int(size % 2 == 0)
    n,c,h,w = inputs.shape

    # seed top-left corners of squares to cutout boxes from, in one dimension each
    corner_y = torch.randint(0, h-size+1, size=(n,), device=inputs.device)
    corner_x = torch.randint(0, w-size+1, size=(n,), device=inputs.device)

    # measure distance, using the center as a reference point
    corner_y_dists = torch.arange(h, device=inputs.device).view(1, 1, h, 1) - corner_y.view(-1, 1, 1, 1)
    corner_x_dists = torch.arange(w, device=inputs.device).view(1, 1, 1, w) - corner_x.view(-1, 1, 1, 1)
    
    mask_y = (corner_y_dists >= 0) * (corner_y_dists < size)
    mask_x = (corner_x_dists >= 0) * (corner_x_dists < size)

    final_mask = mask_y * mask_x

    return final_mask

def batch_cutout(inputs, size):
    cutout_masks = make_random_square_masks(inputs, size)
    return inputs.masked_fill(cutout_masks, 0)

class CifarLoader:

    def __init__(self, path, train=True, batch_size=500, aug=None, drop_last=None, shuffle=None, altflip=False):

        data_path = os.path.join(path, 'train.pt' if train else 'test.pt')
        if not os.path.exists(data_path):
            dset = torchvision.datasets.CIFAR10(path, download=True, train=train)
            images = torch.tensor(dset.data)
            labels = torch.tensor(dset.targets)
            torch.save({'images': images, 'labels': labels, 'classes': dset.classes}, data_path)
        data = torch.load(data_path, map_location='mps')

        self.epoch = 0
        self.images, self.labels, self.classes = data['images'], data['labels'], data['classes']
        # It's faster to load+process uint8 data than to load preprocessed fp16 data
        self.images = (self.images.half() / 255).permute(0, 3, 1, 2).to(memory_format=torch.channels_last)

        self.normalize = T.Normalize(CIFAR_MEAN, CIFAR_STD)
        self.proc_images = {} # Saved results of image processing to be done on the first epoch

        self.aug = aug or {}
        for k in self.aug.keys():
            assert k in ['flip', 'translate', 'cutout'], 'Unrecognized key: %s' % k

        self.batch_size = batch_size
        self.drop_last = train if drop_last is None else drop_last
        self.shuffle = train if shuffle is None else shuffle
        self.altflip = altflip

    def __len__(self):
        return len(self.images)//self.batch_size if self.drop_last else ceil(len(self.images)/self.batch_size)
    
    def __setattr__(self, k, v):
        if k in ('images', 'labels'):
            assert self.epoch == 0, 'Changing images or labels is only unsupported before iteration.'
        super().__setattr__(k, v)

    def __iter__(self):
        print("iterating...")
        if self.epoch == 0:
            images = self.proc_images['norm'] = self.normalize(self.images)
            # Pre-flip images in order to do every-other epoch flipping scheme
            if self.aug.get('flip', False):
                images = self.proc_images['flip'] = batch_flip_lr(images)
            # Pre-pad images to save time when doing random translation
            pad = self.aug.get('translate', 0)
            if pad > 0:
                self.proc_images['pad'] = F.pad(images, (pad,)*4, 'reflect')

        if self.aug.get('translate', 0) > 0:
            images = batch_crop(self.proc_images['pad'], self.images.shape[-2])
        elif self.aug.get('flip', False):
            images = self.proc_images['flip']
        else:
            images = self.proc_images['norm']
        # Flip all images together every other epoch. This increases diversity relative to random flipping
        if self.aug.get('flip', False):
            if self.altflip:
                if self.epoch % 2 == 1:
                    images = images.flip(-1)
            else:
                images = batch_flip_lr(images)
        if self.aug.get('cutout', 0) > 0:
            images = batch_cutout(images, self.aug['cutout'])

        self.epoch += 1

        indices = (torch.randperm if self.shuffle else torch.arange)(len(images), device=images.device)
        for i in range(len(self)):
            idxs = indices[i*self.batch_size:(i+1)*self.batch_size]
            yield (images[idxs], self.labels[idxs])

In [163]:
train_loader = CifarLoader('../../raw_data/', train=True, aug=dict(flip=True, translate=4, cutout=16), batch_size=128)
test_loader = CifarLoader('../../raw_data/', train=False, batch_size=128)

## Train

In [168]:
train_loader.images[0].dtype

torch.float16

In [170]:
net = model.half().to("mps")
net = net.to(memory_format=torch.channels_last)

In [172]:
optim, sched = build_optim_and_sched(config, net, train_loader)

In [176]:
tuple(int(i * 255) for i in (0.49139968, 0.48215827, 0.44653124))

(125, 122, 113)

In [151]:
from legoml.core.step_output import StepOutput

In [152]:
config = Config(train_augmentation=False, max_epochs=20, data_root="../../raw_data/")

In [153]:
config

Config(train_bs=64, eval_bs=32, train_augmentation=False, max_epochs=20, train_log_interval=100, eval_log_interval=100, data_root='../../raw_data/')

In [154]:
def build_optim_and_sched(
    config: Config,
    model: torch.nn.Module,
    train_dl: DataLoader,
) -> tuple[torch.optim.Optimizer, lrs.LRScheduler]:
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=1e-6,
        weight_decay=0.0005,
    )
    scheduler = lrs.OneCycleLR(
        optimizer,
        epochs=config.max_epochs,
        steps_per_epoch=len(train_dl),
        max_lr=1e-2,
    )
    return optimizer, scheduler

In [155]:
from experiments.image_clf.models import ConvMixer_w256_d8_p2_k7

In [156]:
model = ConvMixer_w256_d8_p2_k7(c_in=3)

In [157]:
_ = model.to(device)

In [158]:
optim, sched = build_optim_and_sched(config, model, train_loader)

In [159]:
from experiments.step_utils import forward_and_compute_loss, backward_and_step, log_step

In [160]:
def train_step(
    engine, batch, context
):
    config: Config = context.config
    model = context.model
    loss_fn = context.loss_fn
    optimizer = context.optimizer
    device = context.device
    use_amp = context.scaler is not None

    assert optimizer is not None, "Optimizer is not set"

    model.train()
    optimizer.zero_grad(set_to_none=True)

    inputs, targets = batch
    outputs, loss = forward_and_compute_loss(
        model, loss_fn, inputs.to(PRECISION), targets, device, use_amp
    )

    # scheduler step is handled explicity via engine event handles
    backward_and_step(loss, optimizer, scaler=context.scaler)
    log_step(engine, "train", config.train_log_interval)

    return StepOutput(
        loss=loss,
        predictions=outputs.detach().cpu(),
        targets=targets.detach().cpu(),
    )



In [161]:
@torch.inference_mode()
def eval_step(
    engine, batch, context
):
    config: Config = context.config
    model = context.model
    loss_fn = context.loss_fn
    device = context.device

    model.eval()
    inputs, targets = batch
    outputs, loss = forward_and_compute_loss(model, loss_fn, inputs.to(PRECISION), targets, device)

    log_step(engine, "eval", config.eval_log_interval)

    return StepOutput(
        loss=loss,
        predictions=outputs.detach().cpu(),
        targets=targets.detach().cpu(),
    )

In [162]:
with run(base_dir=Path("runs").joinpath("train_img_clf_cifar10")) as sess:
    train_context = Context(
        config=config,
        model=model,
        loss_fn=torch.nn.CrossEntropyLoss(),
        optimizer=optim,
        scheduler=sched,
        device=device,
        scaler=torch.GradScaler(device=device.type),  # slow on M1 air
    )
    trainer = Engine(train_step, train_context)

    eval_context = Context(
        config=config,
        model=model,
        loss_fn=torch.nn.CrossEntropyLoss(),
        device=device,
    )
    evaluator = Engine(
        eval_step,
        eval_context,
        callbacks=[
            MetricsCallback(metrics=[MultiClassAccuracy("eval_acc")]),
        ],
    )

    trainer.callbacks.extend(
        [
            EvalOnEpochEndCallback(evaluator, test_loader, 1),
            MetricsCallback(metrics=[MultiClassAccuracy("train_acc")]),
            CheckpointCallback(
                dirpath=sess.get_artifact_dir().joinpath("checkpoints"),
                save_every_n_epochs=9999,
                save_on_engine_end=True,
                best_fn=lambda: evaluator.state.metrics["eval_acc"],
            ),
        ]
    )

    summarize_model(model, train_loader.images[0], depth=2)
    model.to(device)
    trainer.loop(train_loader, max_epochs=config.max_epochs)
    sess.log_params({"exp_config": asdict(config)})
    sess.log_text("model", str(model))
    sess.log_params({"trainer": trainer.to_dict()})
    sess.log_params({"evaluator": evaluator.to_dict()})

[2m2025-09-17T07:34:36.797260Z[0m [[32m[1minfo     [0m] [1mStarted experiment session    [0m [36mrun_dir[0m=[35mruns/train_img_clf_cifar10/run_20250917_130436[0m [36mrun_name[0m=[35mrun_20250917_130436[0m
-------------------------------------------------
Name       | Type            | Params  | In | Out
-------------------------------------------------
stem       | Sequential      | 3.8 K   | ?  | ?  
stem.0     | ConvActNorm     | 3.8 K   | ?  | ?  
backbone   | Sequential      | 587.8 K | ?  | ?  
backbone.0 | ConvMixerBlock  | 73.5 K  | ?  | ?  
backbone.1 | ConvMixerBlock  | 73.5 K  | ?  | ?  
backbone.2 | ConvMixerBlock  | 73.5 K  | ?  | ?  
backbone.3 | ConvMixerBlock  | 73.5 K  | ?  | ?  
backbone.4 | ConvMixerBlock  | 73.5 K  | ?  | ?  
backbone.5 | ConvMixerBlock  | 73.5 K  | ?  | ?  
backbone.6 | ConvMixerBlock  | 73.5 K  | ?  | ?  
backbone.7 | ConvMixerBlock  | 73.5 K  | ?  | ?  
head       | Sequential      | 2.6 K   | ?  | ?  
head.0     | GlobalAvgPool2d |

KeyboardInterrupt: 