In [None]:
from google.colab import drive
drive.mount('/content/drive')

import os
import shutil

# copy dataset from google drive to local storage
drive_data_path = '/content/drive/MyDrive/dynamic_res/imagenette2'
local_data_path = '/content/imagenette2'

if not os.path.exists(local_data_path):
    shutil.copytree(drive_data_path, local_data_path)
else:
    print(f'dataset already exists at {local_data_path}')

data_path = local_data_path

In [15]:
import argparse
import time
import csv
import numpy as np
import os
import torch
import torch.nn as nn
import torch.optim as optim

from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader

import torchvision.datasets as datasets
import torchvision.models as models
import torchvision.transforms as transforms
import sys

def parse_args():
    argv = [arg for arg in sys.argv if not arg.startswith('-f')]

    parser = argparse.ArgumentParser()

    parser.add_argument('data')
    parser.add_argument('--arch')
    parser.add_argument('--epochs', type=int)
    parser.add_argument('--batch-size', type=int)
    parser.add_argument('--workers', type=int)
    parser.add_argument('--r-min', type=int)
    parser.add_argument('--r-max', type=int)
    parser.add_argument('--gamma', type=float)
    parser.add_argument('--reassign-epoch', type=int)
    parser.add_argument('--ratio-schedule')

    parser.add_argument('--lr', type=float, default=0.1)
    parser.add_argument('--momentum', type=float, default=0.9)
    parser.add_argument('--weight-decay', type=float, default=1e-4)
    parser.add_argument('--print-freq', type=int, default=10)

    return parser.parse_args(argv[1:])

class AverageMeter:
    def __init__(self, name):
        self.name = name
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

def accuracy(output, target, topk=(1,)):
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))

        return res

# geometric resolution schedule
def compute_resolution_schedule(r_max, r_min, gamma):
    schedule = []
    r = r_max

    while True:
        schedule.append(r)
        r_next = int(r * gamma)
        r_next = max(r_next, r_min)

        if r_next == r:
            break

        r = r_next

    return schedule

# applies different resolutions to different samples
class DynamicResolutionDataset(torch.utils.data.Dataset):
    def __init__(self, base_dataset, sample_resolutions, normalize):
        self.base_dataset = base_dataset
        self.sample_resolutions = sample_resolutions
        self.normalize = normalize

    def __len__(self):
        return len(self.base_dataset)

    def __getitem__(self, idx):
        try:
            img, target = self.base_dataset.samples[idx]
            img = self.base_dataset.loader(img)
            resolution = self.sample_resolutions.get(idx, 224)

            transform = transforms.Compose([transforms.RandomResizedCrop(resolution), transforms.RandomHorizontalFlip(), transforms.ToTensor(), self.normalize])

            img = transform(img)

            return img, target, idx
        except Exception as e:
            # if image is corrupted, use next sample
            print(f'skipping corrupted image at index {idx}: {e}')

            return self.__getitem__((idx + 1) % len(self.base_dataset))

class SubsetDynamicResolutionDataset(torch.utils.data.Dataset):
    def __init__(self, dynamic_dataset, indices):
        self.dynamic_dataset = dynamic_dataset
        self.indices = list(indices)

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, subset_idx):
        original_idx = self.indices[subset_idx]
        return self.dynamic_dataset[original_idx]

# group by resolution
def collate_by_resolution(batch):
    resolution_groups = {}
    for img, target, idx in batch:
        res = img.shape[-1]

        if res not in resolution_groups:
            resolution_groups[res] = {'images': [], 'targets': [], 'indices': []}

        resolution_groups[res]['images'].append(img)
        resolution_groups[res]['targets'].append(target)
        resolution_groups[res]['indices'].append(idx)

    batches = []
    for res, group in resolution_groups.items():
        images = torch.stack(group['images'])
        targets = torch.tensor(group['targets'])
        indices = torch.tensor(group['indices'])

        batches.append((images, targets, indices))

    return batches

def train_epoch(train_loader, model, criterion, optimizer, epoch, device, args, track_loss=False):
    losses = AverageMeter('Loss')
    top1 = AverageMeter('Acc@1')
    top5 = AverageMeter('Acc@5')

    model.train()

    epoch_start = time.time()
    sample_losses = {} if track_loss else None

    for i, batches in enumerate(train_loader):
        # handle multiple resolution batches
        for images, target, indices in batches:
            if track_loss:
                indices = indices.cpu().numpy()

            images = images.to(device, non_blocking=True)
            target = target.to(device, non_blocking=True)

            output = model(images)
            loss_per_sample = nn.functional.cross_entropy(output, target, reduction='none')
            loss = loss_per_sample.mean()

            if track_loss:
                losses_cpu = loss_per_sample.detach().cpu().numpy()
                for idx, sample_loss in zip(indices, losses_cpu):
                    if idx not in sample_losses:
                        sample_losses[idx] = []

                    sample_losses[idx].append(sample_loss)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            acc1, acc5 = accuracy(output, target, topk=(1, 5))
            losses.update(loss.item(), images.size(0))
            top1.update(acc1[0], images.size(0))
            top5.update(acc5[0], images.size(0))

        if i % args.print_freq == 0:
            print(f'Epoch [{epoch}][{i}/{len(train_loader)}] Loss {losses.avg} Acc@1 {top1.avg} Acc@5 {top5.avg}')

    epoch_time = time.time() - epoch_start

    if track_loss:
        aggregated_losses = {idx: np.mean(losses) for idx, losses in sample_losses.items()}

        return losses.avg, top1.avg, top5.avg, epoch_time, aggregated_losses

    return losses.avg, top1.avg, top5.avg, epoch_time, None

def validate(val_loader, model, criterion, device, args):
    losses = AverageMeter('Loss')
    top1 = AverageMeter('Acc@1')
    top5 = AverageMeter('Acc@5')

    model.eval()

    with torch.no_grad():
        for i, (images, target) in enumerate(val_loader):
            images = images.to(device, non_blocking=True)
            target = target.to(device, non_blocking=True)

            output = model(images)
            loss = criterion(output, target)

            acc1, acc5 = accuracy(output, target, topk=(1, 5))
            losses.update(loss.item(), images.size(0))
            top1.update(acc1[0], images.size(0))
            top5.update(acc5[0], images.size(0))

            if i % args.print_freq == 0:
                print(f'Test [{i}/{len(val_loader)}] Loss {losses.avg} Acc@1 {top1.avg} Acc@5 {top5.avg}')

    print(f'Val: Acc@1 {top1.avg} Acc@5 {top5.avg}')

    return losses.avg, top1.avg, top5.avg

def validate_dynamic(val_loader, model, criterion, device, args):
    losses = AverageMeter('Loss')
    top1 = AverageMeter('Acc@1')
    top5 = AverageMeter('Acc@5')

    model.eval()

    with torch.no_grad():
        for i, batches in enumerate(val_loader):
            # handle multiple resolution batches
            for images, target, indices in batches:
                images = images.to(device, non_blocking=True)
                target = target.to(device, non_blocking=True)

                output = model(images)
                loss = criterion(output, target)

                acc1, acc5 = accuracy(output, target, topk=(1, 5))
                losses.update(loss.item(), images.size(0))
                top1.update(acc1[0], images.size(0))
                top5.update(acc5[0], images.size(0))

            if i % args.print_freq == 0:
                print(f'Test [{i}/{len(val_loader)}] Loss {losses.avg} Acc@1 {top1.avg} Acc@5 {top5.avg}')

    print(f'Val: Acc@1 {top1.avg} Acc@5 {top5.avg}')

    return losses.avg, top1.avg, top5.avg

def main():
    args = parse_args()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    print(f'resolution range: {args.r_min} - {args.r_max}')
    print(f'gamma: {args.gamma}')

    resolution_schedule = compute_resolution_schedule(args.r_max, args.r_min, args.gamma)
    ratio_schedule = [float(x) for x in args.ratio_schedule.split(',')]

    model = models.__dict__[args.arch](num_classes=10)
    model = model.to(device)

    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)

    criterion = nn.CrossEntropyLoss().to(device)
    optimizer = optim.SGD(model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
    scheduler = StepLR(optimizer, step_size=30, gamma=0.1)

    traindir = os.path.join(args.data, 'train')
    valdir = os.path.join(args.data, 'val')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

    base_train_dataset = datasets.ImageFolder(traindir)
    initial_num_samples = len(base_train_dataset)
    sample_resolutions = {i: args.r_max for i in range(initial_num_samples)}

    active_indices = list(range(initial_num_samples))

    val_dataset = datasets.ImageFolder(valdir, transforms.Compose([transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize]))

    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True)

    log_file = f'training_log_{args.arch}_drop.csv'
    with open(log_file, 'w', newline='') as f:
        writer = csv.writer(f)
        writer.writerow(['epoch', 'train_loss', 'train_top1', 'epoch_time', 'num_active_samples'])

    print(f'logging to: {log_file}')

    total_start = time.time()
    cumulative_losses = {}
    reassignment_count = 0

    for epoch in range(args.epochs):
        train_dataset_full = DynamicResolutionDataset(base_train_dataset, sample_resolutions, normalize)
        train_dataset = SubsetDynamicResolutionDataset(train_dataset_full, active_indices)
        train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True, collate_fn=collate_by_resolution)

        # track loss for next drop
        will_reassign = (epoch + 1) % 5 == 0 and epoch >= 4
        track_loss = will_reassign

        result = train_epoch(train_loader, model, criterion, optimizer, epoch, device, args, track_loss)
        train_loss, train_top1, train_top5, epoch_time, sample_losses = result

        if sample_losses is not None:
            for idx, loss_val in sample_losses.items():
                if idx not in cumulative_losses:
                    cumulative_losses[idx] = []

                cumulative_losses[idx].append(loss_val)

        scheduler.step()

        with open(log_file, 'a', newline='') as f:
            writer = csv.writer(f)
            writer.writerow([epoch, train_loss, train_top1.item(), epoch_time, len(active_indices)])

        print(f'Epoch {epoch} completed in {epoch_time}s')

        # drop samples every reassign_epoch starting from reassign_epoch-1
        if (epoch + 1) % args.reassign_epoch == 0 and epoch >= args.reassign_epoch - 1:
            print(f'Sample dropping after epoch {epoch}')
            avg_losses = {idx: np.mean(losses) for idx, losses in cumulative_losses.items()}

            # sort by loss (ascending) and drop easiest samples
            sorted_samples = sorted(avg_losses.items(), key=lambda x: x[1])

            current_ratio = ratio_schedule[min(reassignment_count, len(ratio_schedule) - 1)]
            num_active = len(active_indices)
            num_to_drop = int(num_active * current_ratio)

            candidates_to_drop = [idx for idx, _ in sorted_samples]
            samples_to_drop = candidates_to_drop[:num_to_drop]

            if num_to_drop > 0 and len(samples_to_drop) > 0:
                drop_set = set(samples_to_drop)
                active_indices = [idx for idx in active_indices if idx not in drop_set]

            reassignment_count += 1

            print(f'current drop ratio: {current_ratio*100}%')
            print(f'easiest samples dropped: {len(samples_to_drop)} / {num_active} ({current_ratio*100}%)')
            print(f'active samples remaining: {len(active_indices)} / {initial_num_samples}')

            cumulative_losses = {}

    total_time = time.time() - total_start
    print(f'total training time: {total_time}s')

    test_loss, test_top1, test_top5 = validate(val_loader, model, criterion, device, args)
    print(f'final test results:')
    print(f'Test Loss: {test_loss}')
    print(f'Test Acc@1: {test_top1}')
    print(f'Test Acc@5: {test_top5}')

    val_base_dataset = datasets.ImageFolder(valdir)
    num_val_samples = len(val_base_dataset)

    # assign same resolution distribution as training
    res_counts = {}
    for res in sample_resolutions.values():
        res_counts[res] = res_counts.get(res, 0) + 1

    # randomly assign test samples to match training distribution
    val_sample_resolutions = {}
    train_resolutions = list(sample_resolutions.values())
    import random
    random.seed(42)
    for i in range(num_val_samples):
        val_sample_resolutions[i] = random.choice(train_resolutions)

    val_dynamic_dataset = DynamicResolutionDataset(val_base_dataset, val_sample_resolutions, normalize)
    val_dynamic_loader = DataLoader(val_dynamic_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True, collate_fn=collate_by_resolution)

    test_loss_dynamic, test_top1_dynamic, test_top5_dynamic = validate_dynamic(val_dynamic_loader, model, criterion, device, args)
    print(f'final test results (dynamic res):')
    print(f'Test Loss: {test_loss_dynamic}')
    print(f'Test Acc@1: {test_top1_dynamic}')
    print(f'Test Acc@5: {test_top5_dynamic}')

    with open(log_file, 'a', newline='') as f:
        writer = csv.writer(f)
        writer.writerow(['FINAL_TEST_224', test_loss, test_top1.item(), total_time, len(active_indices)])
        writer.writerow(['FINAL_TEST_DYNAMIC', test_loss_dynamic, test_top1_dynamic.item(), total_time, len(active_indices)])

    print(f'results saved to: {log_file}')

    # save model to google drive
    save_path = '/content/drive/MyDrive/dynamic_res/model_drop.pth'
    torch.save(model.state_dict(), save_path)
    print(f'\nmodel saved to: {save_path}')

In [None]:
sys.argv = ['train.py', data_path, '--arch', 'resnet18', '--epochs', '30', '--batch-size', '128', '--workers', '2', '--r-min', '112', '--r-max', '224',
            '--gamma', '0.5', '--reassign-epoch', '10', '--ratio-schedule', '0.1,0.2']
main()

In [None]:
from google.colab import files
files.download('training_log_resnet18_drop.csv')

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>