In [37]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from transformers import GPT2LMHeadModel, GPT2Tokenizer

class GPT2FineTuner(nn.Module):
    def __init__(self, model_name='gpt2'):
        super().__init__()
        # in future, replace with own pretrained gpt2 model
        self.model = GPT2LMHeadModel.from_pretrained(model_name)
        self.tokenizer = GPT2Tokenizer.from_pretrained(model_name)
        self.tokenizer.pad_token = self.tokenizer.eos_token

    def forward(self, input_ids, attention_mask, labels):
        outputs = self.model(input_ids, attention_mask=attention_mask, labels=labels)
        return outputs.loss, outputs.logits

def prepare_data(instruction, response, tokenizer, max_length=512):
    # Combine instruction and response
    full_text = f"Instruction: {instruction}\nResponse: {response}"
    
    # Tokenize and prepare input
    encoded = tokenizer.encode_plus(
        full_text,
        max_length=max_length,
        padding='max_length',
        truncation=True,
        return_tensors='pt'
    )
    
    input_ids = encoded['input_ids']
    attention_mask = encoded['attention_mask']
    
    # Prepare labels (shift input_ids right by one)
    labels = input_ids.clone()
    labels[:, :-1] = input_ids[:, 1:]
    labels[:, -1] = -100  # Ignore loss for last token prediction
    
    return input_ids, attention_mask, labels

def train_step(model, optimizer, input_ids, attention_mask, labels):
    optimizer.zero_grad()
    loss, logits = model(input_ids, attention_mask, labels)
    loss.backward()
    optimizer.step()
    return loss.item(), logits

def custom_loss_vectorized(logits, labels):
    shift_logits = logits[:, :-1, :].contiguous()
    shift_labels = labels[:, 1:].contiguous()
    loss = F.cross_entropy(shift_logits.view(-1, shift_logits.size(-1)), 
                           shift_labels.view(-1), 
                           ignore_index=-100,
                           reduction='sum')
    num_valid_tokens = (shift_labels != -100).sum().item()
    return loss / num_valid_tokens if num_valid_tokens > 0 else 0

def fine_tune(model, train_data, num_epochs=3, lr=5e-5):
    optimizer = optim.AdamW(model.parameters(), lr=lr)
    
    for epoch in range(num_epochs):
        total_loss = 0
        for instruction, response in train_data:
            input_ids, attention_mask, labels = prepare_data(instruction, response, model.tokenizer)
            loss, _ = train_step(model, optimizer, input_ids, attention_mask, labels)
            total_loss += loss
        
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss/len(train_data)}")
    
def fine_tune_custom_loss(model, train_data, num_epochs=3, lr=5e-5):
    optimizer = optim.AdamW(model.parameters(), lr=lr)
    
    for epoch in range(num_epochs):
        total_loss = 0
        for instruction, response in train_data:
            input_ids, attention_mask, labels = prepare_data(instruction, response, model.tokenizer)
            _, logits = train_step(model, optimizer, input_ids, attention_mask, labels)
            loss = custom_loss_vectorized(logits, labels)
            total_loss += loss
        
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss/len(train_data)}")

# Example usage
model = GPT2FineTuner()
train_data = [
    ("Explain the concept of machine learning.", "Machine learning is a subset of artificial intelligence..."),
    ("What is a neural network?", "A neural network is a computational model inspired by the human brain...")
]
fine_tune(model, train_data)
model = GPT2FineTuner()
print("Fine-tuning with custom loss function:")
fine_tune_custom_loss(model, train_data)

Epoch 1/3, Loss: 8.47936201095581
Epoch 2/3, Loss: 1.5545154213905334
Epoch 3/3, Loss: 0.39046511054039
Fine-tuning with custom loss function:
Epoch 1/3, Loss: 8.479362487792969
Epoch 2/3, Loss: 1.5545153617858887
Epoch 3/3, Loss: 0.39046511054039


In [34]:
model_name = 'gpt2'
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token

In [13]:
instruction = "What is a neural network?"
response = "A neural network is a computational model inspired by the human brain..."
full_text = f"Instruction: {instruction}\nResponse: {response}"

encoded = tokenizer.encode_plus(
        full_text,
        max_length=512,
        padding='max_length',
        truncation=True,
        return_tensors='pt'
    )


input_ids = encoded['input_ids']
attention_mask = encoded['attention_mask']

# Prepare labels (shift input_ids right by one)
labels = input_ids.clone()
labels[:, :-1] = input_ids[:, 1:]
labels[:, -1] = -100  # Ignore loss for last token prediction
print(input_ids.view(-1).tolist()[:20])  
print(labels.view(-1).tolist()[:20])
print(labels[:, -1])

[6310, 2762, 25, 1867, 318, 257, 17019, 3127, 30, 198, 31077, 25, 317, 17019, 3127, 318, 257, 31350, 2746, 7867]
[2762, 25, 1867, 318, 257, 17019, 3127, 30, 198, 31077, 25, 317, 17019, 3127, 318, 257, 31350, 2746, 7867, 416]
tensor([-100])


In [35]:
import torch
import torch.nn.functional as F
import time

# loss function without using huggingface library to calculate loss

def custom_loss_loop(logits, labels):
    loss = 0
    batch_size, seq_len, vocab_size = logits.shape
    num_valid_tokens = 0
    for i in range(batch_size):
        for j in range(seq_len - 1):  # -1 because we're predicting the next token
            if labels[i, j + 1] != -100:
                # predicted (number of tokens,)
                predicted = logits[i, j, :]
                # predicted (1,)
                target = labels[i, j + 1]
                # cross entropy expects 2D inputs
                loss += F.cross_entropy(predicted.unsqueeze(0), target.unsqueeze(0), reduction='sum')
                num_valid_tokens += 1
    return loss / num_valid_tokens if num_valid_tokens > 0 else 0

def custom_loss_vectorized(logits, labels):
    shift_logits = logits[:, :-1, :].contiguous()
    shift_labels = labels[:, 1:].contiguous()
    loss = F.cross_entropy(shift_logits.view(-1, shift_logits.size(-1)), 
                           shift_labels.view(-1), 
                           ignore_index=-100,
                           reduction='sum')
    num_valid_tokens = (shift_labels != -100).sum().item()
    return loss / num_valid_tokens if num_valid_tokens > 0 else 0


# Example usage
torch.manual_seed(42)  # For reproducibility
batch_size, seq_len, vocab_size = 2, 10, 50000
logits = torch.randn(batch_size, seq_len, vocab_size)
labels = torch.randint(0, vocab_size, (batch_size, seq_len))
labels[:, -1] = -100  # Set last token to -100

loop_loss = custom_loss_loop(logits, labels)
vectorized_loss = custom_loss_vectorized(logits, labels)

print(f"Loop Loss: {loop_loss.item():.6f}")
print(f"Vectorized Loss: {vectorized_loss.item():.6f}")

# Verify that logits and labels are identical for all functions
print(f"\nVerification:")
print(f"Loop Loss == Vectorized Loss: {torch.isclose(loop_loss, vectorized_loss)}")


Loop Loss: 11.350384
Vectorized Loss: 11.350384

Verification:
Loop Loss == Vectorized Loss: True


In [29]:
torch.manual_seed(42)  # For reproducibility
batch_size, seq_len, vocab_size = 2, 10, 50000
logits = torch.randn(batch_size, seq_len, vocab_size)
labels = torch.randint(0, vocab_size, (batch_size, seq_len))
labels[:, -1] = -100  # Set last token to -100

loss = 0
batch_size, seq_len, vocab_size = logits.shape
num_valid_tokens = 0
for i in range(batch_size):
    for j in range(seq_len - 1):  # -1 because we're predicting the next token
        if labels[i, j + 1] != -100:
            predicted = logits[i, j, :]
            target = labels[i, j + 1]
            #print(torch.topk(predicted, 5).indices, target)
            print(predicted.shape)
            print(predicted.unsqueeze(0).shape, target.unsqueeze(0).shape)
            import sys; sys.exit()
 
            loss += F.cross_entropy(predicted.unsqueeze(0), target.unsqueeze(0), reduction='sum')
            num_valid_tokens += 1
print(loss / num_valid_tokens if num_valid_tokens > 0 else 0)

torch.Size([50000])
torch.Size([1, 50000]) torch.Size([1])


SystemExit: 

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)


In [32]:
# Reduction: sum or mean does not matter in this case
torch.manual_seed(42)  # For reproducibility
batch_size, seq_len, vocab_size = 2, 10, 50000
logits = torch.randn(batch_size, seq_len, vocab_size)
labels = torch.randint(0, vocab_size, (batch_size, seq_len))
labels[:, -1] = -100  # Set last token to -100

loss = 0
batch_size, seq_len, vocab_size = logits.shape
num_valid_tokens = 0
for i in range(batch_size):
    for j in range(seq_len - 1):  # -1 because we're predicting the next token
        if labels[i, j + 1] != -100:
            predicted = logits[i, j, :]
            target = labels[i, j + 1]
            loss += F.cross_entropy(predicted.unsqueeze(0), target.unsqueeze(0), reduction='sum')
            num_valid_tokens += 1
print(loss / num_valid_tokens if num_valid_tokens > 0 else 0)

torch.manual_seed(42)  # For reproducibility
batch_size, seq_len, vocab_size = 2, 10, 50000
logits = torch.randn(batch_size, seq_len, vocab_size)
labels = torch.randint(0, vocab_size, (batch_size, seq_len))
labels[:, -1] = -100  # Set last token to -100

loss = 0
batch_size, seq_len, vocab_size = logits.shape
num_valid_tokens = 0
for i in range(batch_size):
    for j in range(seq_len - 1):  # -1 because we're predicting the next token
        if labels[i, j + 1] != -100:
            predicted = logits[i, j, :]
            target = labels[i, j + 1]
            loss += F.cross_entropy(predicted.unsqueeze(0), target.unsqueeze(0), reduction='mean')
            num_valid_tokens += 1
print(loss / num_valid_tokens if num_valid_tokens > 0 else 0)

tensor(11.3504)
tensor(11.3504)
