In [1]:
import torch
import torchvision
from torchvision import transforms
from torch import nn
import time

In [2]:
train_transform = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

test_transform = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

data = torchvision.datasets.CIFAR10(root='/disk2/wjl/cv-dataset', 
                                    train=True, download=False, transform=train_transform)
train_dataset, val_dataset, _ = torch.utils.data.random_split(data, [40000, 10000, 0]) # total 50000

test_dataset = torchvision.datasets.CIFAR10(root='/disk2/wjl/cv-dataset',
                                            train=False, download=False, transform=test_transform)

classes = ('plane', 'car', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck')


In [3]:
import torch.utils

class GradAccDataset(torch.utils.data.IterableDataset):
    def __init__(self, dataset, global_batch_size, max_batch_size, random_batch_size):
        self.dataset = dataset
        self.loader = torch.utils.data.DataLoader(dataset, batch_size=global_batch_size, shuffle=True)
        self.global_batch_size = global_batch_size
        self.max_batch_size = max_batch_size
        self.random_batch_size = random_batch_size
        self.gb = None
        self.gb_idx = 0
        self.lb_idx = 0

    def reset(self):
        self.gb = None
        self.gb_idx = 0
        self.lb_idx = 0

    def __len__(self):
        return len(self.dataset)

    def _next_gb(self):
        while True:
            for b in self.loader:
                yield b

    def next_gb(self):
        assert self.gb_idx != len(self)
        next_gb_size = min(self.global_batch_size, len(self)-self.gb_idx)
        # self.gb = next(self.loader)
        self.gb = next(self._next_gb())
        imgs, labs = self.gb
        assert len(imgs) == next_gb_size
        self.lb_idx = 0
        self.gb_idx += next_gb_size
    
    def __iter__(self):
        worker_info = torch.utils.data.get_worker_info()
        assert worker_info is None, "This dataset is not compatible with multi-process loading"
        while True:
            if self.gb is None or self.lb_idx >= len(self.gb[0]):
                if self.gb_idx == len(self):
                    self.reset()
                    break
                self.next_gb()                

            imgs, labs = self.gb

            if self.random_batch_size:
                batch_size = torch.randint(1, self.max_batch_size+1, (1,)).item()
            else:
                batch_size = self.max_batch_size

            batch_size = min(len(imgs)-self.lb_idx, batch_size)
            lb = (imgs[self.lb_idx:self.lb_idx+batch_size], labs[self.lb_idx:self.lb_idx+batch_size])
            self.lb_idx += batch_size
            assert self.lb_idx <= len(imgs)
            assert batch_size > 0
            yield lb
        
        # return torch.utils.data.DataLoader(self.dataset, batch_size=batch_size, shuffle=True)

    def can_step(self):
        imgs, _ = self.gb
        # print(self.gb_idx, self.lb_idx, self.lb_idx == len(imgs))
        return self.lb_idx == len(imgs) 

In [4]:

random_batch_size = True

global_batch_size = 500
max_batch_size = 100

train_dataset = GradAccDataset(train_dataset, global_batch_size, max_batch_size, random_batch_size)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=None, shuffle=None, 
                                               num_workers=0, drop_last=False)

# train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=80, shuffle=True)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=80, shuffle=False)

In [5]:
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.base = torchvision.models.swin_t()
        self.fc = nn.Linear(1000, 10)

    def forward(self, x):
        x = self.base(x)
        x = self.fc(x)
        return x


In [6]:
model = Model().cuda()
criterion = nn.CrossEntropyLoss().cuda()

optimizer = torch.optim.SGD(model.parameters(), lr=0.001, 
                            momentum=0.9)

In [7]:
epoch = 10

model.train()

for e in range(epoch):
    running_loss = 0
    running_acc = 0

    epoch_begin_t = time.time()
    for i, b in enumerate(train_dataloader):
        images, labels = b
        lbs = len(images)
        # print(len(images))
        images = images.cuda()
        labels = labels.cuda()

        output = model(images)
        loss = criterion(output, labels)
        scaled_loss = loss * lbs / global_batch_size
        scaled_loss.backward()
        
        if train_dataset.can_step():
            optimizer.step()
            optimizer.zero_grad()

        running_loss += loss.item() * images.size(0)
        pred = torch.argmax(output.detach(), dim=1)
        running_acc += (pred == labels).sum().item()
    epoch_end_t = time.time()
 
    val_acc = 0
    with torch.no_grad():
        for i, b in enumerate(val_dataloader):
            images, labels = b
            images = images.cuda()
            labels = labels.cuda()

            output = model(images)
            
            pred = torch.argmax(output, dim=1)
            val_acc += (pred == labels).sum().item()
    epoch_loss = running_loss / len(train_dataset)
    epoch_train_acc = 100 * running_acc / len(train_dataset)
    epoch_val_acc = 100 * val_acc / len(val_dataset)
    train_thpt = len(train_dataset) / (epoch_end_t - epoch_begin_t)
    epoch_time = epoch_end_t - epoch_begin_t
    print(f'[epoch {e} loss {epoch_loss:.1f} acc {epoch_train_acc:.1f}% val acc {epoch_val_acc:.1f}%] {epoch_time:.1f} sec, thpt {train_thpt:.1f}')

[epoch 0 loss 2.1 acc 22.8% val acc 26.8%] 216.4 sec, thpt 184.8
[epoch 1 loss 2.0 acc 27.9% val acc 29.1%] 219.8 sec, thpt 182.0
[epoch 2 loss 1.9 acc 31.0% val acc 31.8%] 214.3 sec, thpt 186.7
[epoch 3 loss 1.8 acc 33.3% val acc 32.8%] 214.2 sec, thpt 186.7
[epoch 4 loss 1.8 acc 34.8% val acc 35.9%] 213.5 sec, thpt 187.4
[epoch 5 loss 1.8 acc 36.7% val acc 37.0%] 214.5 sec, thpt 186.5
[epoch 6 loss 1.7 acc 37.9% val acc 38.8%] 211.4 sec, thpt 189.2
[epoch 7 loss 1.7 acc 39.6% val acc 39.4%] 213.7 sec, thpt 187.2
[epoch 8 loss 1.6 acc 40.5% val acc 40.8%] 213.3 sec, thpt 187.6
