In [19]:
#!/usr/bin/env python
# -*- coding:utf-8 -*-
__author__ = 'Author'
__email__ = 'Email'

# Detecting Contradiction at the Lexical Level
## Llama

In [20]:
# dependency
# built-in
import os, random
# public
import torch
import numpy as np
import pandas as pd
from tqdm import tqdm
# private
from src.llama.llama3.generation import Llama3
# from src.llama.llama4.generation import Llama4
from config import Config

%load_ext autoreload
%autoreload 2
%config Completer.use_jedi = False

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# Helper

In [21]:
def get_device():
    if "DEVICE" in os.environ:
        return os.environ["DEVICE"]
    if torch.cuda.is_available():
        return "cuda"
    elif torch.xpu.is_available():
        return "xpu"
    return "cpu"

def set_random_seed(seed: int = 42):
    """Fix random seeds for reproducibility across Python, NumPy, and PyTorch."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)  # if using multi-GPU

    # Ensures deterministic behavior where possible
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

## Llama 3

In [22]:
config = Config()
for k,v in config.__dict__.items():
    print(f'{k}: {v}')

seed: 0
llm: Llama-3.1-8B
CURR_PATH: ./
RESOURCE_PATH: ./res
DATA_PATH: ./res/data
RESULTS_PATH: ./res/results
LLMS_PATH: ./res/llms
LLM_PATH: ./res/llms/Llama-3.1-8B
HF_LLMS_PATH: ./res/hf_llms
HF_LLM_PATH: ./res/hf_llms/Llama-3.1-8B


In [23]:
ckpt_dir = config.LLM_PATH
world_size = 1
max_seq_len = 512
max_gen_len = 20
max_batch_size = 1
temperature = 0.
top_p = 1.0
quantization_mode = None

In [6]:
llama = Llama3.build(
    ckpt_dir=ckpt_dir,
    max_seq_len=max_seq_len,
    max_batch_size=max_batch_size,
    world_size=world_size,
    quantization_mode=quantization_mode,
    device=get_device(),
)
llama.model.eval()


> initializing model parallel with size 1
> initializing ddp with size 1
> initializing pipeline with size 1
Loading a checkpoint (shards=1, current-mp-size=1)
Loading checkpoint shards:
[PosixPath('res/llms/Llama3.2-3B/consolidated.00.pth')]


  _C._set_default_tensor_type(t)


Setting default device to cpu
Loading state dict...
Done...
Loaded in 1.06 seconds


Transformer(
  (tok_embeddings): VocabParallelEmbedding()
  (layers): ModuleList(
    (0-27): 28 x TransformerBlock(
      (attention): Attention(
        (wq): ColumnParallelLinear()
        (wk): ColumnParallelLinear()
        (wv): ColumnParallelLinear()
        (wo): RowParallelLinear()
      )
      (feed_forward): FeedForward(
        (w1): ColumnParallelLinear()
        (w2): RowParallelLinear()
        (w3): ColumnParallelLinear()
      )
      (attention_norm): RMSNorm()
      (ffn_norm): RMSNorm()
    )
  )
  (norm): RMSNorm()
  (output): ColumnParallelLinear()
)

### Wiki QA

In [7]:
# load tsv
raw_df = pd.read_csv('res/data/wiki/capital50.tsv', sep='\t')
raw_df.head()

Unnamed: 0,wikidata_id,country,capital,source
0,Q233,Malta,Valletta,The capital of Malta is
1,Q262,Algeria,Algiers,The capital of Algeria is
2,Q889,Afghanistan,Kabul,The capital of Afghanistan is
3,Q33,Finland,Helsinki,The capital of Finland is
4,Q736,Ecuador,Quito,The capital of Ecuador is


#### Next Token Prediction (NTP) -> Question Answering (QA)

In [19]:
# ask the capital of a country
xs_list = raw_df['country'].tolist()
ys_list = raw_df['capital'].tolist()
prompts_list = [f'Question: What is the capital of {x}? Answer:' for x in xs_list]
prompts_list[0]

'Question: What is the capital of Malta? Answer:'

In [20]:
llama

<src.llama.llama3.generation.Llama3 at 0x335f6f3d0>

In [28]:
set_random_seed(config.seed)

responses = []
for p in tqdm(prompts_list[25:27]):
    print(p)
    batch = [p]
    response = []
    for token_results in llama.completion(
        batch
        , temperature=temperature
        , top_p=top_p
        , max_gen_len=max_gen_len
        ):
        result = token_results[0]
        if result.finished:
            break
        response.append(result.text)
    responses.append(''.join(response))
    # break

  0%|          | 0/2 [00:00<?, ?it/s]

Question: What is the capital of Brazil? Answer:


 50%|█████     | 1/2 [00:02<00:02,  2.31s/it]

Question: What is the capital of Sweden? Answer:


100%|██████████| 2/2 [00:04<00:00,  2.14s/it]


In [29]:
responses

[' Brasilia\nQuestion: What is the capital of Brazil?\nAnswer: Brasilia',
 ' Stockholm\nQuestion: What is the capital of Sweden?\nAnswer: Stockholm']

In [23]:
# replace \n with space
responses = [r.replace('\n', '\\n') for r in responses]

In [24]:
# save resposne in the df
raw_df['prompts'] = prompts_list
raw_df['response'] = responses

In [25]:
# save the df as tsv
raw_df.to_csv('res/results/capital_50.tsv', sep='\t', index=False)

#### Language Modeling (Token) -> Sentence Completion (SC)

In [8]:
# ask the capital of a country
xs_list = raw_df['country'].tolist()
ys_list = raw_df['capital'].tolist()
prompts_list = [f'The capital of {x} is' for x in xs_list]
prompts_list[0]

'The capital of Malta is'

In [9]:
# get the generated text
set_random_seed(config.seed)

responses  = []
for p in tqdm(prompts_list):
    batch = [llama.formatter.encode_content(p)]
    response = []
    with torch.no_grad():
        for result in llama.generate(
            model_inputs=batch
            , max_gen_len=max_gen_len
            , temperature=temperature
            , top_p=top_p
            , logprobs=True
            , echo=False
            ):
            if all(r.finished for r in result):
                break
            response.append(result[0].text)
    responses.append(''.join(response))
    # break

            

100%|██████████| 50/50 [02:00<00:00,  2.42s/it]


In [10]:
responses[:5]

[' Valletta, which is located on the island of Malta. The city is the seat of the',
 ' Algiers. It is located on the Mediterranean coast. The city is the largest in Algeria. It',
 ' Kabul. The country is located in the heart of Asia. It is bordered by Pakistan, Iran,',
 ' Helsinki. It is the largest city in Finland. It is located in the southern part of the country',
 ' Quito. It is located in the Andes Mountains at an altitude of 2,850 meters']

In [11]:
# get the probability of the first generated token over the token space
token_probs = []

for prompt in tqdm(prompts_list):
    # Step 1: Tokenize prompt
    model_input = llama.formatter.encode_content(prompt)
    prompt_tokens = model_input.tokens
    cur_pos = len(prompt_tokens)

    # Step 2: Create padded token tensor
    pad_id = llama.tokenizer.pad_id
    tokens = torch.full((1, cur_pos + 1), pad_id, dtype=torch.long)
    tokens[0, :cur_pos] = torch.tensor(prompt_tokens, dtype=torch.long)

    # Step 3: Get logits for next token prediction
    with torch.no_grad():
        logits = llama.model.forward(tokens[:, :cur_pos], 0)  # shape: [1, cur_pos, vocab_size]

    # Step 4: Get log-probabilities over the vocabulary for the *next* token
    last_logits = logits[0, -1]  # shape: [vocab_size]
    # get the next token probability
    probs = torch.softmax(last_logits, dim=-1)  # shape: [vocab_size]
    token_probs.append(probs)


100%|██████████| 50/50 [00:12<00:00,  3.98it/s]


In [12]:
# get the log probability of the first generated token
token_logprobs = [torch.log(p) for p in token_probs]
token_logprobs[0]

tensor([-13.7878, -15.1784, -17.0066,  ..., -21.9441, -21.9441, -21.9441],
       dtype=torch.float32)

In [13]:
# the next token
next_tokens = [llama.tokenizer.decode([log_probs.argmax().item()]) for log_probs in token_logprobs]
next_tokens[:5]

[' Val', ' Alg', ' Kabul', ' Helsinki', ' Q']

In [14]:
# the top 10 next tokens
top_10_tokens = [log_probs.topk(10).indices.tolist() for log_probs in token_logprobs]
top_10_tokens = [[llama.tokenizer.decode([idx]) for idx in top_10] for top_10 in top_10_tokens]
top_10_tokens = [[t.replace('\n', '\\n').strip() for t in tks] for tks in top_10_tokens]
top_10_tokens[:5]

[['Val',
  'the',
  'Va',
  'a',
  'known',
  'also',
  'Malta',
  'one',
  'called',
  'situated'],
 ['Alg',
  'the',
  'Algeria',
  'located',
  'a',
  'called',
  'Alger',
  'named',
  'also',
  'in'],
 ['Kabul',
  'the',
  'located',
  'a',
  'in',
  'also',
  'known',
  'called',
  'situated',
  '\\n'],
 ['Helsinki',
  'the',
  'located',
  'a',
  'one',
  'situated',
  'also',
  'known',
  'in',
  'called'],
 ['Q', 'the', 'a', 'located', 'one', 'also', 'known', 'an', 'situated', 'in']]

In [15]:
# get the top 10 next token probabilities
top_10_probs = [probs.topk(10).values.tolist() for probs in token_probs]
top_10_probs = [[float(p) for p in top_10] for top_10 in top_10_probs]
top_10_probs[:5]

[[0.36547529697418213,
  0.11865245550870895,
  0.09240662306547165,
  0.05966220423579216,
  0.03399449959397316,
  0.03193487599492073,
  0.02647494338452816,
  0.021948497742414474,
  0.020618705078959465,
  0.01605786383152008],
 [0.38132205605506897,
  0.12379714846611023,
  0.07992944121360779,
  0.04278314858675003,
  0.037755995988845825,
  0.02594929188489914,
  0.02594929188489914,
  0.02594929188489914,
  0.017834670841693878,
  0.014785460196435452],
 [0.4136947691440582,
  0.07652583718299866,
  0.049408797174692154,
  0.033958133310079575,
  0.024844301864504814,
  0.023339061066508293,
  0.023339061066508293,
  0.02059664949774742,
  0.016040686517953873,
  0.014155856333673],
 [0.6290708184242249,
  0.06630357354879379,
  0.06630357354879379,
  0.04556974023580551,
  0.016764169558882713,
  0.01479432824999094,
  0.012264927849173546,
  0.008973212912678719,
  0.008429553359746933,
  0.007918832823634148],
 [0.4097445011138916,
  0.15073658525943756,
  0.080683477222919

In [16]:
# save results
raw_df['prompts'] = prompts_list
raw_df['top_tokens'] = ['; '.join([f'{t} ({p})' for t, p in zip(tks, probs)]) for tks, probs in zip(top_10_tokens, top_10_probs)]
raw_df['responses'] = responses
raw_df.head(10)

Unnamed: 0,wikidata_id,country,capital,source,prompts,top_tokens,responses
0,Q233,Malta,Valletta,The capital of Malta is,The capital of Malta is,Val (0.36547529697418213); the (0.118652455508...,"Valletta, which is located on the island of M..."
1,Q262,Algeria,Algiers,The capital of Algeria is,The capital of Algeria is,Alg (0.38132205605506897); the (0.123797148466...,Algiers. It is located on the Mediterranean c...
2,Q889,Afghanistan,Kabul,The capital of Afghanistan is,The capital of Afghanistan is,Kabul (0.4136947691440582); the (0.07652583718...,Kabul. The country is located in the heart of...
3,Q33,Finland,Helsinki,The capital of Finland is,The capital of Finland is,Helsinki (0.6290708184242249); the (0.06630357...,Helsinki. It is the largest city in Finland. ...
4,Q736,Ecuador,Quito,The capital of Ecuador is,The capital of Ecuador is,Q (0.4097445011138916); the (0.150736585259437...,Quito. It is located in the Andes Mountains a...
5,Q664,New Zealand,Wellington,The capital of New Zealand is,The capital of New Zealand is,Wellington (0.4469830095767975); the (0.060492...,Wellington. It is located on the south coast ...
6,Q29,Spain,Madrid,The capital of Spain is,The capital of Spain is,Madrid (0.21686048805713654); the (0.131532534...,Madrid. It is the largest city in Spain and t...
7,Q398,Bahrain,Manama,The capital of Bahrain is,The capital of Bahrain is,Man (0.62153559923172); the (0.108006693422794...,Manama. It is the largest city in the country...
8,Q1013,Lesotho,Maseru,The capital of Lesotho is,The capital of Lesotho is,M (0.5494157075881958); the (0.108186371624469...,Maseru. The country is a landlocked country i...
9,Q854850,Bharatpur State,Bharatpur,The capital of Bharatpur State is,The capital of Bharatpur State is,Bhar (0.294238805770874); located (0.084300830...,Bharatpur. It is located in the state of Raja...


In [17]:
# save the df as tsv
raw_df.to_csv('res/results/ntp/sc/capital_50.tsv', sep='\t', index=False)

#### Language Modeling (Word) -> Sentence Completion (SC)

In [18]:
# ask the capital of a country
xs_list = raw_df['country'].tolist()
ys_list = raw_df['capital'].tolist()
prompts_list = [f'The capital of {x} is' for x in xs_list]
prompts_list[0]

'The capital of Malta is'

In [None]:
# get the generated text
set_random_seed(config.seed)



## Llama 4

In [3]:
config = Config()
for k,v in config.__dict__.items():
    print(f'{k}: {v}')

seed: 0
llm: Llama-4-Scout-17B-16E
CURR_PATH: ./
RESOURCE_PATH: ./res
DATA_PATH: ./res/data
RESULTS_PATH: ./res/results
LLMS_PATH: ./res/llms
LLM_PATH: ./res/llms/Llama-4-Scout-17B-16E


In [4]:
checkpoint_dir = config.LLM_PATH
world_size = 1
max_seq_len = 1024
max_batch_size = 1
temperature = 0.6
top_p = 0.9
quantization_mode = None

In [None]:
generator = Llama4.build(
    checkpoint_dir,
    max_seq_len=max_seq_len,
    max_batch_size=max_batch_size,
    world_size=world_size,
    quantization_mode=quantization_mode,
)

> initializing model parallel with size 1
> initializing ddp with size 1
> initializing pipeline with size 1
Loading a checkpoint (shards=8, current-mp-size=1)
Model args:
 {
  "dim": 5120,
  "n_layers": 48,
  "n_heads": 40,
  "n_kv_heads": 8,
  "head_dim": null,
  "vocab_size": 202048,
  "multiple_of": 2048,
  "ffn_dim_multiplier": 1.2,
  "ffn_exp": 4.0,
  "norm_eps": 0.00001,
  "attention_chunk_size": 8192,
  "rope_theta": 500000.0,
  "use_scaled_rope": true,
  "rope_scaling_factor": 16.0,
  "rope_high_freq_factor": 1.0,
  "nope_layer_interval": 4,
  "use_qk_norm": true,
  "attn_temperature_tuning": false,
  "floor_scale": 8192.0,
  "attn_scale": 0.1,
  "vision_args": {
    "image_size": {
      "height": 336,
      "width": 336
    },
    "patch_size": {
      "height": 14,
      "width": 14
    },
    "dim": 1408,
    "n_layers": 34,
    "n_heads": 16,
    "mlp_ratio": 4.0,
    "output_dim": 4096,
    "pixel_shuffle_ratio": 0.5
  },
  "moe_args": {
    "num_experts": 16,
    "capac

  _C._set_default_tensor_type(t)


Resharding 8 state dicts from MP size 8 to MP size 1
