In [2]:
import torch

def load_data():
    import PIL.Image
    import numpy as np
    import os
    
    xs = []
    ys = []
    
    for filename in os.listdir('data/cifar10'):
        if not filename.endswith('.jpg'):
            continue
        
        x = PIL.Image.open('data/cifar10/%s' % filename)
        
        x = torch.FloatTensor(np.array(x)) / 255
        
        x = x.permute(2, 0, 1)
        
        y = int(filename[0])
        
        xs.append(x)
        ys.append(y)
        
    return xs, ys

xs, ys = load_data()

len(xs), len(ys), xs[0].shape, ys[0]

(60000, 60000, torch.Size([3, 32, 32]), 6)

In [3]:
class Dataset(torch.utils.data.Dataset):
    def __len__(self):
        return len(xs)
    def __getitem__(self, i):
        return xs[i], ys[i]
    
dataset = Dataset()
x, y = dataset[0]

len(dataset), x.shape, y

(60000, torch.Size([3, 32, 32]), 6)

In [4]:
loader = torch.utils.data.DataLoader(dataset,
                                    batch_size=8,
                                    shuffle=True,
                                    drop_last=True)
x, y = next(iter(loader))

len(loader), x.shape, y

(7500, torch.Size([8, 3, 32, 32]), tensor([2, 9, 9, 7, 9, 2, 4, 8]))

In [8]:
 class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        
        self.cnn1 = torch.nn.Conv2d(in_channels=3,
                                    out_channels=16,
                                    kernel_size=5,
                                    stride=2,
                                    padding=0)
        self.cnn2 = torch.nn.Conv2d(in_channels=16,
                                    out_channels=32,
                                    kernel_size=3,
                                    stride=1,
                                    padding=1)
        self.cnn3 = torch.nn.Conv2d(in_channels=32,
                                    out_channels=128,
                                    kernel_size=7,
                                    stride=1,
                                    padding=0)
        
        self.pool = torch.nn.MaxPool2d(kernel_size=2, stride=2)
        self.relu = torch.nn.ReLU()
        self.fc = torch.nn.Linear(in_features=128, out_features=10)

    def forward(self, x):
        x = self.cnn1(x)
        x = self.relu(x)

        x = self.cnn2(x)
        x = self.relu(x)

        x = self.pool(x)

        x = self.cnn3(x)
        x = self.relu(x)

        x = x.flatten(start_dim=1)

        return self.fc(x)
    
model = Model()
model(torch.randn(8, 3, 32, 32)).shape

torch.Size([8, 10])

In [None]:
def train():
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    loss_fun = torch.nn.CrossEntropyLoss()
    model.train()

    for epoch in range(5):
        for i, (x, y) in enumerate(loader):
            out = model(x)
            loss = loss_fun(out, y)

            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            if i % 2000 == 0:
                acc = (out.argmax(dim=1) == y).sum().item() / len(y)
                print(epoch, i, loss.item(), acc)

    torch.save(model, 'model/3.model')


train()