<a href="https://colab.research.google.com/github/Demon-Sheriff/Linear-Alg_ML_fs/blob/master/sumo_mem_profiling.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# import torch

# def frag_report():
#     stats = torch.cuda.memory_stats()
#     print("Total allocated: ", stats["allocated_bytes.all.current"] / 1e9, "GB")
#     print("Total reserved : ", stats["reserved_bytes.all.current"] / 1e9, "GB")
#     print("Active bytes   : ", stats["active_bytes.all.current"] / 1e9, "GB")
#     print("Inactive split : ", stats["inactive_split_bytes.all.current"] / 1e9, "GB")
#     print("Largest block  : ", stats["largest_block_bytes.all.current"] / 1e9, "GB")

# frag_report()

In [2]:
# %env PYTORCH_CUDA_ALLOC_CONF=backend:cudaMallocAsync
# %env PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:128
%env PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True

env: PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True


In [3]:
import torch

In [4]:
torch.cuda.memory._record_memory_history()

In [5]:
@torch.compile
def orthogonalization_svd(M):
  U,sig,V = torch.linalg.svd(M.float(), full_matrices=False) # TODO: figuring out full_matrices=True case
  return U @ V

class SUMO(torch.optim.Optimizer):
  def __init__(self, params, lr=1e-4, scale_factor=2.0, weight_decay=0.01,
                rank=8, K=50, step_clip_ratio=1.1, momentum_coeff=0.9):
      defaults = dict(lr=lr, scale_factor=scale_factor, weight_decay=weight_decay,
                      step_clip_ratio=step_clip_ratio, momentum_coeff=momentum_coeff, rank=rank, K=K)
      super(SUMO, self).__init__(params, defaults)

  def step(self):
    for group in self.param_groups:
      lr = group['lr']
      scale_factor = group['scale_factor']
      weight_decay = group['weight_decay']
      rank = group['rank']
      K = group['K']
      step_clip_ratio = group['step_clip_ratio']
      momentum_coeff = group['momentum_coeff']

      # mul_, add_ and other inplace ops would only be used when manipulating internal gradient params.
      for p in group['params']:
        g = p.grad # the gradient matrix for w_t-1
        if g is None: continue
        state = self.state[p] # current state of the optimizer

        if g.dim() < 2: # logic for handling bias/layernorm like params (< 2 dim)
           # simple sgd + weight decay fallback
           p.data.mul_(1 - lr * weight_decay)
           p.data.add_(p.grad, alpha=-lr)
           continue
        if 'step' not in state.keys(): state['step'] = 0
        _, _, m, n = g.size()
        # m >= n assumption without loss of generality. [if n > m we transpose]
        if (rev:=(m<n)) : g = g.transpose(-2, -1); _, _, m, n = g.size()
        if state['step'] % K == 0:
          if rank >= (k:=min(m, n)): U_float, sig, V = torch.linalg.svd(g.float(), full_matrices=False) # full rank case : U shape : (m, n)
          else: U_float, sig, V = torch.svd_lowrank(g.float(), q=rank) # U (m, r)

          if 'sub_moment_buffer' not in state.keys():
            # first step so no rotations. [init the bufs]
            state['sub_moment_buffer'] = U_float
            state['moment_buffer'] = torch.zeros(U_float.size(-1), n, device=g.device, dtype=torch.float)
            state['orthogonal_buf'] = torch.zeros(U_float.size(-1), n, device=g.device, dtype=torch.float)
            moment = state['moment_buffer']
          else:
            # apply rotations if not first step
            R = U_float.transpose(-2, -1) @ state['sub_moment_buffer'] # Q_t @ Q_t-1
            moment = R @ state['moment_buffer'] # use rotated moment for current update
            state['sub_moment_buffer'] = U_float
        else:
          U_float = state['sub_moment_buffer']
          moment = state['moment_buffer']
        U_for_matmul = U_float.to(g.dtype)
        G_hat = U_for_matmul.transpose(-2, -1) @ g
        M_float = momentum_coeff * moment + G_hat.float()
        state['moment_buffer'] = M_float
        O_float = orthogonalization_svd(M_float) # (r, n) if r <= k else (k, n)

        if (prev_o_norm:=torch.norm(state['orthogonal_buf'])) > 1e-8 and ((o_norm:=torch.norm(O_float)) > step_clip_ratio * prev_o_norm):
          O_float = (O_float * (step_clip_ratio * prev_o_norm)) / o_norm # (r, n) if r <= k else (k, n)
        # update orthogonal buffer
        state['orthogonal_buf'] = O_float

        # weight updates
        update_direction = g - (U_for_matmul @ (G_hat - O_float.to(g.dtype))).view_as(g)
        p.data.mul_(1 - lr * weight_decay) # apply weight decay to the weigts
        p.data.add_(scale_factor * (update_direction.transpose(-2, -1)) if rev else scale_factor * update_direction, alpha=-lr)
        state['step'] += 1

In [6]:
stats = torch.cuda.memory_stats()

In [7]:
for k,v in stats.keys():
  print(k)

In [8]:
import os
import wandb
import sys
with open(sys.argv[0]) as f:
    code = f.read()
import uuid
from math import ceil

# import torch
from torch import nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as T

torch.backends.cudnn.benchmark = True

#############################################
#               SUMO optimizer              #
#############################################

@torch.compile
def orthogonalization_svd(M):
  U,sig,V = torch.linalg.svd(M.float(), full_matrices=False) # TODO: figuring out full_matrices=True case
  return U @ V

class SUMO(torch.optim.Optimizer):
  def __init__(self, params, lr=1e-4, scale_factor=2.0, weight_decay=0.01,
                rank=8, K=200, step_clip_ratio=1.1, momentum_coeff=0.9):
      defaults = dict(lr=lr, scale_factor=scale_factor, weight_decay=weight_decay,
                      step_clip_ratio=step_clip_ratio, momentum_coeff=momentum_coeff, rank=rank, K=K)
      super(SUMO, self).__init__(params, defaults)

  def __setstate__(self, state):
    super(SUMO, self).__setstate__(state)

  def step(self):
    for group in self.param_groups:
      # step_size = group['lr']
      lr = group['lr']
      scale_factor = group['scale_factor']
      weight_decay = group['weight_decay']
      rank = group['rank']
      K = group['K']
      step_clip_ratio = group['step_clip_ratio']
      momentum_coeff = group['momentum_coeff']

      # mul_, add_ and other inplace ops would only be used when manipulating internal gradient params.
      for p in group['params']:
        g = p.grad # the gradient matrix for w_t-1
        if g is None: continue
        state = self.state[p] # current state of the optimizer

        if g.dim() < 2: # logic for handling bias/layernorm like params
           # simple sgd + weight decay fallback
           p.data.mul_(1 - lr * weight_decay)
           p.data.add_(p.grad, alpha=-lr)
           continue

        if 'step' not in state.keys(): state['step'] = 0
        # print(g.dim())
        _, _, m, n = g.size()
        # m >= n assumption without loss of generality. [if n > m we transpose]
        if (rev:=(m<n)) : g = g.transpose(-2, -1); _, _, m, n = g.size()

        if state['step'] % K == 0:
          if rank >= (k:=min(m, n)): U_float, sig, V = torch.linalg.svd(g.float(), full_matrices=False) # full rank case : U shape : (m, n)
          else: U_float, sig, V = torch.svd_lowrank(g.float(), q=rank) # U (m, r)

          if 'sub_moment_buffer' not in state.keys():
            # first step so no rotations. [init the bufs]
            state['sub_moment_buffer'] = U_float # why init with U?
            state['moment_buffer'] = torch.zeros(U_float.size(-1), n, device=g.device, dtype=torch.float) # why not init identity ? and why use dynamic U shape calc ? wouldn't U shapes remain fixed.
            state['orthogonal_buf'] = torch.zeros(U_float.size(-1), n, device=g.device, dtype=torch.float) # same qn : why torch.eye ?

            moment = state['moment_buffer']
          else:
            # apply rotations if not first step
            R = U_float.transpose(-2, -1) @ state['sub_moment_buffer'] # Q_t @ Q_t-1
            moment = R @ state['moment_buffer'] # use rotated moment for current update
            state['sub_moment_buffer'] = U_float

        else:
          U_float = state['sub_moment_buffer']
          moment = state['moment_buffer']
        # Cast U to g.dtype (e.g., half) for matrix multiplication with g
        U_for_matmul = U_float.to(g.dtype)

        # G_hat will be g.dtype (e.g., half)
        G_hat = U_for_matmul.transpose(-2, -1) @ g

        # Accumulate moment in float to keep moment_buffer float
        # G_hat.float() is crucial here to match dtype of moment
        M_float = momentum_coeff * moment + G_hat.float()
        state['moment_buffer'] = M_float # Update moment_buffer as float

        # O_float is computed from M_float, so O_float is float
        O_float = orthogonalization_svd(M_float) # (r, n) if r <= k else (k, n)

        # prevent zeroing out o_norm initially
        if (prev_o_norm:=torch.norm(state['orthogonal_buf'])) > 1e-8 and ((o_norm:=torch.norm(O_float)) > step_clip_ratio * prev_o_norm):
          O_float = (O_float * (step_clip_ratio * prev_o_norm)) / o_norm # (r, n) if r <= k else (k, n)

        # update orthogonal buffer
        state['orthogonal_buf'] = O_float

        # weight updates
        update_direction = g - (U_for_matmul @ (G_hat - O_float.to(g.dtype))).view_as(g)
        # try:
        #   update_direction = g - (U@(G_hat - O)).view_as(g)
        # except Exception as e:
        #   print(f"G_hat shape: {G_hat.shape}, O shape: {O.shape}, g shape: {g.shape}, U shape: {U.shape}, M shape: {M.shape}")

        p.data.mul_(1 - lr * weight_decay) # apply weight decay to the weigts
        update = scale_factor * update_direction
        p.data.add_(update.transpose(-2, -1) if rev else update, alpha=-lr)
        # update = lr*weight_decay*g + scale_factor*lr*(g - (U@(G_hat - O)).view_as(g))
        # p.data.add_(update.transpose(-2, -1) if rev else update, alpha=-1.0)

        state['step'] += 1

#############################################
#                DataLoader                 #
#############################################

CIFAR_MEAN = torch.tensor((0.4914, 0.4822, 0.4465))
CIFAR_STD = torch.tensor((0.2470, 0.2435, 0.2616))

def batch_flip_lr(inputs):
    flip_mask = (torch.rand(len(inputs), device=inputs.device) < 0.5).view(-1, 1, 1, 1)
    return torch.where(flip_mask, inputs.flip(-1), inputs)

def batch_crop(images, crop_size):
    r = (images.size(-1) - crop_size)//2
    shifts = torch.randint(-r, r+1, size=(len(images), 2), device=images.device)
    images_out = torch.empty((len(images), 3, crop_size, crop_size), device=images.device, dtype=images.dtype)
    # The two cropping methods in this if-else produce equivalent results, but the second is faster for r > 2.
    if r <= 2:
        for sy in range(-r, r+1):
            for sx in range(-r, r+1):
                mask = (shifts[:, 0] == sy) & (shifts[:, 1] == sx)
                images_out[mask] = images[mask, :, r+sy:r+sy+crop_size, r+sx:r+sx+crop_size]
    else:
        images_tmp = torch.empty((len(images), 3, crop_size, crop_size+2*r), device=images.device, dtype=images.dtype)
        for s in range(-r, r+1):
            mask = (shifts[:, 0] == s)
            images_tmp[mask] = images[mask, :, r+s:r+s+crop_size, :]
        for s in range(-r, r+1):
            mask = (shifts[:, 1] == s)
            images_out[mask] = images_tmp[mask, :, :, r+s:r+s+crop_size]
    return images_out

class CifarLoader:

    def __init__(self, path, train=True, batch_size=500, aug=None):
        data_path = os.path.join(path, 'train.pt' if train else 'test.pt')
        if not os.path.exists(data_path):
            dset = torchvision.datasets.CIFAR10(path, download=True, train=train)
            images = torch.tensor(dset.data)
            labels = torch.tensor(dset.targets)
            torch.save({'images': images, 'labels': labels, 'classes': dset.classes}, data_path)

        data = torch.load(data_path, map_location=torch.device('cuda'))
        self.images, self.labels, self.classes = data['images'], data['labels'], data['classes']
        # It's faster to load+process uint8 data than to load preprocessed fp16 data
        self.images = (self.images.half() / 255).permute(0, 3, 1, 2).to(memory_format=torch.channels_last)

        self.normalize = T.Normalize(CIFAR_MEAN, CIFAR_STD)
        self.proc_images = {} # Saved results of image processing to be done on the first epoch
        self.epoch = 0

        self.aug = aug or {}
        for k in self.aug.keys():
            assert k in ['flip', 'translate'], 'Unrecognized key: %s' % k

        self.batch_size = batch_size
        self.drop_last = train
        self.shuffle = train

    def __len__(self):
        return len(self.images)//self.batch_size if self.drop_last else ceil(len(self.images)/self.batch_size)

    def __iter__(self):

        if self.epoch == 0:
            images = self.proc_images['norm'] = self.normalize(self.images)
            # Pre-flip images in order to do every-other epoch flipping scheme
            if self.aug.get('flip', False):
                images = self.proc_images['flip'] = batch_flip_lr(images)
            # Pre-pad images to save time when doing random translation
            pad = self.aug.get('translate', 0)
            if pad > 0:
                self.proc_images['pad'] = F.pad(images, (pad,)*4, 'reflect')

        if self.aug.get('translate', 0) > 0:
            images = batch_crop(self.proc_images['pad'], self.images.shape[-2])
        elif self.aug.get('flip', False):
            images = self.proc_images['flip']
        else:
            images = self.proc_images['norm']
        # Flip all images together every other epoch. This increases diversity relative to random flipping
        if self.aug.get('flip', False):
            if self.epoch % 2 == 1:
                images = images.flip(-1)

        self.epoch += 1

        indices = (torch.randperm if self.shuffle else torch.arange)(len(images), device=images.device)
        for i in range(len(self)):
            idxs = indices[i*self.batch_size:(i+1)*self.batch_size]
            yield (images[idxs], self.labels[idxs])

#############################################
#            Network Definition             #
#############################################

# note the use of low BatchNorm stats momentum
class BatchNorm(nn.BatchNorm2d):
    def __init__(self, num_features, momentum=0.6, eps=1e-12):
        super().__init__(num_features, eps=eps, momentum=1-momentum)
        self.weight.requires_grad = False
        # Note that PyTorch already initializes the weights to one and bias to zero

class Conv(nn.Conv2d):
    def __init__(self, in_channels, out_channels):
        super().__init__(in_channels, out_channels, kernel_size=3, padding='same', bias=False)

    def reset_parameters(self):
        super().reset_parameters()
        w = self.weight.data
        torch.nn.init.dirac_(w[:w.size(1)])

class ConvGroup(nn.Module):
    def __init__(self, channels_in, channels_out):
        super().__init__()
        self.conv1 = Conv(channels_in,  channels_out)
        self.pool = nn.MaxPool2d(2)
        self.norm1 = BatchNorm(channels_out)
        self.conv2 = Conv(channels_out, channels_out)
        self.norm2 = BatchNorm(channels_out)
        self.activ = nn.GELU()

    def forward(self, x):
        x = self.conv1(x)
        x = self.pool(x)
        x = self.norm1(x)
        x = self.activ(x)
        x = self.conv2(x)
        x = self.norm2(x)
        x = self.activ(x)
        return x

class CifarNet(nn.Module):
    def __init__(self):
        super().__init__()
        widths = dict(block1=64, block2=256, block3=256)
        whiten_kernel_size = 2
        whiten_width = 2 * 3 * whiten_kernel_size**2
        self.whiten = nn.Conv2d(3, whiten_width, whiten_kernel_size, padding=0, bias=True)
        self.whiten.weight.requires_grad = False
        self.layers = nn.Sequential(
            nn.GELU(),
            ConvGroup(whiten_width,     widths['block1']),
            ConvGroup(widths['block1'], widths['block2']),
            ConvGroup(widths['block2'], widths['block3']),
            nn.MaxPool2d(3),
        )
        self.head = nn.Linear(widths['block3'], 10, bias=False)
        for mod in self.modules():
            if isinstance(mod, BatchNorm):
                mod.float()
            else:
                mod.half()

    def reset(self):
        for m in model.modules():
            if type(m) in (nn.Conv2d, Conv, BatchNorm, nn.Linear):
                m.reset_parameters()
        w = self.head.weight.data
        w *= 1 / w.std()

    def init_whiten(self, train_images, eps=5e-4):
        c, (h, w) = train_images.shape[1], self.whiten.weight.shape[2:]
        patches = train_images.unfold(2,h,1).unfold(3,w,1).transpose(1,3).reshape(-1,c,h,w).float()
        patches_flat = patches.view(len(patches), -1)
        est_patch_covariance = (patches_flat.T @ patches_flat) / len(patches_flat)
        eigenvalues, eigenvectors = torch.linalg.eigh(est_patch_covariance, UPLO='U')
        eigenvectors_scaled = eigenvectors.T.reshape(-1,c,h,w) / torch.sqrt(eigenvalues.view(-1,1,1,1) + eps)
        self.whiten.weight.data[:] = torch.cat((eigenvectors_scaled, -eigenvectors_scaled))

    def forward(self, x, whiten_bias_grad=True):
        b = self.whiten.bias
        x = F.conv2d(x, self.whiten.weight, b if whiten_bias_grad else b.detach())
        x = self.layers(x)
        x = x.view(len(x), -1)
        return self.head(x) / x.size(-1)

############################################
#                 Logging                  #
############################################

def print_columns(columns_list, is_head=False, is_final_entry=False):
    print_string = ''
    for col in columns_list:
        print_string += '|  %s  ' % col
    print_string += '|'
    if is_head:
        print('-'*len(print_string))
    print(print_string)
    if is_head or is_final_entry:
        print('-'*len(print_string))

logging_columns_list = ['run   ', 'epoch', 'train_acc', 'val_acc', 'tta_val_acc', 'time_seconds']
def print_training_details(variables, is_final_entry):
    formatted = []
    for col in logging_columns_list:
        var = variables.get(col.strip(), None)
        if type(var) in (int, str):
            res = str(var)
        elif type(var) is float:
            res = '{:0.4f}'.format(var)
        else:
            assert var is None
            res = ''
        formatted.append(res.rjust(len(col)))
    print_columns(formatted, is_final_entry=is_final_entry)

############################################
#               Evaluation                 #
############################################

def infer(model, loader, tta_level=0):

    # Test-time augmentation strategy (for tta_level=2):
    # 1. Flip/mirror the image left-to-right (50% of the time).
    # 2. Translate the image by one pixel either up-and-left or down-and-right (50% of the time,
    #    i.e. both happen 25% of the time).
    #
    # This creates 6 views per image (left/right times the two translations and no-translation),
    # which we evaluate and then weight according to the given probabilities.

    def infer_basic(inputs, net):
        return net(inputs).clone()

    def infer_mirror(inputs, net):
        return 0.5 * net(inputs) + 0.5 * net(inputs.flip(-1))

    def infer_mirror_translate(inputs, net):
        logits = infer_mirror(inputs, net)
        pad = 1
        padded_inputs = F.pad(inputs, (pad,)*4, 'reflect')
        inputs_translate_list = [
            padded_inputs[:, :, 0:32, 0:32],
            padded_inputs[:, :, 2:34, 2:34],
        ]
        logits_translate_list = [infer_mirror(inputs_translate, net)
                                 for inputs_translate in inputs_translate_list]
        logits_translate = torch.stack(logits_translate_list).mean(0)
        return 0.5 * logits + 0.5 * logits_translate

    model.eval()
    test_images = loader.normalize(loader.images)
    infer_fn = [infer_basic, infer_mirror, infer_mirror_translate][tta_level]
    with torch.no_grad():
        return torch.cat([infer_fn(inputs, model) for inputs in test_images.split(2000)])

def evaluate(model, loader, tta_level=0):
    logits = infer(model, loader, tta_level)
    return (logits.argmax(1) == loader.labels).float().mean().item()

############################################
#                Training                  #
############################################

def main(run, model):

    batch_size = 2000
    bias_lr = 0.053
    head_lr = 0.67
    wd = 2e-6 * batch_size

    test_loader = CifarLoader('cifar10', train=False, batch_size=512)
    train_loader = CifarLoader('cifar10', train=True, batch_size=batch_size, aug=dict(flip=True, translate=2))
    if run == 'warmup':
        # The only purpose of the first run is to warmup the compiled model, so we can use dummy data
        train_loader.labels = torch.randint(0, 10, size=(len(train_loader.labels),), device=train_loader.labels.device)
    total_train_steps = ceil(8 * len(train_loader))
    whiten_bias_train_steps = ceil(3 * len(train_loader))

    # Create optimizers and learning rate schedulers
    filter_params = [p for p in model.parameters() if len(p.shape) == 4 and p.requires_grad]
    norm_biases = [p for n, p in model.named_parameters() if 'norm' in n and p.requires_grad]
    param_configs = [dict(params=[model.whiten.bias], lr=bias_lr, weight_decay=wd/bias_lr),
                     dict(params=norm_biases, lr=bias_lr, weight_decay=wd/bias_lr),
                     dict(params=[model.head.weight], lr=head_lr, weight_decay=wd/head_lr)]
    optimizer1 = torch.optim.SGD(param_configs, momentum=0.85, nesterov=True, fused=True)
    # optimizer2 = Muon(filter_params, lr=0.24, momentum=0.6, nesterov=True)
    optimizer2 = SUMO(filter_params)
    optimizers = [optimizer1, optimizer2]

    for group in optimizer1.param_groups:
      group['initial_lr'] = group['lr']

    for group in optimizer2.param_groups:
      group['initial_lr'] = group['lr']

    # For accurately timing GPU code
    starter = torch.cuda.Event(enable_timing=True)
    ender = torch.cuda.Event(enable_timing=True)
    time_seconds = 0.0
    def start_timer():
        starter.record()
    def stop_timer():
        ender.record()
        torch.cuda.synchronize()
        nonlocal time_seconds
        time_seconds += 1e-3 * starter.elapsed_time(ender)

    model.reset()
    step = 0

    # Initialize the whitening layer using training images
    # wandb.init()
    start_timer()
    train_images = train_loader.normalize(train_loader.images[:5000])
    model.init_whiten(train_images)
    stop_timer()

    for epoch in range(ceil(total_train_steps / len(train_loader))):

        ####################
        #     Training     #
        ####################

        start_timer()
        model.train()
        for inputs, labels in train_loader:
            outputs = model(inputs, whiten_bias_grad=(step < whiten_bias_train_steps))
            F.cross_entropy(outputs, labels, label_smoothing=0.2, reduction='sum').backward()
            for group in optimizer1.param_groups[:1]:
                group["lr"] = group["initial_lr"] * (1 - step / whiten_bias_train_steps)
            for group in optimizer1.param_groups[1:]+optimizer2.param_groups:
                group["lr"] = group["initial_lr"] * (1 - step / total_train_steps)
            for opt in optimizers:
                opt.step()
            model.zero_grad(set_to_none=True)
            step += 1
            if step >= total_train_steps:
                break
        stop_timer()
        torch.cuda.synchronize()
        torch.cuda.empty_cache()
        ####################
        #    Evaluation    #
        ####################

        # Save the accuracy and loss from the last training batch of the epoch
        train_acc = (outputs.detach().argmax(1) == labels).float().mean().item()
        val_acc = evaluate(model, test_loader, tta_level=0)
        print_training_details(locals(), is_final_entry=False)
        run = None # Only print the run number once


    ####################
    #  TTA Evaluation  #
    ####################
    # wandb.finish()
    start_timer()
    tta_val_acc = evaluate(model, test_loader, tta_level=0)
    stop_timer()
    epoch = 'eval'
    print_training_details(locals(), is_final_entry=True)

    return tta_val_acc

In [9]:
if __name__ == "__main__":

    # We re-use the compiled model between runs to save the non-data-dependent compilation time
    model = CifarNet().cuda().to(memory_format=torch.channels_last)
    model.compile(mode='max-autotune')

    print_columns(logging_columns_list, is_head=True)
    main('warmup', model)
    accs = torch.tensor([main(run, model) for run in range(200)])
    print('Mean: %.4f    Std: %.4f' % (accs.mean(), accs.std()))

    log_dir = os.path.join('logs', str(uuid.uuid4()))
    os.makedirs(log_dir, exist_ok=True)
    log_path = os.path.join(log_dir, 'log.pt')
    torch.save(dict(code=code, accs=accs), log_path)
    print(os.path.abspath(log_path))

---------------------------------------------------------------------------------
|  run     |  epoch  |  train_acc  |  val_acc  |  tta_val_acc  |  time_seconds  |
---------------------------------------------------------------------------------


100%|██████████| 170M/170M [00:04<00:00, 42.5MB/s]
  return torch._C._get_cublas_allow_tf32()
W1210 10:17:51.272000 192 torch/_inductor/utils.py:1558] [0/0] Not enough SMs to use max_autotune_gemm mode


|  warmup  |      0  |     0.0875  |   0.1113  |               |      116.4343  |
|          |      1  |     0.1000  |   0.1019  |               |      126.2897  |
|          |      2  |     0.0925  |   0.0857  |               |      136.3367  |
|          |      3  |     0.1120  |   0.1167  |               |      183.8164  |
|          |      4  |     0.1250  |   0.1072  |               |      194.1475  |
|          |      5  |     0.1115  |   0.1073  |               |      204.6166  |
|          |      6  |     0.1195  |   0.1082  |               |      215.1058  |
|          |      7  |     0.1310  |   0.0938  |               |      225.4990  |
|          |   eval  |     0.1310  |   0.0938  |       0.0938  |      225.6839  |
---------------------------------------------------------------------------------
|       0  |      0  |     0.5090  |   0.3644  |               |       10.4919  |
|          |      1  |     0.6485  |   0.6299  |               |       20.8188  |
|          |    

KeyboardInterrupt: 

In [10]:
torch.cuda.memory._dump_snapshot("/content/sumo_expandable_segments_snapshot.pickle")

In [17]:
torch.cuda.synchronize()
stats = torch.cuda.memory_stats()

In [20]:
stats

OrderedDict([('active.all.allocated', 691114),
             ('active.all.current', 85),
             ('active.all.freed', 691029),
             ('active.all.peak', 127),
             ('active.large_pool.allocated', 176118),
             ('active.large_pool.current', 20),
             ('active.large_pool.freed', 176098),
             ('active.large_pool.peak', 46),
             ('active.small_pool.allocated', 514996),
             ('active.small_pool.current', 65),
             ('active.small_pool.freed', 514931),
             ('active.small_pool.peak', 85),
             ('active_bytes.all.allocated', 771592927744),
             ('active_bytes.all.current', 1457926144),
             ('active_bytes.all.freed', 770135001600),
             ('active_bytes.all.peak', 9683783168),
             ('active_bytes.large_pool.allocated', 718051320832),
             ('active_bytes.large_pool.current', 1453645312),
             ('active_bytes.large_pool.freed', 716597675520),
             ('active_byt

In [22]:
 print(f"[{"after code"}] allocated={torch.cuda.memory_allocated()/1024**2:.1f}MB  "
          f"reserved={torch.cuda.memory_reserved()/1024**2:.1f}MB  "
          f"largest_block={stats['largest_block_bytes.all.current']/1024**2:.1f}MB  "
          f"inactive_split={stats['inactive_split_bytes.all.current']/1024**2:.1f}MB")

KeyError: 'largest_block_bytes.all.current'

In [23]:
summary = torch.cuda.memory.memory_summary()

In [25]:
print(summary)

|                  PyTorch CUDA memory summary, device ID 0                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 0            |        cudaMalloc retries: 0         |
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      |   1390 MiB |   9235 MiB | 735848 MiB | 734457 MiB |
|       from large pool |   1386 MiB |   9230 MiB | 684787 MiB | 683400 MiB |
|       from small pool |      4 MiB |      8 MiB |  51061 MiB |  51057 MiB |
|---------------------------------------------------------------------------|
| Active memory         |   1390 MiB |   9235 MiB | 735848 MiB | 734457 MiB |
|       from large pool |   1386 MiB |   9230 MiB | 684787 MiB | 683400 MiB |
|       from small pool |      4 MiB |      8 MiB |  51061 MiB |  51057 MiB |
|---------------------------------------------------------------