Skip to content

Commit

Permalink
one big commit, sorry
Browse files Browse the repository at this point in the history
  • Loading branch information
= committed Mar 9, 2023
1 parent a1e11f2 commit dabf20d
Show file tree
Hide file tree
Showing 414 changed files with 774 additions and 57 deletions.
2 changes: 1 addition & 1 deletion common/backend-utils/xaidemo/tracking/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

class TrackingSettings(BaseSettings):
experiment: bool = False
service_name: str
service_name: str = "tracking_service"
collector_url: Optional[str]
collector_timeout: int = 60

Expand Down
2 changes: 1 addition & 1 deletion guess-the-country/country-backend/country/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@


class Settings(BaseSettings):
service_name: str = "country-service"
service_name = "country-service"
root_path: str = ""
path_prefix: str = ""
# Google Maps API access
Expand Down
Binary file not shown.
22 changes: 14 additions & 8 deletions guess-the-country/country-backend/country/explainer/explain.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ class Explanation(BaseModel):
def explain(data):
encoded_data = str(data)
image = load_image(encoded_data)
pre_image = preprocess(img=image)
preproc_image = preprocess(img=image)

explanation = explain_cnn(pre_image, model)
explanation = explain_cnn(np.array(image), preproc_image, model)
explanation_id = uuid.uuid4()

encoded_image_string = convert_explanation(explanation)
Expand All @@ -45,12 +45,18 @@ def convert_explanation(explanation):


@traced
def explain_cnn(image, model_=model):
segment_mask, segment_weights = explain_classification(image=image,
def explain_cnn(image, preproc_image, model_=model):
try:
segment_mask, segment_weights = explain_classification(image=preproc_image,
segmentation_method="felzenszwalb",
segmentation_settings={},
predict_fn=model_.predict_,
num_of_samples=500,
p=0.9)
predict_fn=model_.predict,
num_of_samples=100,
p=0.5)

return render_explanation(preproc_image, segment_mask, segment_weights,
positive="violet", coverage=0.15, opacity=0.5)

except ZeroDivisionError:
return image

return render_explanation(image, segment_mask, segment_weights, positive="violet", coverage=0.15, opacity=0.5)
4 changes: 2 additions & 2 deletions guess-the-country/country-backend/country/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import sys

from fastapi import FastAPI
from xaidemo import tracing, tracking
from xaidemo import tracing#, tracking


from xaidemo.routers import vue_frontend
Expand All @@ -16,7 +16,7 @@
tracing.set_up()

app = FastAPI(root_path=settings.root_path)
tracking.instrument_app(app)
#tracking.instrument_app(app)
app.include_router(api, prefix=settings.path_prefix)
app.include_router(vue_frontend(__file__), prefix=settings.path_prefix)

Expand Down
20 changes: 16 additions & 4 deletions guess-the-country/country-backend/country/model/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def load_image(encoded_data):

@traced
def predict_image(image):
prediction = model.predict(image)
prediction = process_predict(image[None,:,:,:])
result = decode_model_output(prediction)
return result

Expand All @@ -44,9 +44,21 @@ def preprocess(img):
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# resize image to match model's expected sizing
img_resize = cv2.resize(img_rgb, (IMG_SIZE, IMG_SIZE))
image = tf.keras.applications.mobilenet_v2.preprocess_input(img_resize)
pre_image = image.reshape(-1, IMG_SIZE, IMG_SIZE, 3)
return pre_image
#image = tf.keras.applications.mobilenet_v2.preprocess_input(img_resize)
#pre_image = img_resize
#pre_image = image.reshape(-1, IMG_SIZE, IMG_SIZE, 3)

return img_resize


def process_predict(batch: np.ndarray):
"""
input: np array of size (batch, height, width, 3)
output: np array of size (batch, 4) where every row represents the softmax output of a
image inside the given batch
"""
proc_batch = tf.keras.applications.mobilenet_v2.preprocess_input(batch)
return model.predict(proc_batch)


MODEL_OUTPUT_MAP = ["Tel_Aviv", "Westjerusalem", "Berlin", "Hamburg"]
Expand Down
233 changes: 233 additions & 0 deletions guess-the-country/country-backend/country/streetview/_collect.py

Large diffs are not rendered by default.

56 changes: 25 additions & 31 deletions guess-the-country/country-backend/country/streetview/collect.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
import base64
import io
import json
import random
import os

from urllib.error import URLError
from urllib.request import urlopen

from itertools import cycle
from PIL import Image
from pydantic import BaseModel
from shapely.geometry import Polygon
from xaidemo.http_client import AioHttpClientSession
Expand Down Expand Up @@ -194,40 +200,28 @@ class Streetview(BaseModel):

country_array = [tel_aviv, jerusalem, berlin, hamburg]

with open('/country/streetview/filename2class.json', 'r') as f:
filename2class = json.load(f)


filenames = [f for f in filename2class]
random.shuffle(filenames)
file_iter = cycle(filenames)


@traced
async def get_streetview(API_KEY):
async with AioHttpClientSession() as session:
nominated_country = random.randint(0, 3)
poly = country_array[nominated_country]['polygon']
status = False
while status != 'OK':
coord = generate_random(poly)
lng = coord[0][0]
lat = coord[0][1]
locstring = str(lat) + "," + str(lng)
try:
async with session.get(
API_URL + "?key=" + API_KEY + "&location=" + locstring + "&source=outdoor") as response:
json_body = (await response.json())
status = json_body['status']
print(status)
if status == 'REQUEST_DENIED':
print("NO API-KEY is definied, please set environment variable GOOGLE_MAPS_API_TOKEN")
break
except AioHttpClientSession.exceptions.TimeoutError:
print(AioHttpClientSession.exceptions.TimeoutError)
print(" ========== Got one! ==========")
url = GOOGLE_URL + API_KEY + "&location=" + locstring
try:
contents = urlopen(url).read()
# urlretrieve(url, outfile)
except URLError:
print(URLError)
status = False
encoded_image_string = base64.b64encode(contents)
encoded_bytes = bytes("data:image/png;base64,",
filename = next(file_iter)
try:
image = Image.open(f"/country/streetview/streetview_images/{filename}")
except:
raise Exception(f'FILE NOT FOUND!')

buffered = io.BytesIO()
image.save(buffered, format="JPEG")
encoded_image_string = base64.b64encode(buffered.getvalue())
encoded_bytes = bytes("data:image/png;base64,",
encoding="utf-8") + encoded_image_string
return Streetview(
image=encoded_bytes,
class_label=country_array[nominated_country]['city'])
class_label=filename2class[filename])

0 comments on commit dabf20d

Please sign in to comment.