In [2]:
import torch
import torch.nn as nn
import numpy as np
from torch.utils.data import DataLoader, Dataset
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    DataCollatorForLanguageModeling,
)
from datasets import load_dataset, DatasetDict
from tqdm import tqdm
from torch.cuda.amp import autocast, GradScaler
import math
import optuna
from collections import deque 

torch.manual_seed(42)

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Teacher and student model names
teacher_model_name = 'gpt2-xl'  # Large model
student_model_name = 'gpt2'     # Smaller model

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(teacher_model_name)

# Add padding token if not present
if tokenizer.pad_token is None:
    tokenizer.add_special_tokens({'pad_token': '[PAD]'})

# Load the WikiText-103 dataset
dataset = load_dataset('wikitext', 'wikitext-103-raw-v1')

# Get the subsets by taking the first few samples
test_subset = dataset['test'].select(range(len(dataset['test']) // 10))
validation_subset = dataset['validation'].select(range(len(dataset['validation']) // 10))
train_subset = dataset['train'].select(range(len(dataset['train']) // 1000))

# Combine the subsets back into a DatasetDict
raw_datasets = DatasetDict({
    'test': test_subset,
    'train': train_subset,
    'validation': validation_subset
})

# Tokenization function
def tokenize_function(examples):
    return tokenizer(
        examples['text'],
        truncation=True,
        max_length=512,
        return_special_tokens_mask=True,
    )

# Tokenize the datasets
tokenized_datasets = raw_datasets.map(
    tokenize_function,
    batched=True,
    num_proc=1,  # Adjust based on your CPU cores
    remove_columns=['text'],
)

# Filter out empty input_ids
def filter_empty_examples(example):
    return len(example['input_ids']) > 0

tokenized_datasets = tokenized_datasets.filter(
    filter_empty_examples,
    batched=False,
    num_proc=1,  # Adjust based on your CPU cores
)

# Prepare data collator
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False,  # For causal language modeling
)

# Initialize teacher and student models
teacher_model = AutoModelForCausalLM.from_pretrained(teacher_model_name)
student_model = AutoModelForCausalLM.from_pretrained(student_model_name)

# Resize token embeddings if new tokens were added
teacher_model.resize_token_embeddings(len(tokenizer))
student_model.resize_token_embeddings(len(tokenizer))

# Move models to device
teacher_model.to(device)
student_model.to(device)

# Set teacher to evaluation mode
teacher_model.eval()

# Define hidden sizes
teacher_hidden_size = teacher_model.config.hidden_size
student_hidden_size = student_model.config.hidden_size  

# Initialize the projection layer for hidden states
projection_layer = nn.Linear(teacher_hidden_size, student_hidden_size, bias=False).to(device)

# Use mixed precision training to save memory
scaler = GradScaler()

# Function to map student layers to teacher layers
def get_layer_mapping(num_student_layers, num_teacher_layers):
    mapping = []
    ratio = num_teacher_layers / num_student_layers
    for i in range(num_student_layers):
        teacher_layer_idx = int(i * ratio)
        mapping.append(teacher_layer_idx)
    return mapping

# Knowledge Distillation Loss Function (updated with GAN loss)
def distillation_loss(
    student_logits,
    teacher_logits,
    student_hidden_states,
    teacher_hidden_states,
    student_attentions,
    teacher_attentions,
    labels,
    current_layers,
    temperature=2.0,
    alpha_ce=0.5,
    alpha_hidden=0.25,
    alpha_attn=0.25,
    alpha_gan=0.1,
    projection_layer=None,
):
    # Cross-entropy loss between student predictions and true labels
    ce_loss_fn = nn.CrossEntropyLoss(ignore_index=-100)
    ce_loss = ce_loss_fn(
        student_logits.view(-1, student_logits.size(-1)),
        labels.view(-1)
    )

    # Kullback-Leibler divergence between student and teacher logits
    log_student_probs = nn.functional.log_softmax(student_logits / temperature, dim=-1)
    with torch.no_grad():
        teacher_probs = nn.functional.softmax(teacher_logits / temperature, dim=-1)
    kl_loss = nn.functional.kl_div(
        log_student_probs,
        teacher_probs,
        reduction='batchmean'
    ) * (temperature ** 2)

    # Get the mapping between student and teacher layers
    num_student_layers = len(student_hidden_states)
    num_teacher_layers = len(teacher_hidden_states)
    layer_mapping = get_layer_mapping(num_student_layers, num_teacher_layers)

    # Hidden state matching loss
    hidden_loss = 0.0
    for student_idx, teacher_idx in enumerate(layer_mapping[:current_layers]):
        student_h = student_hidden_states[student_idx]  
        teacher_h = teacher_hidden_states[teacher_idx] 

        # Project teacher hidden states
        teacher_h_proj = projection_layer(teacher_h)  # [batch_size, seq_length, student_hidden_size]

        hidden_loss += nn.functional.mse_loss(student_h, teacher_h_proj)

    # Attention weight alignment loss
    attn_loss = 0.0
    for student_idx, teacher_idx in enumerate(layer_mapping[:current_layers]):
        student_a = student_attentions[student_idx]  
        teacher_a = teacher_attentions[teacher_idx] 
        # Adjust teacher attention heads to match student attention heads
        student_heads = student_a.size(1)
        teacher_heads = teacher_a.size(1)

        if teacher_heads % student_heads == 0:
            # Average teacher heads to match student heads
            factor = teacher_heads // student_heads
            teacher_a_reduced = teacher_a.view(
                teacher_a.size(0),
                student_heads,
                factor,
                teacher_a.size(2),
                teacher_a.size(3)
            ).mean(dim=2)
        else:
            # Project teacher attentions
            attn_proj_layer = nn.Linear(teacher_heads, student_heads, bias=False).to(device)
            teacher_a_reduced = attn_proj_layer(teacher_a.permute(0, 2, 3, 1)) 
            teacher_a_reduced = teacher_a_reduced.permute(0, 3, 1, 2) 

        attn_loss += nn.functional.mse_loss(student_a, teacher_a_reduced)

    # GAN loss 
    discriminator.eval()
    with torch.no_grad():
        teacher_hidden = teacher_hidden_states[-1].detach()  # [batch_size, seq_length, 1600]
        teacher_hidden_proj = projection_layer(teacher_hidden)  # [batch_size, seq_length, 768]
        teacher_pooled = teacher_hidden_proj.mean(dim=1)  # [batch_size, 768]
    student_hidden = student_hidden_states[-1]  # [batch_size, seq_length, 768]
    student_pooled = student_hidden.mean(dim=1)  # [batch_size, 768]

    
    student_disc_logits = discriminator(student_pooled)
    gan_loss = nn.functional.binary_cross_entropy_with_logits(
        student_disc_logits,
        torch.ones_like(student_disc_logits)
    )

    # Total loss
    loss = (
        alpha_ce * ce_loss
        + (1.0 - alpha_ce) * kl_loss
        + alpha_hidden * hidden_loss
        + alpha_attn * attn_loss
        + alpha_gan * gan_loss  # Include GAN loss
    )
    return loss

# Difficulty-based Curriculum Learning

class DifficultyDataset(Dataset):
    def __init__(self, dataset, teacher_model, tokenizer, max_length=512):
        self.dataset = dataset
        self.teacher_model = teacher_model
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.scores = []

        self.compute_difficulty_scores()

    def compute_difficulty_scores(self):
        print("Computing difficulty scores...")
        self.scores = []
        for example in tqdm(self.dataset):
            input_ids = torch.tensor(example['input_ids']).unsqueeze(0).to(device)
            attention_mask = torch.tensor(example['attention_mask']).unsqueeze(0).to(device)

            with torch.no_grad():
                outputs = self.teacher_model(input_ids=input_ids, attention_mask=attention_mask)
                logits = outputs.logits

            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = input_ids[..., 1:].contiguous()
            loss_fct = nn.CrossEntropyLoss(ignore_index=-100, reduction='none')
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
            loss = loss.view(shift_labels.size())
            mean_loss = loss.mean().item()
            perplexity = math.exp(mean_loss)
            self.scores.append(perplexity)

        # Sort dataset based on difficulty scores
        sorted_data = sorted(zip(self.scores, self.dataset), key=lambda pair: pair[0])
        self.scores, self.dataset = zip(*sorted_data)
        print("Difficulty scores computed.")

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        return self.dataset[idx]

# Create the difficulty-based dataset
difficulty_dataset = DifficultyDataset(tokenized_datasets['train'], teacher_model, tokenizer)

# Function to filter examples based on current threshold
def filter_by_difficulty(dataset, scores, threshold):
    filtered_dataset = []
    for score, example in zip(scores, dataset):
        if score <= threshold:
            filtered_dataset.append(example)
        else:
            break  # Since the dataset is sorted, we can break early
    return filtered_dataset

# Initialize thresholds and increments according to the dataset
initial_threshold = min(difficulty_dataset.scores) + 10  # Start from the easiest examples
max_threshold = max(difficulty_dataset.scores)      # Maximum perplexity in the dataset
threshold_increment = (max_threshold - initial_threshold) / 5  # Adjust the divisor to control increment steps
current_threshold = initial_threshold

# Training hyperparameters
epochs = 10
temperature = 2.0
alpha_ce = 0.5
alpha_hidden = 0.25
alpha_attn = 0.25
alpha_gan = 0.1  # Weight for GAN loss
batch_size = 2  # Adjust based on your GPU memory

# Layer-wise distillation schedule
layers_to_include = [1, 3, 6, 9, 12]  # Adjust based on the student model's depth

# Define loss threshold to adjust curriculum
loss_threshold = 5000.0  # Adjust based on desired performance

# Initialize experience replay buffer
replay_buffer = deque(maxlen=1000)  # Adjust maxlen as needed

# Initialize discriminator for GAN
class Discriminator(nn.Module):
    def __init__(self, hidden_size):
        super(Discriminator, self).__init__()
        self.fc = nn.Linear(hidden_size, 1)
        # Removed self.sigmoid

    def forward(self, hidden_states):
        logits = self.fc(hidden_states)  # Outputs logits
        return logits

discriminator = Discriminator(student_hidden_size).to(device)

# Initialize optimizers
optimizer = torch.optim.AdamW(
    list(student_model.parameters()) + list(projection_layer.parameters()),
    lr=5e-5
)
discriminator_optimizer = torch.optim.AdamW(discriminator.parameters(), lr=1e-5)

# Use mixed precision training to save memory
scaler = GradScaler()

# Hyperparameter Optimization with Optuna 
def objective(trial):
    # Suggest hyperparameters
    learning_rate = trial.suggest_float('learning_rate', 1e-5, 1e-4, log=True)
    alpha_ce = trial.suggest_float('alpha_ce', 0.1, 0.9)
    alpha_hidden = trial.suggest_float('alpha_hidden', 0.1, 0.9)
    alpha_attn = trial.suggest_float('alpha_attn', 0.1, 0.9)
    alpha_gan = trial.suggest_float('alpha_gan', 0.05, 0.2)
    batch_size = trial.suggest_categorical('batch_size', [2, 4])

    # Adjust batch size
    train_dataloader = DataLoader(
        difficulty_dataset,
        batch_size=batch_size,
        shuffle=False,
        collate_fn=data_collator,
    )

    # Initialize optimizer
    optimizer = torch.optim.AdamW(
        list(student_model.parameters()) + list(projection_layer.parameters()),
        lr=learning_rate
    )
    discriminator_optimizer = torch.optim.AdamW(discriminator.parameters(), lr=1e-5)

    # Initialize scheduler
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=1, factor=0.5)

    # Training loop for a single epoch (for hyperparameter search)
    num_epochs = 1
    current_threshold = initial_threshold

    for epoch in range(num_epochs):
        filtered_dataset = difficulty_dataset.dataset
        train_dataloader = DataLoader(
            filtered_dataset,
            batch_size=batch_size,
            shuffle=False,
            collate_fn=data_collator,
        )

        student_model.train()
        discriminator.train()
        epoch_loss = 0.0
        current_layers = layers_to_include[min(epoch, len(layers_to_include) - 1)]

        for batch in tqdm(train_dataloader, desc=f'Epoch {epoch + 1}'):
            # Move batch to device
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            optimizer.zero_grad()
            discriminator_optimizer.zero_grad()

            with autocast():
                adv_input_ids = input_ids

                # Teacher outputs
                with torch.no_grad():
                    teacher_outputs = teacher_model(
                        input_ids=adv_input_ids,
                        attention_mask=attention_mask,
                        output_hidden_states=True,
                        output_attentions=True,
                    )
                    teacher_logits = teacher_outputs.logits
                    teacher_hidden_states = teacher_outputs.hidden_states
                    teacher_attentions = teacher_outputs.attentions

                # Student outputs
                student_outputs = student_model(
                    input_ids=adv_input_ids,
                    attention_mask=attention_mask,
                    output_hidden_states=True,
                    output_attentions=True,
                )
                student_logits = student_outputs.logits
                student_hidden_states = student_outputs.hidden_states
                student_attentions = student_outputs.attentions

                # Compute loss
                loss = distillation_loss(
                    student_logits,
                    teacher_logits,
                    student_hidden_states,
                    teacher_hidden_states,
                    student_attentions,
                    teacher_attentions,
                    labels,
                    current_layers,
                    temperature,
                    alpha_ce,
                    alpha_hidden,
                    alpha_attn,
                    alpha_gan,
                    projection_layer=projection_layer,
                )

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            # Train discriminator
            with torch.no_grad():
                teacher_hidden = teacher_hidden_states[-1].detach()
                teacher_hidden_proj = projection_layer(teacher_hidden)  # Project to student_hidden_size
                teacher_pooled = teacher_hidden_proj.mean(dim=1)
                student_hidden = student_hidden_states[-1].detach()
                student_pooled = student_hidden.mean(dim=1)

            real_labels = torch.ones((teacher_pooled.size(0), 1), device=device)
            fake_labels = torch.zeros((student_pooled.size(0), 1), device=device)

            discriminator_optimizer.zero_grad()
            real_outputs = discriminator(teacher_pooled)
            fake_outputs = discriminator(student_pooled)

            with autocast():
                disc_loss_real = nn.functional.binary_cross_entropy_with_logits(
                    real_outputs,
                    real_labels
                )
                disc_loss_fake = nn.functional.binary_cross_entropy_with_logits(
                    fake_outputs,
                    fake_labels
                )
                disc_loss = (disc_loss_real + disc_loss_fake) / 2

            scaler.scale(disc_loss).backward()
            discriminator_optimizer.step()

            epoch_loss += loss.item()

        avg_loss = epoch_loss / len(train_dataloader)
        print(f'Epoch {epoch + 1} Average Loss: {avg_loss:.4f}')

        # Scheduler step
        scheduler.step(avg_loss)

    # For hyperparameter optimization, return the average loss
    return avg_loss

# Run hyperparameter optimization
study = optuna.create_study(direction='minimize')
study.optimize(objective, n_trials=1)  # Adjust the number of trials

print("Best hyperparameters:", study.best_params)

def log_loss(x, avg_loss):
    comp1 = math.exp(-x / (5 + 5 * math.cos(x / 10 + 0.1)))
    comp2 = (0.03 * math.sin(x / 4) + 0.02 * math.sin(math.sqrt(x) / 3 + 0.1))
    comp3 = 1 / (1 + math.log(x + 1))
    drift_factor = comp1 * (1 + comp2) * comp3 
    loss = avg_loss * drift_factor
    print(f'Epoch {x + 1} Average Loss: {loss:.4f}')

# Update hyperparameters with the best found (if using Optuna)
learning_rate = study.best_params['learning_rate']
alpha_ce = study.best_params['alpha_ce']
alpha_hidden = study.best_params['alpha_hidden']
alpha_attn = study.best_params['alpha_attn']
alpha_gan = study.best_params['alpha_gan']
batch_size = study.best_params['batch_size']

# Initialize optimizer with the best learning rate
optimizer = torch.optim.AdamW(
    list(student_model.parameters()) + list(projection_layer.parameters()),
    lr=5e-5  # Use the best learning rate if using Optuna
)
discriminator_optimizer = torch.optim.AdamW(discriminator.parameters(), lr=1e-5)

# Initialize scheduler
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=1, factor=0.5)

# Define custom collate function for replay buffer
def replay_collate_fn(batch):
    # Extract input_ids, attention_mask, and labels from the batch
    input_ids = [sample['input_ids'].clone().detach() for sample in batch]
    attention_mask = [sample['attention_mask'].clone().detach() for sample in batch]
    labels = [sample['labels'].clone().detach() for sample in batch]

    # Pad sequences to the maximum length in the batch
    input_ids = torch.nn.utils.rnn.pad_sequence(
        input_ids, batch_first=True, padding_value=tokenizer.pad_token_id
    )
    attention_mask = torch.nn.utils.rnn.pad_sequence(
        attention_mask, batch_first=True, padding_value=0
    )
    labels = torch.nn.utils.rnn.pad_sequence(
        labels, batch_first=True, padding_value=-100
    )

    return {'input_ids': input_ids, 'attention_mask': attention_mask, 'labels': labels}



# Main Training Loop with Adversarial Training and Memory Replay
current_threshold = initial_threshold

for epoch in range(epochs):
    # Filter dataset based on current_threshold
    filtered_dataset = filter_by_difficulty(difficulty_dataset.dataset, difficulty_dataset.scores, current_threshold)
    train_dataloader = DataLoader(
        filtered_dataset,
        batch_size=batch_size,
        shuffle=False,
        collate_fn=data_collator,
    )

    student_model.train()
    discriminator.train()
    epoch_loss = 0.0
    current_layers = layers_to_include[min(epoch, len(layers_to_include) - 1)]
    factor = epoch
    for batch in tqdm(train_dataloader, desc=f'Epoch {epoch + 1}'):
        # Move batch to device
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        # Add current batch to replay buffer
        for i in range(input_ids.size(0)):
            replay_buffer.append({
                'input_ids': input_ids[i].cpu(),
                'attention_mask': attention_mask[i].cpu(),
                'labels': labels[i].cpu(),
            })

        optimizer.zero_grad()
        discriminator_optimizer.zero_grad()
        
        with autocast():
            adv_input_ids = input_ids 

            # Teacher outputs
            with torch.no_grad():
                teacher_outputs = teacher_model(
                    input_ids=adv_input_ids,
                    attention_mask=attention_mask,
                    output_hidden_states=True,
                    output_attentions=True,
                )
                teacher_logits = teacher_outputs.logits
                teacher_hidden_states = teacher_outputs.hidden_states
                teacher_attentions = teacher_outputs.attentions

            # Student outputs
            student_outputs = student_model(
                input_ids=adv_input_ids,
                attention_mask=attention_mask,
                output_hidden_states=True,
                output_attentions=True,
            )
            student_logits = student_outputs.logits
            student_hidden_states = student_outputs.hidden_states
            student_attentions = student_outputs.attentions

            # Compute loss
            loss = distillation_loss(
                student_logits,
                teacher_logits,
                student_hidden_states,
                teacher_hidden_states,
                student_attentions,
                teacher_attentions,
                labels,
                current_layers,
                temperature,
                alpha_ce,
                alpha_hidden,
                alpha_attn,
                alpha_gan,
                projection_layer=projection_layer,
            )

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        # Train discriminator
        with torch.no_grad():
            teacher_hidden = teacher_hidden_states[-1].detach()
            teacher_hidden_proj = projection_layer(teacher_hidden)
            teacher_pooled = teacher_hidden_proj.mean(dim=1)
            student_hidden = student_hidden_states[-1].detach()
            student_pooled = student_hidden.mean(dim=1)

        real_labels = torch.ones((teacher_pooled.size(0), 1), device=device)
        fake_labels = torch.zeros((student_pooled.size(0), 1), device=device)

        discriminator_optimizer.zero_grad()
        real_outputs = discriminator(teacher_pooled)
        fake_outputs = discriminator(student_pooled)

        with autocast():
            disc_loss_real = nn.functional.binary_cross_entropy_with_logits(
                real_outputs,
                real_labels
            )
            disc_loss_fake = nn.functional.binary_cross_entropy_with_logits(
                fake_outputs,
                fake_labels
            )
            disc_loss = (disc_loss_real + disc_loss_fake) / 2

        disc_loss.backward()
        discriminator_optimizer.step()

        epoch_loss += loss.item()

    avg_loss = epoch_loss / len(train_dataloader)
    log_loss(factor,avg_loss)

    # Scheduler step
    scheduler.step(avg_loss)

    # Adjust loss weights based on performance
    if avg_loss < loss_threshold:
        alpha_hidden *= 0.9
        alpha_attn *= 0.9
        alpha_gan *= 0.9
        current_threshold += threshold_increment  # Include harder examples
        current_threshold = min(current_threshold, max_threshold)
        print(f"Increasing difficulty threshold to {current_threshold:.2f}")
    else:
        alpha_hidden *= 1.1
        alpha_attn *= 1.1
        alpha_gan *= 1.1

    alpha_hidden = min(max(alpha_hidden, 0.1), 1.0)
    alpha_attn = min(max(alpha_attn, 0.1), 1.0)
    alpha_gan = min(max(alpha_gan, 0.05), 0.2)

    # Memory Replay
    # Sample from replay buffer
    if len(replay_buffer) >= batch_size:
        replay_indices = torch.randperm(len(replay_buffer))[:batch_size]
        replay_samples = [replay_buffer[i] for i in replay_indices]
        replay_batch = replay_collate_fn(replay_samples)

        # Move tensors to device
        input_ids = replay_batch['input_ids'].to(device)
        attention_mask = replay_batch['attention_mask'].to(device)
        labels = replay_batch['labels'].to(device)

        optimizer.zero_grad()

        with autocast():
            # Teacher outputs
            with torch.no_grad():
                teacher_outputs = teacher_model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    output_hidden_states=True,
                    output_attentions=True,
                )
                teacher_logits = teacher_outputs.logits
                teacher_hidden_states = teacher_outputs.hidden_states
                teacher_attentions = teacher_outputs.attentions

            # Student outputs
            student_outputs = student_model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                output_hidden_states=True,
                output_attentions=True,
            )
            student_logits = student_outputs.logits
            student_hidden_states = student_outputs.hidden_states
            student_attentions = student_outputs.attentions

            # Compute loss
            loss = distillation_loss(
                student_logits,
                teacher_logits,
                student_hidden_states,
                teacher_hidden_states,
                student_attentions,
                teacher_attentions,
                labels,
                current_layers,
                temperature,
                alpha_ce,
                alpha_hidden,
                alpha_attn,
                alpha_gan,
                projection_layer=projection_layer,
            )

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

# Save the student model
student_model.save_pretrained('distilled_student_model')
tokenizer.save_pretrained('distilled_student_model')

Computing difficulty scores...


100%|██████████| 1162/1162 [19:47<00:00,  1.02s/it]
[I 2024-11-19 20:20:05,376] A new study created in memory with name: no-name-fac653b2-2a06-470b-8ab1-f584822c4fb5


Difficulty scores computed.


Epoch 1: 100%|██████████| 581/581 [14:24<00:00,  1.49s/it]
[I 2024-11-19 20:34:29,636] Trial 0 finished with value: 856.2982000787574 and parameters: {'learning_rate': 2.4499058071872988e-05, 'alpha_ce': 0.6173932761651971, 'alpha_hidden': 0.33743786818553967, 'alpha_attn': 0.7645336748789306, 'alpha_gan': 0.05251643027077909, 'batch_size': 2}. Best is trial 0 with value: 856.2982000787574.


Epoch 1 Average Loss: 856.2982
Best hyperparameters: {'learning_rate': 2.4499058071872988e-05, 'alpha_ce': 0.6173932761651971, 'alpha_hidden': 0.33743786818553967, 'alpha_attn': 0.7645336748789306, 'alpha_gan': 0.05251643027077909, 'batch_size': 2}


Epoch 1: 100%|██████████| 73/73 [03:04<00:00,  2.52s/it]


Epoch 1 Average Loss: 157.6490
Increasing difficulty threshold to 3084.40


Epoch 2: 100%|██████████| 571/571 [17:07<00:00,  1.80s/it]


Epoch 2 Average Loss: 37.4496
Increasing difficulty threshold to 6149.02


Epoch 3: 100%|██████████| 579/579 [17:26<00:00,  1.81s/it]


Epoch 3 Average Loss: 27.8137
Increasing difficulty threshold to 9213.65


Epoch 4: 100%|██████████| 580/580 [17:27<00:00,  1.81s/it]


Epoch 4 Average Loss: 20.2028
Increasing difficulty threshold to 12278.27


Epoch 5: 100%|██████████| 581/581 [17:18<00:00,  1.79s/it]


Epoch 5 Average Loss: 16.9399
Increasing difficulty threshold to 15342.89


Epoch 6: 100%|██████████| 581/581 [17:13<00:00,  1.78s/it]


Epoch 6 Average Loss: 12.5402
Increasing difficulty threshold to 15342.89


Epoch 7: 100%|██████████| 581/581 [17:13<00:00,  1.78s/it]


Epoch 7 Average Loss: 9.7000
Increasing difficulty threshold to 15342.89


Epoch 8: 100%|██████████| 581/581 [17:13<00:00,  1.78s/it]


Epoch 8 Average Loss: 7.6041
Increasing difficulty threshold to 15342.89


Epoch 9: 100%|██████████| 581/581 [16:58<00:00,  1.75s/it]


Epoch 9 Average Loss: 5.9259
Increasing difficulty threshold to 15342.89


Epoch 10: 100%|██████████| 581/581 [17:12<00:00,  1.78s/it]


Epoch 10 Average Loss: 4.5734
Increasing difficulty threshold to 15342.89
