In [1]:
import sys

sys.path.append('../')

In [2]:
from Model import Model, CountryBlock, CityBlock

import torch
import torch.nn as nn
import torch.functional as F

from tqdm import tqdm

torch.autograd.set_detect_anomaly(True)

<torch.autograd.anomaly_mode.set_detect_anomaly at 0x1d8eb9d4410>

In [3]:
X = torch.load('X.pt')
y_country = torch.load('y_country.pt')
y_city = torch.load('y_city.pt')

In [4]:
from torch.utils.data import dataloader
from Dataset import PartDataset, FullDataset

country_dataset = PartDataset(X, y_country)
city_dataset = PartDataset(X, y_city)
full_dataset = FullDataset(X, y_city, y_country)

country_dataloader = dataloader.DataLoader(
    country_dataset, batch_size=1024, shuffle=True)

city_dataloader = dataloader.DataLoader(
    city_dataset, batch_size=1024, shuffle=True)

full_dataloader = dataloader.DataLoader(
    full_dataset, batch_size=1024, shuffle=True)

## Plans for training

### 1. First we will train the CountryBlock to predict the countries correctly

### 2. Next we will traing the CityBlock to predict the cities correctly, given true values for countries

### 3. We will combine them and train the model to synchronize

In [5]:
def train_country(epochs: int, train_loader: torch.utils.data.DataLoader, model: nn.Module, optimizer: torch.optim.Optimizer, criterion: nn.Module, embedding_model: nn.Module = None, tested_index: int = 0) -> None:
    for epoch in range(epochs):
        print(f"STARTED EPOCH: {epoch}")

        model.train()

        train_loader = tqdm(train_loader, desc='Training')
 
        running_loss = 0.0
        correct_predictions = 0
        total_samples = 0

        for inputs, labels in train_loader:
            inputs = torch.unsqueeze(inputs, 2)
                        
            if embedding_model is not None:
                inputs[:, 0] = embedding_model.city_embedding(inputs[:, 0].long()).squeeze(2).to(torch.float32)
                inputs[:, 1] = embedding_model.country_embedding(inputs[:, 1].long()).squeeze(2).to(torch.float32)
                inputs[:, 2] = embedding_model.country_embedding(inputs[:, 2].long()).squeeze(2).to(torch.float32)

            inputs = inputs.to(torch.float32)
            labels = labels.to(torch.float32)

            optimizer.zero_grad()

            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

            total_samples += len(inputs)

            for i in range(len(outputs)):
                if torch.argmax(outputs[i]) == torch.argmax(labels[i]):
                    correct_predictions += 1

        epoch_loss = running_loss / len(train_loader)
        epoch_accuracy = correct_predictions / total_samples * 100

        print(
            f"Epoch: {epoch} Loss: {epoch_loss:.4f} Accuracy: {epoch_accuracy:.2f}%")

        torch.save(model.state_dict(), f"model_{epoch}.pth")


In [6]:
model = Model()
country_block = CountryBlock()

country_optimizer = torch.optim.Adam(country_block.parameters(), lr=0.001)
country_criterion = nn.CrossEntropyLoss()

In [7]:
train_country(10, country_dataloader, country_block, country_optimizer, country_criterion, model, 2)

STARTED EPOCH: 0


  X = F.softmax(X)
Training:   1%|          | 5/927 [00:11<35:56,  2.34s/it]


KeyboardInterrupt: 