In [None]:
# Install the different libraries
!pip install dalle-pytorch --upgrade
!pip install gdown
!git clone https://github.com/lucidrains/DALLE-pytorch.git

In [None]:
# Connect the google drive repo
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Coco Dataset
from pycocotools.coco import COCO

# Set paths for the coco dataset
dataDir='/content/drive/MyDrive/COCOdataset2017Airplanes'
dataType='train2017'
annFile='{}/annotations/instances_{}.json'.format(dataDir,dataType)

In [None]:
# Initialize the COCO api for instance annotations
coco=COCO(annFile)

# Load the categories in a variable
catIDs = coco.getCatIds()
cats = coco.loadCats(catIDs)

# Define the classes (out of the 81) which you want to see. Others will not be shown.
filterClasses = ['airplane']

# Fetch class IDs only corresponding to the filterClasses
catIds = coco.getCatIds(catNms=filterClasses) 
# Get all images containing the above Category IDs
imgIds = coco.getImgIds(catIds=catIds)

# initialize COCO API for caption captions
captions_annFile = '{}/annotations/captions_{}_planes.json'.format(dataDir,dataType)
coco_caps = COCO(captions_annFile)

In [None]:
# Lists to contain the images info
images_annot = []
images_paths = []

for i in range(2586, 2985):
  # Training set from 0 to 1999
  # Validation set from 2000 to 2585
  # Test set from 2586 to 2985
  
  annIds = coco_caps.getAnnIds(imgIds=imgIds[i])
  anns = coco_caps.loadAnns(annIds)

  # Keep the captions
  images_annot.append(anns[0]['caption'])

  #Get file paths
  img=coco.loadImgs(imgIds[i])[0]
  content_img_path='{}/New_Art_Airplanes_Final/{}'.format(dataDir,img['file_name'])

  # Keep the path
  images_paths.append(content_img_path)

In [None]:
# Take of the "'" of the captions to limit errors
images_annot_v2 = []
for el in images_annot:
  images_annot_v2.append(el.replace("'", " "))

In [None]:
# Initialize variables containing the distances
distanceTotal = 0
distances_list = []

# import os
import requests
import pandas as pd

for i in range(0,400):

  # Get the current caption
  CAPTION = images_annot_v2[i]

  # Use the dalle model given some parameters like the caption
  # returns an image describing the caption with an art style
  !python /content/DALLE-pytorch/generate.py --dalle_path /content/drive/MyDrive/WandB/dalle_Dim12.pt --text '{CAPTION}' --batch_size 1 --num_images 1

  # Use DeepAI API to get similarity between two images
  r = requests.post(
      "https://api.deepai.org/api/image-similarity",
      files={
          'image1': open('/content/outputs/' + images_annot_v2[i].replace(' ', '_') + '/0.jpg', 'rb'),
          'image2': open(images_paths[i], 'rb'),
      },
      headers={'api-key': '##API-KEY##'}
  )

  # Handle returned request and get distance from it
  distance = r.json()['output']['distance']
  distanceTotal += distance
  distances_list.append(distance)

  # Back-up of the current progress under a CSV file saved directly in the
  # Drive
  if (i % 10) == 0:
    df = pd.DataFrame(distances_list)
    df.to_csv('/content/drive/MyDrive/Results_text2image/dim12_'+ str(i) + '.csv')

