# Imports and declarations

In [2]:
import random
import os
import pandas as pd
import torch
import torch.nn.functional as F
import torch.nn as nn
import geopy.distance

In [3]:
TRAIN_PATH = "/content/drive/MyDrive/Tweet Geolocation/Data_train"
TEST_PATH = "/content/drive/MyDrive/Tweet Geolocation/Data_test"
WEIGHTS_PATH = '/content/drive/MyDrive/Tweet Geolocation/Saved_weights'

In [4]:
train_directory = os.fsencode(TRAIN_PATH)
test_directory = os.fsencode(TEST_PATH)

In [None]:
#Getting all countries in the train set. Needed to do this once to figure out the countries dictionary for encoding.
#train_countries = set()
#    
#for file in os.listdir(train_directory):
#  filename = os.fsdecode(file)
#  data = pd.read_csv(TRAIN_PATH + "/" + filename, sep=";")
#  train_countries.update(data.geo_country.unique())

In [None]:
#test_countries = set()
#    
#for file in os.listdir(test_directory):
#  filename = os.fsdecode(file)
#  data = pd.read_csv(TEST_PATH + "/" + filename, sep=";")
#  test_countries.update(data.geo_country.unique())

In [5]:
countries_num = 18
#Coding countries in different languages to the same values
def get_key_by_country(countryName):
  countries_dict = {'Argentina':1,
                    'Argentine':1,
                    'Aruba':2,
                    'Bolivia':3,
                    'Bolivie':3,
                    'Brasile':4,
                    'Brasilien':4,
                    'Brazil':4,
                    'Brésil':4,
                    'Chile':5,
                    'Chili':5,
                    'Colombia':6,
                    'Curaçao':7,
                    'Ecuador':8,
                    'Falkland Islands (Malvinas)':9,
                    'Fransk Guyana':10,
                    'French Guiana':10,
                    'Guyana':11,
                    'Panama':12,
                    'Paraguay':13,
                    'Peru':14,
                    'Suriname':15,
                    'Trinidad and Tobago':16,
                    'Uruguay':17,
                    'Venezuela':18
                    }
  try:
    return countries_dict[countryName]
  except(KeyError):
    return 0

In [6]:
#Here should be Unicode encoding, but it is far too computationally-intensive 
#for my current available hardware, sorry. Had to revert to OHE over a charset.

def encode_text(input_text:str):
  tweet_len = 280

  input_text = str(input_text).lower()
  charset = '0123456789 aáãbcdeéfghiíjklmnñoópqrstuúüvwxyz&@#_,.'
  char_num = len(charset)
  code = torch.zeros(tweet_len, char_num)
  for i, c in enumerate(input_text):
    try:
      char_code = charset.index(c)
      code[min(i,279),char_code] = 1 #min - in case one charater gets enumerated into several characters
    except(ValueError):
      continue

  return code.permute(1,0)

# Data preparation

In [7]:
from pandas.core.frame import DataFrame
class TweetsDataset(torch.utils.data.Dataset):
    def __init__(self, input_dataframe:DataFrame):
        input_dataframe = input_dataframe.sample(frac=1)
        self.len = input_dataframe.shape[0]
        self.text = input_dataframe.text
        self.target_country = input_dataframe.geo_country
        self.target_latitude = input_dataframe.latitude
        self.target_longitude = input_dataframe.longitude

    def __len__(self):
        return self.len

    def __getitem__(self, index):
        return encode_text(self.text[index]), \
        torch.tensor(get_key_by_country(self.target_country[index])), \
        torch.tensor((float(self.target_latitude[index]), float(self.target_longitude[index])))

# Model preparation

In [13]:
class TweetCharacterCNN(nn.Module):
    def __init__(self, charset_len = 51, countries_num = 18, classifier_dim = 200):
      super().__init__() #In -> 280
      self.conv1 = nn.Conv1d(in_channels = charset_len, out_channels = charset_len, kernel_size=5) #Out -> 276
      self.maxp1 = nn.MaxPool1d(3) #Out -> 92
      self.conv2 = nn.Conv1d(in_channels = charset_len, out_channels = charset_len, kernel_size=6) #Out -> 87
      self.maxp2 = nn.MaxPool1d(3) #Out -> 29
      self.conv3 = nn.Conv1d(in_channels = charset_len, out_channels = charset_len, kernel_size=3) #Out -> 27
      self.conv4 = nn.Conv1d(in_channels = charset_len, out_channels = charset_len, kernel_size=3) #Out -> 25
      self.conv5 = nn.Conv1d(in_channels = charset_len, out_channels = charset_len, kernel_size=3) #Out -> 23
      self.conv6 = nn.Conv1d(in_channels = charset_len, out_channels = charset_len, kernel_size=3) #Out -> 21
      
      self.cc1activation = nn.ReLU()
      self.cc1bn = nn.BatchNorm1d(21*charset_len)
      self.country_classifier1 = nn.Linear(21*charset_len, classifier_dim)
      self.cc2activation = nn.ReLU()
      self.cc2bn = nn.BatchNorm1d(classifier_dim)
      self.country_classifier2 = nn.Linear(classifier_dim, countries_num+1)
      
      self.cr1activation = nn.ReLU()
      self.cr1bn = nn.BatchNorm1d(21*charset_len)
      self.coord_regressor1 = nn.Linear(21*charset_len, classifier_dim)
      self.cr2activation = nn.ReLU()
      self.cr2bn = nn.BatchNorm1d(classifier_dim)
      self.coord_regressor2 = nn.Linear(classifier_dim, 2)
        
    def convs(self, x):
      x = self.conv1(x)
      x = self.maxp1(x)
      x = self.conv2(x)
      x = self.maxp2(x)
      x = self.conv3(x)
      x = self.conv4(x)
      x = self.conv5(x)
      x = self.conv6(x)
      return x
        
    def forward(self, x):
      x = self.conv1(x)
      x = self.maxp1(x)
      x = self.conv2(x)
      x = self.maxp2(x)
      x = self.conv3(x)
      x = self.conv4(x)
      x = self.conv5(x)
      x = self.conv6(x)
      x = torch.flatten(x, start_dim=1) #for batches

      country = self.cc1activation(x)
      country = self.cc1bn(country)
      country = self.country_classifier1(country)
      country = self.cc2activation(country)
      country = self.cc2bn(country)
      country = self.country_classifier2(country)
      
      x = self.cr1activation(x)
      x = self.cr1bn(x)
      x = self.coord_regressor1(x)
      x = self.cr2activation(x)
      x = self.cr2bn(x)
      x = self.coord_regressor2(x)

      return country, x

In [14]:
net = TweetCharacterCNN(classifier_dim = 200)

#Here should be Mixture of von Mises-Fisher Distributions for coordinates loss, 
#but I don't currently have enough computational resources for it, so have to 
#use Euclidean approximation. Just pretend that the Earth is flat for now, please =)

coord_criterion = nn.MSELoss()
country_criterion = nn.CrossEntropyLoss()

In [15]:
optimizer = torch.optim.Adam(net.parameters(), lr=0.01)

# Model Training

In [16]:
net.train()
train_min_length = 5000
data_collected = pd.DataFrame(columns=["text","geo_country","latitude","longitude"])

iteration = 0
filelist = os.listdir(train_directory)
random.shuffle(filelist)
curr_files = []
files_total = len(filelist)
files_processed = 0

for file in filelist:  
  filename = os.fsdecode(file)
  curr_files.append(filename)
  data = pd.read_csv(TRAIN_PATH + "/" + filename, sep=";")
  files_processed += 1
  if data.shape[0] == 0: continue #Skip empty files
  data_cleaned = data[["text","geo_country"]].copy()
  data_cleaned["latitude"] = float(filename[0:filename.index("_")])
  data_cleaned["longitude"] = float(filename[filename.index("_")+1:-4])
  try: #Records correctness check
    try_dataset = TweetsDataset(data_cleaned)
    try_dataloader = torch.utils.data.DataLoader(try_dataset, batch_size=1, shuffle=True, num_workers=2, drop_last = True)
    try_get_subscr = next(iter(try_dataloader))
    data_collected = pd.concat([data_collected, data_cleaned]) #Accumulating data over several files until threshold
  except(TypeError):
    print(f'Broken records in file {filename}')
  if data_collected.shape[0] >= train_min_length:
    data_collected.reset_index(drop=True, inplace=True)
    print(f'Datapoints collected: {data_collected.shape[0]}; Total files processed: {files_processed}/{files_total}')
    train_dataset = TweetsDataset(data_collected)
    train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=100, shuffle=True, num_workers=1, drop_last = True)

    Acc_50 = 0
    Acc_100 = 0
    Acc_500 = 0
    Acc_1000 = 0
    Acc_2000 = 0
    Acc_3000 = 0
    Acc_5000 = 0
    Acc_country = 0
    total_mse_loss = 0
    total_CE_loss = 0
    total_miss = 0
    iteration += 1
    try: #skip files with broken records
      for i, data in enumerate(train_dataloader):
          inputs = data[0]
          target_country = data[1]
          target_coords = data[2]

          #Multitask: first, try to predict country
          optimizer.zero_grad()
          outputs = net(inputs)
          country_loss = country_criterion(outputs[0], target_country)
          total_CE_loss += country_loss
          country_loss.backward()
          optimizer.step()

          #Second, try to predict coords
          optimizer.zero_grad()
          outputs = net(inputs)
          coords_loss = coord_criterion(outputs[1], target_coords)
          total_mse_loss += coords_loss
          coords_loss.backward()
          optimizer.step()

          #print(f'Cross-Entropy loss = {country_loss}')
          #print(f'MSE loss = {coords_loss}')

          # Calculating range-based accuracy
          pred_country = outputs[0]
          pred_coords = outputs[1]

          #bounding crazy early results

          for i, coords in enumerate(pred_coords):
            lat = max(min(coords[0],90),-90)
            long = max(min(coords[1],180),-180)
            miss_distance = geopy.distance.distance((lat, long), target_coords[i]).km
            total_miss += miss_distance
            if miss_distance <= 50:
              Acc_50 +=1
            if miss_distance <= 100:
              Acc_100 +=1
            if miss_distance <= 500:
              Acc_500 +=1
            if miss_distance <= 1000:
              Acc_1000 +=1
            if miss_distance <= 2000:
              Acc_2000 +=1
            if miss_distance <= 3000:
              Acc_3000 +=1
            if miss_distance <= 5000:
              Acc_5000 +=1

          for i, country in enumerate(pred_country):
            if torch.argmax(country) == target_country[i]:
              Acc_country +=1
      total_points = data_collected.shape[0]  
      print(f'Iteration: {iteration}, Accuracies:')
      print(f'Country - {Acc_country/total_points}, @50 - {Acc_50/total_points},' \
      f'@100 - {Acc_100/total_points}, @500 - {Acc_500/total_points},' \
      f'@1000 - {Acc_1000/total_points}, @2000 - {Acc_2000/total_points},' \
      f'@3000 - {Acc_3000/total_points}, @5000 - {Acc_5000/total_points}')
      print(f'Avg losses: MSE - {total_mse_loss/total_points}, CE - {total_CE_loss/total_points}, Distance - {total_miss/total_points}')
      torch.save(net.state_dict(), WEIGHTS_PATH + '/geolocation_temp.pt')
      data_collected = pd.DataFrame(columns=["text","geo_country","latitude","longitude"])      
      curr_files = []
    except(TypeError):
      print(f'Broken records in batch with files {curr_files}')  
      data_collected = pd.DataFrame(columns=["text","geo_country","latitude","longitude"])  
      curr_files = []
        
print('Training is finished!')

Datapoints collected: 7155; Total files processed: 3/3278
Iteration: 1, Accuracies:
Country - 0.8867924528301887, @50 - 0.0004192872117400419,@100 - 0.0023759608665269044, @500 - 0.27882599580712786,@1000 - 0.4863731656184486, @2000 - 0.5306778476589797,@3000 - 0.5703703703703704, @5000 - 0.6489168413696715
Avg losses: MSE - 8.692052841186523, CE - 0.006238440051674843, Distance - 3277.849283941884
Datapoints collected: 29208; Total files processed: 13/3278
Iteration: 2, Accuracies:
Country - 0.9439194741166803, @50 - 0.05293070391673514,@100 - 0.19237880032867707, @500 - 0.9134826075047932,@1000 - 0.9535401259928786, @2000 - 0.962236373596275,@3000 - 0.9721993974253629, @5000 - 0.9996576280471103
Avg losses: MSE - 0.22839267551898956, CE - 0.0034022070467472076, Distance - 340.7054591154428
Datapoints collected: 54164; Total files processed: 18/3278
Iteration: 3, Accuracies:
Country - 0.988165571228122, @50 - 0.7473044826822244,@100 - 0.9185806070452699, @500 - 0.9791743593530758,@100

KeyboardInterrupt: ignored

# Saving model

In [17]:
torch.save(net.state_dict(), WEIGHTS_PATH + '/geolocation.pt')

# Loading model

In [18]:
net = TweetCharacterCNN(classifier_dim = 200)
net.load_state_dict(torch.load(WEIGHTS_PATH + '/geolocation.pt'))
net.eval()

coord_criterion = nn.MSELoss()
country_criterion = nn.CrossEntropyLoss()

# Getting prediction

In [19]:
countries_dec_dict = {1:'Argentina',
                    2:'Aruba',
                    3:'Bolivia',
                    4:'Brazil',
                    5:'Chile',
                    6:'Colombia',
                    7:'Curaçao',
                    8:'Ecuador',
                    9:'Falkland Islands (Malvinas)',
                    10:'French Guiana',
                    11:'Guyana',
                    12:'Panama',
                    13:'Paraguay',
                    14:'Peru',
                    15:'Suriname',
                    16:'Trinidad and Tobago',
                    17:'Uruguay',
                    18:'Venezuela'
                    }

In [None]:
#Running test dataset
net.eval()
test_min_length = 5000

Acc_50 = 0
Acc_100 = 0
Acc_500 = 0
Acc_1000 = 0
Acc_2000 = 0
Acc_3000 = 0
Acc_5000 = 0
Acc_country = 0
total_mse_loss = 0
total_CE_loss = 0
total_miss = 0
total_points = 0

data_collected = pd.DataFrame(columns=["text","geo_country","latitude","longitude"])

filelist = os.listdir(test_directory)
random.shuffle(filelist)

for file in filelist:  
  filename = os.fsdecode(file)
  data = pd.read_csv(TEST_PATH + "/" + filename, sep=";")
  if data.shape[0] == 0: continue #Skip empty files
  data_cleaned = data[["text","geo_country"]].copy()
  data_cleaned["latitude"] = float(filename[0:filename.index("_")])
  data_cleaned["longitude"] = float(filename[filename.index("_")+1:-4])
  data_collected = pd.concat([data_collected, data_cleaned]) #Accumulating data from all test set
  if data_collected.shape[0] >= test_min_length:    
    total_points += data_collected.shape[0]  
    data_collected.reset_index(drop=True, inplace=True)
    print(f'Datapoints collected: {data_collected.shape[0]}')
    test_dataset = TweetsDataset(data_collected)
    test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=100, shuffle=True, num_workers=1, drop_last = True)

    for i, data in enumerate(test_dataloader):
        inputs = data[0]
        target_country = data[1]
        target_coords = data[2]

        #Multitask: first, try to predict country
        outputs = net(inputs)
        country_loss = country_criterion(outputs[0], target_country)
        total_CE_loss += country_loss

        #Second, try to predict coords
        outputs = net(inputs)
        coords_loss = coord_criterion(outputs[1], target_coords)
        total_mse_loss += coords_loss

        #print(f'Cross-Entropy loss = {country_loss}')
        #print(f'MSE loss = {coords_loss}')

        # Calculating range-based accuracy
        pred_country = outputs[0]
        pred_coords = outputs[1]

        #bounding crazy results

        for i, coords in enumerate(pred_coords):
          lat = max(min(coords[0],90),-90)
          long = max(min(coords[1],180),-180)
          miss_distance = geopy.distance.distance((lat, long), target_coords[i]).km
          total_miss += miss_distance
          if miss_distance <= 50:
            Acc_50 +=1
          if miss_distance <= 100:
            Acc_100 +=1
          if miss_distance <= 500:
            Acc_500 +=1
          if miss_distance <= 1000:
            Acc_1000 +=1
          if miss_distance <= 2000:
            Acc_2000 +=1
          if miss_distance <= 3000:
            Acc_3000 +=1
          if miss_distance <= 5000:
            Acc_5000 +=1

        for i, country in enumerate(pred_country):
          if torch.argmax(country) == target_country[i]:
            Acc_country +=1
print(f'Accuracies:')
print(f'Country - {Acc_country/total_points}, @50 - {Acc_50/total_points},' \
f'@100 - {Acc_100/total_points}, @500 - {Acc_500/total_points},' \
f'@1000 - {Acc_1000/total_points}, @2000 - {Acc_2000/total_points},' \
f'@3000 - {Acc_3000/total_points}, @5000 - {Acc_5000/total_points}')
print(f'Avg losses: MSE - {total_mse_loss/total_points}, CE - {total_CE_loss/total_points}, Distance - {total_miss/total_points}')
print('Testing is finished!')

Datapoints collected: 8840
Datapoints collected: 9031
Datapoints collected: 9222
