In [None]:
async def initialize_streaming_session(self, initial_prompt: str):
    """Initialize a streaming session with an initial prompt"""
    if not self.model or not self.tokenizer:
        await self.load_model()
    
    # Tokenize the initial prompt
    input_ids = self.tokenizer(initial_prompt, return_tensors="pt").input_ids.to(self.device)
    
    # Generate initial response and keep the past key values
    with torch.no_grad():
        outputs = self.model.generate(
            input_ids, 
            max_new_tokens=0,  # Just initialize, don't generate yet
            return_dict_in_generate=True,
            output_scores=True,
            past_key_values=None,  # Initial past is None
        )
    
    # Store the past key values for future continuations
    past_key_values = outputs.past_key_values
    
    # Return a session object with the current state
    return {
        "past_key_values": past_key_values,
        "input_ids": input_ids,
        "generated_tokens": []
    }

async def continue_generation(self, session, new_data: str, max_tokens: int = 50):
    """Continue generation with new data fed into an existing session"""
    if not self.model or not self.tokenizer:
        await self.load_model()
    
    # Tokenize only the new data
    new_tokens = self.tokenizer(new_data, return_tensors="pt").input_ids.to(self.device)
    
    # Get the past state from the session
    past_key_values = session["past_key_values"]
    
    # Run generation with the cached state
    loop = asyncio.get_event_loop()
    
    def generate_continuation():
        with torch.no_grad():
            outputs = self.model.generate(
                new_tokens,
                max_new_tokens=max_tokens,
                do_sample=True,
                temperature=0.7,
                top_p=0.95,
                return_dict_in_generate=True,
                output_scores=True,
                past_key_values=past_key_values,  # Use the cached state
            )
        
        # Update the session with new past key values
        session["past_key_values"] = outputs.past_key_values
        
        # Get the generated text
        generated_ids = outputs.sequences[0][-max_tokens:]
        generated_text = self.tokenizer.decode(generated_ids, skip_special_tokens=True)
        
        # Update the session's generated tokens
        session["generated_tokens"].extend(generated_ids.tolist())
        
        return generated_text
    
    response = await loop.run_in_executor(None, generate_continuation)
    return response

In [None]:
async def streaming_session(self, initial_prompt: str):
    """Create a streaming session that can accept new tokens"""
    if not self.model or not self.tokenizer:
        await self.load_model()
    
    # Initialize the context with the prompt
    input_ids = self.tokenizer(initial_prompt, return_tensors="pt").input_ids.to(self.device)
    attention_mask = torch.ones_like(input_ids)
    
    # Do the initial forward pass to build the KV cache
    with torch.no_grad():
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            use_cache=True
        )
    
    # Store the context state
    past_key_values = outputs.past_key_values
    context_tokens = input_ids
    
    return {
        "past_key_values": past_key_values,
        "context_tokens": context_tokens,
        "attention_mask": attention_mask
    }

async def add_data_tokens(self, session, new_data: str):
    """Add new data tokens to the existing context"""
    # Tokenize new data
    new_tokens = self.tokenizer(new_data, return_tensors="pt").input_ids.to(self.device)
    
    # Extend attention mask for new tokens
    new_attention = torch.ones_like(new_tokens)
    extended_attention = torch.cat([session["attention_mask"], new_attention], dim=-1)
    
    # Do forward pass with just the new tokens, using the previous KV cache
    with torch.no_grad():
        outputs = self.model(
            input_ids=new_tokens,
            attention_mask=extended_attention,
            past_key_values=session["past_key_values"],
            use_cache=True
        )
    
    # Update the session state
    session["past_key_values"] = outputs.past_key_values
    session["context_tokens"] = torch.cat([session["context_tokens"], new_tokens], dim=-1)
    session["attention_mask"] = extended_attention
    
    return session

async def generate_response(self, session, max_tokens: int = 50):
    """Generate a response based on the current context"""
    # Get the last token to start generation
    last_token = session["context_tokens"][:, -1:]
    
    generated_ids = []
    for _ in range(max_tokens):
        # Predict next token using the cached KV state
        with torch.no_grad():
            outputs = self.model(
                input_ids=last_token,
                attention_mask=session["attention_mask"],
                past_key_values=session["past_key_values"],
                use_cache=True
            )
        
        # Get the predicted token
        next_token_logits = outputs.logits[:, -1, :]
        next_token = self._sample_token(next_token_logits, temperature=0.7)
        
        # Append to results
        generated_ids.append(next_token.item())
        
        # Break if we hit an end token
        if next_token.item() == self.tokenizer.eos_token_id:
            break
        
        # Update for next iteration
        last_token = next_token.unsqueeze(0).unsqueeze(0)
        session["past_key_values"] = outputs.past_key_values
        session["context_tokens"] = torch.cat([session["context_tokens"], last_token], dim=-1)
        session["attention_mask"] = torch.cat([session["attention_mask"], torch.ones_like(last_token)], dim=-1)
    
    # Decode the generated text
    generated_text = self.tokenizer.decode(generated_ids, skip_special_tokens=True)
    return generated_text

def _sample_token(self, logits, temperature=0.7):
    """Sample a token from logits with temperature"""
    if temperature == 0:
        return torch.argmax(logits, dim=-1)
    
    probs = torch.nn.functional.softmax(logits / temperature, dim=-1)
    next_token = torch.multinomial(probs, num_samples=1)
    return next_token