In [1]:
import torch
import torch.nn as nn
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

In [2]:
trans = transforms.Compose([
    transforms.TrivialAugmentWide(),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor()
])
train_data = datasets.CIFAR10(root="../data", download=False, train=True, transform=trans)
test_data = datasets.CIFAR10(root="../data", download=False, train=False, transform=transforms.ToTensor())

In [3]:
batch_size = 20
train_dataloader = DataLoader(train_data, shuffle=True, batch_size=batch_size)
test_dataloader = DataLoader(test_data, shuffle=False, batch_size=batch_size)

In [4]:
def conv_block(in_channels, out_channels, pool=False):
    layers = [nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
              nn.BatchNorm2d(out_channels),
              nn.ReLU(inplace=True)
             ]
    if pool: layers.append(nn.MaxPool2d(2))
    return nn.Sequential(*layers)

In [5]:
class Resnet9(nn.Module):
    def __init__(self, in_channels, num_classes):
        super().__init__()
        self.conv1 = conv_block(in_channels, 64)
        self.conv2 = conv_block(64, 128, pool=True)
        self.resnet1 = nn.Sequential(conv_block(128, 128), conv_block(128, 128))

        self.conv3 = conv_block(128, 256, pool=True)
        self.conv4 = conv_block(256, 512, pool=True)
        self.resnet2 = nn.Sequential(conv_block(512, 512), conv_block(512, 512))

        self.classifier = nn.Sequential(nn.MaxPool2d(4),
                                       nn.Flatten(),
                                       nn.Linear(512, num_classes))

    def forward(self, x):
        out = self.conv1(x)
        out = self.conv2(out)
        out = self.resnet1(out) + out
        out = self.conv3(out)
        out = self.conv4(out)
        out = self.resnet2(out) + out
        out = self.classifier(out)
        return out
        
device = "cuda"
model = Resnet9(3, 10).to(device)

In [6]:
torch.manual_seed(42)
test = torch.rand(1, 3, 32, 32).to(device)
model(test)

tensor([[-2.1876,  3.0599,  0.4671,  0.3286, -2.3667, -1.6634, -2.4628,  3.8249,
          0.6262,  0.7074]], device='cuda:0', grad_fn=<AddmmBackward0>)

In [7]:
from torch.optim import lr_scheduler
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, weight_decay=0.0001)
scheduler = lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.2)
epochs = 15

In [9]:
from sklearn.metrics import accuracy_score
from timeit import default_timer as timer
start = timer()
for epoch in range(epochs):
    loss_train = 0
    for x, y in train_dataloader:
        x, y = x.to(device), y.to(device)
        y_pred = model(x)
        loss = criterion(y_pred, y)
        loss_train += loss
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
        optimizer.step()
    scheduler.step()

    loss_train /= len(train_dataloader)
    loss_test, acc =  0, 0
    model.eval()
    with torch.inference_mode():
        for x, y in test_dataloader:
            x, y = x.to(device), y.to(device)
            y_pred = model(x)
            loss_test += criterion(y_pred, y)
            acc += accuracy_score(torch.softmax(y_pred, dim=1).argmax(dim=1).cpu(), y.cpu())
        loss_test /= len(test_dataloader)
        acc /= len(test_dataloader)
    print(f"epoch : {epoch}, loss : {loss_train:.4f}, loss_test : {loss_test:.4f}, acc : {(acc*100):.3f}, lr : {optimizer.state_dict()['param_groups'][0]['lr']:.5f}")
end = timer()
round(end-start), "sec"

epoch : 0, loss : 3.0364, loss_test : 1.8241, acc : 32.990, lr : 0.10000
epoch : 1, loss : 1.5945, loss_test : 1.0614, acc : 61.600, lr : 0.10000
epoch : 2, loss : 1.2576, loss_test : 0.8913, acc : 69.080, lr : 0.10000
epoch : 3, loss : 1.0894, loss_test : 0.6882, acc : 75.960, lr : 0.10000
epoch : 4, loss : 0.9746, loss_test : 0.6417, acc : 78.240, lr : 0.10000
epoch : 5, loss : 0.8938, loss_test : 0.5539, acc : 80.720, lr : 0.10000
epoch : 6, loss : 0.8437, loss_test : 0.5411, acc : 81.120, lr : 0.10000
epoch : 7, loss : 0.7884, loss_test : 0.5015, acc : 82.570, lr : 0.10000
epoch : 8, loss : 0.7622, loss_test : 0.5142, acc : 82.270, lr : 0.10000
epoch : 9, loss : 0.7263, loss_test : 0.5411, acc : 81.960, lr : 0.02000
epoch : 10, loss : 0.5714, loss_test : 0.3813, acc : 86.970, lr : 0.02000
epoch : 11, loss : 0.5376, loss_test : 0.3658, acc : 87.900, lr : 0.02000
epoch : 12, loss : 0.5163, loss_test : 0.3619, acc : 87.780, lr : 0.02000
epoch : 13, loss : 0.5077, loss_test : 0.3559, a

(299, 'sec')