In [1]:
import torch
from chronos import ChronosPipeline

# Initialize pipeline
pipeline = ChronosPipeline.from_pretrained(
    "amazon/chronos-t5-base",
    device_map="cuda" if torch.cuda.is_available() else "cpu",
    torch_dtype=torch.bfloat16,
)

# Create sample input
context = torch.randn(1, 100)  # Single time series of length 100
prediction_length = 12  # Number of steps to predict

# B: Get tokenized input
# Reference: MeanScaleUniformBins.context_input_transform()
token_ids, attention_mask, tokenizer_state = pipeline.tokenizer.context_input_transform(context)
print("B - Tokenized shape:", token_ids.shape)

# C: Get encoder embeddings
# Reference: ChronosModel.encode()
encoder_output = pipeline.model.encode(
    input_ids=token_ids.to(pipeline.model.device),
    attention_mask=attention_mask.to(pipeline.model.device)
)
print("C - Encoder output shape:", encoder_output.shape)

# D: Get decoder output with prediction length
# Create decoder input with start token
decoder_input_ids = torch.full(
    (token_ids.shape[0], prediction_length),  # [batch_size, prediction_length]
    fill_value=pipeline.model.model.config.decoder_start_token_id,
    device=encoder_output.device
)

decoder_output = pipeline.model.model.decoder(
    input_ids=decoder_input_ids,
    encoder_hidden_states=encoder_output,
    encoder_attention_mask=attention_mask.to(encoder_output.device),
    return_dict=True
)
print("D - Decoder output shape:", decoder_output.last_hidden_state.shape)

# E: Get output tokens using the language modeling head
lm_logits = pipeline.model.model.lm_head(decoder_output.last_hidden_state)
print("E - Output token logits shape:", lm_logits.shape)



  from .autonotebook import tqdm as notebook_tqdm


In [32]:
encoder_output

tensor([[[-0.0967,  0.0016,  0.0330,  ..., -0.0708, -0.0889, -0.0261],
         [-0.0752,  0.0096,  0.0237,  ...,  0.0342,  0.0017, -0.0723],
         [-0.1055,  0.0087, -0.0703,  ...,  0.0820,  0.0217,  0.0126],
         ...,
         [ 0.0708, -0.0554,  0.0576,  ..., -0.0376, -0.0272,  0.0284],
         [ 0.1167,  0.0269,  0.0835,  ...,  0.0002, -0.0222,  0.0732],
         [-0.0001, -0.0103,  0.0417,  ...,  0.1206, -0.0071, -0.0322]]],
       device='cuda:0', dtype=torch.bfloat16, grad_fn=<MulBackward0>)

In [31]:
pipeline.embed(context)[0]

tensor([[[-0.0967,  0.0016,  0.0330,  ..., -0.0708, -0.0889, -0.0261],
         [-0.0752,  0.0096,  0.0237,  ...,  0.0342,  0.0017, -0.0723],
         [-0.1055,  0.0087, -0.0703,  ...,  0.0820,  0.0217,  0.0126],
         ...,
         [ 0.0708, -0.0554,  0.0576,  ..., -0.0376, -0.0272,  0.0284],
         [ 0.1167,  0.0269,  0.0835,  ...,  0.0002, -0.0222,  0.0732],
         [-0.0001, -0.0103,  0.0417,  ...,  0.1206, -0.0071, -0.0322]]],
       dtype=torch.bfloat16)

In [30]:
torch.cuda.is_available()

True

In [18]:
lm_logits.shape

torch.Size([1, 12, 4096])

In [None]:
import torch
import torch.nn.functional as F
from chronos import ChronosPipeline

# Initialize teacher (larger) and student (smaller) models
teacher = ChronosPipeline.from_pretrained(
    "amazon/chronos-t5-small",  # Larger model as teacher
    device_map="cuda" if torch.cuda.is_available() else "cpu",
    torch_dtype=torch.bfloat16,
)

student = ChronosPipeline.from_pretrained(
    "amazon/chronos-t5-tiny",  # Smaller model as student
    device_map="cuda" if torch.cuda.is_available() else "cpu",
    torch_dtype=torch.bfloat16,
)

def get_model_outputs(pipeline, context, prediction_length):
    # Get all intermediate representations
    token_ids, attention_mask, tokenizer_state = pipeline.tokenizer.context_input_transform(context)
    
    encoder_output = pipeline.model.encode(
        input_ids=token_ids.to(pipeline.model.device),
        attention_mask=attention_mask.to(pipeline.model.device)
    )
    
    decoder_input_ids = torch.full(
        (token_ids.shape[0], prediction_length),
        fill_value=pipeline.model.model.config.decoder_start_token_id,
        device=encoder_output.device
    )
    
    decoder_output = pipeline.model.model.decoder(
        input_ids=decoder_input_ids,
        encoder_hidden_states=encoder_output,
        encoder_attention_mask=attention_mask.to(encoder_output.device),
        return_dict=True
    )
    
    lm_logits = pipeline.model.model.lm_head(decoder_output.last_hidden_state)
    output_tokens = torch.argmax(lm_logits, dim=-1)
    
    return {
        'token_ids': token_ids,
        'encoder_output': encoder_output,
        'decoder_output': decoder_output.last_hidden_state,
        'logits': lm_logits,
        'tokens': output_tokens
    }

def distillation_loss(teacher_outputs, student_outputs, temperature=2.0):
    losses = {}
    
    # 1. Soft targets loss on logits
    soft_targets = F.softmax(teacher_outputs['logits'] / temperature, dim=-1)
    student_logits = student_outputs['logits'] / temperature
    losses['soft_targets'] = F.kl_div(
        F.log_softmax(student_logits, dim=-1),
        soft_targets,
        reduction='batchmean'
    ) * (temperature ** 2)
    
    # 2. Hidden state distillation (encoder)
    losses['encoder'] = F.mse_loss(
        student_outputs['encoder_output'],
        teacher_outputs['encoder_output']
    )
    
    # 3. Hidden state distillation (decoder)
    losses['decoder'] = F.mse_loss(
        student_outputs['decoder_output'],
        teacher_outputs['decoder_output']
    )
    
    # 4. Hard prediction loss
    losses['tokens'] = F.cross_entropy(
        student_outputs['logits'].view(-1, student_outputs['logits'].size(-1)),
        teacher_outputs['tokens'].view(-1)
    )
    
    return losses

# Training loop
def train_step(context, prediction_length):
    # Get teacher outputs (with no grad)
    with torch.no_grad():
        teacher_outputs = get_model_outputs(teacher, context, prediction_length)
    
    # Get student outputs
    student_outputs = get_model_outputs(student, context, prediction_length)
    
    # Calculate losses
    losses = distillation_loss(teacher_outputs, student_outputs)
    
    # Total loss (with weights for each component)
    total_loss = (
        0.5 * losses['soft_targets'] +
        0.1 * losses['encoder'] +
        0.1 * losses['decoder'] +
        0.3 * losses['tokens']
    )
    
    return total_loss, losses

# Example usage
context = torch.randn(1, 100)  # Sample input
prediction_length = 12

loss, component_losses = train_step(context, prediction_length)
print("Total Loss:", loss.item())
print("Component Losses:", {k: v.item() for k, v in component_losses.items()})