In [1]:
from utils_classification import *

# Load data

## MNIST

In [2]:
image_shape = (28, 28)
train_set_size = 60000
test_set_size = 10000

mnist_folder = os.path.join('data', 'MNIST')

train_images_path = os.path.join(mnist_folder, 'train-images-idx3-ubyte.gz')
train_labels_path = os.path.join(mnist_folder, 'train-labels-idx1-ubyte.gz')
test_images_path = os.path.join(mnist_folder, 't10k-images-idx3-ubyte.gz')
test_labels_path = os.path.join(mnist_folder, 't10k-labels-idx1-ubyte.gz')

train_images = extract_data(train_images_path, image_shape, train_set_size)
test_images = extract_data(test_images_path, image_shape, test_set_size)
train_labels = extract_labels(train_labels_path, train_set_size)
test_labels = extract_labels(test_labels_path, test_set_size)

## Segmented cards from task 0

In [3]:
filepath_task0 = os.path.join('data', 'classification_data', 'all_games_classification_series_mnist.pickle')
df_task0, df_numbers_only, df_not_numbers = load_segmentation_task0(filepath_task0)

In [4]:
transform = transforms.Compose(
            [
                transforms.ToPILImage(),
                transforms.RandomAffine(degrees=25, scale=(0.7, 1), shear=25),
                transforms.ToTensor(),
            ]
        )

labels_figures = {'J': 10, 'Q': 11, 'K': 12}

train_augmented_figs, train_augmented_figs_labels, val_augmented_figs, \
        val_augmented_figs_labels, test_figs, test_figs_labels = \
            get_train_val_test_figures(df_not_numbers, labels_figures, transform)

## Data loader and augmentation

In [5]:
transform_aug = transforms.Compose(
            [
                transforms.ToPILImage(),
                transforms.RandomAffine(degrees=25, scale=(0.8, 1.1), shear=20),
                transforms.ToTensor(),
            ]
        )

augmented_train = augment_data(train_images, transform_aug)

In [6]:
# Parameters of our DataLoader
params = {'batch_size': 128,
          'shuffle': True}

# Creation of a train/test dataset & dataloader
train_merged_images = np.concatenate((train_images, augmented_train, train_augmented_figs))
train_merged_labels = np.concatenate((train_labels, train_labels, train_augmented_figs_labels))

test_merged_images = np.concatenate((test_images, val_augmented_figs))
test_merged_labels = np.concatenate((test_labels, val_augmented_figs_labels)) 

ds_train = MNISTDataset(train_merged_images, train_merged_labels, transform=transforms.ToTensor())
dl_train = data.DataLoader(ds_train, **params)

ds_test = MNISTDataset(test_merged_images, test_merged_labels, transform=transforms.ToTensor())
dl_test = data.DataLoader(ds_test, **params)

In [7]:
task0_numbers = np.array([v for v in df_numbers_only.image.apply(img_as_ubyte)])

ds_test_task0 = MNISTDataset(task0_numbers, df_numbers_only['rank'].to_numpy().astype(np.int64), 
                             transform=transforms.ToTensor())
dl_test_task0 = data.DataLoader(ds_test_task0, **params)

In [8]:
ds_test_task0_figs = MNISTDataset(test_figs, 
                                  test_figs_labels.astype(np.int64), 
                                  transform=transforms.ToTensor())
dl_test_task0_figs = data.DataLoader(ds_test_task0_figs, **params)

# Training and testing

## Models

In [9]:
model = Net()

## Training

In [10]:
# Loss function
criterion = nn.CrossEntropyLoss()

# Optimizer
learning_rate = 1e-3
opt = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [11]:
epochs = 7

# Train & test our model until convergence
for e in range(epochs):
    print(f"Epoch {e+1}\n-------------------------------")
    train_loop(dl_train, model, criterion, opt)
    test_loop(dl_test, model, criterion)

Epoch 1
-------------------------------
It 1055/1055:	Loss train: 0.06059, Accuracy train: 97.56%%
Test Error:
	Avg loss: 0.00038, Accuracy: 98.29%

Epoch 2
-------------------------------
It 1055/1055:	Loss train: 0.01649, Accuracy train: 100.00%
Test Error:
	Avg loss: 0.00037, Accuracy: 98.43%

Epoch 3
-------------------------------
It 1055/1055:	Loss train: 0.01611, Accuracy train: 100.00%
Test Error:
	Avg loss: 0.00062, Accuracy: 98.07%

Epoch 4
-------------------------------
It 1055/1055:	Loss train: 0.00509, Accuracy train: 100.00%
Test Error:
	Avg loss: 0.00075, Accuracy: 97.96%

Epoch 5
-------------------------------
It 1055/1055:	Loss train: 0.00857, Accuracy train: 100.00%
Test Error:
	Avg loss: 0.00095, Accuracy: 97.77%

Epoch 6
-------------------------------
It 1055/1055:	Loss train: 0.00703, Accuracy train: 100.00%
Test Error:
	Avg loss: 0.00065, Accuracy: 98.17%

Epoch 7
-------------------------------
It 1055/1055:	Loss train: 0.07759, Accuracy train: 97.56%%
Test Er

In [12]:
test_loop(dl_test_task0, model, criterion)

Test Error:
	Avg loss: 0.00931, Accuracy: 82.44%



In [13]:
test_loop(dl_test_task0_figs, model, criterion)

Test Error:
	Avg loss: 0.02140, Accuracy: 94.44%



## Save model and predict

In [14]:
models_dir = os.path.join('data', 'models')

if not os.path.isdir(models_dir):
    os.mkdir(models_dir)

In [16]:
torch.save(model.state_dict(), os.path.join(models_dir, "model.pt"))

In [17]:
predict_rank(task0_numbers, Net, "data/models/model.pt")

tensor([ 8,  0,  5, 10, 10,  3,  2,  3,  0,  4,  4,  0,  6,  3,  2,  8,  8,  3,
         4,  7,  2,  6,  0,  7, 10,  5,  5,  0,  2,  6, 10,  4,  2,  6,  8,  2,
         1,  5,  0,  2,  2,  4, 10,  6, 10,  4,  8,  2,  5,  6,  5,  0,  0,  3,
         4, 10,  0,  7,  2,  2,  2,  1,  6,  4,  8,  2,  1,  8,  0,  3,  8,  5,
         8,  5,  3,  6,  6,  3,  2,  6,  2,  1,  1,  7,  6,  3,  2,  3,  7,  4,
         2,  6,  8,  7,  0,  8,  5,  0,  0,  5,  6,  3,  4,  0, 10,  2,  6,  2,
         4,  6,  5,  6,  5,  4,  2,  8,  8,  8,  8,  3,  2,  6,  3,  8,  3,  5,
         7,  4,  3,  8,  8,  8,  4,  2,  1,  0, 10,  0,  8,  8,  7,  0,  1,  2,
         0,  5,  2,  5,  4, 11,  2,  8,  4,  6,  3,  5,  6,  6,  2,  2,  2,  8,
         0,  2,  3,  2,  6,  8,  4,  4,  6,  3,  5,  5,  3,  6,  6,  3,  8,  5,
         1,  5,  2,  8,  4,  3,  4,  2,  6,  0, 10,  8,  7,  0,  0, 10,  6,  0,
         1,  4,  5,  4,  0,  7,  6, 10,  8,  5,  5,  3,  8,  1,  2,  8, 10,  6,
         2,  2,  3, 10,  0,  2,  3, 10, 