In [1]:
import sys

sys.path.append('../')

import os
import shutil

if os.path.exists('./runs/booking'):
    shutil.rmtree('./runs/booking')

In [2]:
from Model import Model, CountryBlock, CityBlock

import torch
import torch.nn as nn
import torch.functional as F

from tqdm import tqdm

from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter('./runs/booking')

torch.autograd.set_detect_anomaly(True)

2024-01-24 21:11:43.380774: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-01-24 21:11:43.380821: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-01-24 21:11:43.381722: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-01-24 21:11:43.387697: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


<torch.autograd.anomaly_mode.set_detect_anomaly at 0x7f405e50e4a0>

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# device = torch.device('cpu')

X = torch.load('X.pt').to(device)
y_country = torch.load('y_country.pt').to(device)
y_city = torch.load('y_city.pt').to(device)

In [4]:
from torch.utils.data import dataloader
from Dataset import PartDataset, FullDataset

country_dataset = PartDataset(X, y_country)
city_dataset = PartDataset(X, y_city)
full_dataset = FullDataset(X, y_city, y_country)

country_dataloader = dataloader.DataLoader(
    country_dataset, batch_size=1024, shuffle=True)

city_dataloader = dataloader.DataLoader(
    city_dataset, batch_size=1024, shuffle=True)

full_dataloader = dataloader.DataLoader(
    full_dataset, batch_size=1024, shuffle=True)

In [5]:
def train_country(epochs: int, train_loader: torch.utils.data.DataLoader, model: nn.Module, optimizer: torch.optim.Optimizer, criterion: nn.Module, embedding_model: nn.Module = None) -> None:
    for epoch in range(epochs):
        print(f"STARTED EPOCH: {epoch}")

        model.train()

        train_loader = tqdm(train_loader, desc='Training')
 
        running_loss = 0.0
        correct_predictions = 0
        total_samples = 0

        for inputs, labels in train_loader:
            inputs = torch.unsqueeze(inputs, 2)

            inputs = inputs.to(torch.float32)
            labels = labels.to(torch.float32)

            optimizer.zero_grad()

            outputs = model(inputs)
            
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

            total_samples += len(inputs)

            for i in range(len(outputs)):
                if torch.argmax(outputs[i]) == torch.argmax(labels[i]):
                    correct_predictions += 1

        epoch_loss = running_loss / len(train_loader)
        epoch_accuracy = correct_predictions / total_samples * 100  

        print(f"Epoch {epoch} loss: {epoch_loss:.4f}, accuracy: {epoch_accuracy:.2f}%")

        writer.add_scalar('Loss/train', epoch_loss, epoch)
        writer.add_scalar('Accuracy/train', epoch_accuracy, epoch)

        writer.flush()

        torch.save(model.state_dict(), f"model_{epoch}.pth")


In [6]:
country_block = CountryBlock().to(device)

country_optimizer = torch.optim.Adam(country_block.parameters(), lr=0.01)
country_criterion = nn.CrossEntropyLoss()

In [7]:
train_country(10, country_dataloader, country_block, country_optimizer, country_criterion)

STARTED EPOCH: 0


Training: 100%|██████████| 2067/2067 [10:14<00:00,  3.36it/s]


Epoch 0 loss: 5.1804, accuracy: 10.13%
STARTED EPOCH: 1


Training: 100%|██████████| 2067/2067 [10:13<00:00,  3.37it/s]


Epoch 1 loss: 5.1804, accuracy: 10.14%
STARTED EPOCH: 2


Training: 100%|██████████| 2067/2067 [10:14<00:00,  3.37it/s]


Epoch 2 loss: 5.1804, accuracy: 10.14%
STARTED EPOCH: 3


Training: 100%|██████████| 2067/2067 [10:14<00:00,  3.36it/s]


Epoch 3 loss: 5.1804, accuracy: 10.14%
STARTED EPOCH: 4


Training:  44%|████▍     | 905/2067 [04:29<05:45,  3.36it/s]


KeyboardInterrupt: 