In [None]:
!pip install ftfy regex tqdm
!pip install git+https://github.com/openai/CLIP.git

In [None]:
#download and extract im2gps
!wget http://www.cis.jhu.edu/~shraman/im2gps_rgb_images.tar.gz
!tar -xzvf "/content/im2gps_rgb_images.tar.gz" -C "/content/"     #[run this cell to extract tar.gz files]
!wget http://www.cis.jhu.edu/~shraman/im2gps3k_rgb_images.tar.gz
!tar -xzvf "/content/im2gps3k_rgb_images.tar.gz" -C "/content/"     #[run this cell to extract tar.gz files]

In [None]:
#Generate continental labels from lattitude and longitude values
!pip install pycountry_convert

import numpy as np
import pandas as pd
from geopy.geocoders import Nominatim
from geopy.extra.rate_limiter import RateLimiter
import pycountry_convert as pc

from pprint import pprint
from typing import Tuple

from tqdm import tqdm
tqdm.pandas()



def get_continent_name(continent_code: str) -> str:
    continent_dict = {
        "NA": "North America",
        "SA": "South America",
        "AS": "Asia",
        "AF": "Africa",
        "OC": "Oceania",
        "EU": "Europe",
        "AQ" : "Antarctica"
    }
    return continent_dict[continent_code]

def get_continent(lat: float, lon:float) -> Tuple[str, str]:
    geolocator = Nominatim(user_agent="<username1>@gmail.com", timeout=10)
    # geocode = RateLimiter(geolocator.reverse, min_delay_seconds=.01)
    geocode = RateLimiter(geolocator.reverse)

    location = geocode(f"{lat}, {lon}", language="en")

    # for cases where the location is not found, coordinates are antarctica
    if location is None:
        return "Antarctica", "Antarctica"

    # extract country code
    address = location.raw["address"]
    country_code = address["country_code"].upper()

    # get continent code from country code
    continent_code = pc.country_alpha2_to_continent_code(country_code)
    continent_name = get_continent_name(continent_code)
    
    return country_code, continent_name

#labels files can be found here: https://github.com/ShramanPramanick/Transformer_Based_Geo-localization/tree/main/resources
im2gpslabels3k = pd.read_csv('im2gps3k_places365.csv')
im2gpslabels = pd.read_csv('im2gps_places365.csv')


im2gpslabels3k[["COUNTRY", "CONTINENT"]] = im2gpslabels3k.progress_apply(
    lambda x: get_continent(x["LAT"], x["LON"]), axis=1, result_type="expand")

im2gpslabels[["COUNTRY", "CONTINENT"]] = im2gpslabels.progress_apply(
    lambda x: get_continent(x["LAT"], x["LON"]), axis=1, result_type="expand")


im2gpslabels3k.to_csv("im2gpslabels3k.csv")
im2gpslabels.to_csv("im2gpslabels.csv") 

In [None]:
import torch
import clip
from torch.utils.data import Dataset, DataLoader
from torch import nn
from torch import optim
import os
import pandas as pd
from PIL import Image

In [None]:

#BATCH_SIZE must larger than 1

device = "cuda:0" if torch.cuda.is_available() else "cpu" # If using GPU then use mixed precision training.
model, preprocess = clip.load("ViT-B/32",device=device,jit=False) #Must set jit=False for training

batch_size = 128

class image_title_dataset(Dataset):
    def __init__(self, list_image_path,list_txt):

        self.image_path = list_image_path
        self.title  = clip.tokenize(list_txt) #you can tokenize everything at once in here(slow at the beginning), or tokenize it in the training loop.

    def __len__(self):
        return len(self.title)

    def __getitem__(self, idx):
        image = preprocess(Image.open(self.image_path[idx])) # Image from PIL module
        title = self.title[idx]
        return image,title

# use your own data

test_labels = pd.read_csv("im2gpslabels.csv")[["IMG_ID", "CONTINENT"]]
#test_dataset = CustomImageDataset("im2gpslabels.csv", "/content/im2gps_rgb_images/")
train_labels = pd.read_csv("im2gps3klabels.csv")[["IMG_ID", "CONTINENT"]]
#train_dataset = CustomImageDataset("im2gps3klabels.csv", "/content/im2gps3k_rgb_images/")

train_img_path = "/content/im2gps3k_rgb_images/"
train_list_image_path = [train_img_path + img_name for img_name in train_labels["IMG_ID"]]
train_list_txt = [label for label in train_labels["CONTINENT"]]

train_dataset = image_title_dataset(train_list_image_path,train_list_txt)
train_dataloader = DataLoader(train_dataset, batch_size = batch_size) #Define your own dataloader

test_img_path = "/content/im2gps_rgb_images/"
test_list_image_path = [test_img_path + img_name for img_name in test_labels["IMG_ID"]]
test_list_txt = [label for label in test_labels["CONTINENT"]]

test_dataset = image_title_dataset(test_list_image_path,test_list_txt)
test_dataloader = DataLoader(test_dataset, batch_size = batch_size) #Define your own dataloader

#https://github.com/openai/CLIP/issues/57
def convert_models_to_fp32(model): 
    for p in model.parameters(): 
        p.data = p.data.float() 
        p.grad.data = p.grad.data.float() 


if device == "cpu":
  model.float()
else :
  clip.model.convert_weights(model) # Actually this line is unnecessary since clip by default already on float16


In [None]:
loss_img = nn.CrossEntropyLoss()
loss_txt = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=5e-6,betas=(0.9,0.98),eps=1e-6,weight_decay=0.2) #Params used from paper, the lr is smaller, more safe for fine tuning to new dataset

epochs = 10

# add your own code to track the training progress.
for epoch in range(epochs):
  for batch in train_dataloader :
      optimizer.zero_grad()

      images,texts = batch 
    
      images= images.to(device)
      texts = texts.to(device)
    
      logits_per_image, logits_per_text = model(images, texts)

      ground_truth = torch.arange(len(images),dtype=torch.long,device=device)

      total_loss = (loss_img(logits_per_image,ground_truth) + loss_txt(logits_per_text,ground_truth))/2
      total_loss.backward()
      if device == "cpu":
         optimizer.step()
      else : 
        convert_models_to_fp32(model)
        optimizer.step()
        clip.model.convert_weights(model)

In [None]:
torch.save({
        'epoch': epochs,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': total_loss,
        }, f"model_1000.pt") #just change to your preferred folder/filename

In [None]:
#Evaluate Accuracy
correct = 0
total = 0
eval_model = model.eval()

all_labels = ["Africa", "Asia", "Europe", "North America", "South America", "Oceania"]
continent_dict = {'Africa': 0, 'Asia': 1, "Europe": 2, "North America": 3, "South America": 4, "Oceania":5, 0:'Africa', 1:"Asia", 2:"Europe", 3:"North America", 4:"South America", 5: "Oceania"}

text_inputs = torch.cat([clip.tokenize(f"{c}") for c in all_labels]).to(device)
with torch.no_grad():
  text_features = eval_model.encode_text(text_inputs)
  text_features /= text_features.norm(dim=-1, keepdim=True)

for i in range(len(test_list_image_path)):
  print(i)
  image = Image.open(test_list_image_path[i])
  class_id = continent_dict[test_labels["CONTINENT"][i]]

  image_input = preprocess(image).unsqueeze(0).to(device)

  # Calculate features
  with torch.no_grad():
      image_features = eval_model.encode_image(image_input)

  image_features /= image_features.norm(dim=-1, keepdim=True)
  similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
  values, indices = similarity[0].topk(6)
  
  if indices[0] == class_id:
    correct += 1
    print("correct")
  
  total +=1

  print(values)
  print(indices)
  # Print the result
  print("\nTop predictions:\n")
  print("Correct: ", test_labels["CONTINENT"][i])
  for value, index in zip(values, indices):
      print(f"{all_labels[index]:>16s}: {100 * value.item():.2f}%")

print("Accuracy: ", correct/total)

In [None]:
#training set stats
import matplotlib.pyplot as plt
count = train_labels["CONTINENT"].value_counts()
count.plot.bar()
plt.ylabel('Number of records')
plt.xlabel('Target Class')
plt.show()

In [None]:
#Test set stats
import matplotlib.pyplot as plt
count = test_labels["CONTINENT"].value_counts()
count.plot.bar()
plt.ylabel('Number of records')
plt.xlabel('Target Class')
plt.show()