In [None]:
import torch, math, time
import torch.nn.functional as F
from torch import nn, Tensor
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from wzh.transformer import Transformer

torch.backends.cudnn.benchmark = False

print(torch.__version__)
device = "cuda"

n_class = 10

patch_shape = (4, 4)
d_patch = math.prod(patch_shape) * 3
n_patch = 3 * 32 * 32 // d_patch

dropout = 0.0
epochs = 20
batch_size = 384
learn_rate = 1e-4

In [None]:
class Baseline(nn.Module):
    def __init__(self):
        super().__init__()
        dim_model = 384
        self.embedding = nn.Linear(d_patch, dim_model)
        self.model = Transformer(
            nlayer=6,
            dim_model=dim_model,
            num_head=8,
            max_seq_len=n_patch,
            glu_attn=False,
        )
        self.output = nn.Linear(dim_model, n_class)

    def forward(self, x):
        x = F.unfold(x, patch_shape, stride=patch_shape).mT
        x = self.embedding(x)
        x = self.model(x)
        x = self.output(x)
        x = x.mean(-1)
        return x


class GLUAttention(nn.Module):
    def __init__(self):
        super().__init__()
        dim_model = 384
        self.embedding = nn.Linear(d_patch, dim_model)
        self.model = Transformer(
            nlayer=6,
            dim_model=dim_model,
            num_head=8,
            max_seq_len=n_patch,
            glu_attn=True,
        )
        self.output = nn.Linear(dim_model, n_class)

    def forward(self, x):
        x = F.unfold(x, patch_shape, stride=patch_shape).mT
        x = self.embedding(x)
        x = self.model(x)
        x = self.output(x)
        x = x.mean(-1)
        return x

In [None]:
def train(dataloader, model, loss_fn, optimizer):
    total_time = time.time()
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.train()
    avg_loss, acc = 0, 0
    for i, (source, target) in enumerate(dataloader):
        source: Tensor = source.to(device, non_blocking=True)
        target: Tensor = target.to(device, non_blocking=True)
        pred = model(source)
        loss = loss_fn(pred, target)
        loss.backward(), optimizer.step(), optimizer.zero_grad()
        with torch.no_grad():
            avg_loss += loss.item()
            acc += (pred.argmax(1) == target).type(torch.float).sum().item()
    avg_loss /= num_batches
    acc /= size
    total_time = time.time() - total_time
    return (acc, avg_loss, total_time)


def val(dataloader, model, loss_fn):
    total_time = time.time()
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    avg_loss, acc = 0, 0
    with torch.no_grad():
        for source, target in dataloader:
            source: Tensor = source.to(device, non_blocking=True)
            target: Tensor = target.to(device, non_blocking=True)
            pred = model(source)
            avg_loss += loss_fn(pred, target).item()
            acc += (pred.argmax(1) == target).type(torch.float).sum().item()
    avg_loss /= num_batches
    acc /= size
    total_time = time.time() - total_time
    return (acc, avg_loss, total_time)


transform_train = transforms.Compose(
    [
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.RandAugment(num_ops=2, magnitude=10),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        transforms.RandomErasing(
            p=0.2, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=False
        ),
    ]
)

transform_test = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ]
)
train_data = datasets.CIFAR10(
    root="data",
    train=True,
    download=True,
    transform=transform_train,
)
val_data = datasets.CIFAR10(
    root="data",
    train=False,
    download=True,
    transform=transform_test,
)

In [None]:
model_list: list[nn.Module] = [Baseline, GLUAttention]

for model_creator in model_list:
    for i in range(10):
        model = model_creator().to(device)
        if i == 0:
            print(model)

        train_dataloader = DataLoader(
            train_data, batch_size, shuffle=True, num_workers=4, pin_memory=True
        )
        val_dataloader = DataLoader(
            val_data, batch_size, num_workers=4, pin_memory=True
        )

        loss_fn = nn.CrossEntropyLoss()
        optimizer = torch.optim.AdamW(model.parameters(), learn_rate)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)

        train_total_time = 0
        val_total_time = 0
        print(
            "epoch, train acc,   val acc,train loss,  val loss,train time,  val time,total time"
        )
        for epoch in range(1, epochs + 1):
            (train_acc, train_loss, train_time) = train(
                train_dataloader, model, loss_fn, optimizer
            )
            scheduler.step()
            (val_acc, val_loss, val_time) = val(val_dataloader, model, loss_fn)
            train_total_time += train_time
            val_total_time += val_time

            print(
                f"{epoch:>5},{train_acc:>10.3f},{val_acc:>10.3f},{train_loss:>10f},{val_loss:>10f},{train_time:>10.1f},{val_time:>10.1f},{train_total_time + val_total_time:>10.1f}"
            )