## Plans for training

### 1. Train city_block to predict cities correctly

### 2. Train country_block to predict countries given the predicted cities

##### I belive passing city to country_block instead of country to city_block will be beneficial to the model, as it is easier to predict country given the city than vice versa. Our model could predict countries with 12% accuracy, so let's be optimistic and say it can predict cities with also 12% given correct country. Our final accuracy will be 12% * 12% = 1.44%, as their's 12% chance that model receives correct country and 12% chance it predict correct country. Now let's say our model can predict cities also with 12% accuracy, but it doesn't need the country, and instead we have a lookup table for which coutnry has which city. Then our accuracy will be 12% as we don't have to predict the country. Sadly we don't have the lookup table, so we have to create a model that acts as the lookup table. Let's say the model can predict country with 50% accuracy given the city. Then our accuracy is 12% * 50% = 6%, which is a lot better and easier that previous idea.

In [1]:
import sys

sys.path.append('../')

import os
import shutil

if os.path.exists('./runs/booking'):
    shutil.rmtree('./runs/booking')


os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
os.environ['TORCH_USE_CUDA_DSA'] = '1'

In [2]:
from Model import Model

import torch
import torch.nn as nn
import torch.functional as F

from tqdm import tqdm

from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter('./runs/booking')

torch.autograd.set_detect_anomaly(True)

2024-01-27 16:34:49.905433: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-01-27 16:34:49.905482: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-01-27 16:34:49.906302: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-01-27 16:34:49.912299: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


<torch.autograd.anomaly_mode.set_detect_anomaly at 0x7ff56fc7f970>

In [3]:
BATCH_SIZE = 1024
EPOCHS = 10
LEARNING_RATE = 0.001
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [4]:
X = torch.load('X.pt')
# y_country = torch.load('y_country.pt')
y_city = torch.load('y_city.pt')

In [5]:
from torch.utils.data import dataloader
from Dataset import PartDataset, FullDataset

# country_dataset = PartDataset(X, y_country)
city_dataset = PartDataset(X, y_city)
# full_dataset = FullDataset(X, y_city, y_country)

# country_dataloader = dataloader.DataLoader(
    # country_dataset, batch_size=BATCH_SIZE, shuffle=True)

city_dataloader = dataloader.DataLoader(
    city_dataset, batch_size=BATCH_SIZE, shuffle=True)

# full_dataloader = dataloader.DataLoader(
    # full_dataset, batch_size=BATCH_SIZE, shuffle=True)

In [6]:
def accuracy_at_k(outputs, labels, k = 1):
    batch_size = labels.size(0)

    _, pred_indices = outputs.topk(k, 1, True, True)
    correct = torch.sum(torch.argmax(
        labels, dim=1).view(-1, 1) == pred_indices)

    accuracy = correct.item() / batch_size
    return accuracy * 100, correct.item()

In [13]:
def train(epochs: int, train_loader: torch.utils.data.DataLoader, model: nn.Module, optimizer: torch.optim.Optimizer, criterion: nn.Module) -> None:
    for epoch in range(epochs):
        print(f"STARTED EPOCH: {epoch}")

        model.train()

        train_loader = tqdm(train_loader, desc='Training')
 
        running_loss = 0.0
        total_correct = 0
        batch = 0
        total_samples = 0

        for i, (inputs, labels) in enumerate(train_loader, 1):
            inputs = torch.unsqueeze(inputs, 2)

            inputs = inputs.to(torch.float32).to(device)
            labels = labels.to(torch.int64)

            labels = nn.functional.one_hot(labels, num_classes=10276).to(torch.float32).to(device)

            optimizer.zero_grad()

            outputs = model(inputs)

            loss = criterion(outputs, labels)

            loss.backward()
            optimizer.step()

            _, correct = accuracy_at_k(outputs, labels, 1)
            total_correct += correct
            total_samples += labels.size(0)

            running_loss += loss.item()

            if i % 100 == 0:
                epoch_loss = running_loss / i
                batch_accuracy = total_correct / total_samples * 100

                writer.add_scalar(f'Loss/train/batch/{epoch}', epoch_loss, batch)
                writer.add_scalar(f'Accuracy/train/batch/{epoch}', batch_accuracy, batch)

                batch += 1

                writer.flush()

        epoch_loss = running_loss / len(train_loader)
        epoch_accuracy = total_correct / total_samples * 100

        print(f"Epoch {epoch} loss: {epoch_loss:.4f}, accuracy: {epoch_accuracy:.2f}%")

        writer.add_scalar('Loss/train/epoch', epoch_loss, epoch)
        writer.add_scalar('Accuracy/train/epoch', epoch_accuracy, epoch)

        writer.flush()

        torch.save(model.state_dict(), f"model_{epoch}.pth")


In [8]:
model = Model().to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
criterion = nn.CrossEntropyLoss()

In [14]:
train(EPOCHS, city_dataloader, model, optimizer, criterion)

STARTED EPOCH: 0


Training: 100%|██████████| 852/852 [04:14<00:00,  3.34it/s]


Epoch 0 loss: 9.2274, accuracy: 1.03%
STARTED EPOCH: 1


Training: 100%|██████████| 852/852 [04:13<00:00,  3.36it/s]


Epoch 1 loss: 9.2274, accuracy: 1.03%
STARTED EPOCH: 2


Training:   8%|▊         | 72/852 [00:22<04:00,  3.24it/s]


KeyboardInterrupt: 