In [6]:
%load_ext autoreload
%autoreload 2


import os

os.environ["CUDA_VISIBLE_DEVICES"] = "1,2,3,4"

import sys
import torch
import pandas as pd
from torch import nn
import numpy as np
from torch.optim.lr_scheduler import ExponentialLR
import wandb

sys.path.append("../NLP-DL-Project-hypo-to-hyper/pipeline_src/")


from config.config import TaskConfig
from train import CustomScheduler, train
from logger.logger import WanDBWriter
from trainer.train_epoch import train_epoch, predict
from metrics.metrics import get_all_metrics
from dataset.dataset import init_data
from logger.logger import WanDBWriter


if torch.cuda.is_available():
    device = "cuda"
    print("GPU")
else:
    device = "cpu"
    print("CPU")


SEED = 0
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
print(torch.cuda.device_count())

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
GPU
4


In [7]:
from transformers import (
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    AutoConfig,
    AutoModelForCausalLM,
    LlamaTokenizer,
    LlamaForCausalLM,
)

from peft import LoraConfig, get_peft_model, get_peft_model_state_dict


Welcome to bitsandbytes. For bug reports, please run

python -m bitsandbytes

 and submit this information together with your error trace to: https://github.com/TimDettmers/bitsandbytes/issues
bin /srv/home/rabikov/taxonomy_env/lib/python3.8/site-packages/bitsandbytes/libbitsandbytes_cpu.so
CUDA SETUP: Highest compute capability among GPUs detected: 6.1
CUDA SETUP: Detected CUDA version 117
CUDA SETUP: Loading binary /srv/home/rabikov/taxonomy_env/lib/python3.8/site-packages/bitsandbytes/libbitsandbytes_cpu.so...


  warn(msg)
  warn(msg)
  warn(msg)
  warn(msg)
  warn(msg)
  warn(msg)


In [8]:
config = TaskConfig()

config.batch_size = 8


config.data_path = 'babel_datasets/wnet_train_en_babel.pickle'
config.gold_path = (
    None  # "SemEval2018-Task9/training/gold/1A.english.training.gold.txt"
)
config.test_data_path = 'babel_datasets/wnet_test_en_babel.pickle'
config.test_gold_path = (
    None  # "SemEval2018-Task9/test/gold/1A.english.test.gold.txt"
)

config.device = device
config.using_peft = True
config.model_type = "Auto"  # Auto or Llama
config.wandb_log_dir = "/raid/rabikov/wandb/"
config.model_checkpoint = "EleutherAI/gpt-neo-1.3B"
config.exp_name = config.model_checkpoint.replace("/", "-")
config.saving_path = "/raid/rabikov/model_checkpoints/" + config.exp_name

load_path = (
    "/raid/rabikov/model_checkpoints/"
    + "EleutherAI-gpt-neo-1.3Breweighted_wnet_remove_all_from_labels_custom_multilang_epoch=0_MAP=0.09087864718127875.pth"
)

In [9]:
if config.model_type == "Auto":
    model_type = AutoModelForCausalLM
    tokenizer_type = AutoTokenizer
elif config.model_type == "Llama":
    model_type = LlamaForCausalLM
    tokenizer_type = LlamaTokenizer

model = model_type.from_pretrained(
    config.model_checkpoint,
    # load_in_8bit=True,
   # torch_dtype=torch.float16,
    device_map="auto",
)

tokenizer = tokenizer_type.from_pretrained(
    config.model_checkpoint,
    padding_side="left",
)

In [10]:
if config.using_peft:
    LORA_R = 8
    LORA_ALPHA = 16
    LORA_DROPOUT = 0.05
    LORA_TARGET_MODULES = [
        "q",
        "v",
    ]

    # model = prepare_model_for_int8_training(model)
    config_lora = LoraConfig(
        r=LORA_R,
        lora_alpha=LORA_ALPHA,
        # target_modules=LORA_TARGET_MODULES,
        lora_dropout=LORA_DROPOUT,
        bias="none",
        task_type="CAUSAL_LM",
    )
    model = get_peft_model(model, config_lora)
    model.print_trainable_parameters()

trainable params: 1572864 || all params: 1317148672 || trainable%: 0.11941431012580485


In [11]:
train_dataset, test_dataset, train_loader, val_loader = init_data(tokenizer, config)

In [7]:
train_dataset.data[12]

{'children': ['ferrocerium.n.1', 'misch_metal.n.1'],
 'parents': 'pyrophoric_alloy.n.1',
 'grandparents': None,
 'case': 'only_leafs_all'}

In [33]:
tokenizer.decode(train_dataset[1]['input_seq'])

"<s> Predict hyponyms for the word 'digit.n.1'.  Answer:<s>zero.n.2, three.n.1, four.n.1, five.n.1, six.n.1, seven.n.1, eight.n.1, nine.n.1, binary_digit.n.1, decimal_digit.n.1, duodecimal_digit.n.1, hexadecimal_digit.n.1, octal_digit.n.1, significant_digit.n.1, one.n.1, two.n.1"

In [12]:
checkpoint = torch.load(load_path, map_location='cpu')
model.load_state_dict(checkpoint["model"])
del checkpoint
torch.cuda.empty_cache()

In [13]:
config.gen_args = {
    "no_repeat_ngram_size": 2,
    "max_new_tokens": 32,
    "num_return_sequences": 2,
    "num_beams": 15,
    "early_stopping": True,
    "num_beam_groups": 5,
    "diversity_penalty": 1.0,
    "temperature": 0.9,
}


config.gen_args = {
    "no_repeat_ngram_size": 2,
    "num_beams": 5,
    "early_stopping": True,
    "max_new_tokens": 8,
    "temperature": 0.95,
}

config.gen_args = {
    "no_repeat_ngram_size": 3,
    "do_sample": True,
    "num_beams": 1,
    "min_new_tokens": 16 - 1,
    "max_new_tokens": 16,
    "temperature": 0.9,
    "top_k": 20,
    "num_return_sequences": 2

}

In [None]:
all_preds, all_labels = predict(model, tokenizer, val_loader, config)

In [14]:
from tqdm import tqdm

def split(ls, size):
    res = []

    for i in range(0, len(ls)-1, size):
        res.append(ls[i:i+size])
    return res

def get_one_sample(batch, model, config):

    
    terms, att_mask_terms, targets, input_seqs, att_mask_input, labels = batch
    output_tokens = model.generate(
        inputs=terms.to(config.device),
        attention_mask=att_mask_terms.to(config.device),
        pad_token_id=tokenizer.eos_token_id,
        **config.gen_args,
    )
    pred_tokens = output_tokens[:, terms.size()[1] :]
    pred_str = tokenizer.batch_decode(pred_tokens.cpu(), skip_special_tokens=True)
    gold_str = tokenizer.batch_decode(targets, skip_special_tokens=True)

    if len(pred_str) > len(gold_str):
        pred_str = split(pred_str, config.gen_args['num_return_sequences'])

    return pred_str, gold_str

@torch.no_grad()
def predict(model, tokenizer, val_loader, config, epoch="", ans_load_path=None):

    model.eval()
    
    all_preds = []
    all_golds = []
    for batch in tqdm(val_loader):
        pred, gold = get_one_sample(batch, model, config)

        all_preds.append(pred)
        all_golds.append(gold)

    return all_preds, all_golds

In [15]:
all_preds, all_labels = predict(model, tokenizer, val_loader, config)

  0%|          | 0/99 [00:00<?, ?it/s]

100%|██████████| 99/99 [03:14<00:00,  1.96s/it]


In [19]:
all_preds[1]

[['claudication syndromen.2 | hypernomenclaturen.3',
  'claudication pain of the feet and legs, claudicating pain of lower'],
 ['carpals, tarsal, flexor digitorum superficialis,',
  'pectoralis, latissimus dorsi, pectoralis major,'],
 ['rhomboidalis, rhombicis, rhomboideus, rh',
  'rheumatoid rhombus, pyriform, trapezoid,'],
 ['bastard, godson, godchild, godmother, godfather,',
  'sister, godson, godchild, godmother, godwren,'],
 ['liver donor, blood donor, organ source, recipient, donor, donor card',
  'free citizen, citizen, person, human being, free person, citizen-subject'],
 ['mulatto, half-caste, halfbreed, mongrel',
  'commoner, non-Jew, noncomic, nonconformist,'],
 ['atheist, conservative, liberal, agnostic, atheist, fundamentalist, Marxist,',
  'critic, nonconformist, antiwar, liberal, radical, diss'],
 ['bloodsucker, bullfighter, cocky, bulldog, jester,',
  "matador, bullfight, bullfighter's cape, roper, rodeo"]]

In [90]:
pred[1]

['chinchilla palm, coconut palm, mangrove palm, orchid palm, palm civet, tree palm, tree leg, kapok palm',
 'palmetto, bromeliads, palm tree, palmettos, palm wine, senna, tamarind, pali, cactus,',
 'cocor, cocoanut palm, eucalyptus, eupatorium, eugenia, khalwa, kopi, t',
 'bunya, banana palm, banana tree, banana, banana leaf, banana nut, chameleon palm, date palm, dendrobium, g',
 'acacia, acer, agapanthus, ackee, acorn, acumin, acushla, akalai, akimba, all',
 'baby palm, palm tree, palm leaf, palm wine, palm root, taro, kadamba, manioc, palm-nut, kap',
 'mango palm, banana palm, coconut palm, finger palm, kopje palm, laurustinus, mangrove palm, pine palm, pineapple',
 'grapefruit tree, water palm, kapok tree, banana tree, banyan tree, coconut palm, palm tree, pineapple palm, pineapple tree',
 'tangela, champa, chamomile, palm nut, palm olein, tangela palm, chiromantra palm, b',
 'bougainvillaea, hibiscus, lily of the valley, hollyhock, nana, water palm, fig tree, b',
 'chilean, karst 

In [76]:
def split(ls, size):
    res = []

    for i in range(0, len(ls)-1, size):
        res.append(ls[i:i+size])
    return res

In [77]:
split(pred)

  'aerial radar, civil radar, flight-path radar, shipboard radar, marine radar, microwave radar, radar altimeter, radar astronomy, radio-direction'],
 ['coconut palm, coconut palm, pinnate palm, rubber palm, sugarcane palm, tea palm, tangerine palm, oil palm,',
  'acacia palm, cachiro, ejipe, palm frond, palm tree, palmettos, palm leaf, palmetto, cerc']]

In [17]:
pred, gold

  'palmyra palm, African palm tree, cactus palm (cactus),'],
 ['three-dimensional radar, Doppler radar',
  'cabbage palm, cabbage palm, cabbage palm, coconut, corozo, fishtail palm, nipa palm, royal palm'])

In [19]:
def transform(label):
    all_words = label.split(',')
    new_words = []
    for word in all_words:
        new_words.append(word.strip().split('.')[0])
    
    return ', '.join(new_words)
all_labels2 = list(map(transform, all_labels1))

In [27]:
all_preds2 = list(map(lambda x: x.replace('.', ''), all_preds))

In [28]:
metrics = get_all_metrics(all_labels2, all_preds2, limit=15)

In [34]:
all_labels2[17], all_preds2[17]

('black_locust, clammy_locust', 'chrysanthemum')

In [29]:
metrics

{'MRR': 0.0, 'MAP': 0.0, 'P@1': 0.0, 'P@3': 0.0, 'P@5': 0.0, 'P@15': 0.0}

In [15]:
import pickle

saving_path = config.saving_predictions_path + config.exp_name + "_"

with open(saving_path, "wb") as fp:
    pickle.dump(all_preds, fp)

In [17]:
all_preds[4], all_labels[4]

('person, actor, film director, cinematography, filmmaker, visual arts, person',
 'thrower, baseball player, jock, person')