In [1]:
import sys
sys.path.append('..')

import os
import random
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer
from model.model_minimind import MiniMindConfig, MiniMindForCausalLM
from model.model_lora import *
from trainer.trainer_utils import setup_seed
import pandas as pd
import json
from types import SimpleNamespace

In [2]:
DATA_LOC = '/root/xlcoder/MiniMind2-Small/dataset'
INSTR_LOC = '/root/xlcoder/MiniMind2-Small/hands-on'

In [3]:
def init_model(args):
    tokenizer = AutoTokenizer.from_pretrained(args.load_from)
    if 'model' in args.load_from:
        model = MiniMindForCausalLM(MiniMindConfig(
            hidden_size=args.hidden_size,
            num_hidden_layers=args.num_hidden_layers,
            use_moe=bool(args.use_moe),
            inference_rope_scaling=args.inference_rope_scaling
        ))
        moe_suffix = '_moe' if args.use_moe else ''
        ckp = f'./{args.save_dir}/{args.weight}_{args.hidden_size}{moe_suffix}.pth'
        model.load_state_dict(torch.load(ckp, map_location=args.device), strict=True)
        if args.lora_weight != 'None':
            apply_lora(model)
            load_lora(model, f'./{args.save_dir}/lora/{args.lora_weight}_{args.hidden_size}.pth')
    else:
        model = AutoModelForCausalLM.from_pretrained(args.load_from, trust_remote_code=True)
    print(f'MiniMindÊ®°ÂûãÂèÇÊï∞: {sum(p.numel() for p in model.parameters()) / 1e6:.2f} M(illion)')
    return model.eval().to(args.device), tokenizer

In [10]:
args={
    # 'load_from':'/root/xlcoder/MiniMind2-Small/MiniMind2',
    'load_from': '../model',
    'save_dir': '../out',
    # 'weight': 'full_sft',
    'weight': 'pretrain',
    # 'hidden_size': 768,
    'hidden_size': 512,
    # 'num_hidden_layers': 16,
    'num_hidden_layers': 8,
    'use_moe': 0,
    'inference_rope_scaling': False,
    'max_new_tokens':8192,
    'temperature':0.85,
    'top_p':0.85,
    'historys':0,
    'device':'cuda' if torch.cuda.is_available() else 'cpu',
    'times': 1,
    'count_only': False
}
args = json.loads(json.dumps(args), object_hook=lambda d: SimpleNamespace(**d))

In [5]:
raw_data = pd.read_csv(os.path.join(DATA_LOC, 'bbc-news-data.csv'), sep='\t')
raw_data

Unnamed: 0,category,filename,title,content
0,business,001.txt,Ad sales boost Time Warner profit,Quarterly profits at US media giant TimeWarne...
1,business,002.txt,Dollar gains on Greenspan speech,The dollar has hit its highest level against ...
2,business,003.txt,Yukos unit buyer faces loan claim,The owners of embattled Russian oil giant Yuk...
3,business,004.txt,High fuel prices hit BA's profits,British Airways has blamed high fuel prices f...
4,business,005.txt,Pernod takeover talk lifts Domecq,Shares in UK drinks and food firm Allied Dome...
...,...,...,...,...
2220,tech,397.txt,BT program to beat dialler scams,BT is introducing two initiatives to help bea...
2221,tech,398.txt,Spam e-mails tempt net shoppers,Computer users across the world continue to i...
2222,tech,399.txt,Be careful how you code,A new European directive could put software w...
2223,tech,400.txt,US cyber security chief resigns,The man making sure US computer networks are ...


In [6]:
to_process = raw_data.iloc[:10]
to_process;

In [11]:
model, tokenizer = init_model(args)
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
with open(os.path.join(INSTR_LOC, 'few_shot1_pre.txt'), 'r') as pre_file:
    pre_instruct = pre_file.read()

with open(os.path.join(INSTR_LOC, 'few_shot1_post.txt'), 'r') as post_file:
    post_instruct = post_file.read()
    
for _ in range(int(args.times)):
    prompt_iter = raw_data.sample(n=10)
    # pre_instruct = 
    # post_instruct = "Your answer (choose ONLY one category in business, entertainment, politics, sport, tech):\n"
    for i, prompt in prompt_iter.iterrows():
        # setup_seed(2026) # or setup_seed(random.randint(0, 2048))
        
        setup_seed(random.randint(0, 2048))
        conversation = conversation[-args.historys:] if args.historys else []
        conversation.append({"role": "user", "content": pre_instruct+'title:\n'+prompt['title']+'\ncontent:\n'+prompt['content']+post_instruct})
        # conversation.append({"role":"user", "content": "‰ªäÂ§©Â§©Ê∞îÊÄé‰πàÊ†∑Ôºü"})
        # print(conversation)
        templates = {"conversation": conversation, "tokenize": False, "add_generation_prompt": True}
        # if args.weight == 'reason': templates["enable_thinking"] = True # ‰ªÖReasonÊ®°Âûã‰ΩøÁî®
        inputs = tokenizer.apply_chat_template(**templates) if args.weight != 'pretrain' else (tokenizer.bos_token + prompt)
        inputs = tokenizer(inputs, return_tensors="pt", truncation=True).to(args.device)

        print('ü§ñÔ∏è: ', end='')
        generated_ids = model.generate(
            inputs=inputs["input_ids"], attention_mask=inputs["attention_mask"],
            max_new_tokens=args.max_new_tokens, do_sample=True, streamer=streamer,
            pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id,
            top_p=args.top_p, temperature=args.temperature, repetition_penalty=1.0
        )
        response = tokenizer.decode(generated_ids[0][len(inputs["input_ids"][0]):], skip_special_tokens=True)
        if args.count_only:
            res_len=len(response)
            # ngrm_rep = ngram_repetition(response, 4)
            ngrm_rep = weighted_ngram_local_activation(response, 4, normalize='empirical')
            print("token_count:", res_len, "weighted_4-gram_repitition:", ngrm_rep)
            df.append({'prompt_id':i,'token_count':len(response), 'weighted_4-gram_repitition': ngrm_rep})
            continue
        conversation.append({"role": "assistant", "content": response})
        print('\n\n')

'(MaxRetryError("HTTPSConnectionPool(host='huggingface.co', port=443): Max retries exceeded with url: /model/resolve/main/tokenizer_config.json (Caused by NewConnectionError('<urllib3.connection.HTTPSConnection object at 0x7fc4ac0d9040>: Failed to establish a new connection: [Errno 101] Network is unreachable'))"), '(Request ID: bc4571ec-440e-4041-ac87-6545f8ce9b9f)')' thrown while requesting HEAD https://huggingface.co/model/resolve/main/tokenizer_config.json
Retrying in 1s [Retry 1/5].


KeyboardInterrupt: 