###Connect to Drive

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


###configuration file

In [1]:
import torch

def get_config():
    return {
        "batch_size": 8,
        "num_epochs": 3,
        "lr": 1e-4,
        "seq_len": 512,
        "d_model": 768,
        "n_layers": 12,
        "head": 12,
        "d_ff": 3072,
        "dropout": 0.1,
        "temperature": 1.0,
        "top_k": 13900,
        "vocab_size": 0,
        "device": "cuda" if torch.cuda.is_available() else "cpu",
        "model_file_path": "/content/drive/MyDrive/Colab Notebooks/T-CLM/T-CLM2.pt",
        "dataset_file_path": "/content/drive/MyDrive/Colab Notebooks/T-CLM/dataset.json",
        "tokenizer_file_path": "/content/drive/MyDrive/Colab Notebooks/T-CLM/tokenizer.json",
        "logs": "/content/drive/MyDrive/Colab Notebooks/T-CLM/T-CLM-log",
    }

###BPE Tokenizer


In [2]:
from pathlib import Path
from tokenizers import Tokenizer

def get_tokenizer(config):
    tokenizer_path = Path(config['tokenizer_file_path'])
    if tokenizer_path.exists():
        tokenizer = Tokenizer.from_file(str(tokenizer_path))
    else:
        print("tokenizer file not found!")
    return tokenizer

###Data Pipeline

In [3]:
import torch
from torch.utils.data import Dataset, random_split, DataLoader
import json

class MultiTurnChatDataset(Dataset):

    def __init__(self, ds, tokenizer, seq_len):
        super().__init__()
        self.seq_len = seq_len
        self.ds = ds
        self.tokenizer = tokenizer
        self.pad_token = torch.tensor([tokenizer.token_to_id("[PAD]")], dtype=torch.int64)
        self.user_token = torch.tensor([tokenizer.token_to_id("[USER]")], dtype=torch.int64)
        self.bot_token = torch.tensor([tokenizer.token_to_id("[BOT]")], dtype=torch.int64)

    def __len__(self):
        return len(self.ds)

    def __getitem__(self, idx):
        conversation = self.ds[idx]['conversation']

        # Concatenate multi-turn dialogue, adding <USER> and <BOT> tokens for each turn
        conversation_tokens = []
        for i, turn in enumerate(conversation):
            if i % 2 == 0:  # Even index: User turn
                conversation_tokens += [self.user_token] + self.tokenizer.encode(turn).ids
            else:  # Odd index: Bot turn
                conversation_tokens += [self.bot_token] + self.tokenizer.encode(turn).ids

        # Ensure the total length doesn't exceed the sequence length
        num_padding_tokens = self.seq_len - len(conversation_tokens)
        if num_padding_tokens < 0:
            raise ValueError("Multi-turn conversation is too long for the sequence length")

        # Pad the conversation tokens
        input_tokens = conversation_tokens + [self.pad_token] * num_padding_tokens

        # Convert to torch tensors
        inputs = torch.tensor(input_tokens, dtype=torch.int64)

        assert inputs.size(0) == self.seq_len

        # The model's task is to predict the next bot response, so the labels are shifted
        labels = torch.tensor(input_tokens[1:] + [self.pad_token], dtype=torch.int64)  # Shifted by 1 for language modeling

        return {
            "input_ids": inputs,
            "labels": labels,
        }

###Load dataset

In [4]:
def get_ds(config):
    with open(config['dataset_file_path'], 'r', encoding='utf-8') as f:
        ds_raw = json.load(f)

    tokenizer = get_tokenizer(config)

    # Split the dataset into train and validation sets
    train_ds_size = int(0.99 * len(ds_raw))
    val_ds_size = len(ds_raw) - train_ds_size
    train_ds_raw, val_ds_raw = random_split(ds_raw, [train_ds_size, val_ds_size])

    # Create datasets and dataloaders
    train_ds = MultiTurnChatDataset(train_ds_raw, tokenizer, config['seq_len'])
    val_ds = MultiTurnChatDataset(val_ds_raw, tokenizer, config['seq_len'])

    train_dataloader = DataLoader(train_ds, batch_size=config['batch_size'], shuffle=True)
    val_dataloader = DataLoader(val_ds, batch_size=1, shuffle=True)

    return train_dataloader, val_dataloader, tokenizer

###Transformer model

In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import math

class LayerNormalization(nn.Module):
    def __init__(self, features: int, eps:float=10**-6) -> None:
        super().__init__()
        self.eps = eps
        self.alpha = nn.Parameter(torch.ones(features))
        self.bias = nn.Parameter(torch.zeros(features))

    def forward(self, x):
        mean = x.mean(dim=-1, keepdim=True)
        std = x.std(dim=-1, keepdim=True)
        return self.alpha * (x - mean) / (std + self.eps) + self.bias

class FeedForwardBlock(nn.Module):
    def __init__(self, d_model: int, d_ff: int, dropout: float) -> None:
        super().__init__()
        self.linear_1 = nn.Linear(d_model, d_ff)
        self.dropout = nn.Dropout(dropout)
        self.linear_2 = nn.Linear(d_ff, d_model)

    def forward(self, x):
        return self.linear_2(self.dropout(torch.relu(self.linear_1(x))))

class InputEmbeddings(nn.Module):
    def __init__(self, d_model: int, vocab_size: int) -> None:
        super().__init__()
        self.d_model = d_model
        self.embedding = nn.Embedding(vocab_size, d_model)

    def forward(self, x):
        return self.embedding(x) * math.sqrt(self.d_model)

class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, seq_len: int, dropout: float) -> None:
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        pe = torch.zeros(seq_len, d_model)
        position = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + (self.pe[:, :x.shape[1], :]).requires_grad_(False)
        return self.dropout(x)

class ResidualConnection(nn.Module):
    def __init__(self, features: int, dropout: float) -> None:
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        self.norm = LayerNormalization(features)

    def forward(self, x, sublayer):
        return x + self.dropout(sublayer(self.norm(x)))

class MultiHeadAttentionBlock(nn.Module):
    def __init__(self, d_model: int, h: int, dropout: float) -> None:
        super().__init__()
        self.d_model = d_model
        self.h = h
        assert d_model % h == 0, "d_model must be divisible by h"
        self.d_k = d_model // h

        self.w_q = nn.Linear(d_model, d_model, bias=False)
        self.w_k = nn.Linear(d_model, d_model, bias=False)
        self.w_v = nn.Linear(d_model, d_model, bias=False)
        self.w_o = nn.Linear(d_model, d_model, bias=False)
        self.dropout = nn.Dropout(dropout)

    @staticmethod
    def attention(query, key, value, mask, dropout: nn.Dropout):
        d_k = query.shape[-1]
        scores = (query @ key.transpose(-2, -1)) / math.sqrt(d_k)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        scores = F.softmax(scores, dim=-1)
        if dropout is not None:
            scores = dropout(scores)
        return scores @ value, scores

    def forward(self, q, k, v, mask=None):
        batch_size = q.size(0)

        query = self.w_q(q).view(batch_size, -1, self.h, self.d_k).transpose(1, 2)
        key = self.w_k(k).view(batch_size, -1, self.h, self.d_k).transpose(1, 2)
        value = self.w_v(v).view(batch_size, -1, self.h, self.d_k).transpose(1, 2)

        x, attn = self.attention(query, key, value, mask, self.dropout)

        x = x.transpose(1, 2).contiguous().view(batch_size, -1, self.h * self.d_k)
        return self.w_o(x)

class DecoderBlock(nn.Module):
    def __init__(self, features: int, self_attention_block: MultiHeadAttentionBlock, feed_forward_block: FeedForwardBlock, dropout: float) -> None:
        super().__init__()
        self.self_attention_block = self_attention_block
        self.feed_forward_block = feed_forward_block
        self.residual_connections = nn.ModuleList([ResidualConnection(features, dropout) for _ in range(2)])

    def forward(self, x, tgt_mask):
        x = self.residual_connections[0](x, lambda x: self.self_attention_block(x, x, x, tgt_mask))
        x = self.residual_connections[1](x, self.feed_forward_block)
        return x

class Decoder(nn.Module):
    def __init__(self, layers: nn.ModuleList, norm_layer: LayerNormalization) -> None:
        super().__init__()
        self.layers = layers
        self.norm = norm_layer

    def forward(self, x, tgt_mask):
        for layer in self.layers:
            x = layer(x, tgt_mask)
        return self.norm(x)

class ProjectionLayer(nn.Module):
    def __init__(self, d_model: int, vocab_size: int) -> None:
        super().__init__()
        self.proj = nn.Linear(d_model, vocab_size)

    def forward(self, x):
        return self.proj(x)

class TCLM(nn.Module):
    def __init__(self, vocab_size: int, seq_len: int, d_model: int, N: int, h: int, dropout: float, d_ff: int):
        super().__init__()
        self.input_embed = InputEmbeddings(d_model, vocab_size)
        self.pos_embed = PositionalEncoding(d_model, seq_len, dropout)

        # Decoder blocks with multi-head attention and feed-forward
        self.layers = nn.ModuleList([
            DecoderBlock(
                features=d_model,
                self_attention_block=MultiHeadAttentionBlock(d_model, h, dropout),
                feed_forward_block=FeedForwardBlock(d_model, d_ff, dropout),
                dropout=dropout
            ) for _ in range(N)
        ])

        self.decoder = Decoder(self.layers, LayerNormalization(d_model))
        self.projection_layer = ProjectionLayer(d_model, vocab_size)

    def forward(self, idx, targets=None):
        B, T = idx.shape
        token_embed = self.input_embed(idx)
        x = self.pos_embed(token_embed)

        tgt_mask = torch.tril(torch.ones((T, T), device=idx.device)).unsqueeze(0).unsqueeze(0)
        x = self.decoder(x, tgt_mask)

        logits = self.projection_layer(x)
 
        if targets is not None:
            logits = logits.view(-1, logits.size(-1))
            targets = targets.view(-1)
            loss = F.cross_entropy(logits, targets)
        else:
            loss = None

        return logits, loss

    def generate(self, idx: torch.Tensor, max_new_tokens: int, seq_len: int, temperature: float = 1.0, top_k: int = 10):
        vocab_size = self.projection_layer.proj.out_features
        for _ in range(max_new_tokens):

            if idx.size(1) > seq_len:
                idx_crop = idx[:, -seq_len:]
            else:
                idx_crop = idx

            # Forward pass through the model
            logits, _ = self.forward(idx_crop)
            logits = logits[:, -1, :]  # logits for the last token in the sequence

            logits = logits / temperature
            probs = F.softmax(logits, dim=-1)

            top_k = min(top_k, vocab_size)
            top_k_probs, top_k_indices = torch.topk(probs, top_k, dim=-1)
        
            top_k_probs = top_k_probs / top_k_probs.sum(dim=-1, keepdim=True)
            idx_next = top_k_indices.gather(-1, torch.multinomial(top_k_probs, 1))

            idx = torch.cat((idx, idx_next), dim=1)

        return idx




###Preload model

In [7]:
def load_model(config, device, model, optimizer):
    initial_epoch = 1
    global_step = 0
    model = model.to(device)
    model_file_path = config['model_file_path']
    if Path(model_file_path).exists():
        print(f'Loading model from {str(model_file_path)}')
        state = torch.load(str(model_file_path), map_location=device, weights_only=True)
        model.load_state_dict(state['model_state_dict'])
        initial_epoch = state['epoch'] + 1
        optimizer.load_state_dict(state['optimizer_state_dict'])
        global_step = state['global_step']
    else:
        print("No model file found.")

    return model, optimizer, initial_epoch, global_step

###WANDB

In [None]:
!pip install wandb -qU

In [None]:
import wandb
wandb.login()

###Training loop

In [None]:
import torch
import warnings
import sys
from tqdm import tqdm
import wandb

def train(config):
    # Initialize wandb
    wandb.init(project="T-CLM2", config=config)

    device = config['device']
    print("Using device:", device)
    if device == 'cuda':
        print(f"Device name: {torch.cuda.get_device_name(device.index)}")
        print(f"Device memory: {round(torch.cuda.get_device_properties(device.index).total_memory / 1024 ** 3, 1)} GB")
    device = torch.device(device)

    train_dataloader, val_dataloader, tokenizer = get_ds(config)
    model = TCLM(vocab_size=tokenizer.get_vocab_size(), seq_len=config['seq_len'], d_model=config['d_model'], N=config['n_layers'], h=config['head'], dropout=config['dropout'], d_ff=config['d_ff'])

    # Log model configuration in wandb
    wandb.watch(model, log="all")

    optimizer = torch.optim.Adam(model.parameters(), lr=config['lr'], eps=1e-9)
    model, optimizer, initial_epoch, global_step = load_model(config, device, model, optimizer)

    for epoch in range(initial_epoch, config['num_epochs']):
        torch.cuda.empty_cache()
        model.train()
        batch_iterator = tqdm(train_dataloader, desc=f"Processing Epoch {epoch:02d}")
        total_loss = 0
        num_batches = len(train_dataloader)

        for batch in batch_iterator:
            encoder_input = batch['inputs_ids'].to(device)
            targets = batch['labels'].to(device)

            optimizer.zero_grad()  # Reset gradients
            logits, loss = model(encoder_input, targets=targets)

            total_loss += loss.item()  # Accumulate loss
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)  # Gradient clipping
            optimizer.step()

            global_step += 1

            # Log batch loss to wandb
            wandb.log({"batch_loss": loss.item(), "global_step": global_step})

            batch_iterator.set_postfix({'Loss': loss.item()})

        # Epoch-level logging
        avg_loss = total_loss / num_batches
        print(f"Epoch {epoch} | Avg Loss: {round(avg_loss, 2)}")

        # Log average loss for epoch to wandb
        wandb.log({"average_loss": avg_loss, "epoch": epoch})

        # Model checkpointing
        model_filename = config['model_file_path']
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'global_step': global_step,
        }, model_filename)

        validate(model, val_dataloader, device, epoch)

def validate(model, val_dataloader, device, epoch):
    model.eval()
    total_val_loss = 0
    num_batches = len(val_dataloader)

    with torch.no_grad():
        for batch in val_dataloader:
            input_ids = batch['inputs_ids'].to(device)
            targets = batch['labels'].to(device)
            _, val_loss = model(input_ids, targets=targets)
            total_val_loss += val_loss.item()

    avg_val_loss = total_val_loss / num_batches
    print(f"Validation Loss (Epoch {epoch}): {round(avg_val_loss, 2)}")

    # Log validation loss to wandb
    wandb.log({"validation_loss": avg_val_loss, "epoch": epoch})

if __name__ == '__main__':
    warnings.filterwarnings("ignore")
    config = get_config()
    train(config)


###Inference

In [None]:
import torch

config = get_config()
tokenizer = get_tokenizer(config)

user_token_id = tokenizer.token_to_id('[USER]')
bot_token_id = tokenizer.token_to_id('[BOT]')

# Initialize the model
model = TCLM(
    vocab_size=tokenizer.get_vocab_size(),
    seq_len=config['seq_len'],
    d_model=config['d_model'],
    N=config['n_layers'],
    h=config['head'],
    dropout=config['dropout'],
    d_ff=config['d_ff']
)

optimizer = torch.optim.Adam(model.parameters(), lr=config['lr'], eps=1e-9)
model, _, _, _ = load_model(config, config['device'], model, optimizer)

conversation_history = []

while True:
    user_input = input("You: ")

    if user_input.lower() == "exit" or user_input == "":
        break

    conversation_history.append(user_input)
    input_sequence = [user_token_id]
    user_input_ids = tokenizer.encode(user_input).ids
    input_sequence += user_input_ids

    # If conversation history exceeds a certain length, truncate it
    if len(conversation_history) > 5:  # Adjust the number based on context length needed
        conversation_history = conversation_history[-5:]

    # Append previous turns to the input sequence
    for i in range(len(conversation_history)):
        if i % 2 == 0:  # User turn
            input_sequence += [user_token_id] + tokenizer.encode(conversation_history[i]).ids
        else:  # Bot turn
            input_sequence += [bot_token_id] + tokenizer.encode(conversation_history[i]).ids

    # Convert to tensor and send to device
    try:
        input_tensor = torch.tensor([input_sequence]).to(config['device'])
    except Exception as e:
        print(f"Error while converting to tensor: {e}")
        continue

    # Generate response from the model
    generated_sequence = model.generate(
        input_tensor,
        max_new_tokens=20,
        seq_len=config['seq_len'],
        temperature=config['temperature'],
        top_k=config['top_k']
    )

    predicted_text = tokenizer.decode(generated_sequence[0].cpu().numpy())
    conversation_history.append(predicted_text)

    print("Omnira:", predicted_text)
