In [None]:
import os

import fastai.vision.core
import torch
from PIL import Image
from fastai.data.core import DataLoaders
from fastai.data.load import DataLoader
from fastai.learner import Learner
from fastai.metrics import accuracy
from fastai.optimizer import SGD
from fastai.vision.data import ImageDataLoaders
from fastai.vision.learner import vision_learner
from fastbook import Path, tensor
from torch import Tensor, nn
import torch.nn.functional as F
from torchvision.models import resnet18

In [None]:
path = Path(os.getenv('MNIST_PATH'))
training_ones = (path / "train" / "1").ls().sorted()
training_fives = (path / "train" / "5").ls().sorted()
testing_ones = (path / "valid" / "1").ls().sorted()
testing_fives = (path / "valid" / "5").ls().sorted()
training_ones, training_fives, testing_ones, testing_fives


In [None]:
training_ones_tensor = torch.stack([tensor(Image.open(img)) for img in training_ones]).float() / 255
testing_ones_tensor = torch.stack([tensor(Image.open(img)) for img in testing_ones]).float() / 255
training_fives_tensor = torch.stack([tensor(Image.open(img)) for img in training_fives]).float() / 255
testing_fives_tensor = torch.stack([tensor(Image.open(img)) for img in testing_fives]).float() / 255
train_x = torch.cat([
    torch.reshape(training_ones_tensor, (len(training_ones_tensor), 28 * 28)),
    torch.reshape(training_fives_tensor, (len(training_fives_tensor), 28 * 28)),
])

train_y = torch.cat([torch.full((len(training_ones_tensor),), 1), torch.full((len(training_fives_tensor),), 0)])
valid_x = torch.cat([
    torch.reshape(testing_ones_tensor, (len(testing_ones_tensor), 28 * 28)),
    torch.reshape(testing_fives_tensor, (len(testing_fives_tensor), 28 * 28)),
])
valid_y = torch.cat([torch.full((len(testing_ones_tensor),), 1), torch.full((len(testing_fives_tensor),), 0)])
train_x.shape, train_y.shape, valid_x.shape, valid_y.shape


In [None]:
train_dl = DataLoader(list(zip(train_x, train_y)), batch_size=256)
valid_dl = DataLoader(list(zip(valid_x, valid_y)), batch_size=256)
dls = DataLoaders(train_dl, valid_dl)

In [None]:
def batch_accuracy(xb, yb):
    preds = xb.sigmoid()
    correct = (preds>0.5) == yb
    return correct.float().mean()

def loss(pred: torch.Tensor, target: torch.Tensor):
    pred = pred.sigmoid()
    return torch.where(target==1, 1 - pred, pred).mean()

learn = Learner(dls, nn.Linear(28*28,1), opt_func=SGD, loss_func=loss, metrics=batch_accuracy)
learn.fit_one_cycle(10, 1e-3)

In [None]:
dls = ImageDataLoaders.from_folder(path)
learn: Learner = vision_learner(dls, resnet18, pretrained=False,
                    loss_func=F.cross_entropy, metrics=accuracy)
learn.fit_one_cycle(3, 0.1)

In [None]:
learn.export()

In [None]:
import torch
print(torch.cuda.is_available())
print(torch.version.cuda)

In [None]:
from fastai.learner import load_learner

learn_inf = load_learner(path/'export.pkl', cpu=False)

img_to_test = os.getenv('MNIST_PATH') + '/valid/0/1001.png'
if not os.path.exists(img_to_test):
    raise Exception(f'file {img_to_test} does not exist')
learn_inf.predict(Path(img_to_test)) 