In [2]:
import torch

def load_data():
    import PIL.Image
    import numpy as np
    import os
    
    xs = []
    ys = []
    
    for filename in os.listdir('data/cifar10'):
        if not filename.endswith('.jpg'):
            continue
        
        x = PIL.Image.open('data/cifar10/%s' % filename)
        
        x = torch.FloatTensor(np.array(x)) / 255
        
        x = x.permute(2, 0, 1)
        
        y = int(filename[0])
        
        xs.append(x)
        ys.append(y)
        
    return xs, ys

xs, ys = load_data()

len(xs), len(ys), xs[0].shape, ys[0]

(60000, 60000, torch.Size([3, 32, 32]), 6)

In [3]:
class Dataset(torch.utils.data.Dataset):
    def __len__(self):
        return len(xs)
    def __getitem__(self, i):
        return xs[i], ys[i]
    
dataset = Dataset()
x, y = dataset[0]

len(dataset), x.shape, y

(60000, torch.Size([3, 32, 32]), 6)

In [4]:
loader = torch.utils.data.DataLoader(dataset,
                                    batch_size=8,
                                    shuffle=True,
                                    drop_last=True)
x, y = next(iter(loader))

len(loader), x.shape, y

(7500, torch.Size([8, 3, 32, 32]), tensor([2, 9, 9, 7, 9, 2, 4, 8]))

In [8]:
 class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        
        self.cnn1 = torch.nn.Conv2d(in_channels=3,
                                    out_channels=16,
                                    kernel_size=5,
                                    stride=2,
                                    padding=0)
        self.cnn2 = torch.nn.Conv2d(in_channels=16,
                                    out_channels=32,
                                    kernel_size=3,
                                    stride=1,
                                    padding=1)
        self.cnn3 = torch.nn.Conv2d(in_channels=32,
                                    out_channels=128,
                                    kernel_size=7,
                                    stride=1,
                                    padding=0)
        
        self.pool = torch.nn.MaxPool2d(kernel_size=2, stride=2)
        self.relu = torch.nn.ReLU()
        self.fc = torch.nn.Linear(in_features=128, out_features=10)

    def forward(self, x):
        x = self.cnn1(x)
        x = self.relu(x)

        x = self.cnn2(x)
        x = self.relu(x)

        x = self.pool(x)

        x = self.cnn3(x)
        x = self.relu(x)

        x = x.flatten(start_dim=1)

        return self.fc(x)
    
model = Model()
model(torch.randn(8, 3, 32, 32)).shape

torch.Size([8, 10])

In [9]:
def train():
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    loss_fun = torch.nn.CrossEntropyLoss()
    model.train()

    for epoch in range(5):
        for i, (x, y) in enumerate(loader):
            out = model(x)
            loss = loss_fun(out, y)

            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            if i % 2000 == 0:
                acc = (out.argmax(dim=1) == y).sum().item() / len(y)
                print(epoch, i, loss.item(), acc)

    torch.save(model, 'model/3.model')


train()

0 0 2.3080222606658936 0.125
0 2000 1.7882035970687866 0.25
0 4000 1.4761797189712524 0.375
0 6000 1.4885742664337158 0.25
1 0 0.8778889179229736 0.75
1 2000 1.161247968673706 0.5
1 4000 1.0296813249588013 0.5
1 6000 1.4031336307525635 0.625
2 0 0.8329468369483948 0.75
2 2000 1.3011161088943481 0.625
2 4000 0.7739822864532471 0.75
2 6000 1.3758182525634766 0.375
3 0 1.3978806734085083 0.75
3 2000 0.7546510100364685 0.75
3 4000 0.534713864326477 0.75
3 6000 0.7918781638145447 0.75
4 0 1.0472137928009033 0.625
4 2000 0.8900766372680664 0.625
4 4000 1.5932656526565552 0.5
4 6000 0.5175949931144714 0.75


In [12]:
@torch.no_grad()
def test():
    model = torch.load('model/3.model')
    model.eval()

    correct = 0
    total = 0
    for i in range(100):
        x, y = next(iter(loader))

        out = model(x).argmax(dim=1)
        print(out)

        correct += (out == y).sum().item()
        total += len(y)

    print(correct / total)


test()

tensor([0, 4, 2, 2, 6, 5, 7, 3])
tensor([5, 3, 1, 4, 7, 6, 7, 7])
tensor([4, 9, 8, 0, 0, 6, 6, 4])
tensor([4, 6, 9, 3, 4, 8, 6, 9])
tensor([7, 3, 6, 1, 4, 1, 8, 4])
tensor([3, 1, 0, 5, 1, 1, 1, 6])
tensor([6, 0, 4, 6, 7, 4, 6, 9])
tensor([9, 0, 4, 3, 6, 0, 4, 1])
tensor([0, 2, 0, 5, 5, 1, 3, 4])
tensor([1, 3, 6, 5, 0, 9, 9, 2])
tensor([8, 1, 7, 5, 6, 8, 7, 6])
tensor([4, 4, 6, 7, 4, 7, 2, 5])
tensor([2, 4, 3, 8, 0, 4, 3, 7])
tensor([1, 9, 6, 0, 1, 5, 5, 8])
tensor([4, 4, 8, 4, 9, 4, 4, 4])
tensor([0, 0, 8, 3, 1, 1, 0, 5])
tensor([5, 4, 0, 9, 7, 6, 0, 3])
tensor([2, 2, 4, 4, 5, 3, 4, 3])
tensor([9, 2, 7, 1, 6, 5, 4, 8])
tensor([9, 7, 4, 3, 1, 4, 7, 0])
tensor([7, 1, 2, 9, 9, 9, 3, 7])
tensor([5, 3, 7, 4, 4, 4, 7, 3])
tensor([7, 5, 3, 4, 1, 3, 4, 8])
tensor([0, 6, 6, 1, 7, 5, 5, 5])
tensor([5, 0, 7, 5, 1, 4, 5, 6])
tensor([4, 0, 7, 9, 9, 7, 3, 7])
tensor([1, 8, 2, 4, 0, 4, 0, 8])
tensor([1, 3, 2, 7, 1, 3, 0, 6])
tensor([8, 7, 4, 9, 3, 5, 8, 0])
tensor([9, 7, 5, 3, 1, 6, 4, 5])
tensor([7,