# Importing Libraries

In [1]:
import os
import gc
import random
import numpy as np
from dataclasses import dataclass
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModelForCausalLM
import wandb

# Configuration

In [2]:
@dataclass
class Config:
    model_id: str = "openai/gpt-oss-20b"
    encoder_id: str = "FacebookAI/roberta-base"

    device: str = "cuda"
    dtype: torch.dtype = torch.bfloat16
    attn_implementation: str = "flash_attention_2"

    seed: int = 42

config = Config()

In [3]:
def set_seed(seed: int):
    """
    Set the random seed for reproducibility.

    Args:
        seed (int): The seed value to set.
    """
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    print(f"Random seed set to {seed}")
set_seed(config.seed)

Random seed set to 42


# Model

In [4]:
tokenizer = AutoTokenizer.from_pretrained(config.model_id)
tokenizer.pad_token = tokenizer.eos_token
print(tokenizer)

PreTrainedTokenizerFast(name_or_path='openai/gpt-oss-20b', vocab_size=199998, model_max_length=1000000000000000019884624838656, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<|startoftext|>', 'eos_token': '<|return|>', 'pad_token': '<|return|>'}, clean_up_tokenization_spaces=False, added_tokens_decoder={
	199998: AddedToken("<|startoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	199999: AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	200000: AddedToken("<|reserved_200000|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	200001: AddedToken("<|reserved_200001|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	200002: AddedToken("<|return|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	200003: AddedToken("<|constrain|>", rstrip=False, ls

In [None]:
model = AutoModelForCausalLM.from_pretrained(
    config.model_id,
    device_map="auto",
    dtype=config.dtype,
    attn_implementation=config.attn_implementation,
)
print(model)
print(f"Number of parameters: {sum(p.numel() for p in model.parameters()) / 1e9}b")

MXFP4 quantization requires triton >= 3.4.0 and kernels installed, we will default to dequantizing the model to bf16


In [None]:
encoder = AutoModelForCausalLM.from_pretrained(
    config.encoder_id,
    device_map="auto",
    dtype=config.dtype
)
print(encoder)

In [None]:
sample_user = "Explain quantum mechanics clearly and concisely."
messages = [
    {"role": "user", "content": sample_user},
]
prompt = tokenizer.apply_chat_template(
    messages,
    tokenize=False,
    add_generation_prompt=True,
)
print(prompt)

In [None]:
model.eval()
input_ids = tokenizer.apply_chat_template(
    messages,
    tokenize=True,
    add_generation_prompt=True,
    return_tensors="pt",
).to(model.device)
output_ids = model.generate(
    input_ids,
    max_new_tokens=256
)
response = tokenizer.decode(output_ids[0], skip_special_tokens=False)
print(response)

In [None]:
from transformers import pipeline
import torch

model_id = "openai/gpt-oss-20b"

pipe = pipeline(
    "text-generation",
    model=model_id,
    torch_dtype="auto",
    device_map="auto",
)

messages = [
    {"role": "user", "content": "Explain quantum mechanics clearly and concisely."},
]

outputs = pipe(
    messages,
    max_new_tokens=256,
)
print(outputs[0]["generated_text"][-1])
