In [48]:
import warnings
import torch
import math
import time
import os
import matplotlib.pyplot as plt
from itertools import cycle
from datasets import Dataset
import datasets
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
from torch.optim.lr_scheduler import _LRScheduler
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [49]:
import tiktoken
tokenizer = tiktoken.get_encoding("gpt2") # Get the same tokenizer used for GPT-2


print("Vocabulary size:", tokenizer.n_vocab) # Vocabilary size is how many unique tokens the tokenizer can encode
print("End of text token:", tokenizer.eot_token) # End of text token is used to indicate the end of a text sequence
print("Example tokenization:", tokenizer.encode("Hello world!"))

# Convert entire dataset into a single string
# This dataset is small enough to fit into memory
# For larger datasets, you may need to use more 
# sophisticated methods to process the data.


Vocabulary size: 50257
End of text token: 50256
Example tokenization: [15496, 995, 0]


In [50]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# A simple configuration container
class GPTConfig:
    def __init__(
        self, 
        vocab_size,  # size of the vocabulary, from tokenizer, for gpt2 tokenizer it is 50257
        n_layer,   # number of transformer blocks
        n_head,    # number of attention heads for each transformer block
        n_embd,  # embedding dimension for each token
        seq_len,  # sequence length for the model - e.g. the "context window" 
    
    ):
        self.vocab_size = vocab_size
        self.n_layer = n_layer
        self.n_head = n_head
        self.n_embd = n_embd
        self.seq_len = seq_len
     
test_config = GPTConfig(
    vocab_size=tokenizer.n_vocab,
    n_layer=2,  
    n_head=3,
    n_embd=6,
    seq_len=4,
)

In [51]:
def get_position_encoding(seq_len, d, n=10000):
    """
    Computes the positional encoding matrix of shape (seq_len, d).
    
    Args:
        seq_len (int): Length of the sequence.
        d (int): Dimension of the embedding.
        n (float): The base for the exponential term (default 10000 in many Transformer implementations).
    
    Returns:
        torch.Tensor: A tensor of shape (seq_len, d) containing the positional encodings.
    """
    
    P = torch.zeros(seq_len, d).to(device)
    for pos in range(seq_len):
        for i in range(0, d // 2):
            P[pos, 2 * i] = math.sin(pos / (n ** ((2 * i) / d)))
            if i + 1 < d:
                P[pos, 2* i + 1] = math.cos(pos / (n ** ((2 * i) / d)))

    return P.unsqueeze(0)


# Example usage:
position_encoding = get_position_encoding(seq_len=test_config.seq_len, d=test_config.n_embd)
print("Position encoding shape:", position_encoding.shape)

Position encoding shape: torch.Size([1, 4, 6])


In [52]:
class SelfAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.Wq = nn.Parameter(torch.randn(config.n_embd, config.n_embd)).to(device) # Query weights - will transform input embeddings into queries
        self.Wk = nn.Parameter(torch.randn(config.n_embd, config.n_embd)).to(device) # Key weights - will transform input embeddings into keys
        self.Wv = nn.Parameter(torch.randn(config.n_embd, config.n_embd)).to(device) # Value weights - will transform input embeddings into values

    def forward(self, x):
        print("Attention input shape:", x.shape)
        print("")
        print("Query weights shape:", self.Wq.shape)
        print("Key weights shape:", self.Wk.shape)
        print("Value weights shape:", self.Wv.shape)
        queries = x @ self.Wq # Matrix multiplication to transform input embeddings into queries
        keys = x @ self.Wk # Matrix multiplication to transform input embeddings into keys
        values = x @ self.Wv # Matrix multiplication to transform input embeddings into values
        print("")
        print("Queries shape:", queries.shape)
        print("Keys shape:", keys.shape)
        print("Values shape:", values.shape)

        qkt = queries @ keys.transpose(-2, -1) # Calculate QK^T
        qkt_scaled = qkt / math.sqrt(queries.size(-1)) # Scale QK^T by the dimension of the keys
        qkt_softmax = F.softmax(qkt_scaled, dim=-1) # Apply softmax row-wise to get attention weights
        print("")
        print("QK^T shape:", qkt.shape)

        attn_output = qkt_softmax @ values # Multiply softmax(QK^T) by values
        print("")
        print("Attention output shape:", attn_output.shape)
        return attn_output 


In [53]:
class CausalSelfAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.Wq = nn.Parameter(torch.randn(config.n_embd, config.n_embd)).to(device) # Query weights - will transform input embeddings into queries
        self.Wk = nn.Parameter(torch.randn(config.n_embd, config.n_embd)).to(device) # Key weights - will transform input embeddings into keys
        self.Wv = nn.Parameter(torch.randn(config.n_embd, config.n_embd)).to(device) # Value weights - will transform input embeddings into values

    def forward(self, x):
        seq_len = x.shape[1] # Get sequence length (number of tokens / context window length)
        queries = x @ self.Wq # Matrix multiplication to transform input embeddings into queries
        keys = x @ self.Wk    # Matrix multiplication to transform input embeddings into keys
        values = x @ self.Wv  # Matrix multiplication to transform input embeddings into values
        qkt = queries @ keys.transpose(-2, -1)  # Calculate QK^T
        qkt_scaled = qkt / math.sqrt(queries.size(-1))  # Scale QK^T by the dimension of the keys

        # MASKING
        # THIS IS THE ONLY DIFFERENCE, USE -inf FOR UPPER TRIANGLE MASK SO THAT SOFTMAX WILL BE 0
        causal_mask = torch.triu(torch.ones(seq_len, seq_len, device=x.device), diagonal=1)
        causal_mask = causal_mask.masked_fill(causal_mask == 1, float('-inf'))  # Upper triangle masked with -inf 
        qkt_scaled = qkt_scaled + causal_mask # Add the mask to the scaled QK^T
        # END MASKING

        qkt_softmax = F.softmax(qkt_scaled, dim=-1) # Apply softmax row-wise to get attention weights, the -inf values will become 0 here
        attn_output = qkt_softmax @ values # Multiply softmax(QK^T) by values
        return attn_output

In [54]:
class MultiHeadAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.attn_heads = nn.ModuleList([
            CausalSelfAttention(config) for _ in range(config.n_head)
        ])  # Create n_head attention heads
        self.projection = nn.Linear(config.n_embd * config.n_head, config.n_embd).to(device) # Linear layer to project multi-head attention outputs

    def forward(self, x):
        head_outputs = [head(x) for head in self.attn_heads] # Get the output of each attention head
        multihead_output = torch.cat(head_outputs, dim=-1) # Concatenate the outputs
        return self.projection(multihead_output) # Project the concatenated outputs

In [55]:
class GPTBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.mha = MultiHeadAttention(config)
        self.ln1 = nn.LayerNorm(config.n_embd).to(device)
        self.ffn = nn.Sequential(
            nn.Linear(config.n_embd, 4 * config.n_embd),
            nn.GELU(),
            nn.Linear(4 * config.n_embd, config.n_embd),
        ).to(device)
        self.ln2 = nn.LayerNorm(config.n_embd).to(device)

    def forward(self, x):
        x = x + self.mha(self.ln1(x))
        x = x + self.ffn(self.ln2(x))
        return x


In [71]:
class GPTModel(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.token_embedding = nn.Embedding(config.vocab_size, config.n_embd).to(device)
        self.position_encoding = get_position_encoding(config.seq_len, config.n_embd)
        self.blocks = nn.Sequential(*[GPTBlock(config) for _ in range(config.n_layer)])
        self.ln_f = nn.LayerNorm(config.n_embd).to(device)
        self.head = nn.Linear(config.n_embd, config.vocab_size).to(device)
    
    def forward(self, x,return_probs=False):
        x = self.token_embedding(x) + self.position_encoding
        x = self.blocks(x)
        x = self.ln_f(x)
        logits= self.head(x)
        if return_probs:
            return torch.softmax(logits, dim=-1)
        return logits
    

In [57]:
def ROC_curve_micro(pred_tensor, label_tensor):
    device=pred_tensor.device
    n_class=pred_tensor.size(1)
    one_hot_labels = F.one_hot(label_tensor, num_classes=n_class).to(device)
    is_positive = one_hot_labels
    is_negative =1-one_hot_labels
    fn_diff = -is_positive.flatten()
    fp_diff = is_negative.flatten()
    thresh_tensor = -pred_tensor.flatten()
    fn_denom = is_positive.sum()
    fp_denom = is_negative.sum()
    sorted_indices = torch.argsort(thresh_tensor)
    sorted_fp_cum = fp_diff[sorted_indices].cumsum(0) / fp_denom
    sorted_fn_cum = -fn_diff[sorted_indices].flip(0).cumsum(0).flip(0) / fn_denom

    sorted_thresh = thresh_tensor[sorted_indices]
    sorted_is_diff = sorted_thresh.diff() != 0
    sorted_fp_end = torch.cat([sorted_is_diff, torch.tensor([True],device=device)])
    sorted_fn_end = torch.cat([torch.tensor([True],device=device), sorted_is_diff])

    uniq_thresh = sorted_thresh[sorted_fp_end]
    uniq_fp_after = sorted_fp_cum[sorted_fp_end]
    uniq_fn_before = sorted_fn_cum[sorted_fn_end]

    FPR = torch.cat([torch.tensor([0.0],device=device), uniq_fp_after])
    FNR = torch.cat([uniq_fn_before, torch.tensor([0.0],device=device)])

    return {
        "FPR": FPR,
        "FNR": FNR,
        "TPR": 1 - FNR,
        "min(FPR,FNR)": torch.minimum(FPR, FNR),
        "min_constant": torch.cat([torch.tensor([-1],device=device), uniq_thresh]),
        "max_constant": torch.cat([uniq_thresh, torch.tensor([0],device=device)])
    }
def ROC_AUC_micro(pred_tensor, label_tensor):
    roc = ROC_curve_micro(pred_tensor, label_tensor)
    FPR_diff = roc["FPR"][1:]-roc["FPR"][:-1]   
    TPR_sum = roc["TPR"][1:]+roc["TPR"][:-1]
    return torch.sum(FPR_diff*TPR_sum/2.0)
#AUM 
def Proposed_AUM_micro(pred_tensor, label_tensor):

    roc = ROC_curve_micro(pred_tensor, label_tensor)
    min_FPR_FNR = roc["min(FPR,FNR)"][1:-1]
    constant_diff = roc["min_constant"][1:].diff()
    return torch.sum(min_FPR_FNR * constant_diff)

In [58]:
def ROC_curve_macro(pred_tensor, label_tensor):
    n_class=pred_tensor.size(1)
    one_hot_labels = F.one_hot(label_tensor, num_classes=n_class)
    is_positive = one_hot_labels
    is_negative =1-one_hot_labels
    fn_diff = -is_positive
    fp_diff = is_negative
    thresh_tensor = -pred_tensor
    fn_denom = is_positive.sum(dim=0).clamp(min=1)
    fp_denom = is_negative.sum(dim=0).clamp(min=1)
    sorted_indices = torch.argsort(thresh_tensor,dim=0)
    sorted_fp_cum = torch.div(torch.gather(fp_diff, dim=0, index=sorted_indices).cumsum(0), fp_denom)
    sorted_fn_cum = -torch.div(torch.gather(fn_diff, dim=0, index=sorted_indices).flip(0).cumsum(0).flip(0) , fn_denom)
    sorted_thresh = torch.gather(thresh_tensor, dim=0, index=sorted_indices)
    #Problem starts here 
    zeros_vec=torch.zeros(1,n_class,device=device)
    FPR = torch.cat([zeros_vec, sorted_fp_cum])
    FNR = torch.cat([sorted_fn_cum, zeros_vec])
    return {
        "FPR_all_classes": FPR,
        "FNR_all_classes": FNR,
        "TPR_all_classes": 1 - FNR,
        "min(FPR,FNR)": torch.minimum(FPR, FNR),
        "min_constant": torch.cat([-torch.ones(1,n_class,device=device), sorted_thresh]),
        "max_constant": torch.cat([sorted_thresh, zeros_vec])
    }

def ROC_AUC_macro(pred_tensor, label_tensor):
    roc = ROC_curve_macro(pred_tensor, label_tensor)
    FPR_diff = roc["FPR_all_classes"][1:,:]-roc["FPR_all_classes"][:-1,]
    TPR_sum = roc["TPR_all_classes"][1:,:]+roc["TPR_all_classes"][:-1,:]
    sum_FPR_TPR= torch.sum(FPR_diff*TPR_sum/2.0,dim=0)
    count_non_defined=(sum_FPR_TPR == 0).sum()
    if count_non_defined==pred_tensor.size(1):
        return 0
    return  sum_FPR_TPR.sum()/(pred_tensor.size(1)-count_non_defined)
def Proposed_AUM_macro(pred_tensor, label_tensor):

    roc = ROC_curve_macro(pred_tensor, label_tensor)
    min_FPR_FNR = roc["min(FPR,FNR)"][1:-1,:]
    constant_diff = roc["min_constant"][1:,:].diff(dim=0)
    sum_min= torch.sum(min_FPR_FNR * constant_diff,dim=0)
    count_non_defined=(sum_min== 0).sum()
    if count_non_defined==pred_tensor.size(1):
        return torch.tensor(0,device=pred_tensor.device)
    return  sum_min.sum()/(pred_tensor.size(1)-count_non_defined)

In [59]:
batch_size = 10
sequence_len = 256
num_steps = 300
config = GPTConfig(
    vocab_size=tokenizer.n_vocab,
    n_layer=4,   # fewer layers for a quick demo
    n_head=4,
    n_embd=128,
    seq_len=sequence_len,
)


In [60]:
# Example config:
batch_size = 10
sequence_len = 256
num_steps = 1000
config = GPTConfig(
    vocab_size=tokenizer.n_vocab,
    n_layer=4,   # fewer layers for a quick demo
    n_head=4,
    n_embd=128,
    seq_len=sequence_len,
)

loss_dict={
    "AUM_micro": Proposed_AUM_micro,
    "AUM_macro": Proposed_AUM_macro,
    "Cross-entropy": F.cross_entropy
}
  

In [77]:
def inference(prompt, max_new_tokens,model,return_probs):
    tokens = tokenizer.encode(prompt)
    for _ in range(max_new_tokens):
        num_tokens = len(tokens)
        tokens_padded = tokens + [tokenizer.eot_token] * (config.seq_len - num_tokens)
        tokens_padded = torch.tensor(tokens_padded).unsqueeze(0).to(device)
        logits = model(tokens_padded,return_probs=return_probs)
        predicted_token = torch.argmax(logits[0, num_tokens-1, :]).item()
        tokens.append(predicted_token)
    return tokenizer.decode(tokens)

In [78]:
model_dict={}
loss_dict={
    "AUM_macro": Proposed_AUM_macro,
    "Cross-entropy": F.cross_entropy,
    "AUM_micro": Proposed_AUM_micro
}

for name , _ in loss_dict.items():
    model=torch.load(f"{name}_model.pt")
    model_dict[name]=model
    print(f"Predicted for {name}:", inference("Lillia and the dog", max_new_tokens=20,model=model,return_probs=(name!="Cross-entropy")).replace("\n", "\\n"))

  model=torch.load(f"{name}_model.pt")


Predicted for AUM_macro: Lillia and the dog Thirty Thirty Thirty Thirty Thirty Thirty Thirty Thirty Thirty Thirty Thirty Thirty Thirty Thirty Thirty Thirty Thirty Thirty Thirty Thirty
Predicted for Cross-entropy: Lillia and the dog. They were happy. They were happy. They were happy. They were happy. They were happy
Predicted for AUM_micro: Lillia and the dogselectionselectionselectionselectionselectionselectionselectionselectionselectionselectionselectionselectionselectionselectionselectionselectionselectionselectionselectionselection


In [64]:
from datasets import load_from_disk

dataset = load_from_disk("Tinystories_valid")
print(dataset)
print(dataset[7])

Dataset({
    features: ['text'],
    num_rows: 21990
})
{'text': 'Once upon a time, there was a little boy named Tom. He loved to play with his red ball. One sunny day, Tom went outside to play with his ball in the land near his home.\n\nTom kicked the ball high in the sky. The ball went far, far away. Tom was sad because he could not find his ball. He walked and walked, looking for it. The land was big and sometimes dangerous. Tom knew he had to be careful.\n\nAt last, Tom found his ball near a big tree. He was very happy. Tom knew he should not kick the ball too hard next time. He went back home, holding his ball tightly. Tom played safely in his yard, away from the dangerous land.'}


In [65]:
hf_tokenizer = AutoTokenizer.from_pretrained("gpt2")

In [66]:

def tokenize_and_chunk(dataset, tokenizer, chunk_size=256, train_rows=1000):
    """
    Tokenizes and chunks the dataset into fixed-length 512-token segments.
    The 'target' sequence is shifted left by 1 token.
    Stops after generating `train_rows + test_rows` tokenized chunks.
    """
    buffer = []  # Rolling buffer for tokens
    row_count = 0

    for example in dataset:
        tokens = tokenizer(example["text"], truncation=False, padding=False)['input_ids']
        buffer.extend(tokens)

        # Yield full chunks until we reach train_rows + test_rows
        while len(buffer) >= chunk_size + 1:  # +1 to ensure we can shift target
            if row_count >= train_rows:
                return  # Stop yielding once enough rows are reached

            # Create input-target pairs
            input_chunk = buffer[:chunk_size]         # First 512 tokens
            target_chunk = buffer[1:chunk_size + 1]  # Shifted by 1 token
        

            yield {
                "input": input_chunk, 
                "target": target_chunk
            }
            
            buffer = buffer[chunk_size:]  # Remove used tokens
            row_count += 1
tokenized_ds = datasets.Dataset.from_generator(lambda: tokenize_and_chunk(dataset, hf_tokenizer,train_rows=4))


In [67]:
torch.tensor(tokenized_ds['input']).flatten()
input_tensor=torch.tensor(tokenized_ds['input'],device=device)
label_tensor=torch.tensor(tokenized_ds['target'],device=device).flatten()

In [72]:
for name , loss_fn in loss_dict.items():
    model=torch.load(f"{name}_model.pt")
    pred_tensor=model(input_tensor,return_probs=(name!="Cross-entropy"))
    print(pred_tensor)
    print(f"CE for {name} model: {F.cross_entropy(pred_tensor.view(-1, pred_tensor.size(-1)),label_tensor.view(-1))}")
    print(f"AUC_macro for {name}:{ROC_AUC_macro(pred_tensor.view(-1, pred_tensor.size(-1)),label_tensor.view(-1))} ")
    print(f"AUC_micro for {name}: {ROC_AUC_micro(pred_tensor.view(-1, pred_tensor.size(-1)),label_tensor.view(-1))}")

  model=torch.load(f"{name}_model.pt")


tensor([[[4.5113e-14, 5.3758e-14, 3.9596e-13,  ..., 1.8519e-12,
          9.0396e-13, 1.4321e-12],
         [4.4684e-14, 5.4299e-14, 3.8703e-13,  ..., 1.8779e-12,
          9.0755e-13, 1.4124e-12],
         [4.5209e-14, 5.5126e-14, 3.9274e-13,  ..., 1.8802e-12,
          9.0492e-13, 1.4049e-12],
         ...,
         [4.1435e-14, 4.9484e-14, 3.7562e-13,  ..., 1.7084e-12,
          8.3826e-13, 1.2595e-12],
         [4.4949e-14, 5.3223e-14, 3.7127e-13,  ..., 1.8477e-12,
          8.8794e-13, 1.3630e-12],
         [4.3312e-14, 5.1261e-14, 3.8165e-13,  ..., 1.8176e-12,
          8.6498e-13, 1.3218e-12]],

        [[4.5041e-14, 5.4733e-14, 3.9337e-13,  ..., 1.9751e-12,
          9.1494e-13, 1.3679e-12],
         [4.4158e-14, 5.3686e-14, 3.8882e-13,  ..., 1.9250e-12,
          8.9737e-13, 1.3548e-12],
         [4.3749e-14, 5.3615e-14, 3.9046e-13,  ..., 1.8904e-12,
          8.8352e-13, 1.3093e-12],
         ...,
         [4.1964e-14, 4.9799e-14, 3.6248e-13,  ..., 1.7951e-12,
          8.501

In [74]:
from sklearn.metrics import roc_auc_score
for name , loss_fn in loss_dict.items():
    model=torch.load(f"{name}_model.pt")
    pred_tensor=model(input_tensor,return_probs=(name!="Cross-entropy"))
    print(f" Scikit AUC_macro for {name}:{roc_auc_score(label_tensor.view(-1).cpu().numpy(),pred_tensor.view(-1, pred_tensor.size(-1)).detach().cpu().numpy(),average='macro',multi_class='ovr')} ")

  model=torch.load(f"{name}_model.pt")


ValueError: Number of classes in y_true not equal to the number of columns in 'y_score'