In [1]:
!pip install torch

Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch)
  Downloading nvidia_curand_cu12-10.3.5

In [2]:
import torch
print("GPU Available:", torch.cuda.is_available())

GPU Available: True


In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
from transformer_decoder_model import TransformerDecoderModel, Config  # Import your model here

Model initialized with 29.29M parameters


In [5]:
# Custom Dataset
import torch
from torch.utils.data import Dataset

class ChatDataset(Dataset):
    def __init__(self, path):
        data = np.load(path)
        self.inputs = torch.tensor(data['input_ids'], dtype=torch.long)
        self.targets = torch.tensor(data['target_ids'], dtype=torch.long)

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        return self.inputs[idx], self.targets[idx]

In [6]:
# Hyperparameters
BATCH_SIZE = 64
EPOCHS = 3
LEARNING_RATE = 3e-4
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [7]:
# Load Data
dataset = ChatDataset('chat_data_sequences.npz')
train_loader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=True)

In [8]:
# Initialize Model
config = Config()
model = TransformerDecoderModel(config).to(DEVICE)
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE)

In [9]:
# Training Loop
model.train()
for epoch in range(EPOCHS):
    total_loss = 0
    for step, (x, y) in enumerate(train_loader):
        x, y = x.to(DEVICE), y.to(DEVICE)

        logits, loss = model(x, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        if step % 100 == 0:
            print(f"Epoch [{epoch+1}/{EPOCHS}], Step [{step}/{len(train_loader)}], Loss: {loss.item():.4f}")

    avg_loss = total_loss / len(train_loader)
    print(f"🎯 Epoch [{epoch+1}/{EPOCHS}] complete — Average Loss: {avg_loss:.4f}")

Epoch [1/3], Step [0/10575], Loss: 9.3553
Epoch [1/3], Step [100/10575], Loss: 0.5418
Epoch [1/3], Step [200/10575], Loss: 0.1958
Epoch [1/3], Step [300/10575], Loss: 0.1771
Epoch [1/3], Step [400/10575], Loss: 0.1687
Epoch [1/3], Step [500/10575], Loss: 0.1581
Epoch [1/3], Step [600/10575], Loss: 0.1674
Epoch [1/3], Step [700/10575], Loss: 0.1573
Epoch [1/3], Step [800/10575], Loss: 0.1714
Epoch [1/3], Step [900/10575], Loss: 0.1632
Epoch [1/3], Step [1000/10575], Loss: 0.1649
Epoch [1/3], Step [1100/10575], Loss: 0.1644
Epoch [1/3], Step [1200/10575], Loss: 0.1601
Epoch [1/3], Step [1300/10575], Loss: 0.1608
Epoch [1/3], Step [1400/10575], Loss: 0.1640
Epoch [1/3], Step [1500/10575], Loss: 0.1655
Epoch [1/3], Step [1600/10575], Loss: 0.1611
Epoch [1/3], Step [1700/10575], Loss: 0.1619
Epoch [1/3], Step [1800/10575], Loss: 0.1629
Epoch [1/3], Step [1900/10575], Loss: 0.1618
Epoch [1/3], Step [2000/10575], Loss: 0.1614
Epoch [1/3], Step [2100/10575], Loss: 0.1678
Epoch [1/3], Step [220

In [10]:
# Save Model
torch.save(model.state_dict(), "trained_decoder_model.pth")
print("Model training complete and saved as 'trained_decoder_model.pth'")

Model training complete and saved as 'trained_decoder_model.pth'
