## Plans for training

### 1. Train city_block to predict cities correctly

### 2. Train country_block to predict countries given the predicted cities

##### I belive passing city to country_block instead of country to city_block will be beneficial to the model, as it is easier to predict country given the city than vice versa. Our model could predict countries with 12% accuracy, so let's be optimistic and say it can predict cities with also 12% given correct country. Our final accuracy will be 12% * 12% = 1.44%, as their's 12% chance that model receives correct country and 12% chance it predict correct country. Now let's say our model can predict cities also with 12% accuracy, but it doesn't need the country, and instead we have a lookup table for which coutnry has which city. Then our accuracy will be 12% as we don't have to predict the country. Sadly we don't have the lookup table, so we have to create a model that acts as the lookup table. Let's say the model can predict country with 50% accuracy given the city. Then our accuracy is 12% * 50% = 6%, which is a lot better and easier that previous idea.

In [2]:
import sys

sys.path.append('../')

import os
import shutil

if os.path.exists('./runs/booking'):
    shutil.rmtree('./runs/booking')


os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
os.environ['TORCH_USE_CUDA_DSA'] = '1'

In [3]:
from Model import Model, CountryBlock, CityBlock

import torch
import torch.nn as nn
import torch.functional as F

from tqdm import tqdm

from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter('./runs/booking')

torch.autograd.set_detect_anomaly(True)

2024-01-26 17:42:25.777729: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-01-26 17:42:25.777774: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-01-26 17:42:25.778635: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-01-26 17:42:25.784414: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


<torch.autograd.anomaly_mode.set_detect_anomaly at 0x7f6ce477d690>

In [4]:
BATCH_SIZE = 128
EPOCHS = 10
LEARNING_RATE = 0.001
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [5]:
X = torch.load('X.pt')
y_country = torch.load('y_country.pt')
y_city = torch.load('y_city.pt')

In [6]:
from torch.utils.data import dataloader
from Dataset import PartDataset, FullDataset

country_dataset = PartDataset(X, y_country)
city_dataset = PartDataset(X, y_city)
full_dataset = FullDataset(X, y_city, y_country)

country_dataloader = dataloader.DataLoader(
    country_dataset, batch_size=BATCH_SIZE, shuffle=True)

city_dataloader = dataloader.DataLoader(
    city_dataset, batch_size=BATCH_SIZE, shuffle=True)

full_dataloader = dataloader.DataLoader(
    full_dataset, batch_size=BATCH_SIZE, shuffle=True)

In [7]:
def train_city(epochs: int, train_loader: torch.utils.data.DataLoader, model: nn.Module, optimizer: torch.optim.Optimizer, criterion: nn.Module) -> None:
    # If you're wondering why I'm doing all this suff with gpu and freeing memory, it's because my gpu has only 8gb of vram
    # Which is constantly running out, so without freeing the memory and moving data to gpu only when used I can't train the model
    for epoch in range(epochs):
        print(f"STARTED EPOCH: {epoch}")

        model.train()

        train_loader = tqdm(train_loader, desc='Training')
 
        running_loss = 0.0
        correct_predictions = 0
        total_samples = 0
        batch = 0

        for i, (inputs, labels) in enumerate(train_loader, 1):
            inputs = torch.unsqueeze(inputs, 2)

            inputs = inputs.to(torch.float32).to(device)
            labels = labels.to(torch.int64)

            labels = nn.functional.one_hot(labels, num_classes=39901).to(torch.float32).to(device)

            optimizer.zero_grad()

            outputs = model(inputs)

            loss = criterion(outputs, labels)

            outputs = outputs.detach().cpu()
            labels = labels.detach().cpu()
            del inputs

            loss.backward()
            optimizer.step()

            running_loss += loss.item()

            total_samples += BATCH_SIZE

            for j in range(len(outputs)):
                if torch.argmax(outputs[j]) == torch.argmax(labels[j]):
                    correct_predictions += 1

            if i % 100 == 0:
                epoch_loss = running_loss / i
                epoch_accuracy = correct_predictions / total_samples * 100  

                writer.add_scalar('Loss/train/batch', epoch_loss, batch)
                writer.add_scalar('Accuracy/train/batch', epoch_accuracy, batch)

                batch += 1

                writer.flush()

        epoch_loss = running_loss / len(train_loader)
        epoch_accuracy = correct_predictions / total_samples * 100  

        print(f"Epoch {epoch} loss: {epoch_loss:.4f}, accuracy: {epoch_accuracy:.2f}%")

        writer.add_scalar('Loss/train/epoch', epoch_loss, epoch)
        writer.add_scalar('Accuracy/train/epoch', epoch_accuracy, epoch)

        writer.flush()

        torch.save(model.state_dict(), f"model_{epoch}.pth")


In [8]:
city_block = CityBlock().to(device)

optimizer = torch.optim.Adam(city_block.parameters(), lr=LEARNING_RATE)
criterion = nn.CrossEntropyLoss()

In [9]:
train_city(EPOCHS, city_dataloader, city_block, optimizer, criterion)

STARTED EPOCH: 0


Training: 100%|██████████| 7416/7416 [25:19<00:00,  4.88it/s]


Epoch 0 loss: 10.5867, accuracy: 0.76%
STARTED EPOCH: 1


Training: 100%|██████████| 7416/7416 [26:30<00:00,  4.66it/s]