<a href="https://colab.research.google.com/github/RaviShah1/Combining-Cords-and-Composer/blob/main/experiments/resnet18_cifar10/session3/glister_only.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!git clone https://github.com/decile-team/cords.git
%cd cords
%ls

fatal: destination path 'cords' already exists and is not an empty directory.
/content/cords
[0m[01;34mbenchmarks[0m/   [01;34mcords[0m/  [01;34mexamples[0m/    README.md      setup.py      train_sl.py
CITATION.CFF  [01;34mdata[0m/   LICENSE.txt  [01;34mrequirements[0m/  [01;34mtests[0m/        train_ssl.py
[01;34mconfigs[0m/      [01;34mdocs[0m/   model.pt     [01;34mresults[0m/       train_hpo.py


In [2]:
!pip install dotmap apricot-select ray[default] ray[tune] mosaicml


Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [3]:
"""
You can select any options from the all options list.

Put the options you want in the options list and throughout the notebook, see 
how the composer functions you selected are applied.

Some options will decrease overall training time, a few options will decrease
per epoch training time, and a few options may not improve or harm training time.
"""

all_options = ["augmix", "blurpool", "channels_last", "colout", "cutout", "ema",
               "factorize", "fused_layernorm", "label_smooth", "layer_freeze", 
               "mixup", "rand_aug", "squeeze_excite"]
options = []

In [4]:
import time
import numpy as np
import os
import copy
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.models import resnet18
from cords.utils.data.datasets.SL import gen_dataset
from torch.utils.data import Subset
from cords.utils.config_utils import load_config_data
import os.path as osp
from cords.utils.data.data_utils import WeightedSubset
from composer import functional as cf
from composer.algorithms.augmix import AugmentAndMixTransform
from composer.algorithms.colout import ColOutTransform
from composer.algorithms.randaugment import RandAugmentTransform
from ray import tune
import torchvision
from torchvision import transforms
from torch.utils.data import random_split

In [5]:
#trainset, validset, testset, num_cls = gen_dataset('data/', 'cifar10', None, isnumpy=False)
def get_data():
    torch.cuda.manual_seed(42)
    torch.manual_seed(42)

    train_transforms = [
                        transforms.RandomCrop(32, padding=4),
                        transforms.RandomHorizontalFlip(),
                        transforms.ToTensor(),
                        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
                       ]
    if "augmix" in options:
        train_transforms.append(AugmentAndMixTransform(severity=3,
                                                       width=3,
                                                       depth=-1,
                                                       alpha=1.0,
                                                       augmentation_set="all"))
    if "colout" in options:
        train_transforms.append(ColOutTransform(p_row=0.15, p_col=0.15))
    if "rand_aug" in options:
        train_transforms.append(RandAugmentTransform(severity=4,
                                                     depth=2,
                                                     augmentation_set="all"))

    cifar_transform = transforms.Compose(train_transforms)

    cifar_tst_transform = transforms.Compose([
                            transforms.ToTensor(),
                            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
                        ])

    num_cls = 10

    fullset = torchvision.datasets.CIFAR10(root='data/', train=True, download=True, transform=cifar_transform)
    testset = torchvision.datasets.CIFAR10(root='data/', train=False, download=True, transform=cifar_tst_transform)

    validation_set_fraction = 0.1
    num_fulltrn = len(fullset)
    num_val = int(num_fulltrn * validation_set_fraction)
    num_trn = num_fulltrn - num_val
    trainset, valset = random_split(fullset, [num_trn, num_val])

    return trainset, valset, testset, num_cls

trainset, validset, testset, num_cls = get_data()

Files already downloaded and verified
Files already downloaded and verified


In [6]:
trn_batch_size = 20
val_batch_size = 20
tst_batch_size = 1000

# Creating the Data Loaders
trainloader = torch.utils.data.DataLoader(trainset, batch_size=trn_batch_size,
                                          shuffle=False, pin_memory=True)

valloader = torch.utils.data.DataLoader(validset, batch_size=val_batch_size,
                                        shuffle=False, pin_memory=True)

testloader = torch.utils.data.DataLoader(testset, batch_size=tst_batch_size,
                                          shuffle=False, pin_memory=True)

In [7]:
from cords.utils.models import ResNet18
numclasses = 10
device = 'cuda' #Device Argument
model = ResNet18(10)

if "blurpool" in options:
    model = cf.apply_blurpool(model)
if "channels_last" in options:
    cf.apply_channels_last(model)
if "ema" in options:
    ema_model = copy.deepcopy(model)
    ema_model.to(device)
if "factorize" in options:
    cf.apply_factorization(model)
if "fused_layernorm" in options:
    cf.apply_fused_layernorm(model)
if "squeeze_excite" in options:
    model = cf.apply_squeeze_excite(model)


model = model.to(device)

In [8]:
criterion = nn.CrossEntropyLoss()
criterion_nored = nn.CrossEntropyLoss(reduction='none')

In [9]:
def save_ckpt(state, ckpt_path):
    torch.save(state, ckpt_path)


def load_ckpt(ckpt_path, model, optimizer):
    checkpoint = torch.load(ckpt_path)
    start_epoch = checkpoint['epoch']
    model.load_state_dict(checkpoint['state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    loss = checkpoint['loss']
    metrics = checkpoint['metrics']
    return start_epoch, model, optimizer, loss, metrics

In [10]:
def generate_cumulative_timing(mod_timing):
    tmp = 0
    mod_cum_timing = np.zeros(len(mod_timing))
    for i in range(len(mod_timing)):
        tmp += mod_timing[i]
        mod_cum_timing[i] = tmp
    return mod_cum_timing / 3600

In [11]:
optimizer = optim.SGD(model.parameters(), lr=1e-2,
                                  momentum=0.9,
                                  weight_decay=5e-4,
                                  nesterov=False)

#T_max is the maximum number of scheduler steps. Here we are using the number of epochs as the maximum number of scheduler steps.

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                       T_max=30)

In [12]:
def __get_logger(results_dir):
  os.makedirs(results_dir, exist_ok=True)
  # setup logger
  plain_formatter = logging.Formatter("[%(asctime)s] %(name)s %(levelname)s: %(message)s",
                                      datefmt="%m/%d %H:%M:%S")
  logger = logging.getLogger(__name__)
  logger.setLevel(logging.INFO)
  s_handler = logging.StreamHandler(stream=sys.stdout)
  s_handler.setFormatter(plain_formatter)
  s_handler.setLevel(logging.INFO)
  logger.addHandler(s_handler)
  f_handler = logging.FileHandler(os.path.join(results_dir, "results.log"))
  f_handler.setFormatter(plain_formatter)
  f_handler.setLevel(logging.DEBUG)
  logger.addHandler(f_handler)
  logger.propagate = False
  return logger

In [13]:
import logging
import os
import os.path as osp
import sys

#Results logging directory
results_dir = osp.abspath(osp.expanduser('results'))
logger = __get_logger(results_dir)

In [14]:
logger.info("hello")

[07/20 20:26:06] __main__ INFO: hello


In [15]:
from cords.utils.data.dataloader.SL.adaptive import GLISTERDataLoader, OLRandomDataLoader, \
    CRAIGDataLoader, GradMatchDataLoader, RandomDataLoader
from dotmap import DotMap

selection_strategy = 'GLISTER'
dss_args = dict(model=model,
                loss=criterion_nored,
                eta=0.01,
                num_classes=10,
                num_epochs=50,
                device='cuda',
                fraction=0.1,
                select_every=10,
                kappa=0,
                linear_layer=False,
                selection_type='SL',
                greedy='Stochastic')
dss_args = DotMap(dss_args)

dataloader = GLISTERDataLoader(trainloader, valloader, dss_args, logger, 
                                  batch_size=20, 
                                  shuffle=True,
                                  pin_memory=False)



In [16]:
#Training Arguments
num_epochs = 50

#Arguments for results logging
print_every = 1
print_args = ["val_loss", "val_acc", "tst_loss", "tst_acc", "time"]

#Argumets for checkpointing
save_every = 20
is_save = True

#Evaluation Metrics
trn_losses = list()
val_losses = list()
tst_losses = list()
subtrn_losses = list()
timing = list()
trn_acc = list()
val_acc = list()  
tst_acc = list()  
subtrn_acc = list()

In [17]:
"""
################################################# Training Loop #################################################
"""
for epoch in range(num_epochs):
    subtrn_loss = 0
    subtrn_correct = 0
    subtrn_total = 0
    model.train()
    start_time = time.time()
    for _, (inputs, targets, weights) in enumerate(dataloader):
        inputs = inputs.to(device)
        targets = targets.to(device, non_blocking=True)
        weights = weights.to(device)  
        optimizer.zero_grad()
        if "mixup" in options:
            inputs, y_perm, mixing = cf.mixup_batch(inputs, targets, alpha=0.2)
        if "cutout" in options:
             inputs = cf.cutout_batch(inputs, num_holes=1, length=0.5)
        outputs = model(inputs)
        if "label_smooth" in options:
            targets = cf.smooth_labels(outputs, targets, smoothing=0.1)
        losses = criterion_nored(outputs, targets)
        loss = torch.dot(losses, weights/(weights.sum()))
        loss.backward()
        subtrn_loss += loss.item()
        optimizer.step()
        _, predicted = outputs.max(1)
        subtrn_total += targets.size(0)
        subtrn_correct += predicted.eq(targets).sum().item()

        if "ema" in options:
            cf.compute_ema(model, ema_model, smoothing=0.99)
        if "layer_freeze" in options:
            freeze_depth, feeze_level = cf.freeze_layers(
                                        model=model,
                                        optimizers=optimizer,
                                        current_duration=epoch/num_epochs,
                                        freeze_start=0.0,
                                        freeze_level=1.0
                                    )

    epoch_time = time.time() - start_time
    scheduler.step()
    timing.append(epoch_time)


    """
    ################################################# Evaluation Loop #################################################
    """

    if (epoch + 1) % print_every == 0:
        trn_loss = 0
        trn_correct = 0
        trn_total = 0
        val_loss = 0
        val_correct = 0
        val_total = 0
        tst_correct = 0
        tst_total = 0
        tst_loss = 0
        model.eval()

        if ("trn_loss" in print_args) or ("trn_acc" in print_args):
            with torch.no_grad():
                for _, (inputs, targets) in enumerate(trainloader):
                    inputs, targets = inputs.to(device), \
                                      targets.to(device, non_blocking=True)
                    outputs = model(inputs)
                    loss = criterion(outputs, targets)
                    trn_loss += loss.item()
                    if "trn_acc" in print_args:
                        _, predicted = outputs.max(1)
                        trn_total += targets.size(0)
                        trn_correct += predicted.eq(targets).sum().item()
                trn_losses.append(trn_loss)

            if "trn_acc" in print_args:
                trn_acc.append(trn_correct / trn_total)

        if ("val_loss" in print_args) or ("val_acc" in print_args):
            with torch.no_grad():
                for _, (inputs, targets) in enumerate(valloader):
                    inputs, targets = inputs.to(device), \
                                      targets.to(device, non_blocking=True)
                    outputs = model(inputs)
                    loss = criterion(outputs, targets)
                    val_loss += loss.item()
                    if "val_acc" in print_args:
                        _, predicted = outputs.max(1)
                        val_total += targets.size(0)
                        val_correct += predicted.eq(targets).sum().item()
                val_losses.append(val_loss)

            if "val_acc" in print_args:
                val_acc.append(val_correct / val_total)

        if ("tst_loss" in print_args) or ("tst_acc" in print_args):
            with torch.no_grad():
                for _, (inputs, targets) in enumerate(testloader):
                    inputs, targets = inputs.to(device), \
                                      targets.to(device, non_blocking=True)
                    outputs = model(inputs)
                    loss = criterion(outputs, targets)
                    tst_loss += loss.item()
                    if "tst_acc" in print_args:
                        _, predicted = outputs.max(1)
                        tst_total += targets.size(0)
                        tst_correct += predicted.eq(targets).sum().item()
                tst_losses.append(tst_loss)

            if "tst_acc" in print_args:
                tst_acc.append(tst_correct / tst_total)

        if "subtrn_acc" in print_args:
            subtrn_acc.append(subtrn_correct / subtrn_total)

        if "subtrn_losses" in print_args:
            subtrn_losses.append(subtrn_loss)

        print_str = "Epoch: " + str(epoch + 1)

        """
        ################################################# Results Printing #################################################
        """

        for arg in print_args:

            if arg == "val_loss":
                print_str += " , " + "Validation Loss: " + str(val_losses[-1])

            if arg == "val_acc":
                print_str += " , " + "Validation Accuracy: " + str(val_acc[-1])

            if arg == "tst_loss":
                print_str += " , " + "Test Loss: " + str(tst_losses[-1])

            if arg == "tst_acc":
                print_str += " , " + "Test Accuracy: " + str(tst_acc[-1])

            if arg == "trn_loss":
                print_str += " , " + "Training Loss: " + str(trn_losses[-1])

            if arg == "trn_acc":
                print_str += " , " + "Training Accuracy: " + str(trn_acc[-1])

            if arg == "subtrn_loss":
                print_str += " , " + "Subset Loss: " + str(subtrn_losses[-1])

            if arg == "subtrn_acc":
                print_str += " , " + "Subset Accuracy: " + str(subtrn_acc[-1])

            if arg == "time":
                print_str += " , " + "Timing: " + str(timing[-1])

        logger.info(print_str)

    """
    ################################################# Checkpoint Saving #################################################
    """

    if ((epoch + 1) % save_every == 0) and is_save:

        metric_dict = {}

        for arg in print_args:
            if arg == "val_loss":
                metric_dict['val_loss'] = val_losses
            if arg == "val_acc":
                metric_dict['val_acc'] = val_acc
            if arg == "tst_loss":
                metric_dict['tst_loss'] = tst_losses
            if arg == "tst_acc":
                metric_dict['tst_acc'] = tst_acc
            if arg == "trn_loss":
                metric_dict['trn_loss'] = trn_losses
            if arg == "trn_acc":
                metric_dict['trn_acc'] = trn_acc
            if arg == "subtrn_loss":
                metric_dict['subtrn_loss'] = subtrn_losses
            if arg == "subtrn_acc":
                metric_dict['subtrn_acc'] = subtrn_acc
            if arg == "time":
                metric_dict['time'] = timing

        ckpt_state = {
            'epoch': epoch + 1,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'loss': criterion_nored,
            'metrics': metric_dict
        }

        # save checkpoint
        save_ckpt(ckpt_state, 'model.pt')
        logger.info("Model checkpoint saved at epoch: {0:d}".format(epoch + 1))


"""
################################################# Results Summary #################################################
"""

logger.info("{0:s} Selection Run---------------------------------".format(selection_strategy))
logger.info("Final SubsetTrn: {0:f}".format(subtrn_loss))
if "val_loss" in print_args:
    if "val_acc" in print_args:
        logger.info("Validation Loss: %.2f , Validation Accuracy: %.2f", val_loss, val_acc[-1])
    else:
        logger.info("Validation Loss: %.2f", val_loss)

if "tst_loss" in print_args:
    if "tst_acc" in print_args:
        logger.info("Test Loss: %.2f, Test Accuracy: %.2f", tst_loss, tst_acc[-1])
    else:
        logger.info("Test Data Loss: %f", tst_loss)
logger.info('---------------------------------------------------------------------')
logger.info(selection_strategy)
logger.info('---------------------------------------------------------------------')

"""
################################################# Final Results Logging #################################################
"""

if "val_acc" in print_args:
    val_str = "Validation Accuracy, "
    for val in val_acc:
        val_str = val_str + " , " + str(val)
    logger.info(val_str)

if "tst_acc" in print_args:
    tst_str = "Test Accuracy, "
    for tst in tst_acc:
        tst_str = tst_str + " , " + str(tst)
    logger.info(tst_str)

if "time" in print_args:
    time_str = "Time, "
    for t in timing:
        time_str = time_str + " , " + str(t)
    logger.info(timing)

timing_array = np.array(timing)
cum_timing = list(generate_cumulative_timing(timing_array))
logger.info("Total time taken by %s = %.4f ", selection_strategy, cum_timing[-1])

[07/20 20:26:24] __main__ INFO: Epoch: 1 , Validation Loss: 467.67154943943024 , Validation Accuracy: 0.312 , Test Loss: 18.367909789085388 , Test Accuracy: 0.3203 , Timing: 6.976815223693848
[07/20 20:26:36] __main__ INFO: Epoch: 2 , Validation Loss: 455.4665801525116 , Validation Accuracy: 0.324 , Test Loss: 17.87793242931366 , Test Accuracy: 0.3511 , Timing: 5.579818248748779
[07/20 20:26:48] __main__ INFO: Epoch: 3 , Validation Loss: 421.05246472358704 , Validation Accuracy: 0.3802 , Test Loss: 16.63795566558838 , Test Accuracy: 0.3885 , Timing: 5.5597498416900635
[07/20 20:27:00] __main__ INFO: Epoch: 4 , Validation Loss: 404.9928056001663 , Validation Accuracy: 0.4002 , Test Loss: 15.740660667419434 , Test Accuracy: 0.4347 , Timing: 5.51932692527771
[07/20 20:27:12] __main__ INFO: Epoch: 5 , Validation Loss: 376.82281017303467 , Validation Accuracy: 0.4436 , Test Loss: 14.828064560890198 , Test Accuracy: 0.4594 , Timing: 5.5061891078948975
[07/20 20:27:24] __main__ INFO: Epoch: 6

In [18]:
print(f"Average Time / Epoch: {np.mean(timing)} \nTotal Epochs: {num_epochs}")

Average Time / Epoch: 8.418898530006409 
Total Epochs: 50


In [19]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Not connected to a GPU')
else:
  print(gpu_info)

Wed Jul 20 20:38:38 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla P100-PCIE...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   65C    P0    48W / 250W |   5145MiB / 16280MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces