In [None]:
# Training function for WaveNet
def train_wavenet(model, train_loader, val_loader, criterion, optimizer, device, epochs=10):
    """
    Train WaveNet model.
    
    Args:
        model: WaveNet model
        train_loader: Training data loader
        val_loader: Validation data loader
        criterion: Loss function
        optimizer: Optimizer
        device: Device to train on
        epochs: Number of epochs
        
    Returns:
        train_losses: List of training losses
        val_losses: List of validation losses
    """
    model = model.to(device)
    train_losses = []
    val_losses = []
    
    for epoch in range(epochs):
        # Training phase
        model.train()
        total_train_loss = 0
        num_train_batches = 0
        
        for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs} - Training"):
            features = batch['features'].to(device)  # [batch, 2, seq_len]
            targets = batch['targets'].to(device)    # [batch, 2, seq_len]
            
            # Forward pass
            optimizer.zero_grad()
            outputs = model(features)  # [batch, 2, seq_len]
            
            # Calculate loss (cross-entropy for each position)
            loss = criterion(outputs, targets)
            
            # Backward pass
            loss.backward()
            optimizer.step()
            
            total_train_loss += loss.item()
            num_train_batches += 1
        
        avg_train_loss = total_train_loss / num_train_batches
        train_losses.append(avg_train_loss)
        
        # Validation phase
        model.eval()
        total_val_loss = 0
        num_val_batches = 0
        
        with torch.no_grad():
            for batch in tqdm(val_loader, desc=f"Epoch {epoch+1}/{epochs} - Validation"):
                features = batch['features'].to(device)
                targets = batch['targets'].to(device)
                
                outputs = model(features)
                loss = criterion(outputs, targets)
                
                total_val_loss += loss.item()
                num_val_batches += 1
        
        avg_val_loss = total_val_loss / num_val_batches
        val_losses.append(avg_val_loss)
        
        print(f"Epoch {epoch+1}/{epochs}: Train Loss = {avg_train_loss:.6f}, Val Loss = {avg_val_loss:.6f}")
    
    return train_losses, val_losses

# Custom loss function for character-level predictions
class CharacterLevelLoss(nn.Module):
    def __init__(self):
        super(CharacterLevelLoss, self).__init__()
        self.bce_loss = nn.BCELoss()
    
    def forward(self, predictions, targets):
        """
        Calculate loss for character-level predictions.
        
        Args:
            predictions: Model predictions [batch, 2, seq_len]
            targets: Ground truth [batch, 2, seq_len]
            
        Returns:
            loss: Combined loss for start and end predictions
        """
        # predictions already have softmax applied in model
        # For BCE loss, we need probabilities for the positive class
        # predictions shape: [batch, 2, seq_len] where dim 1 is [start_prob, end_prob]
        
        loss = self.bce_loss(predictions, targets)
        return loss

# Test training on a small subset
print("Testing WaveNet training...")

# Create small datasets for testing
train_subset = train_df.sample(n=50, random_state=42)
train_texts = train_subset['text'].tolist()
train_sentiments = train_subset['sentiment'].tolist()
train_selected_texts = train_subset['selected_text'].tolist()

# Split into train/val
train_size = int(0.8 * len(train_texts))
train_texts_split = train_texts[:train_size]
train_sentiments_split = train_sentiments[:train_size]
train_selected_texts_split = train_selected_texts[:train_size]

val_texts_split = train_texts[train_size:]
val_sentiments_split = train_sentiments[train_size:]
val_selected_texts_split = train_selected_texts[train_size:]

print(f"Train samples: {len(train_texts_split)}")
print(f"Val samples: {len(val_texts_split)}")

# Create datasets
train_dataset = CharacterLevelDataset(
    texts=train_texts_split,
    sentiments=train_sentiments_split,
    selected_texts=train_selected_texts_split,
    roberta_model=model,
    tokenizer=tokenizer,
    device='cuda' if torch.cuda.is_available() else 'cpu'
)

val_dataset = CharacterLevelDataset(
    texts=val_texts_split,
    sentiments=val_sentiments_split,
    selected_texts=val_selected_texts_split,
    roberta_model=model,
    tokenizer=tokenizer,
    device='cuda' if torch.cuda.is_available() else 'cpu'
)

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False)

print(f"Train batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")

# Initialize model, loss, optimizer
wavenet_model = CharacterWaveNet(
    input_channels=2,
    num_classes=2,
    num_blocks=2,  # Reduced for testing
    num_layers=4,   # Reduced for testing
    residual_channels=16,
    gate_channels=16,
    skip_channels=16
)

criterion = CharacterLevelLoss()
optimizer = torch.optim.Adam(wavenet_model.parameters(), lr=0.001)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Training on device: {device}")

# Train for a few epochs (reduced for testing)
train_losses, val_losses = train_wavenet(
    wavenet_model, train_loader, val_loader, criterion, optimizer, device, epochs=2
)

print(f"\nFinal train loss: {train_losses[-1]:.6f}")
print(f"Final val loss: {val_losses[-1]:.6f}")