Importamos la clase ImageFeaturesExtraction

In [11]:
from torchvision.models.detection import fasterrcnn_resnet50_fpn, FasterRCNN_ResNet50_FPN_Weights
import torchvision.transforms as T
import requests
import torch
from io import BytesIO
from PIL import Image
from typing import List
from translate import Translator
import cairosvg
import json
import pymongo
from pymongo import MongoClient

class ImageFeaturesExtraction:
    def __init__(self, threshold = 0.5):
        self.model = fasterrcnn_resnet50_fpn(weights=FasterRCNN_ResNet50_FPN_Weights.DEFAULT)
        self.model.eval()
        self.transform = T.Compose([T.ToTensor()])
        self.threshold = threshold
        self.labels_url = 'https://raw.githubusercontent.com/nightrome/cocostuff/master/labels.txt'
        self.coco_labels = self.load_coco_labels()
        self.translator = Translator(to_lang="es")

    def load_coco_labels(self):
        response = requests.get(self.labels_url)
        return response.text.splitlines()

    def translate_labels(self, labels):
        translations = [self.translator.translate(label) for label in labels]
        return [translation for translation in translations]

    def convert_svg_to_format(self, svg_data, output_format='png'):
        if output_format not in ['png', 'jpeg', 'pdf', 'ps', 'eps', 'svg']:
            raise ValueError("El formato de salida debe ser uno de: png, jpeg, pdf, ps, eps, svg")

        converted_data = cairosvg.svg2bytestring(bytestring=svg_data, output_format=output_format)
        return Image.open(BytesIO(converted_data))

    def analyze_news_images(self, image_urls: List[str], output_format='png'):
        results = []
        for img_url in image_urls:
            response = requests.get(img_url)

            # Comprobar si la imagen es SVG
            if 'image/svg+xml' in response.headers.get('content-type', ''):
                svg_data = response.content
                image = self.convert_svg_to_format(svg_data, output_format)
            else:
                image_extension = response.headers.get('content-type').split('/')[1]
                image = Image.open(BytesIO(response.content))

            # El resto del código sigue igual para las imágenes rasterizadas
            image_tensor = self.transform(image).unsqueeze(0)

            with torch.no_grad():
                prediction = self.model(image_tensor)

            boxes = prediction[0]['boxes']
            scores = prediction[0]['scores']
            labels = prediction[0]['labels']

            filtered_boxes = boxes[scores > self.threshold]
            filtered_scores = scores[scores > self.threshold]
            filtered_labels = labels[scores > self.threshold]

            translated_labels = self.translate_labels([self.coco_labels[label.item()] for label in filtered_labels])

            results.append({
                'image_url': img_url,
                'detections': list(zip(translated_labels, filtered_scores.numpy()))
            })

        return results

    def save_results_to_file(self, results, output_file='results.json'):
        with open(output_file, 'w') as file:
            json.dump(results, file, indent=4)

imgs = ["https://cloudfront-eu-central-1.images.arcpublishing.com/prisa/2HNUTOK2HNEPFDTR5WYMIJDFOY.jpg", "https://cloudfront-eu-central-1.images.arcpublishing.com/prisa/JPPHJM6AWJDHJGO5DEJRHQLNXM.jpg"]



image_features_extraction = ImageFeaturesExtraction()
image_features_extraction = image_features_extraction.analyze_news_images(imgs)
print(image_features_extraction)


[{'image_url': 'https://cloudfront-eu-central-1.images.arcpublishing.com/prisa/2HNUTOK2HNEPFDTR5WYMIJDFOY.jpg', 'detections': [('1 persona', 0.9969202), ('47: taza', 0.99400324), ('1 persona', 0.9872723), ('62: silla', 0.9762356), ('44: botella', 0.9664547), ('44: botella', 0.8505604), ('76: teclado', 0.8439268), ('1 persona', 0.81844544), ('73: portátil', 0.7570158), ('62: silla', 0.7482912), ('Mesa de comedor', 0.71215147), ('44: botella', 0.65817744), ('76: teclado', 0.6301845), ('62: silla', 0.5747696), ('62: silla', 0.56445944)]}, {'image_url': 'https://cloudfront-eu-central-1.images.arcpublishing.com/prisa/JPPHJM6AWJDHJGO5DEJRHQLNXM.jpg', 'detections': [('73: portátil', 0.9989899), ('1 persona', 0.99642926), ('44: botella', 0.99126947), ('Mesa de comedor', 0.968309), ('62: silla', 0.9303603), ('62: silla', 0.89775985), ('Planta en maceta', 0.88069636)]}]
