In [None]:
from patronum.processing import ICLSProcessor, IProcessor
from patronum.processing.silo import DataSilo
from transformers import AutoTokenizer
from pathlib import Path
import os

In [None]:
# model_name_or_path  = "intfloat/multilingual-e5-large"
model_name_or_path = "distilbert-base-multilingual-cased"
data_dir = Path.home() / "IDataset" / "vk" / "banner-cls"
save_dir = Path(os.getcwd()) / "weights"

In [None]:
def process_fn(text:str):
    return "query:" + text.strip()

In [None]:
label_list = [
    'EDUCATION',
    'BUSINESSANDFINANCE',
    'AUTOMOTIVE',
    'TELEVISION',
    'RELIGIONANDSPIRITUALITY',
    'FAMILYANDRELATIONSHIPS',
    'EVENTSANDATTRACTIONS',
    'MEDICALHEALTH',
    'HEALTHYLIVING',
    'STYLEANDFASHION',
    'NEWSANDPOLITICS',
    'TRAVEL',
    'HOBBIESANDINTERESTS',
    'MUSICANDAUDIO',
    'BOOKSANDLITERATURE',
    'HOMEANDGARDEN',
    'PERSONALFINANCE',
    'OTHER',
    'FOODANDDRINK',
    'SPORTS',
    'JEWELRYANDWATCHES',
    'MOVIES',
    'SHOPPING',
    'FINEART',
    'CAREERS',
    'PETS',
    'VIDEOGAMING',
    'TECHNOLOGYANDCOMPUTING',
    'REALESTATE',
    'CASINO',
    'POPCULTURE'
]

In [None]:
processor = ICLSProcessor(
    tokenizer=AutoTokenizer.from_pretrained(model_name_or_path),
    max_seq_len=512,
    data_dir=data_dir,
    train_filename="test.csv",
    dev_filename="test.csv",
    test_filename="test.csv",
    label_list=label_list,
    metric="acc",
    delimiter=",",
    text_column_name="text",
    label_column_name="label",
    process_fn=process_fn)

In [None]:
processor.tokenizer

In [None]:
# processor.save(save_dir)
processor = IProcessor.load_from_dir(save_dir)
processor.process_fn = process_fn

In [None]:
docs = [
    "I realize I detest Haymitch. No wonder the District 12 tributes never stand a chance. It isn’t just that we’ve been underfed and lack training. Some of our tributes have still been strong enough to make a go of it. But we rarely get sponsors and he’s (Haymitch) a big part of the reason why. The rich people who back tributes — either because they’re betting on them or simply for the bragging rights of picking a winner — expect someone classier than Haymitch to deal with.",
    "In late summer, I was washing up in a pond when I noticed the plants growing around me. Tall with leaves like arrowheads. Blossoms with three white petals. I knelt down in the water, my fingers digging into the soft mud, and I pulled up handfuls of the roots. Small, bluish tubers that don’t look like much but boiled or baked are as good as any potato. \"Katniss,\" I said aloud. It’s the plant I was named for. And I heard my father’s voice joking, \"As long as you can find yourself, you’ll never starve.\"\n I spent hours stirring up the pond bed with my toes and a stick, gathering the tubers that floated to the top. That night, we feasted on fish and katniss roots until we were all, for the first time in months, full.",
    "They’re funny birds and something of a slap in the face to the Capitol. During the rebellion, the Capitol bred a series of genetically altered animals as weapons. The common term for them was muttations, or sometimes mutts for short. One was a special bird called a jabberjay that had the ability to memorize and repeat whole human conversations. They were homing birds, exclusively male, that were released into regions where the Capitol’s enemies were known to be hiding. After the birds gathered words, they’d fly back to centers to be recorded. It took people awhile to realize what was going on in the districts, how private conversations were being transmitted. Then, of course, the rebels fed the Capitol endless lies, and the joke was on it. So the centers were shut down and the birds were abandoned to die off in the wild.",
    "Peeta Mellark, on the other hand, has obviously been crying and interestingly enough does not seem to be trying to cover it up. I immediately wonder if this will be his strategy in the Games. To appear weak and frightened, to reassure the other tributes that he is no competition at all, and then come out fighting. This worked very well for a girl, Johanna Mason, from District 7 a few years back. She seemed like such a sniveling, cowardly fool that no one bothered about her until there were only a handful of contestants left. It turned out she could kill viciously. Pretty clever, the way she played it. But this seems an odd strategy for Peeta Mellark because he’s a baker’s son. All those years of having enough to eat and hauling bread trays around have made him broad-shouldered and strong. It will take an awful lot of weeping to convince anyone to overlook him.",
    "Finally, Gale is here and maybe there is nothing romantic between us, but when he opens his arms I don’t hesitate to go into them. His body is familiar to me — the way it moves, the smell of wood smoke, even the sound of his heart beating I know from quiet moments on a hunt — but this is the first time I really feel it, lean and hard-muscled against my own. \"Listen,\" he says. \"Getting a knife should be pretty easy, but you’ve got to get your hands on a bow. That’s your best chance.\"\n \"They don’t always have bows,\" I say, thinking of the year there were only horrible spiked maces that the tributes had to bludgeon one another to death with.\n \"Then make one,\" says Gale. \"Even a weak bow is better than no bow at all.\" I have tried copying my father’s bows with poor results. It’s not that easy. Even he had to scrap his own work sometimes."
]

_docs = [{"text": d, processor.tasks['text_classification']["label_name"]: "BOOKSANDLITERATURE"} for d in docs]

In [None]:
dataset, tensor_names, problematic_sample_ids = processor.dataset_from_dicts(
dicts=_docs, indices=list(range(len(_docs)))  # TODO remove indices
)

In [None]:
silo = DataSilo(processor=processor, batch_size=256)

In [None]:
# loader = silo.get_data_loader(dataset_name="train")

In [None]:
# next(iter(silo.loaders["train"]))

In [None]:
from patronum.modeling.prime import IDIBERT
from patronum.modeling import ILanguageModel

In [None]:
# lm = ILanguageModel.load(Path(os.getcwd()) / "weights")

In [None]:
lm = IDIBERT.load("distilbert-base-multilingual-cased")

In [None]:
save_dir = Path(os.getcwd()) / "weights"

In [None]:
from patronum.modeling.flow import ICLSHead

In [None]:
num_labels =len(label_list)

In [None]:
head = ICLSHead(layer_dims=[768, len(label_list)], num_labels=len(label_list))

In [None]:
# lm.save(save_dir)

In [None]:
# head.save(save_dir)

In [None]:
from patronum.modeling import M1Runner

In [None]:
from patronum.etc import initialize_device_settings

In [None]:
device, n_gpu = initialize_device_settings(use_cuda=True)

In [None]:
runner = M1Runner(
        language_model=lm,
        prediction_heads=[head],
        embeds_dropout_prob=0.1,
        lm_output_types=["per_sequence"],
        device="cuda")

In [None]:
# runner.save(save_dir)
# runner = M1Runner.load(save_dir, device="cpu")

In [None]:
from patronum.training.optimizer import initialize_optimizer

In [None]:
# 5. Create an optimizer
model, optimizer, lr_schedule = initialize_optimizer(
    model=runner,
    learning_rate=3e-5,
    device=device,
    n_batches=len(silo.loaders["train"]),
    n_epochs=1,
    use_amp=False)