```python
from transformers import pipeline
from datachain import DataChain, Column

classifier = pipeline("sentiment-analysis", device="cpu",
                model="distilbert/distilbert-base-uncased-finetuned-sst-2-english")

def is_positive_dialogue_ending(file) -> bool:
    dialogue_ending = file.read()[-512:]
    return classifier(dialogue_ending)[0]["label"] == "POSITIVE"

chain = (
   DataChain.from_storage("gs://datachain-demo/chatbot-KiT/",
                          object_name="file", type="text")
   .settings(parallel=8, cache=True)
   .map(is_positive=is_positive_dialogue_ending)
   .save("file_response")
)

positive_chain = chain.filter(Column("is_positive") == True)
positive_chain.export_files("./output")

print(f"{positive_chain.count()} files were exported")

In [7]:
from utils import MBDataset, DATA_DIR
import os 
import json
from tqdm import tqdm

In [8]:
ds = MBDataset(DATA_DIR)

100%|██████████| 1053/1053 [00:00<00:00, 1062449.39it/s]


In [10]:
TURNS_PATH = "data/test-MagicBrush/test/turns"
os.makedirs(TURNS_PATH, exist_ok=True)

for idx, turn in tqdm(enumerate(ds.edit_turns)):
    with open(os.path.join(TURNS_PATH, f"{idx}.json"), "w") as f:
        json.dump(turn.to_json(), f, indent=4)

1053it [00:00, 6212.53it/s]


In [None]:
from transformers import pipeline
from datachain import DataChain, Column

oracle = pipeline(model="deepset/roberta-base-squad2", device="mps")

def is_coloring_edit(obj) -> bool:

    QUESTION = "Is this edit instruction corresponding to a modification in the color of an object or the scene? JUST ANSWER YES OR NO"

    _edit = obj["instruction"]

    res = oracle(question=QUESTION, context=_edit)

    return res[0]["answer"].lower() == "yes"

chain = (
      DataChain.from_json(
      path="data/test-MagicBrush/test/turns/*.json",
      schema_from="data/mb_turns_schema.json"
   )
   .map(is_coloring=is_coloring_edit)
   .save("file_response")
)

positive_chain = chain.filter(Column("is_coloring") == True)