In [92]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset
import cv2

import pandas as pd
from typing import List, Tuple, Dict


import multiprocessing  as mp
import requests
import jsonlines
from tqdm import tqdm
import time
import random
from PIL import Image


from os import listdir
from os.path import isfile, join

In [2]:
import sys
sys.path.append("/data/hdd1/brain/BraTS19/YandexCup/ru-clip")

In [32]:
from clip.evaluate.utils import (
    get_text_batch, get_image_batch, get_tokenizer,
    show_test_images, show_similarity,
    prepare_classes, call_model,
    show_topk_probs,
    load_weights_only,
    get_topk_accuracy,
    show_topk_accuracy
)

In [42]:
# в csv file для каждого касса должен быть массив
model, args = load_weights_only("ViT-B/32-small")

{0, 2, 3, 4, 8, 9, 11, 12, 14, 15, 18, 19}

In [115]:
class Sent2textDataset(Dataset):
    
    def __init__(self,path_t_csv, path_i_json, 
                 path_i_folder, down_data = False,
                 n_classes = 20, args = None,
                 tokenizer = None, clastering_mode = False,
                 transform = None, mode = "Sber",
                ):
        """
        path_t_csv - путь до csv файла с текстами
        path_i_json - путь до json файла с картинками
        path_i_folder - путь для сохранения скаченных фотографий
        clastering_mode - тексты разбиты на кластеры
        """
        self.text_data = path_t_csv #pd.read_csv(path_t_csv)
        self.img_name_to_text = None # change in self._load_json_links() ускоряет поиск по файлам
        self.img_links = self._load_json_links(path_i_json) # only in text_data
        
        
        self.path_to_img = path_i_folder if path_i_folder else path_i_json
        
        if down_data:
            manager = mp.Manager()
            self._imgs_path = manager.Queue()
            self._load_imgs(list(self.img_links.items()),n_workers = 16)
        else:
            #check data
            self._check_data_in_folder(path_i_folder)
            
        self.clastering_mode = clastering_mode
        self.transform = transform
        self.args = args
        self.mode = mode
        self.n_classes = n_classes
        self.tokenizer = get_tokenizer() if tokenizer == None else tokenizer
        
        
    def __len__(self,):
        return self.text_data.id_img.unique().shape[0]
    
    
    def _stack_texts(self,now_idx):
        if self.clastering_mode:
            pass
        else:
            indexs = [random.randint(0,len(self.text_data)-1) for i in range(self.n_classes-1)]
            #проверка на совпадение индексов
            for i in range(len(indexs)):
                if indexs[i] == now_idx:
                    indexs[i] = now_idx + 1
                    
            texts = [self.text_data.iloc[idx][0] for idx in indexs]
            return texts
    
    
    def __getitem__(self,idx):
        code = self.text_data.iloc[idx,1]
        name_img = self.img_links.get(code)[0]
        
        img = cv2.imread(f"{self.path_to_img}/{name_img}.jpg", cv2.IMREAD_COLOR)
        
        text = self.text_data.iloc[idx][0] # get gt text
        
        texts = self._stack_texts(idx) # create new class
        texts.append(text)
        #add shuffle 
        labels = torch.tensor([0 if i != self.n_classes - 1 else 1
                               for i in range(self.n_classes)], 
                              dtype=torch.int32)
        
        if self.mode == "Sber":
            assert self.args != None, f"Define args"
            input_ids, attention_mask = get_text_batch(texts, self.tokenizer, self.args)
            if self.transform == None:
                image = [Image.fromarray(img)] # corret this line
                img_input = get_image_batch(image, self.args.img_transform, self.args)
            else:
                img_input = self.transform(img)
                
        return (img_input, input_ids, attention_mask), labels
        
        
    
    def _check_data_in_folder(self,path_i_folder):
        #оставить в csv файле только те sample изображения которых есть в папке
        onlyfiles = [f for f in listdir(path_i_folder) if isfile(join(path_i_folder, f))]
        counter = 0
        start_count = len(self.text_data.id_img.unique())
        for id_img in self.text_data.id_img.unique().copy():
            if self.img_name_to_text.get(id_img) == None:
                self.text_data = self.text_data.drop(self.text_data[self.text_data.id_img == id_img].index)
            else:
                counter+= 1
        print(f"From {start_count} sample folder hasn't {start_count - counter}")
                
    
    def _load_json_links(self,data_path: str, only_i_from_csv = True)->Dict[int, Tuple[str,str]]:
        data = []
        only_csv_links = []
        with jsonlines.open(data_path) as reader:
            reader = tqdm(reader)
            for obj in reader:
                data.append((obj['image'], obj['url']))
                
        if only_i_from_csv:
            #скачивать изображения принадлежащие csv
            only_csv_links = {idx: data[idx] for idx in self.text_data.id_img.unique()}
            self.img_name_to_text = {idx: data[idx][0] for idx in self.text_data.id_img.unique()}
            return only_csv_links
        
        return data
    
    
    def _worker(self,task):
        paths_img = self._load_img(task)
        self._imgs_path.put(paths_img)
        
    
    def _load_img(self,links: Tuple[int,Tuple[str,str]])->int:
        try:
            response = requests.get(f"{links[1][1]}")
            with open(f"{self.path_to_img}/{links[1][0]}.jpg", "wb") as img:
                img.write(response.content)
            return links[0]
        except requests.exceptions.ConnectionError as e:
            print(f"Oyy, miss {links[0]}")
            
    def _load_imgs(self, links: List[Tuple[int, Tuple[str,str]]], n_workers = 1)->bool:
        all_row = set(self.text_data.id_img.unique())
        return_row = set()
        all_len = len(all_row)
        print(links)
        with mp.Pool(n_workers) as p:
            p.map(self._worker, links)
            
            for _ in range(len(links)):
                return_row.add(self._imgs_path.get())
                
        all_row.difference_update(return_row)
        
        for row in all_row:
             self.text_data = self.text_data.drop(self.text_data[self.text_data.id_img == row].index)
        
        assert all_len - len(all_row) == len(self.text_data.id_img.unique())
        
        print(f"Download photo {all_len - len(all_row)} with {all_len} finish")
        return True

In [116]:
path_t_csv_no_clastr = pd.read_csv("data/preproc_text700.csv")[1000:2000] #"data/preproc_text700.csv"
path_i_json = "data/images.json"
path_i_folder = "data/images"

ds = Sent2textDataset(path_t_csv_no_clastr, path_i_json, path_i_folder,args = args)

5462418it [00:12, 451216.13it/s]


From 352 sample folder hasn't 0


In [118]:
data = ds[0]

In [126]:
data[1]

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
       dtype=torch.int32)

In [7]:
!rm -rf  data/images

In [8]:
!ls data/
!mkdir data/images

images.json  metadata.json  preproc_text700.csv


In [24]:
path_t_csv_no_clastr.id_img.copy()

1000     941193
1001     941193
1002     941193
1003    1370697
1004    1370697
         ...   
1995    4731110
1996    4731110
1997    4731110
1998    1444968
1999    1444968
Name: id_img, Length: 1000, dtype: int64