In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import LambdaLR, StepLR







    

In [None]:
# # create a simple model 
class SimpleTransformer(nn.Module):
    def __init__(self):
        super(SimpleTransformer, self).__init__()
        self.linear = nn.Linear(768, 768)  # # (define as many layers needed)
        self.dropout = nn.Dropout(0.1)  

    def forward(self, x):
        return self.dropout(self.linear(x))


In [None]:
# Initialize model - Do not forget to select a GPU (T4 on Colab)
model = SimpleTransformer()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

In [None]:
# Optimizer with weight decay -Regularization
optimizer = optim.AdamW(model.parameters(), lr=5e-5, weight_decay=0.01)

# Learning Rate Scheduler  -> Warm-up  & Step Decay
def warmup_schedule(epoch, warmup_epochs=5, base_lr=5e-6, target_lr=5e-5):
    if epoch < warmup_epochs:
        return base_lr + (target_lr - base_lr) * (epoch / warmup_epochs)
    return target_lr

# define a warm up
warmup_scheduler = LambdaLR(optimizer, lr_lambda=lambda epoch: warmup_schedule(epoch))
decay_scheduler = StepLR(optimizer, step_size=10, gamma=0.5)

In [None]:
# Training Parameters
batch_size = 32  
gradient_accumulation_steps = 4  #  a larger batch size (128)

# Dummy Training Loop
num_epochs = 30
for epoch in range(num_epochs):
    optimizer.zero_grad()
    total_loss = 0

    for step in range(120):  # Assuming 120 batches per epoch
        inputs = torch.randn(batch_size, 768).to(device)  # Fake data
        outputs = model(inputs)
        loss = torch.mean(outputs)  # Dummy loss

        loss = loss / gradient_accumulation_steps  # Scale loss for accumulation
        loss.backward()
        total_loss += loss.item()

        if (step + 1) % gradient_accumulation_steps == 0:
            # Gradient Clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)  
            optimizer.step()
            optimizer.zero_grad()

    # Apply Warm-up  or Decay
    if epoch < 5:
        warmup_scheduler.step()
    else:
        decay_scheduler.step()
    # Print Learning Rate & Loss
    current_lr = optimizer.param_groups[0]['lr']
    print(f"Epoch {epoch+1}: LR = {current_lr:.6e}, Loss = {total_loss:.4f}")