In [1]:
%%writefile run.py
import json
import os
import re
import numpy as np
import pandas as pd
from transformers import TrOCRProcessor, VisionEncoderDecoderModel, AutoTokenizer, Seq2SeqTrainingArguments, Seq2SeqTrainer, DefaultDataCollator
from PIL import Image
import requests
from evaluate import load
import wandb
import torch.nn.functional as F
import torch.nn as nn
import torch
from datasets import Dataset, load_from_disk, load_dataset
from sklearn.metrics import accuracy_score
import random
from tqdm.auto import tqdm
from sklearn.model_selection import train_test_split
from collections import Counter
from transformers import DonutProcessor, VisionEncoderDecoderModel
from transformers.modeling_outputs import Seq2SeqLMOutput
from transformers.models.vision_encoder_decoder.modeling_vision_encoder_decoder import shift_tokens_right
from datasets import load_dataset, concatenate_datasets
from copy import deepcopy
from PIL import Image
from rdkit import Chem, RDLogger
import os
from rdkit.Chem import Draw
from io import BytesIO

RDLogger.DisableLog('rdApp.*')  
wandb.login(key="673ae6e9b51cc896110db5327738b993795fffad")
os.environ['WANDB_API_KEY'] = "673ae6e9b51cc896110db5327738b993795fffad"
wandb.init(project='DoHACK',name='DONUT')
cer = load('cer')

def set_seed(seed: int = 56) -> None:
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # When running on the CuDNN backend, two further options must be set
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    # Set a fixed value for the hash seed
    os.environ["PYTHONHASHSEED"] = str(seed)
    print(f"Random seed set as {seed}")
set_seed()

molecula_130m = load_dataset("parquet", data_files="PubChem90m_canon.parquet.gzip")['train']
zinc20 = load_dataset('sagawa/ZINC-canonicalized')['train']
zinc20 = zinc20.filter(lambda x: Chem.MolFromSmiles(x['smiles']) is not None and len(x['smiles']) < 128, num_proc=12)
ds = molecula_130m.train_test_split(0.0002, seed=56)
dataset_train = ds['train']
dataset_val = load_dataset("csv", data_files="new_val.csv")['train'] #ds['test']
tokenizer = AutoTokenizer.from_pretrained('sagawa/PubChem-10m-t5-v2')

processor = DonutProcessor.from_pretrained('naver-clova-ix/donut-base')
processor.tokenizer= tokenizer
processor.image_processor.size = {'height': 384, 'width': 384}#{'height': 512, 'width': 512}

class VisionEncoderDecoderSmooth(VisionEncoderDecoderModel):
    def forward(
        self,
        pixel_values = None,
        decoder_input_ids = None,
        decoder_attention_mask = None,
        encoder_outputs = None,
        past_key_values = None,
        decoder_inputs_embeds = None,
        labels=None,
        use_cache = None,
        output_attentions = None,
        output_hidden_states = True,
        return_dict = None,
        **kwargs,
    ):
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        kwargs_encoder = {argument: value for argument, value in kwargs.items() if not argument.startswith("decoder_")}

        kwargs_decoder = {
            argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
        }

        if encoder_outputs is None:
            if pixel_values is None:
                raise ValueError("You have to specify pixel_values")

            encoder_outputs = self.encoder(
                pixel_values,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
                **kwargs_encoder,
            )
        elif isinstance(encoder_outputs, tuple):
            encoder_outputs = BaseModelOutput(*encoder_outputs)

        encoder_hidden_states = encoder_outputs[0]

        # optionally project encoder_hidden_states
        if (
            self.encoder.config.hidden_size != self.decoder.config.hidden_size
            and self.decoder.config.cross_attention_hidden_size is None
        ):
            encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states)

        # else:
        encoder_attention_mask = None

        if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None):
            decoder_input_ids = shift_tokens_right(
                labels, self.config.pad_token_id, self.config.decoder_start_token_id
            )

        # Decode
        decoder_outputs = self.decoder(
            input_ids=decoder_input_ids,
            attention_mask=decoder_attention_mask,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
            inputs_embeds=decoder_inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            use_cache=use_cache,
            past_key_values=past_key_values,
            return_dict=return_dict,
            **kwargs_decoder,
        )

        # Compute loss independent from decoder (as some shift the logits inside them)
        loss = None
        if labels is not None:
            logits = decoder_outputs.logits if return_dict else decoder_outputs[0]
            loss_fct = nn.CrossEntropyLoss(label_smoothing=0.05)
            loss = loss_fct(logits.reshape(-1, self.decoder.config.vocab_size), labels.reshape(-1))
            
            

        if not return_dict:
            if loss is not None:
                return (loss,) + decoder_outputs + encoder_outputs
            else:
                return decoder_outputs + encoder_outputs

        return Seq2SeqLMOutput(
            loss=loss,
            logits=decoder_outputs.logits,
            past_key_values=decoder_outputs.past_key_values,
            decoder_hidden_states=decoder_outputs.hidden_states,
            decoder_attentions=decoder_outputs.attentions,
            cross_attentions=decoder_outputs.cross_attentions,
            encoder_last_hidden_state=encoder_outputs.last_hidden_state,
            encoder_hidden_states=encoder_outputs.hidden_states,
            encoder_attentions=encoder_outputs.attentions,
        )

model = VisionEncoderDecoderSmooth.from_pretrained('naver-clova-ix/donut-base')
model.decoder.resize_token_embeddings(len(tokenizer))

processor.tokenizer.cls_token_id = processor.tokenizer.pad_token_id
model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
model.config.pad_token_id = processor.tokenizer.pad_token_id
model.config.vocab_size = model.config.decoder.vocab_size

model.config.eos_token_id = processor.tokenizer.eos_token_id
model.config.max_length = 256


def draw_smiles(smiles):
    try:
        m = Chem.MolFromSmiles(smiles)
        d2d = Draw.MolDraw2DCairo(512,512)
        dopts = d2d.drawOptions()
        dopts.useBWAtomPalette()
        d2d.DrawMolecule(m)
        d2d.FinishDrawing()
        bio = BytesIO(d2d.GetDrawingText())
        return Image.open(bio).convert('RGB')
    except:
        return draw_smiles('C')

MERGE_PROB = 0.15
merge_i = 10_300_000
ORGANIC_SET = ['B', 'C', 'N', 'O', 'P', 'S', 'F', 'Cl', 'Br', 'I']
ELEMENTS = [
    "H", "He", "Li", "Be", "B", "C", "N", "O", "F", "Ne",
    "Na", "Mg", "Al", "Si", "P", "S", "Cl", "Ar", "K", "Ca",
    "Sc", "Ti", "V", "Cr", "Mn", "Fe", "Co", "Ni", "Cu", "Zn",
    "Ga", "Ge", "As", "Se", "Br", "Kr", "Rb", "Sr", "Y", "Zr",
    "Nb", "Mo", "Tc", "Ru", "Rh", "Pd", "Ag", "Cd", "In", "Sn",
    "Sb", "Te", "I", "Xe", "Cs", "Ba", "La", "Ce", "Pr", "Nd",
    "Pm", "Sm", "Eu", "Gd", "Tb", "Dy", "Ho", "Er", "Tm", "Yb",
    "Lu", "Hf", "Ta", "W", "Re", "Os", "Ir", "Pt", "Au", "Hg",
    "Tl", "Pb", "Bi", "Po", "At", "Rn", "Fr", "Ra", "Ac", "Th",
    "Pa", "U", "Np", "Pu", "Am", "Cm", "Bk", "Cf", "Es", "Fm",
    "Md", "No", "Lr", "Rf", "Db", "Sg", "Bh", "Hs", "Mt", "Ds",
    "Rg", "Cn", "Nh", "Fl", "Mc", "Lv", "Ts", "Og"
]

ATOMS = ORGANIC_SET + [f'[{e}]' for e in ELEMENTS]
len_merge = len(zinc20)

def aug_smiles(smiles):
    global merge_i
    
    if random.random() < MERGE_PROB:
        mode = random.choice(['long', 'short'])
        if mode == 'long':
            add_smiles_idx = merge_i % len_merge
            merge_i += 1
            smileses = [smiles, zinc20[add_smiles_idx]['smiles']]
            smileses.sort(key=len)
            smiles = '.'.join(smileses)
        else:
            count = random.randint(1, 3)
            add_atoms = np.random.choice(ATOMS, count)
            smileses = [smiles]
            smileses.extend(list(add_atoms))
            smileses.sort(key=len)
            smiles = '.'.join(smileses)
    #smiles = Chem.MolToSmiles(Chem.MolFromSmiles(smiles), kekuleSmiles=True)
    return smiles

def prepare_features(examples):
    smileses = [aug_smiles(s) for s in examples['smiles']]
    images = [draw_smiles(s) for s in smileses]
    pixel_values = processor(images=images, return_tensors="pt").pixel_values
    target_encoding = processor.tokenizer(
        [f'{i}' for i in smileses],
        padding="longest",
        max_length=256,
        truncation=True,
        return_tensors='np'
    )
    tokenized_examples = {'pixel_values': pixel_values}
    labels = target_encoding.input_ids
    labels[labels == tokenizer.pad_token_id] = -100
    tokenized_examples['labels'] = labels
    return tokenized_examples

def prepare_features_val(examples):
    smileses = [s for s in examples['smiles']]
    images = [draw_smiles(s) for s in smileses]
    pixel_values = processor(images=images, return_tensors="pt").pixel_values
    target_encoding = processor.tokenizer(
        [f'{i}' for i in smileses],
        padding="longest",
        max_length=256,
        truncation=True,
        return_tensors='np'
    )
    tokenized_examples = {'pixel_values': pixel_values}
    labels = target_encoding.input_ids
    labels[labels == tokenizer.pad_token_id] = -100
    tokenized_examples['labels'] = labels
    return tokenized_examples


tokenized_dataset_train = dataset_train.with_transform(prepare_features)
tokenized_dataset_val = dataset_val.with_transform(prepare_features_val)

args = Seq2SeqTrainingArguments(
    'donut_modelv2',
    remove_unused_columns=False,
    save_safetensors = False,
    evaluation_strategy = 'steps',
    per_device_train_batch_size = 64,
    per_device_eval_batch_size = 64,
    learning_rate = 5e-4,
    weight_decay=0.01,
    num_train_epochs=1,
    logging_steps = 5,
    save_strategy = 'steps',
    eval_steps = 750,
    save_steps=750,
    report_to = 'wandb',
    gradient_accumulation_steps=4,
    dataloader_num_workers=12,
    lr_scheduler_type = 'cosine',
    predict_with_generate = True,
    save_total_limit = 3,
    optim='adamw_torch',
    adam_beta2 = 0.98,
)

def compute_metrics(preds):
    labels, predictions = preds.label_ids, preds.predictions
    labels[labels == -100] = tokenizer.pad_token_id
    labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    predictions[predictions == -100] = tokenizer.pad_token_id
    predictions = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    print(labels[-1], predictions[-1])
    y_true = [x.strip() for x in labels]
    y_pred = [x.strip() for x in predictions]
    accuracy = accuracy_score(y_true,y_pred)
    return {f'cer': cer.compute(predictions=y_pred, references=y_true),'accuracy':accuracy}


trainer = Seq2SeqTrainer(
    model=model,
    args=args,
    train_dataset = tokenized_dataset_train,
    eval_dataset = tokenized_dataset_val,
    compute_metrics = compute_metrics,
    tokenizer=processor,
    data_collator=DefaultDataCollator()
)

trainer.train("donut_modelv2/checkpoint-174000")

Overwriting run.py


In [2]:
#import json
#dump = json.load(open('donut_modelv2/checkpoint-16000/trainer_state.json'))
#dump['global_step'] = 16500
#
#with open('donut_modelv2/checkpoint-16000/trainer_state.json', 'w') as f:
#    json.dump(dump, f)

In [None]:
!accelerate launch --mixed_precision=fp16 ./run.py
# Отруби ZINC

The following values were not passed to `accelerate launch` and had defaults used instead:
	`--num_processes` was set to a value of `1`
	`--num_machines` was set to a value of `1`
	`--dynamo_backend` was set to a value of `'no'`
[34m[1mwandb[0m: Currently logged in as: [33mandrey20007[0m ([33mandrey2007[0m). Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: wandb version 0.16.4 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade
[34m[1mwandb[0m: Tracking run with wandb version 0.13.4
[34m[1mwandb[0m: Run data is saved locally in [35m[1m/notebooks/wandb/run-20240314_151946-fsrrjboo[0m
[34m[1mwandb[0m: Run [1m`wandb offline`[0m to turn off syncing.
[34m[1mwandb[0m: Syncing run [33mDONUT[0m
[34m[1mwandb[0m: ⭐️ View project at [34m[4mhttps://wandb.ai/andrey2007/DoHACK[0m
[34m[1mwandb[0m: 🚀 View run at [34m[4mht