In [16]:
# df = spark.sql("SELECT * FROM InFreGen.task LIMIT 1000")
# display(df)
# # 删除所有给定taskID的pic
# spark.sql("DELETE FROM InFreGen.pic WHERE taskID = {}".format(taskID))
# taskID = 19

StatementMeta(, , , Waiting, , Waiting)

In [17]:
# Import required libraries
from enum import Enum
import requests
from pyspark.sql import functions as F
import json

# Unsplash API access key
ACCESS_KEY = 'YOUR KEY HERE'


class Orientation(Enum):
    """Enum class to represent image orientation options"""
    UNKNOWN = 0
    LANDSCAPE = 1  # Landscape orientation
    PORTRAIT = 2   # Portrait orientation
    SQUARISH = 3   # Square orientation

    def get_lower_value(self):
        """Returns lowercase string value of orientation for API params"""
        if self == Orientation.UNKNOWN:
            return ""
        return self.name.lower()


class UnsplashClient:
    """Client class to interact with the Unsplash API"""
    
    def __init__(self, access_key):
        """Initialize client with access key and base configuration"""
        self.access_key = access_key
        self.base_url = 'https://api.unsplash.com/'
        self.headers = {
            'Accept-Version': 'v1',
            'Authorization': f'Client-ID {self.access_key}'
        }

    def get_random_image(self, query: str = None, orientation: Orientation = Orientation.UNKNOWN):
        """
        Fetch a random image from Unsplash
        Args:
            query: Search term for the image
            orientation: Desired image orientation
        Returns:
            JSON response if successful, None if failed
        """
        url = self.base_url + 'photos/random'
        params = {
            'query': query,
            'orientation': orientation.get_lower_value()
        }
        response = requests.get(url, headers=self.headers, params=params)
        return response.json() if response.status_code == 200 else None

    def search_images(
            self, query: str,
            page: int = 1,
            per_page: int = 30,
            orientation: Orientation = Orientation.UNKNOWN
    ):
        """
        Search for images on Unsplash
        Args:
            query: Search term
            page: Page number for pagination
            per_page: Number of results per page
            orientation: Desired image orientation
        Returns:
            JSON response if successful, None if failed
        """
        url = self.base_url + 'search/photos'
        params = {
            'query': query,
            'page': page,
            'per_page': per_page,
        }
        if orientation != Orientation.UNKNOWN:
            params['orientation'] = orientation.get_lower_value()
        response = requests.get(url, headers=self.headers, params=params)
        print(response.status_code)
        # 400 error 
        return response.json() if response.status_code == 200 else None

# Initialize Unsplash client with access key
client = UnsplashClient(ACCESS_KEY)

StatementMeta(, , , Waiting, , Waiting)

In [18]:
# Import required random module for selecting random keywords and images
import random

# Query the task table to get task details for the given taskID
task_row = spark.sql("SELECT * FROM InFreGen.task WHERE taskID = {}".format(taskID)).collect()[0]
print(task_row)
# Row(taskID=12, userInput='I want to do medical image cancer detection', keyword=None, num=8, resolution='1024x1024', sizeChoice='small', State=0)

def searchPerTask(task_row) -> list[dict]:
    """
    Search and fetch images from Unsplash based on task parameters
    
    Args:
        task_row: Row object containing task details like keywords, number of images etc.
        
    Returns:
        list[dict]: List of dictionaries containing image metadata and URLs
    """
    # Split comma-separated keywords into list
    keywords = task_row['keyword'].split(',')
    print("keyword:", keywords)
    
    # Initialize empty list to store image metadata
    pic_table = []
    count = 0
    
    # Keep fetching images until we reach desired count
    while count < task_row['num']:
        # Randomly select a keyword from available keywords
        keyword = random.choice(keywords)
        print("Chosen keyword: " + keyword)

        # Search Unsplash with selected keyword
        search_results = client.search_images(query=keyword, page=1, per_page=30, orientation=Orientation.SQUARISH)
        try:
            # Get URLs for a random image from search results
            image_urls = random.choice(search_results['results'])['urls']
        except:
            # Skip if no results found
            continue
            
        # Create metadata dictionary for the image
        new_row = {
            'taskID': task_row['taskID'],
            'picID': count,
            'Resolution': task_row['resolution'],
            'sizeChoice': task_row['sizeChoice'],
            'operations': "",
            'operationsReturn': "",
            'url': json.dumps(image_urls),
            'State': 0,
            'originalPicPath': "",
            'curPicPath': "",
            'keywords': keyword,
            'finalPicPath': "",
        }
        pic_table.append(new_row)
        count += 1
    return pic_table
    
# Import IntegerType for type casting
from pyspark.sql.types import IntegerType

# Create DataFrame from image metadata and cast columns to proper types
pic_df = spark.createDataFrame(searchPerTask(task_row))
pic_df = pic_df.withColumn("taskID", pic_df["taskID"].cast(IntegerType()))
pic_df = pic_df.withColumn("picID", pic_df["picID"].cast(IntegerType()))
pic_df = pic_df.withColumn("State", pic_df["State"].cast(IntegerType()))

# Save DataFrame as Delta table
pic_df.write.format("delta").mode("append").saveAsTable("InFreGen.pic")

StatementMeta(, , , Waiting, , Waiting)