Download dataset

In [2]:
!gsutil -m cp \
  "gs://quickdraw_dataset/full/numpy_bitmap/apple.npy" \
  "gs://quickdraw_dataset/full/numpy_bitmap/banana.npy" \
  "gs://quickdraw_dataset/full/numpy_bitmap/book.npy" \
  "gs://quickdraw_dataset/full/numpy_bitmap/cake.npy" \
  "gs://quickdraw_dataset/full/numpy_bitmap/car.npy" \
  "gs://quickdraw_dataset/full/numpy_bitmap/cloud.npy" \
  "gs://quickdraw_dataset/full/numpy_bitmap/cup.npy" \
  "gs://quickdraw_dataset/full/numpy_bitmap/door.npy" \
  "gs://quickdraw_dataset/full/numpy_bitmap/envelope.npy" \
  "gs://quickdraw_dataset/full/numpy_bitmap/eyeglasses.npy" \
  "gs://quickdraw_dataset/full/numpy_bitmap/finger.npy" \
  "gs://quickdraw_dataset/full/numpy_bitmap/fish.npy" \
  "gs://quickdraw_dataset/full/numpy_bitmap/flower.npy" \
  "gs://quickdraw_dataset/full/numpy_bitmap/fork.npy" \
  "gs://quickdraw_dataset/full/numpy_bitmap/grass.npy" \
  "gs://quickdraw_dataset/full/numpy_bitmap/guitar.npy" \
  "gs://quickdraw_dataset/full/numpy_bitmap/hamburger.npy" \
  "gs://quickdraw_dataset/full/numpy_bitmap/headphones.npy" \
  "gs://quickdraw_dataset/full/numpy_bitmap/house.npy" \
  "gs://quickdraw_dataset/full/numpy_bitmap/ice cream.npy" \
  "gs://quickdraw_dataset/full/numpy_bitmap/knife.npy" \
  "gs://quickdraw_dataset/full/numpy_bitmap/leaf.npy" \
  "gs://quickdraw_dataset/full/numpy_bitmap/pants.npy" \
  "gs://quickdraw_dataset/full/numpy_bitmap/pencil.npy" \
  "gs://quickdraw_dataset/full/numpy_bitmap/pizza.npy" \
  "gs://quickdraw_dataset/full/numpy_bitmap/spoon.npy" \
  "gs://quickdraw_dataset/full/numpy_bitmap/star.npy" \
  "gs://quickdraw_dataset/full/numpy_bitmap/sword.npy" \
  "gs://quickdraw_dataset/full/numpy_bitmap/t-shirt.npy" \
  "gs://quickdraw_dataset/full/numpy_bitmap/umbrella.npy" ./data/

Copying gs://quickdraw_dataset/full/numpy_bitmap/apple.npy...
Copying gs://quickdraw_dataset/full/numpy_bitmap/banana.npy...                  
==> NOTE: You are downloading one or more large file(s), which would            
run significantly faster if you enabled sliced object downloads. This
feature is enabled by default but requires that compiled crcmod be
installed (see "gsutil help crcmod").

Copying gs://quickdraw_dataset/full/numpy_bitmap/book.npy...
Copying gs://quickdraw_dataset/full/numpy_bitmap/cake.npy...                    
Copying gs://quickdraw_dataset/full/numpy_bitmap/car.npy...                     
Copying gs://quickdraw_dataset/full/numpy_bitmap/cloud.npy...                   
Copying gs://quickdraw_dataset/full/numpy_bitmap/cup.npy...                     
Copying gs://quickdraw_dataset/full/numpy_bitmap/door.npy...                    
Copying gs://quickdraw_dataset/full/numpy_bitmap/envelope.npy...                
Copying gs://quickdraw_dataset/full/numpy_bitmap/eyeg

Download all the library

Import all library

In [3]:
import os
import torch
import numpy as np
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader, random_split, Dataset
import torch.nn.functional as F
import copy
from sklearn.metrics import accuracy_score

Basic constant variables

In [7]:
BASE_DIR = os.getcwd()
DIR_FILE = os.path.join(BASE_DIR, 'data')
FILE_NAME = [os.path.join(DIR_FILE, f) for f in os.listdir(DIR_FILE)]
NUM_CLASSES = len(FILE_NAME)

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
NUM_EPOCHS = 20
BATCH_SIZE = 64
LEARNING_RATE = 0.001
CLASS_NAMES = [os.path.splitext(f)[0] for f in os.listdir(DIR_FILE)]

TRAIN_SPLIT = 0.8
CHECKPOINT_PATH = os.path.join(BASE_DIR, 'model.pt')

THRESHOLD_PROB = 0.5

Basic config variables

In [8]:
class QuickDrawDataset(Dataset):
    def __init__(self, file_names):
        self.file_names = file_names
        self.data_list = [np.load(f, mmap_mode='r') for f in file_names]
        self.labels = list(range(len(file_names)))

    def __len__(self):
        return sum(len(d) for d in self.data_list)

    def __getitem__(self, idx):
        total = 0
        for i, data in enumerate(self.data_list):
            if idx < total + len(data):
                item = data[idx - total].reshape(28, 28)
                item = torch.tensor(item, dtype=torch.float32).unsqueeze(0) / 255.0
                return item, self.labels[i]
            total += len(data)

dataset = QuickDrawDataset(FILE_NAME)

Basic load and save model

In [9]:
def load_model(
    model: nn.Module,
    optimizer: torch.optim.Adam,
    file_name: str,
) -> tuple[int, int]:
    checkpoint = torch.load(file_name)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint['epoch']
    loss = checkpoint['loss']
    return loss, epoch

def save_model(
    model: nn.Module,
    optimizer: torch.optim.Adam,
    epoch: int,
    file_name: str,
    loss: float
) -> None:
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss
    }
    torch.save(checkpoint, file_name)

Model draw_model

In [None]:
class DrawModel(nn.Module):
    def __init__(self, num_classes = 3):
        super().__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, stride=1, padding=0),
            nn.ReLU()
        )
        self.pool1 = nn.Sequential(
            nn.AvgPool2d(kernel_size=2, stride=2),
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1, padding=0),
            nn.ReLU()
        )
        self.pool2 = nn.Sequential(
            nn.AvgPool2d(kernel_size=2, stride=2),
        )
        self.flat = nn.Flatten() #256 gia tri 4 * 4 * 16
        self.fc_1 = nn.Sequential(
            nn.Linear(in_features=16 * 4 * 4, out_features=120, device=DEVICE),
            nn.ReLU()
        )
        self.fc_2 = nn.Sequential(
            nn.Linear(in_features=120, out_features=84, device=DEVICE),
            nn.ReLU()
        )
        self.out = nn.Linear(in_features=84, out_features=num_classes, device=DEVICE)
    def forward(self, input):
        input = self.conv1(input)
        input = self.pool1(input)
        input = self.conv2(input)
        input = self.pool2(input)
        input = self.flat(input)
        input = self.fc_1(input)
        input = self.fc_2(input)
        return self.out(input)

    def fit(
        self,
        dataset: Dataset,
        loss: float,
        optimizer: torch.optim.Adam,
        criterion: nn.CrossEntropyLoss = nn.CrossEntropyLoss(),
        start_epoch: int = 1
    ) -> None:
        if start_epoch > NUM_EPOCHS:
            print(f'Training complete with full epoch')
            return

        train_size = int(TRAIN_SPLIT * len(dataset))
        valid_size = len(dataset) - int(TRAIN_SPLIT * len(dataset))
        train_dataset, valid_dataset = random_split(dataset, [train_size, valid_size])

        train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
        valid_loader = DataLoader(valid_dataset, batch_size=BATCH_SIZE)

        best_model = copy.deepcopy(self.state_dict())
        best_valid_loss = loss

        for epoch in range(start_epoch, NUM_EPOCHS + 1):
            self.train()
            print(f'Running with epoch: {epoch} / {NUM_EPOCHS}')
            epoch_train_loss = 0
            total_loop_train = 0
            acc_score = 0

            #Train loop
            for batch in train_loader:
                x_batch, y_batch = batch
                x_batch, y_batch = x_batch.to(DEVICE), y_batch.to(DEVICE)
                # train step
                logits = self(x_batch)

                #Calculate loss with prediction and true label with CrossEntropyLoss
                loss: torch.Tensor = criterion(logits, y_batch)

                # Backpropagation
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                # Sum of total loss in one epoch
                epoch_train_loss += loss.item()

                #Get the highest probability of each class
                pred = logits.argmax(dim = 1)
                F.softmax
                #Calculate the accuracy of prediction
                acc_score += accuracy_score(y_batch.cpu(), pred.cpu())
                total_loop_train += 1

            #Valid
            self.eval()
            with torch.no_grad():
                total_loop_valid = 0
                valid_score = 0
                epoch_valid_loss = 0
                for batch in valid_loader:
                    x_batch, y_batch = batch
                    x_batch, y_batch = x_batch.to(DEVICE), y_batch.to(DEVICE)
                    logits = self(x_batch)
                    loss: torch.Tensor = criterion(logits, y_batch)
                    epoch_valid_loss += loss.item()
                    pred = logits.argmax(dim = 1)
                    valid_score += accuracy_score(y_batch.cpu(), pred.cpu())
                    total_loop_valid += 1

            acc_train_score = acc_score / total_loop_train
            acc_valid_score = valid_score / total_loop_valid
            loss_train = epoch_train_loss / len(train_loader)
            loss_valid = epoch_valid_loss / len(valid_loader)
            print(f'Training loss: {loss_train:.4f}. Training accuracy: {acc_train_score:.4f}')
            print(f'Validation loss: {loss_valid:.4f}. Validation accuracy: {acc_valid_score:.4f}')
            if loss_valid < best_valid_loss:
                best_valid_loss = loss_valid
                best_model = copy.deepcopy(self.state_dict())
                save_model(model=self, loss=best_valid_loss, optimizer=optimizer, epoch=epoch, file_name=CHECKPOINT_PATH)
                print(f'Best validation loss: {best_valid_loss:.4f}')
            self.load_state_dict(best_model)
            print(f'Model saved to {CHECKPOINT_PATH} successfully')

    def load(
        self,
        file_name: str
    ) -> nn.Module:
        checkpoint = torch.load(file_name)
        self.load_state_dict(checkpoint['model_state_dict'])
        self.eval()
        return self

    def predict(self, input):
        self.load(CHECKPOINT_PATH).to(DEVICE)
        input = input.to(DEVICE)
        with torch.no_grad():
            logits = self(input)
            preds = F.softmax(logits, dim=1)
            values, indices = preds.max(dim=1)
            return [CLASS_NAMES[idx.item()] if value.item() >= THRESHOLD_PROB else "UNKNOWN" for value, idx in zip(values, indices)]


Training model

In [11]:
model = DrawModel(num_classes=NUM_CLASSES).to(DEVICE)
optimizer = torch.optim.AdamW(params=model.parameters(), lr=LEARNING_RATE)
criterion = nn.CrossEntropyLoss()

start_epoch = 1
loss = 1
try:
    if os.path.exists(CHECKPOINT_PATH):
        loss, start_epoch = load_model(model, optimizer, CHECKPOINT_PATH) + 1
        print(f'Resume training from epoch {start_epoch}')
except Exception as e:
    print(f'Error loading model: {e}')

model.fit(dataset, loss, optimizer=optimizer, criterion=criterion, start_epoch=start_epoch)

Running with epoch: 1 / 20
Training loss: 0.5056. Training accuracy: 0.8630
Validation loss: 0.3661. Validation accuracy: 0.9013
Best validation loss: 0.3661
Model saved to /kaggle/working/model.pt successfully
Running with epoch: 2 / 20
Training loss: 0.3433. Training accuracy: 0.9074
Validation loss: 0.3319. Validation accuracy: 0.9094
Best validation loss: 0.3319
Model saved to /kaggle/working/model.pt successfully
Running with epoch: 3 / 20
Training loss: 0.3211. Training accuracy: 0.9134
Validation loss: 0.3150. Validation accuracy: 0.9152
Best validation loss: 0.3150
Model saved to /kaggle/working/model.pt successfully
Running with epoch: 4 / 20
Training loss: 0.3112. Training accuracy: 0.9159
Validation loss: 0.3084. Validation accuracy: 0.9164
Best validation loss: 0.3084
Model saved to /kaggle/working/model.pt successfully
Running with epoch: 5 / 20
Training loss: 0.3055. Training accuracy: 0.9174
Validation loss: 0.2987. Validation accuracy: 0.9194
Best validation loss: 0.298